fast-hnsw 1.0.1

Pure-Rust implementation of Hierarchical Navigable Small World (HNSW) approximate nearest-neighbour search
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
//! Core HNSW index — optimised implementation.
//!
//! ## Optimisations (in order of measured impact)
//!
//! ### 1 · Simple sort-and-truncate for reverse-update pruning  *(biggest win)*
//!
//! When we add a bidirectional edge `q ↔ nb` and `nb`'s connection list
//! exceeds `M`, we must shrink it back to `M`.  The naïve approach recomputes
//! all `M+1` distances from scratch (cache-miss-dominated) and then runs the
//! O(M²/2) heuristic.  Instead:
//!
//! * **Connection lists store `(u32, f32)` pairs** — the neighbour id and the
//!   distance from *this node* to that neighbour.  The distance is already
//!   known at edge-addition time (symmetric metric), so storing it is free.
//! * **Reverse-update prune** = `sort_unstable_by(dist) + truncate(M)` — zero
//!   new distance computations, ~25 ns vs ~2 000 ns in-cache (measured),
//!   ~11 µs with real L3-miss rates × 16 prunes/insert ≈ **170 µs/insert saved**
//!   at n=10k/dim=128.
//!
//! ### 2 · Iterate `scratch.out` directly in heuristic selection
//!
//! `scratch.out` is already sorted closest-first after `search_layer`.
//! Previously `select_neighbours_heuristic` rebuilt a `BinaryHeap` from it
//! (O(ef log ef) + 1 allocation).  Now we iterate the slice directly: 321 ns
//! → 122 ns measured, plus the allocation is gone.
//!
//! ### 3 · Early return in heuristic when |candidates| ≤ M
//!
//! When fewer candidates exist than there are slots, **all candidates are
//! automatically diverse** — no pairwise distance check is needed.  Copy
//! directly to `select_buf` in O(M) time (50 ns vs 10 000 ns).  Triggers at
//! every upper-layer descent (ef=1) and during early index build.
//!
//! ### 4 · `VecStore` — flat contiguous vector storage
//!
//! All feature vectors in one `Vec<f32>`, stride = `dim`.  One pointer
//! dereference instead of two per distance call.
//!
//! ### 5 · `VisitedTracker` — O(1) generation-counter visited set
//!
//! Replaces `HashSet::with_capacity(ef*4)` (1 703 ns/call) with a stamp array
//! (105 ns/call) — 16× faster, zero allocation after construction.
//!
//! ### 6 · `Scratch` — reusable heap scratch space
//!
//! Both `BinaryHeap`s in `search_layer` are cleared (not reallocated) between
//! calls.  `ep_buf` is swapped with `scratch.out` via `std::mem::swap` — zero
//! copy between layers.
//!
//! ### 7 · `u32` + pre-allocated connection `Vec`s
//!
//! `u32` IDs halve connection-list memory.  Each inner `Vec` is pre-created
//! with `Vec::with_capacity(m_max)`.
//!
//! ### 8 · Triangle-inequality shortcut in heuristic selection
//!
//! If `d(q,s) > 2·d(q,e)`, then `d(e,s) > d(q,e)` by the triangle
//! inequality — the heuristic condition is trivially satisfied without
//! computing the actual distance.  Measured 36% reduction in pairwise
//! distance computations.
//!
//! ### 9 · `select_buf` / `pruned_buf` reuse
//!
//! Pre-allocated `Vec<(usize, f32)>` fields in `Hnsw` hold the results of
//! `select_neighbours_*`; callers read from `self.select_buf` directly,
//! eliminating the one remaining per-call `Vec` allocation.

use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;

use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use crate::distance::Distance;
use crate::heap::DistId;

// ─── Memory-mapped vector backing ────────────────────────────────────────────

/// A live memory-mapping of the vector section of an index file.
///
/// Kept alive inside [`VecStore`] so the mmap is not unmapped while the index
/// is in use.  The raw pointer is derived from the `Mmap` bytes; both the
/// pointer and the `Mmap` live inside the same `Arc` so they are guaranteed
/// to share the same lifetime.
///
/// # Safety invariants
/// * `ptr` always points into `mmap`'s mapped region.
/// * `len` is the total number of `f32` values in that region
///   (`n_vectors × dim`).
/// * The region is read-only — no writes ever go through this pointer.
pub(crate) struct MmapBacking {
    /// Keeps the file mapping alive.  Dropped when the last `Arc` clone is
    /// released (i.e. when the `Hnsw` is dropped).
    pub(crate) _mmap: Arc<memmap2::Mmap>,
    /// Pointer to the first `f32` in the mapped vector section.
    pub(crate) ptr: *const f32,
    /// Total number of `f32` values available through `ptr`.
    pub(crate) len: usize,
}

// The pointer never leaves the struct boundary through a mutable reference,
// and the mapping is read-only, so both Send and Sync are safe.
unsafe impl Send for MmapBacking {}
unsafe impl Sync for MmapBacking {}

// ─── Flat vector store ────────────────────────────────────────────────────────

