#![cfg(feature = "hnsw")]
#![allow(clippy::unwrap_used, clippy::expect_used)]
#[path = "common/mod.rs"]
mod common;
use common::*;
use std::collections::HashSet;
use vicinity::hnsw::filtered::{acorn_search, acorn_search_with_stats, AcornConfig, FnFilter};
use vicinity::hnsw::{HNSWIndex, HNSWParams};
fn exact_knn_cosine(vectors: &[Vec<f32>], query: &[f32], k: usize) -> Vec<(u32, f32)> {
let mut distances: Vec<(u32, f32)> = vectors
.iter()
.enumerate()
.map(|(i, v)| {
let dist = vicinity::distance::cosine_distance(v, query);
(i as u32, dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
distances.truncate(k);
distances
}
const DEFAULT_EF: usize = 50;
#[test]
fn test_hnsw_basic_build_and_query() {
let dim = 32;
let n = 500;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 42)
.into_iter()
.map(|v| normalize(&v))
.collect();
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create index");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add vector");
}
hnsw.build().expect("Failed to build index");
let query = &vectors[0];
let results = hnsw.search(query, 10, DEFAULT_EF).expect("Search failed");
assert!(!results.is_empty(), "Search should return results");
assert_eq!(results[0].0, 0, "First result should be internal index 0");
assert!(results[0].1 < 0.01, "Distance to self should be ~0");
}
#[test]
fn test_hnsw_recall_quality() {
let dim = 64;
let n = 1000;
let k = 10;
let n_queries = 50;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 123)
.into_iter()
.map(|v| normalize(&v))
.collect();
let queries: Vec<Vec<f32>> = random_vectors(n_queries, dim, 456)
.into_iter()
.map(|v| normalize(&v))
.collect();
let params = HNSWParams {
m: 32,
m_max: 32,
ef_construction: 200,
ef_search: 100,
..Default::default()
};
let mut hnsw = HNSWIndex::with_params(dim, params).expect("Failed to create index");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add");
}
hnsw.build().expect("Failed to build");
let mut total_recall = 0.0;
for query in &queries {
let exact = exact_knn_cosine(&vectors, query, k);
let approx = hnsw.search(query, k, 150).expect("Search failed");
let recall = recall_at_k_sets(&exact, &approx, k);
total_recall += recall;
}
let avg_recall = total_recall / n_queries as f32;
assert!(
avg_recall >= 0.85,
"Average recall@{} should be >= 0.85, got {}",
k,
avg_recall
);
}
#[test]
fn test_hnsw_empty_index_errors() {
let dim = 32;
let hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create");
let query = vec![0.0f32; dim];
let result = hnsw.search(&query, 10, DEFAULT_EF);
assert!(
result.is_err(),
"Empty unbuilt index should error on search"
);
}
#[test]
fn test_hnsw_single_vector() {
let dim = 16;
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create");
let vector = normalize(&vec![1.0f32; dim]);
hnsw.add(42, vector.clone()).expect("Failed to add");
hnsw.build().expect("Failed to build");
let results = hnsw.search(&vector, 10, DEFAULT_EF).expect("Search failed");
assert_eq!(results.len(), 1, "Should find exactly one result");
assert_eq!(results[0].0, 42, "Should return the inserted doc_id");
}
#[test]
fn test_hnsw_high_dimensional() {
let dim = 768; let n = 100;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 789)
.into_iter()
.map(|v| normalize(&v))
.collect();
let mut hnsw = HNSWIndex::new(dim, 32, 32).expect("Failed to create");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add");
}
hnsw.build().expect("Failed to build");
let results = hnsw
.search(&vectors[50], 10, DEFAULT_EF)
.expect("Search failed");
assert!(!results.is_empty());
assert_eq!(results[0].0, 50, "Should return doc_id 50");
}
#[test]
fn test_hnsw_returns_k_results() {
let dim = 32;
let n = 100;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 111)
.into_iter()
.map(|v| normalize(&v))
.collect();
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add");
}
hnsw.build().expect("Failed to build");
for k in [1, 5, 10, 50, 100] {
let results = hnsw.search(&vectors[0], k, 100).expect("Search failed");
let expected = k.min(n);
assert_eq!(
results.len(),
expected,
"Should return {} results for k={}, got {}",
expected,
k,
results.len()
);
}
let results = hnsw.search(&vectors[0], 200, 200).expect("Search failed");
assert_eq!(results.len(), n, "Should return all {} vectors", n);
}
#[test]
fn test_hnsw_results_sorted_by_distance() {
let dim = 32;
let n = 200;
let k = 20;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 222)
.into_iter()
.map(|v| normalize(&v))
.collect();
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add");
}
hnsw.build().expect("Failed to build");
let query = normalize(&random_vectors(1, dim, 333).pop().unwrap());
let results = hnsw.search(&query, k, 100).expect("Search failed");
for i in 1..results.len() {
assert!(
results[i].1 >= results[i - 1].1 - 1e-6,
"Results not sorted: {:?} vs {:?}",
results[i - 1],
results[i]
);
}
}
#[test]
fn test_hnsw_dimension_validation() {
let result = HNSWIndex::new(0, 16, 16);
assert!(result.is_err(), "Zero dimension should fail");
}
#[test]
fn test_hnsw_query_dimension_mismatch() {
let dim = 32;
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create");
hnsw.add(0, normalize(&vec![1.0; dim]))
.expect("Failed to add");
hnsw.build().expect("Failed to build");
let wrong_dim_query = vec![1.0; dim + 1];
let result = hnsw.search(&wrong_dim_query, 10, DEFAULT_EF);
assert!(result.is_err(), "Wrong dimension query should error");
}
#[test]
fn test_hnsw_with_custom_params() {
let dim = 32;
let n = 200;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 444)
.into_iter()
.map(|v| normalize(&v))
.collect();
let params = HNSWParams {
m: 32,
m_max: 32,
ef_construction: 100,
ef_search: 100,
..Default::default()
};
let mut hnsw = HNSWIndex::with_params(dim, params).expect("Failed to create");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add");
}
hnsw.build().expect("Failed to build");
let results = hnsw.search(&vectors[0], 10, 100).expect("Search failed");
assert!(!results.is_empty());
}
#[test]
fn test_hnsw_repeated_builds_idempotent() {
let dim = 16;
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create");
hnsw.add(0, normalize(&vec![1.0; dim]))
.expect("Failed to add");
hnsw.build().expect("First build");
hnsw.build().expect("Second build should be idempotent");
let results = hnsw
.search(&normalize(&vec![1.0; dim]), 10, DEFAULT_EF)
.expect("Search failed");
assert_eq!(results.len(), 1);
}
#[test]
fn test_hnsw_ef_tradeoff() {
let dim = 64;
let n = 500;
let k = 10;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 999)
.into_iter()
.map(|v| normalize(&v))
.collect();
let query = normalize(&random_vectors(1, dim, 1000).pop().unwrap());
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add");
}
hnsw.build().expect("Failed to build");
let exact = exact_knn_cosine(&vectors, &query, k);
let approx_low = hnsw.search(&query, k, 20).expect("Search failed");
let approx_high = hnsw.search(&query, k, 200).expect("Search failed");
let recall_low = recall_at_k_sets(&exact, &approx_low, k);
let recall_high = recall_at_k_sets(&exact, &approx_high, k);
assert!(
recall_high >= recall_low - 0.1,
"Higher ef_search should not significantly decrease recall: {} vs {}",
recall_low,
recall_high
);
}
#[test]
fn test_hnsw_cannot_add_after_build() {
let dim = 16;
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create");
hnsw.add(0, normalize(&vec![1.0; dim]))
.expect("Failed to add");
hnsw.build().expect("Failed to build");
let result = hnsw.add(1, normalize(&vec![2.0; dim]));
assert!(result.is_err(), "Adding after build should fail");
}
#[test]
fn test_hnsw_search_before_build_fails() {
let dim = 16;
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create");
hnsw.add(0, normalize(&vec![1.0; dim]))
.expect("Failed to add");
let result = hnsw.search(&normalize(&vec![1.0; dim]), 10, DEFAULT_EF);
assert!(result.is_err(), "Search before build should fail");
}
#[test]
fn test_hnsw_cosine_similarity_property() {
let dim = 32;
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create");
let v = normalize(&vec![1.0; dim]);
hnsw.add(0, v.clone()).expect("Failed to add");
hnsw.add(1, v.clone()).expect("Failed to add"); hnsw.build().expect("Failed to build");
let results = hnsw.search(&v, 10, DEFAULT_EF).expect("Search failed");
assert!(
results[0].1 < 0.01,
"Distance to identical vector should be ~0"
);
assert!(
results[1].1 < 0.01,
"Distance to identical vector should be ~0"
);
}
#[test]
fn test_hnsw_orthogonal_vectors() {
let dim = 4;
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create");
let v1 = vec![1.0, 0.0, 0.0, 0.0];
let v2 = vec![0.0, 1.0, 0.0, 0.0];
hnsw.add(0, v1.clone()).expect("Failed to add");
hnsw.add(1, v2.clone()).expect("Failed to add");
hnsw.build().expect("Failed to build");
let results = hnsw.search(&v1, 10, DEFAULT_EF).expect("Search failed");
assert_eq!(results[0].0, 0);
assert!(results[0].1 < 0.01);
assert_eq!(results[1].0, 1);
assert!(
(results[1].1 - 1.0).abs() < 0.01,
"Orthogonal vector distance should be ~1"
);
}
#[test]
fn test_hnsw_recall_monotonic_with_ef() {
let dim = 32;
let n = 500;
let k = 10;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 42)
.into_iter()
.map(|v| normalize(&v))
.collect();
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create index");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add vector");
}
hnsw.build().expect("Failed to build index");
let test_queries: Vec<Vec<f32>> = random_vectors(20, dim, 999)
.into_iter()
.map(|v| normalize(&v))
.collect();
let ef_values = [10, 20, 50, 100, 200];
for query in &test_queries {
let exact = exact_knn_cosine(&vectors, query, k);
let mut prev_recall = 0.0_f32;
for &ef in &ef_values {
let results = hnsw.search(query, k, ef).expect("Search failed");
let recall = recall_at_k_sets(&exact, &results, k);
assert!(
recall >= prev_recall - 0.1,
"Recall decreased from {} to {} when ef increased to {}",
prev_recall,
recall,
ef
);
prev_recall = recall;
}
}
}
#[test]
fn test_hnsw_search_returns_valid_indices() {
let dim = 16;
let n = 200;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 123)
.into_iter()
.map(|v| normalize(&v))
.collect();
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create index");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add vector");
}
hnsw.build().expect("Failed to build index");
let test_queries: Vec<Vec<f32>> = random_vectors(50, dim, 456)
.into_iter()
.map(|v| normalize(&v))
.collect();
for query in &test_queries {
let results = hnsw.search(query, 20, DEFAULT_EF).expect("Search failed");
for (doc_id, dist) in &results {
assert!(
(*doc_id as usize) < n,
"Search returned invalid doc_id {} (n={})",
doc_id,
n
);
assert!(
*dist >= 0.0 && *dist <= 2.0 + 1e-5,
"Invalid cosine distance: {}",
dist
);
}
}
}
#[test]
fn test_hnsw_deterministic_search() {
let dim = 32;
let n = 300;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 789)
.into_iter()
.map(|v| normalize(&v))
.collect();
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create index");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add vector");
}
hnsw.build().expect("Failed to build index");
let query = normalize(&vec![1.0; dim]);
let results1 = hnsw.search(&query, 10, DEFAULT_EF).expect("Search failed");
let results2 = hnsw.search(&query, 10, DEFAULT_EF).expect("Search failed");
let results3 = hnsw.search(&query, 10, DEFAULT_EF).expect("Search failed");
assert_eq!(results1, results2, "Search should be deterministic");
assert_eq!(results2, results3, "Search should be deterministic");
}
#[test]
fn test_hnsw_results_unique() {
let dim = 32;
let n = 400;
let k = 50;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 321)
.into_iter()
.map(|v| normalize(&v))
.collect();
let mut hnsw = HNSWIndex::new(dim, 16, 16).expect("Failed to create index");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add vector");
}
hnsw.build().expect("Failed to build index");
let test_queries: Vec<Vec<f32>> = random_vectors(30, dim, 654)
.into_iter()
.map(|v| normalize(&v))
.collect();
for query in &test_queries {
let results = hnsw.search(query, k, 100).expect("Search failed");
let indices: Vec<u32> = results.iter().map(|(i, _)| *i).collect();
let unique: HashSet<u32> = indices.iter().copied().collect();
assert_eq!(
indices.len(),
unique.len(),
"Search returned duplicate indices"
);
}
}
#[test]
fn test_hnsw_graph_connectivity() {
let dim = 32;
let n = 200;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 777)
.into_iter()
.map(|v| normalize(&v))
.collect();
let mut hnsw = HNSWIndex::new(dim, 16, 100).expect("Failed to create");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add");
}
hnsw.build().expect("Failed to build");
let mut reachable = 0;
for (i, v) in vectors.iter().enumerate() {
let results = hnsw.search(v, 10, 200).expect("Search failed");
let ids: HashSet<u32> = results.iter().map(|(id, _)| *id).collect();
if ids.contains(&(i as u32)) {
reachable += 1;
}
}
let reachability = reachable as f32 / n as f32;
assert!(
reachability >= 0.95,
"Only {:.1}% of nodes reachable from themselves (expected >= 95%)",
reachability * 100.0
);
}
#[test]
fn test_distance_simd_scalar_agreement() {
let dims = [1, 2, 3, 7, 16, 31, 32, 33, 64, 128, 255, 256, 513];
for dim in dims {
let a: Vec<f32> = (0..dim)
.map(|i| ((i * 31 + 7) as f32 * 0.001).sin())
.collect();
let b: Vec<f32> = (0..dim)
.map(|i| ((i * 17 + 3) as f32 * 0.001).cos())
.collect();
let scalar_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let scalar_l2_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
let scalar_l2 = scalar_l2_sq.sqrt();
let simd_l2 = vicinity::distance::l2_distance(&a, &b);
let simd_cosine = vicinity::distance::cosine_distance(&a, &b);
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
let scalar_cosine = if norm_a > 0.0 && norm_b > 0.0 {
1.0 - scalar_dot / (norm_a * norm_b)
} else {
1.0
};
let eps = 1e-4;
assert!(
(simd_l2 - scalar_l2).abs() < eps,
"L2 mismatch at dim={}: simd={}, scalar={}",
dim,
simd_l2,
scalar_l2
);
assert!(
(simd_cosine - scalar_cosine).abs() < eps,
"Cosine mismatch at dim={}: simd={}, scalar={}",
dim,
simd_cosine,
scalar_cosine
);
}
}
#[test]
fn test_hnsw_edge_case_k_values() {
let dim = 16;
let n = 50;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 42)
.into_iter()
.map(|v| normalize(&v))
.collect();
let mut hnsw = HNSWIndex::new(dim, 8, 50).expect("Failed to create");
for (i, v) in vectors.iter().enumerate() {
hnsw.add(i as u32, v.clone()).expect("Failed to add");
}
hnsw.build().expect("Failed to build");
let query = &vectors[0];
let results = hnsw.search(query, 1, 50).expect("Search failed");
assert_eq!(results.len(), 1, "k=1 should return exactly 1 result");
assert_eq!(results[0].0, 0, "k=1 query on self should return self");
let results = hnsw.search(query, n + 100, 200).expect("Search failed");
assert!(
results.len() <= n,
"k > n should return at most n={} results, got {}",
n,
results.len()
);
}
#[test]
fn test_filtered_search_oracle() {
let dim = 32;
let n = 200;
let k = 10;
let n_queries = 20;
let neighbors_per_node = 32;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 7777)
.into_iter()
.map(|v| normalize(&v))
.collect();
let category_of = |id: u32| -> u32 { id % 2 };
let mut graph: Vec<HashSet<u32>> = (0..n).map(|_| HashSet::new()).collect();
for i in 0..n {
let mut dists: Vec<(u32, f32)> = (0..n)
.filter(|&j| j != i)
.map(|j| {
(
j as u32,
vicinity::distance::cosine_distance(&vectors[i], &vectors[j]),
)
})
.collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
for &(j, _) in dists.iter().take(neighbors_per_node) {
graph[i].insert(j);
graph[j as usize].insert(i as u32); }
}
let graph: Vec<Vec<u32>> = graph.into_iter().map(|s| s.into_iter().collect()).collect();
let queries: Vec<Vec<f32>> = random_vectors(n_queries, dim, 8888)
.into_iter()
.map(|v| normalize(&v))
.collect();
let target_category: u32 = 0;
let mut total_overlap = 0usize;
let mut total_expected = 0usize;
for query in &queries {
let mut gt: Vec<(u32, f32)> = (0..n as u32)
.filter(|&id| category_of(id) == target_category)
.map(|id| {
(
id,
vicinity::distance::cosine_distance(&vectors[id as usize], query),
)
})
.collect();
gt.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
gt.truncate(k);
let gt_ids: HashSet<u32> = gt.iter().map(|(id, _)| *id).collect();
let entry_point = (0..n as u32)
.min_by(|&a, &b| {
let da = vicinity::distance::cosine_distance(&vectors[a as usize], query);
let db = vicinity::distance::cosine_distance(&vectors[b as usize], query);
da.partial_cmp(&db).unwrap()
})
.unwrap();
let filter = FnFilter(|id: u32| category_of(id) == target_category);
let config = AcornConfig {
enable_two_hop: true,
two_hop_threshold: 0.5,
max_two_hop_neighbors: 64,
ef_search: 200,
};
let results = acorn_search(
k,
&config,
&filter,
|id| graph[id as usize].clone(),
|id| vicinity::distance::cosine_distance(&vectors[id as usize], query),
entry_point,
)
.expect("acorn_search failed");
for (id, _) in &results {
assert_eq!(
category_of(*id),
target_category,
"Filtered result {} has wrong category",
id
);
}
let result_ids: HashSet<u32> = results.iter().map(|(id, _)| *id).collect();
let overlap = gt_ids.intersection(&result_ids).count();
total_overlap += overlap;
total_expected += gt_ids.len();
}
let avg_recall = total_overlap as f32 / total_expected as f32;
assert!(
avg_recall >= 0.5,
"Filtered search recall too low: {:.2} (expected >= 0.50)",
avg_recall
);
}
#[test]
fn test_acorn_low_selectivity_returns_valid_results() {
let dim = 32;
let n = 400;
let k = 5;
let n_queries = 20;
let neighbors_per_node = 8;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 7777)
.into_iter()
.map(|v| normalize(&v))
.collect();
let category_of = |id: u32| -> bool { id.is_multiple_of(40) };
let mut graph: Vec<HashSet<u32>> = (0..n).map(|_| HashSet::new()).collect();
for i in 0..n {
let mut dists: Vec<(u32, f32)> = (0..n)
.filter(|&j| j != i)
.map(|j| {
(
j as u32,
vicinity::distance::cosine_distance(&vectors[i], &vectors[j]),
)
})
.collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
for &(j, _) in dists.iter().take(neighbors_per_node) {
graph[i].insert(j);
graph[j as usize].insert(i as u32);
}
}
let graph: Vec<Vec<u32>> = graph.into_iter().map(|s| s.into_iter().collect()).collect();
let queries: Vec<Vec<f32>> = random_vectors(n_queries, dim, 8888)
.into_iter()
.map(|v| normalize(&v))
.collect();
let mut total_overlap = 0usize;
let mut total_expected = 0usize;
for query in &queries {
let mut gt: Vec<(u32, f32)> = (0..n as u32)
.filter(|&id| category_of(id))
.map(|id| {
(
id,
vicinity::distance::cosine_distance(&vectors[id as usize], query),
)
})
.collect();
gt.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
gt.truncate(k);
if gt.is_empty() {
continue;
}
let gt_ids: HashSet<u32> = gt.iter().map(|(id, _)| *id).collect();
let entry_point = (0..n as u32)
.min_by(|&a, &b| {
let da = vicinity::distance::cosine_distance(&vectors[a as usize], query);
let db = vicinity::distance::cosine_distance(&vectors[b as usize], query);
da.partial_cmp(&db).unwrap()
})
.unwrap();
let filter = FnFilter(|id: u32| category_of(id));
let config = AcornConfig {
enable_two_hop: true,
two_hop_threshold: 0.5,
max_two_hop_neighbors: 16,
ef_search: 64,
};
let results = acorn_search(
k,
&config,
&filter,
|id| graph[id as usize].clone(),
|id| vicinity::distance::cosine_distance(&vectors[id as usize], query),
entry_point,
)
.expect("acorn_search failed");
for (id, _) in &results {
assert!(
category_of(*id),
"acorn_search returned id {} that fails the predicate",
id
);
}
let result_ids: HashSet<u32> = results.iter().map(|(id, _)| *id).collect();
total_overlap += gt_ids.intersection(&result_ids).count();
total_expected += gt_ids.len();
}
let recall = total_overlap as f32 / total_expected as f32;
assert!(
recall >= 0.5,
"ACORN recall at ~2.5% selectivity dropped below floor: {:.3} < 0.50",
recall
);
}
#[test]
fn test_acorn_two_hop_branch_fires_at_sparse_selectivity() {
let dim = 32;
let n = 400;
let k = 5;
let neighbors_per_node = 8;
let vectors: Vec<Vec<f32>> = random_vectors(n, dim, 7777)
.into_iter()
.map(|v| normalize(&v))
.collect();
let category_of = |id: u32| -> bool { id.is_multiple_of(40) };
let mut graph: Vec<HashSet<u32>> = (0..n).map(|_| HashSet::new()).collect();
for i in 0..n {
let mut dists: Vec<(u32, f32)> = (0..n)
.filter(|&j| j != i)
.map(|j| {
(
j as u32,
vicinity::distance::cosine_distance(&vectors[i], &vectors[j]),
)
})
.collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
for &(j, _) in dists.iter().take(neighbors_per_node) {
graph[i].insert(j);
graph[j as usize].insert(i as u32);
}
}
let graph: Vec<Vec<u32>> = graph.into_iter().map(|s| s.into_iter().collect()).collect();
let query = normalize(&random_vectors(1, dim, 8888).into_iter().next().unwrap());
let entry_point = (0..n as u32)
.min_by(|&a, &b| {
let da = vicinity::distance::cosine_distance(&vectors[a as usize], &query);
let db = vicinity::distance::cosine_distance(&vectors[b as usize], &query);
da.partial_cmp(&db).unwrap()
})
.unwrap();
let filter = FnFilter(|id: u32| category_of(id));
let config = AcornConfig {
enable_two_hop: true,
two_hop_threshold: 0.5,
max_two_hop_neighbors: 16,
ef_search: 64,
};
let (_results, stats) = acorn_search_with_stats(
k,
&config,
&filter,
|id| graph[id as usize].clone(),
|id| vicinity::distance::cosine_distance(&vectors[id as usize], &query),
entry_point,
)
.expect("acorn_search_with_stats failed");
assert!(
stats.two_hop_invocations >= 1,
"ACORN 2-hop branch did not fire at ~2.5% selectivity: \
two_hop_invocations={}, two_hop_nodes_examined={}. \
likely cause: enable_two_hop wired wrong or the predicate-failing \
path was short-circuited.",
stats.two_hop_invocations,
stats.two_hop_nodes_examined,
);
assert!(
stats.two_hop_nodes_examined >= 1,
"ACORN 2-hop branch fired ({} times) but examined zero new nodes. \
likely cause: SearchState dedup is masking the branch's contribution.",
stats.two_hop_invocations,
);
}