Skip to main content

dynvec/
turbo_hnsw.rs

1//! HNSW topology over `turbovec` packed codes.
2//!
3//! Combines two existing pieces of this crate:
4//!
5//! * The hand-rolled HNSW graph from [`crate::index`] (Malkov &
6//!   Yashunin, TPAMI 2018) for sub-linear traversal.
7//! * The TurboQuant 2/3/4-bit codebook from the `turbovec`
8//!   crate for the per-vector storage and the per-pair scoring
9//!   kernel.
10//!
11//! The brute-force [`crate::turbo_index::TurboTable`] dominates
12//! at small corpora because its SIMD scan is so fast, but at
13//! 100k+ vectors the linear scan caps p99 latency. HNSW gives
14//! an `O(log N)` traversal: ~1000 distance ops per query at
15//! 100k versus 100k for the brute-force scan. This module pairs
16//! that traversal with a turbovec-derived per-pair scorer so
17//! the resulting search keeps the codec's compression and
18//! quantisation accuracy.
19//!
20//! # Distance kernel
21//!
22//! `turbovec`'s public search API (`TurboQuantIndex::search`
23//! and `search_with_mask`) is bulk-only: every call scores
24//! every block in the index, with at best a per-block skip
25//! when a contiguous 32-vector block has no allowed slots.
26//! That pattern is the wrong shape for HNSW, where each
27//! traversal step needs the score against ~M=16 scattered
28//! candidates and would re-pay the LUT-build + full scan
29//! cost at every step.
30//!
31//! Instead this module implements the per-pair scoring kernel
32//! in pure safe Rust, on top of the public `turbovec::codebook`
33//! and `turbovec::rotation` primitives. Each stored vector is
34//! quantised to the codebook lattice and persisted as one
35//! `u8` per coordinate (low `BITS` bits used) plus a per-
36//! vector `f32` scale. The byte-per-code layout is `BITS`
37//! times bigger on disk than `turbovec`'s bit-plane packed
38//! format but lets the scoring kernel walk a contiguous
39//! `u8` slice and feed a clean dot product, which auto-
40//! vectorises through LLVM. The crate-level
41//! `forbid(unsafe_code)` rules out the intrinsics-driven
42//! SIMD path that `turbovec::search` uses, so the layout
43//! choice is what unlocks SIMD here.
44//!
45//! Compared to the brute-force [`crate::turbo_index::TurboTable`]
46//! the per-vec memory is `8 / BITS` times bigger (e.g. 4x
47//! at 2-bit, 2x at 4-bit) but still smaller than the f32
48//! HNSW path, and the HNSW topology cuts the per-query
49//! work from `O(N)` distance calls to `O(log N)`.
50//!
51//! # Recall
52//!
53//! TQ+ per-coordinate calibration is disabled (identity shift
54//! and scale). Fitting TQ+ requires a batched first-add of at
55//! least 1000 vectors to estimate per-coordinate quantiles;
56//! the HNSW path is incremental, so an identity calibration is
57//! the honest default. The recall tests in
58//! `tests/turbo_hnsw.rs` confirm this stays inside the same
59//! `>= 85%` budget the brute [`crate::turbo_index::TurboTable`]
60//! tests use.
61
62use std::cmp::Ordering;
63use std::collections::{BinaryHeap, HashMap, HashSet};
64
65use turbovec::codebook::codebook;
66use turbovec::rotation::make_rotation_matrix;
67
68use crate::distance::Distance;
69use crate::index::{HnswParams, IndexError, NodeId, SearchResult};
70
71/// Distance abstraction that every `dynvec` ANN container
72/// honours.
73///
74/// The trait is intentionally narrow: a single
75/// `(NodeId, NodeId) -> f32` score is enough to drive the HNSW
76/// pruning heuristics (`select_neighbours` and
77/// `shrink_connections`). Query-to-node scoring during search
78/// is handled by each impl directly because the query's f32
79/// representation is in scope at that layer. Smaller scores
80/// mean closer.
81pub trait CodecDistance {
82    /// Score the stored vectors at `a` and `b` against each
83    /// other.
84    ///
85    /// The score is in the metric's smaller-is-closer
86    /// convention so the same heap comparator works
87    /// regardless of the underlying distance.
88    fn distance(&self, a: NodeId, b: NodeId) -> f32;
89}
90
91/// One node in the HNSW graph held by [`TurboHnswIndex`].
92#[derive(Clone, Debug)]
93struct TurboHnswNode {
94    id: NodeId,
95    /// Adjacency lists, one per layer. `levels[0]` is the base
96    /// layer; higher indices are the sparser upper layers.
97    levels: Vec<Vec<usize>>,
98    /// Soft-deleted node. Tombstoned nodes are skipped during
99    /// search but their adjacency stays so the graph topology
100    /// is preserved until a future compaction rebuilds.
101    deleted: bool,
102}
103
104impl TurboHnswNode {
105    fn level(&self) -> usize {
106        self.levels.len().saturating_sub(1)
107    }
108}
109
110/// HNSW graph over `turbovec`-packed codes.
111///
112/// `BITS` is the per-coordinate bit width; only `2`, `3`, and
113/// `4` are valid. Each vector occupies `dim * BITS / 8` bytes
114/// of packed storage plus a single `f32` per-vector scale.
115pub struct TurboHnswIndex<const BITS: u8> {
116    /// Distance metric the index was built with.
117    distance: Distance,
118    /// Frozen vector dimension. Must be a positive multiple of 8.
119    dim: u16,
120    /// HNSW tuning parameters; mirrors the f32 path in
121    /// [`crate::index::HnswIndex`].
122    params: HnswParams,
123
124    /// Random rotation matrix shared with the `turbovec`
125    /// encoder. Row-major, dim x dim.
126    rotation: Vec<f32>,
127    /// Lloyd-Max codebook boundaries: one f32 per quantisation
128    /// edge, with `2^BITS - 1` edges total.
129    boundaries: Vec<f32>,
130    /// Lloyd-Max codebook centroids: one f32 per quantisation
131    /// bucket, with `2^BITS` buckets total.
132    centroids: Vec<f32>,
133
134    /// Flat code buffer, one `u8` per coordinate. Slot `i`
135    /// occupies `[i * dim, (i + 1) * dim)`; only the low
136    /// `BITS` bits of each byte are populated. Sequential
137    /// access keeps the scoring loop SIMD-friendly.
138    packed: Vec<u8>,
139    /// Per-vector scale fitted by the encoder. Smaller-is-
140    /// closer scoring multiplies through this scale.
141    scales: Vec<f32>,
142
143    /// HNSW node table. Slot index in `nodes` matches slot
144    /// index in `packed` and `scales`.
145    nodes: Vec<TurboHnswNode>,
146    /// External-id lookup so `delete(NodeId)` and
147    /// `contains(NodeId)` are O(1).
148    id_to_idx: HashMap<NodeId, usize>,
149    /// Index of the entry-point node, or `None` for an empty
150    /// index.
151    entry: Option<usize>,
152    /// PRNG state for layer assignment.
153    rng_state: u64,
154    /// `mL` factor for level assignment, cached because every
155    /// insert calls it.
156    ml: f64,
157}
158
159/// Min-heap entry on score; used as the search frontier.
160#[derive(Clone, Copy, Debug)]
161struct Candidate {
162    idx: usize,
163    score: f32,
164}
165
166impl PartialEq for Candidate {
167    fn eq(&self, other: &Self) -> bool {
168        self.score == other.score
169    }
170}
171impl Eq for Candidate {}
172impl PartialOrd for Candidate {
173    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
174        Some(self.cmp(other))
175    }
176}
177impl Ord for Candidate {
178    fn cmp(&self, other: &Self) -> Ordering {
179        // Min-heap on score: invert.
180        other
181            .score
182            .partial_cmp(&self.score)
183            .unwrap_or(Ordering::Equal)
184    }
185}
186
187/// Max-heap entry on score; used to keep the top-K furthest in
188/// the dynamic candidate set.
189#[derive(Clone, Copy, Debug)]
190struct MaxCandidate {
191    idx: usize,
192    score: f32,
193}
194
195impl PartialEq for MaxCandidate {
196    fn eq(&self, other: &Self) -> bool {
197        self.score == other.score
198    }
199}
200impl Eq for MaxCandidate {}
201impl PartialOrd for MaxCandidate {
202    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
203        Some(self.cmp(other))
204    }
205}
206impl Ord for MaxCandidate {
207    fn cmp(&self, other: &Self) -> Ordering {
208        self.score
209            .partial_cmp(&other.score)
210            .unwrap_or(Ordering::Equal)
211    }
212}
213
214impl<const BITS: u8> TurboHnswIndex<BITS> {
215    /// Build an empty turbo-HNSW index for `dim` and the codec
216    /// metric.
217    ///
218    /// # Errors
219    ///
220    /// [`IndexError::Empty`] when `BITS` is outside `{2, 3, 4}`
221    /// or when `dim == 0`. [`IndexError::DimensionMismatch`]
222    /// when `dim` is not a positive multiple of 8 (the
223    /// `turbovec` codebook constraint); the `expected` field
224    /// is rounded up to the next multiple of 8 to give the
225    /// caller a workable suggestion.
226    pub fn new(distance: Distance, dim: u16, params: HnswParams) -> Result<Self, IndexError> {
227        if !(2..=4).contains(&BITS) {
228            return Err(IndexError::Empty);
229        }
230        if dim == 0 {
231            return Err(IndexError::Empty);
232        }
233        if !dim.is_multiple_of(8) {
234            return Err(IndexError::DimensionMismatch {
235                expected: ((dim / 8) + 1) * 8,
236                got: dim,
237            });
238        }
239        let dim_usize = usize::from(dim);
240        let bits_usize = usize::from(BITS);
241        let rotation = make_rotation_matrix(dim_usize);
242        let (boundaries, centroids) = codebook(bits_usize, dim_usize);
243        let ml = if params.m > 1 {
244            1.0 / f64::from(u32::try_from(params.m).unwrap_or(u32::MAX)).ln()
245        } else {
246            1.0
247        };
248        Ok(Self {
249            distance,
250            dim,
251            params,
252            rotation,
253            boundaries,
254            centroids,
255            packed: Vec::new(),
256            scales: Vec::new(),
257            nodes: Vec::new(),
258            id_to_idx: HashMap::new(),
259            entry: None,
260            rng_state: params.seed,
261            ml,
262        })
263    }
264
265    /// Number of live (non-deleted) nodes.
266    #[must_use]
267    pub fn len(&self) -> usize {
268        self.nodes.iter().filter(|n| !n.deleted).count()
269    }
270
271    /// `true` when no live nodes exist.
272    #[must_use]
273    pub fn is_empty(&self) -> bool {
274        self.len() == 0
275    }
276
277    /// Frozen vector dimension.
278    #[must_use]
279    pub fn dim(&self) -> u16 {
280        self.dim
281    }
282
283    /// Distance metric.
284    #[must_use]
285    pub fn distance_metric(&self) -> Distance {
286        self.distance
287    }
288
289    /// Bit width handed to `turbovec`.
290    #[must_use]
291    pub fn bits(&self) -> u8 {
292        BITS
293    }
294
295    /// `true` when `id` is currently a live node.
296    #[must_use]
297    pub fn contains(&self, id: NodeId) -> bool {
298        self.id_to_idx
299            .get(&id)
300            .is_some_and(|&idx| !self.nodes[idx].deleted)
301    }
302
303    /// Insert a new vector under `id`.
304    ///
305    /// # Errors
306    ///
307    /// [`IndexError::Empty`] for a zero-dim vector,
308    /// [`IndexError::DimensionMismatch`] when the vector's
309    /// dimension does not match the index's frozen dim, and
310    /// [`IndexError::Duplicate`] when `id` is already present.
311    pub fn insert(&mut self, id: NodeId, vector: Vec<f32>) -> Result<(), IndexError> {
312        if vector.is_empty() {
313            return Err(IndexError::Empty);
314        }
315        let got = u16::try_from(vector.len()).unwrap_or(u16::MAX);
316        if got != self.dim {
317            return Err(IndexError::DimensionMismatch {
318                expected: self.dim,
319                got,
320            });
321        }
322        if self.id_to_idx.contains_key(&id) {
323            return Err(IndexError::Duplicate(id));
324        }
325        // Cosine and Euclidean both decompose into an inner
326        // product on L2-normalised inputs; matching the
327        // `TurboTable` policy keeps the on-codec score
328        // comparable across the brute and HNSW paths.
329        let prepared = match self.distance {
330            Distance::Cosine | Distance::Euclidean => l2_normalise(&vector),
331            Distance::DotProduct => vector,
332        };
333        // Reject non-finite or huge-magnitude coordinates
334        // before they reach `turbovec::encode`, which would
335        // panic the process otherwise.
336        for v in &prepared {
337            if !v.is_finite() || v.abs() >= 1e16_f32 {
338                return Err(IndexError::Empty);
339            }
340        }
341        let dim_usize = usize::from(self.dim);
342        let bytes_per_vec = self.bytes_per_vec();
343        let (packed, scale) = self.encode_one(&prepared);
344        debug_assert_eq!(packed.len(), bytes_per_vec);
345        let _ = dim_usize;
346
347        // Encode-and-store happens before the graph wiring so
348        // the new node's slot index matches `nodes.len()` once
349        // the node record is pushed below.
350        self.packed.extend_from_slice(&packed);
351        self.scales.push(scale);
352
353        let level = self.random_level();
354        let mut levels: Vec<Vec<usize>> = Vec::with_capacity(level + 1);
355        for _ in 0..=level {
356            levels.push(Vec::new());
357        }
358
359        let new_idx = self.nodes.len();
360        self.nodes.push(TurboHnswNode {
361            id,
362            levels,
363            deleted: false,
364        });
365        self.id_to_idx.insert(id, new_idx);
366
367        let Some(entry) = self.entry else {
368            self.entry = Some(new_idx);
369            return Ok(());
370        };
371        let entry_level = self.nodes[entry].level();
372
373        // Phase 1: descend through layers above `level`,
374        // narrowing to the best entry point at level + 1.
375        let q_rot = self.rotate(&prepared);
376        let mut current = entry;
377        if entry_level > level {
378            for lc in (level + 1..=entry_level).rev() {
379                current = self.greedy_search_layer(&q_rot, current, lc, new_idx);
380            }
381        }
382
383        // Phase 2: at every layer from min(level, entry_level)
384        // down to 0, beam-search for ef_construction
385        // candidates and connect.
386        let start_layer = level.min(entry_level);
387        let mut entry_points = vec![current];
388        for lc in (0..=start_layer).rev() {
389            let neighbours = self.search_layer(
390                &q_rot,
391                &entry_points,
392                lc,
393                self.params.ef_construction,
394                /* skip_idx = */ Some(new_idx),
395            );
396            let m = if lc == 0 {
397                self.params.m0
398            } else {
399                self.params.m
400            };
401            let selected = Self::select_neighbours(&neighbours, m);
402            for &nb in &selected {
403                self.nodes[new_idx].levels[lc].push(nb);
404                self.nodes[nb].levels[lc].push(new_idx);
405                let cap = if lc == 0 {
406                    self.params.m0
407                } else {
408                    self.params.m
409                };
410                if self.nodes[nb].levels[lc].len() > cap {
411                    self.shrink_connections(nb, lc, cap);
412                }
413            }
414            entry_points = selected;
415            if entry_points.is_empty() {
416                entry_points = vec![current];
417            }
418        }
419
420        if level > entry_level {
421            self.entry = Some(new_idx);
422        }
423        Ok(())
424    }
425
426    /// Soft-delete `id`. Returns `true` when the id was a live
427    /// node, `false` otherwise.
428    pub fn delete(&mut self, id: NodeId) -> bool {
429        let Some(&idx) = self.id_to_idx.get(&id) else {
430            return false;
431        };
432        if self.nodes[idx].deleted {
433            return false;
434        }
435        self.nodes[idx].deleted = true;
436        true
437    }
438
439    /// Search for the `k` nearest neighbours of `query`.
440    ///
441    /// `ef` overrides the default `ef_search` beam width.
442    ///
443    /// # Errors
444    ///
445    /// [`IndexError::DimensionMismatch`] when the query's
446    /// dimension does not match the index's frozen dim.
447    pub fn search(
448        &self,
449        query: &[f32],
450        k: usize,
451        ef: Option<usize>,
452    ) -> Result<Vec<SearchResult>, IndexError> {
453        if query.is_empty() {
454            return Ok(Vec::new());
455        }
456        if self.nodes.is_empty() {
457            return Ok(Vec::new());
458        }
459        let got = u16::try_from(query.len()).unwrap_or(u16::MAX);
460        if got != self.dim {
461            return Err(IndexError::DimensionMismatch {
462                expected: self.dim,
463                got,
464            });
465        }
466
467        let prepared = match self.distance {
468            Distance::Cosine | Distance::Euclidean => l2_normalise(query),
469            Distance::DotProduct => query.to_vec(),
470        };
471        let q_rot = self.rotate(&prepared);
472
473        let mut entry = self.entry.unwrap_or(0);
474        let entry_level = self.nodes[entry].level();
475        let ef = ef.unwrap_or(self.params.ef_search).max(k);
476
477        for lc in (1..=entry_level).rev() {
478            entry = self.greedy_search_layer(&q_rot, entry, lc, usize::MAX);
479        }
480
481        let candidates = self.search_layer(&q_rot, &[entry], 0, ef, None);
482
483        let mut sorted = candidates;
484        sorted.sort_by(|a, b| {
485            a.score
486                .partial_cmp(&b.score)
487                .unwrap_or(std::cmp::Ordering::Equal)
488        });
489        Ok(sorted
490            .into_iter()
491            .filter(|c| !self.nodes[c.idx].deleted)
492            .take(k)
493            .map(|c| SearchResult {
494                id: self.nodes[c.idx].id,
495                score: c.score,
496            })
497            .collect())
498    }
499
500    /// Number of bytes occupied by one stored vector. With
501    /// the byte-per-code layout this is just `dim`.
502    fn bytes_per_vec(&self) -> usize {
503        usize::from(self.dim)
504    }
505
506    /// Single-vector encode: rotate, quantise to the codebook
507    /// lattice, fit a per-vector scale, and emit one `u8`
508    /// code per coordinate.
509    ///
510    /// The public `turbovec::encode::encode` runs the
511    /// quantisation pipeline through Rayon and is sized for
512    /// big batches; calling it once per HNSW insert pays the
513    /// thread-pool setup on every vector and balloons build
514    /// time by orders of magnitude. This routine reproduces
515    /// the per-row math (normalise, rotate, quantise, fit
516    /// per-vector scale) in scalar code so each insert stays
517    /// inexpensive. TQ+ per-coordinate calibration is the
518    /// identity (no shift, no scale); fitting it would need
519    /// a 1000-vector batch which the incremental HNSW path
520    /// does not have at insert time.
521    fn encode_one(&self, vector: &[f32]) -> (Vec<u8>, f32) {
522        let dim = usize::from(self.dim);
523        // 1. Norm and unit vector.
524        let mut norm_sq = 0.0_f32;
525        for &x in vector {
526            norm_sq += x * x;
527        }
528        let norm = norm_sq.sqrt();
529        let inv_norm = if norm > 1e-10 { 1.0 / norm } else { 0.0 };
530        let mut unit = vec![0.0_f32; dim];
531        for (d, slot) in unit.iter_mut().enumerate().take(dim) {
532            *slot = vector[d] * inv_norm;
533        }
534        // 2. Rotate: u_rot = R @ unit.
535        let u_rot = self.rotate(&unit);
536        // 3. Quantise to centroid codes; fit the scale by
537        // accumulating the unit-vec inner product against the
538        // chosen centroids.
539        let mut packed = vec![0_u8; dim];
540        let mut inner = 0.0_f32;
541        for (j, &uj) in u_rot.iter().enumerate().take(dim) {
542            let mut code = 0_u8;
543            for &b in &self.boundaries {
544                if uj > b {
545                    code += 1;
546                }
547            }
548            inner += uj * self.centroids[usize::from(code)];
549            packed[j] = code;
550        }
551        // 4. Per-vector scale: norm / <u_rot, x_hat>. Floor at
552        // 1e-10 so a vanishing inner product cannot produce
553        // an infinite scale.
554        let inner = inner.max(1e-10_f32);
555        let scale = norm / inner;
556        (packed, scale)
557    }
558
559    /// Borrow the contiguous byte slice for slot `slot`.
560    fn codes(&self, slot: usize) -> &[u8] {
561        let dim = usize::from(self.dim);
562        let row_start = slot * dim;
563        &self.packed[row_start..row_start + dim]
564    }
565
566    /// Multiply the rotation matrix by `q` and return `R @ q`.
567    fn rotate(&self, q: &[f32]) -> Vec<f32> {
568        let dim = usize::from(self.dim);
569        let mut out = vec![0.0_f32; dim];
570        for (d, slot) in out.iter_mut().enumerate().take(dim) {
571            let row = &self.rotation[d * dim..(d + 1) * dim];
572            let mut sum = 0.0_f32;
573            for (e, &qe) in q.iter().enumerate().take(dim) {
574                sum += row[e] * qe;
575            }
576            *slot = sum;
577        }
578        out
579    }
580
581    /// Inner-product surrogate: `<q_rot, x_hat[slot]> *
582    /// scale[slot]`.
583    ///
584    /// The result is the codec's similarity estimate, in
585    /// `(-||q||, +||q||)` for unit-normalised queries. The
586    /// metric mapping in [`Self::similarity_to_distance`]
587    /// turns that into a smaller-is-closer score.
588    fn similarity_query(&self, q_rot: &[f32], slot: usize) -> f32 {
589        let dim = usize::from(self.dim);
590        let codes = self.codes(slot);
591        let centroids = self.centroids.as_slice();
592        let mut acc = 0.0_f32;
593        for d in 0..dim {
594            acc += q_rot[d] * centroids[codes[d] as usize];
595        }
596        acc * self.scales[slot]
597    }
598
599    /// Inner-product surrogate between two stored slots.
600    ///
601    /// The rotation is orthogonal, so
602    /// `<v_a, v_b> ~= scale_a * scale_b * <x_hat_a, x_hat_b>`.
603    /// Both vectors are quantised, so the kernel sees the
604    /// double-quantisation error; recall on the pair-only path
605    /// (used by `shrink_connections`) is tighter than on the
606    /// query-to-stored path.
607    fn similarity_pair(&self, a: usize, b: usize) -> f32 {
608        let dim = usize::from(self.dim);
609        let ca = self.codes(a);
610        let cb = self.codes(b);
611        let centroids = self.centroids.as_slice();
612        let mut acc = 0.0_f32;
613        for d in 0..dim {
614            acc += centroids[ca[d] as usize] * centroids[cb[d] as usize];
615        }
616        acc * self.scales[a] * self.scales[b]
617    }
618
619    /// Map a codec similarity into the smaller-is-closer
620    /// distance convention used elsewhere in `dynvec`.
621    fn similarity_to_distance(&self, similarity: f32) -> f32 {
622        match self.distance {
623            Distance::DotProduct => -similarity,
624            Distance::Cosine => 1.0 - similarity,
625            Distance::Euclidean => (2.0 - 2.0 * similarity).max(0.0).sqrt(),
626        }
627    }
628
629    fn distance_query(&self, q_rot: &[f32], slot: usize) -> f32 {
630        self.similarity_to_distance(self.similarity_query(q_rot, slot))
631    }
632
633    fn distance_pair(&self, a: usize, b: usize) -> f32 {
634        self.similarity_to_distance(self.similarity_pair(a, b))
635    }
636
637    /// Greedy single-best descent at layer `lc`. `skip_idx` is
638    /// the slot index of the node currently being inserted, if
639    /// any; passing `usize::MAX` disables the filter for
640    /// search-time queries.
641    fn greedy_search_layer(
642        &self,
643        q_rot: &[f32],
644        entry: usize,
645        lc: usize,
646        skip_idx: usize,
647    ) -> usize {
648        let mut current = entry;
649        let mut current_score = self.distance_query(q_rot, current);
650        loop {
651            let mut improved = false;
652            let next = if lc < self.nodes[current].levels.len() {
653                let neighbours = self.nodes[current].levels[lc].as_slice();
654                let mut best = (current, current_score);
655                for &nb in neighbours {
656                    if nb == skip_idx {
657                        continue;
658                    }
659                    let s = self.distance_query(q_rot, nb);
660                    if s < best.1 {
661                        best = (nb, s);
662                        improved = true;
663                    }
664                }
665                best
666            } else {
667                (current, current_score)
668            };
669            current = next.0;
670            current_score = next.1;
671            if !improved {
672                break;
673            }
674        }
675        current
676    }
677
678    /// Beam search at layer `lc`. Returns up to `ef`
679    /// candidates ordered by the underlying max-heap.
680    fn search_layer(
681        &self,
682        q_rot: &[f32],
683        entry_points: &[usize],
684        lc: usize,
685        ef: usize,
686        skip_idx: Option<usize>,
687    ) -> Vec<MaxCandidate> {
688        let mut visited: HashSet<usize> = HashSet::new();
689        let mut frontier: BinaryHeap<Candidate> = BinaryHeap::new();
690        let mut top: BinaryHeap<MaxCandidate> = BinaryHeap::new();
691        for &ep in entry_points {
692            if Some(ep) == skip_idx {
693                continue;
694            }
695            if visited.insert(ep) {
696                let s = self.distance_query(q_rot, ep);
697                frontier.push(Candidate { idx: ep, score: s });
698                top.push(MaxCandidate { idx: ep, score: s });
699            }
700        }
701        while let Some(c) = frontier.pop() {
702            if top.len() >= ef {
703                if let Some(worst) = top.peek() {
704                    if c.score > worst.score {
705                        break;
706                    }
707                }
708            }
709            if lc < self.nodes[c.idx].levels.len() {
710                let neighbours = self.nodes[c.idx].levels[lc].as_slice();
711                for &nb in neighbours {
712                    if Some(nb) == skip_idx {
713                        continue;
714                    }
715                    if !visited.insert(nb) {
716                        continue;
717                    }
718                    let s = self.distance_query(q_rot, nb);
719                    let admit = match top.peek() {
720                        Some(worst) => s < worst.score || top.len() < ef,
721                        None => true,
722                    };
723                    if admit {
724                        frontier.push(Candidate { idx: nb, score: s });
725                        top.push(MaxCandidate { idx: nb, score: s });
726                        if top.len() > ef {
727                            top.pop();
728                        }
729                    }
730                }
731            }
732        }
733        top.into_vec()
734    }
735
736    /// Pick the top-`m` by closest-first heuristic.
737    fn select_neighbours(candidates: &[MaxCandidate], m: usize) -> Vec<usize> {
738        let mut sorted: Vec<MaxCandidate> = candidates.to_vec();
739        sorted.sort_by(|a, b| {
740            a.score
741                .partial_cmp(&b.score)
742                .unwrap_or(std::cmp::Ordering::Equal)
743        });
744        sorted.into_iter().take(m).map(|c| c.idx).collect()
745    }
746
747    /// Drop the longest edges from a node's adjacency list at
748    /// `lc` until it fits in `cap`. Uses [`Self::distance_pair`]
749    /// for stored-to-stored scoring; that double-quantisation
750    /// error is the cost of dropping the f32 fallback that
751    /// [`crate::index::HnswIndex::shrink_connections`] uses.
752    fn shrink_connections(&mut self, idx: usize, lc: usize, cap: usize) {
753        let neighbours = std::mem::take(&mut self.nodes[idx].levels[lc]);
754        let mut scored: Vec<(usize, f32)> = neighbours
755            .into_iter()
756            .map(|nb| {
757                let s = self.distance_pair(idx, nb);
758                (nb, s)
759            })
760            .collect();
761        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
762        scored.truncate(cap);
763        self.nodes[idx].levels[lc] = scored.into_iter().map(|(nb, _)| nb).collect();
764    }
765
766    /// xorshift64* PRNG, deterministic for a given seed.
767    fn rand_unit(&mut self) -> f64 {
768        let mut x = self.rng_state;
769        x ^= x >> 12;
770        x ^= x << 25;
771        x ^= x >> 27;
772        self.rng_state = x;
773        let r = x.wrapping_mul(0x2545_F491_4F6C_DD1D);
774        let bits = (r >> 11) & ((1_u64 << 53) - 1);
775        // `bits` is in [0, 2^53), exactly representable as f64.
776        #[allow(
777            clippy::cast_precision_loss,
778            reason = "bits is in [0, 2^53), exactly representable as f64"
779        )]
780        let f = (bits as f64) / ((1_u64 << 53) as f64);
781        f
782    }
783
784    /// Random level assignment: `floor(-ln(uniform(0, 1)) *
785    /// mL)` capped at 16 to keep allocations sane.
786    fn random_level(&mut self) -> usize {
787        let r = self.rand_unit().max(f64::MIN_POSITIVE);
788        let level = (-r.ln() * self.ml).floor();
789        let clamped = level.clamp(0.0, 16.0);
790        // Clamped to [0, 16]; the cast is well-defined.
791        #[allow(
792            clippy::cast_possible_truncation,
793            clippy::cast_sign_loss,
794            reason = "clamped to [0, 16]"
795        )]
796        let lvl = clamped as usize;
797        lvl
798    }
799}
800
801impl<const BITS: u8> CodecDistance for TurboHnswIndex<BITS> {
802    fn distance(&self, a: NodeId, b: NodeId) -> f32 {
803        let Some(&sa) = self.id_to_idx.get(&a) else {
804            return f32::INFINITY;
805        };
806        let Some(&sb) = self.id_to_idx.get(&b) else {
807            return f32::INFINITY;
808        };
809        self.distance_pair(sa, sb)
810    }
811}
812
813fn l2_normalise(v: &[f32]) -> Vec<f32> {
814    let n2: f32 = v.iter().map(|x| x * x).sum();
815    let n = n2.sqrt();
816    if n <= 0.0 {
817        return v.to_vec();
818    }
819    v.iter().map(|x| x / n).collect()
820}
821
822#[cfg(test)]
823mod tests {
824    use super::*;
825
826    fn rand_vec(seed: u64, dim: usize) -> Vec<f32> {
827        let mut x = if seed == 0 { 0xDEAD_BEEF } else { seed };
828        let mut v = Vec::with_capacity(dim);
829        for _ in 0..dim {
830            x ^= x << 13;
831            x ^= x >> 7;
832            x ^= x << 17;
833            let bits = (x >> 11) & ((1_u64 << 53) - 1);
834            #[allow(
835                clippy::cast_precision_loss,
836                clippy::cast_possible_truncation,
837                reason = "test fixture: PRNG narrowed to f32"
838            )]
839            let r = (((bits as f64) / ((1_u64 << 53) as f64)) * 2.0 - 1.0) as f32;
840            v.push(r);
841        }
842        v
843    }
844
845    #[test]
846    fn insert_and_search_returns_self_first_4bit() {
847        let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
848            .expect("4-bit ctor");
849        let target = rand_vec(42, 64);
850        idx.insert(0, target.clone()).unwrap();
851        for i in 1..50_u64 {
852            idx.insert(i, rand_vec(i.wrapping_mul(1_000_003) + 1, 64))
853                .unwrap();
854        }
855        let res = idx.search(&target, 3, None).unwrap();
856        assert!(!res.is_empty());
857        assert_eq!(res[0].id, 0);
858    }
859
860    #[test]
861    fn delete_excludes_from_search() {
862        let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
863            .expect("4-bit ctor");
864        for i in 0..30_u64 {
865            idx.insert(i, rand_vec(i + 1, 64)).unwrap();
866        }
867        let q = rand_vec(1, 64);
868        let before = idx.search(&q, 5, None).unwrap();
869        let target = before[0].id;
870        assert!(idx.delete(target));
871        let after = idx.search(&q, 5, None).unwrap();
872        assert!(after.iter().all(|r| r.id != target));
873    }
874
875    #[test]
876    fn duplicate_id_rejected() {
877        let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
878            .expect("4-bit ctor");
879        idx.insert(7, rand_vec(7, 64)).unwrap();
880        assert!(matches!(
881            idx.insert(7, rand_vec(8, 64)),
882            Err(IndexError::Duplicate(7))
883        ));
884    }
885
886    #[test]
887    fn dimension_mismatch_rejected() {
888        let mut idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
889            .expect("4-bit ctor");
890        assert!(matches!(
891            idx.insert(0, vec![0.1; 32]),
892            Err(IndexError::DimensionMismatch { .. })
893        ));
894    }
895
896    #[test]
897    fn empty_index_search_is_empty() {
898        let idx = TurboHnswIndex::<4>::new(Distance::Cosine, 64, HnswParams::default())
899            .expect("4-bit ctor");
900        let res = idx.search(&rand_vec(0, 64), 5, None).unwrap();
901        assert!(res.is_empty());
902    }
903
904    #[test]
905    fn ctor_rejects_misaligned_dim() {
906        let r = TurboHnswIndex::<4>::new(Distance::Cosine, 7, HnswParams::default());
907        assert!(matches!(
908            r,
909            Err(IndexError::DimensionMismatch {
910                expected: 8,
911                got: 7
912            })
913        ));
914    }
915}