Skip to main content

memory_mcp/index/
in_memory.rs

1use std::{collections::HashMap, path::Path, sync::Mutex};
2
3use crate::{
4    error::MemoryError,
5    index::VectorStore,
6    types::{Scope, ScopeFilter},
7};
8
9// ---------------------------------------------------------------------------
10// InMemoryState
11// ---------------------------------------------------------------------------
12
13struct InMemoryState {
14    /// All entries keyed by qualified name, storing (scope, vector).
15    entries: HashMap<String, (Scope, Vec<f32>)>,
16    /// Monotonic key counter.
17    key_counter: u64,
18    /// Maps qualified name → assigned key.
19    key_map: HashMap<String, u64>,
20    /// Stored commit SHA.
21    commit_sha: Option<String>,
22}
23
24// ---------------------------------------------------------------------------
25// InMemoryStore
26// ---------------------------------------------------------------------------
27
28/// HashMap-based `VectorStore` implementation for tests.
29///
30/// Performs brute-force cosine similarity search — no HNSW required.
31/// `save` and `load` are no-ops. `is_ready` is configurable.
32#[non_exhaustive]
33pub struct InMemoryStore {
34    state: Mutex<InMemoryState>,
35    dimensions: usize,
36    ready: bool,
37}
38
39impl InMemoryStore {
40    /// Create a new `InMemoryStore` with the given embedding dimensionality.
41    ///
42    /// `is_ready` defaults to `true`. Use [`InMemoryStore::with_ready`] to
43    /// override.
44    pub fn new(dimensions: usize) -> Self {
45        Self {
46            state: Mutex::new(InMemoryState {
47                entries: HashMap::new(),
48                key_counter: 0,
49                key_map: HashMap::new(),
50                commit_sha: None,
51            }),
52            dimensions,
53            ready: true,
54        }
55    }
56
57    /// Override the value returned by `is_ready()`.
58    pub fn with_ready(mut self, ready: bool) -> Self {
59        self.ready = ready;
60        self
61    }
62}
63
64/// Compute cosine similarity between two equal-length vectors.
65///
66/// Returns `0.0` if either vector is zero-magnitude.
67fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
68    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
69    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
70    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
71    if norm_a == 0.0 || norm_b == 0.0 {
72        0.0
73    } else {
74        (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
75    }
76}
77
78/// Convert cosine similarity to a distance metric (lower = more similar),
79/// matching the usearch convention.
80fn similarity_to_distance(sim: f32) -> f32 {
81    1.0 - sim
82}
83
84impl crate::index::sealed::Sealed for InMemoryStore {}
85
86impl VectorStore for InMemoryStore {
87    // Scope filtering relies on qualified_name encoding the scope (e.g.
88    // "global:foo" or "project:bar:baz"), which UsearchStoreInner guarantees.
89    // The scope parameter is stored alongside the entry for search filtering
90    // but is not used as a separate index key.
91    fn add(
92        &self,
93        scope: &Scope,
94        vector: &[f32],
95        qualified_name: String,
96    ) -> Result<u64, MemoryError> {
97        if vector.len() != self.dimensions {
98            return Err(MemoryError::InvalidInput {
99                reason: format!(
100                    "expected {} dimensions, got {}",
101                    self.dimensions,
102                    vector.len()
103                ),
104            });
105        }
106        let mut state = self
107            .state
108            .lock()
109            .expect("lock poisoned — prior panic corrupted state");
110        let key = state.key_counter;
111        state.key_counter = state
112            .key_counter
113            .checked_add(1)
114            .expect("key space exhausted");
115        state
116            .entries
117            .insert(qualified_name.clone(), (scope.clone(), vector.to_vec()));
118        state.key_map.insert(qualified_name, key);
119        Ok(key)
120    }
121
122    fn remove(&self, _scope: &Scope, qualified_name: &str) -> Result<(), MemoryError> {
123        let mut state = self
124            .state
125            .lock()
126            .expect("lock poisoned — prior panic corrupted state");
127        state.entries.remove(qualified_name);
128        state.key_map.remove(qualified_name);
129        Ok(())
130    }
131
132    fn search(
133        &self,
134        filter: &ScopeFilter,
135        query: &[f32],
136        limit: usize,
137    ) -> Result<Vec<(u64, String, f32)>, MemoryError> {
138        if query.len() != self.dimensions {
139            return Err(MemoryError::InvalidInput {
140                reason: format!(
141                    "expected {} dimensions, got {}",
142                    self.dimensions,
143                    query.len()
144                ),
145            });
146        }
147        let state = self
148            .state
149            .lock()
150            .expect("lock poisoned — prior panic corrupted state");
151
152        let mut candidates: Vec<(u64, String, f32)> = state
153            .entries
154            .iter()
155            .filter(|(_, (scope, _))| scope_matches(filter, scope))
156            .map(|(name, (_, vec))| {
157                let key = state
158                    .key_map
159                    .get(name)
160                    .copied()
161                    .expect("invariant: every entry has a key_map entry");
162                let sim = cosine_similarity(query, vec);
163                let dist = similarity_to_distance(sim);
164                (key, name.clone(), dist)
165            })
166            .collect();
167
168        // Sort by ascending distance (lower distance = more similar).
169        candidates.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
170        candidates.truncate(limit);
171        Ok(candidates)
172    }
173
174    fn find_by_name(&self, qualified_name: &str) -> Option<u64> {
175        let state = self
176            .state
177            .lock()
178            .expect("lock poisoned — prior panic corrupted state");
179        state.key_map.get(qualified_name).copied()
180    }
181
182    fn save(&self, _dir: &Path) -> Result<(), MemoryError> {
183        // No-op: InMemoryStore is an in-memory test double.
184        Ok(())
185    }
186
187    fn is_ready(&self) -> bool {
188        self.ready
189    }
190
191    fn dimensions(&self) -> usize {
192        self.dimensions
193    }
194
195    fn commit_sha(&self) -> Option<String> {
196        let state = self
197            .state
198            .lock()
199            .expect("lock poisoned — prior panic corrupted state");
200        state.commit_sha.clone()
201    }
202
203    fn set_commit_sha(&self, sha: Option<&str>) {
204        let mut state = self
205            .state
206            .lock()
207            .expect("lock poisoned — prior panic corrupted state");
208        state.commit_sha = sha.map(|s| s.to_owned());
209    }
210}
211
212/// Returns `true` if `scope` should be included given `filter`.
213fn scope_matches(filter: &ScopeFilter, scope: &Scope) -> bool {
214    match filter {
215        ScopeFilter::All => true,
216        ScopeFilter::GlobalOnly => matches!(scope, Scope::Global),
217        ScopeFilter::ProjectAndGlobal(project_name) => match scope {
218            Scope::Global => true,
219            Scope::Project(p) => p == project_name,
220        },
221    }
222}
223
224// ---------------------------------------------------------------------------
225// Tests
226// ---------------------------------------------------------------------------
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::index::VectorStore;
232
233    fn make_store() -> InMemoryStore {
234        InMemoryStore::new(4)
235    }
236
237    fn vec_a() -> Vec<f32> {
238        vec![1.0, 0.0, 0.0, 0.0]
239    }
240
241    fn vec_b() -> Vec<f32> {
242        vec![0.0, 1.0, 0.0, 0.0]
243    }
244
245    fn vec_c() -> Vec<f32> {
246        vec![0.0, 0.0, 1.0, 0.0]
247    }
248
249    // TC-02a: add + find_by_name
250    #[test]
251    fn tc02a_add_and_find_by_name() {
252        let store: &dyn VectorStore = &make_store();
253        store
254            .add(&Scope::Global, &vec_a(), "global/mem1".to_string())
255            .expect("add failed");
256        assert!(
257            store.find_by_name("global/mem1").is_some(),
258            "TC-02a: find_by_name should return Some after add"
259        );
260    }
261
262    // TC-02b: remove clears entry
263    #[test]
264    fn tc02b_remove_clears_entry() {
265        let store: &dyn VectorStore = &make_store();
266        store
267            .add(&Scope::Global, &vec_a(), "global/mem2".to_string())
268            .expect("add failed");
269        store
270            .remove(&Scope::Global, "global/mem2")
271            .expect("remove failed");
272        assert!(
273            store.find_by_name("global/mem2").is_none(),
274            "TC-02b: find_by_name should return None after remove"
275        );
276    }
277
278    // TC-02c: search with GlobalOnly filter
279    #[test]
280    fn tc02c_search_global_only() {
281        let store: &dyn VectorStore = &make_store();
282        let proj = Scope::Project("p".to_string());
283
284        store
285            .add(&Scope::Global, &vec_a(), "global/g1".to_string())
286            .expect("add global");
287        store
288            .add(&proj, &vec_b(), "projects/p/p1".to_string())
289            .expect("add project");
290
291        let results = store
292            .search(&ScopeFilter::GlobalOnly, &vec_a(), 10)
293            .expect("search failed");
294        let names: Vec<&str> = results.iter().map(|(_, n, _)| n.as_str()).collect();
295        assert!(names.contains(&"global/g1"), "should contain global");
296        assert!(
297            !names.contains(&"projects/p/p1"),
298            "should NOT contain project"
299        );
300    }
301
302    // TC-02d: search with ProjectAndGlobal filter
303    #[test]
304    fn tc02d_search_project_and_global() {
305        let store: &dyn VectorStore = &make_store();
306        let proj_a = Scope::Project("alpha".to_string());
307        let proj_b = Scope::Project("beta".to_string());
308
309        store
310            .add(&Scope::Global, &vec_a(), "global/g1".to_string())
311            .expect("add global");
312        store
313            .add(&proj_a, &vec_b(), "projects/alpha/a1".to_string())
314            .expect("add alpha");
315        store
316            .add(&proj_b, &vec_c(), "projects/beta/b1".to_string())
317            .expect("add beta");
318
319        let results = store
320            .search(
321                &ScopeFilter::ProjectAndGlobal("alpha".to_string()),
322                &vec_a(),
323                10,
324            )
325            .expect("search failed");
326        let names: Vec<&str> = results.iter().map(|(_, n, _)| n.as_str()).collect();
327        assert!(names.contains(&"global/g1"), "should contain global");
328        assert!(names.contains(&"projects/alpha/a1"), "should contain alpha");
329        assert!(
330            !names.contains(&"projects/beta/b1"),
331            "should NOT contain beta"
332        );
333    }
334
335    // TC-02e: search with All filter
336    #[test]
337    fn tc02e_search_all() {
338        let store: &dyn VectorStore = &make_store();
339        let proj = Scope::Project("foo".to_string());
340
341        store
342            .add(&Scope::Global, &vec_a(), "global/x".to_string())
343            .expect("add global");
344        store
345            .add(&proj, &vec_b(), "projects/foo/y".to_string())
346            .expect("add project");
347
348        let results = store
349            .search(&ScopeFilter::All, &vec_a(), 10)
350            .expect("search failed");
351        let names: Vec<&str> = results.iter().map(|(_, n, _)| n.as_str()).collect();
352        assert!(names.contains(&"global/x"), "all should include global");
353        assert!(
354            names.contains(&"projects/foo/y"),
355            "all should include project"
356        );
357    }
358
359    // TC-05c: InMemoryStore returns same MemoryError variants
360    // InMemoryStore's add/remove/search don't return errors in normal operation,
361    // but verify the Ok path works and returns appropriate types.
362    #[test]
363    fn tc05c_in_memory_store_returns_ok_variants() {
364        let store: &dyn VectorStore = &make_store();
365        let result = store.add(&Scope::Global, &vec_a(), "global/tc05c".to_string());
366        assert!(
367            result.is_ok(),
368            "TC-05c: add should return Ok, got: {:?}",
369            result
370        );
371        let result = store.search(&ScopeFilter::All, &vec_a(), 5);
372        assert!(result.is_ok(), "TC-05c: search should return Ok");
373        let result = store.remove(&Scope::Global, "global/tc05c");
374        assert!(result.is_ok(), "TC-05c: remove should return Ok");
375    }
376
377    // TC-06b: InMemoryStore::is_ready() returns configured value
378    #[test]
379    fn tc06b_in_memory_store_is_ready_default_true() {
380        let store = InMemoryStore::new(4);
381        assert!(
382            store.is_ready(),
383            "TC-06b: InMemoryStore::is_ready() should return true by default"
384        );
385    }
386
387    #[test]
388    fn tc06b_in_memory_store_is_ready_configured_false() {
389        let store = InMemoryStore::new(4).with_ready(false);
390        assert!(
391            !store.is_ready(),
392            "TC-06b: InMemoryStore::is_ready() should return false when configured so"
393        );
394    }
395
396    #[test]
397    fn in_memory_store_dimensions() {
398        let store = InMemoryStore::new(128);
399        assert_eq!(store.dimensions(), 128);
400    }
401
402    #[test]
403    fn in_memory_store_commit_sha_round_trip() {
404        let store: &dyn VectorStore = &InMemoryStore::new(4);
405        assert!(store.commit_sha().is_none());
406        store.set_commit_sha(Some("abc123"));
407        assert_eq!(store.commit_sha(), Some("abc123".to_string()));
408        store.set_commit_sha(None);
409        assert!(store.commit_sha().is_none());
410    }
411
412    #[test]
413    fn in_memory_store_save_is_noop() {
414        let store: &dyn VectorStore = &make_store();
415        let dir = tempfile::tempdir().expect("tempdir");
416        store
417            .add(&Scope::Global, &vec_a(), "global/save-test".to_string())
418            .expect("add");
419        let result = store.save(dir.path());
420        assert!(result.is_ok(), "save should be a no-op Ok");
421    }
422
423    #[test]
424    fn in_memory_store_search_results_sorted_by_distance() {
425        let store: &dyn VectorStore = &make_store();
426        // vec_a is [1,0,0,0]; searching for [1,0,0,0] should rank it first.
427        store
428            .add(&Scope::Global, &vec_a(), "global/closest".to_string())
429            .expect("add a");
430        store
431            .add(&Scope::Global, &vec_b(), "global/farther".to_string())
432            .expect("add b");
433
434        let results = store
435            .search(&ScopeFilter::All, &vec_a(), 10)
436            .expect("search");
437        assert_eq!(results.len(), 2);
438        // Distance to vec_a from query vec_a should be 0 (cosine distance = 1 - 1 = 0).
439        assert!(
440            results[0].2 <= results[1].2,
441            "results should be sorted by ascending distance"
442        );
443        assert_eq!(results[0].1, "global/closest");
444    }
445
446    #[test]
447    fn tc05c_in_memory_store_dimension_mismatch_returns_invalid_input() {
448        let store = InMemoryStore::new(4);
449        let wrong_dims = vec![1.0_f32, 0.0]; // 2 dims, store expects 4
450        let err = store
451            .add(&Scope::Global, &wrong_dims, "global/bad-dims".to_string())
452            .unwrap_err();
453        assert!(
454            matches!(err, MemoryError::InvalidInput { .. }),
455            "TC-05c: dimension mismatch should return InvalidInput, got: {:?}",
456            err
457        );
458    }
459
460    #[test]
461    fn in_memory_store_upsert_overwrites() {
462        let store: &dyn VectorStore = &make_store();
463        let name = "global/upsert-me".to_string();
464        let key1 = store
465            .add(&Scope::Global, &vec_a(), name.clone())
466            .expect("first add");
467        let key2 = store
468            .add(&Scope::Global, &vec_b(), name.clone())
469            .expect("second add");
470        // Keys should differ (monotonic counter).
471        assert_ne!(key1, key2);
472        // The latest key wins in find_by_name.
473        assert_eq!(store.find_by_name(&name), Some(key2));
474    }
475}