Skip to main content

fast_hnsw/
lib.rs

1//! # hnsw
2//!
3//! A pure-Rust implementation of **Hierarchical Navigable Small World** (HNSW)
4//! approximate nearest-neighbour search, following the algorithm from:
5//!
6//! > Malkov & Yashunin, *"Efficient and robust approximate nearest neighbor
7//! > search using Hierarchical Navigable Small World graphs"*,
8//! > IEEE TPAMI 2018.
9//!
10//! ## Quick start
11//!
12//! ```rust
13//! use hnsw::{Builder, Hnsw, SearchResult};
14//! use hnsw::distance::Euclidean;
15//!
16//! // Build an index.
17//! let mut index: Hnsw<Euclidean> = Builder::new()
18//!     .m(16)
19//!     .ef_construction(200)
20//!     .seed(42)
21//!     .build(Euclidean);
22//!
23//! // Insert vectors.
24//! for i in 0..100_u32 {
25//!     index.insert(vec![i as f32, (i * i) as f32]);
26//! }
27//!
28//! // Query: find 5 nearest neighbours with ef=50.
29//! let results: Vec<SearchResult> = index.search(&[10.0, 101.0], 5, 50);
30//! assert_eq!(results[0].id, 10);
31//! ```
32//!
33//! ## Distance metrics
34//!
35//! | Type                        | Description                         |
36//! |-----------------------------|-------------------------------------|
37//! | [`distance::Euclidean`]     | True L2 distance                    |
38//! | [`distance::SquaredEuclidean`] | L2² (faster, same NN order)      |
39//! | [`distance::Cosine`]        | 1 − cosine similarity               |
40//! | [`distance::DotProduct`]    | 1 − dot product                     |
41//! | [`distance::Manhattan`]     | L1 / taxicab distance               |
42//!
43//! Custom metrics are easy to add by implementing the [`distance::Distance`] trait.
44//!
45//! ## Feature flags
46//! *(none yet — this crate is dependency-light by design)*
47
48pub mod builder;
49pub mod distance;
50pub(crate) mod heap;
51pub mod hnsw;
52pub mod labeled;
53pub mod paired;
54pub mod payload;
55pub mod persist;
56
57pub use builder::Builder;
58pub use hnsw::{Config, Hnsw, IndexStats, PruneStrategy, SearchResult};
59pub use labeled::LabeledIndex;
60pub use paired::PairedIndex;
61
62// ─── Tests ────────────────────────────────────────────────────────────────────
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67    use distance::{Cosine, Euclidean, Manhattan, SquaredEuclidean};
68    use labeled::LabeledIndex;
69    use paired::PairedIndex;
70    use crate::persist;
71
72    // ── helpers ──────────────────────────────────────────────────────────
73
74    fn build_index(n: usize, dim: usize, seed: u64) -> Hnsw<Euclidean> {
75        use rand::{Rng, SeedableRng};
76        let mut rng = rand::rngs::SmallRng::seed_from_u64(seed + 1_000);
77        let mut index = Builder::new()
78            .m(16)
79            .ef_construction(200)
80            .seed(seed)
81            .build(Euclidean);
82        for _ in 0..n {
83            let v: Vec<f32> = (0..dim).map(|_| rng.gen::<f32>()).collect();
84            index.insert(v);
85        }
86        index
87    }
88
89    /// Brute-force exact k-NN.
90    fn exact_knn(vectors: &[Vec<f32>], query: &[f32], k: usize) -> Vec<usize> {
91        let mut dists: Vec<(f32, usize)> = vectors
92            .iter()
93            .enumerate()
94            .map(|(i, v)| {
95                let d: f32 = v.iter().zip(query).map(|(a, b)| (a - b) * (a - b)).sum::<f32>().sqrt();
96                (d, i)
97            })
98            .collect();
99        dists.sort_by(|a, b| a.0.total_cmp(&b.0));
100        dists.iter().take(k).map(|(_, i)| *i).collect()
101    }
102
103    // ── unit tests ────────────────────────────────────────────────────────
104
105    #[test]
106    fn empty_index_returns_nothing() {
107        let index: Hnsw<Euclidean> = Builder::new().build(Euclidean);
108        assert!(index.search(&[1.0, 2.0], 5, 20).is_empty());
109        assert!(index.is_empty());
110        assert_eq!(index.len(), 0);
111    }
112
113    #[test]
114    fn single_vector_always_returned() {
115        let mut index = Builder::new().seed(0).build(Euclidean);
116        index.insert(vec![1.0, 2.0, 3.0]);
117        let res = index.search(&[0.0, 0.0, 0.0], 1, 10);
118        assert_eq!(res.len(), 1);
119        assert_eq!(res[0].id, 0);
120    }
121
122    #[test]
123    fn ids_are_assigned_sequentially() {
124        let mut index = Builder::new().seed(1).build(Euclidean);
125        for i in 0..20 {
126            let id = index.insert(vec![i as f32]);
127            assert_eq!(id, i);
128        }
129        assert_eq!(index.len(), 20);
130    }
131
132    #[test]
133    fn nearest_of_two_is_correct() {
134        let mut index = Builder::new().seed(2).build(Euclidean);
135        index.insert(vec![0.0, 0.0]); // id=0
136        index.insert(vec![10.0, 0.0]); // id=1
137        // Query very close to id=0
138        let res = index.search(&[0.1, 0.0], 1, 10);
139        assert_eq!(res[0].id, 0);
140        // Query very close to id=1
141        let res = index.search(&[9.9, 0.0], 1, 10);
142        assert_eq!(res[0].id, 1);
143    }
144
145    #[test]
146    fn distances_are_non_negative_and_ordered() {
147        let index = build_index(200, 16, 3);
148        let query: Vec<f32> = vec![0.5; 16];
149        let results = index.search(&query, 10, 50);
150        assert_eq!(results.len(), 10);
151        for w in results.windows(2) {
152            assert!(w[0].distance >= 0.0);
153            assert!(w[0].distance <= w[1].distance);
154        }
155    }
156
157    #[test]
158    fn k_larger_than_index_returns_all() {
159        let index = build_index(30, 4, 4);
160        let query = vec![0.5f32; 4];
161        let res = index.search(&query, 100, 200);
162        assert_eq!(res.len(), 30);
163    }
164
165    #[test]
166    fn stored_vectors_are_retrievable() {
167        let mut index = Builder::new().seed(5).build(Euclidean);
168        let vecs = vec![vec![1.0f32, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
169        for v in &vecs {
170            index.insert(v.clone());
171        }
172        for (i, v) in vecs.iter().enumerate() {
173            assert_eq!(index.get_vector(i), v.as_slice());
174        }
175    }
176
177    #[test]
178    fn dim_is_tracked() {
179        let mut index = Builder::new().seed(6).build(Euclidean);
180        assert_eq!(index.dim(), None);
181        index.insert(vec![1.0, 2.0, 3.0]);
182        assert_eq!(index.dim(), Some(3));
183    }
184
185    #[test]
186    #[should_panic(expected = "expected 3")]
187    fn wrong_dimension_panics() {
188        let mut index = Builder::new().seed(7).build(Euclidean);
189        index.insert(vec![1.0, 2.0, 3.0]);
190        index.insert(vec![1.0, 2.0]); // wrong dim → panic
191    }
192
193    // ── recall tests ──────────────────────────────────────────────────────
194
195    fn recall(index: &Hnsw<Euclidean>, vectors: &[Vec<f32>], k: usize, ef: usize, n_queries: usize) -> f64 {
196        use rand::{Rng, SeedableRng};
197        let mut rng = rand::rngs::SmallRng::seed_from_u64(99_999);
198        let dim = vectors[0].len();
199
200        let mut hits = 0usize;
201        let mut total = 0usize;
202
203        for _ in 0..n_queries {
204            let query: Vec<f32> = (0..dim).map(|_| rng.gen::<f32>()).collect();
205            let exact = exact_knn(vectors, &query, k);
206            let approx: Vec<usize> = index.search(&query, k, ef).iter().map(|r| r.id).collect();
207            let exact_set: std::collections::HashSet<usize> = exact.into_iter().collect();
208            for id in &approx {
209                if exact_set.contains(id) {
210                    hits += 1;
211                }
212            }
213            total += k;
214        }
215
216        hits as f64 / total as f64
217    }
218
219    #[test]
220    fn recall_128d_is_acceptable() {
221        use rand::{Rng, SeedableRng};
222        let mut rng = rand::rngs::SmallRng::seed_from_u64(77);
223        let dim = 128;
224        let n = 1_000;
225
226        let mut vectors: Vec<Vec<f32>> = Vec::with_capacity(n);
227        let mut index = Builder::new()
228            .m(16)
229            .ef_construction(200)
230            .seed(42)
231            .build(Euclidean);
232
233        for _ in 0..n {
234            let v: Vec<f32> = (0..dim).map(|_| rng.gen::<f32>()).collect();
235            index.insert(v.clone());
236            vectors.push(v);
237        }
238
239        let r = recall(&index, &vectors, 10, 100, 100);
240        println!("Recall@10 (128d, 1k vectors, ef=100): {:.2}%", r * 100.0);
241        // Expect ≥ 90 % recall with these parameters.
242        assert!(r >= 0.90, "recall {:.2}% is too low", r * 100.0);
243    }
244
245    #[test]
246    fn recall_32d_high_ef_is_near_perfect() {
247        use rand::{Rng, SeedableRng};
248        let mut rng = rand::rngs::SmallRng::seed_from_u64(55);
249        let dim = 32;
250        let n = 500;
251
252        let mut vectors: Vec<Vec<f32>> = Vec::with_capacity(n);
253        let mut index = Builder::new()
254            .m(32)
255            .ef_construction(400)
256            .seed(13)
257            .build(Euclidean);
258
259        for _ in 0..n {
260            let v: Vec<f32> = (0..dim).map(|_| rng.gen::<f32>()).collect();
261            index.insert(v.clone());
262            vectors.push(v);
263        }
264
265        let r = recall(&index, &vectors, 10, 500, 50);
266        println!("Recall@10 (32d, 500 vectors, ef=500): {:.2}%", r * 100.0);
267        assert!(r >= 0.98, "recall {:.2}% is too low", r * 100.0);
268    }
269
270    // ── distance metric tests ─────────────────────────────────────────────
271
272    #[test]
273    fn squared_euclidean_finds_correct_neighbour() {
274        let mut index = Builder::new().seed(10).build(SquaredEuclidean);
275        index.insert(vec![0.0, 0.0]); // id=0
276        index.insert(vec![1.0, 0.0]); // id=1
277        index.insert(vec![5.0, 0.0]); // id=2
278        let res = index.search(&[0.2, 0.0], 1, 10);
279        assert_eq!(res[0].id, 0);
280    }
281
282    #[test]
283    fn cosine_distance_orthogonal_vectors() {
284        let mut index = Builder::new().seed(11).build(Cosine);
285        index.insert(vec![1.0, 0.0]); // id=0
286        index.insert(vec![0.0, 1.0]); // id=1  orthogonal
287        index.insert(vec![0.9, 0.1]); // id=2  close to id=0
288        let res = index.search(&[1.0, 0.0], 1, 10);
289        assert_eq!(res[0].id, 0);
290    }
291
292    #[test]
293    fn manhattan_metric_correct_order() {
294        let mut index = Builder::new().seed(12).build(Manhattan);
295        index.insert(vec![0.0]);  // id=0, dist=1.0 from query 1.0
296        index.insert(vec![10.0]); // id=1, dist=9.0 from query 1.0
297        index.insert(vec![1.5]);  // id=2, dist=0.5 from query 1.0
298        let res = index.search(&[1.0], 1, 10);
299        assert_eq!(res[0].id, 2);
300    }
301
302    // ── edge cases ────────────────────────────────────────────────────────
303
304    #[test]
305    fn two_identical_vectors() {
306        let mut index = Builder::new().seed(20).build(Euclidean);
307        index.insert(vec![1.0, 1.0]); // id=0
308        index.insert(vec![1.0, 1.0]); // id=1  duplicate
309        let res = index.search(&[1.0, 1.0], 2, 10);
310        assert_eq!(res.len(), 2);
311        assert_eq!(res[0].distance, 0.0);
312        assert_eq!(res[1].distance, 0.0);
313    }
314
315    #[test]
316    fn one_dimensional_vectors() {
317        let mut index = Builder::new().seed(21).build(Euclidean);
318        for i in 0..50_u32 {
319            index.insert(vec![i as f32]);
320        }
321        let res = index.search(&[25.0], 3, 30);
322        let ids: Vec<usize> = res.iter().map(|r| r.id).collect();
323        assert!(ids.contains(&25));
324    }
325
326    #[test]
327    fn large_dimension_does_not_panic() {
328        let mut index = Builder::new().m(8).ef_construction(50).seed(22).build(Euclidean);
329        let dim: usize = 1024;
330        for i in 0..50_u32 {
331            let v: Vec<f32> = (0..dim).map(|j| (i as usize + j) as f32).collect();
332            index.insert(v);
333        }
334        let query: Vec<f32> = vec![1.0; dim];
335        let res = index.search(&query, 5, 20);
336        assert_eq!(res.len(), 5);
337    }
338
339    #[test]
340    fn simple_neighbour_selection_fallback() {
341        let mut index = Builder::new()
342            .m(16)
343            .ef_construction(100)
344            .heuristic(false) // use simple selection
345            .seed(30)
346            .build(Euclidean);
347        for i in 0..100_u32 {
348            index.insert(vec![i as f32, 0.0]);
349        }
350        let res = index.search(&[50.0, 0.0], 3, 30);
351        // Should include 50
352        assert!(res.iter().any(|r| r.id == 50));
353    }
354
355    // ── stats ─────────────────────────────────────────────────────────────
356
357    #[test]
358    fn stats_are_consistent() {
359        let index = build_index(500, 32, 50);
360        let stats = index.stats();
361        assert_eq!(stats.num_vectors, 500);
362        // Layer 0 must contain all nodes.
363        assert_eq!(stats.layer_counts[0], 500);
364        // Edge count must be even (undirected).
365        assert_eq!(stats.layer_edges[0] % 2, 0);
366        println!("{}", stats);
367    }
368
369    // ── Persistence tests ─────────────────────────────────────────────────
370
371    fn make_hnsw(n: usize, dim: usize, seed: u64) -> (Hnsw<Euclidean>, Vec<Vec<f32>>) {
372        use rand::{Rng, SeedableRng};
373        let mut rng = rand::rngs::SmallRng::seed_from_u64(seed + 5_000);
374        let mut index = Builder::new().m(16).ef_construction(200).seed(seed).build(Euclidean);
375        let mut corpus = Vec::with_capacity(n);
376        for _ in 0..n {
377            let v: Vec<f32> = (0..dim).map(|_| rng.gen::<f32>()).collect();
378            index.insert(v.clone());
379            corpus.push(v);
380        }
381        (index, corpus)
382    }
383
384    #[test]
385    fn persist_save_load_round_trip() {
386        let (orig, _) = make_hnsw(200, 16, 300);
387        let dir = tempdir();
388        let path = dir.join("test.hnsw");
389        persist::save(&orig, &path).expect("save failed");
390
391        let loaded = persist::load(&path, Euclidean).expect("load failed");
392        assert_eq!(orig.len(), loaded.len());
393        assert_eq!(orig.dim(), loaded.dim());
394        // Vectors must be identical.
395        for i in 0..orig.len() {
396            assert_eq!(orig.get_vector(i), loaded.get_vector(i),
397                       "vector {i} differs after load");
398        }
399        // Search results must be identical (same graph topology).
400        let q = vec![0.5f32; 16];
401        let r_orig   = orig.search(&q, 5, 50);
402        let r_loaded = loaded.search(&q, 5, 50);
403        assert_eq!(r_orig.len(), r_loaded.len());
404        for (a, b) in r_orig.iter().zip(r_loaded.iter()) {
405            assert_eq!(a.id, b.id, "search result id differs");
406            assert!((a.distance - b.distance).abs() < 1e-6,
407                    "distance differs: {} vs {}", a.distance, b.distance);
408        }
409    }
410
411    #[test]
412    fn persist_mmap_load_round_trip() {
413        let (orig, _) = make_hnsw(200, 16, 301);
414        let dir  = tempdir();
415        let path = dir.join("mmap_test.hnsw");
416        persist::save(&orig, &path).expect("save failed");
417
418        let mmap = persist::load_mmap(&path, Euclidean).expect("mmap load failed");
419        assert_eq!(orig.len(), mmap.len());
420        for i in 0..orig.len() {
421            assert_eq!(orig.get_vector(i), mmap.get_vector(i),
422                       "mmap vector {i} differs");
423        }
424        let q = vec![0.3f32; 16];
425        let r_orig = orig.search(&q, 5, 50);
426        let r_mmap = mmap.search(&q, 5, 50);
427        for (a, b) in r_orig.iter().zip(r_mmap.iter()) {
428            assert_eq!(a.id, b.id);
429        }
430    }
431
432    #[test]
433    fn persist_empty_index() {
434        let empty: Hnsw<Euclidean> = Builder::new().build(Euclidean);
435        let dir  = tempdir();
436        let path = dir.join("empty.hnsw");
437        persist::save(&empty, &path).expect("save empty failed");
438        let loaded = persist::load(&path, Euclidean).expect("load empty failed");
439        assert_eq!(loaded.len(), 0);
440        assert!(loaded.search(&[0.0, 1.0], 5, 10).is_empty());
441    }
442
443    // ── LabeledIndex tests ────────────────────────────────────────────────
444
445    #[test]
446    fn labeled_insert_and_search_u32() {
447        let mut idx: LabeledIndex<Euclidean, u32> =
448            Builder::new().seed(400).build_labeled(Euclidean);
449        idx.insert(vec![0.0, 0.0], 10_u32);
450        idx.insert(vec![1.0, 0.0], 20_u32);
451        idx.insert(vec![0.0, 1.0], 30_u32);
452
453        let hits = idx.search(&[0.1, 0.0], 1, 20);
454        assert_eq!(hits[0].payload, &10_u32);
455        assert_eq!(hits[0].id, 0);
456    }
457
458    #[test]
459    fn labeled_insert_and_search_string() {
460        let mut idx: LabeledIndex<Euclidean, String> =
461            Builder::new().seed(401).build_labeled(Euclidean);
462        idx.insert(vec![1.0, 0.0], "cat".to_string());
463        idx.insert(vec![0.0, 1.0], "dog".to_string());
464        idx.insert(vec![0.5, 0.5], "rabbit".to_string());
465
466        let hits = idx.search(&[0.9, 0.1], 1, 20);
467        assert_eq!(hits[0].payload, "cat");
468        assert_eq!(hits[0].embedding, &[1.0f32, 0.0]);
469    }
470
471    #[test]
472    fn labeled_search_returns_embedding() {
473        let mut idx: LabeledIndex<Euclidean, ()> =
474            Builder::new().seed(402).build_labeled(Euclidean);
475        let v = vec![3.0f32, 4.0];
476        idx.insert(v.clone(), ());
477        let hits = idx.search(&[3.0, 4.0], 1, 10);
478        assert_eq!(hits[0].embedding, v.as_slice());
479    }
480
481    #[test]
482    fn labeled_save_load_u32() {
483        let mut idx: LabeledIndex<Euclidean, u32> =
484            Builder::new().seed(410).build_labeled(Euclidean);
485        for i in 0..50_u32 {
486            idx.insert(vec![i as f32, (i * 2) as f32], i * 10);
487        }
488        let dir  = tempdir();
489        let path = dir.join("labeled_u32.hnsw");
490        idx.save(&path).expect("save failed");
491
492        let loaded = LabeledIndex::<Euclidean, u32>::load(&path, Euclidean)
493            .expect("load failed");
494        assert_eq!(loaded.len(), 50);
495        for i in 0..50_usize {
496            assert_eq!(loaded.get_payload(i), &(i as u32 * 10));
497            assert_eq!(loaded.get_embedding(i), &[i as f32, (i * 2) as f32]);
498        }
499        let hits = loaded.search(&[25.0, 50.0], 1, 30);
500        assert_eq!(hits[0].id, 25);
501        assert_eq!(hits[0].payload, &250_u32);
502    }
503
504    #[test]
505    fn labeled_save_load_string() {
506        let labels = ["alpha", "beta", "gamma", "delta", "epsilon"];
507        let mut idx: LabeledIndex<Euclidean, String> =
508            Builder::new().seed(411).build_labeled(Euclidean);
509        for (i, &s) in labels.iter().enumerate() {
510            idx.insert(vec![i as f32], s.to_string());
511        }
512        let dir  = tempdir();
513        let path = dir.join("labeled_str.hnsw");
514        idx.save(&path).expect("save failed");
515
516        let loaded = LabeledIndex::<Euclidean, String>::load(&path, Euclidean)
517            .expect("load failed");
518        for (i, &s) in labels.iter().enumerate() {
519            assert_eq!(loaded.get_payload(i), s);
520        }
521    }
522
523    #[test]
524    fn labeled_save_load_vec_f32_payload() {
525        // Payload is a secondary embedding (variable-width)
526        let mut idx: LabeledIndex<Euclidean, Vec<f32>> =
527            Builder::new().seed(412).build_labeled(Euclidean);
528        let primary = vec![1.0f32, 0.0];
529        let secondary = vec![0.0f32, 0.0, 1.0]; // different dim
530        idx.insert(primary.clone(), secondary.clone());
531        let dir  = tempdir();
532        let path = dir.join("labeled_vecf32.hnsw");
533        idx.save(&path).expect("save failed");
534
535        let loaded = LabeledIndex::<Euclidean, Vec<f32>>::load(&path, Euclidean)
536            .expect("load failed");
537        assert_eq!(loaded.get_payload(0), &secondary);
538    }
539
540    #[test]
541    fn labeled_mmap_load() {
542        let mut idx: LabeledIndex<Euclidean, u32> =
543            Builder::new().seed(420).build_labeled(Euclidean);
544        for i in 0..30_u32 {
545            idx.insert(vec![i as f32], i);
546        }
547        let dir  = tempdir();
548        let path = dir.join("labeled_mmap.hnsw");
549        idx.save(&path).expect("save failed");
550
551        let mmap = LabeledIndex::<Euclidean, u32>::load_mmap(&path, Euclidean)
552            .expect("mmap load failed");
553        assert_eq!(mmap.len(), 30);
554        for i in 0..30_usize {
555            assert_eq!(mmap.get_payload(i), &(i as u32));
556        }
557    }
558
559    // ── PairedIndex tests ─────────────────────────────────────────────────
560
561    #[test]
562    fn paired_insert_and_search_both_sides() {
563        let mut idx: PairedIndex<Euclidean, Euclidean> = Builder::new()
564            .m(16).ef_construction(50).seed(500)
565            .build_paired(Euclidean, Euclidean);
566
567        // Three items: each has a 2-D A-embedding and 3-D B-embedding.
568        idx.insert(vec![1.0, 0.0],       vec![0.9, 0.1, 0.0]);   // id=0
569        idx.insert(vec![0.0, 1.0],       vec![0.1, 0.8, 0.1]);   // id=1
570        idx.insert(vec![0.5, 0.5],       vec![0.3, 0.3, 0.4]);   // id=2
571
572        // Search A-space: query near item 0
573        let hits_a = idx.search_by_a(&[0.9, 0.1], 1, 20);
574        assert_eq!(hits_a[0].id, 0);
575        assert_eq!(hits_a[0].emb_b, &[0.9f32, 0.1, 0.0]);
576
577        // Search B-space: query near item 1
578        let hits_b = idx.search_by_b(&[0.1, 0.9, 0.0], 1, 20);
579        assert_eq!(hits_b[0].id, 1);
580        assert_eq!(hits_b[0].emb_a, &[0.0f32, 1.0]);
581    }
582
583    #[test]
584    fn paired_len_consistent() {
585        let mut idx: PairedIndex<Euclidean, Euclidean> =
586            PairedIndex::new(Default::default(), Euclidean, Default::default(), Euclidean);
587        assert_eq!(idx.len(), 0);
588        for i in 0..10_u32 {
589            idx.insert(vec![i as f32], vec![i as f32, i as f32]);
590            assert_eq!(idx.len(), i as usize + 1);
591        }
592    }
593
594    #[test]
595    fn paired_cross_side_retrieval() {
596        let mut idx: PairedIndex<Euclidean, Euclidean> = Builder::new()
597            .m(16).ef_construction(100).seed(501)
598            .build_paired(Euclidean, Euclidean);
599        // 20 items
600        for i in 0..20_u32 {
601            idx.insert(vec![i as f32, 0.0], vec![0.0, i as f32]);
602        }
603        // Search by A near item 10 → get B embedding of item 10
604        let hits = idx.search_by_a(&[10.0, 0.0], 1, 30);
605        assert_eq!(hits[0].id, 10);
606        assert_eq!(hits[0].emb_b, &[0.0f32, 10.0]);
607        // Confirm: searching by B near item 10 → get A embedding of item 10
608        let hits2 = idx.search_by_b(&[0.0, 10.0], 1, 30);
609        assert_eq!(hits2[0].id, 10);
610        assert_eq!(hits2[0].emb_a, &[10.0f32, 0.0]);
611    }
612
613    #[test]
614    fn paired_save_load() {
615        let mut idx: PairedIndex<Euclidean, Euclidean> = Builder::new()
616            .m(16).ef_construction(100).seed(510)
617            .build_paired(Euclidean, Euclidean);
618        for i in 0..50_u32 {
619            idx.insert(vec![i as f32], vec![i as f32, i as f32]);
620        }
621        let dir = tempdir();
622        let base = dir.join("paired");
623        idx.save(&base).expect("save failed");
624
625        let loaded = PairedIndex::<Euclidean, Euclidean>::load(&base, Euclidean, Euclidean)
626            .expect("load failed");
627        assert_eq!(loaded.len(), 50);
628        for i in 0..50_usize {
629            assert_eq!(loaded.get_emb_a(i), &[i as f32][..]);
630            assert_eq!(loaded.get_emb_b(i), &[i as f32, i as f32][..]);
631        }
632        let hits = loaded.search_by_a(&[25.0], 1, 30);
633        assert_eq!(hits[0].id, 25);
634    }
635
636    #[test]
637    fn paired_mmap_load() {
638        let mut idx: PairedIndex<Euclidean, Euclidean> = Builder::new()
639            .seed(520).build_paired(Euclidean, Euclidean);
640        for i in 0..30_u32 {
641            idx.insert(vec![i as f32, 0.0], vec![0.0, i as f32, 1.0]);
642        }
643        let dir  = tempdir();
644        let base = dir.join("paired_mmap");
645        idx.save(&base).expect("save failed");
646
647        let m = PairedIndex::<Euclidean, Euclidean>::load_mmap(&base, Euclidean, Euclidean)
648            .expect("mmap load failed");
649        assert_eq!(m.len(), 30);
650        // Spot-check a few vectors
651        for i in [0, 15, 29] {
652            assert_eq!(m.get_emb_a(i), &[i as f32, 0.0f32][..]);
653            assert_eq!(m.get_emb_b(i), &[0.0f32, i as f32, 1.0][..]);
654        }
655    }
656
657    // Helper: create a temp directory that lives for the duration of the test.
658    fn tempdir() -> std::path::PathBuf {
659        use std::time::{SystemTime, UNIX_EPOCH};
660        let ts = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().subsec_nanos();
661        let dir = std::env::temp_dir().join(format!("hnsw_test_{ts}"));
662        std::fs::create_dir_all(&dir).unwrap();
663        dir
664    }
665
666    // ── PruneStrategy tests ───────────────────────────────────────────────
667
668    /// Helper: build an index with a given prune strategy and return
669    /// (index, corpus) so callers can run recall checks.
670    fn build_with_prune(n: usize, dim: usize, seed: u64, ps: PruneStrategy)
671        -> (Hnsw<Euclidean>, Vec<Vec<f32>>)
672    {
673        use rand::{Rng, SeedableRng};
674        let mut rng = rand::rngs::SmallRng::seed_from_u64(seed + 2_000);
675        let mut index = Builder::new()
676            .m(16)
677            .ef_construction(200)
678            .prune_strategy(ps)
679            .seed(seed)
680            .build(Euclidean);
681        let mut corpus = Vec::with_capacity(n);
682        for _ in 0..n {
683            let v: Vec<f32> = (0..dim).map(|_| rng.gen::<f32>()).collect();
684            index.insert(v.clone());
685            corpus.push(v);
686        }
687        (index, corpus)
688    }
689
690    #[test]
691    fn prune_strategy_default_is_simple() {
692        // Ensure Config::default() picks Simple so users get the fastest
693        // behaviour out of the box without any builder call.
694        assert_eq!(Config::default().prune_strategy, PruneStrategy::Simple);
695        // Building via Builder without calling .prune_strategy() must also
696        // default to Simple.
697        let mut index = Builder::new().seed(0).build(Euclidean);
698        index.insert(vec![1.0, 2.0]);
699        // The index built successfully — no panics, correct result.
700        assert_eq!(index.search(&[1.0, 2.0], 1, 10)[0].id, 0);
701    }
702
703    #[test]
704    fn prune_strategy_simple_gives_acceptable_recall() {
705        let (index, corpus) = build_with_prune(500, 32, 101, PruneStrategy::Simple);
706        let r = recall(&index, &corpus, 10, 200, 50);
707        println!("Simple recall@10 (32d 500v ef=200): {:.2}%", r * 100.0);
708        assert!(r >= 0.95, "Simple recall {:.2}% too low", r * 100.0);
709    }
710
711    #[test]
712    fn prune_strategy_heuristic_gives_acceptable_recall() {
713        let (index, corpus) = build_with_prune(500, 32, 101, PruneStrategy::Heuristic);
714        let r = recall(&index, &corpus, 10, 200, 50);
715        println!("Heuristic recall@10 (32d 500v ef=200): {:.2}%", r * 100.0);
716        assert!(r >= 0.95, "Heuristic recall {:.2}% too low", r * 100.0);
717    }
718
719    #[test]
720    fn prune_strategy_heuristic_recall_ge_simple() {
721        // Heuristic must not be worse than Simple (it does strictly more work
722        // to preserve diversity).  Run both on the same data and seed.
723        let (idx_s, corpus) = build_with_prune(500, 128, 202, PruneStrategy::Simple);
724        let (idx_h, _)      = build_with_prune(500, 128, 202, PruneStrategy::Heuristic);
725        let r_s = recall(&idx_s, &corpus, 10, 100, 50);
726        let r_h = recall(&idx_h, &corpus, 10, 100, 50);
727        println!("Simple {:.2}%  Heuristic {:.2}%", r_s * 100.0, r_h * 100.0);
728        // Allow up to 1 pp slack for statistical noise in the random queries.
729        assert!(r_h + 0.01 >= r_s,
730            "Heuristic recall ({:.2}%) should be ≥ Simple ({:.2}%)",
731            r_h * 100.0, r_s * 100.0);
732    }
733
734    #[test]
735    fn max_level_grows_with_more_inserts() {
736        let index_small = build_index(10, 4, 60);
737        let index_large = build_index(10_000, 4, 60);
738        // With many more nodes the entry-point level is likely higher.
739        // This is probabilistic but almost certain with 10 000 vs 10 nodes.
740        let l_small = index_small.max_level().unwrap_or(0);
741        let l_large = index_large.max_level().unwrap_or(0);
742        println!("small max_level={l_small}, large max_level={l_large}");
743        assert!(l_large >= l_small);
744    }
745}