1use std::collections::HashMap;
8
9use async_trait::async_trait;
10use tokio::sync::RwLock;
11
12use crate::document::{Chunk, SearchResult};
13use crate::error::{RagError, Result};
14use crate::vectorstore::VectorStore;
15
16#[derive(Debug, Default)]
30pub struct InMemoryVectorStore {
31 collections: RwLock<HashMap<String, HashMap<String, Chunk>>>,
32}
33
34impl InMemoryVectorStore {
35 pub fn new() -> Self {
37 Self::default()
38 }
39}
40
41fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
46 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
47 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
48 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
49 if norm_a == 0.0 || norm_b == 0.0 {
50 return 0.0;
51 }
52 dot / (norm_a * norm_b)
53}
54
55#[async_trait]
56impl VectorStore for InMemoryVectorStore {
57 async fn create_collection(&self, name: &str, _dimensions: usize) -> Result<()> {
58 let mut collections = self.collections.write().await;
59 collections.entry(name.to_string()).or_default();
60 Ok(())
61 }
62
63 async fn delete_collection(&self, name: &str) -> Result<()> {
64 let mut collections = self.collections.write().await;
65 collections.remove(name);
66 Ok(())
67 }
68
69 async fn upsert(&self, collection: &str, chunks: &[Chunk]) -> Result<()> {
70 let mut collections = self.collections.write().await;
71 let store = collections.get_mut(collection).ok_or_else(|| RagError::VectorStoreError {
72 backend: "InMemory".to_string(),
73 message: format!("collection '{collection}' does not exist"),
74 })?;
75 for chunk in chunks {
76 store.insert(chunk.id.clone(), chunk.clone());
77 }
78 Ok(())
79 }
80
81 async fn delete(&self, collection: &str, ids: &[&str]) -> Result<()> {
82 let mut collections = self.collections.write().await;
83 let store = collections.get_mut(collection).ok_or_else(|| RagError::VectorStoreError {
84 backend: "InMemory".to_string(),
85 message: format!("collection '{collection}' does not exist"),
86 })?;
87 for id in ids {
88 store.remove(*id);
89 }
90 Ok(())
91 }
92
93 async fn search(
94 &self,
95 collection: &str,
96 embedding: &[f32],
97 top_k: usize,
98 ) -> Result<Vec<SearchResult>> {
99 let collections = self.collections.read().await;
100 let store = collections.get(collection).ok_or_else(|| RagError::VectorStoreError {
101 backend: "InMemory".to_string(),
102 message: format!("collection '{collection}' does not exist"),
103 })?;
104
105 let mut scored: Vec<SearchResult> = store
106 .values()
107 .map(|chunk| {
108 let score = cosine_similarity(&chunk.embedding, embedding);
109 SearchResult { chunk: chunk.clone(), score }
110 })
111 .collect();
112
113 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
114 scored.truncate(top_k);
115 Ok(scored)
116 }
117}