Skip to main content

memory_mcp/
index.rs

1use std::{
2    collections::HashMap,
3    path::Path,
4    sync::{Mutex, RwLock},
5};
6
7use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
8
9use crate::{
10    error::MemoryError,
11    types::{validate_name, Scope, ScopeFilter},
12};
13
14// ---------------------------------------------------------------------------
15// VectorIndex
16// ---------------------------------------------------------------------------
17
18/// Internal state kept behind the mutex.
19struct VectorState {
20    index: Index,
21    /// Maps usearch u64 keys → memory name strings.
22    key_map: HashMap<u64, String>,
23    /// Reverse map: memory name strings → usearch u64 keys (derived from key_map).
24    name_map: HashMap<String, u64>,
25    /// Monotonic counter used to assign unique vector keys.
26    next_key: u64,
27    /// Commit SHA at the time this index was last saved/loaded.
28    commit_sha: Option<String>,
29}
30
31/// Wraps `usearch::Index` and a key-map behind a single `std::sync::Mutex`.
32///
33/// `usearch::Index` is `Send + Sync`, and `HashMap` is `Send`, so
34/// `VectorIndex` is `Send + Sync` via the mutex.
35pub struct VectorIndex {
36    state: Mutex<VectorState>,
37}
38
39impl VectorIndex {
40    /// Initial capacity reserved when creating a new index.
41    const INITIAL_CAPACITY: usize = 1024;
42
43    /// Create a new HNSW index with cosine metric.
44    pub fn new(dimensions: usize) -> Result<Self, MemoryError> {
45        let options = IndexOptions {
46            dimensions,
47            metric: MetricKind::Cos,
48            quantization: ScalarKind::F32,
49            ..Default::default()
50        };
51        let index =
52            Index::new(&options).map_err(|e| MemoryError::Index(format!("create: {}", e)))?;
53        // usearch requires reserve() before any add() calls.
54        index
55            .reserve(Self::INITIAL_CAPACITY)
56            .map_err(|e| MemoryError::Index(format!("reserve: {}", e)))?;
57        Ok(Self {
58            state: Mutex::new(VectorState {
59                index,
60                key_map: HashMap::new(),
61                name_map: HashMap::new(),
62                next_key: 0,
63                commit_sha: None,
64            }),
65        })
66    }
67
68    /// Grow the index if it doesn't have room for `additional` more vectors.
69    ///
70    /// Operates on an already-locked `VectorState` reference so callers that
71    /// already hold the lock can call this without re-locking.
72    fn grow_if_needed_inner(state: &VectorState, additional: usize) -> Result<(), MemoryError> {
73        let current_capacity = state.index.capacity();
74        let current_size = state.index.size();
75        if current_size + additional > current_capacity {
76            let new_capacity = (current_capacity + additional).max(current_capacity * 2);
77            state
78                .index
79                .reserve(new_capacity)
80                .map_err(|e| MemoryError::Index(format!("reserve: {}", e)))?;
81        }
82        Ok(())
83    }
84
85    /// Ensure the index has capacity for at least `additional` more vectors.
86    pub fn grow_if_needed(&self, additional: usize) -> Result<(), MemoryError> {
87        let state = self
88            .state
89            .lock()
90            .expect("lock poisoned — prior panic corrupted state");
91        Self::grow_if_needed_inner(&state, additional)
92    }
93
94    /// Atomically increment and return the next unique vector key.
95    #[cfg(test)]
96    pub fn next_key(&self) -> u64 {
97        let mut state = self
98            .state
99            .lock()
100            .expect("lock poisoned — prior panic corrupted state");
101        let key = state.next_key;
102        state.next_key += 1;
103        key
104    }
105
106    /// Find the vector key associated with a qualified memory name.
107    pub fn find_key_by_name(&self, name: &str) -> Option<u64> {
108        let state = self
109            .state
110            .lock()
111            .expect("lock poisoned — prior panic corrupted state");
112        state.name_map.get(name).copied()
113    }
114
115    /// Add a vector under the given key, growing the index if necessary.
116    #[cfg(test)]
117    pub fn add(&self, key: u64, vector: &[f32], name: String) -> Result<(), MemoryError> {
118        let mut state = self
119            .state
120            .lock()
121            .expect("lock poisoned — prior panic corrupted state");
122        Self::grow_if_needed_inner(&state, 1)?;
123        state
124            .index
125            .add(key, vector)
126            .map_err(|e| MemoryError::Index(format!("add: {}", e)))?;
127        state.name_map.insert(name.clone(), key);
128        state.key_map.insert(key, name);
129        Ok(())
130    }
131
132    /// Atomically allocate the next key and add the vector in one lock acquisition.
133    /// Returns the assigned key on success. On failure the counter is not advanced.
134    pub fn add_with_next_key(&self, vector: &[f32], name: String) -> Result<u64, MemoryError> {
135        let mut state = self
136            .state
137            .lock()
138            .expect("lock poisoned — prior panic corrupted state");
139        Self::grow_if_needed_inner(&state, 1)?;
140        let key = state.next_key;
141        state
142            .index
143            .add(key, vector)
144            .map_err(|e| MemoryError::Index(format!("add: {}", e)))?;
145        state.name_map.insert(name.clone(), key);
146        state.key_map.insert(key, name);
147        state.next_key = state
148            .next_key
149            .checked_add(1)
150            .expect("vector key space exhausted");
151        Ok(key)
152    }
153
154    /// Search for the `limit` nearest neighbours of `query`.
155    ///
156    /// Returns `(key, distance)` pairs sorted by ascending distance.
157    pub fn search(
158        &self,
159        query: &[f32],
160        limit: usize,
161    ) -> Result<Vec<(u64, String, f32)>, MemoryError> {
162        let state = self
163            .state
164            .lock()
165            .expect("lock poisoned — prior panic corrupted state");
166        let matches = state
167            .index
168            .search(query, limit)
169            .map_err(|e| MemoryError::Index(format!("search: {}", e)))?;
170
171        let results = matches
172            .keys
173            .into_iter()
174            .zip(matches.distances)
175            .filter_map(|(key, dist)| {
176                state
177                    .key_map
178                    .get(&key)
179                    .map(|name| (key, name.clone(), dist))
180            })
181            .collect();
182        Ok(results)
183    }
184
185    /// Remove a vector by key.
186    pub fn remove(&self, key: u64) -> Result<(), MemoryError> {
187        let mut state = self
188            .state
189            .lock()
190            .expect("lock poisoned — prior panic corrupted state");
191        state
192            .index
193            .remove(key)
194            .map_err(|e| MemoryError::Index(format!("remove: {}", e)))?;
195        if let Some(name) = state.key_map.remove(&key) {
196            // Only remove from name_map if it still points to this key.
197            // An upsert may have already updated name_map to point to a newer key.
198            if state.name_map.get(&name).copied() == Some(key) {
199                state.name_map.remove(&name);
200            }
201        }
202        Ok(())
203    }
204
205    /// Return the commit SHA stored in the index metadata (if any).
206    pub fn commit_sha(&self) -> Option<String> {
207        let state = self
208            .state
209            .lock()
210            .expect("lock poisoned — prior panic corrupted state");
211        state.commit_sha.clone()
212    }
213
214    /// Set the commit SHA in the index metadata.
215    pub fn set_commit_sha(&self, sha: Option<&str>) {
216        let mut state = self
217            .state
218            .lock()
219            .expect("lock poisoned — prior panic corrupted state");
220        state.commit_sha = sha.map(|s| s.to_owned());
221    }
222
223    /// Persist the index to `path`. Also writes `<path>.keys.json`.
224    ///
225    /// If `commit_sha` is `Some`, it is written to the metadata alongside the
226    /// key map so the next load can verify freshness.
227    pub fn save(&self, path: &Path) -> Result<(), MemoryError> {
228        let path_str = path.to_str().ok_or_else(|| MemoryError::InvalidInput {
229            reason: "non-UTF-8 index path".to_string(),
230        })?;
231
232        let state = self
233            .state
234            .lock()
235            .expect("lock poisoned — prior panic corrupted state");
236        state
237            .index
238            .save(path_str)
239            .map_err(|e| MemoryError::Index(format!("save: {}", e)))?;
240
241        // Persist the key map and counter alongside the index.
242        let keys_path = format!("{}.keys.json", path_str);
243        let payload = serde_json::json!({
244            "key_map": &state.key_map,
245            "next_key": state.next_key,
246            "commit_sha": state.commit_sha,
247        });
248        let json = serde_json::to_string(&payload)
249            .map_err(|e| MemoryError::Index(format!("keymap serialise: {}", e)))?;
250        std::fs::write(&keys_path, json)?;
251
252        Ok(())
253    }
254
255    /// Load an existing index from `path`. Also reads `<path>.keys.json`.
256    pub fn load(path: &Path) -> Result<Self, MemoryError> {
257        let path_str = path.to_str().ok_or_else(|| MemoryError::InvalidInput {
258            reason: "non-UTF-8 index path".to_string(),
259        })?;
260
261        // We need to know dimensions to create the IndexOptions for load.
262        // usearch::Index::load() restores dimensions from the file, so we
263        // use placeholder options here — they are overwritten on load.
264        let options = IndexOptions {
265            dimensions: 1, // overwritten by load()
266            metric: MetricKind::Cos,
267            quantization: ScalarKind::F32,
268            ..Default::default()
269        };
270        let index = Index::new(&options)
271            .map_err(|e| MemoryError::Index(format!("init for load: {}", e)))?;
272        index
273            .load(path_str)
274            .map_err(|e| MemoryError::Index(format!("load: {}", e)))?;
275
276        // Load the key map and counter.
277        let keys_path = format!("{}.keys.json", path_str);
278        let (key_map, next_key, commit_sha): (HashMap<u64, String>, u64, Option<String>) =
279            if std::path::Path::new(&keys_path).exists() {
280                let json = std::fs::read_to_string(&keys_path)?;
281                // Support both old format (bare HashMap) and new format ({key_map, next_key}).
282                let value: serde_json::Value = serde_json::from_str(&json)
283                    .map_err(|e| MemoryError::Index(format!("keymap deserialise: {}", e)))?;
284                if value.is_object() && value.get("key_map").is_some() {
285                    let km: HashMap<u64, String> = serde_json::from_value(value["key_map"].clone())
286                        .map_err(|e| MemoryError::Index(format!("keymap deserialise: {}", e)))?;
287                    let nk: u64 = value["next_key"]
288                        .as_u64()
289                        .unwrap_or_else(|| km.keys().max().map(|k| k + 1).unwrap_or(0));
290                    let sha: Option<String> = value
291                        .get("commit_sha")
292                        .and_then(|v| v.as_str())
293                        .map(|s| s.to_string());
294                    (km, nk, sha)
295                } else {
296                    // Legacy format: bare HashMap.
297                    let km: HashMap<u64, String> = serde_json::from_value(value)
298                        .map_err(|e| MemoryError::Index(format!("keymap deserialise: {}", e)))?;
299                    let nk = km.keys().max().map(|k| k + 1).unwrap_or(0);
300                    (km, nk, None)
301                }
302            } else {
303                (HashMap::new(), 0, None)
304            };
305
306        let name_map: HashMap<String, u64> = key_map.iter().map(|(&k, v)| (v.clone(), k)).collect();
307        if key_map.len() != name_map.len() {
308            tracing::warn!(
309                key_map_len = key_map.len(),
310                name_map_len = name_map.len(),
311                "key_map and name_map have different sizes; index may contain duplicate names"
312            );
313        }
314
315        Ok(Self {
316            state: Mutex::new(VectorState {
317                index,
318                key_map,
319                name_map,
320                next_key,
321                commit_sha,
322            }),
323        })
324    }
325}
326
327// ---------------------------------------------------------------------------
328// ScopedIndex
329// ---------------------------------------------------------------------------
330
331/// Manages multiple `VectorIndex` instances — one per scope (global, each
332/// project) plus a combined "all" index. Every memory exists in exactly two
333/// indexes: its scope-specific index + the "all" index.
334///
335/// `ScopedIndex` is `Send + Sync` because all inner state is protected by
336/// `RwLock` / `Mutex`.
337pub struct ScopedIndex {
338    /// Per-scope indexes (global + each project).
339    scopes: RwLock<HashMap<Scope, VectorIndex>>,
340    /// Combined index containing all vectors.
341    all: VectorIndex,
342    /// Embedding dimensions (needed to create new scope indexes).
343    dimensions: usize,
344}
345
346// Locking order: `scopes` (RwLock) is always acquired before any
347// `VectorIndex::state` (Mutex). Never hold a VectorIndex Mutex while
348// acquiring `scopes`. The `all` index is accessed directly (not through
349// `scopes`), but always while `scopes` is already held or after it has
350// been released — never in the reverse order.
351
352impl ScopedIndex {
353    /// Create a new `ScopedIndex` with empty global + all indexes.
354    pub fn new(dimensions: usize) -> Result<Self, MemoryError> {
355        let global = VectorIndex::new(dimensions)?;
356        let all = VectorIndex::new(dimensions)?;
357        let mut scopes = HashMap::new();
358        scopes.insert(Scope::Global, global);
359        Ok(Self {
360            scopes: RwLock::new(scopes),
361            all,
362            dimensions,
363        })
364    }
365
366    /// Insert `vector` into both the scope-specific index and the all-index.
367    ///
368    /// Handles upserts: if `qualified_name` already exists in either index, the
369    /// old entry is removed after the new one is successfully inserted.
370    ///
371    /// Returns the key assigned in the all-index.
372    pub fn add(
373        &self,
374        scope: &Scope,
375        vector: &[f32],
376        qualified_name: String,
377    ) -> Result<u64, MemoryError> {
378        // Write lock serialises the full find→insert→remove composite so
379        // concurrent upserts for the same name cannot interleave. Reads
380        // (via `search`) use a read lock and are not blocked by other reads.
381        let mut scopes = self.scopes.write().expect("scopes lock poisoned");
382
383        // Ensure scope index exists (inline, since we already hold write lock).
384        if !scopes.contains_key(scope) {
385            scopes.insert(scope.clone(), VectorIndex::new(self.dimensions)?);
386        }
387
388        let scope_idx = scopes
389            .get(scope)
390            .expect("scope index must exist after insert");
391
392        // Capture old keys before inserting new ones.
393        let old_scope_key = scope_idx.find_key_by_name(&qualified_name);
394        let old_all_key = self.all.find_key_by_name(&qualified_name);
395
396        // Insert into scope index first.
397        let new_scope_key = scope_idx.add_with_next_key(vector, qualified_name.clone())?;
398
399        // Insert into all-index; if this fails, roll back scope insert.
400        // Note: the rollback path is not unit-tested because usearch allocation
401        // failures are not injectable without a mock layer. The logic is simple
402        // (remove the key we just inserted) and covered by VectorIndex::remove's
403        // existing tests.
404        let all_key = match self.all.add_with_next_key(vector, qualified_name) {
405            Ok(key) => key,
406            Err(e) => {
407                let _ = scope_idx.remove(new_scope_key);
408                return Err(e);
409            }
410        };
411
412        // Both succeeded — now clean up old entries.
413        if let Some(key) = old_scope_key {
414            let _ = scope_idx.remove(key);
415        }
416        if let Some(key) = old_all_key {
417            let _ = self.all.remove(key);
418        }
419
420        Ok(all_key)
421    }
422
423    /// Remove a memory by qualified name from both the scope-specific index
424    /// and the all-index.
425    ///
426    /// Both removals are best-effort: an error in one does not prevent the
427    /// other from running. Returns `Ok(())` regardless of individual failures.
428    pub fn remove(&self, scope: &Scope, qualified_name: &str) -> Result<(), MemoryError> {
429        // Write lock serialises with concurrent adds for the same name.
430        let scopes = self.scopes.write().expect("scopes lock poisoned");
431
432        // Remove from scope index (best-effort).
433        if let Some(scope_idx) = scopes.get(scope) {
434            if let Some(key) = scope_idx.find_key_by_name(qualified_name) {
435                if let Err(e) = scope_idx.remove(key) {
436                    tracing::warn!(
437                        qualified_name = %qualified_name,
438                        error = %e,
439                        "scope index removal failed; continuing to all-index"
440                    );
441                }
442            }
443        }
444
445        // Remove from all-index (best-effort).
446        if let Some(key) = self.all.find_key_by_name(qualified_name) {
447            if let Err(e) = self.all.remove(key) {
448                tracing::warn!(
449                    qualified_name = %qualified_name,
450                    error = %e,
451                    "all-index removal failed"
452                );
453            }
454        }
455
456        Ok(())
457    }
458
459    /// Search for the nearest neighbours of `query`, routing to the correct
460    /// indexes based on `filter`.
461    ///
462    /// | `filter`               | Indexes searched          | Merge strategy             |
463    /// |------------------------|---------------------------|----------------------------|
464    /// | `GlobalOnly`           | `global`                  | Direct top-k               |
465    /// | `ProjectAndGlobal(p)`  | `global` + `projects/p`   | Merge by distance, top-k   |
466    /// | `All`                  | `all` combined index       | Direct top-k               |
467    pub fn search(
468        &self,
469        filter: &ScopeFilter,
470        query: &[f32],
471        limit: usize,
472    ) -> Result<Vec<(u64, String, f32)>, MemoryError> {
473        match filter {
474            ScopeFilter::All => self.all.search(query, limit),
475
476            ScopeFilter::GlobalOnly => {
477                let scopes = self.scopes.read().expect("scopes lock poisoned");
478                match scopes.get(&Scope::Global) {
479                    Some(global_idx) => global_idx.search(query, limit),
480                    None => Ok(Vec::new()),
481                }
482            }
483
484            ScopeFilter::ProjectAndGlobal(project_name) => {
485                let scopes = self.scopes.read().expect("scopes lock poisoned");
486                let project_scope = Scope::Project(project_name.clone());
487
488                let mut combined: Vec<(u64, String, f32)> = Vec::new();
489
490                if let Some(global_idx) = scopes.get(&Scope::Global) {
491                    let mut global_results = global_idx.search(query, limit)?;
492                    combined.append(&mut global_results);
493                }
494
495                if let Some(proj_idx) = scopes.get(&project_scope) {
496                    let mut proj_results = proj_idx.search(query, limit)?;
497                    combined.append(&mut proj_results);
498                }
499
500                // Deduplicate by qualified name (HashSet ensures non-adjacent dupes are caught).
501                let mut seen = std::collections::HashSet::new();
502                combined.retain(|(_, name, _)| seen.insert(name.clone()));
503                // Sort by ascending distance and take top-k.
504                combined.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
505                combined.truncate(limit);
506                Ok(combined)
507            }
508        }
509    }
510
511    /// Find the key for a given qualified name in the **all-index** (not scope-specific).
512    ///
513    /// This is the canonical lookup — the all-index contains every memory regardless of scope.
514    pub fn find_key_by_name(&self, qualified_name: &str) -> Option<u64> {
515        self.all.find_key_by_name(qualified_name)
516    }
517
518    /// Grow all indexes to accommodate `additional` more vectors.
519    ///
520    /// Reserved for future batch-insert operations; no production callers currently exist.
521    #[allow(dead_code)]
522    pub fn grow_if_needed(&self, additional: usize) -> Result<(), MemoryError> {
523        self.all.grow_if_needed(additional)?;
524        let scopes = self.scopes.read().expect("scopes lock poisoned");
525        for idx in scopes.values() {
526            idx.grow_if_needed(additional)?;
527        }
528        Ok(())
529    }
530
531    /// Persist all indexes to subdirectories under `dir`.
532    ///
533    /// Layout:
534    /// ```text
535    /// dir/
536    ///   all/index.usearch  (+ .keys.json)
537    ///   global/index.usearch
538    ///   projects/foo/index.usearch
539    /// ```
540    pub fn save(&self, dir: &Path) -> Result<(), MemoryError> {
541        std::fs::create_dir_all(dir)?;
542
543        // Write a dirty marker — if we crash mid-save, the next load will see
544        // this and ignore commit SHAs (forcing a fresh rebuild).
545        let marker = dir.join(".save-in-progress");
546        std::fs::write(&marker, b"")?;
547
548        // Persist all-index.
549        let all_dir = dir.join("all");
550        std::fs::create_dir_all(&all_dir)?;
551        self.all.save(&all_dir.join("index.usearch"))?;
552
553        // Persist per-scope indexes.
554        let scopes = self.scopes.read().expect("scopes lock poisoned");
555        for (scope, idx) in scopes.iter() {
556            let scope_dir = dir.join(scope.dir_prefix());
557            std::fs::create_dir_all(&scope_dir)?;
558            idx.save(&scope_dir.join("index.usearch"))?;
559        }
560
561        // Remove marker — save completed successfully.
562        let _ = std::fs::remove_file(&marker);
563
564        Ok(())
565    }
566
567    /// Load all indexes from subdirectories under `dir`.
568    ///
569    /// Missing subdirectories are treated as empty — those scopes will be
570    /// rebuilt incrementally on next use.
571    pub fn load(dir: &Path, dimensions: usize) -> Result<Self, MemoryError> {
572        // If a previous save was interrupted, the on-disk state may be
573        // inconsistent (some indexes from current state, others from prior).
574        // Rather than loading mixed data, start fresh — indexes are a cache
575        // that can always be rebuilt from the source-of-truth markdown files.
576        let dirty_marker = dir.join(".save-in-progress");
577        if dirty_marker.exists() {
578            tracing::warn!("detected interrupted index save — discarding indexes");
579            let _ = std::fs::remove_file(&dirty_marker);
580            return Self::new(dimensions);
581        }
582
583        // Load all-index.
584        let all_path = dir.join("all").join("index.usearch");
585        let all = if all_path.exists() {
586            VectorIndex::load(&all_path)?
587        } else {
588            VectorIndex::new(dimensions)?
589        };
590
591        let mut scopes: HashMap<Scope, VectorIndex> = HashMap::new();
592
593        // Load global index.
594        let global_path = dir.join("global").join("index.usearch");
595        let global = if global_path.exists() {
596            VectorIndex::load(&global_path)?
597        } else {
598            VectorIndex::new(dimensions)?
599        };
600        scopes.insert(Scope::Global, global);
601
602        // Scan for project indexes under projects/*/
603        let projects_dir = dir.join("projects");
604        if projects_dir.is_dir() {
605            let entries = std::fs::read_dir(&projects_dir)
606                .map_err(|e| MemoryError::Index(format!("read projects dir: {}", e)))?;
607            for entry in entries {
608                let entry =
609                    entry.map_err(|e| MemoryError::Index(format!("read dir entry: {}", e)))?;
610                let path = entry.path();
611                if path.is_dir() {
612                    let project_name = path
613                        .file_name()
614                        .and_then(|n| n.to_str())
615                        .map(|s| s.to_string())
616                        .ok_or_else(|| {
617                            MemoryError::Index("non-UTF-8 project directory name".to_string())
618                        })?;
619                    if let Err(e) = validate_name(&project_name) {
620                        tracing::warn!(
621                            project_name = %project_name,
622                            error = %e,
623                            "skipping project index with invalid name"
624                        );
625                        continue;
626                    }
627                    let index_path = path.join("index.usearch");
628                    if index_path.exists() {
629                        let idx = VectorIndex::load(&index_path)?;
630                        scopes.insert(Scope::Project(project_name), idx);
631                    }
632                }
633            }
634        }
635
636        Ok(Self {
637            scopes: RwLock::new(scopes),
638            all,
639            dimensions,
640        })
641    }
642
643    /// Read the commit SHA from the all-index metadata.
644    pub fn commit_sha(&self) -> Option<String> {
645        self.all.commit_sha()
646    }
647
648    /// Set the commit SHA on all sub-indexes.
649    pub fn set_commit_sha(&self, sha: Option<&str>) {
650        self.all.set_commit_sha(sha);
651        let scopes = self.scopes.read().expect("scopes lock poisoned");
652        for idx in scopes.values() {
653            idx.set_commit_sha(sha);
654        }
655    }
656}
657
658// ---------------------------------------------------------------------------
659// Tests
660// ---------------------------------------------------------------------------
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    fn make_index() -> VectorIndex {
667        VectorIndex::new(4).expect("failed to create index")
668    }
669
670    fn dummy_vec() -> Vec<f32> {
671        vec![1.0, 0.0, 0.0, 0.0]
672    }
673
674    /// Verify that `remove(old_key)` does NOT clobber `name_map` when an
675    /// upsert has already updated `name_map` to point to a newer key.
676    ///
677    /// Pattern: add_with_next_key("name") → old_key
678    ///          add_with_next_key("name") → new_key  (name_map now points to new_key)
679    ///          remove(old_key)
680    ///          find_key_by_name("name") must return new_key (not None)
681    #[test]
682    fn remove_old_key_does_not_clobber_upserted_name_map_entry() {
683        let index = make_index();
684        let v = dummy_vec();
685
686        // First insert — establishes old_key.
687        let old_key = index
688            .add_with_next_key(&v, "global/foo".to_string())
689            .expect("first add failed");
690
691        // Upsert (second insert for same name) — name_map now points to new_key.
692        let new_key = index
693            .add_with_next_key(&v, "global/foo".to_string())
694            .expect("second add failed");
695
696        assert_ne!(old_key, new_key, "keys must differ");
697
698        // Remove the OLD key — should not disturb name_map's entry for new_key.
699        index.remove(old_key).expect("remove failed");
700
701        // name_map must still resolve "global/foo" to new_key.
702        assert_eq!(
703            index.find_key_by_name("global/foo"),
704            Some(new_key),
705            "name_map entry for new_key was incorrectly removed"
706        );
707    }
708
709    /// Removing the current (only) key should clear the name_map entry.
710    #[test]
711    fn remove_only_key_clears_name_map() {
712        let index = make_index();
713        let v = dummy_vec();
714
715        let key = index
716            .add_with_next_key(&v, "global/bar".to_string())
717            .expect("add failed");
718
719        index.remove(key).expect("remove failed");
720
721        assert_eq!(
722            index.find_key_by_name("global/bar"),
723            None,
724            "name_map entry should have been cleared"
725        );
726    }
727
728    // -----------------------------------------------------------------------
729    // ScopedIndex tests
730    // -----------------------------------------------------------------------
731
732    fn make_scoped() -> ScopedIndex {
733        ScopedIndex::new(8).expect("failed to create scoped index")
734    }
735
736    fn vec_a() -> Vec<f32> {
737        vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
738    }
739
740    fn vec_b() -> Vec<f32> {
741        vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
742    }
743
744    fn vec_c() -> Vec<f32> {
745        vec![0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
746    }
747
748    #[test]
749    fn scoped_index_add_inserts_into_scope_and_all() {
750        let si = make_scoped();
751        let scope = Scope::Global;
752        let name = "global/memory-a".to_string();
753
754        si.add(&scope, &vec_a(), name.clone()).expect("add failed");
755
756        // Should be findable in the all-index via find_key_by_name.
757        assert!(
758            si.find_key_by_name(&name).is_some(),
759            "should be in all-index"
760        );
761
762        // Should also be in scope-specific index — verify via search.
763        let results = si
764            .search(&ScopeFilter::GlobalOnly, &vec_a(), 5)
765            .expect("search failed");
766        assert!(
767            results.iter().any(|(_, n, _)| n == &name),
768            "should be found in global search"
769        );
770    }
771
772    #[test]
773    fn scoped_index_remove_removes_from_both() {
774        let si = make_scoped();
775        let scope = Scope::Global;
776        let name = "global/memory-rm".to_string();
777
778        si.add(&scope, &vec_a(), name.clone()).expect("add failed");
779        assert!(si.find_key_by_name(&name).is_some(), "should exist");
780
781        si.remove(&scope, &name).expect("remove failed");
782
783        assert!(
784            si.find_key_by_name(&name).is_none(),
785            "should be gone from all-index"
786        );
787
788        let results = si
789            .search(&ScopeFilter::GlobalOnly, &vec_a(), 5)
790            .expect("search failed");
791        assert!(
792            !results.iter().any(|(_, n, _)| n == &name),
793            "should not appear in global search after removal"
794        );
795    }
796
797    #[test]
798    fn scoped_index_search_global_only() {
799        let si = make_scoped();
800        let proj = Scope::Project("myproj".to_string());
801
802        si.add(&Scope::Global, &vec_a(), "global/mem-global".to_string())
803            .expect("add global failed");
804        si.add(&proj, &vec_b(), "projects/myproj/mem-proj".to_string())
805            .expect("add project failed");
806
807        let results = si
808            .search(&ScopeFilter::GlobalOnly, &vec_a(), 5)
809            .expect("search failed");
810
811        let names: Vec<&str> = results.iter().map(|(_, n, _)| n.as_str()).collect();
812        assert!(
813            names.contains(&"global/mem-global"),
814            "should contain global"
815        );
816        assert!(
817            !names.contains(&"projects/myproj/mem-proj"),
818            "should NOT contain project memory"
819        );
820    }
821
822    #[test]
823    fn scoped_index_search_project_and_global() {
824        let si = make_scoped();
825        let proj_a = Scope::Project("alpha".to_string());
826        let proj_b = Scope::Project("beta".to_string());
827
828        si.add(&Scope::Global, &vec_a(), "global/g1".to_string())
829            .expect("add global failed");
830        si.add(&proj_a, &vec_b(), "projects/alpha/a1".to_string())
831            .expect("add alpha failed");
832        si.add(&proj_b, &vec_c(), "projects/beta/b1".to_string())
833            .expect("add beta failed");
834
835        let results = si
836            .search(
837                &ScopeFilter::ProjectAndGlobal("alpha".to_string()),
838                &vec_a(),
839                10,
840            )
841            .expect("search failed");
842
843        let names: Vec<&str> = results.iter().map(|(_, n, _)| n.as_str()).collect();
844        assert!(names.contains(&"global/g1"), "should contain global");
845        assert!(names.contains(&"projects/alpha/a1"), "should contain alpha");
846        assert!(
847            !names.contains(&"projects/beta/b1"),
848            "should NOT contain beta"
849        );
850    }
851
852    #[test]
853    fn scoped_index_search_all() {
854        let si = make_scoped();
855        let proj = Scope::Project("foo".to_string());
856
857        si.add(&Scope::Global, &vec_a(), "global/x".to_string())
858            .expect("add global");
859        si.add(&proj, &vec_b(), "projects/foo/y".to_string())
860            .expect("add project");
861
862        let results = si
863            .search(&ScopeFilter::All, &vec_a(), 10)
864            .expect("search failed");
865
866        let names: Vec<&str> = results.iter().map(|(_, n, _)| n.as_str()).collect();
867        assert!(names.contains(&"global/x"), "all should include global");
868        assert!(
869            names.contains(&"projects/foo/y"),
870            "all should include project"
871        );
872    }
873
874    #[test]
875    fn scoped_index_upsert_replaces_old_entry() {
876        let si = make_scoped();
877        let name = "global/memo".to_string();
878        si.add(&Scope::Global, &vec_a(), name.clone()).unwrap();
879        si.add(&Scope::Global, &vec_b(), name.clone()).unwrap();
880        // Should have exactly one entry in all-index search.
881        let results = si.search(&ScopeFilter::All, &vec_b(), 10).unwrap();
882        assert_eq!(
883            results.iter().filter(|(_, n, _)| n == &name).count(),
884            1,
885            "upsert should leave exactly one entry for the name"
886        );
887    }
888
889    #[test]
890    fn scoped_index_dirty_marker_discards_indexes() {
891        let dir = tempfile::tempdir().expect("tempdir");
892        let si = ScopedIndex::new(8).expect("create");
893        si.add(&Scope::Global, &vec_a(), "global/test-mem".to_string())
894            .expect("add");
895        si.set_commit_sha(Some("abc123"));
896        si.save(dir.path()).expect("save");
897
898        // Simulate interrupted save by re-creating the marker.
899        std::fs::write(dir.path().join(".save-in-progress"), b"").unwrap();
900
901        // Load should discard all indexes and return fresh empty ones.
902        let loaded = ScopedIndex::load(dir.path(), 8).expect("load");
903        assert!(
904            loaded.commit_sha().is_none(),
905            "dirty marker should result in no SHA"
906        );
907        assert!(
908            loaded.find_key_by_name("global/test-mem").is_none(),
909            "dirty marker should discard all indexed data"
910        );
911        assert!(
912            !dir.path().join(".save-in-progress").exists(),
913            "marker should be cleaned up"
914        );
915    }
916
917    #[test]
918    fn scoped_index_save_load_round_trip() {
919        let dir = tempfile::tempdir().expect("tempdir");
920        let si = ScopedIndex::new(8).expect("create");
921        let proj = Scope::Project("rtrip".to_string());
922
923        si.add(&Scope::Global, &vec_a(), "global/rt-global".to_string())
924            .expect("add global");
925        si.add(&proj, &vec_b(), "projects/rtrip/rt-proj".to_string())
926            .expect("add project");
927
928        si.save(dir.path()).expect("save failed");
929
930        let loaded = ScopedIndex::load(dir.path(), 8).expect("load failed");
931
932        // Verify all-index finds both memories.
933        assert!(
934            loaded.find_key_by_name("global/rt-global").is_some(),
935            "global memory should survive round-trip"
936        );
937        assert!(
938            loaded.find_key_by_name("projects/rtrip/rt-proj").is_some(),
939            "project memory should survive round-trip"
940        );
941
942        // Verify search still works after reload.
943        let results = loaded
944            .search(
945                &ScopeFilter::ProjectAndGlobal("rtrip".to_string()),
946                &vec_a(),
947                10,
948            )
949            .expect("search failed");
950        let names: Vec<&str> = results.iter().map(|(_, n, _)| n.as_str()).collect();
951        assert!(names.contains(&"global/rt-global"));
952        assert!(names.contains(&"projects/rtrip/rt-proj"));
953    }
954
955    #[test]
956    fn scoped_index_same_short_name_different_scopes_coexist() {
957        let si = make_scoped();
958        si.add(&Scope::Global, &vec_a(), "global/foo".to_string())
959            .unwrap();
960        si.add(
961            &Scope::Project("p".into()),
962            &vec_b(),
963            "projects/p/foo".to_string(),
964        )
965        .unwrap();
966        assert!(si.find_key_by_name("global/foo").is_some());
967        assert!(si.find_key_by_name("projects/p/foo").is_some());
968        assert_ne!(
969            si.find_key_by_name("global/foo"),
970            si.find_key_by_name("projects/p/foo"),
971            "different scopes should have distinct keys"
972        );
973    }
974}