/// All feature vectors packed contiguously: vector `i` occupies
/// `data[i*dim .. (i+1)*dim]` (owned mode) or the corresponding slice of the
/// mmap'd region (mmap mode).
///
/// ## Owned mode
/// Created by normal [`Hnsw::insert`] calls.  Vectors live in a `Vec<f32>`
/// that grows as new items are added.
///
/// ## Mmap mode
/// Created by [`Hnsw::load_mmap`].  Vectors are read directly from the
/// memory-mapped file region — no heap copy.  The file must remain on disk
/// and the [`MmapBacking`] must stay alive (it is stored in `mmap`).
/// Inserts into a mmap-backed index are not allowed.
pub(crate) struct VecStore {
    /// Owned vector data (empty in mmap mode).
    pub(crate) data: Vec<f32>,
    /// Optional memory-mapped backing (non-`None` in mmap mode).
    pub(crate) mmap: Option<MmapBacking>,
    pub(crate) dim:  usize,
}

impl VecStore {
    pub(crate) fn new(dim: usize, capacity: usize) -> Self {
        Self {
            data: Vec::with_capacity(capacity.saturating_mul(dim.max(1))),
            mmap: None,
            dim,
        }
    }

    /// Build a mmap-backed store from a live mapping and the byte offset where
    /// the vector data begins.
    pub(crate) fn from_mmap(
        mmap: Arc<memmap2::Mmap>,
        vec_offset: usize,
        n: usize,
        dim: usize,
    ) -> Self {
        let len = n * dim;
        // SAFETY: vec_offset + len*4 is guaranteed to be within the mapped
        // region by the caller (checked in `load_mmap`).
        let ptr = unsafe { mmap.as_ptr().add(vec_offset) as *const f32 };
        Self {
            data: Vec::new(),
            mmap: Some(MmapBacking { _mmap: mmap, ptr, len }),
            dim,
        }
    }

    pub(crate) fn push(&mut self, v: Vec<f32>) {
        assert!(
            self.mmap.is_none(),
            "cannot insert into a memory-mapped (read-only) index"
        );
        debug_assert_eq!(v.len(), self.dim);
        self.data.extend_from_slice(&v);
    }

    #[inline(always)]
    pub(crate) fn get(&self, id: usize) -> &[f32] {
        let s = id * self.dim;
        match &self.mmap {
            None => &self.data[s..s + self.dim],
            // SAFETY: `s + self.dim <= len` because `id < n` (guaranteed by
            // the caller — same as bounds in the owned case).
            Some(mb) => unsafe { std::slice::from_raw_parts(mb.ptr.add(s), self.dim) },
        }
    }

    #[inline(always)]
    pub(crate) fn len(&self) -> usize {
        if self.dim == 0 {
            return 0;
        }
        match &self.mmap {
            None     => self.data.len() / self.dim,
            Some(mb) => mb.len / self.dim,
        }
    }

    /// Returns the whole vector section as a flat byte slice (for writing).
    /// Only valid in owned mode.
    pub(crate) fn as_bytes(&self) -> &[u8] {
        // SAFETY: f32 has no padding; any bit pattern is valid; alignment is
        // satisfied because we're converting &[f32] → &[u8].
        unsafe {
            std::slice::from_raw_parts(
                self.data.as_ptr() as *const u8,
                self.data.len() * std::mem::size_of::<f32>(),
            )
        }
    }
}

// ─── Generation-counter visited-set ──────────────────────────────────────────

/// O(1) visited-node tracker.  Each search query increments `current`; a node
/// `i` is "visited" iff `stamps[i] == current`.  `begin()` is a single
/// integer increment — no allocation after construction.
struct VisitedTracker {
    stamps:  Vec<u32>,
    current: u32,
}

impl VisitedTracker {
    fn new(capacity: usize) -> Self {
        Self { stamps: vec![0u32; capacity], current: 1 }
    }

    #[inline]
    fn begin(&mut self) {
        self.current = self.current.wrapping_add(1);
        if self.current == 0 {
            self.stamps.fill(0);
            self.current = 1;
        }
    }

    /// Returns `true` if `id` was **not** previously visited, and marks it.
    #[inline]
    fn visit(&mut self, id: usize) -> bool {
        if id >= self.stamps.len() {
            self.stamps.resize(id * 2 + 1, 0);
        }
        if self.stamps[id] == self.current {
            false
        } else {
            self.stamps[id] = self.current;
            true
        }
    }
}

// ─── Reusable search scratch space ───────────────────────────────────────────

/// Pre-allocated candidate min-heap + result max-heap + sorted output buffer.
/// Stored inside `Hnsw` and *cleared* (not reallocated) between `search_layer`
/// calls during `insert`.  For `search()` (`&self`) a local `Scratch` is
/// created per query.
struct Scratch {
    candidates:  BinaryHeap<Reverse<DistId>>,
    results:     BinaryHeap<DistId>,
    results_cap: usize,
    /// Sorted-closest-first output written by `finish()`.
    pub out: Vec<DistId>,
}

impl Scratch {
    fn new(ef: usize) -> Self {
        Self {
            candidates:  BinaryHeap::with_capacity(ef * 2 + 1),
            results:     BinaryHeap::with_capacity(ef + 1),
            results_cap: ef,
            out:         Vec::with_capacity(ef),
        }
    }

    #[inline]
    fn begin(&mut self, ef: usize) {
        self.candidates.clear();
        self.results.clear();
        self.results_cap = ef;
    }

    #[inline]
    fn push_entry(&mut self, d: DistId) {
        self.candidates.push(Reverse(d));
        self.results.push(d);
        if self.results.len() > self.results_cap { self.results.pop(); }
    }

