use super::distance::CachedSimdDistance;
use super::dual_precision::{DualPrecisionConfig, DualPrecisionHnsw};
use crate::distance::DistanceMetric;
#[test]
fn test_create_dual_precision_hnsw() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 128);
let hnsw = DualPrecisionHnsw::new(engine, 128, 16, 100, 1000).expect("test");
assert!(hnsw.is_empty());
assert!(!hnsw.is_quantizer_trained());
}
#[test]
fn test_insert_before_quantizer_training() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 32);
let mut hnsw = DualPrecisionHnsw::new(engine, 32, 16, 100, 1000).expect("test");
for i in 0..10 {
let v: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
hnsw.insert(&v).expect("test");
}
assert_eq!(hnsw.len(), 10);
assert!(!hnsw.is_quantizer_trained(), "Should not train yet");
}
#[test]
fn test_quantizer_trains_after_threshold() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 32);
let mut hnsw = DualPrecisionHnsw::new(engine, 32, 16, 100, 100).expect("test");
for i in 0..100 {
let v: Vec<f32> = (0..32)
.map(|j| ((i * 32 + j) as f32 * 0.01).sin())
.collect();
hnsw.insert(&v).expect("test");
}
assert!(
hnsw.is_quantizer_trained(),
"Quantizer should be trained after threshold"
);
}
#[test]
fn test_force_train_quantizer() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 32);
let mut hnsw = DualPrecisionHnsw::new(engine, 32, 16, 100, 1000).expect("test");
for i in 0..50 {
let v: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
hnsw.insert(&v).expect("test");
}
assert!(!hnsw.is_quantizer_trained());
hnsw.force_train_quantizer();
assert!(hnsw.is_quantizer_trained());
}
#[test]
fn test_search_before_quantizer_training() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 32);
let mut hnsw = DualPrecisionHnsw::new(engine, 32, 16, 100, 1000).expect("test");
for i in 0..50 {
let v: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
hnsw.insert(&v).expect("test");
}
let query: Vec<f32> = (0..32).map(|j| j as f32).collect();
let results = hnsw.search(&query, 10, 50);
assert!(!results.is_empty());
assert_eq!(results[0].0, 0);
}
#[test]
fn test_search_after_quantizer_training() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 32);
let mut hnsw = DualPrecisionHnsw::new(engine, 32, 16, 100, 1000).expect("test");
for i in 0..50 {
let v: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
hnsw.insert(&v).expect("test");
}
hnsw.force_train_quantizer();
let query: Vec<f32> = (0..32).map(|j| j as f32).collect();
let results = hnsw.search(&query, 10, 50);
assert!(!results.is_empty());
assert_eq!(results[0].0, 0);
}
#[test]
fn test_dual_precision_recall() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 128);
let mut hnsw = DualPrecisionHnsw::new(engine, 128, 32, 200, 1000).expect("test");
let vectors: Vec<Vec<f32>> = (0..200)
.map(|i| {
(0..128)
.map(|j| ((i * 128 + j) as f32 * 0.01).sin())
.collect()
})
.collect();
for v in &vectors {
hnsw.insert(v).expect("test");
}
hnsw.force_train_quantizer();
let query: Vec<f32> = (0..128).map(|j| (j as f32 * 0.01).sin()).collect();
let results = hnsw.search(&query, 10, 100);
assert!(results.len() >= 5, "Should find at least 5 neighbors");
for i in 1..results.len() {
assert!(
results[i].1 >= results[i - 1].1,
"Results should be sorted by distance"
);
}
}
#[test]
fn test_insert_after_quantizer_training() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 32);
let mut hnsw = DualPrecisionHnsw::new(engine, 32, 16, 100, 1000).expect("test");
for i in 0..50 {
let v: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
hnsw.insert(&v).expect("test");
}
hnsw.force_train_quantizer();
for i in 50..100 {
let v: Vec<f32> = (0..32).map(|j| (i * 32 + j) as f32).collect();
hnsw.insert(&v).expect("test");
}
assert_eq!(hnsw.len(), 100);
let query: Vec<f32> = (0..32).map(|j| (75 * 32 + j) as f32).collect();
let results = hnsw.search(&query, 5, 50);
assert!(!results.is_empty());
}
#[test]
fn test_quantized_reranking_uses_asymmetric_distance() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 64);
let mut hnsw = DualPrecisionHnsw::new(engine, 64, 16, 100, 500).expect("test");
for i in 0..200 {
let v: Vec<f32> = (0..64)
.map(|j| ((i * 64 + j) as f32 * 0.01).sin())
.collect();
hnsw.insert(&v).expect("test");
}
hnsw.force_train_quantizer();
assert!(hnsw.is_quantizer_trained());
let query: Vec<f32> = (0..64).map(|j| (j as f32 * 0.01).sin()).collect();
let results = hnsw.search(&query, 10, 50);
assert!(!results.is_empty());
for i in 1..results.len() {
assert!(
results[i].1 >= results[i - 1].1,
"Results must be sorted by exact distance after reranking"
);
}
}
#[test]
fn test_quantized_reranking_maintains_recall() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 128);
let mut hnsw = DualPrecisionHnsw::new(engine, 128, 32, 200, 1000).expect("test");
let vectors: Vec<Vec<f32>> = (0..500)
.map(|i| {
(0..128)
.map(|j| ((i * 128 + j) as f32 * 0.001).cos())
.collect()
})
.collect();
for v in &vectors {
hnsw.insert(v).expect("test");
}
hnsw.force_train_quantizer();
let query = vectors[0].clone();
let results = hnsw.search(&query, 10, 100);
let found_exact = results.iter().any(|(id, _)| *id == 0);
assert!(
found_exact,
"Quantized reranking should maintain high recall"
);
}
#[test]
fn test_search_with_int8_traversal_enabled() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 64);
let mut hnsw = DualPrecisionHnsw::new(engine, 64, 16, 100, 500).expect("test");
for i in 0..200 {
let v: Vec<f32> = (0..64)
.map(|j| ((i * 64 + j) as f32 * 0.01).sin())
.collect();
hnsw.insert(&v).expect("test");
}
hnsw.force_train_quantizer();
let query: Vec<f32> = (0..64).map(|j| (j as f32 * 0.01).sin()).collect();
let config = DualPrecisionConfig {
oversampling_ratio: 4,
use_int8_traversal: true, ..Default::default()
};
let results = hnsw.search_with_config(&query, 10, 50, &config);
assert!(!results.is_empty());
for i in 1..results.len() {
assert!(results[i].1 >= results[i - 1].1, "Results should be sorted");
}
}
#[test]
fn test_int8_traversal_recall_vs_f32() {
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, 128);
let mut hnsw = DualPrecisionHnsw::new(engine, 128, 32, 200, 1000).expect("test");
let vectors: Vec<Vec<f32>> = (0..500)
.map(|i| {
(0..128)
.map(|j| ((i * 128 + j) as f32 * 0.001).cos())
.collect()
})
.collect();
for v in &vectors {
hnsw.insert(v).expect("test");
}
hnsw.force_train_quantizer();
let query = vectors[0].clone();
let f32_results = hnsw.search(&query, 10, 100);
let config = DualPrecisionConfig {
oversampling_ratio: 4,
use_int8_traversal: true,
..Default::default()
};
let int8_results = hnsw.search_with_config(&query, 10, 100, &config);
let f32_ids: std::collections::HashSet<_> = f32_results.iter().map(|(id, _)| *id).collect();
let int8_ids: std::collections::HashSet<_> = int8_results.iter().map(|(id, _)| *id).collect();
let overlap = f32_ids.intersection(&int8_ids).count();
let recall = overlap as f64 / f32_results.len().max(1) as f64;
assert!(
recall >= 0.90,
"Int8 traversal recall should be >= 90%, got {:.2}%",
recall * 100.0
);
}
#[test]
fn test_dual_precision_config_defaults() {
let config = DualPrecisionConfig::default();
assert_eq!(config.oversampling_ratio, 4);
assert!(config.use_int8_traversal);
assert_eq!(config.min_index_size, 10_000);
}
#[test]
fn test_rerank_euclidean_returns_sqrt_not_squared_with_cached_engine() {
use super::distance::CachedSimdDistance;
let dim = 32;
let engine = CachedSimdDistance::new(DistanceMetric::Euclidean, dim);
let mut hnsw = DualPrecisionHnsw::new(engine, dim, 16, 100, 1000).expect("test");
let v0 = vec![0.0_f32; dim];
let v1 = vec![1.0_f32; dim];
hnsw.insert(&v0).expect("test");
hnsw.insert(&v1).expect("test");
hnsw.force_train_quantizer();
let results = hnsw.search(&v0, 2, 50);
assert!(
results.len() >= 2,
"Expected at least 2 results, got {}",
results.len()
);
let v1_dist = results
.iter()
.find(|(id, _)| *id == 1)
.map(|(_, d)| *d)
.expect("v1 should be in results");
let expected = (dim as f32).sqrt(); let tolerance = 0.01;
assert!(
(v1_dist - expected).abs() < tolerance,
"Distance to v1 should be sqrt({dim}) ~= {expected:.3}, got {v1_dist:.3} \
(if ~{dim}.0, transform_score was not applied)"
);
}
pub(super) fn unit_vector(seed: u64, dim: usize) -> Vec<f32> {
let mut state = seed.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1);
let mut v: Vec<f32> = (0..dim)
.map(|_| {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
((state >> 40) as f32 / 8_388_608.0) - 1.0
})
.collect();
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut v {
*x /= norm;
}
v
}
pub(super) fn planted_unit_vectors(n: usize, dim: usize, query_id: usize) -> Vec<Vec<f32>> {
debug_assert!(query_id < n - 10, "query must not overlap planted slots");
let mut vectors: Vec<Vec<f32>> = (0..n).map(|i| unit_vector(i as u64 + 1, dim)).collect();
for slot in 0..10 {
let noise = unit_vector(1_000 + slot as u64, dim);
let eps = 0.1 + 0.04 * slot as f32;
let mut v: Vec<f32> = vectors[query_id]
.iter()
.zip(noise.iter())
.map(|(a, b)| a + eps * b)
.collect();
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut v {
*x /= norm;
}
vectors[n - 10 + slot] = v;
}
vectors
}
pub(super) fn brute_force_top_ids(
vectors: &[Vec<f32>],
query: &[f32],
metric: DistanceMetric,
k: usize,
) -> Vec<usize> {
let mut scored: Vec<(usize, f32)> = vectors
.iter()
.enumerate()
.map(|(i, v)| (i, metric.calculate(query, v)))
.collect();
if metric.higher_is_better() {
scored.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
} else {
scored.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
}
scored.truncate(k);
scored.into_iter().map(|(i, _)| i).collect()
}
pub(super) fn assert_top1_and_recall(
results: &[(usize, f32)],
vectors: &[Vec<f32>],
query_id: usize,
metric: DistanceMetric,
k: usize,
) {
assert!(!results.is_empty(), "search returned no results");
assert_eq!(
results[0].0, query_id,
"self-query must rank first for {metric:?}, got node {} (score {})",
results[0].0, results[0].1
);
assert!(
results[0].1 > 0.99,
"self-similarity must be maximal for {metric:?}, got {}",
results[0].1
);
let expected = brute_force_top_ids(vectors, &vectors[query_id], metric, k);
let got: std::collections::HashSet<usize> = results.iter().map(|(id, _)| *id).collect();
let overlap = expected.iter().filter(|id| got.contains(id)).count();
let recall = overlap as f64 / k as f64;
assert!(
recall >= 0.95,
"recall@{k} vs brute-force must be >= 0.95 for {metric:?}, got {recall:.2}"
);
}
fn run_dual_precision_self_query(metric: DistanceMetric, use_int8_traversal: bool) {
let (dim, n, k) = (32, 100, 10);
let query_id = 42_usize;
let engine = CachedSimdDistance::new(metric, dim);
let mut hnsw = DualPrecisionHnsw::new(engine, dim, 16, 200, 1000).expect("test");
let vectors = planted_unit_vectors(n, dim, query_id);
for v in &vectors {
hnsw.insert(v).expect("test");
}
hnsw.force_train_quantizer();
assert!(hnsw.is_quantizer_trained());
let results = if use_int8_traversal {
let config = DualPrecisionConfig {
min_index_size: 0,
..Default::default()
};
hnsw.search_with_config(&vectors[query_id], k, 100, &config)
} else {
hnsw.search(&vectors[query_id], k, 100)
};
assert_top1_and_recall(&results, &vectors, query_id, metric, k);
}
#[test]
fn test_dual_precision_cosine_rerank_keeps_best_candidates() {
run_dual_precision_self_query(DistanceMetric::Cosine, false);
}
#[test]
fn test_dual_precision_dot_product_rerank_keeps_best_candidates() {
run_dual_precision_self_query(DistanceMetric::DotProduct, false);
}
#[test]
fn test_int8_traversal_cosine_rerank_keeps_best_candidates() {
run_dual_precision_self_query(DistanceMetric::Cosine, true);
}
#[test]
fn test_int8_traversal_dot_product_rerank_keeps_best_candidates() {
run_dual_precision_self_query(DistanceMetric::DotProduct, true);
}
#[test]
fn test_rerank_cosine_applies_transform_with_cached_engine() {
use super::distance::CachedSimdDistance;
let dim = 32;
let engine = CachedSimdDistance::new(DistanceMetric::Cosine, dim);
let mut hnsw = DualPrecisionHnsw::new(engine, dim, 16, 100, 1000).expect("test");
let norm = 1.0 / (dim as f32).sqrt();
let v0: Vec<f32> = vec![norm; dim];
let mut v1 = vec![0.0_f32; dim];
let v1_norm = 1.0 / (dim as f32 / 2.0).sqrt();
for slot in v1.iter_mut().take(dim / 2) {
*slot = v1_norm;
}
hnsw.insert(&v0).expect("test");
hnsw.insert(&v1).expect("test");
hnsw.force_train_quantizer();
let results = hnsw.search(&v0, 2, 50);
assert!(!results.is_empty());
for (id, score) in &results {
assert!(
*score >= 0.0 && *score <= 1.0,
"Cosine score for node {id} should be in [0,1], got {score}"
);
}
}