Skip to main content

storage/
cache.rs

1//! L1 In-Memory Cache using Moka
2//!
3//! High-performance concurrent cache for vectors with LRU eviction.
4
5use common::Vector;
6use futures_util::{stream::FuturesUnordered, StreamExt};
7use moka::future::Cache;
8use std::sync::Arc;
9use std::time::Duration;
10
11/// Configuration for the L1 cache
12#[derive(Debug, Clone)]
13pub struct CacheConfig {
14    /// Maximum number of vectors to cache
15    pub max_capacity: u64,
16    /// Time-to-live for cached entries
17    pub ttl: Option<Duration>,
18    /// Time-to-idle for cached entries (evict if not accessed)
19    pub tti: Option<Duration>,
20}
21
22impl Default for CacheConfig {
23    fn default() -> Self {
24        Self {
25            max_capacity: 100_000,
26            ttl: Some(Duration::from_secs(3600)), // 1 hour
27            tti: Some(Duration::from_secs(600)),  // 10 minutes idle
28        }
29    }
30}
31
32/// Cache key combining namespace and vector ID
33#[derive(Debug, Clone, Hash, Eq, PartialEq)]
34pub struct CacheKey {
35    pub namespace: Arc<str>,
36    pub vector_id: Arc<str>,
37}
38
39impl CacheKey {
40    pub fn new(namespace: impl AsRef<str>, vector_id: impl AsRef<str>) -> Self {
41        Self {
42            namespace: Arc::from(namespace.as_ref()),
43            vector_id: Arc::from(vector_id.as_ref()),
44        }
45    }
46}
47
48/// L1 in-memory vector cache
49#[derive(Clone)]
50pub struct VectorCache {
51    cache: Cache<CacheKey, Arc<Vector>>,
52    config: CacheConfig,
53}
54
55impl VectorCache {
56    /// Create a new cache with the given configuration
57    pub fn new(config: CacheConfig) -> Self {
58        let mut builder = Cache::builder().max_capacity(config.max_capacity);
59
60        if let Some(ttl) = config.ttl {
61            builder = builder.time_to_live(ttl);
62        }
63
64        if let Some(tti) = config.tti {
65            builder = builder.time_to_idle(tti);
66        }
67
68        let cache = builder.build();
69
70        Self { cache, config }
71    }
72
73    /// Create with default configuration
74    pub fn with_defaults() -> Self {
75        Self::new(CacheConfig::default())
76    }
77
78    /// Get a vector from the cache
79    pub async fn get(&self, namespace: &str, vector_id: &str) -> Option<Arc<Vector>> {
80        let key = CacheKey::new(namespace, vector_id);
81        self.cache.get(&key).await
82    }
83
84    /// Insert a vector into the cache
85    pub async fn insert(&self, namespace: &str, vector: Vector) {
86        let key = CacheKey::new(namespace, &vector.id);
87        self.cache.insert(key, Arc::new(vector)).await;
88    }
89
90    /// Insert multiple vectors into the cache
91    pub async fn insert_batch(&self, namespace: &str, vectors: Vec<Vector>) {
92        let mut futs: FuturesUnordered<_> = vectors
93            .into_iter()
94            .map(|v| self.insert(namespace, v))
95            .collect();
96        while futs.next().await.is_some() {}
97    }
98
99    /// Remove a vector from the cache
100    pub async fn remove(&self, namespace: &str, vector_id: &str) {
101        let key = CacheKey::new(namespace, vector_id);
102        self.cache.remove(&key).await;
103    }
104
105    /// Remove multiple vectors from the cache
106    pub async fn remove_batch(&self, namespace: &str, vector_ids: &[String]) {
107        for id in vector_ids {
108            self.remove(namespace, id).await;
109        }
110    }
111
112    /// Invalidate all entries for a namespace.
113    pub async fn invalidate_namespace(&self, namespace: &str) {
114        let ns: Arc<str> = Arc::from(namespace);
115        self.cache
116            .invalidate_entries_if(move |k, _v| *k.namespace == *ns)
117            .expect("invalidate_entries_if failed");
118        tracing::debug!(namespace = namespace, "Cache namespace invalidated");
119    }
120
121    /// Clear the entire cache
122    pub fn clear(&self) {
123        self.cache.invalidate_all();
124    }
125
126    /// Get cache statistics
127    pub fn stats(&self) -> CacheStats {
128        CacheStats {
129            entry_count: self.cache.entry_count(),
130            weighted_size: self.cache.weighted_size(),
131            max_capacity: self.config.max_capacity,
132        }
133    }
134
135    /// Run pending maintenance tasks (eviction, etc.)
136    pub async fn run_pending_tasks(&self) {
137        self.cache.run_pending_tasks().await;
138    }
139}
140
141/// Cache statistics
142#[derive(Debug, Clone)]
143pub struct CacheStats {
144    /// Number of entries in the cache
145    pub entry_count: u64,
146    /// Weighted size of the cache
147    pub weighted_size: u64,
148    /// Maximum capacity
149    pub max_capacity: u64,
150}
151
152impl CacheStats {
153    /// Cache utilization as a percentage
154    pub fn utilization(&self) -> f64 {
155        if self.max_capacity == 0 {
156            return 0.0;
157        }
158        (self.entry_count as f64 / self.max_capacity as f64) * 100.0
159    }
160}
161
162/// Cached storage wrapper that adds L1 caching to any VectorStorage
163pub struct CachedStorage<S> {
164    inner: S,
165    cache: VectorCache,
166    redis: Option<crate::RedisCache>,
167}
168
169impl<S> CachedStorage<S> {
170    pub fn new(inner: S, cache: VectorCache, redis: Option<crate::RedisCache>) -> Self {
171        Self {
172            inner,
173            cache,
174            redis,
175        }
176    }
177
178    pub fn with_default_cache(inner: S) -> Self {
179        Self::new(inner, VectorCache::with_defaults(), None)
180    }
181
182    pub fn cache(&self) -> &VectorCache {
183        &self.cache
184    }
185
186    pub fn inner(&self) -> &S {
187        &self.inner
188    }
189
190    pub fn redis(&self) -> Option<&crate::RedisCache> {
191        self.redis.as_ref()
192    }
193}
194
195#[async_trait::async_trait]
196impl<S: crate::VectorStorage> crate::VectorStorage for CachedStorage<S> {
197    async fn upsert(
198        &self,
199        namespace: &common::NamespaceId,
200        vectors: Vec<common::Vector>,
201    ) -> common::Result<usize> {
202        let count = self.inner.upsert(namespace, vectors.clone()).await?;
203        // Populate L1 cache with upserted vectors
204        self.cache.insert_batch(namespace, vectors.clone()).await;
205        // Populate L1.5 Redis and publish invalidation for HA peers
206        if let Some(ref redis) = self.redis {
207            redis.set_batch(namespace, &vectors).await;
208            let ids: Vec<String> = vectors.iter().map(|v| v.id.clone()).collect();
209            redis
210                .publish_invalidation(&crate::CacheInvalidation::Vectors {
211                    namespace: namespace.to_string(),
212                    ids,
213                })
214                .await;
215        }
216        Ok(count)
217    }
218
219    async fn get(
220        &self,
221        namespace: &common::NamespaceId,
222        ids: &[common::VectorId],
223    ) -> common::Result<Vec<common::Vector>> {
224        let mut found = Vec::new();
225        let mut missing_ids: Vec<String> = Vec::new();
226
227        // Check L1 Moka first
228        for id in ids {
229            if let Some(v) = self.cache.get(namespace, id).await {
230                found.push((*v).clone());
231            } else {
232                missing_ids.push(id.clone());
233            }
234        }
235        if missing_ids.is_empty() {
236            return Ok(found);
237        }
238
239        // Check L1.5 Redis
240        if let Some(ref redis) = self.redis {
241            let from_redis = redis.get_multi(namespace, &missing_ids).await;
242            let redis_found_ids: std::collections::HashSet<String> =
243                from_redis.iter().map(|v| v.id.clone()).collect();
244            for v in &from_redis {
245                self.cache.insert(namespace, v.clone()).await; // backfill L1
246            }
247            found.extend(from_redis);
248            missing_ids.retain(|id| !redis_found_ids.contains(id));
249        }
250        if missing_ids.is_empty() {
251            return Ok(found);
252        }
253
254        // Fall through to backing store
255        let from_store = self.inner.get(namespace, &missing_ids).await?;
256        for v in &from_store {
257            self.cache.insert(namespace, v.clone()).await; // backfill L1
258            if let Some(ref redis) = self.redis {
259                redis.set(namespace, v).await; // backfill L1.5
260            }
261        }
262        found.extend(from_store);
263        Ok(found)
264    }
265
266    async fn get_all(
267        &self,
268        namespace: &common::NamespaceId,
269    ) -> common::Result<Vec<common::Vector>> {
270        let vectors = self.inner.get_all(namespace).await?;
271        // Backfill L1 cache so subsequent individual get() calls will hit
272        for v in &vectors {
273            self.cache.insert(namespace, v.clone()).await;
274        }
275        // Backfill L1.5 Redis if configured
276        if let Some(ref redis) = self.redis {
277            redis.set_batch(namespace, &vectors).await;
278        }
279        Ok(vectors)
280    }
281
282    async fn delete(
283        &self,
284        namespace: &common::NamespaceId,
285        ids: &[common::VectorId],
286    ) -> common::Result<usize> {
287        let count = self.inner.delete(namespace, ids).await?;
288        self.cache.remove_batch(namespace, ids).await;
289        if let Some(ref redis) = self.redis {
290            let id_strings: Vec<String> = ids.iter().map(|s| s.to_string()).collect();
291            redis.delete(namespace, &id_strings).await;
292            redis
293                .publish_invalidation(&crate::CacheInvalidation::Vectors {
294                    namespace: namespace.to_string(),
295                    ids: id_strings,
296                })
297                .await;
298        }
299        Ok(count)
300    }
301
302    async fn namespace_exists(&self, namespace: &common::NamespaceId) -> common::Result<bool> {
303        self.inner.namespace_exists(namespace).await
304    }
305
306    async fn ensure_namespace(&self, namespace: &common::NamespaceId) -> common::Result<()> {
307        self.inner.ensure_namespace(namespace).await
308    }
309
310    async fn count(&self, namespace: &common::NamespaceId) -> common::Result<usize> {
311        self.inner.count(namespace).await
312    }
313
314    async fn dimension(&self, namespace: &common::NamespaceId) -> common::Result<Option<usize>> {
315        self.inner.dimension(namespace).await
316    }
317
318    async fn list_namespaces(&self) -> common::Result<Vec<common::NamespaceId>> {
319        self.inner.list_namespaces().await
320    }
321
322    async fn delete_namespace(&self, namespace: &common::NamespaceId) -> common::Result<bool> {
323        let result = self.inner.delete_namespace(namespace).await?;
324        self.cache.invalidate_namespace(namespace).await;
325        if let Some(ref redis) = self.redis {
326            redis.invalidate_namespace(namespace).await;
327            redis
328                .publish_invalidation(&crate::CacheInvalidation::Namespace(namespace.to_string()))
329                .await;
330        }
331        Ok(result)
332    }
333
334    async fn cleanup_expired(&self, namespace: &common::NamespaceId) -> common::Result<usize> {
335        self.inner.cleanup_expired(namespace).await
336    }
337
338    async fn cleanup_all_expired(&self) -> common::Result<usize> {
339        self.inner.cleanup_all_expired().await
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[tokio::test]
348    async fn test_cache_insert_and_get() {
349        let cache = VectorCache::with_defaults();
350
351        let vector = Vector {
352            id: "v1".to_string(),
353            values: vec![1.0, 2.0, 3.0],
354            metadata: None,
355            ttl_seconds: None,
356            expires_at: None,
357        };
358
359        cache.insert("test_ns", vector.clone()).await;
360
361        let retrieved = cache.get("test_ns", "v1").await;
362        assert!(retrieved.is_some());
363
364        let retrieved = retrieved.unwrap();
365        assert_eq!(retrieved.id, "v1");
366        assert_eq!(retrieved.values, vec![1.0, 2.0, 3.0]);
367    }
368
369    #[tokio::test]
370    async fn test_cache_miss() {
371        let cache = VectorCache::with_defaults();
372
373        let retrieved = cache.get("test_ns", "nonexistent").await;
374        assert!(retrieved.is_none());
375    }
376
377    #[tokio::test]
378    async fn test_cache_remove() {
379        let cache = VectorCache::with_defaults();
380
381        let vector = Vector {
382            id: "v1".to_string(),
383            values: vec![1.0, 2.0, 3.0],
384            metadata: None,
385            ttl_seconds: None,
386            expires_at: None,
387        };
388
389        cache.insert("test_ns", vector).await;
390        assert!(cache.get("test_ns", "v1").await.is_some());
391
392        cache.remove("test_ns", "v1").await;
393        cache.run_pending_tasks().await;
394
395        assert!(cache.get("test_ns", "v1").await.is_none());
396    }
397
398    #[tokio::test]
399    async fn test_cache_batch_operations() {
400        let cache = VectorCache::with_defaults();
401
402        let vectors = vec![
403            Vector {
404                id: "v1".to_string(),
405                values: vec![1.0],
406                metadata: None,
407                ttl_seconds: None,
408                expires_at: None,
409            },
410            Vector {
411                id: "v2".to_string(),
412                values: vec![2.0],
413                metadata: None,
414                ttl_seconds: None,
415                expires_at: None,
416            },
417            Vector {
418                id: "v3".to_string(),
419                values: vec![3.0],
420                metadata: None,
421                ttl_seconds: None,
422                expires_at: None,
423            },
424        ];
425
426        cache.insert_batch("test_ns", vectors).await;
427
428        assert!(cache.get("test_ns", "v1").await.is_some());
429        assert!(cache.get("test_ns", "v2").await.is_some());
430        assert!(cache.get("test_ns", "v3").await.is_some());
431
432        cache
433            .remove_batch("test_ns", &["v1".to_string(), "v2".to_string()])
434            .await;
435        cache.run_pending_tasks().await;
436
437        assert!(cache.get("test_ns", "v1").await.is_none());
438        assert!(cache.get("test_ns", "v2").await.is_none());
439        assert!(cache.get("test_ns", "v3").await.is_some());
440    }
441
442    #[tokio::test]
443    async fn test_cache_stats() {
444        let cache = VectorCache::new(CacheConfig {
445            max_capacity: 1000,
446            ttl: None,
447            tti: None,
448        });
449
450        for i in 0..10 {
451            let vector = Vector {
452                id: format!("v{}", i),
453                values: vec![i as f32],
454                metadata: None,
455                ttl_seconds: None,
456                expires_at: None,
457            };
458            cache.insert("test_ns", vector).await;
459        }
460
461        // Verify entries are retrievable
462        for i in 0..10 {
463            assert!(cache.get("test_ns", &format!("v{}", i)).await.is_some());
464        }
465
466        let stats = cache.stats();
467        assert_eq!(stats.max_capacity, 1000);
468    }
469
470    #[tokio::test]
471    async fn test_cache_namespace_isolation() {
472        let cache = VectorCache::with_defaults();
473
474        let v1 = Vector {
475            id: "same_id".to_string(),
476            values: vec![1.0],
477            metadata: None,
478            ttl_seconds: None,
479            expires_at: None,
480        };
481
482        let v2 = Vector {
483            id: "same_id".to_string(),
484            values: vec![2.0],
485            metadata: None,
486            ttl_seconds: None,
487            expires_at: None,
488        };
489
490        cache.insert("ns1", v1).await;
491        cache.insert("ns2", v2).await;
492
493        let from_ns1 = cache.get("ns1", "same_id").await.unwrap();
494        let from_ns2 = cache.get("ns2", "same_id").await.unwrap();
495
496        assert_eq!(from_ns1.values, vec![1.0]);
497        assert_eq!(from_ns2.values, vec![2.0]);
498    }
499
500    #[tokio::test]
501    async fn test_cache_clear() {
502        let cache = VectorCache::with_defaults();
503
504        for i in 0..5 {
505            let vector = Vector {
506                id: format!("v{}", i),
507                values: vec![i as f32],
508                metadata: None,
509                ttl_seconds: None,
510                expires_at: None,
511            };
512            cache.insert("test_ns", vector).await;
513        }
514
515        // Verify entries exist before clear
516        for i in 0..5 {
517            assert!(cache.get("test_ns", &format!("v{}", i)).await.is_some());
518        }
519
520        cache.clear();
521        cache.run_pending_tasks().await;
522
523        // Verify entries are gone after clear
524        for i in 0..5 {
525            assert!(cache.get("test_ns", &format!("v{}", i)).await.is_none());
526        }
527    }
528}