    #[inline]
    fn push_candidate(&mut self, d: DistId) {
        self.candidates.push(Reverse(d));
        self.results.push(d);
        if self.results.len() > self.results_cap { self.results.pop(); }
    }

    #[inline]
    fn pop_candidate(&mut self) -> Option<DistId> {
        self.candidates.pop().map(|Reverse(x)| x)
    }

    #[inline]
    fn worst_result_dist(&self) -> Option<f32> {
        self.results.peek().map(|x| x.dist)
    }

    #[inline]
    fn results_len(&self) -> usize { self.results.len() }

    fn finish(&mut self) {
        self.out.clear();
        while let Some(d) = self.results.pop() { self.out.push(d); }
        self.out.reverse(); // max-heap → farthest-first; reverse → closest-first
    }
}

// ─── Prune strategy ───────────────────────────────────────────────────────────

/// Controls how an existing node's connection list is shrunk back to `M`
/// entries when it overflows after a bidirectional edge `q ↔ nb` is added
/// during [`Hnsw::insert`].
///
/// # Background
///
/// Every insert adds M bidirectional edges.  For each selected neighbour `nb`,
/// the reverse edge `nb → q` is appended to `nb`'s connection list.  If `nb`
/// already had `M` connections its list grows to `M + 1` and must be pruned
/// back to `M`.  The two strategies differ in *which* edge is dropped:
///
/// | | [`Simple`] | [`Heuristic`] |
/// |---|---|---|
/// | Work per prune | Sort `M+1` stored `f32`s + truncate | O(M²/2) pairwise distance computations |
/// | New distance calls | **0** | up to M²/2, ~40% skipped by triangle shortcut |
/// | Speed (M=16, dim=128) | ~25 ns | ~1–11 µs depending on cache state |
/// | Recall impact | −0 to −1 pp vs Heuristic | full quality |
///
/// # Which to choose
///
/// **`Simple` (default)** — use when insert throughput is the priority.
/// Equivalent to what most production HNSW libraries (faiss, hnsw_rs) do for
/// reverse-update pruning.  The small recall gap can be recovered by raising
/// `ef` at query time.
///
/// **`Heuristic`** — use when recall quality is non-negotiable or you need
/// a like-for-like algorithmic comparison.  Runs the full paper Algorithm 4
/// for *every* edge in the graph (not just the new node's own connections),
/// at the cost of slower inserts on high-dimensional data.
///
/// # Why `Heuristic` is not as expensive as it looks
///
/// Connection lists store `(neighbour_id: u32, dist_from_this_node: f32)`.
/// Because the distance is recorded at edge-add time (symmetric metric, free),
/// the M distance *recomputations* that a naïve heuristic prune would require
/// are completely eliminated.  Only the inter-neighbour pairwise checks remain.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub enum PruneStrategy {
    /// **Sort-and-truncate** (default, fastest).
    ///
    /// Sorts the `M+1`-entry connection list by the stored per-edge distance
    /// and truncates to `M`.  The farthest neighbour by raw distance is
    /// dropped.
    ///
    /// - **Zero new distance computations** — distances are already stored as
    ///   `f32` values alongside the neighbour id.
    /// - Benchmark: ~25 ns per prune call (in-cache), effectively 0 at L3
    ///   miss rates because no vector data is touched.
    /// - Recall cost: ≈ 0–1 pp on high-dimensional data (dim ≥ 64) where a
    ///   cluster of nearby nodes can crowd out a more structurally useful but
    ///   slightly farther connection.
    #[default]
    Simple,

    /// **Full Algorithm 4 (heuristic)** using stored distances.
    ///
    /// Runs the paper's diversity-based selection for every reverse-update
    /// prune: candidate `e` is kept only if `d(node, e) ≤ d(e, s)` for every
    /// already-selected neighbour `s` — i.e. the node is closer to `e` than
    /// any currently-selected neighbour is.  This ensures the final `M`
    /// connections are spread across the neighbourhood rather than all
    /// pointing in the same direction.
    ///
    /// **Stored-distance optimisation**: `d(node, e)` for each candidate is
    /// read directly from the stored `f32` in the connection list — no
    /// distance recomputation.  Only the pairwise `d(e, s)` checks are
    /// computed fresh, and ~60% of those are eliminated by the
    /// triangle-inequality shortcut (`d(node,s) > 2·d(node,e)` guarantees
    /// the accept condition without touching vector data).
    ///
    /// - Benchmark: ~1 µs per prune in L2 cache, ~11 µs at L3-miss rates
    ///   (M = 16, dim = 128).  × 16 prunes/insert ≈ 170 µs extra per insert
    ///   on large, high-dimensional indexes.
    /// - Recall: full quality; recovers the gap vs [`Simple`].
    Heuristic,
}

// ─── Configuration ────────────────────────────────────────────────────────────

