Skip to main content

adk_rag/
inmemory.rs

1//! In-memory vector store using cosine similarity.
2//!
3//! This module provides [`InMemoryVectorStore`], a zero-dependency vector store
4//! backed by a `HashMap` protected by a `tokio::sync::RwLock`. It is suitable
5//! for development, testing, and small-scale use cases.
6
7use std::collections::HashMap;
8
9use async_trait::async_trait;
10use tokio::sync::RwLock;
11
12use crate::document::{Chunk, SearchResult};
13use crate::error::{RagError, Result};
14use crate::vectorstore::VectorStore;
15
16/// An in-memory vector store using cosine similarity for search.
17///
18/// Collections are stored as nested `HashMap`s: collection name → chunk ID → chunk.
19/// All operations are async-safe via `tokio::sync::RwLock`.
20///
21/// # Example
22///
23/// ```rust,ignore
24/// use adk_rag::{InMemoryVectorStore, VectorStore};
25///
26/// let store = InMemoryVectorStore::new();
27/// store.create_collection("docs", 384).await?;
28/// ```
29#[derive(Debug, Default)]
30pub struct InMemoryVectorStore {
31    collections: RwLock<HashMap<String, HashMap<String, Chunk>>>,
32}
33
34impl InMemoryVectorStore {
35    /// Create a new empty in-memory vector store.
36    pub fn new() -> Self {
37        Self::default()
38    }
39}
40
41/// Compute cosine similarity between two vectors.
42///
43/// Both vectors are L2-normalized before computing the dot product.
44/// Returns 0.0 if either vector has zero magnitude.
45fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
46    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
47    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
48    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
49    if norm_a == 0.0 || norm_b == 0.0 {
50        return 0.0;
51    }
52    dot / (norm_a * norm_b)
53}
54
55#[async_trait]
56impl VectorStore for InMemoryVectorStore {
57    async fn create_collection(&self, name: &str, _dimensions: usize) -> Result<()> {
58        let mut collections = self.collections.write().await;
59        collections.entry(name.to_string()).or_default();
60        Ok(())
61    }
62
63    async fn delete_collection(&self, name: &str) -> Result<()> {
64        let mut collections = self.collections.write().await;
65        collections.remove(name);
66        Ok(())
67    }
68
69    async fn upsert(&self, collection: &str, chunks: &[Chunk]) -> Result<()> {
70        let mut collections = self.collections.write().await;
71        let store = collections.get_mut(collection).ok_or_else(|| RagError::VectorStoreError {
72            backend: "InMemory".to_string(),
73            message: format!("collection '{collection}' does not exist"),
74        })?;
75        for chunk in chunks {
76            store.insert(chunk.id.clone(), chunk.clone());
77        }
78        Ok(())
79    }
80
81    async fn delete(&self, collection: &str, ids: &[&str]) -> Result<()> {
82        let mut collections = self.collections.write().await;
83        let store = collections.get_mut(collection).ok_or_else(|| RagError::VectorStoreError {
84            backend: "InMemory".to_string(),
85            message: format!("collection '{collection}' does not exist"),
86        })?;
87        for id in ids {
88            store.remove(*id);
89        }
90        Ok(())
91    }
92
93    async fn search(
94        &self,
95        collection: &str,
96        embedding: &[f32],
97        top_k: usize,
98    ) -> Result<Vec<SearchResult>> {
99        let collections = self.collections.read().await;
100        let store = collections.get(collection).ok_or_else(|| RagError::VectorStoreError {
101            backend: "InMemory".to_string(),
102            message: format!("collection '{collection}' does not exist"),
103        })?;
104
105        let mut scored: Vec<SearchResult> = store
106            .values()
107            .map(|chunk| {
108                let score = cosine_similarity(&chunk.embedding, embedding);
109                SearchResult { chunk: chunk.clone(), score }
110            })
111            .collect();
112
113        scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
114        scored.truncate(top_k);
115        Ok(scored)
116    }
117}