use common::Vector;
use futures_util::{stream::FuturesUnordered, StreamExt};
use moka::future::Cache;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_capacity: u64,
pub ttl: Option<Duration>,
pub tti: Option<Duration>,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_capacity: 100_000,
ttl: Some(Duration::from_secs(3600)), tti: Some(Duration::from_secs(600)), }
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct CacheKey {
pub namespace: Arc<str>,
pub vector_id: Arc<str>,
}
impl CacheKey {
pub fn new(namespace: impl AsRef<str>, vector_id: impl AsRef<str>) -> Self {
Self {
namespace: Arc::from(namespace.as_ref()),
vector_id: Arc::from(vector_id.as_ref()),
}
}
}
#[derive(Clone)]
pub struct VectorCache {
cache: Cache<CacheKey, Arc<Vector>>,
config: CacheConfig,
}
impl VectorCache {
pub fn new(config: CacheConfig) -> Self {
let mut builder = Cache::builder()
.max_capacity(config.max_capacity)
.support_invalidation_closures();
if let Some(ttl) = config.ttl {
builder = builder.time_to_live(ttl);
}
if let Some(tti) = config.tti {
builder = builder.time_to_idle(tti);
}
let cache = builder.build();
Self { cache, config }
}
pub fn with_defaults() -> Self {
Self::new(CacheConfig::default())
}
pub async fn get(&self, namespace: &str, vector_id: &str) -> Option<Arc<Vector>> {
let key = CacheKey::new(namespace, vector_id);
self.cache.get(&key).await
}
pub async fn insert(&self, namespace: &str, vector: Vector) {
let key = CacheKey::new(namespace, &vector.id);
self.cache.insert(key, Arc::new(vector)).await;
}
pub async fn insert_batch(&self, namespace: &str, vectors: Vec<Vector>) {
let mut futs: FuturesUnordered<_> = vectors
.into_iter()
.map(|v| self.insert(namespace, v))
.collect();
while futs.next().await.is_some() {}
}
pub async fn remove(&self, namespace: &str, vector_id: &str) {
let key = CacheKey::new(namespace, vector_id);
self.cache.remove(&key).await;
}
pub async fn remove_batch(&self, namespace: &str, vector_ids: &[String]) {
for id in vector_ids {
self.remove(namespace, id).await;
}
}
pub async fn invalidate_namespace(&self, namespace: &str) {
let ns: Arc<str> = Arc::from(namespace);
self.cache
.invalidate_entries_if(move |k, _v| *k.namespace == *ns)
.expect("invalidate_entries_if failed");
tracing::debug!(namespace = namespace, "Cache namespace invalidated");
}
pub fn clear(&self) {
self.cache.invalidate_all();
}
pub fn stats(&self) -> CacheStats {
CacheStats {
entry_count: self.cache.entry_count(),
weighted_size: self.cache.weighted_size(),
max_capacity: self.config.max_capacity,
}
}
pub async fn run_pending_tasks(&self) {
self.cache.run_pending_tasks().await;
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub entry_count: u64,
pub weighted_size: u64,
pub max_capacity: u64,
}
impl CacheStats {
pub fn utilization(&self) -> f64 {
if self.max_capacity == 0 {
return 0.0;
}
(self.entry_count as f64 / self.max_capacity as f64) * 100.0
}
}
pub struct CachedStorage<S> {
inner: S,
cache: VectorCache,
redis: Option<crate::RedisCache>,
}
impl<S> CachedStorage<S> {
pub fn new(inner: S, cache: VectorCache, redis: Option<crate::RedisCache>) -> Self {
Self {
inner,
cache,
redis,
}
}
pub fn with_default_cache(inner: S) -> Self {
Self::new(inner, VectorCache::with_defaults(), None)
}
pub fn cache(&self) -> &VectorCache {
&self.cache
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn redis(&self) -> Option<&crate::RedisCache> {
self.redis.as_ref()
}
}
#[async_trait::async_trait]
impl<S: crate::VectorStorage> crate::VectorStorage for CachedStorage<S> {
async fn upsert(
&self,
namespace: &common::NamespaceId,
vectors: Vec<common::Vector>,
) -> common::Result<usize> {
let count = self.inner.upsert(namespace, vectors.clone()).await?;
self.cache.insert_batch(namespace, vectors.clone()).await;
if let Some(ref redis) = self.redis {
redis.set_batch(namespace, &vectors).await;
let ids: Vec<String> = vectors.iter().map(|v| v.id.clone()).collect();
redis
.publish_invalidation(&crate::CacheInvalidation::Vectors {
namespace: namespace.to_string(),
ids,
})
.await;
}
Ok(count)
}
async fn get(
&self,
namespace: &common::NamespaceId,
ids: &[common::VectorId],
) -> common::Result<Vec<common::Vector>> {
let mut found = Vec::new();
let mut missing_ids: Vec<String> = Vec::new();
for id in ids {
if let Some(v) = self.cache.get(namespace, id).await {
found.push((*v).clone());
} else {
missing_ids.push(id.clone());
}
}
if missing_ids.is_empty() {
return Ok(found);
}
if let Some(ref redis) = self.redis {
let from_redis = redis.get_multi(namespace, &missing_ids).await;
let redis_found_ids: std::collections::HashSet<String> =
from_redis.iter().map(|v| v.id.clone()).collect();
for v in &from_redis {
self.cache.insert(namespace, v.clone()).await; }
found.extend(from_redis);
missing_ids.retain(|id| !redis_found_ids.contains(id));
}
if missing_ids.is_empty() {
return Ok(found);
}
let from_store = self.inner.get(namespace, &missing_ids).await?;
for v in &from_store {
self.cache.insert(namespace, v.clone()).await; if let Some(ref redis) = self.redis {
redis.set(namespace, v).await; }
}
found.extend(from_store);
Ok(found)
}
async fn get_all(
&self,
namespace: &common::NamespaceId,
) -> common::Result<Vec<common::Vector>> {
let vectors = self.inner.get_all(namespace).await?;
for v in &vectors {
self.cache.insert(namespace, v.clone()).await;
}
if let Some(ref redis) = self.redis {
redis.set_batch(namespace, &vectors).await;
}
Ok(vectors)
}
async fn delete(
&self,
namespace: &common::NamespaceId,
ids: &[common::VectorId],
) -> common::Result<usize> {
let count = self.inner.delete(namespace, ids).await?;
self.cache.remove_batch(namespace, ids).await;
if let Some(ref redis) = self.redis {
let id_strings: Vec<String> = ids.iter().map(|s| s.to_string()).collect();
redis.delete(namespace, &id_strings).await;
redis
.publish_invalidation(&crate::CacheInvalidation::Vectors {
namespace: namespace.to_string(),
ids: id_strings,
})
.await;
}
Ok(count)
}
async fn namespace_exists(&self, namespace: &common::NamespaceId) -> common::Result<bool> {
self.inner.namespace_exists(namespace).await
}
async fn ensure_namespace(&self, namespace: &common::NamespaceId) -> common::Result<()> {
self.inner.ensure_namespace(namespace).await
}
async fn count(&self, namespace: &common::NamespaceId) -> common::Result<usize> {
self.inner.count(namespace).await
}
async fn dimension(&self, namespace: &common::NamespaceId) -> common::Result<Option<usize>> {
self.inner.dimension(namespace).await
}
async fn list_namespaces(&self) -> common::Result<Vec<common::NamespaceId>> {
self.inner.list_namespaces().await
}
async fn delete_namespace(&self, namespace: &common::NamespaceId) -> common::Result<bool> {
let result = self.inner.delete_namespace(namespace).await?;
self.cache.invalidate_namespace(namespace).await;
if let Some(ref redis) = self.redis {
redis.invalidate_namespace(namespace).await;
redis
.publish_invalidation(&crate::CacheInvalidation::Namespace(namespace.to_string()))
.await;
}
Ok(result)
}
async fn cleanup_expired(&self, namespace: &common::NamespaceId) -> common::Result<usize> {
self.inner.cleanup_expired(namespace).await
}
async fn cleanup_all_expired(&self) -> common::Result<usize> {
self.inner.cleanup_all_expired().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_cache_insert_and_get() {
let cache = VectorCache::with_defaults();
let vector = Vector {
id: "v1".to_string(),
values: vec![1.0, 2.0, 3.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
};
cache.insert("test_ns", vector.clone()).await;
let retrieved = cache.get("test_ns", "v1").await;
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.id, "v1");
assert_eq!(retrieved.values, vec![1.0, 2.0, 3.0]);
}
#[tokio::test]
async fn test_cache_miss() {
let cache = VectorCache::with_defaults();
let retrieved = cache.get("test_ns", "nonexistent").await;
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_cache_remove() {
let cache = VectorCache::with_defaults();
let vector = Vector {
id: "v1".to_string(),
values: vec![1.0, 2.0, 3.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
};
cache.insert("test_ns", vector).await;
assert!(cache.get("test_ns", "v1").await.is_some());
cache.remove("test_ns", "v1").await;
cache.run_pending_tasks().await;
assert!(cache.get("test_ns", "v1").await.is_none());
}
#[tokio::test]
async fn test_cache_batch_operations() {
let cache = VectorCache::with_defaults();
let vectors = vec![
Vector {
id: "v1".to_string(),
values: vec![1.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v2".to_string(),
values: vec![2.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v3".to_string(),
values: vec![3.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
];
cache.insert_batch("test_ns", vectors).await;
assert!(cache.get("test_ns", "v1").await.is_some());
assert!(cache.get("test_ns", "v2").await.is_some());
assert!(cache.get("test_ns", "v3").await.is_some());
cache
.remove_batch("test_ns", &["v1".to_string(), "v2".to_string()])
.await;
cache.run_pending_tasks().await;
assert!(cache.get("test_ns", "v1").await.is_none());
assert!(cache.get("test_ns", "v2").await.is_none());
assert!(cache.get("test_ns", "v3").await.is_some());
}
#[tokio::test]
async fn test_cache_stats() {
let cache = VectorCache::new(CacheConfig {
max_capacity: 1000,
ttl: None,
tti: None,
});
for i in 0..10 {
let vector = Vector {
id: format!("v{}", i),
values: vec![i as f32],
metadata: None,
ttl_seconds: None,
expires_at: None,
};
cache.insert("test_ns", vector).await;
}
for i in 0..10 {
assert!(cache.get("test_ns", &format!("v{}", i)).await.is_some());
}
let stats = cache.stats();
assert_eq!(stats.max_capacity, 1000);
}
#[tokio::test]
async fn test_cache_namespace_isolation() {
let cache = VectorCache::with_defaults();
let v1 = Vector {
id: "same_id".to_string(),
values: vec![1.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
};
let v2 = Vector {
id: "same_id".to_string(),
values: vec![2.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
};
cache.insert("ns1", v1).await;
cache.insert("ns2", v2).await;
let from_ns1 = cache.get("ns1", "same_id").await.unwrap();
let from_ns2 = cache.get("ns2", "same_id").await.unwrap();
assert_eq!(from_ns1.values, vec![1.0]);
assert_eq!(from_ns2.values, vec![2.0]);
}
#[tokio::test]
async fn test_cache_clear() {
let cache = VectorCache::with_defaults();
for i in 0..5 {
let vector = Vector {
id: format!("v{}", i),
values: vec![i as f32],
metadata: None,
ttl_seconds: None,
expires_at: None,
};
cache.insert("test_ns", vector).await;
}
for i in 0..5 {
assert!(cache.get("test_ns", &format!("v{}", i)).await.is_some());
}
cache.clear();
cache.run_pending_tasks().await;
for i in 0..5 {
assert!(cache.get("test_ns", &format!("v{}", i)).await.is_none());
}
}
}