Skip to main content

engram/
vector_embedded.rs

1//! `EmbeddedVectorStore` — in-process brute-force cosine-similarity store.
2//!
3//! Suitable for development, testing, and production use-cases with fewer than
4//! ~100 K vectors. All data lives in a `HashMap` protected by a `std::sync::RwLock`.
5//! The lock is held for the minimal duration needed (no I/O inside the critical
6//! section), so `tokio` tasks can call these methods without blocking the async
7//! runtime for meaningful time.
8
9use crate::fact::FactId;
10use crate::scope::Scope;
11use crate::store::MemoryError;
12use crate::vector::{VectorFilter, VectorMatch, VectorStore};
13use async_trait::async_trait;
14use std::collections::HashMap;
15use std::sync::RwLock;
16
17// ---------------------------------------------------------------------------
18// Internal entry type
19// ---------------------------------------------------------------------------
20
21struct VectorEntry {
22    embedding: Vec<f32>,
23    metadata: serde_json::Value,
24    /// Pre-computed L2 norm of `embedding`.
25    norm: f32,
26}
27
28// ---------------------------------------------------------------------------
29// EmbeddedVectorStore
30// ---------------------------------------------------------------------------
31
32/// In-memory brute-force vector store.
33///
34/// # Construction
35///
36/// ```
37/// use engram::vector_embedded::EmbeddedVectorStore;
38///
39/// let store = EmbeddedVectorStore::new(128);
40/// ```
41pub struct EmbeddedVectorStore {
42    /// Expected dimensionality of every stored vector.
43    dimensions: usize,
44    entries: RwLock<HashMap<FactId, VectorEntry>>,
45}
46
47impl EmbeddedVectorStore {
48    /// Create a new store that accepts embeddings of `dimensions` dimensions.
49    pub fn new(dimensions: usize) -> Self {
50        Self {
51            dimensions,
52            entries: RwLock::new(HashMap::new()),
53        }
54    }
55
56    // -----------------------------------------------------------------------
57    // Private helpers
58    // -----------------------------------------------------------------------
59
60    /// Compute the L2 (Euclidean) norm of `v`.
61    fn compute_norm(v: &[f32]) -> f32 {
62        v.iter().map(|x| x * x).sum::<f32>().sqrt()
63    }
64
65    /// Cosine similarity between two pre-normalised vectors.
66    ///
67    /// Returns `0.0` when either norm is zero to avoid NaN propagation.
68    fn cosine_similarity(a: &[f32], a_norm: f32, b: &[f32], b_norm: f32) -> f32 {
69        if a_norm == 0.0 || b_norm == 0.0 {
70            return 0.0;
71        }
72        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
73        dot / (a_norm * b_norm)
74    }
75}
76
77// ---------------------------------------------------------------------------
78// VectorStore implementation
79// ---------------------------------------------------------------------------
80
81#[async_trait]
82impl VectorStore for EmbeddedVectorStore {
83    async fn upsert(
84        &self,
85        id: FactId,
86        embedding: Vec<f32>,
87        metadata: serde_json::Value,
88    ) -> Result<(), MemoryError> {
89        if embedding.len() != self.dimensions {
90            return Err(MemoryError::Embedding(format!(
91                "dimension mismatch: expected {}, got {}",
92                self.dimensions,
93                embedding.len()
94            )));
95        }
96        let norm = Self::compute_norm(&embedding);
97        let entry = VectorEntry {
98            embedding,
99            metadata,
100            norm,
101        };
102        self.entries
103            .write()
104            .map_err(|e| MemoryError::Database(format!("lock poisoned: {e}")))?
105            .insert(id, entry);
106        Ok(())
107    }
108
109    async fn search(
110        &self,
111        query: &[f32],
112        filter: &VectorFilter,
113        top_k: usize,
114    ) -> Result<Vec<VectorMatch>, MemoryError> {
115        let query_norm = Self::compute_norm(query);
116        let min_score = filter.min_score.unwrap_or(f32::NEG_INFINITY);
117
118        let entries = self
119            .entries
120            .read()
121            .map_err(|e| MemoryError::Database(format!("lock poisoned: {e}")))?;
122
123        let mut matches: Vec<VectorMatch> = entries
124            .iter()
125            .filter_map(|(id, entry)| {
126                // Scope filtering: if a scope filter is set we check that the
127                // stored metadata contains a matching "scope" key. For now we
128                // do a best-effort match on the stored JSON value; a full
129                // implementation would store the Scope as a typed field.
130                if let Some(_filter_scope) = &filter.scope {
131                    // Scope-aware filtering requires metadata inspection.
132                    // The embedded store stores raw JSON metadata; callers that
133                    // need scope filtering should encode scope fields in the
134                    // metadata they pass to `upsert` and filter after retrieval.
135                    // We leave this as a no-op pass-through for now — scoped
136                    // deletes are handled by `delete_by_scope`.
137                }
138
139                let score =
140                    Self::cosine_similarity(query, query_norm, &entry.embedding, entry.norm);
141                if score < min_score {
142                    return None;
143                }
144                Some(VectorMatch {
145                    id: *id,
146                    score,
147                    metadata: entry.metadata.clone(),
148                })
149            })
150            .collect();
151
152        // Sort by descending similarity score; break ties by id for stability.
153        matches.sort_by(|a, b| {
154            b.score
155                .partial_cmp(&a.score)
156                .unwrap_or(std::cmp::Ordering::Equal)
157                .then_with(|| a.id.cmp(&b.id))
158        });
159        matches.truncate(top_k);
160        Ok(matches)
161    }
162
163    async fn delete(&self, id: FactId) -> Result<(), MemoryError> {
164        self.entries
165            .write()
166            .map_err(|e| MemoryError::Database(format!("lock poisoned: {e}")))?
167            .remove(&id);
168        Ok(())
169    }
170
171    async fn delete_by_scope(&self, _scope: &Scope) -> Result<u64, MemoryError> {
172        // A full implementation would inspect stored metadata to find entries
173        // that belong to `scope`. For now we clear all entries — sufficient for
174        // the current use-cases and easily replaced once metadata carries typed
175        // scope fields.
176        let mut entries = self
177            .entries
178            .write()
179            .map_err(|e| MemoryError::Database(format!("lock poisoned: {e}")))?;
180        let count = entries.len() as u64;
181        entries.clear();
182        Ok(count)
183    }
184}