use async_trait::async_trait;
use common::{DakeraError, NamespaceId, Result, Vector, VectorId};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use crate::traits::VectorStorage;
#[derive(Clone)]
pub struct InMemoryStorage {
namespaces: Arc<RwLock<HashMap<NamespaceId, HashMap<VectorId, Vector>>>>,
}
impl InMemoryStorage {
pub fn new() -> Self {
Self {
namespaces: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for InMemoryStorage {
fn default() -> Self {
Self::new()
}
}
fn now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[async_trait]
impl VectorStorage for InMemoryStorage {
async fn upsert(&self, namespace: &NamespaceId, vectors: Vec<Vector>) -> Result<usize> {
let mut namespaces = self.namespaces.write();
let ns = namespaces.entry(namespace.clone()).or_default();
if !vectors.is_empty() {
let now = now_secs();
let expected_dim = ns
.values()
.find(|v| !v.is_expired_at(now))
.map(|v| v.values.len())
.unwrap_or_else(|| vectors[0].values.len());
for v in &vectors {
if v.values.len() != expected_dim {
return Err(DakeraError::DimensionMismatch {
expected: expected_dim,
actual: v.values.len(),
});
}
}
}
let count = vectors.len();
for mut vector in vectors {
vector.apply_ttl();
ns.insert(vector.id.clone(), vector);
}
tracing::debug!(
namespace = %namespace,
count = count,
"Upserted vectors"
);
Ok(count)
}
async fn get(&self, namespace: &NamespaceId, ids: &[VectorId]) -> Result<Vec<Vector>> {
let namespaces = self.namespaces.read();
let ns = namespaces
.get(namespace)
.ok_or_else(|| DakeraError::NamespaceNotFound(namespace.clone()))?;
let now = now_secs();
Ok(ids
.iter()
.filter_map(|id| ns.get(id).cloned())
.filter(|v| !v.is_expired_at(now))
.collect())
}
async fn get_all(&self, namespace: &NamespaceId) -> Result<Vec<Vector>> {
let namespaces = self.namespaces.read();
let ns = namespaces
.get(namespace)
.ok_or_else(|| DakeraError::NamespaceNotFound(namespace.clone()))?;
let now = now_secs();
Ok(ns
.values()
.filter(|v| !v.is_expired_at(now))
.cloned()
.collect())
}
async fn delete(&self, namespace: &NamespaceId, ids: &[VectorId]) -> Result<usize> {
let mut namespaces = self.namespaces.write();
let ns = namespaces
.get_mut(namespace)
.ok_or_else(|| DakeraError::NamespaceNotFound(namespace.clone()))?;
let mut deleted = 0;
for id in ids {
if ns.remove(id).is_some() {
deleted += 1;
}
}
tracing::debug!(
namespace = %namespace,
deleted = deleted,
"Deleted vectors"
);
Ok(deleted)
}
async fn namespace_exists(&self, namespace: &NamespaceId) -> Result<bool> {
Ok(self.namespaces.read().contains_key(namespace))
}
async fn ensure_namespace(&self, namespace: &NamespaceId) -> Result<()> {
self.namespaces
.write()
.entry(namespace.clone())
.or_default();
Ok(())
}
async fn count(&self, namespace: &NamespaceId) -> Result<usize> {
let namespaces = self.namespaces.read();
let now = now_secs();
Ok(namespaces
.get(namespace)
.map(|ns| ns.values().filter(|v| !v.is_expired_at(now)).count())
.unwrap_or(0))
}
async fn dimension(&self, namespace: &NamespaceId) -> Result<Option<usize>> {
let namespaces = self.namespaces.read();
let now = now_secs();
Ok(namespaces
.get(namespace)
.and_then(|ns| ns.values().find(|v| !v.is_expired_at(now)))
.map(|v| v.values.len()))
}
async fn list_namespaces(&self) -> Result<Vec<NamespaceId>> {
let namespaces = self.namespaces.read();
Ok(namespaces.keys().cloned().collect())
}
async fn delete_namespace(&self, namespace: &NamespaceId) -> Result<bool> {
let mut namespaces = self.namespaces.write();
let existed = namespaces.remove(namespace).is_some();
if existed {
tracing::debug!(
namespace = %namespace,
"Deleted namespace"
);
}
Ok(existed)
}
async fn cleanup_expired(&self, namespace: &NamespaceId) -> Result<usize> {
let mut namespaces = self.namespaces.write();
let ns = match namespaces.get_mut(namespace) {
Some(ns) => ns,
None => return Ok(0),
};
let now = now_secs();
let before_count = ns.len();
ns.retain(|_, v| !v.is_expired_at(now));
let removed = before_count - ns.len();
if removed > 0 {
tracing::debug!(
namespace = %namespace,
removed = removed,
"Cleaned up expired vectors"
);
}
Ok(removed)
}
async fn cleanup_all_expired(&self) -> Result<usize> {
let mut namespaces = self.namespaces.write();
let mut total_removed = 0;
let now = now_secs();
for (namespace, ns) in namespaces.iter_mut() {
let before_count = ns.len();
ns.retain(|_, v| !v.is_expired_at(now));
let removed = before_count - ns.len();
total_removed += removed;
if removed > 0 {
tracing::debug!(
namespace = %namespace,
removed = removed,
"Cleaned up expired vectors"
);
}
}
if total_removed > 0 {
tracing::info!(
total_removed = total_removed,
"Cleaned up expired vectors across all namespaces"
);
}
Ok(total_removed)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_upsert_and_get() {
let storage = InMemoryStorage::new();
let namespace = "test".to_string();
let vectors = vec![Vector {
id: "v1".to_string(),
values: vec![1.0, 2.0, 3.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
}];
storage.ensure_namespace(&namespace).await.unwrap();
let count = storage.upsert(&namespace, vectors).await.unwrap();
assert_eq!(count, 1);
let retrieved = storage.get(&namespace, &["v1".to_string()]).await.unwrap();
assert_eq!(retrieved.len(), 1);
assert_eq!(retrieved[0].id, "v1");
}
#[tokio::test]
async fn test_dimension_mismatch() {
let storage = InMemoryStorage::new();
let namespace = "test".to_string();
storage.ensure_namespace(&namespace).await.unwrap();
storage
.upsert(
&namespace,
vec![Vector {
id: "v1".to_string(),
values: vec![1.0, 2.0, 3.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
}],
)
.await
.unwrap();
let result = storage
.upsert(
&namespace,
vec![Vector {
id: "v2".to_string(),
values: vec![1.0, 2.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
}],
)
.await;
assert!(matches!(result, Err(DakeraError::DimensionMismatch { .. })));
}
#[tokio::test]
async fn test_delete() {
let storage = InMemoryStorage::new();
let namespace = "test".to_string();
storage.ensure_namespace(&namespace).await.unwrap();
storage
.upsert(
&namespace,
vec![
Vector {
id: "v1".to_string(),
values: vec![1.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v2".to_string(),
values: vec![2.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
],
)
.await
.unwrap();
let deleted = storage
.delete(&namespace, &["v1".to_string()])
.await
.unwrap();
assert_eq!(deleted, 1);
let count = storage.count(&namespace).await.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn test_get_all() {
let storage = InMemoryStorage::new();
let namespace = "test".to_string();
storage.ensure_namespace(&namespace).await.unwrap();
storage
.upsert(
&namespace,
vec![
Vector {
id: "v1".to_string(),
values: vec![1.0, 2.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
Vector {
id: "v2".to_string(),
values: vec![3.0, 4.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
],
)
.await
.unwrap();
let all = storage.get_all(&namespace).await.unwrap();
assert_eq!(all.len(), 2);
}
#[tokio::test]
async fn test_ttl_expired_vectors_filtered() {
let storage = InMemoryStorage::new();
let namespace = "test".to_string();
storage.ensure_namespace(&namespace).await.unwrap();
let past_timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
- 100;
storage
.upsert(
&namespace,
vec![
Vector {
id: "expired".to_string(),
values: vec![1.0, 2.0],
metadata: None,
ttl_seconds: None,
expires_at: Some(past_timestamp),
},
Vector {
id: "valid".to_string(),
values: vec![3.0, 4.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
],
)
.await
.unwrap();
let retrieved = storage
.get(&namespace, &["expired".to_string(), "valid".to_string()])
.await
.unwrap();
assert_eq!(retrieved.len(), 1);
assert_eq!(retrieved[0].id, "valid");
let all = storage.get_all(&namespace).await.unwrap();
assert_eq!(all.len(), 1);
assert_eq!(all[0].id, "valid");
let count = storage.count(&namespace).await.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn test_ttl_applied_on_upsert() {
let storage = InMemoryStorage::new();
let namespace = "test".to_string();
storage.ensure_namespace(&namespace).await.unwrap();
storage
.upsert(
&namespace,
vec![Vector {
id: "with_ttl".to_string(),
values: vec![1.0, 2.0],
metadata: None,
ttl_seconds: Some(3600), expires_at: None,
}],
)
.await
.unwrap();
let namespaces = storage.namespaces.read();
let ns = namespaces.get(&namespace).unwrap();
let vector = ns.get("with_ttl").unwrap();
assert!(vector.expires_at.is_some());
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let expires_at = vector.expires_at.unwrap();
assert!(expires_at > now);
assert!(expires_at <= now + 3601); }
#[tokio::test]
async fn test_cleanup_expired() {
let storage = InMemoryStorage::new();
let namespace = "test".to_string();
storage.ensure_namespace(&namespace).await.unwrap();
let past_timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
- 100;
storage
.upsert(
&namespace,
vec![
Vector {
id: "expired1".to_string(),
values: vec![1.0],
metadata: None,
ttl_seconds: None,
expires_at: Some(past_timestamp),
},
Vector {
id: "expired2".to_string(),
values: vec![2.0],
metadata: None,
ttl_seconds: None,
expires_at: Some(past_timestamp),
},
Vector {
id: "valid".to_string(),
values: vec![3.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
],
)
.await
.unwrap();
{
let namespaces = storage.namespaces.read();
let ns = namespaces.get(&namespace).unwrap();
assert_eq!(ns.len(), 3);
}
let removed = storage.cleanup_expired(&namespace).await.unwrap();
assert_eq!(removed, 2);
{
let namespaces = storage.namespaces.read();
let ns = namespaces.get(&namespace).unwrap();
assert_eq!(ns.len(), 1);
assert!(ns.contains_key("valid"));
}
}
#[tokio::test]
async fn test_cleanup_all_expired() {
let storage = InMemoryStorage::new();
let ns1 = "test1".to_string();
let ns2 = "test2".to_string();
storage.ensure_namespace(&ns1).await.unwrap();
storage.ensure_namespace(&ns2).await.unwrap();
let past_timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
- 100;
storage
.upsert(
&ns1,
vec![Vector {
id: "expired".to_string(),
values: vec![1.0],
metadata: None,
ttl_seconds: None,
expires_at: Some(past_timestamp),
}],
)
.await
.unwrap();
storage
.upsert(
&ns2,
vec![
Vector {
id: "expired".to_string(),
values: vec![2.0],
metadata: None,
ttl_seconds: None,
expires_at: Some(past_timestamp),
},
Vector {
id: "valid".to_string(),
values: vec![3.0],
metadata: None,
ttl_seconds: None,
expires_at: None,
},
],
)
.await
.unwrap();
let removed = storage.cleanup_all_expired().await.unwrap();
assert_eq!(removed, 2);
{
let namespaces = storage.namespaces.read();
assert_eq!(namespaces.get(&ns1).unwrap().len(), 0);
assert_eq!(namespaces.get(&ns2).unwrap().len(), 1);
}
}
}