Skip to main content

cognate_rag/
memory.rs

1//! In-memory vector store implementation.
2//!
3//! Suitable for prototyping and testing.  All documents are stored in a
4//! `Vec` protected by an `Arc<Mutex>` and similarity is computed as
5//! cosine similarity over `f32` vectors.
6
7use crate::{Document, Vector, VectorStore};
8use async_trait::async_trait;
9use std::sync::{Arc, Mutex};
10
11/// A simple in-memory [`VectorStore`].
12///
13/// Thread-safe via `Arc<Mutex<_>>` — suitable for concurrent Axum handlers.
14/// For production use, replace with a dedicated vector database.
15#[derive(Debug, Clone, Default)]
16pub struct MemoryVectorStore {
17    documents: Arc<Mutex<Vec<Document>>>,
18}
19
20impl MemoryVectorStore {
21    /// Create an empty store.
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    /// Return the number of stored documents.
27    pub fn len(&self) -> usize {
28        self.documents.lock().unwrap().len()
29    }
30
31    /// Return `true` if the store contains no documents.
32    pub fn is_empty(&self) -> bool {
33        self.len() == 0
34    }
35}
36
37#[async_trait]
38impl VectorStore for MemoryVectorStore {
39    async fn add_documents(
40        &self,
41        docs: Vec<Document>,
42    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
43        self.documents.lock().unwrap().extend(docs);
44        Ok(())
45    }
46
47    async fn search(
48        &self,
49        query_vector: Vector,
50        limit: usize,
51    ) -> Result<Vec<Document>, Box<dyn std::error::Error + Send + Sync>> {
52        let documents = self.documents.lock().unwrap();
53
54        let mut scored: Vec<(f32, &Document)> = documents
55            .iter()
56            .filter_map(|doc| {
57                doc.embedding
58                    .as_ref()
59                    .map(|emb| (cosine_similarity(&query_vector, emb), doc))
60            })
61            .collect();
62
63        // Sort descending by score — NaN is treated as -infinity so it sinks to the bottom.
64        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Less));
65
66        Ok(scored
67            .into_iter()
68            .take(limit)
69            .map(|(_, doc)| doc.clone())
70            .collect())
71    }
72}
73
74/// Cosine similarity between two equal-length vectors.
75///
76/// Returns `0.0` if either vector has zero norm.
77fn cosine_similarity(v1: &[f32], v2: &[f32]) -> f32 {
78    let dot: f32 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
79    let n1: f32 = v1.iter().map(|a| a * a).sum::<f32>().sqrt();
80    let n2: f32 = v2.iter().map(|b| b * b).sum::<f32>().sqrt();
81    if n1 == 0.0 || n2 == 0.0 {
82        0.0
83    } else {
84        dot / (n1 * n2)
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use crate::VectorStore;
92
93    fn doc(id: &str, content: &str, embedding: Vec<f32>) -> Document {
94        Document {
95            id: id.to_string(),
96            content: content.to_string(),
97            metadata: serde_json::Value::Null,
98            embedding: Some(embedding),
99        }
100    }
101
102    #[tokio::test]
103    async fn test_search_returns_closest() {
104        let store = MemoryVectorStore::new();
105        store
106            .add_documents(vec![
107                doc("1", "close",  vec![1.0, 0.0, 0.0]),
108                doc("2", "far",    vec![0.0, 1.0, 0.0]),
109                doc("3", "medium", vec![0.7, 0.7, 0.0]),
110            ])
111            .await
112            .unwrap();
113
114        let results = store
115            .search(vec![1.0, 0.0, 0.0], 1)
116            .await
117            .unwrap();
118
119        assert_eq!(results.len(), 1);
120        assert_eq!(results[0].content, "close");
121    }
122
123    #[tokio::test]
124    async fn test_search_respects_limit() {
125        let store = MemoryVectorStore::new();
126        store
127            .add_documents(vec![
128                doc("a", "a", vec![1.0, 0.0]),
129                doc("b", "b", vec![0.8, 0.6]),
130                doc("c", "c", vec![0.0, 1.0]),
131            ])
132            .await
133            .unwrap();
134
135        let results = store.search(vec![1.0, 0.0], 2).await.unwrap();
136        assert_eq!(results.len(), 2);
137    }
138
139    #[test]
140    fn test_cosine_zero_vector() {
141        assert_eq!(cosine_similarity(&[0.0, 0.0], &[1.0, 0.0]), 0.0);
142    }
143}