Skip to main content

storage/
cache.rs

1//! L1 In-Memory Cache using Moka
2//!
3//! High-performance concurrent cache for vectors with LRU eviction.
4
5use common::Vector;
6use futures_util::{stream::FuturesUnordered, StreamExt};
7use moka::future::Cache;
8use std::sync::Arc;
9use std::time::Duration;
10
11/// Configuration for the L1 cache
12#[derive(Debug, Clone)]
13pub struct CacheConfig {
14    /// Maximum number of vectors to cache
15    pub max_capacity: u64,
16    /// Time-to-live for cached entries
17    pub ttl: Option<Duration>,
18    /// Time-to-idle for cached entries (evict if not accessed)
19    pub tti: Option<Duration>,
20}
21
22impl Default for CacheConfig {
23    fn default() -> Self {
24        Self {
25            max_capacity: 100_000,
26            ttl: Some(Duration::from_secs(3600)), // 1 hour
27            tti: Some(Duration::from_secs(600)),  // 10 minutes idle
28        }
29    }
30}
31
32/// Cache key combining namespace and vector ID
33#[derive(Debug, Clone, Hash, Eq, PartialEq)]
34pub struct CacheKey {
35    pub namespace: Arc<str>,
36    pub vector_id: Arc<str>,
37}
38
39impl CacheKey {
40    pub fn new(namespace: impl AsRef<str>, vector_id: impl AsRef<str>) -> Self {
41        Self {
42            namespace: Arc::from(namespace.as_ref()),
43            vector_id: Arc::from(vector_id.as_ref()),
44        }
45    }
46}
47
48/// L1 in-memory vector cache
49#[derive(Clone)]
50pub struct VectorCache {
51    cache: Cache<CacheKey, Arc<Vector>>,
52    config: CacheConfig,
53}
54
55impl VectorCache {
56    /// Create a new cache with the given configuration
57    pub fn new(config: CacheConfig) -> Self {
58        let mut builder = Cache::builder()
59            .max_capacity(config.max_capacity)
60            .support_invalidation_closures();
61
62        if let Some(ttl) = config.ttl {
63            builder = builder.time_to_live(ttl);
64        }
65
66        if let Some(tti) = config.tti {
67            builder = builder.time_to_idle(tti);
68        }
69
70        let cache = builder.build();
71
72        Self { cache, config }
73    }
74
75    /// Create with default configuration
76    pub fn with_defaults() -> Self {
77        Self::new(CacheConfig::default())
78    }
79
80    /// Get a vector from the cache
81    pub async fn get(&self, namespace: &str, vector_id: &str) -> Option<Arc<Vector>> {
82        let key = CacheKey::new(namespace, vector_id);
83        self.cache.get(&key).await
84    }
85
86    /// Insert a vector into the cache
87    pub async fn insert(&self, namespace: &str, vector: Vector) {
88        let key = CacheKey::new(namespace, &vector.id);
89        self.cache.insert(key, Arc::new(vector)).await;
90    }
91
92    /// Insert multiple vectors into the cache
93    pub async fn insert_batch(&self, namespace: &str, vectors: Vec<Vector>) {
94        let mut futs: FuturesUnordered<_> = vectors
95            .into_iter()
96            .map(|v| self.insert(namespace, v))
97            .collect();
98        while futs.next().await.is_some() {}
99    }
100
101    /// Remove a vector from the cache
102    pub async fn remove(&self, namespace: &str, vector_id: &str) {
103        let key = CacheKey::new(namespace, vector_id);
104        self.cache.remove(&key).await;
105    }
106
107    /// Remove multiple vectors from the cache
108    pub async fn remove_batch(&self, namespace: &str, vector_ids: &[String]) {
109        for id in vector_ids {
110            self.remove(namespace, id).await;
111        }
112    }
113
114    /// Invalidate all entries for a namespace.
115    pub async fn invalidate_namespace(&self, namespace: &str) {
116        let ns: Arc<str> = Arc::from(namespace);
117        self.cache
118            .invalidate_entries_if(move |k, _v| *k.namespace == *ns)
119            .expect("invalidate_entries_if failed");
120        tracing::debug!(namespace = namespace, "Cache namespace invalidated");
121    }
122
123    /// Clear the entire cache
124    pub fn clear(&self) {
125        self.cache.invalidate_all();
126    }
127
128    /// Get cache statistics
129    pub fn stats(&self) -> CacheStats {
130        CacheStats {
131            entry_count: self.cache.entry_count(),
132            weighted_size: self.cache.weighted_size(),
133            max_capacity: self.config.max_capacity,
134        }
135    }
136
137    /// Run pending maintenance tasks (eviction, etc.)
138    pub async fn run_pending_tasks(&self) {
139        self.cache.run_pending_tasks().await;
140    }
141}
142
143/// Cache statistics
144#[derive(Debug, Clone)]
145pub struct CacheStats {
146    /// Number of entries in the cache
147    pub entry_count: u64,
148    /// Weighted size of the cache
149    pub weighted_size: u64,
150    /// Maximum capacity
151    pub max_capacity: u64,
152}
153
154impl CacheStats {
155    /// Cache utilization as a percentage
156    pub fn utilization(&self) -> f64 {
157        if self.max_capacity == 0 {
158            return 0.0;
159        }
160        (self.entry_count as f64 / self.max_capacity as f64) * 100.0
161    }
162}
163
164/// Cached storage wrapper that adds L1 caching to any VectorStorage
165pub struct CachedStorage<S> {
166    inner: S,
167    cache: VectorCache,
168    redis: Option<crate::RedisCache>,
169}
170
171impl<S> CachedStorage<S> {
172    pub fn new(inner: S, cache: VectorCache, redis: Option<crate::RedisCache>) -> Self {
173        Self {
174            inner,
175            cache,
176            redis,
177        }
178    }
179
180    pub fn with_default_cache(inner: S) -> Self {
181        Self::new(inner, VectorCache::with_defaults(), None)
182    }
183
184    pub fn cache(&self) -> &VectorCache {
185        &self.cache
186    }
187
188    pub fn inner(&self) -> &S {
189        &self.inner
190    }
191
192    pub fn redis(&self) -> Option<&crate::RedisCache> {
193        self.redis.as_ref()
194    }
195}
196
197#[async_trait::async_trait]
198impl<S: crate::VectorStorage> crate::VectorStorage for CachedStorage<S> {
199    async fn upsert(
200        &self,
201        namespace: &common::NamespaceId,
202        vectors: Vec<common::Vector>,
203    ) -> common::Result<usize> {
204        let count = self.inner.upsert(namespace, vectors.clone()).await?;
205        // Populate L1 cache with upserted vectors
206        self.cache.insert_batch(namespace, vectors.clone()).await;
207        // Populate L1.5 Redis and publish invalidation for HA peers
208        if let Some(ref redis) = self.redis {
209            redis.set_batch(namespace, &vectors).await;
210            let ids: Vec<String> = vectors.iter().map(|v| v.id.clone()).collect();
211            redis
212                .publish_invalidation(&crate::CacheInvalidation::Vectors {
213                    namespace: namespace.to_string(),
214                    ids,
215                })
216                .await;
217        }
218        Ok(count)
219    }
220
221    async fn get(
222        &self,
223        namespace: &common::NamespaceId,
224        ids: &[common::VectorId],
225    ) -> common::Result<Vec<common::Vector>> {
226        let mut found = Vec::new();
227        let mut missing_ids: Vec<String> = Vec::new();
228
229        // Check L1 Moka first
230        for id in ids {
231            if let Some(v) = self.cache.get(namespace, id).await {
232                found.push((*v).clone());
233            } else {
234                missing_ids.push(id.clone());
235            }
236        }
237        if missing_ids.is_empty() {
238            return Ok(found);
239        }
240
241        // Check L1.5 Redis
242        if let Some(ref redis) = self.redis {
243            let from_redis = redis.get_multi(namespace, &missing_ids).await;
244            let redis_found_ids: std::collections::HashSet<String> =
245                from_redis.iter().map(|v| v.id.clone()).collect();
246            for v in &from_redis {
247                self.cache.insert(namespace, v.clone()).await; // backfill L1
248            }
249            found.extend(from_redis);
250            missing_ids.retain(|id| !redis_found_ids.contains(id));
251        }
252        if missing_ids.is_empty() {
253            return Ok(found);
254        }
255
256        // Fall through to backing store
257        let from_store = self.inner.get(namespace, &missing_ids).await?;
258        for v in &from_store {
259            self.cache.insert(namespace, v.clone()).await; // backfill L1
260            if let Some(ref redis) = self.redis {
261                redis.set(namespace, v).await; // backfill L1.5
262            }
263        }
264        found.extend(from_store);
265        Ok(found)
266    }
267
268    async fn get_all(
269        &self,
270        namespace: &common::NamespaceId,
271    ) -> common::Result<Vec<common::Vector>> {
272        let vectors = self.inner.get_all(namespace).await?;
273        // Backfill L1 cache so subsequent individual get() calls will hit
274        for v in &vectors {
275            self.cache.insert(namespace, v.clone()).await;
276        }
277        // Backfill L1.5 Redis if configured
278        if let Some(ref redis) = self.redis {
279            redis.set_batch(namespace, &vectors).await;
280        }
281        Ok(vectors)
282    }
283
284    async fn delete(
285        &self,
286        namespace: &common::NamespaceId,
287        ids: &[common::VectorId],
288    ) -> common::Result<usize> {
289        let count = self.inner.delete(namespace, ids).await?;
290        self.cache.remove_batch(namespace, ids).await;
291        if let Some(ref redis) = self.redis {
292            let id_strings: Vec<String> = ids.iter().map(|s| s.to_string()).collect();
293            redis.delete(namespace, &id_strings).await;
294            redis
295                .publish_invalidation(&crate::CacheInvalidation::Vectors {
296                    namespace: namespace.to_string(),
297                    ids: id_strings,
298                })
299                .await;
300        }
301        Ok(count)
302    }
303
304    async fn namespace_exists(&self, namespace: &common::NamespaceId) -> common::Result<bool> {
305        self.inner.namespace_exists(namespace).await
306    }
307
308    async fn ensure_namespace(&self, namespace: &common::NamespaceId) -> common::Result<()> {
309        self.inner.ensure_namespace(namespace).await
310    }
311
312    async fn count(&self, namespace: &common::NamespaceId) -> common::Result<usize> {
313        self.inner.count(namespace).await
314    }
315
316    async fn dimension(&self, namespace: &common::NamespaceId) -> common::Result<Option<usize>> {
317        self.inner.dimension(namespace).await
318    }
319
320    async fn list_namespaces(&self) -> common::Result<Vec<common::NamespaceId>> {
321        self.inner.list_namespaces().await
322    }
323
324    async fn delete_namespace(&self, namespace: &common::NamespaceId) -> common::Result<bool> {
325        let result = self.inner.delete_namespace(namespace).await?;
326        self.cache.invalidate_namespace(namespace).await;
327        if let Some(ref redis) = self.redis {
328            redis.invalidate_namespace(namespace).await;
329            redis
330                .publish_invalidation(&crate::CacheInvalidation::Namespace(namespace.to_string()))
331                .await;
332        }
333        Ok(result)
334    }
335
336    async fn cleanup_expired(&self, namespace: &common::NamespaceId) -> common::Result<usize> {
337        self.inner.cleanup_expired(namespace).await
338    }
339
340    async fn cleanup_all_expired(&self) -> common::Result<usize> {
341        self.inner.cleanup_all_expired().await
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[tokio::test]
350    async fn test_cache_insert_and_get() {
351        let cache = VectorCache::with_defaults();
352
353        let vector = Vector {
354            id: "v1".to_string(),
355            values: vec![1.0, 2.0, 3.0],
356            metadata: None,
357            ttl_seconds: None,
358            expires_at: None,
359        };
360
361        cache.insert("test_ns", vector.clone()).await;
362
363        let retrieved = cache.get("test_ns", "v1").await;
364        assert!(retrieved.is_some());
365
366        let retrieved = retrieved.unwrap();
367        assert_eq!(retrieved.id, "v1");
368        assert_eq!(retrieved.values, vec![1.0, 2.0, 3.0]);
369    }
370
371    #[tokio::test]
372    async fn test_cache_miss() {
373        let cache = VectorCache::with_defaults();
374
375        let retrieved = cache.get("test_ns", "nonexistent").await;
376        assert!(retrieved.is_none());
377    }
378
379    #[tokio::test]
380    async fn test_cache_remove() {
381        let cache = VectorCache::with_defaults();
382
383        let vector = Vector {
384            id: "v1".to_string(),
385            values: vec![1.0, 2.0, 3.0],
386            metadata: None,
387            ttl_seconds: None,
388            expires_at: None,
389        };
390
391        cache.insert("test_ns", vector).await;
392        assert!(cache.get("test_ns", "v1").await.is_some());
393
394        cache.remove("test_ns", "v1").await;
395        cache.run_pending_tasks().await;
396
397        assert!(cache.get("test_ns", "v1").await.is_none());
398    }
399
400    #[tokio::test]
401    async fn test_cache_batch_operations() {
402        let cache = VectorCache::with_defaults();
403
404        let vectors = vec![
405            Vector {
406                id: "v1".to_string(),
407                values: vec![1.0],
408                metadata: None,
409                ttl_seconds: None,
410                expires_at: None,
411            },
412            Vector {
413                id: "v2".to_string(),
414                values: vec![2.0],
415                metadata: None,
416                ttl_seconds: None,
417                expires_at: None,
418            },
419            Vector {
420                id: "v3".to_string(),
421                values: vec![3.0],
422                metadata: None,
423                ttl_seconds: None,
424                expires_at: None,
425            },
426        ];
427
428        cache.insert_batch("test_ns", vectors).await;
429
430        assert!(cache.get("test_ns", "v1").await.is_some());
431        assert!(cache.get("test_ns", "v2").await.is_some());
432        assert!(cache.get("test_ns", "v3").await.is_some());
433
434        cache
435            .remove_batch("test_ns", &["v1".to_string(), "v2".to_string()])
436            .await;
437        cache.run_pending_tasks().await;
438
439        assert!(cache.get("test_ns", "v1").await.is_none());
440        assert!(cache.get("test_ns", "v2").await.is_none());
441        assert!(cache.get("test_ns", "v3").await.is_some());
442    }
443
444    #[tokio::test]
445    async fn test_cache_stats() {
446        let cache = VectorCache::new(CacheConfig {
447            max_capacity: 1000,
448            ttl: None,
449            tti: None,
450        });
451
452        for i in 0..10 {
453            let vector = Vector {
454                id: format!("v{}", i),
455                values: vec![i as f32],
456                metadata: None,
457                ttl_seconds: None,
458                expires_at: None,
459            };
460            cache.insert("test_ns", vector).await;
461        }
462
463        // Verify entries are retrievable
464        for i in 0..10 {
465            assert!(cache.get("test_ns", &format!("v{}", i)).await.is_some());
466        }
467
468        let stats = cache.stats();
469        assert_eq!(stats.max_capacity, 1000);
470    }
471
472    #[tokio::test]
473    async fn test_cache_namespace_isolation() {
474        let cache = VectorCache::with_defaults();
475
476        let v1 = Vector {
477            id: "same_id".to_string(),
478            values: vec![1.0],
479            metadata: None,
480            ttl_seconds: None,
481            expires_at: None,
482        };
483
484        let v2 = Vector {
485            id: "same_id".to_string(),
486            values: vec![2.0],
487            metadata: None,
488            ttl_seconds: None,
489            expires_at: None,
490        };
491
492        cache.insert("ns1", v1).await;
493        cache.insert("ns2", v2).await;
494
495        let from_ns1 = cache.get("ns1", "same_id").await.unwrap();
496        let from_ns2 = cache.get("ns2", "same_id").await.unwrap();
497
498        assert_eq!(from_ns1.values, vec![1.0]);
499        assert_eq!(from_ns2.values, vec![2.0]);
500    }
501
502    #[tokio::test]
503    async fn test_cache_clear() {
504        let cache = VectorCache::with_defaults();
505
506        for i in 0..5 {
507            let vector = Vector {
508                id: format!("v{}", i),
509                values: vec![i as f32],
510                metadata: None,
511                ttl_seconds: None,
512                expires_at: None,
513            };
514            cache.insert("test_ns", vector).await;
515        }
516
517        // Verify entries exist before clear
518        for i in 0..5 {
519            assert!(cache.get("test_ns", &format!("v{}", i)).await.is_some());
520        }
521
522        cache.clear();
523        cache.run_pending_tasks().await;
524
525        // Verify entries are gone after clear
526        for i in 0..5 {
527            assert!(cache.get("test_ns", &format!("v{}", i)).await.is_none());
528        }
529    }
530}