use crate::fact::FactId;
use crate::scope::Scope;
use crate::store::MemoryError;
use crate::vector::{VectorFilter, VectorMatch, VectorStore};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::RwLock;
struct VectorEntry {
embedding: Vec<f32>,
metadata: serde_json::Value,
norm: f32,
}
pub struct EmbeddedVectorStore {
dimensions: usize,
entries: RwLock<HashMap<FactId, VectorEntry>>,
}
impl EmbeddedVectorStore {
pub fn new(dimensions: usize) -> Self {
Self {
dimensions,
entries: RwLock::new(HashMap::new()),
}
}
fn compute_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
fn cosine_similarity(a: &[f32], a_norm: f32, b: &[f32], b_norm: f32) -> f32 {
if a_norm == 0.0 || b_norm == 0.0 {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
dot / (a_norm * b_norm)
}
}
#[async_trait]
impl VectorStore for EmbeddedVectorStore {
async fn upsert(
&self,
id: FactId,
embedding: Vec<f32>,
metadata: serde_json::Value,
) -> Result<(), MemoryError> {
if embedding.len() != self.dimensions {
return Err(MemoryError::Embedding(format!(
"dimension mismatch: expected {}, got {}",
self.dimensions,
embedding.len()
)));
}
let norm = Self::compute_norm(&embedding);
let entry = VectorEntry {
embedding,
metadata,
norm,
};
self.entries
.write()
.map_err(|e| MemoryError::Database(format!("lock poisoned: {e}")))?
.insert(id, entry);
Ok(())
}
async fn search(
&self,
query: &[f32],
filter: &VectorFilter,
top_k: usize,
) -> Result<Vec<VectorMatch>, MemoryError> {
let query_norm = Self::compute_norm(query);
let min_score = filter.min_score.unwrap_or(f32::NEG_INFINITY);
let entries = self
.entries
.read()
.map_err(|e| MemoryError::Database(format!("lock poisoned: {e}")))?;
let mut matches: Vec<VectorMatch> = entries
.iter()
.filter_map(|(id, entry)| {
if let Some(_filter_scope) = &filter.scope {
}
let score =
Self::cosine_similarity(query, query_norm, &entry.embedding, entry.norm);
if score < min_score {
return None;
}
Some(VectorMatch {
id: *id,
score,
metadata: entry.metadata.clone(),
})
})
.collect();
matches.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.id.cmp(&b.id))
});
matches.truncate(top_k);
Ok(matches)
}
async fn delete(&self, id: FactId) -> Result<(), MemoryError> {
self.entries
.write()
.map_err(|e| MemoryError::Database(format!("lock poisoned: {e}")))?
.remove(&id);
Ok(())
}
async fn delete_by_scope(&self, _scope: &Scope) -> Result<u64, MemoryError> {
let mut entries = self
.entries
.write()
.map_err(|e| MemoryError::Database(format!("lock poisoned: {e}")))?;
let count = entries.len() as u64;
entries.clear();
Ok(count)
}
}