1use common::Vector;
6use futures_util::{stream::FuturesUnordered, StreamExt};
7use moka::future::Cache;
8use std::sync::Arc;
9use std::time::Duration;
10
11#[derive(Debug, Clone)]
13pub struct CacheConfig {
14 pub max_capacity: u64,
16 pub ttl: Option<Duration>,
18 pub tti: Option<Duration>,
20}
21
22impl Default for CacheConfig {
23 fn default() -> Self {
24 Self {
25 max_capacity: 100_000,
26 ttl: Some(Duration::from_secs(3600)), tti: Some(Duration::from_secs(600)), }
29 }
30}
31
32#[derive(Debug, Clone, Hash, Eq, PartialEq)]
34pub struct CacheKey {
35 pub namespace: Arc<str>,
36 pub vector_id: Arc<str>,
37}
38
39impl CacheKey {
40 pub fn new(namespace: impl AsRef<str>, vector_id: impl AsRef<str>) -> Self {
41 Self {
42 namespace: Arc::from(namespace.as_ref()),
43 vector_id: Arc::from(vector_id.as_ref()),
44 }
45 }
46}
47
48#[derive(Clone)]
50pub struct VectorCache {
51 cache: Cache<CacheKey, Arc<Vector>>,
52 config: CacheConfig,
53}
54
55impl VectorCache {
56 pub fn new(config: CacheConfig) -> Self {
58 let mut builder = Cache::builder()
59 .max_capacity(config.max_capacity)
60 .support_invalidation_closures();
61
62 if let Some(ttl) = config.ttl {
63 builder = builder.time_to_live(ttl);
64 }
65
66 if let Some(tti) = config.tti {
67 builder = builder.time_to_idle(tti);
68 }
69
70 let cache = builder.build();
71
72 Self { cache, config }
73 }
74
75 pub fn with_defaults() -> Self {
77 Self::new(CacheConfig::default())
78 }
79
80 pub async fn get(&self, namespace: &str, vector_id: &str) -> Option<Arc<Vector>> {
82 let key = CacheKey::new(namespace, vector_id);
83 self.cache.get(&key).await
84 }
85
86 pub async fn insert(&self, namespace: &str, vector: Vector) {
88 let key = CacheKey::new(namespace, &vector.id);
89 self.cache.insert(key, Arc::new(vector)).await;
90 }
91
92 pub async fn insert_batch(&self, namespace: &str, vectors: Vec<Vector>) {
94 let mut futs: FuturesUnordered<_> = vectors
95 .into_iter()
96 .map(|v| self.insert(namespace, v))
97 .collect();
98 while futs.next().await.is_some() {}
99 }
100
101 pub async fn remove(&self, namespace: &str, vector_id: &str) {
103 let key = CacheKey::new(namespace, vector_id);
104 self.cache.remove(&key).await;
105 }
106
107 pub async fn remove_batch(&self, namespace: &str, vector_ids: &[String]) {
109 for id in vector_ids {
110 self.remove(namespace, id).await;
111 }
112 }
113
114 pub async fn invalidate_namespace(&self, namespace: &str) {
116 let ns: Arc<str> = Arc::from(namespace);
117 self.cache
118 .invalidate_entries_if(move |k, _v| *k.namespace == *ns)
119 .expect("invalidate_entries_if failed");
120 tracing::debug!(namespace = namespace, "Cache namespace invalidated");
121 }
122
123 pub fn clear(&self) {
125 self.cache.invalidate_all();
126 }
127
128 pub fn stats(&self) -> CacheStats {
130 CacheStats {
131 entry_count: self.cache.entry_count(),
132 weighted_size: self.cache.weighted_size(),
133 max_capacity: self.config.max_capacity,
134 }
135 }
136
137 pub async fn run_pending_tasks(&self) {
139 self.cache.run_pending_tasks().await;
140 }
141}
142
143#[derive(Debug, Clone)]
145pub struct CacheStats {
146 pub entry_count: u64,
148 pub weighted_size: u64,
150 pub max_capacity: u64,
152}
153
154impl CacheStats {
155 pub fn utilization(&self) -> f64 {
157 if self.max_capacity == 0 {
158 return 0.0;
159 }
160 (self.entry_count as f64 / self.max_capacity as f64) * 100.0
161 }
162}
163
164pub struct CachedStorage<S> {
166 inner: S,
167 cache: VectorCache,
168 redis: Option<crate::RedisCache>,
169}
170
171impl<S> CachedStorage<S> {
172 pub fn new(inner: S, cache: VectorCache, redis: Option<crate::RedisCache>) -> Self {
173 Self {
174 inner,
175 cache,
176 redis,
177 }
178 }
179
180 pub fn with_default_cache(inner: S) -> Self {
181 Self::new(inner, VectorCache::with_defaults(), None)
182 }
183
184 pub fn cache(&self) -> &VectorCache {
185 &self.cache
186 }
187
188 pub fn inner(&self) -> &S {
189 &self.inner
190 }
191
192 pub fn redis(&self) -> Option<&crate::RedisCache> {
193 self.redis.as_ref()
194 }
195}
196
197#[async_trait::async_trait]
198impl<S: crate::VectorStorage> crate::VectorStorage for CachedStorage<S> {
199 async fn upsert(
200 &self,
201 namespace: &common::NamespaceId,
202 vectors: Vec<common::Vector>,
203 ) -> common::Result<usize> {
204 let count = self.inner.upsert(namespace, vectors.clone()).await?;
205 self.cache.insert_batch(namespace, vectors.clone()).await;
207 if let Some(ref redis) = self.redis {
209 redis.set_batch(namespace, &vectors).await;
210 let ids: Vec<String> = vectors.iter().map(|v| v.id.clone()).collect();
211 redis
212 .publish_invalidation(&crate::CacheInvalidation::Vectors {
213 namespace: namespace.to_string(),
214 ids,
215 })
216 .await;
217 }
218 Ok(count)
219 }
220
221 async fn get(
222 &self,
223 namespace: &common::NamespaceId,
224 ids: &[common::VectorId],
225 ) -> common::Result<Vec<common::Vector>> {
226 let mut found = Vec::new();
227 let mut missing_ids: Vec<String> = Vec::new();
228
229 for id in ids {
231 if let Some(v) = self.cache.get(namespace, id).await {
232 found.push((*v).clone());
233 } else {
234 missing_ids.push(id.clone());
235 }
236 }
237 if missing_ids.is_empty() {
238 return Ok(found);
239 }
240
241 if let Some(ref redis) = self.redis {
243 let from_redis = redis.get_multi(namespace, &missing_ids).await;
244 let redis_found_ids: std::collections::HashSet<String> =
245 from_redis.iter().map(|v| v.id.clone()).collect();
246 for v in &from_redis {
247 self.cache.insert(namespace, v.clone()).await; }
249 found.extend(from_redis);
250 missing_ids.retain(|id| !redis_found_ids.contains(id));
251 }
252 if missing_ids.is_empty() {
253 return Ok(found);
254 }
255
256 let from_store = self.inner.get(namespace, &missing_ids).await?;
258 for v in &from_store {
259 self.cache.insert(namespace, v.clone()).await; if let Some(ref redis) = self.redis {
261 redis.set(namespace, v).await; }
263 }
264 found.extend(from_store);
265 Ok(found)
266 }
267
268 async fn get_all(
269 &self,
270 namespace: &common::NamespaceId,
271 ) -> common::Result<Vec<common::Vector>> {
272 let vectors = self.inner.get_all(namespace).await?;
273 for v in &vectors {
275 self.cache.insert(namespace, v.clone()).await;
276 }
277 if let Some(ref redis) = self.redis {
279 redis.set_batch(namespace, &vectors).await;
280 }
281 Ok(vectors)
282 }
283
284 async fn delete(
285 &self,
286 namespace: &common::NamespaceId,
287 ids: &[common::VectorId],
288 ) -> common::Result<usize> {
289 let count = self.inner.delete(namespace, ids).await?;
290 self.cache.remove_batch(namespace, ids).await;
291 if let Some(ref redis) = self.redis {
292 let id_strings: Vec<String> = ids.iter().map(|s| s.to_string()).collect();
293 redis.delete(namespace, &id_strings).await;
294 redis
295 .publish_invalidation(&crate::CacheInvalidation::Vectors {
296 namespace: namespace.to_string(),
297 ids: id_strings,
298 })
299 .await;
300 }
301 Ok(count)
302 }
303
304 async fn namespace_exists(&self, namespace: &common::NamespaceId) -> common::Result<bool> {
305 self.inner.namespace_exists(namespace).await
306 }
307
308 async fn ensure_namespace(&self, namespace: &common::NamespaceId) -> common::Result<()> {
309 self.inner.ensure_namespace(namespace).await
310 }
311
312 async fn count(&self, namespace: &common::NamespaceId) -> common::Result<usize> {
313 self.inner.count(namespace).await
314 }
315
316 async fn dimension(&self, namespace: &common::NamespaceId) -> common::Result<Option<usize>> {
317 self.inner.dimension(namespace).await
318 }
319
320 async fn list_namespaces(&self) -> common::Result<Vec<common::NamespaceId>> {
321 self.inner.list_namespaces().await
322 }
323
324 async fn delete_namespace(&self, namespace: &common::NamespaceId) -> common::Result<bool> {
325 let result = self.inner.delete_namespace(namespace).await?;
326 self.cache.invalidate_namespace(namespace).await;
327 if let Some(ref redis) = self.redis {
328 redis.invalidate_namespace(namespace).await;
329 redis
330 .publish_invalidation(&crate::CacheInvalidation::Namespace(namespace.to_string()))
331 .await;
332 }
333 Ok(result)
334 }
335
336 async fn cleanup_expired(&self, namespace: &common::NamespaceId) -> common::Result<usize> {
337 self.inner.cleanup_expired(namespace).await
338 }
339
340 async fn cleanup_all_expired(&self) -> common::Result<usize> {
341 self.inner.cleanup_all_expired().await
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[tokio::test]
350 async fn test_cache_insert_and_get() {
351 let cache = VectorCache::with_defaults();
352
353 let vector = Vector {
354 id: "v1".to_string(),
355 values: vec![1.0, 2.0, 3.0],
356 metadata: None,
357 ttl_seconds: None,
358 expires_at: None,
359 };
360
361 cache.insert("test_ns", vector.clone()).await;
362
363 let retrieved = cache.get("test_ns", "v1").await;
364 assert!(retrieved.is_some());
365
366 let retrieved = retrieved.unwrap();
367 assert_eq!(retrieved.id, "v1");
368 assert_eq!(retrieved.values, vec![1.0, 2.0, 3.0]);
369 }
370
371 #[tokio::test]
372 async fn test_cache_miss() {
373 let cache = VectorCache::with_defaults();
374
375 let retrieved = cache.get("test_ns", "nonexistent").await;
376 assert!(retrieved.is_none());
377 }
378
379 #[tokio::test]
380 async fn test_cache_remove() {
381 let cache = VectorCache::with_defaults();
382
383 let vector = Vector {
384 id: "v1".to_string(),
385 values: vec![1.0, 2.0, 3.0],
386 metadata: None,
387 ttl_seconds: None,
388 expires_at: None,
389 };
390
391 cache.insert("test_ns", vector).await;
392 assert!(cache.get("test_ns", "v1").await.is_some());
393
394 cache.remove("test_ns", "v1").await;
395 cache.run_pending_tasks().await;
396
397 assert!(cache.get("test_ns", "v1").await.is_none());
398 }
399
400 #[tokio::test]
401 async fn test_cache_batch_operations() {
402 let cache = VectorCache::with_defaults();
403
404 let vectors = vec![
405 Vector {
406 id: "v1".to_string(),
407 values: vec![1.0],
408 metadata: None,
409 ttl_seconds: None,
410 expires_at: None,
411 },
412 Vector {
413 id: "v2".to_string(),
414 values: vec![2.0],
415 metadata: None,
416 ttl_seconds: None,
417 expires_at: None,
418 },
419 Vector {
420 id: "v3".to_string(),
421 values: vec![3.0],
422 metadata: None,
423 ttl_seconds: None,
424 expires_at: None,
425 },
426 ];
427
428 cache.insert_batch("test_ns", vectors).await;
429
430 assert!(cache.get("test_ns", "v1").await.is_some());
431 assert!(cache.get("test_ns", "v2").await.is_some());
432 assert!(cache.get("test_ns", "v3").await.is_some());
433
434 cache
435 .remove_batch("test_ns", &["v1".to_string(), "v2".to_string()])
436 .await;
437 cache.run_pending_tasks().await;
438
439 assert!(cache.get("test_ns", "v1").await.is_none());
440 assert!(cache.get("test_ns", "v2").await.is_none());
441 assert!(cache.get("test_ns", "v3").await.is_some());
442 }
443
444 #[tokio::test]
445 async fn test_cache_stats() {
446 let cache = VectorCache::new(CacheConfig {
447 max_capacity: 1000,
448 ttl: None,
449 tti: None,
450 });
451
452 for i in 0..10 {
453 let vector = Vector {
454 id: format!("v{}", i),
455 values: vec![i as f32],
456 metadata: None,
457 ttl_seconds: None,
458 expires_at: None,
459 };
460 cache.insert("test_ns", vector).await;
461 }
462
463 for i in 0..10 {
465 assert!(cache.get("test_ns", &format!("v{}", i)).await.is_some());
466 }
467
468 let stats = cache.stats();
469 assert_eq!(stats.max_capacity, 1000);
470 }
471
472 #[tokio::test]
473 async fn test_cache_namespace_isolation() {
474 let cache = VectorCache::with_defaults();
475
476 let v1 = Vector {
477 id: "same_id".to_string(),
478 values: vec![1.0],
479 metadata: None,
480 ttl_seconds: None,
481 expires_at: None,
482 };
483
484 let v2 = Vector {
485 id: "same_id".to_string(),
486 values: vec![2.0],
487 metadata: None,
488 ttl_seconds: None,
489 expires_at: None,
490 };
491
492 cache.insert("ns1", v1).await;
493 cache.insert("ns2", v2).await;
494
495 let from_ns1 = cache.get("ns1", "same_id").await.unwrap();
496 let from_ns2 = cache.get("ns2", "same_id").await.unwrap();
497
498 assert_eq!(from_ns1.values, vec![1.0]);
499 assert_eq!(from_ns2.values, vec![2.0]);
500 }
501
502 #[tokio::test]
503 async fn test_cache_clear() {
504 let cache = VectorCache::with_defaults();
505
506 for i in 0..5 {
507 let vector = Vector {
508 id: format!("v{}", i),
509 values: vec![i as f32],
510 metadata: None,
511 ttl_seconds: None,
512 expires_at: None,
513 };
514 cache.insert("test_ns", vector).await;
515 }
516
517 for i in 0..5 {
519 assert!(cache.get("test_ns", &format!("v{}", i)).await.is_some());
520 }
521
522 cache.clear();
523 cache.run_pending_tasks().await;
524
525 for i in 0..5 {
527 assert!(cache.get("test_ns", &format!("v{}", i)).await.is_none());
528 }
529 }
530}