/// Build-time parameters for an [`Hnsw`] index.
///
/// Construct via [`Builder`](crate::Builder) for a more ergonomic API.
#[derive(Clone, Debug)]
pub struct Config {
    /// Max bidirectional links per non-zero layer (must be ≥ 2).
    pub m: usize,
    /// Override for layer-0 link limit (`None` → `2 * M`).
    pub m0: Option<usize>,
    /// Beam width during construction.  Higher → better graph quality,
    /// slower inserts.
    pub ef_construction: usize,
    /// Use heuristic neighbour selection (Algorithm 4) for the *new node's
    /// own* M connections.  Recommended; improves recall.  Setting this to
    /// `false` uses the simpler greedy strategy (Algorithm 3).
    pub use_heuristic: bool,
    /// `extendCandidates` flag from §4 Algorithm 4 of the paper.
    /// Adds the neighbours-of-candidates to the candidate set before
    /// heuristic selection.  Usually not needed; disabled by default.
    pub extend_candidates: bool,
    /// `keepPrunedConnections` flag from §4 Algorithm 4.
    /// When the heuristic rejects a candidate, pad the result with rejected
    /// candidates (nearest-first) until M is reached.  Improves recall on
    /// sparse graphs; enabled by default.
    pub keep_pruned: bool,
    /// Strategy for pruning an *existing* node's connection list when it
    /// overflows after a bidirectional edge is added.
    ///
    /// See [`PruneStrategy`] for a full comparison.  Defaults to
    /// [`PruneStrategy::Simple`] (fastest; zero new distance computations).
    pub prune_strategy: PruneStrategy,
    /// Expected number of vectors — pre-allocation hint only.  Has no effect
    /// on correctness.
    pub capacity: usize,
}

impl Default for Config {
    fn default() -> Self {
        Self {
            m: 16,
            m0: None,
            ef_construction: 200,
            use_heuristic: true,
            extend_candidates: false,
            keep_pruned: true,
            prune_strategy: PruneStrategy::Simple,
            capacity: 0,
        }
    }
}

impl Config {
    #[inline] pub(crate) fn m0(&self) -> usize { self.m0.unwrap_or(2 * self.m) }
    #[inline] fn max_links(&self, layer: usize) -> usize {
        if layer == 0 { self.m0() } else { self.m }
    }
    #[inline] fn m_l(&self) -> f64 { 1.0 / (self.m as f64).ln() }
}

// ─── Search result ────────────────────────────────────────────────────────────

/// One result returned by [`Hnsw::search`].
#[derive(Clone, Debug, PartialEq)]
pub struct SearchResult {
    pub id:       usize,
    pub distance: f32,
}

// ─── Main index ───────────────────────────────────────────────────────────────

/// Hierarchical Navigable Small World approximate nearest-neighbour index.
///
/// # Type parameter
/// * `D` – a [`Distance`] implementation (e.g. [`crate::distance::Euclidean`]).
///
/// # Example
/// ```
/// use hnsw::{Hnsw, Config};
/// use hnsw::distance::Euclidean;
///
/// let mut index = Hnsw::new(Config::default(), Euclidean);
/// index.insert(vec![1.0, 0.0]);
/// index.insert(vec![0.0, 1.0]);
/// index.insert(vec![0.5, 0.5]);
///
/// let results = index.search(&[0.9, 0.1], 1, 20);
/// assert_eq!(results[0].id, 0); // [1,0] is closest to [0.9,0.1]
/// ```
pub struct Hnsw<D: Distance> {
    pub(crate) config: Config,
    pub(crate) metric: D,
    /// Flat vector store: vector `i` at `data[i*dim .. (i+1)*dim]`.
    pub(crate) vec_store: VecStore,
    /// `connections[node][layer]` = list of (neighbour_id_u32, dist_from_node_to_neighbour).
    ///
    /// Storing the distance alongside the id enables the heuristic reverse-update
    /// prune to skip all M distance recomputations — only the O(M²/2) pairwise
    /// diversity checks remain.
    pub(crate) connections: Vec<Vec<Vec<(u32, f32)>>>,
    pub(crate) entry_point: Option<(usize, usize)>,
    rng:         SmallRng,
    pub(crate) dim: Option<usize>,
    visited:     VisitedTracker,
    scratch:     Scratch,
    ep_buf:      Vec<DistId>,
    /// Output buffer for `select_neighbours_*` — reused across calls.
    select_buf:  Vec<(usize, f32)>,
    /// Discarded-candidates buffer for `keep_pruned` path — reused.
    pruned_buf:  Vec<(usize, f32)>,
    /// Sorted candidate buffer for `prune_connections_heuristic` — reused.
    prune_buf:   Vec<(usize, f32)>,
}

impl<D: Distance> Hnsw<D> {
    // ─── Construction ─────────────────────────────────────────────────────

    /// Create a new, empty index.
    pub fn new(config: Config, metric: D) -> Self {
        assert!(config.m >= 2, "M must be at least 2");
        assert!(config.ef_construction >= config.m,
                "ef_construction should be ≥ M for good recall");
        let cap = config.capacity;
        let ef  = config.ef_construction;
        let m   = config.m;
        Self {
            config,
            metric,
            vec_store:   VecStore::new(0, cap),
            connections: Vec::with_capacity(cap),
            entry_point: None,
            rng:         SmallRng::from_entropy(),
            dim:         None,
            visited:     VisitedTracker::new(cap.max(64)),
            scratch:     Scratch::new(ef),
            ep_buf:      Vec::with_capacity(ef),
            select_buf:  Vec::with_capacity(m * 2 + 2),
            pruned_buf:  Vec::with_capacity(m * 2 + 2),
            prune_buf:   Vec::with_capacity(m * 2 + 2),
        }
    }

