Skip to main content

dynvec/
index.rs

1//! Approximate nearest-neighbour index.
2//!
3//! Implements a minimal Hierarchical Navigable Small World
4//! (HNSW) graph following the algorithm of Malkov & Yashunin,
5//! "Efficient and robust approximate nearest neighbor search
6//! using Hierarchical Navigable Small World graphs"
7//! (TPAMI 2018, arXiv:1603.09320).
8//!
9//! Why hand-rolled rather than `instant-distance` or `hnsw_rs`:
10//!
11//! * The index needs to interleave with our codec layer so that
12//!   the on-disk representation is a [`crate::encoding::EncodedVector`]
13//!   not an `f32` slice. Hooking that into a third-party crate
14//!   requires either keeping a parallel `Vec<f32>` cache (doubles
15//!   memory) or wrapping its `Point` trait in adapters (locks us
16//!   into that crate's API surface).
17//! * We need explicit `delete` semantics. `instant-distance`
18//!   does not expose deletion; we would have to maintain a
19//!   tombstone set externally. Inverting that with a hand-rolled
20//!   HNSW is a small amount of code and keeps the public API
21//!   honest.
22//! * No new third-party dependency, no review burden.
23//!
24//! Defaults:
25//! * `M = 16` (max bidirectional connections per layer)
26//! * `M0 = 32` (max connections at layer 0)
27//! * `ef_construction = 200`
28//! * `ef_search = 50`
29//! * Layer assignment uses `floor(-ln(rand()) * mL)` with
30//!   `mL = 1 / ln(M)` per the original paper.
31//!
32//! The index is single-threaded; coarser concurrency lives at
33//! the [`crate::storage`] layer where a per-table `Mutex` is
34//! held across an insert / search call.
35
36use std::cmp::Ordering;
37use std::collections::{BinaryHeap, HashMap, HashSet};
38
39use serde::{Deserialize, Serialize};
40
41use crate::distance::Distance;
42
43/// Stable identifier for the value an index node points back to.
44///
45/// The storage layer maps `NodeId` to a row key; the index is
46/// agnostic to the row format.
47pub type NodeId = u64;
48
49/// Tuneable parameters.
50#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
51pub struct HnswParams {
52    /// Max bidirectional links per node at every level above 0.
53    pub m: usize,
54    /// Max bidirectional links per node at level 0.
55    pub m0: usize,
56    /// Search beam width during insertion.
57    pub ef_construction: usize,
58    /// Default search beam width for queries. The query API can
59    /// override this per call.
60    pub ef_search: usize,
61    /// Random seed for layer assignment. Stored for
62    /// reproducibility; the `xorshift64` PRNG below is
63    /// deterministic for a given seed.
64    pub seed: u64,
65}
66
67impl Default for HnswParams {
68    fn default() -> Self {
69        Self {
70            m: 16,
71            m0: 32,
72            ef_construction: 200,
73            ef_search: 50,
74            seed: 0xDEAD_BEEF_CAFE_F00D,
75        }
76    }
77}
78
79/// One node in the graph.
80#[derive(Clone, Debug, Serialize, Deserialize)]
81struct HnswNode {
82    /// External identifier. Used by the storage layer to map back
83    /// to the persisted row.
84    id: NodeId,
85    /// `f32` representation of the vector. The index keeps a
86    /// decoded copy because the inner search loops touch every
87    /// component on the hot path; re-decoding on every distance
88    /// computation would dominate the runtime.
89    vector: Vec<f32>,
90    /// Adjacency lists, indexed by layer. `levels[0]` is the base
91    /// layer; higher indices are the sparser upper layers.
92    levels: Vec<Vec<usize>>,
93    /// Soft-deleted node. Tombstoned nodes are skipped during
94    /// search but their adjacency stays so the graph topology
95    /// is preserved until a future compaction rebuilds.
96    deleted: bool,
97}
98
99impl HnswNode {
100    fn level(&self) -> usize {
101        self.levels.len().saturating_sub(1)
102    }
103}
104
105/// Hand-rolled HNSW index over `f32` vectors.
106#[derive(Clone, Debug, Serialize, Deserialize)]
107pub struct HnswIndex {
108    params: HnswParams,
109    distance: Distance,
110    /// Storage of nodes by internal index. `NodeId -> internal idx`
111    /// is in [`Self::id_to_idx`].
112    nodes: Vec<HnswNode>,
113    /// External-id lookup.
114    id_to_idx: HashMap<NodeId, usize>,
115    /// Index of the entry-point node, or `None` for an empty
116    /// index.
117    entry: Option<usize>,
118    /// `mL` factor for level assignment, cached because every
119    /// insert calls it.
120    ml: f64,
121    /// PRNG state for layer assignment.
122    rng_state: u64,
123    /// Vector dimension. Frozen on first insert and enforced on
124    /// every subsequent insert; an attempt to insert a different
125    /// dimension is rejected by the storage layer before reaching
126    /// the index.
127    dim: u16,
128}
129
130/// Errors returned by the index.
131#[derive(Debug, thiserror::Error)]
132#[non_exhaustive]
133pub enum IndexError {
134    /// Vector dimension does not match the index dimension.
135    #[error("dimension mismatch: index has {expected}, got {got}")]
136    DimensionMismatch {
137        /// Index's frozen dimension.
138        expected: u16,
139        /// Caller's vector dimension.
140        got: u16,
141    },
142    /// Tried to insert a [`NodeId`] that already exists.
143    #[error("id {0} already present in the index")]
144    Duplicate(NodeId),
145    /// Empty input vector.
146    #[error("empty vector")]
147    Empty,
148}
149
150/// Result entry from a search query.
151#[derive(Clone, Debug, PartialEq)]
152pub struct SearchResult {
153    /// External identifier of the matched node.
154    pub id: NodeId,
155    /// Distance score; smaller is closer.
156    pub score: f32,
157}
158
159/// Min-heap entry: `Reverse`-style ordering so a [`BinaryHeap`]
160/// behaves as a min-heap on the score.
161#[derive(Clone, Copy, Debug)]
162struct Candidate {
163    idx: usize,
164    score: f32,
165}
166
167impl PartialEq for Candidate {
168    fn eq(&self, other: &Self) -> bool {
169        self.score == other.score
170    }
171}
172impl Eq for Candidate {}
173impl PartialOrd for Candidate {
174    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
175        Some(self.cmp(other))
176    }
177}
178impl Ord for Candidate {
179    fn cmp(&self, other: &Self) -> Ordering {
180        // Min-heap on score: invert.
181        other
182            .score
183            .partial_cmp(&self.score)
184            .unwrap_or(Ordering::Equal)
185    }
186}
187
188/// Max-heap entry on score; used to keep the top-K furthest in
189/// the dynamic candidate set.
190#[derive(Clone, Copy, Debug)]
191struct MaxCandidate {
192    idx: usize,
193    score: f32,
194}
195
196impl PartialEq for MaxCandidate {
197    fn eq(&self, other: &Self) -> bool {
198        self.score == other.score
199    }
200}
201impl Eq for MaxCandidate {}
202impl PartialOrd for MaxCandidate {
203    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
204        Some(self.cmp(other))
205    }
206}
207impl Ord for MaxCandidate {
208    fn cmp(&self, other: &Self) -> Ordering {
209        self.score
210            .partial_cmp(&other.score)
211            .unwrap_or(Ordering::Equal)
212    }
213}
214
215impl HnswIndex {
216    /// Build an empty index.
217    #[must_use]
218    pub fn new(distance: Distance, params: HnswParams) -> Self {
219        let ml = if params.m > 1 {
220            1.0 / f64::from(u32::try_from(params.m).unwrap_or(u32::MAX)).ln()
221        } else {
222            1.0
223        };
224        Self {
225            params,
226            distance,
227            nodes: Vec::new(),
228            id_to_idx: HashMap::new(),
229            entry: None,
230            ml,
231            rng_state: params.seed,
232            dim: 0,
233        }
234    }
235
236    /// Number of live (non-deleted) nodes.
237    #[must_use]
238    pub fn len(&self) -> usize {
239        self.nodes.iter().filter(|n| !n.deleted).count()
240    }
241
242    /// `true` when the index has no live nodes.
243    #[must_use]
244    pub fn is_empty(&self) -> bool {
245        self.len() == 0
246    }
247
248    /// Vector dimension, or 0 if the index is empty.
249    #[must_use]
250    pub fn dim(&self) -> u16 {
251        self.dim
252    }
253
254    /// Distance metric this index was built with.
255    #[must_use]
256    pub fn distance(&self) -> Distance {
257        self.distance
258    }
259
260    /// Insert a new vector under `id`.
261    ///
262    /// # Errors
263    ///
264    /// [`IndexError::Empty`] for a zero-dim vector,
265    /// [`IndexError::DimensionMismatch`] when the vector's
266    /// dimension differs from the index's frozen dimension,
267    /// and [`IndexError::Duplicate`] when `id` is already in
268    /// the index.
269    pub fn insert(&mut self, id: NodeId, vector: Vec<f32>) -> Result<(), IndexError> {
270        if vector.is_empty() {
271            return Err(IndexError::Empty);
272        }
273        let got = u16::try_from(vector.len()).unwrap_or(u16::MAX);
274        if self.nodes.is_empty() {
275            self.dim = got;
276        } else if self.dim != got {
277            return Err(IndexError::DimensionMismatch {
278                expected: self.dim,
279                got,
280            });
281        }
282        if self.id_to_idx.contains_key(&id) {
283            return Err(IndexError::Duplicate(id));
284        }
285
286        let level = self.random_level();
287        let mut levels: Vec<Vec<usize>> = Vec::with_capacity(level + 1);
288        for _ in 0..=level {
289            levels.push(Vec::new());
290        }
291
292        let new_idx = self.nodes.len();
293        self.nodes.push(HnswNode {
294            id,
295            vector,
296            levels,
297            deleted: false,
298        });
299        self.id_to_idx.insert(id, new_idx);
300
301        let Some(entry) = self.entry else {
302            self.entry = Some(new_idx);
303            return Ok(());
304        };
305        let entry_level = self.nodes[entry].level();
306
307        // Phase 1: descend through layers above `level` finding the
308        // best entry point for `level`.
309        let mut current = entry;
310        if entry_level > level {
311            for lc in (level + 1..=entry_level).rev() {
312                current = self.greedy_search_layer(current, new_idx, lc);
313            }
314        }
315
316        // Phase 2: at each layer from min(level, entry_level) down to
317        // 0, search for ef_construction candidates and connect.
318        let start_layer = level.min(entry_level);
319        let mut entry_points = vec![current];
320        for lc in (0..=start_layer).rev() {
321            let neighbours = self.search_layer(
322                new_idx,
323                &entry_points,
324                lc,
325                self.params.ef_construction,
326                /*include_deleted=*/ true,
327            );
328            let m = if lc == 0 {
329                self.params.m0
330            } else {
331                self.params.m
332            };
333            let selected = Self::select_neighbours(&neighbours, m);
334            // Bidirectional links.
335            for &nb in &selected {
336                self.nodes[new_idx].levels[lc].push(nb);
337                self.nodes[nb].levels[lc].push(new_idx);
338                // Shrink the neighbour's adjacency if it now exceeds
339                // the cap.
340                let cap = if lc == 0 {
341                    self.params.m0
342                } else {
343                    self.params.m
344                };
345                if self.nodes[nb].levels[lc].len() > cap {
346                    self.shrink_connections(nb, lc, cap);
347                }
348            }
349            entry_points = selected;
350            if entry_points.is_empty() {
351                entry_points = vec![current];
352            }
353        }
354
355        // If the new node sits above the previous entry point, it
356        // becomes the new entry point.
357        if level > entry_level {
358            self.entry = Some(new_idx);
359        }
360        Ok(())
361    }
362
363    /// Soft-delete `id`. The node remains in the graph for
364    /// connectivity but is filtered out of search results.
365    ///
366    /// Returns `true` when the id was present, `false` otherwise.
367    pub fn delete(&mut self, id: NodeId) -> bool {
368        let Some(&idx) = self.id_to_idx.get(&id) else {
369            return false;
370        };
371        if self.nodes[idx].deleted {
372            return false;
373        }
374        self.nodes[idx].deleted = true;
375        true
376    }
377
378    /// Search for the `k` nearest neighbours of `query`.
379    ///
380    /// `ef` controls the search beam width. Pass `None` to use the
381    /// index's default `ef_search`. A larger `ef` trades CPU for
382    /// recall.
383    ///
384    /// # Errors
385    ///
386    /// Returns [`IndexError::DimensionMismatch`] when the query
387    /// vector's dimension does not match the index's frozen
388    /// dimension.
389    pub fn search(
390        &self,
391        query: &[f32],
392        k: usize,
393        ef: Option<usize>,
394    ) -> Result<Vec<SearchResult>, IndexError> {
395        if query.is_empty() {
396            return Ok(Vec::new());
397        }
398        if self.nodes.is_empty() {
399            return Ok(Vec::new());
400        }
401        let got = u16::try_from(query.len()).unwrap_or(u16::MAX);
402        if self.dim != got {
403            return Err(IndexError::DimensionMismatch {
404                expected: self.dim,
405                got,
406            });
407        }
408
409        let mut entry = self.entry.unwrap_or(0);
410        let entry_level = self.nodes[entry].level();
411        let ef = ef.unwrap_or(self.params.ef_search).max(k);
412
413        // Greedy-descend through upper layers.
414        let query_owned = query.to_vec();
415        for lc in (1..=entry_level).rev() {
416            entry = self.greedy_search_layer_against(&query_owned, entry, lc);
417        }
418
419        let candidates = self.search_layer_against(&query_owned, &[entry], 0, ef, true);
420
421        let mut sorted: Vec<MaxCandidate> = candidates;
422        sorted.sort_by(|a, b| {
423            a.score
424                .partial_cmp(&b.score)
425                .unwrap_or(std::cmp::Ordering::Equal)
426        });
427        Ok(sorted
428            .into_iter()
429            .filter(|c| !self.nodes[c.idx].deleted)
430            .take(k)
431            .map(|c| SearchResult {
432                id: self.nodes[c.idx].id,
433                score: c.score,
434            })
435            .collect())
436    }
437
438    /// `true` when `id` is currently a live node in the index.
439    #[must_use]
440    pub fn contains(&self, id: NodeId) -> bool {
441        self.id_to_idx
442            .get(&id)
443            .is_some_and(|&idx| !self.nodes[idx].deleted)
444    }
445
446    /// Random level assignment via the original paper's formula:
447    /// `level = floor(-ln(uniform(0, 1)) * mL)`.
448    fn random_level(&mut self) -> usize {
449        let r = self.rand_unit();
450        // Guard r > 0 so `ln(0)` is impossible.
451        let r = r.max(f64::MIN_POSITIVE);
452        let level = (-r.ln() * self.ml).floor();
453        // Cap at a sane ceiling so a freak uniform sample does not
454        // allocate thousands of empty layers.
455        let max_level = 16_f64;
456        let clamped = level.clamp(0.0, max_level);
457        // The clamp guarantees `clamped` is in [0, 16]; the
458        // cast cannot truncate or sign-flip.
459        #[allow(
460            clippy::cast_possible_truncation,
461            clippy::cast_sign_loss,
462            reason = "clamped to [0, 16]"
463        )]
464        let lvl = clamped as usize;
465        lvl
466    }
467
468    /// xorshift64* PRNG, deterministic given the seed.
469    fn rand_unit(&mut self) -> f64 {
470        let mut x = self.rng_state;
471        x ^= x >> 12;
472        x ^= x << 25;
473        x ^= x >> 27;
474        self.rng_state = x;
475        let r = x.wrapping_mul(0x2545_F491_4F6C_DD1D);
476        // Take the top 53 bits and divide by 2^53 to get a
477        // uniform [0, 1) double.
478        let bits = (r >> 11) & ((1u64 << 53) - 1);
479        // bits is < 2^53, fits in f64 exactly.
480        #[allow(
481            clippy::cast_precision_loss,
482            reason = "bits is in [0, 2^53), exactly representable as f64"
483        )]
484        let f = (bits as f64) / ((1_u64 << 53) as f64);
485        f
486    }
487
488    /// Greedy search using the freshly-inserted node as the query.
489    fn greedy_search_layer(&self, entry: usize, query_idx: usize, lc: usize) -> usize {
490        let q = self.nodes[query_idx].vector.clone();
491        self.greedy_search_layer_against(&q, entry, lc)
492    }
493
494    /// Greedy single-best descent at layer `lc`.
495    fn greedy_search_layer_against(&self, query: &[f32], entry: usize, lc: usize) -> usize {
496        let mut current = entry;
497        let mut current_score = self.distance.score(query, &self.nodes[current].vector);
498        loop {
499            let mut improved = false;
500            if lc < self.nodes[current].levels.len() {
501                let neighbours: Vec<usize> = self.nodes[current].levels[lc].clone();
502                for nb in neighbours {
503                    let s = self.distance.score(query, &self.nodes[nb].vector);
504                    if s < current_score {
505                        current_score = s;
506                        current = nb;
507                        improved = true;
508                    }
509                }
510            }
511            if !improved {
512                break;
513            }
514        }
515        current
516    }
517
518    /// Beam search at layer `lc` with the freshly-inserted node as
519    /// the query.
520    fn search_layer(
521        &self,
522        query_idx: usize,
523        entry_points: &[usize],
524        lc: usize,
525        ef: usize,
526        include_deleted: bool,
527    ) -> Vec<MaxCandidate> {
528        let q = self.nodes[query_idx].vector.clone();
529        self.search_layer_against(&q, entry_points, lc, ef, include_deleted)
530    }
531
532    /// Beam search at layer `lc`. Returns up to `ef` candidates.
533    fn search_layer_against(
534        &self,
535        query: &[f32],
536        entry_points: &[usize],
537        lc: usize,
538        ef: usize,
539        include_deleted: bool,
540    ) -> Vec<MaxCandidate> {
541        let mut visited: HashSet<usize> = HashSet::new();
542        let mut frontier: BinaryHeap<Candidate> = BinaryHeap::new();
543        let mut top: BinaryHeap<MaxCandidate> = BinaryHeap::new();
544        for &ep in entry_points {
545            if visited.insert(ep) {
546                let s = self.distance.score(query, &self.nodes[ep].vector);
547                frontier.push(Candidate { idx: ep, score: s });
548                if include_deleted || !self.nodes[ep].deleted {
549                    top.push(MaxCandidate { idx: ep, score: s });
550                }
551            }
552        }
553        while let Some(c) = frontier.pop() {
554            // Stop when the closest unprocessed candidate is
555            // already worse than the current top.
556            if top.len() >= ef {
557                if let Some(worst) = top.peek() {
558                    if c.score > worst.score {
559                        break;
560                    }
561                }
562            }
563            if lc < self.nodes[c.idx].levels.len() {
564                let neighbours: Vec<usize> = self.nodes[c.idx].levels[lc].clone();
565                for nb in neighbours {
566                    if !visited.insert(nb) {
567                        continue;
568                    }
569                    let s = self.distance.score(query, &self.nodes[nb].vector);
570                    let admit = match top.peek() {
571                        Some(worst) => s < worst.score || top.len() < ef,
572                        None => true,
573                    };
574                    if admit {
575                        frontier.push(Candidate { idx: nb, score: s });
576                        if include_deleted || !self.nodes[nb].deleted {
577                            top.push(MaxCandidate { idx: nb, score: s });
578                            if top.len() > ef {
579                                top.pop();
580                            }
581                        }
582                    }
583                }
584            }
585        }
586        top.into_vec()
587    }
588
589    /// Pick the top-`m` neighbours from the candidate set using
590    /// the simple closest-first heuristic. Sufficient for the MVP;
591    /// the original paper offers a more sophisticated
592    /// "extend-by-heuristic" rule that we leave for a future tune.
593    fn select_neighbours(candidates: &[MaxCandidate], m: usize) -> Vec<usize> {
594        let mut sorted: Vec<MaxCandidate> = candidates.to_vec();
595        sorted.sort_by(|a, b| {
596            a.score
597                .partial_cmp(&b.score)
598                .unwrap_or(std::cmp::Ordering::Equal)
599        });
600        sorted.into_iter().take(m).map(|c| c.idx).collect()
601    }
602
603    /// Drop the longest edges from a node's adjacency list at `lc`
604    /// until it fits in `cap`.
605    fn shrink_connections(&mut self, idx: usize, lc: usize, cap: usize) {
606        let q = self.nodes[idx].vector.clone();
607        let neighbours = std::mem::take(&mut self.nodes[idx].levels[lc]);
608        let mut scored: Vec<(usize, f32)> = neighbours
609            .into_iter()
610            .map(|nb| {
611                let s = self.distance.score(&q, &self.nodes[nb].vector);
612                (nb, s)
613            })
614            .collect();
615        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
616        scored.truncate(cap);
617        self.nodes[idx].levels[lc] = scored.into_iter().map(|(nb, _)| nb).collect();
618    }
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624    use crate::distance::Distance;
625
626    fn unit(seed: u64, dim: usize) -> Vec<f32> {
627        let mut x = seed;
628        let mut v: Vec<f32> = Vec::with_capacity(dim);
629        for _ in 0..dim {
630            x ^= x << 13;
631            x ^= x >> 7;
632            x ^= x << 17;
633            // Map to [-1, 1). bits is in [0, 2^53), exactly
634            // representable in f64; the f64->f32 narrowing is
635            // intentional (test data does not need full f64
636            // precision).
637            let bits = (x >> 11) & ((1_u64 << 53) - 1);
638            #[allow(
639                clippy::cast_precision_loss,
640                clippy::cast_possible_truncation,
641                reason = "test fixture; PRNG output narrowed to f32"
642            )]
643            let r = ((bits as f64) / ((1_u64 << 53) as f64)) * 2.0 - 1.0;
644            #[allow(
645                clippy::cast_possible_truncation,
646                reason = "test fixture; f64 -> f32 narrowing is intentional"
647            )]
648            let rf = r as f32;
649            v.push(rf);
650        }
651        v
652    }
653
654    #[test]
655    fn insert_and_search_small() {
656        let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
657        let target = unit(42, 8);
658        idx.insert(0, target.clone()).unwrap();
659        for i in 1..50_u64 {
660            idx.insert(i, unit(i.wrapping_mul(1_000_003) + 1, 8))
661                .unwrap();
662        }
663        let res = idx.search(&target, 3, None).unwrap();
664        assert!(!res.is_empty());
665        // The node with id 0 was inserted with the same vector as
666        // the query, so it must be the nearest match.
667        assert_eq!(res[0].id, 0);
668    }
669
670    #[test]
671    fn delete_excludes_from_search() {
672        let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
673        for i in 0..30_u64 {
674            idx.insert(i, unit(i + 1, 8)).unwrap();
675        }
676        let q = unit(1, 8);
677        let before = idx.search(&q, 5, None).unwrap();
678        let target = before[0].id;
679        assert!(idx.delete(target));
680        let after = idx.search(&q, 5, None).unwrap();
681        assert!(after.iter().all(|r| r.id != target));
682    }
683
684    #[test]
685    fn dimension_mismatch_rejected() {
686        let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
687        idx.insert(0, vec![0.1, 0.2, 0.3]).unwrap();
688        assert!(matches!(
689            idx.insert(1, vec![0.1, 0.2]),
690            Err(IndexError::DimensionMismatch { .. })
691        ));
692    }
693
694    #[test]
695    fn duplicate_id_rejected() {
696        let mut idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
697        idx.insert(7, vec![0.1, 0.2]).unwrap();
698        assert!(matches!(
699            idx.insert(7, vec![0.3, 0.4]),
700            Err(IndexError::Duplicate(7))
701        ));
702    }
703
704    #[test]
705    fn empty_index_search_is_empty() {
706        let idx = HnswIndex::new(Distance::Euclidean, HnswParams::default());
707        let res = idx.search(&[0.1, 0.2], 5, None).unwrap();
708        assert!(res.is_empty());
709    }
710}