use super::super::distance::CachedSimdDistance;
use super::super::layer::NodeId;
use super::super::ordered_float::OrderedFloat;
use super::search_pools::{BitVecVisited, CANDIDATE_HEAP_POOL, POOL_MAX, RESULT_HEAP_POOL};
use super::search_state::{gather_unvisited_neighbors, process_batch_results, SearchState};
use super::{NativeHnsw, NO_ENTRY_POINT};
use crate::distance::DistanceMetric;
use rustc_hash::FxHashSet;
use smallvec::SmallVec;
use std::cmp::Reverse;
#[allow(dead_code)]
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
1.0
} else {
1.0 - (dot / (norm_a * norm_b))
}
}
#[test]
fn test_search_state_new_and_push() {
let mut state = SearchState::new(0);
state.push_candidate(10, 0.5);
state.push_candidate(20, 0.1);
state.push_candidate(30, 0.9);
let (OrderedFloat(top_dist), top_node) = state
.candidates
.peek()
.map(|Reverse(item)| item)
.copied()
.expect("candidates should not be empty");
assert_eq!(top_node, 20, "min-heap should surface closest candidate");
assert!((top_dist - 0.1).abs() < f32::EPSILON);
let &(OrderedFloat(furthest_dist), furthest_node) =
state.results.peek().expect("results should not be empty");
assert_eq!(furthest_node, 30, "max-heap should surface furthest result");
assert!((furthest_dist - 0.9).abs() < f32::EPSILON);
assert!(state.visited.contains(10));
assert!(state.visited.contains(20));
assert!(state.visited.contains(30));
}
#[test]
fn test_search_state_should_terminate() {
let mut state = SearchState::new(0);
state.push_candidate(1, 0.1);
state.push_candidate(2, 0.3);
state.push_candidate(3, 0.5);
let ef = 3;
let stagnation_limit = 10;
assert!(
state.should_terminate(0.6, ef, stagnation_limit),
"should terminate: c_dist > furthest and results full"
);
assert!(
!state.should_terminate(0.4, ef, stagnation_limit),
"should not terminate: c_dist < furthest"
);
let mut state2 = SearchState::new(0);
state2.push_candidate(1, 0.1);
state2.push_candidate(2, 0.3);
assert!(
!state2.should_terminate(0.6, ef, stagnation_limit),
"should not terminate: results not yet full"
);
}
#[test]
fn test_search_state_stagnation() {
let mut state = SearchState::new(0);
state.push_candidate(1, 0.1);
state.push_candidate(2, 0.3);
let ef = 2;
let stagnation_limit = 3;
assert_eq!(state.stagnation_count, 0);
state.update_stagnation(false); assert_eq!(state.stagnation_count, 1);
assert!(!state.should_terminate(0.0, ef, stagnation_limit));
state.update_stagnation(false); assert_eq!(state.stagnation_count, 2);
assert!(!state.should_terminate(0.0, ef, stagnation_limit));
state.update_stagnation(false); assert_eq!(state.stagnation_count, 3);
assert!(
state.should_terminate(0.0, ef, stagnation_limit),
"should terminate after reaching stagnation limit"
);
state.update_stagnation(true); assert_eq!(state.stagnation_count, 0);
assert!(!state.should_terminate(0.0, ef, stagnation_limit));
}
#[test]
fn test_search_state_into_sorted_results() {
let mut state = SearchState::new(0);
state.push_candidate(10, 0.7);
state.push_candidate(20, 0.2);
state.push_candidate(30, 0.5);
state.push_candidate(40, 0.1);
state.push_candidate(50, 0.9);
let sorted = state.into_sorted_results(None);
assert_eq!(sorted.len(), 5);
assert_eq!(sorted[0].0, 40); assert_eq!(sorted[1].0, 20); assert_eq!(sorted[2].0, 30); assert_eq!(sorted[3].0, 10); assert_eq!(sorted[4].0, 50);
for window in sorted.windows(2) {
assert!(
window[0].1 <= window[1].1,
"results must be sorted by distance ascending: {} <= {}",
window[0].1,
window[1].1,
);
}
}
#[test]
fn test_gather_unvisited_neighbors_filters_visited() {
let dim = 4;
let mut vectors =
crate::perf_optimizations::ContiguousVectors::new(dim, 10).expect("alloc should succeed");
for i in 0..5_usize {
#[allow(clippy::cast_precision_loss)]
let v: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32).collect();
vectors.push(&v).expect("push should succeed");
}
let neighbors: Vec<NodeId> = vec![0, 1, 2, 3, 4];
let mut visited = BitVecVisited::with_capacity(5);
visited.insert(1);
visited.insert(3);
let unvisited: SmallVec<[(NodeId, &[f32]); 32]> =
gather_unvisited_neighbors(&neighbors, &mut visited, &vectors, false);
let ids: Vec<NodeId> = unvisited.iter().map(|(id, _)| *id).collect();
assert_eq!(ids.len(), 3, "should exclude 2 visited nodes");
assert!(ids.contains(&0));
assert!(ids.contains(&2));
assert!(ids.contains(&4));
assert!(!ids.contains(&1), "visited node 1 must be excluded");
assert!(!ids.contains(&3), "visited node 3 must be excluded");
}
#[test]
fn test_gather_unvisited_neighbors_marks_visited() {
let dim = 4;
let mut vectors =
crate::perf_optimizations::ContiguousVectors::new(dim, 10).expect("alloc should succeed");
for i in 0..3_usize {
#[allow(clippy::cast_precision_loss)]
let v: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32).collect();
vectors.push(&v).expect("push should succeed");
}
let neighbors: Vec<NodeId> = vec![0, 1, 2];
let mut visited = BitVecVisited::with_capacity(3);
let _unvisited = gather_unvisited_neighbors(&neighbors, &mut visited, &vectors, false);
assert!(visited.contains(0), "node 0 should be marked visited");
assert!(visited.contains(1), "node 1 should be marked visited");
assert!(visited.contains(2), "node 2 should be marked visited");
}
#[test]
fn test_process_batch_results_updates_heaps() {
let mut state = SearchState::new(0);
let ef = 10;
let dim = 4;
let vecs: Vec<Vec<f32>> = (0..3)
.map(|i| (0..dim).map(|j| (i * dim + j) as f32).collect())
.collect();
let batch: Vec<(NodeId, &[f32])> = vecs
.iter()
.enumerate()
.map(|(i, v)| (i, v.as_slice()))
.collect();
let distances = vec![0.3_f32, 0.1, 0.5];
let improved = process_batch_results(&batch, &distances, ef, &mut state);
assert!(improved, "first batch should improve empty state");
assert_eq!(state.candidates.len(), 3);
assert_eq!(state.results.len(), 3);
let Reverse((OrderedFloat(min_dist), min_node)) = *state.candidates.peek().expect("non-empty");
assert_eq!(min_node, 1);
assert!((min_dist - 0.1).abs() < f32::EPSILON);
}
#[test]
fn test_process_batch_results_evicts_furthest_when_full() {
let mut state = SearchState::new(0);
let ef = 3;
state.push_candidate(10, 0.2);
state.push_candidate(20, 0.4);
state.push_candidate(30, 0.6);
let dim = 4;
let v_close: Vec<f32> = vec![1.0; dim];
let v_far: Vec<f32> = vec![2.0; dim];
let batch: Vec<(NodeId, &[f32])> = vec![(40, v_close.as_slice()), (50, v_far.as_slice())];
let distances = vec![0.3_f32, 0.8];
let improved = process_batch_results(&batch, &distances, ef, &mut state);
assert!(
improved,
"batch with closer candidate should improve results"
);
assert_eq!(state.results.len(), ef, "results must not exceed ef");
let result_ids: Vec<NodeId> = state.results.iter().map(|(_, id)| *id).collect();
assert!(
!result_ids.contains(&30),
"node 30 (dist 0.6) should have been evicted"
);
assert!(
result_ids.contains(&40),
"node 40 (dist 0.3) should have been admitted"
);
}
#[test]
fn test_refactored_search_recall_matches_original() {
let dim = 32;
let n = 200;
let k = 10;
let ef_search = 64;
let n_queries = 10;
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 32);
let hnsw = NativeHnsw::new(engine, 16, 100, n);
let vectors: Vec<Vec<f32>> = (0..n)
.map(|i| {
(0..dim)
.map(|j| ((i * dim + j) as f32 * 0.001).sin())
.collect()
})
.collect();
for v in &vectors {
hnsw.insert(v).expect("insert should succeed in test");
}
let mut total_recall = 0.0_f64;
for q_idx in 0..n_queries {
let query = &vectors[q_idx * (n / n_queries)];
let hnsw_results: Vec<NodeId> = hnsw
.search(query, k, ef_search)
.iter()
.map(|(id, _)| *id)
.collect();
let mut brute: Vec<(NodeId, f32)> = vectors
.iter()
.enumerate()
.map(|(i, v)| {
let dist: f32 = v
.iter()
.zip(query.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
(i, dist)
})
.collect();
brute.sort_by(|a, b| a.1.total_cmp(&b.1));
let ground_truth: Vec<NodeId> = brute.iter().take(k).map(|(id, _)| *id).collect();
let hits = hnsw_results
.iter()
.filter(|id| ground_truth.contains(id))
.count();
#[allow(clippy::cast_precision_loss)]
{
total_recall += hits as f64 / k as f64;
}
}
#[allow(clippy::cast_precision_loss)]
let avg_recall = total_recall / n_queries as f64;
assert!(
avg_recall >= 0.90,
"search recall must be >= 90% (got {:.1}%); \
if this fails after refactoring, the extraction broke correctness",
avg_recall * 100.0,
);
}
#[test]
fn test_into_sorted_results_with_limit() {
let mut state = SearchState::new(0);
for i in 0..10_usize {
#[allow(clippy::cast_precision_loss)]
state.push_candidate(i, (10 - i) as f32 * 0.1);
}
assert_eq!(state.results.len(), 10);
let sorted = state.into_sorted_results(Some(3));
assert_eq!(sorted.len(), 3, "limit=3 should return exactly 3 results");
for window in sorted.windows(2) {
assert!(
window[0].1 <= window[1].1,
"results must be sorted ascending: {} <= {}",
window[0].1,
window[1].1,
);
}
assert!(
(sorted[0].1 - 0.1).abs() < f32::EPSILON,
"first result should be dist 0.1, got {}",
sorted[0].1,
);
assert!(
(sorted[1].1 - 0.2).abs() < f32::EPSILON,
"second result should be dist 0.2, got {}",
sorted[1].1,
);
assert!(
(sorted[2].1 - 0.3).abs() < f32::EPSILON,
"third result should be dist 0.3, got {}",
sorted[2].1,
);
}
#[test]
fn test_into_sorted_results_without_limit() {
let mut state = SearchState::new(0);
for i in 0..10_usize {
#[allow(clippy::cast_precision_loss)]
state.push_candidate(i, (10 - i) as f32 * 0.1);
}
let sorted = state.into_sorted_results(None);
assert_eq!(sorted.len(), 10, "None limit should return all results");
for window in sorted.windows(2) {
assert!(
window[0].1 <= window[1].1,
"results must be sorted ascending: {} <= {}",
window[0].1,
window[1].1,
);
}
}
#[test]
fn test_into_sorted_results_limit_greater_than_results() {
let mut state = SearchState::new(0);
for i in 0..5_usize {
#[allow(clippy::cast_precision_loss)]
state.push_candidate(i, (5 - i) as f32 * 0.1);
}
let sorted = state.into_sorted_results(Some(10));
assert_eq!(
sorted.len(),
5,
"limit > len should return all available results"
);
for window in sorted.windows(2) {
assert!(
window[0].1 <= window[1].1,
"results must be sorted ascending: {} <= {}",
window[0].1,
window[1].1,
);
}
}
#[test]
fn test_into_sorted_results_limit_zero() {
let mut state = SearchState::new(0);
state.push_candidate(0, 0.5);
state.push_candidate(1, 0.1);
state.push_candidate(2, 0.9);
let sorted = state.into_sorted_results(Some(0));
assert!(sorted.is_empty(), "limit=0 should return empty vec");
}
#[test]
fn test_into_sorted_results_empty_state() {
let state = SearchState::new(0);
let sorted = state.into_sorted_results(None);
assert!(
sorted.is_empty(),
"empty state with None should return empty vec"
);
let state2 = SearchState::new(0);
let sorted2 = state2.into_sorted_results(Some(5));
assert!(
sorted2.is_empty(),
"empty state with Some(5) should return empty vec"
);
}
#[test]
fn test_bitvec_visited_insert_and_contains() {
let mut visited = BitVecVisited::with_capacity(1000);
assert!(!visited.contains(42));
visited.insert(42);
assert!(visited.contains(42));
assert!(!visited.contains(43));
}
#[test]
fn test_bitvec_visited_insert_returns_newly_inserted() {
let mut visited = BitVecVisited::with_capacity(100);
assert!(visited.insert(10));
assert!(!visited.insert(10));
assert!(visited.insert(11));
}
#[test]
fn test_bitvec_visited_clear_resets() {
let mut visited = BitVecVisited::with_capacity(100);
visited.insert(50);
assert!(visited.contains(50));
visited.clear();
assert!(!visited.contains(50));
}
#[test]
fn test_bitvec_visited_clear_preserves_capacity() {
let mut visited = BitVecVisited::with_capacity(1000);
let words_before = visited.words.len();
visited.insert(999);
visited.clear();
assert_eq!(visited.words.len(), words_before);
}
#[test]
fn test_bitvec_visited_out_of_bounds_grows() {
let mut visited = BitVecVisited::with_capacity(10);
visited.insert(100);
assert!(visited.contains(100));
assert!(!visited.contains(99));
}
#[test]
fn test_bitvec_visited_zero_capacity() {
let mut visited = BitVecVisited::with_capacity(0);
assert!(!visited.contains(0));
visited.insert(0);
assert!(visited.contains(0));
}
#[test]
fn test_bitvec_visited_word_boundary() {
let mut visited = BitVecVisited::with_capacity(128);
for id in [0, 1, 62, 63, 64, 65, 126, 127] {
visited.insert(id);
}
for id in [0, 1, 62, 63, 64, 65, 126, 127] {
assert!(visited.contains(id), "should contain {id}");
}
for id in [2, 32, 66, 100] {
assert!(!visited.contains(id), "should not contain {id}");
}
}
#[test]
fn test_bitvec_visited_identical_to_hashset() {
let mut bv = BitVecVisited::with_capacity(10_000);
let mut hs = FxHashSet::default();
let ids = [0, 1, 42, 999, 5000, 9999, 7, 128, 255, 256, 1023];
for &id in &ids {
let bv_new = bv.insert(id);
let hs_new = hs.insert(id);
assert_eq!(
bv_new, hs_new,
"insert return mismatch at {id}: bv={bv_new}, hs={hs_new}"
);
}
for i in 0..10_000 {
assert_eq!(bv.contains(i), hs.contains(&i), "contains mismatch at {i}");
}
}
#[test]
fn test_bitvec_visited_recall_regression() {
let dim = 32;
let n = 500;
let k = 10;
let ef_search = 128;
let n_queries = 20;
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 32);
let hnsw = NativeHnsw::new(engine, 16, 200, n);
let vectors: Vec<Vec<f32>> = (0..n)
.map(|i| {
(0..dim)
.map(|j| ((i * dim + j) as f32 * 0.001).sin())
.collect()
})
.collect();
for v in &vectors {
hnsw.insert(v).expect("insert should succeed");
}
let mut total_recall = 0.0_f64;
for q_idx in 0..n_queries {
let query = &vectors[q_idx * (n / n_queries)];
let hnsw_ids: Vec<NodeId> = hnsw
.search(query, k, ef_search)
.iter()
.map(|(id, _)| *id)
.collect();
let mut brute: Vec<(NodeId, f32)> = vectors
.iter()
.enumerate()
.map(|(i, v)| {
let dist: f32 = v
.iter()
.zip(query.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
(i, dist)
})
.collect();
brute.sort_by(|a, b| a.1.total_cmp(&b.1));
let gt: Vec<NodeId> = brute.iter().take(k).map(|(id, _)| *id).collect();
let hits = hnsw_ids.iter().filter(|id| gt.contains(id)).count();
#[allow(clippy::cast_precision_loss)]
{
total_recall += hits as f64 / k as f64;
}
}
#[allow(clippy::cast_precision_loss)]
let avg_recall = total_recall / n_queries as f64;
assert!(
avg_recall >= 0.95,
"BitVecVisited recall@{k} must be >= 95% (got {:.1}%); \
bitvec visited set regression detected",
avg_recall * 100.0,
);
}
#[test]
fn test_heap_pool_reuses_allocations() {
{
let mut state = SearchState::new(100);
state.push_candidate(1, 0.5);
state.push_candidate(2, 0.3);
}
{
let mut state = SearchState::new(100);
assert!(
state.candidates.is_empty(),
"pooled candidate heap must be empty on acquire"
);
assert!(
state.results.is_empty(),
"pooled result heap must be empty on acquire"
);
state.push_candidate(10, 0.1);
state.push_candidate(20, 0.9);
assert_eq!(state.candidates.len(), 2);
assert_eq!(state.results.len(), 2);
}
}
#[test]
fn test_heap_pool_bounded_size() {
for _ in 0..(POOL_MAX + 2) {
let mut state = SearchState::new(50);
state.push_candidate(1, 0.5);
}
let mut count = 0_usize;
CANDIDATE_HEAP_POOL.with(|pool| {
count = pool.borrow().len();
});
assert!(
count <= POOL_MAX,
"candidate pool must not exceed POOL_MAX ({count} > {})",
POOL_MAX,
);
let mut result_count = 0_usize;
RESULT_HEAP_POOL.with(|pool| {
result_count = pool.borrow().len();
});
assert!(
result_count <= POOL_MAX,
"result pool must not exceed POOL_MAX ({result_count} > {})",
POOL_MAX,
);
}
#[test]
fn test_heap_pool_recall_regression() {
let dim = 32;
let n = 500;
let k = 10;
let ef_search = 128;
let n_queries = 20;
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 32);
let hnsw = NativeHnsw::new(engine, 16, 200, n);
let vectors: Vec<Vec<f32>> = (0..n)
.map(|i| {
(0..dim)
.map(|j| ((i * dim + j) as f32 * 0.001).sin())
.collect()
})
.collect();
for v in &vectors {
hnsw.insert(v).expect("insert should succeed");
}
let mut total_recall = 0.0_f64;
for q_idx in 0..n_queries {
let query = &vectors[q_idx * (n / n_queries)];
let hnsw_ids: Vec<NodeId> = hnsw
.search(query, k, ef_search)
.iter()
.map(|(id, _)| *id)
.collect();
let mut brute: Vec<(NodeId, f32)> = vectors
.iter()
.enumerate()
.map(|(i, v)| {
let dist: f32 = v
.iter()
.zip(query.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
(i, dist)
})
.collect();
brute.sort_by(|a, b| a.1.total_cmp(&b.1));
let gt: Vec<NodeId> = brute.iter().take(k).map(|(id, _)| *id).collect();
let hits = hnsw_ids.iter().filter(|id| gt.contains(id)).count();
#[allow(clippy::cast_precision_loss)]
{
total_recall += hits as f64 / k as f64;
}
}
#[allow(clippy::cast_precision_loss)]
let avg_recall = total_recall / n_queries as f64;
assert!(
avg_recall >= 0.95,
"Pooled heap recall@{k} must be >= 95% (got {:.1}%); \
heap pool regression detected",
avg_recall * 100.0,
);
}
#[test]
fn test_gather_unvisited_neighbors_with_prefetch() {
let dim = 64; let n = 10_usize;
let mut vectors =
crate::perf_optimizations::ContiguousVectors::new(dim, n).expect("alloc should succeed");
for i in 0..n {
#[allow(clippy::cast_precision_loss)]
let v: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32).collect();
vectors.push(&v).expect("push should succeed");
}
let neighbors: Vec<NodeId> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let mut visited_no_pf = BitVecVisited::with_capacity(n);
let mut visited_pf = BitVecVisited::with_capacity(n);
visited_no_pf.insert(2);
visited_no_pf.insert(5);
visited_no_pf.insert(8);
visited_pf.insert(2);
visited_pf.insert(5);
visited_pf.insert(8);
let batch_no_pf = gather_unvisited_neighbors(&neighbors, &mut visited_no_pf, &vectors, false);
let batch_pf = gather_unvisited_neighbors(&neighbors, &mut visited_pf, &vectors, true);
let ids_no_pf: Vec<NodeId> = batch_no_pf.iter().map(|(id, _)| *id).collect();
let ids_pf: Vec<NodeId> = batch_pf.iter().map(|(id, _)| *id).collect();
assert_eq!(ids_no_pf, ids_pf, "prefetch must not alter gather results");
assert_eq!(ids_pf, vec![0, 1, 3, 4, 6, 7, 9]);
}
#[test]
#[allow(clippy::float_cmp)] fn test_cached_furthest_tracks_push_candidate() {
let mut state = SearchState::new(0);
assert_eq!(
state.cached_furthest,
f32::MAX,
"cached_furthest must be f32::MAX when results heap is empty"
);
state.push_candidate(1, 0.3);
assert!(
(state.cached_furthest - 0.3).abs() < f32::EPSILON,
"cached_furthest must track single-element heap root: got {}",
state.cached_furthest,
);
state.push_candidate(2, 0.1);
assert!(
(state.cached_furthest - 0.3).abs() < f32::EPSILON,
"cached_furthest must remain at max distance: got {}",
state.cached_furthest,
);
state.push_candidate(3, 0.9);
assert!(
(state.cached_furthest - 0.9).abs() < f32::EPSILON,
"cached_furthest must update to new max: got {}",
state.cached_furthest,
);
}
#[test]
fn test_cached_furthest_tracks_batch_eviction() {
let mut state = SearchState::new(0);
let ef = 3;
state.push_candidate(10, 0.2);
state.push_candidate(20, 0.4);
state.push_candidate(30, 0.6);
let v: Vec<f32> = vec![1.0; 4];
let batch: Vec<(NodeId, &[f32])> = vec![(40, v.as_slice())];
let distances = vec![0.1_f32];
process_batch_results(&batch, &distances, ef, &mut state);
assert!(
(state.cached_furthest - 0.4).abs() < f32::EPSILON,
"cached_furthest must refresh after eviction: got {}",
state.cached_furthest,
);
}
#[test]
fn test_atomic_entry_point_starts_as_sentinel() {
use super::super::distance::CpuDistance;
use std::sync::atomic::Ordering;
let engine = CpuDistance::new(DistanceMetric::Euclidean);
let hnsw = NativeHnsw::new(engine, 8, 32, 100);
let ep = hnsw.entry_point.load(Ordering::Acquire);
assert_eq!(
ep, NO_ENTRY_POINT,
"fresh index must have NO_ENTRY_POINT sentinel"
);
let results = hnsw.search(&[1.0, 2.0, 3.0, 4.0], 5, 64);
assert!(
results.is_empty(),
"search on empty index must return empty"
);
}
#[test]
fn test_atomic_entry_point_set_after_insert() {
use std::sync::atomic::Ordering;
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 32);
let hnsw = NativeHnsw::new(engine, 8, 32, 100);
hnsw.insert(&[1.0, 2.0, 3.0, 4.0])
.expect("insert should succeed");
let ep = hnsw.entry_point.load(Ordering::Acquire);
assert_ne!(
ep, NO_ENTRY_POINT,
"entry_point must be set after first insert"
);
assert_eq!(ep, 0, "first inserted node should be entry point");
}
#[test]
fn test_atomic_entry_point_promotes_higher_layer() {
use std::sync::atomic::Ordering;
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 32);
let hnsw = NativeHnsw::new(engine, 16, 100, 5000);
for i in 0..2000_usize {
#[allow(clippy::cast_precision_loss)]
let v: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
hnsw.insert(&v).expect("insert should succeed");
}
let ep = hnsw.entry_point.load(Ordering::Acquire);
assert_ne!(ep, NO_ENTRY_POINT);
let max_layer = hnsw.max_layer.load(Ordering::Relaxed);
assert!(
max_layer > 0,
"max_layer must be > 0 with 2000 nodes (got {max_layer})"
);
}
#[test]
fn test_prefetch_recall_regression() {
let dim = 64;
let n = 500;
let k = 10;
let ef_search = 128;
let n_queries = 20;
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 64);
let hnsw = NativeHnsw::new(engine, 16, 200, n);
let vectors: Vec<Vec<f32>> = (0..n)
.map(|i| {
(0..dim)
.map(|j| ((i * dim + j) as f32 * 0.001).sin())
.collect()
})
.collect();
for v in &vectors {
hnsw.insert(v).expect("insert should succeed");
}
let mut total_recall = 0.0_f64;
for q_idx in 0..n_queries {
let query = &vectors[q_idx * (n / n_queries)];
let hnsw_ids: Vec<NodeId> = hnsw
.search(query, k, ef_search)
.iter()
.map(|(id, _)| *id)
.collect();
let mut brute: Vec<(NodeId, f32)> = vectors
.iter()
.enumerate()
.map(|(i, v)| {
let dist: f32 = v
.iter()
.zip(query.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
(i, dist)
})
.collect();
brute.sort_by(|a, b| a.1.total_cmp(&b.1));
let gt: Vec<NodeId> = brute.iter().take(k).map(|(id, _)| *id).collect();
let hits = hnsw_ids.iter().filter(|id| gt.contains(id)).count();
#[allow(clippy::cast_precision_loss)]
{
total_recall += hits as f64 / k as f64;
}
}
#[allow(clippy::cast_precision_loss)]
let avg_recall = total_recall / n_queries as f64;
assert!(
avg_recall >= 0.95,
"Prefetch recall@{k} must be >= 95% (got {:.1}%); \
prefetch regression detected",
avg_recall * 100.0,
);
}