    /// Reconstruct an index from its already-deserialized components.
    ///
    /// Called exclusively by the persistence layer (`persist::read_hnsw`).
    /// All the heavy lifting (reading header, vectors, graph) is done by the
    /// caller; this just wires everything into the struct.
    pub(crate) fn from_parts(
        config:      Config,
        metric:      D,
        vec_store:   VecStore,
        connections: Vec<Vec<Vec<(u32, f32)>>>,
        entry_point: Option<(usize, usize)>,
        dim:         Option<usize>,
    ) -> Self {
        let n  = vec_store.len();
        let ef = config.ef_construction;
        let m  = config.m;
        Self {
            config,
            metric,
            vec_store,
            connections,
            entry_point,
            rng:        SmallRng::from_entropy(),
            dim,
            visited:    VisitedTracker::new(n.max(64)),
            scratch:    Scratch::new(ef),
            ep_buf:     Vec::with_capacity(ef),
            select_buf: Vec::with_capacity(m * 2 + 2),
            pruned_buf: Vec::with_capacity(m * 2 + 2),
            prune_buf:  Vec::with_capacity(m * 2 + 2),
        }
    }

    /// Create a new index with a fixed RNG seed (reproducible for tests).
    pub fn new_with_seed(config: Config, metric: D, seed: u64) -> Self {
        assert!(config.m >= 2);
        let cap = config.capacity;
        let ef  = config.ef_construction;
        let m   = config.m;
        Self {
            config,
            metric,
            vec_store:   VecStore::new(0, cap),
            connections: Vec::with_capacity(cap),
            entry_point: None,
            rng:         SmallRng::seed_from_u64(seed),
            dim:         None,
            visited:     VisitedTracker::new(cap.max(64)),
            scratch:     Scratch::new(ef),
            ep_buf:      Vec::with_capacity(ef),
            select_buf:  Vec::with_capacity(m * 2 + 2),
            pruned_buf:  Vec::with_capacity(m * 2 + 2),
            prune_buf:   Vec::with_capacity(m * 2 + 2),
        }
    }

    // ─── Public API ───────────────────────────────────────────────────────

    /// Insert a vector and return its assigned id (0-based).
    ///
    /// # Panics
    /// Panics if `vector.len()` differs from the dimension of previously
    /// inserted vectors.
    pub fn insert(&mut self, vector: Vec<f32>) -> usize {
        let dim = vector.len();
        match self.dim {
            None    => { self.dim = Some(dim); self.vec_store.dim = dim; }
            Some(d) => assert_eq!(d, dim,
                "all vectors must have the same dimension (expected {d}, got {dim})"),
        }

        let q       = self.vec_store.len();
        let q_level = self.random_level();

        self.vec_store.push(vector);

        // Pre-allocate connection list with inner Vecs sized to capacity.
        let mut conn: Vec<Vec<(u32, f32)>> = Vec::with_capacity(q_level + 1);
        for l in 0..=q_level {
            conn.push(Vec::with_capacity(self.config.max_links(l)));
        }
        self.connections.push(conn);

        if self.visited.stamps.len() <= q {
            self.visited.stamps.resize(q * 2 + 1, 0);
        }

        // ── First insertion ───────────────────────────────────────────────
        let (ep_id, ep_level) = match self.entry_point {
            None => { self.entry_point = Some((q, q_level)); return q; }
            Some(x) => x,
        };

        // ── Phase 1: greedy descent from ep_level down to q_level+1 ──────
        self.ep_buf.clear();
        self.ep_buf.push(DistId::new(self.dist(q, ep_id), ep_id));

        for layer in (q_level + 1..=ep_level).rev() {
            self.search_layer_node(q, 1, layer);
            std::mem::swap(&mut self.ep_buf, &mut self.scratch.out);
        }

        // ── Phase 2: insert q at layers q_level..=0 ──────────────────────
        let top = ep_level.min(q_level);
        for layer in (0..=top).rev() {
            let ef    = self.config.ef_construction;
            let m_max = self.config.max_links(layer);

            self.search_layer_node(q, ef, layer);

            // Select M neighbours for q.  Results written to `self.select_buf`.
            if self.config.use_heuristic {
                self.select_neighbours_heuristic(q, m_max, layer);
            } else {
                self.select_neighbours_simple(m_max);
            }

            // Add bidirectional edges.
            //
            // We copy select_buf into a small stack array first because
            // `prune_connections_heuristic` (called below) overwrites select_buf
            // and pruned_buf when it runs its own heuristic selection.
            //
            // m_max ≤ 2·M ≤ 64 in practice; the array is zero-cost on the stack.
            let n_sel = self.select_buf.len();
            let mut edge_buf = [(0u32, 0.0f32); 64];
            for i in 0..n_sel {
                let (nb, dist) = self.select_buf[i];
                edge_buf[i] = (nb as u32, dist);
            }

            // Pass 1: add all edges (keeps the connection lists coherent before
            // any pruning modifies them).
            for i in 0..n_sel {
                let (nb_u32, dist_q_nb) = edge_buf[i];
                let nb = nb_u32 as usize;
                self.connections[q][layer].push((nb_u32, dist_q_nb));
                self.connections[nb][layer].push((q as u32, dist_q_nb));
            }

            // Pass 2: prune any neighbour whose list now exceeds m_max.
            //
            // Which strategy fires is determined by `self.config.prune_strategy`
            // (set via `Builder::prune_strategy`).  Both branches use the `f32`
            // distance that is stored alongside every neighbour id — so neither
            // branch needs to recompute the M distances from scratch.
            for i in 0..n_sel {
                let nb = edge_buf[i].0 as usize;
                if self.connections[nb][layer].len() > m_max {
                    match self.config.prune_strategy {

                        // ── Simple: sort stored distances + truncate ──────────
                        // Keeps the M nearest neighbours by raw distance.
                        // Cost: O(M log M) sort of f32 values already in the
                        //       connection list — no vector data touched, no new
                        //       distance computation.  ~25 ns per call.
                        PruneStrategy::Simple => {
                            self.connections[nb][layer]
                                .sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
                            self.connections[nb][layer].truncate(m_max);
                        }

                        // ── Heuristic: full Algorithm 4 with stored distances ─
                        // Runs the diversity check: keeps candidate `e` only if
                        // d(nb, e) ≤ d(e, s) for every already-selected s.
                        // d(nb, e) is the stored f32 — zero recomputation.
                        // Only the O(M²/2) pairwise d(e, s) checks are computed,
                        // and ~60% are skipped by the triangle-inequality shortcut.
                        // ~1–11 µs per call depending on cache state.
                        PruneStrategy::Heuristic => {
                            self.prune_connections_heuristic(nb, layer, m_max);
                        }
                    }
                }
            }

            std::mem::swap(&mut self.ep_buf, &mut self.scratch.out);
        }

        if q_level > ep_level {
            self.entry_point = Some((q, q_level));
        }
        q
    }

