use aletheiadb::core::id::NodeId;
use aletheiadb::index::vector::{DistanceMetric, HnswIndexBuilder, Quantization, VectorIndex};
use std::collections::HashSet;
fn generate_vectors(count: usize, dims: usize) -> Vec<Vec<f32>> {
(0..count)
.map(|i| {
(0..dims)
.map(|j| ((i * 17 + j * 31) % 1000) as f32 / 1000.0)
.collect()
})
.collect()
}
fn calculate_recall(baseline: &[(NodeId, f32)], test: &[(NodeId, f32)]) -> f64 {
let baseline_ids: HashSet<_> = baseline.iter().map(|(id, _)| *id).collect();
let test_ids: HashSet<_> = test.iter().map(|(id, _)| *id).collect();
let intersection = baseline_ids.intersection(&test_ids).count();
if baseline_ids.is_empty() {
1.0
} else {
intersection as f64 / baseline_ids.len() as f64
}
}
#[test]
fn test_f16_quantization_recall() {
let dims = 128;
let vectors = generate_vectors(1000, dims);
let f32_index = HnswIndexBuilder::new(dims, DistanceMetric::Cosine)
.ef_construction(200)
.ef_search(100)
.build()
.unwrap();
let f16_index = HnswIndexBuilder::new(dims, DistanceMetric::Cosine)
.ef_construction(200)
.ef_search(200) .quantization(Quantization::F16)
.build()
.unwrap();
for (i, vec) in vectors.iter().enumerate() {
let node = NodeId::new(i as u64 + 1).unwrap();
f32_index.add(node, vec).unwrap();
f16_index.add(node, vec).unwrap();
}
let mut total_recall = 0.0;
let num_queries = 10;
for i in 0..num_queries {
let query = &vectors[i * 100];
let f32_results = f32_index.search(query, 10).unwrap();
let f16_results = f16_index.search(query, 10).unwrap();
total_recall += calculate_recall(&f32_results, &f16_results);
}
let avg_recall = total_recall / num_queries as f64;
assert!(
avg_recall >= 0.80,
"F16 recall {:.2}% is below 80% threshold",
avg_recall * 100.0
);
let f32_memory = f32_index.memory_usage();
let f16_memory = f16_index.memory_usage();
assert!(f32_memory > 0, "F32 index should report non-zero memory");
assert!(f16_memory > 0, "F16 index should report non-zero memory");
}
#[test]
fn test_i8_quantization_recall() {
let dims = 128;
let vectors = generate_vectors(1000, dims);
let f32_index = HnswIndexBuilder::new(dims, DistanceMetric::Cosine)
.ef_construction(200)
.ef_search(100)
.build()
.unwrap();
let i8_index = HnswIndexBuilder::new(dims, DistanceMetric::Cosine)
.ef_construction(200)
.ef_search(100)
.quantization(Quantization::I8)
.build()
.unwrap();
for (i, vec) in vectors.iter().enumerate() {
let node = NodeId::new(i as u64 + 1).unwrap();
f32_index.add(node, vec).unwrap();
i8_index.add(node, vec).unwrap();
}
let mut total_recall = 0.0;
let num_queries = 10;
for i in 0..num_queries {
let query = &vectors[i * 100];
let f32_results = f32_index.search(query, 10).unwrap();
let i8_results = i8_index.search(query, 10).unwrap();
total_recall += calculate_recall(&f32_results, &i8_results);
}
let avg_recall = total_recall / num_queries as f64;
assert!(
avg_recall >= 0.80,
"I8 recall {:.2}% is below 80% threshold",
avg_recall * 100.0
);
}
#[test]
fn test_quantization_preserved() {
let f16_index = HnswIndexBuilder::new(64, DistanceMetric::Cosine)
.quantization(Quantization::F16)
.build()
.unwrap();
assert_eq!(f16_index.quantization(), Quantization::F16);
let i8_index = HnswIndexBuilder::new(64, DistanceMetric::Cosine)
.quantization(Quantization::I8)
.build()
.unwrap();
assert_eq!(i8_index.quantization(), Quantization::I8);
}