use async_trait::async_trait;
use rucora_core::error::ProviderError;
use rucora_core::retrieval::{SearchResult, VectorQuery, VectorRecord, VectorStore};
use dashmap::DashMap;
#[derive(Debug, Default, Clone)]
pub struct InMemoryVectorStore {
records: DashMap<String, VectorRecord>,
}
impl InMemoryVectorStore {
pub fn new() -> Self {
Self {
records: DashMap::new(),
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.is_empty() || b.is_empty() || a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude_a == 0.0 || magnitude_b == 0.0 {
return 0.0;
}
dot_product / (magnitude_a * magnitude_b)
}
}
#[async_trait]
impl VectorStore for InMemoryVectorStore {
async fn upsert(&self, records: Vec<VectorRecord>) -> Result<(), ProviderError> {
for record in records {
self.records.insert(record.id.clone(), record);
}
Ok(())
}
async fn delete(&self, ids: Vec<String>) -> Result<(), ProviderError> {
for id in ids {
self.records.remove(&id);
}
Ok(())
}
async fn get(&self, ids: Vec<String>) -> Result<Vec<VectorRecord>, ProviderError> {
Ok(ids
.iter()
.filter_map(|id| self.records.get(id).map(|r| r.clone()))
.collect())
}
async fn search(&self, query: VectorQuery) -> Result<Vec<SearchResult>, ProviderError> {
let mut results: Vec<SearchResult> = self
.records
.iter()
.map(|record| {
let score = Self::cosine_similarity(&query.vector, &record.vector);
SearchResult {
id: record.id.clone(),
score,
text: record.text.clone(),
metadata: record.metadata.clone(),
vector: None, }
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
if let Some(threshold) = query.score_threshold {
results.retain(|r| r.score >= threshold);
}
results.truncate(query.top_k);
Ok(results)
}
async fn clear(&self) -> Result<(), ProviderError> {
self.records.clear();
Ok(())
}
async fn count(&self) -> Result<usize, ProviderError> {
Ok(self.records.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_in_memory_vector_store_basic() {
let store = InMemoryVectorStore::new();
store
.upsert(vec![
VectorRecord::new("doc1", vec![1.0, 0.0]).with_text("文档 1"),
VectorRecord::new("doc2", vec![0.0, 1.0]).with_text("文档 2"),
])
.await
.unwrap();
assert_eq!(store.count().await.unwrap(), 2);
let records = store.get(vec!["doc1".to_string()]).await.unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0].id, "doc1");
let results = store
.search(VectorQuery::new(vec![1.0, 0.0]).with_top_k(10))
.await
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "doc1");
store.delete(vec!["doc1".to_string()]).await.unwrap();
assert_eq!(store.count().await.unwrap(), 1);
store.clear().await.unwrap();
assert_eq!(store.count().await.unwrap(), 0);
}
#[tokio::test]
async fn test_in_memory_vector_store_metadata() {
let store = InMemoryVectorStore::new();
store
.upsert(vec![
VectorRecord::new("a", vec![1.0, 0.0]).with_metadata(json!({"k": 1})),
VectorRecord::new("b", vec![0.9, 0.1]).with_metadata(json!({"k": 2})),
])
.await
.unwrap();
let results = store
.search(VectorQuery::new(vec![1.0, 0.0]).with_top_k(3))
.await
.unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].score > results[1].score);
}
#[tokio::test]
async fn test_cosine_similarity() {
let sim = InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]);
assert!((sim - 1.0).abs() < 0.001);
let sim = InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[0.0, 1.0]);
assert!(sim.abs() < 0.001);
let sim = InMemoryVectorStore::cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]);
assert!((sim + 1.0).abs() < 0.001);
}
}