Skip to main content

ai_memory/
hnsw.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! HNSW (Hierarchical Navigable Small World) vector index for fast approximate
5//! nearest-neighbor search over memory embeddings.
6//!
7//! Built on `instant-distance`. The index is constructed at startup from all
8//! stored embeddings. New memories added during the session go into an overflow
9//! list that is scanned linearly alongside the HNSW results — the index is
10//! rebuilt lazily once the overflow exceeds a threshold.
11
12use instant_distance::{Builder, HnswMap, Point, Search};
13use std::sync::Mutex;
14use std::sync::atomic::{AtomicU64, Ordering};
15
16/// Maximum overflow entries before triggering a rebuild.
17const REBUILD_THRESHOLD: usize = 200;
18
19/// Maximum entries before evicting oldest to prevent unbounded memory growth.
20const MAX_ENTRIES: usize = 100_000;
21
22// ---------------------------------------------------------------------------
23// v0.6.3.1 (P3, G2): eviction observability
24//
25// `MAX_ENTRIES`-triggered eviction in `insert()` previously dropped the
26// oldest embeddings silently — operators near the cap lost recall quality
27// invisibly. The two counters below + the structured `hnsw.eviction`
28// tracing event close that gap:
29//
30//   - `INDEX_EVICTIONS_TOTAL` — cumulative count surfaced via
31//     `db::stats().index_evictions_total` (and capabilities).
32//   - `LAST_EVICTION_AT_NANOS` — wall-clock UNIX nanoseconds of the most
33//     recent eviction; capabilities derive `hnsw.evicted_recently` from
34//     this with a 60 s rolling window.
35//
36// Process-local. The counters reset on restart because the index itself
37// resets on restart. Both atomics are touched only on the eviction edge
38// (rare: requires >100k vectors), so there is no measurable hot-path cost.
39// ---------------------------------------------------------------------------
40
41static INDEX_EVICTIONS_TOTAL: AtomicU64 = AtomicU64::new(0);
42static LAST_EVICTION_AT_NANOS: AtomicU64 = AtomicU64::new(0);
43
44/// Cumulative HNSW oldest-eviction count since process start.
45///
46/// Surfaces in `memory_stats`. Non-zero indicates the in-memory vector
47/// index has hit `MAX_ENTRIES` and dropped older embeddings; recall
48/// quality may have degraded for evicted ids until they are re-inserted
49/// (e.g. on next access via `recall` touch path).
50#[must_use]
51pub fn index_evictions_total() -> u64 {
52    INDEX_EVICTIONS_TOTAL.load(Ordering::Relaxed)
53}
54
55/// Whether an eviction occurred within the trailing `window_secs`.
56///
57/// Used by capabilities (P1) to set `hnsw.evicted_recently` so operators
58/// can see ongoing pressure on the cap, not just the cumulative count.
59/// Returns `false` when no evictions have ever happened in this process.
60#[must_use]
61pub fn evicted_recently(window_secs: u64) -> bool {
62    let last = LAST_EVICTION_AT_NANOS.load(Ordering::Relaxed);
63    if last == 0 {
64        return false;
65    }
66    let now_nanos = std::time::SystemTime::now()
67        .duration_since(std::time::UNIX_EPOCH)
68        .map(|d| d.as_nanos())
69        .unwrap_or(0);
70    // Saturating math: clock can move backwards on some VMs.
71    let elapsed_nanos = u128::from(u64::MAX).min(now_nanos.saturating_sub(u128::from(last)));
72    elapsed_nanos < u128::from(window_secs).saturating_mul(1_000_000_000)
73}
74
75/// Reset the eviction counters. Test-only — production callers must not
76/// reach into the counter directly. The function is `pub` (rather than
77/// `pub(crate)`) so the integration-test crate at `tests/` can drive it
78/// alongside the public `index_evictions_total()` accessor; renaming
79/// keeps the intent obvious at every call site.
80#[doc(hidden)]
81pub fn reset_eviction_counters_for_test() {
82    INDEX_EVICTIONS_TOTAL.store(0, Ordering::Relaxed);
83    LAST_EVICTION_AT_NANOS.store(0, Ordering::Relaxed);
84}
85
86/// A point in the HNSW index — wraps a dense embedding vector.
87#[derive(Clone, Debug)]
88pub struct EmbeddingPoint(pub Vec<f32>);
89
90impl instant_distance::Point for EmbeddingPoint {
91    fn distance(&self, other: &Self) -> f32 {
92        // Cosine distance = 1 - cosine_similarity.
93        // Embeddings are L2-normalised so dot product = cosine similarity.
94        let dot: f32 = self.0.iter().zip(other.0.iter()).map(|(a, b)| a * b).sum();
95        1.0 - dot
96    }
97}
98
99/// Thread-safe HNSW index over memory embeddings.
100pub struct VectorIndex {
101    /// The built HNSW index — maps embedding points to memory IDs.
102    inner: Mutex<IndexState>,
103}
104
105struct IndexState {
106    hnsw: Option<HnswMap<EmbeddingPoint, String>>,
107    /// Entries added after the last rebuild. Searched linearly.
108    overflow: Vec<(String, Vec<f32>)>,
109    /// All entries (for rebuild). Kept in sync with the index + overflow.
110    all_entries: Vec<(String, Vec<f32>)>,
111}
112
113/// A search result from the vector index.
114#[derive(Debug, Clone)]
115pub struct VectorHit {
116    pub id: String,
117    pub distance: f32,
118}
119
120impl VectorIndex {
121    /// Build a new index from a list of (`memory_id`, embedding) pairs.
122    pub fn build(entries: Vec<(String, Vec<f32>)>) -> Self {
123        let hnsw = Self::build_hnsw(&entries);
124        VectorIndex {
125            inner: Mutex::new(IndexState {
126                hnsw,
127                overflow: Vec::new(),
128                all_entries: entries,
129            }),
130        }
131    }
132
133    /// Build an empty index.
134    pub fn empty() -> Self {
135        VectorIndex {
136            inner: Mutex::new(IndexState {
137                hnsw: None,
138                overflow: Vec::new(),
139                all_entries: Vec::new(),
140            }),
141        }
142    }
143
144    fn build_hnsw(entries: &[(String, Vec<f32>)]) -> Option<HnswMap<EmbeddingPoint, String>> {
145        if entries.is_empty() {
146            return None;
147        }
148        let points: Vec<EmbeddingPoint> = entries
149            .iter()
150            .map(|(_, emb)| EmbeddingPoint(emb.clone()))
151            .collect();
152        let values: Vec<String> = entries.iter().map(|(id, _)| id.clone()).collect();
153        Some(Builder::default().build(points, values))
154    }
155
156    /// Add a new entry to the index (goes to overflow until next rebuild).
157    pub fn insert(&self, id: String, embedding: Vec<f32>) {
158        let mut state = match self.inner.lock() {
159            Ok(s) => s,
160            Err(poisoned) => poisoned.into_inner(),
161        };
162        state.all_entries.push((id.clone(), embedding.clone()));
163        state.overflow.push((id, embedding));
164
165        // Auto-rebuild if overflow is large
166        if state.overflow.len() >= REBUILD_THRESHOLD {
167            state.hnsw = Self::build_hnsw(&state.all_entries);
168            state.overflow.clear();
169        }
170
171        // Evict oldest entries if over capacity
172        if state.all_entries.len() > MAX_ENTRIES {
173            let excess = state.all_entries.len() - MAX_ENTRIES;
174            // v0.6.3.1 (P3, G2): emit one structured tracing event per evicted
175            // id BEFORE we drop the rows so operators can post-mortem which
176            // memories lost their semantic-search affordance. Bumping the
177            // counter and last-eviction timestamp surfaces aggregate pressure
178            // through `memory_stats` and capabilities. The drain itself is
179            // unchanged — observability only.
180            for (evicted_id, _) in state.all_entries.iter().take(excess) {
181                tracing::warn!(
182                    target: "hnsw.eviction",
183                    evicted_id = %evicted_id,
184                    reason = "max_entries_reached",
185                    max_entries = MAX_ENTRIES,
186                    "hnsw index evicting oldest entry: cap reached"
187                );
188            }
189            #[allow(clippy::cast_possible_truncation)]
190            let evicted = excess as u64;
191            INDEX_EVICTIONS_TOTAL.fetch_add(evicted, Ordering::Relaxed);
192
193            state.all_entries.drain(..excess);
194            state.hnsw = Self::build_hnsw(&state.all_entries);
195            state.overflow.clear();
196
197            // Record completion time AFTER the rebuild. `evicted_recently` is
198            // a "did we evict in the trailing N seconds" check; an operator
199            // asking that wants the operation completion time, not the
200            // start. At cap, build_hnsw dominates wall time (~minutes at
201            // 100k entries) — using the start would make evicted_recently
202            // misreport even immediately after insert returns.
203            let now_nanos = std::time::SystemTime::now()
204                .duration_since(std::time::UNIX_EPOCH)
205                .map(|d| d.as_nanos())
206                .unwrap_or(0);
207            let now_nanos_u64 = u64::try_from(now_nanos).unwrap_or(u64::MAX);
208            LAST_EVICTION_AT_NANOS.store(now_nanos_u64, Ordering::Relaxed);
209        }
210    }
211
212    /// Remove an entry by ID (marks for exclusion; cleaned up on rebuild).
213    pub fn remove(&self, id: &str) {
214        let mut state = match self.inner.lock() {
215            Ok(s) => s,
216            Err(poisoned) => poisoned.into_inner(),
217        };
218        state.all_entries.retain(|(eid, _)| eid != id);
219        state.overflow.retain(|(eid, _)| eid != id);
220        // Note: the HNSW index itself is immutable — removed IDs are filtered
221        // from search results. A rebuild will fully remove them.
222    }
223
224    /// Search for the `k` nearest neighbors to the query embedding.
225    ///
226    /// Combines HNSW approximate search with linear scan of overflow entries.
227    /// Returns results sorted by ascending distance (closest first).
228    pub fn search(&self, query: &[f32], k: usize) -> Vec<VectorHit> {
229        let state = match self.inner.lock() {
230            Ok(s) => s,
231            Err(poisoned) => poisoned.into_inner(),
232        };
233        let query_point = EmbeddingPoint(query.to_vec());
234
235        let mut results: Vec<VectorHit> = Vec::with_capacity(k * 2);
236
237        // Collect valid IDs from all_entries for filtering removed entries
238        let valid_ids: std::collections::HashSet<&str> = state
239            .all_entries
240            .iter()
241            .map(|(id, _)| id.as_str())
242            .collect();
243
244        // Search the HNSW index
245        if let Some(ref hnsw) = state.hnsw {
246            let mut search = Search::default();
247            for item in hnsw.search(&query_point, &mut search) {
248                if !valid_ids.contains(item.value.as_str()) {
249                    continue; // Removed entry
250                }
251                results.push(VectorHit {
252                    id: item.value.clone(),
253                    distance: item.distance,
254                });
255                if results.len() >= k * 2 {
256                    break;
257                }
258            }
259        }
260
261        // Linear scan of overflow entries
262        let mut overflow_hits: Vec<VectorHit> = state
263            .overflow
264            .iter()
265            .map(|(id, emb)| {
266                let point = EmbeddingPoint(emb.clone());
267                VectorHit {
268                    id: id.clone(),
269                    distance: query_point.distance(&point),
270                }
271            })
272            .collect();
273        overflow_hits.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
274
275        results.extend(overflow_hits);
276
277        // Deduplicate by ID (prefer lower distance)
278        let mut seen = std::collections::HashSet::new();
279        results.retain(|hit| seen.insert(hit.id.clone()));
280
281        // Sort by distance and truncate
282        results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
283        results.truncate(k);
284        results
285    }
286
287    /// Return the total number of indexed entries (HNSW + overflow).
288    pub fn len(&self) -> usize {
289        let state = match self.inner.lock() {
290            Ok(s) => s,
291            Err(poisoned) => poisoned.into_inner(),
292        };
293        state.all_entries.len()
294    }
295
296    /// Force a full rebuild of the HNSW index from all entries.
297    #[allow(dead_code)]
298    pub fn rebuild(&self) {
299        let mut state = match self.inner.lock() {
300            Ok(s) => s,
301            Err(poisoned) => poisoned.into_inner(),
302        };
303        state.hnsw = Self::build_hnsw(&state.all_entries);
304        state.overflow.clear();
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    fn make_embedding(values: &[f32]) -> Vec<f32> {
313        // L2-normalize
314        let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
315        values.iter().map(|v| v / norm).collect()
316    }
317
318    #[test]
319    fn empty_index_returns_empty() {
320        let idx = VectorIndex::empty();
321        let results = idx.search(&[1.0, 0.0, 0.0], 10);
322        assert!(results.is_empty());
323    }
324
325    #[test]
326    fn basic_search() {
327        let entries = vec![
328            ("a".into(), make_embedding(&[1.0, 0.0, 0.0])),
329            ("b".into(), make_embedding(&[0.0, 1.0, 0.0])),
330            ("c".into(), make_embedding(&[0.0, 0.0, 1.0])),
331        ];
332        let idx = VectorIndex::build(entries);
333        let results = idx.search(&make_embedding(&[1.0, 0.1, 0.0]), 2);
334        assert_eq!(results.len(), 2);
335        assert_eq!(results[0].id, "a"); // Closest to [1, 0.1, 0]
336    }
337
338    #[test]
339    fn insert_and_search_overflow() {
340        let entries = vec![("a".into(), make_embedding(&[1.0, 0.0, 0.0]))];
341        let idx = VectorIndex::build(entries);
342        idx.insert("b".into(), make_embedding(&[0.9, 0.1, 0.0]));
343        let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 2);
344        assert_eq!(results.len(), 2);
345        assert_eq!(results[0].id, "a");
346        assert_eq!(results[1].id, "b");
347    }
348
349    #[test]
350    fn remove_excludes_from_results() {
351        let entries = vec![
352            ("a".into(), make_embedding(&[1.0, 0.0, 0.0])),
353            ("b".into(), make_embedding(&[0.9, 0.1, 0.0])),
354        ];
355        let idx = VectorIndex::build(entries);
356        idx.remove("a");
357        let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 5);
358        assert!(results.iter().all(|h| h.id != "a"));
359    }
360
361    // -----------------------------------------------------------------
362    // W11/S11b — rebuild + batched-insert hardening
363    // -----------------------------------------------------------------
364
365    #[test]
366    fn test_rebuild_preserves_all_entries() {
367        // Build a small but non-trivial set of orthonormal-ish vectors,
368        // rebuild the index, and confirm every id is still findable via
369        // search with a top-k that covers them all.
370        let raw: Vec<(String, Vec<f32>)> = (0..12)
371            .map(|i| {
372                let mut v = vec![0.0_f32; 16];
373                #[allow(clippy::cast_precision_loss)]
374                let f = i as f32;
375                v[i % 16] = 1.0 + f * 0.01; // bias to make L2 norm non-trivial
376                (format!("id-{i}"), make_embedding(&v))
377            })
378            .collect();
379
380        let idx = VectorIndex::build(raw.clone());
381        idx.rebuild();
382        assert_eq!(idx.len(), raw.len());
383
384        // Every id should appear when we ask for top-N where N >= count.
385        let query = make_embedding(&[1.0; 16]);
386        let hits = idx.search(&query, raw.len() * 2);
387        let found: std::collections::HashSet<String> = hits.into_iter().map(|h| h.id).collect();
388        for (id, _) in &raw {
389            assert!(
390                found.contains(id),
391                "rebuild must preserve id {id}, found: {:?}",
392                found
393            );
394        }
395    }
396
397    #[test]
398    fn test_remove_then_search_excludes_id() {
399        let entries = vec![
400            ("alpha".into(), make_embedding(&[1.0, 0.0, 0.0, 0.0])),
401            ("beta".into(), make_embedding(&[0.9, 0.1, 0.0, 0.0])),
402            ("gamma".into(), make_embedding(&[0.8, 0.2, 0.0, 0.0])),
403        ];
404        let idx = VectorIndex::build(entries);
405        // Pre-remove: alpha should be the closest to (1,0,0,0).
406        let pre = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), 5);
407        assert!(pre.iter().any(|h| h.id == "alpha"));
408
409        idx.remove("alpha");
410        // Post-remove: alpha must not appear regardless of k.
411        for k in 1..=10 {
412            let hits = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), k);
413            assert!(
414                hits.iter().all(|h| h.id != "alpha"),
415                "removed id `alpha` resurfaced with k={k}: {:?}",
416                hits.iter().map(|h| &h.id).collect::<Vec<_>>()
417            );
418        }
419
420        // Other entries still findable.
421        let hits = idx.search(&make_embedding(&[1.0, 0.0, 0.0, 0.0]), 5);
422        let ids: Vec<&str> = hits.iter().map(|h| h.id.as_str()).collect();
423        assert!(ids.contains(&"beta"));
424        assert!(ids.contains(&"gamma"));
425    }
426
427    // -----------------------------------------------------------------
428    // W12-H — small edge cases
429    // -----------------------------------------------------------------
430
431    #[test]
432    fn empty_index_len_is_zero() {
433        let idx = VectorIndex::empty();
434        assert_eq!(idx.len(), 0);
435    }
436
437    #[test]
438    fn build_with_empty_entries_search_empty() {
439        let idx = VectorIndex::build(Vec::new());
440        assert_eq!(idx.len(), 0);
441        let results = idx.search(&[1.0, 0.0, 0.0], 5);
442        assert!(results.is_empty());
443    }
444
445    #[test]
446    fn search_with_k_zero_returns_empty() {
447        let entries = vec![("a".into(), make_embedding(&[1.0, 0.0, 0.0]))];
448        let idx = VectorIndex::build(entries);
449        let results = idx.search(&make_embedding(&[1.0, 0.0, 0.0]), 0);
450        assert!(results.is_empty());
451    }
452
453    #[test]
454    fn rebuild_on_empty_does_not_crash() {
455        let idx = VectorIndex::empty();
456        idx.rebuild();
457        assert_eq!(idx.len(), 0);
458    }
459
460    #[test]
461    fn insert_increases_len() {
462        let idx = VectorIndex::empty();
463        idx.insert("a".into(), make_embedding(&[1.0, 0.0, 0.0]));
464        idx.insert("b".into(), make_embedding(&[0.0, 1.0, 0.0]));
465        assert_eq!(idx.len(), 2);
466    }
467
468    #[test]
469    fn embedding_point_distance_orthogonal() {
470        let a = EmbeddingPoint(vec![1.0, 0.0, 0.0]);
471        let b = EmbeddingPoint(vec![0.0, 1.0, 0.0]);
472        // 1 - dot = 1 - 0 = 1
473        assert!((a.distance(&b) - 1.0).abs() < 1e-6);
474    }
475
476    #[test]
477    fn embedding_point_distance_identical_is_zero() {
478        let a = EmbeddingPoint(make_embedding(&[1.0, 1.0, 1.0]));
479        // 1 - 1 = 0 (L2-normalised)
480        assert!(a.distance(&a).abs() < 1e-6);
481    }
482
483    #[test]
484    fn remove_on_empty_index_is_noop() {
485        let idx = VectorIndex::empty();
486        idx.remove("nonexistent");
487        assert_eq!(idx.len(), 0);
488    }
489
490    #[test]
491    fn insert_triggers_auto_rebuild_at_threshold() {
492        // REBUILD_THRESHOLD = 200. Inserting that many into a fresh index
493        // exercises the auto-rebuild branch in `insert`.
494        let idx = VectorIndex::empty();
495        for i in 0..205_usize {
496            let mut v = vec![0.0_f32; 8];
497            #[allow(clippy::cast_precision_loss)]
498            let f = i as f32;
499            v[i % 8] = 1.0 + f * 0.001;
500            idx.insert(format!("id-{i}"), make_embedding(&v));
501        }
502        assert_eq!(idx.len(), 205);
503        // After auto-rebuild, search still works — top-k returns hits.
504        let q = make_embedding(&[1.0_f32; 8]);
505        let hits = idx.search(&q, 5);
506        assert_eq!(hits.len(), 5);
507    }
508
509    #[test]
510    fn test_rebuild_after_batch_insert_settles() {
511        // Start empty, batch-insert N entries, force a rebuild, then assert
512        // that top-K search returns exactly K results (deterministic count
513        // for a fully-populated index with K <= len).
514        let idx = VectorIndex::empty();
515        let n = 25_usize;
516        for i in 0..n {
517            let mut v = vec![0.0_f32; 8];
518            #[allow(clippy::cast_precision_loss)]
519            let f = i as f32;
520            v[i % 8] = 1.0 + f * 0.001;
521            idx.insert(format!("id-{i}"), make_embedding(&v));
522        }
523        // Force a rebuild — overflow may not have hit REBUILD_THRESHOLD.
524        idx.rebuild();
525        assert_eq!(idx.len(), n);
526
527        let query = make_embedding(&[1.0; 8]);
528        let k = 5;
529        let hits = idx.search(&query, k);
530        assert_eq!(
531            hits.len(),
532            k,
533            "post-rebuild search top-{k} must return exactly {k} hits, got {:?}",
534            hits.iter().map(|h| &h.id).collect::<Vec<_>>()
535        );
536
537        // Distances should be sorted ascending (closest first).
538        for w in hits.windows(2) {
539            assert!(
540                w[0].distance <= w[1].distance,
541                "search results must be ascending by distance: {} > {}",
542                w[0].distance,
543                w[1].distance
544            );
545        }
546
547        // No duplicate ids in the result.
548        let mut seen = std::collections::HashSet::new();
549        for h in &hits {
550            assert!(
551                seen.insert(h.id.clone()),
552                "duplicate id in search: {}",
553                h.id
554            );
555        }
556    }
557}