#![allow(clippy::unwrap_used, clippy::expect_used)]
#![cfg(feature = "hnsw")]
#![cfg(feature = "hnsw")]
#![allow(clippy::float_cmp)]
#[path = "common/mod.rs"]
mod common;
use common::*;
#[test]
fn cosine_distance_handles_unnormalized_query() {
let _a = [1.0f32, 0.0, 0.0]; let b = vec![0.0f32, 1.0, 0.0]; let c = vec![0.707f32, 0.707, 0.0];
let query = vec![2.0f32, 0.0, 0.0];
let dist_query_b = vicinity::distance::cosine_distance(&query, &b);
let dist_query_c = vicinity::distance::cosine_distance(&query, &c);
assert!(
dist_query_c < dist_query_b,
"c ({:.4}) should be closer to query than b ({:.4})",
dist_query_c,
dist_query_b
);
}
#[test]
#[cfg(feature = "hnsw")]
fn layer0_uses_correct_connectivity() {
use vicinity::hnsw::HNSWIndex;
let dim = 32;
let m = 8; let m_max = 16;
let mut index = HNSWIndex::new(dim, m, m_max).unwrap();
for i in 0..100u32 {
let vec: Vec<f32> = (0..dim)
.map(|d| ((i as usize * dim + d) as f32 * 0.1).sin())
.collect();
index.add(i, normalize(&vec)).unwrap();
}
index.build().unwrap();
let query = normalize(&vec![1.0; dim]);
let results = index.search(&query, 20, 100).unwrap();
assert!(results.len() >= m, "Should return at least M results");
}
#[test]
#[cfg(feature = "hnsw")]
fn layer_traversal_is_top_to_bottom() {
use vicinity::hnsw::HNSWIndex;
let dim = 64; let n = 500;
let vectors: Vec<Vec<f32>> = (0..n).map(|i| normalize(&random_vec(dim, i))).collect();
let mut index = HNSWIndex::new(dim, 16, 32).unwrap();
for (id, vec) in vectors.iter().enumerate() {
index.add_slice(id as u32, vec).unwrap();
}
index.build().unwrap();
let query = &vectors[250]; let results = index.search(query, 50, n).unwrap();
let found_self = results.iter().any(|(id, _)| *id == 250);
assert!(
found_self,
"Should find the query vector itself - layer traversal may be wrong"
);
}
#[test]
#[cfg(feature = "hnsw")]
fn stopping_condition_allows_sufficient_exploration() {
use std::collections::HashSet;
use vicinity::hnsw::HNSWIndex;
let dim = 64;
let n = 500;
let k = 10;
let vectors: Vec<Vec<f32>> = (0..n).map(|i| normalize(&random_vec(dim, i))).collect();
let mut index = HNSWIndex::new(dim, 16, 32).unwrap();
for (id, vec) in vectors.iter().enumerate() {
index.add(id as u32, vec.clone()).unwrap();
}
index.build().unwrap();
let query = &vectors[250];
let gt = brute_force_knn(query, &vectors, k);
let gt_ids: HashSet<u32> = gt.into_iter().collect();
let recall_ef50 = {
let results = index.search(query, k, 50).unwrap();
let ids: HashSet<u32> = results.iter().map(|(id, _)| *id).collect();
gt_ids.intersection(&ids).count() as f32 / k as f32
};
let recall_ef200 = {
let results = index.search(query, k, 200).unwrap();
let ids: HashSet<u32> = results.iter().map(|(id, _)| *id).collect();
gt_ids.intersection(&ids).count() as f32 / k as f32
};
assert!(
recall_ef200 >= recall_ef50 * 0.95, "Higher ef should not decrease recall: ef50={:.2}%, ef200={:.2}%",
recall_ef50 * 100.0,
recall_ef200 * 100.0
);
}
#[test]
#[cfg(feature = "hnsw")]
fn recall_is_measurable() {
use std::collections::HashSet;
use vicinity::hnsw::HNSWIndex;
let dim = 64;
let n = 500;
let k = 10;
let vectors: Vec<Vec<f32>> = (0..n).map(|i| normalize(&random_vec(dim, i))).collect();
let mut index = HNSWIndex::new(dim, 16, 32).unwrap();
for (id, vec) in vectors.iter().enumerate() {
index.add(id as u32, vec.clone()).unwrap();
}
index.build().unwrap();
let mut total_recall = 0.0;
let n_queries = 10;
for q in 0..n_queries {
let query = &vectors[(q * 37) % n];
let gt = brute_force_knn(query, &vectors, k);
let gt_ids: HashSet<u32> = gt.into_iter().collect();
let approx = index.search(query, k, 100).unwrap();
let approx_ids: HashSet<u32> = approx.iter().map(|(id, _)| *id).collect();
let recall = gt_ids.intersection(&approx_ids).count() as f32 / k as f32;
total_recall += recall;
}
let avg_recall = total_recall / n_queries as f32;
assert!(
avg_recall > 0.7,
"Average recall ({:.2}%) is too low - possible configuration issue",
avg_recall * 100.0
);
}
fn random_vec(dim: usize, seed: usize) -> Vec<f32> {
let raw: Vec<f32> = (0..dim)
.map(|i| ((seed * 31 + i * 17) as f32 * 0.001).sin())
.collect();
normalize(&raw)
}