1use common::Vector;
6use futures_util::{stream::FuturesUnordered, StreamExt};
7use moka::future::Cache;
8use std::sync::Arc;
9use std::time::Duration;
10
11#[derive(Debug, Clone)]
13pub struct CacheConfig {
14 pub max_capacity: u64,
16 pub ttl: Option<Duration>,
18 pub tti: Option<Duration>,
20}
21
22impl Default for CacheConfig {
23 fn default() -> Self {
24 Self {
25 max_capacity: 100_000,
26 ttl: Some(Duration::from_secs(3600)), tti: Some(Duration::from_secs(600)), }
29 }
30}
31
32#[derive(Debug, Clone, Hash, Eq, PartialEq)]
34pub struct CacheKey {
35 pub namespace: Arc<str>,
36 pub vector_id: Arc<str>,
37}
38
39impl CacheKey {
40 pub fn new(namespace: impl AsRef<str>, vector_id: impl AsRef<str>) -> Self {
41 Self {
42 namespace: Arc::from(namespace.as_ref()),
43 vector_id: Arc::from(vector_id.as_ref()),
44 }
45 }
46}
47
48#[derive(Clone)]
50pub struct VectorCache {
51 cache: Cache<CacheKey, Arc<Vector>>,
52 config: CacheConfig,
53}
54
55impl VectorCache {
56 pub fn new(config: CacheConfig) -> Self {
58 let mut builder = Cache::builder().max_capacity(config.max_capacity);
59
60 if let Some(ttl) = config.ttl {
61 builder = builder.time_to_live(ttl);
62 }
63
64 if let Some(tti) = config.tti {
65 builder = builder.time_to_idle(tti);
66 }
67
68 let cache = builder.build();
69
70 Self { cache, config }
71 }
72
73 pub fn with_defaults() -> Self {
75 Self::new(CacheConfig::default())
76 }
77
78 pub async fn get(&self, namespace: &str, vector_id: &str) -> Option<Arc<Vector>> {
80 let key = CacheKey::new(namespace, vector_id);
81 self.cache.get(&key).await
82 }
83
84 pub async fn insert(&self, namespace: &str, vector: Vector) {
86 let key = CacheKey::new(namespace, &vector.id);
87 self.cache.insert(key, Arc::new(vector)).await;
88 }
89
90 pub async fn insert_batch(&self, namespace: &str, vectors: Vec<Vector>) {
92 let mut futs: FuturesUnordered<_> = vectors
93 .into_iter()
94 .map(|v| self.insert(namespace, v))
95 .collect();
96 while futs.next().await.is_some() {}
97 }
98
99 pub async fn remove(&self, namespace: &str, vector_id: &str) {
101 let key = CacheKey::new(namespace, vector_id);
102 self.cache.remove(&key).await;
103 }
104
105 pub async fn remove_batch(&self, namespace: &str, vector_ids: &[String]) {
107 for id in vector_ids {
108 self.remove(namespace, id).await;
109 }
110 }
111
112 pub async fn invalidate_namespace(&self, namespace: &str) {
114 let ns: Arc<str> = Arc::from(namespace);
115 self.cache
116 .invalidate_entries_if(move |k, _v| *k.namespace == *ns)
117 .expect("invalidate_entries_if failed");
118 tracing::debug!(namespace = namespace, "Cache namespace invalidated");
119 }
120
121 pub fn clear(&self) {
123 self.cache.invalidate_all();
124 }
125
126 pub fn stats(&self) -> CacheStats {
128 CacheStats {
129 entry_count: self.cache.entry_count(),
130 weighted_size: self.cache.weighted_size(),
131 max_capacity: self.config.max_capacity,
132 }
133 }
134
135 pub async fn run_pending_tasks(&self) {
137 self.cache.run_pending_tasks().await;
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct CacheStats {
144 pub entry_count: u64,
146 pub weighted_size: u64,
148 pub max_capacity: u64,
150}
151
152impl CacheStats {
153 pub fn utilization(&self) -> f64 {
155 if self.max_capacity == 0 {
156 return 0.0;
157 }
158 (self.entry_count as f64 / self.max_capacity as f64) * 100.0
159 }
160}
161
162pub struct CachedStorage<S> {
164 inner: S,
165 cache: VectorCache,
166 redis: Option<crate::RedisCache>,
167}
168
169impl<S> CachedStorage<S> {
170 pub fn new(inner: S, cache: VectorCache, redis: Option<crate::RedisCache>) -> Self {
171 Self {
172 inner,
173 cache,
174 redis,
175 }
176 }
177
178 pub fn with_default_cache(inner: S) -> Self {
179 Self::new(inner, VectorCache::with_defaults(), None)
180 }
181
182 pub fn cache(&self) -> &VectorCache {
183 &self.cache
184 }
185
186 pub fn inner(&self) -> &S {
187 &self.inner
188 }
189
190 pub fn redis(&self) -> Option<&crate::RedisCache> {
191 self.redis.as_ref()
192 }
193}
194
195#[async_trait::async_trait]
196impl<S: crate::VectorStorage> crate::VectorStorage for CachedStorage<S> {
197 async fn upsert(
198 &self,
199 namespace: &common::NamespaceId,
200 vectors: Vec<common::Vector>,
201 ) -> common::Result<usize> {
202 let count = self.inner.upsert(namespace, vectors.clone()).await?;
203 self.cache.insert_batch(namespace, vectors.clone()).await;
205 if let Some(ref redis) = self.redis {
207 redis.set_batch(namespace, &vectors).await;
208 let ids: Vec<String> = vectors.iter().map(|v| v.id.clone()).collect();
209 redis
210 .publish_invalidation(&crate::CacheInvalidation::Vectors {
211 namespace: namespace.to_string(),
212 ids,
213 })
214 .await;
215 }
216 Ok(count)
217 }
218
219 async fn get(
220 &self,
221 namespace: &common::NamespaceId,
222 ids: &[common::VectorId],
223 ) -> common::Result<Vec<common::Vector>> {
224 let mut found = Vec::new();
225 let mut missing_ids: Vec<String> = Vec::new();
226
227 for id in ids {
229 if let Some(v) = self.cache.get(namespace, id).await {
230 found.push((*v).clone());
231 } else {
232 missing_ids.push(id.clone());
233 }
234 }
235 if missing_ids.is_empty() {
236 return Ok(found);
237 }
238
239 if let Some(ref redis) = self.redis {
241 let from_redis = redis.get_multi(namespace, &missing_ids).await;
242 let redis_found_ids: std::collections::HashSet<String> =
243 from_redis.iter().map(|v| v.id.clone()).collect();
244 for v in &from_redis {
245 self.cache.insert(namespace, v.clone()).await; }
247 found.extend(from_redis);
248 missing_ids.retain(|id| !redis_found_ids.contains(id));
249 }
250 if missing_ids.is_empty() {
251 return Ok(found);
252 }
253
254 let from_store = self.inner.get(namespace, &missing_ids).await?;
256 for v in &from_store {
257 self.cache.insert(namespace, v.clone()).await; if let Some(ref redis) = self.redis {
259 redis.set(namespace, v).await; }
261 }
262 found.extend(from_store);
263 Ok(found)
264 }
265
266 async fn get_all(
267 &self,
268 namespace: &common::NamespaceId,
269 ) -> common::Result<Vec<common::Vector>> {
270 let vectors = self.inner.get_all(namespace).await?;
271 for v in &vectors {
273 self.cache.insert(namespace, v.clone()).await;
274 }
275 if let Some(ref redis) = self.redis {
277 redis.set_batch(namespace, &vectors).await;
278 }
279 Ok(vectors)
280 }
281
282 async fn delete(
283 &self,
284 namespace: &common::NamespaceId,
285 ids: &[common::VectorId],
286 ) -> common::Result<usize> {
287 let count = self.inner.delete(namespace, ids).await?;
288 self.cache.remove_batch(namespace, ids).await;
289 if let Some(ref redis) = self.redis {
290 let id_strings: Vec<String> = ids.iter().map(|s| s.to_string()).collect();
291 redis.delete(namespace, &id_strings).await;
292 redis
293 .publish_invalidation(&crate::CacheInvalidation::Vectors {
294 namespace: namespace.to_string(),
295 ids: id_strings,
296 })
297 .await;
298 }
299 Ok(count)
300 }
301
302 async fn namespace_exists(&self, namespace: &common::NamespaceId) -> common::Result<bool> {
303 self.inner.namespace_exists(namespace).await
304 }
305
306 async fn ensure_namespace(&self, namespace: &common::NamespaceId) -> common::Result<()> {
307 self.inner.ensure_namespace(namespace).await
308 }
309
310 async fn count(&self, namespace: &common::NamespaceId) -> common::Result<usize> {
311 self.inner.count(namespace).await
312 }
313
314 async fn dimension(&self, namespace: &common::NamespaceId) -> common::Result<Option<usize>> {
315 self.inner.dimension(namespace).await
316 }
317
318 async fn list_namespaces(&self) -> common::Result<Vec<common::NamespaceId>> {
319 self.inner.list_namespaces().await
320 }
321
322 async fn delete_namespace(&self, namespace: &common::NamespaceId) -> common::Result<bool> {
323 let result = self.inner.delete_namespace(namespace).await?;
324 self.cache.invalidate_namespace(namespace).await;
325 if let Some(ref redis) = self.redis {
326 redis.invalidate_namespace(namespace).await;
327 redis
328 .publish_invalidation(&crate::CacheInvalidation::Namespace(namespace.to_string()))
329 .await;
330 }
331 Ok(result)
332 }
333
334 async fn cleanup_expired(&self, namespace: &common::NamespaceId) -> common::Result<usize> {
335 self.inner.cleanup_expired(namespace).await
336 }
337
338 async fn cleanup_all_expired(&self) -> common::Result<usize> {
339 self.inner.cleanup_all_expired().await
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[tokio::test]
348 async fn test_cache_insert_and_get() {
349 let cache = VectorCache::with_defaults();
350
351 let vector = Vector {
352 id: "v1".to_string(),
353 values: vec![1.0, 2.0, 3.0],
354 metadata: None,
355 ttl_seconds: None,
356 expires_at: None,
357 };
358
359 cache.insert("test_ns", vector.clone()).await;
360
361 let retrieved = cache.get("test_ns", "v1").await;
362 assert!(retrieved.is_some());
363
364 let retrieved = retrieved.unwrap();
365 assert_eq!(retrieved.id, "v1");
366 assert_eq!(retrieved.values, vec![1.0, 2.0, 3.0]);
367 }
368
369 #[tokio::test]
370 async fn test_cache_miss() {
371 let cache = VectorCache::with_defaults();
372
373 let retrieved = cache.get("test_ns", "nonexistent").await;
374 assert!(retrieved.is_none());
375 }
376
377 #[tokio::test]
378 async fn test_cache_remove() {
379 let cache = VectorCache::with_defaults();
380
381 let vector = Vector {
382 id: "v1".to_string(),
383 values: vec![1.0, 2.0, 3.0],
384 metadata: None,
385 ttl_seconds: None,
386 expires_at: None,
387 };
388
389 cache.insert("test_ns", vector).await;
390 assert!(cache.get("test_ns", "v1").await.is_some());
391
392 cache.remove("test_ns", "v1").await;
393 cache.run_pending_tasks().await;
394
395 assert!(cache.get("test_ns", "v1").await.is_none());
396 }
397
398 #[tokio::test]
399 async fn test_cache_batch_operations() {
400 let cache = VectorCache::with_defaults();
401
402 let vectors = vec![
403 Vector {
404 id: "v1".to_string(),
405 values: vec![1.0],
406 metadata: None,
407 ttl_seconds: None,
408 expires_at: None,
409 },
410 Vector {
411 id: "v2".to_string(),
412 values: vec![2.0],
413 metadata: None,
414 ttl_seconds: None,
415 expires_at: None,
416 },
417 Vector {
418 id: "v3".to_string(),
419 values: vec![3.0],
420 metadata: None,
421 ttl_seconds: None,
422 expires_at: None,
423 },
424 ];
425
426 cache.insert_batch("test_ns", vectors).await;
427
428 assert!(cache.get("test_ns", "v1").await.is_some());
429 assert!(cache.get("test_ns", "v2").await.is_some());
430 assert!(cache.get("test_ns", "v3").await.is_some());
431
432 cache
433 .remove_batch("test_ns", &["v1".to_string(), "v2".to_string()])
434 .await;
435 cache.run_pending_tasks().await;
436
437 assert!(cache.get("test_ns", "v1").await.is_none());
438 assert!(cache.get("test_ns", "v2").await.is_none());
439 assert!(cache.get("test_ns", "v3").await.is_some());
440 }
441
442 #[tokio::test]
443 async fn test_cache_stats() {
444 let cache = VectorCache::new(CacheConfig {
445 max_capacity: 1000,
446 ttl: None,
447 tti: None,
448 });
449
450 for i in 0..10 {
451 let vector = Vector {
452 id: format!("v{}", i),
453 values: vec![i as f32],
454 metadata: None,
455 ttl_seconds: None,
456 expires_at: None,
457 };
458 cache.insert("test_ns", vector).await;
459 }
460
461 for i in 0..10 {
463 assert!(cache.get("test_ns", &format!("v{}", i)).await.is_some());
464 }
465
466 let stats = cache.stats();
467 assert_eq!(stats.max_capacity, 1000);
468 }
469
470 #[tokio::test]
471 async fn test_cache_namespace_isolation() {
472 let cache = VectorCache::with_defaults();
473
474 let v1 = Vector {
475 id: "same_id".to_string(),
476 values: vec![1.0],
477 metadata: None,
478 ttl_seconds: None,
479 expires_at: None,
480 };
481
482 let v2 = Vector {
483 id: "same_id".to_string(),
484 values: vec![2.0],
485 metadata: None,
486 ttl_seconds: None,
487 expires_at: None,
488 };
489
490 cache.insert("ns1", v1).await;
491 cache.insert("ns2", v2).await;
492
493 let from_ns1 = cache.get("ns1", "same_id").await.unwrap();
494 let from_ns2 = cache.get("ns2", "same_id").await.unwrap();
495
496 assert_eq!(from_ns1.values, vec![1.0]);
497 assert_eq!(from_ns2.values, vec![2.0]);
498 }
499
500 #[tokio::test]
501 async fn test_cache_clear() {
502 let cache = VectorCache::with_defaults();
503
504 for i in 0..5 {
505 let vector = Vector {
506 id: format!("v{}", i),
507 values: vec![i as f32],
508 metadata: None,
509 ttl_seconds: None,
510 expires_at: None,
511 };
512 cache.insert("test_ns", vector).await;
513 }
514
515 for i in 0..5 {
517 assert!(cache.get("test_ns", &format!("v{}", i)).await.is_some());
518 }
519
520 cache.clear();
521 cache.run_pending_tasks().await;
522
523 for i in 0..5 {
525 assert!(cache.get("test_ns", &format!("v{}", i)).await.is_none());
526 }
527 }
528}