use std::sync::Arc;
use tokio::sync::RwLock;
use crate::error::Result;
use crate::retriever::types::Document;
use crate::retriever::traits::Retriever;
struct StoredEntry {
document: Document,
embedding: Vec<f32>,
}
pub struct InMemoryVectorStore {
embedding_model: Arc<dyn crate::model::ErasedEmbeddingModel>,
entries: RwLock<Vec<StoredEntry>>,
}
impl InMemoryVectorStore {
pub fn new(embedding_model: Arc<dyn crate::model::ErasedEmbeddingModel>) -> Self {
Self {
embedding_model,
entries: RwLock::new(Vec::new()),
}
}
pub async fn add(&self, doc: Document) -> Result<()> {
let texts = [doc.content.as_str()];
let embeddings = self.embedding_model.embed_erased(&texts).await?;
let embedding = embeddings.into_iter().next().unwrap_or_default();
self.entries.write().await.push(StoredEntry {
document: doc,
embedding,
});
Ok(())
}
pub async fn add_many(&self, docs: Vec<Document>) -> Result<()> {
let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
let embeddings = self.embedding_model.embed_erased(&texts).await?;
let mut entries = self.entries.write().await;
for (doc, embedding) in docs.into_iter().zip(embeddings) {
entries.push(StoredEntry { document: doc, embedding });
}
Ok(())
}
pub async fn len(&self) -> usize {
self.entries.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.entries.read().await.is_empty()
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom == 0.0 { 0.0 } else { dot / denom }
}
impl Retriever for InMemoryVectorStore {
async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>> {
let texts = [query];
let query_embeddings = self.embedding_model.embed_erased(&texts).await?;
let query_vec = query_embeddings.into_iter().next().unwrap_or_default();
let entries = self.entries.read().await;
let mut scored: Vec<(f64, &StoredEntry)> = entries
.iter()
.map(|entry| {
let sim = cosine_similarity(&query_vec, &entry.embedding) as f64;
(sim, entry)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
Ok(scored
.into_iter()
.take(top_k)
.map(|(score, entry)| entry.document.clone().with_score(score))
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_empty() {
assert_eq!(cosine_similarity(&[], &[]), 0.0);
}
}