Skip to main content

ai_memory/
hnsw.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! HNSW (Hierarchical Navigable Small World) vector index for fast approximate
5//! nearest-neighbor search over memory embeddings.
6//!
7//! Built on `instant-distance`. The index is constructed at startup from all
8//! stored embeddings. New memories added during the session go into an overflow
9//! list that is scanned linearly alongside the HNSW results — the index is
10//! rebuilt lazily once the overflow exceeds a threshold.
11
12use instant_distance::{Builder, HnswMap, Point, Search};
13use std::sync::Mutex;
14
15/// Maximum overflow entries before triggering a rebuild.
16const REBUILD_THRESHOLD: usize = 200;
17
18/// Maximum entries before evicting oldest to prevent unbounded memory growth.
19const MAX_ENTRIES: usize = 100_000;
20
21/// A point in the HNSW index — wraps a dense embedding vector.
22#[derive(Clone, Debug)]
23pub struct EmbeddingPoint(pub Vec<f32>);
24
25impl instant_distance::Point for EmbeddingPoint {
26    fn distance(&self, other: &Self) -> f32 {
27        // Cosine distance = 1 - cosine_similarity.
28        // Embeddings are L2-normalised so dot product = cosine similarity.
29        let dot: f32 = self.0.iter().zip(other.0.iter()).map(|(a, b)| a * b).sum();
30        1.0 - dot
31    }
32}
33
34/// Thread-safe HNSW index over memory embeddings.
35pub struct VectorIndex {
36    /// The built HNSW index — maps embedding points to memory IDs.
37    inner: Mutex<IndexState>,
38}
39
40struct IndexState {
41    hnsw: Option<HnswMap<EmbeddingPoint, String>>,
42    /// Entries added after the last rebuild. Searched linearly.
43    overflow: Vec<(String, Vec<f32>)>,
44    /// All entries (for rebuild). Kept in sync with the index + overflow.
45    all_entries: Vec<(String, Vec<f32>)>,
46}
47
48/// A search result from the vector index.
49#[derive(Debug, Clone)]
50pub struct VectorHit {
51    pub id: String,
52    pub distance: f32,
53}
54
55impl VectorIndex {
56    /// Build a new index from a list of (`memory_id`, embedding) pairs.
57    pub fn build(entries: Vec<(String, Vec<f32>)>) -> Self {
58        let hnsw = Self::build_hnsw(&entries);
59        VectorIndex {
60            inner: Mutex::new(IndexState {
61                hnsw,
62                overflow: Vec::new(),
63                all_entries: entries,
64            }),
65        }
66    }
67
68    /// Build an empty index.
69    pub fn empty() -> Self {
70        VectorIndex {
71            inner: Mutex::new(IndexState {
72                hnsw: None,
73                overflow: Vec::new(),
74                all_entries: Vec::new(),
75            }),
76        }
77    }
78
79    fn build_hnsw(entries: &[(String, Vec<f32>)]) -> Option<HnswMap<EmbeddingPoint, String>> {
80        if entries.is_empty() {
81            return None;
82        }
83        let points: Vec<EmbeddingPoint> = entries
84            .iter()
85            .map(|(_, emb)| EmbeddingPoint(emb.clone()))
86            .collect();
87        let values: Vec<String> = entries.iter().map(|(id, _)| id.clone()).collect();
88        Some(Builder::default().build(points, values))
89    }
90
91    /// Add a new entry to the index (goes to overflow until next rebuild).
92    pub fn insert(&self, id: String, embedding: Vec<f32>) {
93        let mut state = match self.inner.lock() {
94            Ok(s) => s,
95            Err(poisoned) => poisoned.into_inner(),
96        };
97        state.all_entries.push((id.clone(), embedding.clone()));
98        state.overflow.push((id, embedding));
99
100        // Auto-rebuild if overflow is large
101        if state.overflow.len() >= REBUILD_THRESHOLD {
102            state.hnsw = Self::build_hnsw(&state.all_entries);
103            state.overflow.clear();
104        }
105
106        // Evict oldest entries if over capacity
107        if state.all_entries.len() > MAX_ENTRIES {
108            let excess = state.all_entries.len() - MAX_ENTRIES;
109            state.all_entries.drain(..excess);
110            state.hnsw = Self::build_hnsw(&state.all_entries);
111            state.overflow.clear();
112        }
113    }
114
115    /// Remove an entry by ID (marks for exclusion; cleaned up on rebuild).
116    pub fn remove(&self, id: &str) {
117        let mut state = match self.inner.lock() {
118            Ok(s) => s,
119            Err(poisoned) => poisoned.into_inner(),
120        };
121        state.all_entries.retain(|(eid, _)| eid != id);
122        state.overflow.retain(|(eid, _)| eid != id);
123        // Note: the HNSW index itself is immutable — removed IDs are filtered
124        // from search results. A rebuild will fully remove them.
125    }
126
127    /// Search for the `k` nearest neighbors to the query embedding.
128    ///
129    /// Combines HNSW approximate search with linear scan of overflow entries.
130    /// Returns results sorted by ascending distance (closest first).
131    pub fn search(&self, query: &[f32], k: usize) -> Vec<VectorHit> {
132        let state = match self.inner.lock() {
133            Ok(s) => s,
134            Err(poisoned) => poisoned.into_inner(),
135        };
136        let query_point = EmbeddingPoint(query.to_vec());
137
138        let mut results: Vec<VectorHit> = Vec::with_capacity(k * 2);
139
140        // Collect valid IDs from all_entries for filtering removed entries
141        let valid_ids: std::collections::HashSet<&str> = state
142            .all_entries
143            .iter()
144            .map(|(id, _)| id.as_str())
145            .collect();
146
147        // Search the HNSW index
148        if let Some(ref hnsw) = state.hnsw {
149            let mut search = Search::default();
150            for item in hnsw.search(&query_point, &mut search) {
151                if !valid_ids.contains(item.value.as_str()) {
152                    continue; // Removed entry
153                }
154                results.push(VectorHit {
155                    id: item.value.clone(),
156                    distance: item.distance,
157                });
158                if results.len() >= k * 2 {
159                    break;
160                }
161            }
162        }
163
164        // Linear scan of overflow entries
165        let mut overflow_hits: Vec<VectorHit> = state
166            .overflow
167            .iter()
168            .map(|(id, emb)| {
169                let point = EmbeddingPoint(emb.clone());
170                VectorHit {
171                    id: id.clone(),
172                    distance: query_point.distance(&point),
173                }
174            })
175            .collect();
176        overflow_hits.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
177
178        results.extend(overflow_hits);
179
180        // Deduplicate by ID (prefer lower distance)
181        let mut seen = std::collections::HashSet::new();
182        results.retain(|hit| seen.insert(hit.id.clone()));
183
184        // Sort by distance and truncate
185        results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
186        results.truncate(k);
187        results
188    }
189
190    /// Return the total number of indexed entries (HNSW + overflow).
191    pub fn len(&self) -> usize {
192        let state = match self.inner.lock() {
193            Ok(s) => s,
194            Err(poisoned) => poisoned.into_inner(),
195        };
196        state.all_entries.len()
197    }
198
199    /// Force a full rebuild of the HNSW index from all entries.
200    #[allow(dead_code)]
201    pub fn rebuild(&self) {
202        let mut state = match self.inner.lock() {
203            Ok(s) => s,
204            Err(poisoned) => poisoned.into_inner(),
205        };
206        state.hnsw = Self::build_hnsw(&state.all_entries);
207        state.overflow.clear();
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    fn make_embedding(values: &[f32]) -> Vec<f32> {
216        // L2-normalize
217        let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
218        values.iter().map(|v| v / norm).collect()
219    }
220
221    #[test]
222    fn empty_index_returns_empty() {
223        let idx = VectorIndex::empty();
224        let results = idx.search(&[1.0, 0.0, 0.0], 10);
225        assert!(results.is_empty());
226    }
227
228    #[test]
229    fn basic_search() {
230        let entries = vec![
231            ("a".into(), make_embedding(&[1.0, 0.0, 0.0])),
232            ("b".into(), make_embedding(&[0.0, 1.0, 0.0])),
233            ("c".into(), make_embedding(&[0.0, 0.0, 1.0])),
234        ];
235        let idx = VectorIndex::build(entries);
236        let results = idx.search(&make_embedding(&[1.0, 0.1, 0.0]), 2);
237        assert_eq!(results.len(), 2);
238        assert_eq!(results[0].id, "a"); // Closest to [1, 0.1, 0]
239    }
240
241    #[test]
242    fn insert_and_search_overflow() {
243        let entries = vec![("a".into(), make_embedding(&[1.0, 0.0, 0.0]))];
244        let idx = VectorIndex::build(entries);
245        idx.insert("b".into(), make_embedding(&[0.9, 0.1, 0.0]));
246        let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 2);
247        assert_eq!(results.len(), 2);
248        assert_eq!(results[0].id, "a");
249        assert_eq!(results[1].id, "b");
250    }
251
252    #[test]
253    fn remove_excludes_from_results() {
254        let entries = vec![
255            ("a".into(), make_embedding(&[1.0, 0.0, 0.0])),
256            ("b".into(), make_embedding(&[0.9, 0.1, 0.0])),
257        ];
258        let idx = VectorIndex::build(entries);
259        idx.remove("a");
260        let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 5);
261        assert!(results.iter().all(|h| h.id != "a"));
262    }
263
264    // -----------------------------------------------------------------
265    // W11/S11b — rebuild + batched-insert hardening
266    // -----------------------------------------------------------------
267
268    #[test]
269    fn test_rebuild_preserves_all_entries() {
270        // Build a small but non-trivial set of orthonormal-ish vectors,
271        // rebuild the index, and confirm every id is still findable via
272        // search with a top-k that covers them all.
273        let raw: Vec<(String, Vec<f32>)> = (0..12)
274            .map(|i| {
275                let mut v = vec![0.0_f32; 16];
276                #[allow(clippy::cast_precision_loss)]
277                let f = i as f32;
278                v[i % 16] = 1.0 + f * 0.01; // bias to make L2 norm non-trivial
279                (format!("id-{i}"), make_embedding(&v))
280            })
281            .collect();
282
283        let idx = VectorIndex::build(raw.clone());
284        idx.rebuild();
285        assert_eq!(idx.len(), raw.len());
286
287        // Every id should appear when we ask for top-N where N >= count.
288        let query = make_embedding(&[1.0; 16]);
289        let hits = idx.search(&query, raw.len() * 2);
290        let found: std::collections::HashSet<String> = hits.into_iter().map(|h| h.id).collect();
291        for (id, _) in &raw {
292            assert!(
293                found.contains(id),
294                "rebuild must preserve id {id}, found: {:?}",
295                found
296            );
297        }
298    }
299
300    #[test]
301    fn test_remove_then_search_excludes_id() {
302        let entries = vec![
303            ("alpha".into(), make_embedding(&[1.0, 0.0, 0.0, 0.0])),
304            ("beta".into(), make_embedding(&[0.9, 0.1, 0.0, 0.0])),
305            ("gamma".into(), make_embedding(&[0.8, 0.2, 0.0, 0.0])),
306        ];
307        let idx = VectorIndex::build(entries);
308        // Pre-remove: alpha should be the closest to (1,0,0,0).
309        let pre = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), 5);
310        assert!(pre.iter().any(|h| h.id == "alpha"));
311
312        idx.remove("alpha");
313        // Post-remove: alpha must not appear regardless of k.
314        for k in 1..=10 {
315            let hits = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), k);
316            assert!(
317                hits.iter().all(|h| h.id != "alpha"),
318                "removed id `alpha` resurfaced with k={k}: {:?}",
319                hits.iter().map(|h| &h.id).collect::<Vec<_>>()
320            );
321        }
322
323        // Other entries still findable.
324        let hits = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), 5);
325        let ids: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
326        assert!(ids.contains(&"beta"));
327        assert!(ids.contains(&"gamma"));
328    }
329
330    // -----------------------------------------------------------------
331    // W12-H — small edge cases
332    // -----------------------------------------------------------------
333
334    #[test]
335    fn empty_index_len_is_zero() {
336        let idx = VectorIndex::empty();
337        assert_eq!(idx.len(), 0);
338    }
339
340    #[test]
341    fn build_with_empty_entries_search_empty() {
342        let idx = VectorIndex::build(Vec::new());
343        assert_eq!(idx.len(), 0);
344        let results = idx.search(&[1.0, 0.0, 0.0], 5);
345        assert!(results.is_empty());
346    }
347
348    #[test]
349    fn search_with_k_zero_returns_empty() {
350        let entries = vec![("a".into(), make_embedding(&[1.0, 0.0, 0.0]))];
351        let idx = VectorIndex::build(entries);
352        let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 0);
353        assert!(results.is_empty());
354    }
355
356    #[test]
357    fn rebuild_on_empty_does_not_crash() {
358        let idx = VectorIndex::empty();
359        idx.rebuild();
360        assert_eq!(idx.len(), 0);
361    }
362
363    #[test]
364    fn insert_increases_len() {
365        let idx = VectorIndex::empty();
366        idx.insert("a".into(), make_embedding(&[1.0, 0.0, 0.0]));
367        idx.insert("b".into(), make_embedding(&[0.0, 1.0, 0.0]));
368        assert_eq!(idx.len(), 2);
369    }
370
371    #[test]
372    fn embedding_point_distance_orthogonal() {
373        let a = EmbeddingPoint(vec![1.0, 0.0, 0.0]);
374        let b = EmbeddingPoint(vec![0.0, 1.0, 0.0]);
375        // 1 - dot = 1 - 0 = 1
376        assert!((a.distance(&b) - 1.0).abs() < 1e-6);
377    }
378
379    #[test]
380    fn embedding_point_distance_identical_is_zero() {
381        let a = EmbeddingPoint(make_embedding(&[1.0, 1.0, 1.0]));
382        // 1 - 1 = 0 (L2-normalised)
383        assert!(a.distance(&a).abs() < 1e-6);
384    }
385
386    #[test]
387    fn remove_on_empty_index_is_noop() {
388        let idx = VectorIndex::empty();
389        idx.remove("nonexistent");
390        assert_eq!(idx.len(), 0);
391    }
392
393    #[test]
394    fn insert_triggers_auto_rebuild_at_threshold() {
395        // REBUILD_THRESHOLD = 200. Inserting that many into a fresh index
396        // exercises the auto-rebuild branch in `insert`.
397        let idx = VectorIndex::empty();
398        for i in 0..205_usize {
399            let mut v = vec![0.0_f32; 8];
400            #[allow(clippy::cast_precision_loss)]
401            let f = i as f32;
402            v[i % 8] = 1.0 + f * 0.001;
403            idx.insert(format!("id-{i}"), make_embedding(&v));
404        }
405        assert_eq!(idx.len(), 205);
406        // After auto-rebuild, search still works — top-k returns hits.
407        let q = make_embedding(&[1.0_f32; 8]);
408        let hits = idx.search(&q, 5);
409        assert_eq!(hits.len(), 5);
410    }
411
412    #[test]
413    fn test_rebuild_after_batch_insert_settles() {
414        // Start empty, batch-insert N entries, force a rebuild, then assert
415        // that top-K search returns exactly K results (deterministic count
416        // for a fully-populated index with K <= len).
417        let idx = VectorIndex::empty();
418        let n = 25_usize;
419        for i in 0..n {
420            let mut v = vec![0.0_f32; 8];
421            #[allow(clippy::cast_precision_loss)]
422            let f = i as f32;
423            v[i % 8] = 1.0 + f * 0.001;
424            idx.insert(format!("id-{i}"), make_embedding(&v));
425        }
426        // Force a rebuild — overflow may not have hit REBUILD_THRESHOLD.
427        idx.rebuild();
428        assert_eq!(idx.len(), n);
429
430        let query = make_embedding(&[1.0; 8]);
431        let k = 5;
432        let hits = idx.search(&query, k);
433        assert_eq!(
434            hits.len(),
435            k,
436            "post-rebuild search top-{k} must return exactly {k} hits, got {:?}",
437            hits.iter().map(|h| &h.id).collect::<Vec<_>>()
438        );
439
440        // Distances should be sorted ascending (closest first).
441        for w in hits.windows(2) {
442            assert!(
443                w[0].distance <= w[1].distance,
444                "search results must be ascending by distance: {} > {}",
445                w[0].distance,
446                w[1].distance
447            );
448        }
449
450        // No duplicate ids in the result.
451        let mut seen = std::collections::HashSet::new();
452        for h in &hits {
453            assert!(
454                seen.insert(h.id.clone()),
455                "duplicate id in search: {}",
456                h.id
457            );
458        }
459    }
460}