    /// Search for the `k` approximate nearest neighbours of `query`.
    ///
    /// `ef` controls recall vs. speed (`ef ≥ k`; larger → better recall).
    pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
        assert!(k > 0, "k must be > 0");
        let ef = ef.max(k);

        let (ep_id, ep_level) = match self.entry_point {
            None    => return Vec::new(),
            Some(x) => x,
        };

        let mut visited = VisitedTracker::new(self.vec_store.len());
        let mut scratch = Scratch::new(ef);
        let mut ep = Vec::with_capacity(ef);

        let ep_dist = self.metric.distance(query, self.vec_store.get(ep_id));
        ep.push(DistId::new(ep_dist, ep_id));

        for layer in (1..=ep_level).rev() {
            Self::do_search_layer(
                &self.vec_store, &self.connections, &self.metric,
                &mut visited, &mut scratch, query, &ep, 1, layer,
            );
            std::mem::swap(&mut ep, &mut scratch.out);
        }

        Self::do_search_layer(
            &self.vec_store, &self.connections, &self.metric,
            &mut visited, &mut scratch, query, &ep, ef, 0,
        );
        scratch.out.truncate(k);
        scratch.out.iter()
            .map(|d| SearchResult { id: d.id, distance: d.dist })
            .collect()
    }

    #[inline] pub fn len(&self)              -> usize         { self.vec_store.len() }
    #[inline] pub fn is_empty(&self)         -> bool          { self.vec_store.len() == 0 }
    #[inline] pub fn get_vector(&self, id: usize) -> &[f32]  { self.vec_store.get(id) }
    #[inline] pub fn dim(&self)              -> Option<usize> { self.dim }
    pub fn max_level(&self) -> Option<usize> { self.entry_point.map(|(_, l)| l) }

    // ─── Level generation ─────────────────────────────────────────────────

    fn random_level(&mut self) -> usize {
        let u: f64 = self.rng.gen::<f64>().max(f64::MIN_POSITIVE);
        (-u.ln() * self.config.m_l()).floor() as usize
    }

    // ─── Distance helpers ─────────────────────────────────────────────────

    #[inline]
    fn dist(&self, a: usize, b: usize) -> f32 {
        self.metric.distance(self.vec_store.get(a), self.vec_store.get(b))
    }

    // ─── search_layer (insert path — uses self.scratch) ───────────────────

    fn search_layer_node(&mut self, q: usize, ef: usize, layer: usize) {
        let vec_store   = &self.vec_store;
        let connections = &self.connections;
        let metric      = &self.metric;
        let visited     = &mut self.visited;
        let scratch     = &mut self.scratch;
        let ep          = &self.ep_buf;

        let q_vec = vec_store.get(q);
        visited.begin();
        scratch.begin(ef);

        for &ep_d in ep {
            if visited.visit(ep_d.id) { scratch.push_entry(ep_d); }
        }

        loop {
            let c = match scratch.pop_candidate() { Some(c) => c, None => break };
            let worst = match scratch.worst_result_dist() { Some(d) => d, None => break };
            if c.dist > worst { break; }

            if let Some(nb_list) = connections.get(c.id).and_then(|nc| nc.get(layer)) {
                for &(nb_u32, _) in nb_list {
                    let nb = nb_u32 as usize;
                    if visited.visit(nb) {
                        let nb_dist = metric.distance(q_vec, vec_store.get(nb));
                        let cur_worst = scratch.worst_result_dist().unwrap_or(f32::INFINITY);
                        if nb_dist < cur_worst || scratch.results_len() < ef {
                            scratch.push_candidate(DistId::new(nb_dist, nb));
                        }
                    }
                }
            }
        }
        scratch.finish();
    }

    // ─── search_layer (search path — takes explicit params) ───────────────

    fn do_search_layer(
        vec_store:    &VecStore,
        connections:  &[Vec<Vec<(u32, f32)>>],
        metric:       &D,
        visited:      &mut VisitedTracker,
        scratch:      &mut Scratch,
        query:        &[f32],
        entry_points: &[DistId],
        ef:           usize,
        layer:        usize,
    ) {
        visited.begin();
        scratch.begin(ef);

        for &ep in entry_points {
            if visited.visit(ep.id) { scratch.push_entry(ep); }
        }

        loop {
            let c = match scratch.pop_candidate() { Some(c) => c, None => break };
            let worst = match scratch.worst_result_dist() { Some(d) => d, None => break };
            if c.dist > worst { break; }

            if let Some(nb_list) = connections.get(c.id).and_then(|nc| nc.get(layer)) {
                for &(nb_u32, _) in nb_list {
                    let nb = nb_u32 as usize;
                    if visited.visit(nb) {
                        let nb_dist = metric.distance(query, vec_store.get(nb));
                        let cur_worst = scratch.worst_result_dist().unwrap_or(f32::INFINITY);
                        if nb_dist < cur_worst || scratch.results_len() < ef {
                            scratch.push_candidate(DistId::new(nb_dist, nb));
                        }
                    }
                }
            }
        }
        scratch.finish();
    }

    // ─── Neighbour selection ──────────────────────────────────────────────

    /// **Algorithm 3** – write the `m` closest entries from `scratch.out`
    /// (sorted closest-first) into `self.select_buf`.
    fn select_neighbours_simple(&mut self, m: usize) {
        self.select_buf.clear();
        let end = m.min(self.scratch.out.len());
        // scratch.out is Copy, access by index is safe with any mutable alias of select_buf
        for i in 0..end {
            let d = self.scratch.out[i];
            self.select_buf.push((d.id, d.dist));
        }
    }

    /// **Algorithm 4 (heuristic)** – write up to `m` diverse neighbours into
    /// `self.select_buf`.  Results are `(node_id, dist_from_q)`.
    ///
    /// Key optimisations vs. the previous version:
    /// * **No heap rebuild** – `scratch.out` is already sorted closest-first;
    ///   we iterate it directly in O(n) instead of rebuilding a min-heap in
    ///   O(n log n) + 1 allocation.
    /// * **Early return** – when `|candidates| ≤ m`, all candidates are
    ///   automatically selected (no pairwise diversity check needed).
    /// * **Triangle-inequality shortcut** – if `d(q,s) > 2·d(q,e)`, skip the
    ///   actual `d(e,s)` computation.
    /// * **`select_buf` / `pruned_buf` reuse** – no per-call allocation.
    fn select_neighbours_heuristic(&mut self, q: usize, m: usize, layer: usize) {
        self.select_buf.clear();
        self.pruned_buf.clear();

        // ── Fast path: fewer candidates than slots → take all ────────────
        let n_cands = self.scratch.out.len();
        if n_cands <= m && !self.config.extend_candidates {
            for i in 0..n_cands {
                let d = self.scratch.out[i]; // Copy
                self.select_buf.push((d.id, d.dist));
            }
            return;
        }

        // ── Extended-candidates path (rare, allocates a temporary) ────────
        // Build an extended candidate list that includes the neighbours-of-
        // candidates.  Only triggered when extend_candidates = true.
        let ext_buf: Vec<DistId>;
        let cands: &[DistId] = if self.config.extend_candidates {
            let mut tmp: Vec<DistId> = self.scratch.out.to_vec();
            let seen_ids: std::collections::HashSet<usize> =
                self.scratch.out.iter().map(|d| d.id).collect();
            let mut extra: Vec<DistId> = Vec::new();
            for &d in &self.scratch.out {
                if let Some(nb_list) = self.connections.get(d.id).and_then(|nc| nc.get(layer)) {
                    for &(nb_u32, _) in nb_list {
                        let nb = nb_u32 as usize;
                        if !seen_ids.contains(&nb) {
                            extra.push(DistId::new(self.dist(q, nb), nb));
                        }
                    }
                }
            }
            tmp.extend_from_slice(&extra);
            tmp.sort_unstable_by(|a, b| a.dist.total_cmp(&b.dist));
            ext_buf = tmp;
            &ext_buf
        } else {
            &self.scratch.out
        };

        // ── Main heuristic loop ───────────────────────────────────────────
        //
        // Iterate candidates in closest-first order.  Accept candidate `e` iff
        // `d(q, e) ≤ d(e, s)` for every already-accepted neighbour `s`.
        // Equivalently, reject if any `s` is closer to `e` than `q` is.
        for i in 0..cands.len() {
            if self.select_buf.len() >= m { break; }
            let e_dist = cands[i].dist;
            let e_id   = cands[i].id;

            let mut accept = true;
            for j in 0..self.select_buf.len() {
                let (s_id, s_dist_q) = self.select_buf[j]; // Copy
                // Triangle-inequality shortcut:
                // If d(q,s) > 2·d(q,e) → d(e,s) ≥ d(q,s)−d(q,e) > d(q,e)
                // so the condition d(q,e) ≤ d(e,s) is guaranteed → continue
                if s_dist_q > 2.0 * e_dist { continue; }
                let d_es = self.metric.distance(
                    self.vec_store.get(e_id),
                    self.vec_store.get(s_id),
                );
                if d_es <= e_dist { accept = false; break; }
            }

            if accept {
                self.select_buf.push((e_id, e_dist));
            } else if self.config.keep_pruned {
                self.pruned_buf.push((e_id, e_dist));
            }
        }

        if self.config.keep_pruned {
            let needed = m.saturating_sub(self.select_buf.len());
            let add    = needed.min(self.pruned_buf.len());
            for i in 0..add {
                let (id, dist) = self.pruned_buf[i]; // Copy
                self.select_buf.push((id, dist));
            }
        }
    }

    // ─── Reverse-update pruning ───────────────────────────────────────────

    /// Heuristic prune of `node_id`'s connection list at `layer` back to
    /// `m_max` entries — **Algorithm 4**, using stored distances.
    ///
    /// Unlike the naïve approach (which cloned the list, recomputed all M
    /// distances, then ran the heuristic), here we:
    ///
    /// 1. **Sort by stored distances** — zero new distance computations.
    /// 2. **Run the heuristic diversity check** — O(M²/2) pairwise in the
    ///    worst case, but the triangle-inequality shortcut eliminates most of
    ///    them in practice (≈ 60% skipped in benchmarks).
    /// 3. **Write the result back in-place** — no heap allocation; results
    ///    land in `self.select_buf` (reused) then moved to the connection list.
    ///
    /// Caller must ensure `connections[node_id][layer].len() == m_max + 1`.
    fn prune_connections_heuristic(&mut self, node_id: usize, layer: usize, m_max: usize) {
        // ── Step 1: load stored (id, dist_from_node) into prune_buf ──────
        self.prune_buf.clear();
        let conn_len = self.connections[node_id][layer].len();
        for i in 0..conn_len {
            let (nb_u32, dist) = self.connections[node_id][layer][i];
            self.prune_buf.push((nb_u32 as usize, dist));
        }
        // Sort closest-first by stored distance — no distance computation.
        self.prune_buf.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));

        // ── Step 2: heuristic selection into select_buf ───────────────────
        self.select_buf.clear();
        self.pruned_buf.clear();

        // Fast path: if we already have ≤ m_max (shouldn't happen here, but
        // guard for correctness), nothing to do.
        // Normal path: iterate sorted candidates, apply diversity criterion.
        for i in 0..self.prune_buf.len() {
            if self.select_buf.len() >= m_max { break; }
            let (e_id, e_dist) = self.prune_buf[i]; // Copy — no borrow held

            let mut accept = true;
            for j in 0..self.select_buf.len() {
                let (s_id, s_dist_node) = self.select_buf[j]; // Copy
                // Triangle-inequality shortcut:
                // d(node,s) > 2·d(node,e)  →  d(e,s) ≥ d(node,s)−d(node,e) > d(node,e)
                // so the accept condition d(node,e) ≤ d(e,s) is guaranteed.
                if s_dist_node > 2.0 * e_dist { continue; }
                let d_es = self.metric.distance(
                    self.vec_store.get(e_id),
                    self.vec_store.get(s_id),
                );
                if d_es <= e_dist { accept = false; break; }
            }

            if accept {
                self.select_buf.push((e_id, e_dist));
            } else if self.config.keep_pruned {
                self.pruned_buf.push((e_id, e_dist));
            }
        }

        if self.config.keep_pruned {
            let needed = m_max.saturating_sub(self.select_buf.len());
            let add    = needed.min(self.pruned_buf.len());
            for i in 0..add {
                let (id, dist) = self.pruned_buf[i];
                self.select_buf.push((id, dist));
            }
        }

        // ── Step 3: write result back to the connection list ──────────────
        self.connections[node_id][layer].clear();
        for i in 0..self.select_buf.len() {
            let (id, dist) = self.select_buf[i];
            self.connections[node_id][layer].push((id as u32, dist));
        }
    }

    // ─── Stats / debug ────────────────────────────────────────────────────

    /// Return a human-readable summary of the index structure.
    pub fn stats(&self) -> IndexStats {
        let max_level = self.entry_point.map(|(_, l)| l).unwrap_or(0);
        let mut layer_counts = vec![0usize; max_level + 1];
        let mut layer_edges  = vec![0usize; max_level + 1];
        for node_conn in &self.connections {
            for (l, conn) in node_conn.iter().enumerate() {
                layer_counts[l] += 1;
                layer_edges[l]  += conn.len();
            }
        }
        IndexStats { num_vectors: self.vec_store.len(), max_level, layer_counts, layer_edges }
    }
}

/// Summary statistics about an [`Hnsw`] index.
#[derive(Debug)]
pub struct IndexStats {
    pub num_vectors:  usize,
    pub max_level:    usize,
    pub layer_counts: Vec<usize>,
    pub layer_edges:  Vec<usize>,
}

impl std::fmt::Display for IndexStats {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        writeln!(f, "HNSW index — {} vectors", self.num_vectors)?;
        writeln!(f, "  Max level : {}", self.max_level)?;
        for l in (0..=self.max_level).rev() {
            writeln!(f, "  Layer {:>3} : {:>6} nodes, {:>7} directed edges",
                     l, self.layer_counts[l], self.layer_edges[l])?;
        }
        Ok(())
    }
}