#![cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
use hive_gpu::cuda::{CudaContext, CudaIvfIndex};
use hive_gpu::types::{GpuDistanceMetric, GpuVector, IvfConfig};
use std::collections::HashMap;
use std::sync::Arc;
fn skip_if_no_gpu() -> bool {
if !CudaContext::is_available() {
eprintln!("[cuda_ivf] no CUDA device detected; test is a no-op");
return true;
}
false
}
struct SeededRng(u64);
impl SeededRng {
fn new(seed: u64) -> Self {
Self(seed)
}
fn next_f32(&mut self) -> f32 {
self.0 = self
.0
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
((self.0 >> 11) as f32 / (1u64 << 53) as f32) * 2.0 - 1.0
}
fn gen_vec(&mut self, dim: usize) -> Vec<f32> {
(0..dim).map(|_| self.next_f32()).collect()
}
}
fn make_vec(id: &str, data: Vec<f32>) -> GpuVector {
GpuVector {
id: id.to_string(),
data,
metadata: HashMap::new(),
}
}
fn clustered_dataset(
n_centers: usize,
n_per_center: usize,
dim: usize,
seed: u64,
) -> Vec<GpuVector> {
let mut rng = SeededRng::new(seed);
let mut centers = Vec::with_capacity(n_centers);
for _ in 0..n_centers {
centers.push(rng.gen_vec(dim).iter().map(|x| x * 5.0).collect::<Vec<_>>());
}
let mut out = Vec::with_capacity(n_centers * n_per_center);
let mut counter = 0usize;
for center in ¢ers {
for _ in 0..n_per_center {
let noise: Vec<f32> = (0..dim).map(|_| rng.next_f32() * 0.2).collect();
let data: Vec<f32> = center.iter().zip(&noise).map(|(a, b)| a + b).collect();
out.push(make_vec(&format!("v{counter}"), data));
counter += 1;
}
}
out
}
fn cpu_top_k_dot(vectors: &[GpuVector], query: &[f32], k: usize) -> Vec<String> {
let mut scored: Vec<(&str, f32)> = vectors
.iter()
.map(|v| {
let dot = v.data.iter().zip(query).map(|(a, b)| a * b).sum::<f32>();
(v.id.as_str(), dot)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored
.into_iter()
.take(k)
.map(|(id, _)| id.to_string())
.collect()
}
fn cpu_top_k_l2(vectors: &[GpuVector], query: &[f32], k: usize) -> Vec<String> {
let mut scored: Vec<(&str, f32)> = vectors
.iter()
.map(|v| {
let d = v
.data
.iter()
.zip(query)
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>();
(v.id.as_str(), d)
})
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
scored
.into_iter()
.take(k)
.map(|(id, _)| id.to_string())
.collect()
}
fn recall(returned: &[String], ground_truth: &[String]) -> f32 {
let gt_set: std::collections::HashSet<&str> = ground_truth.iter().map(|s| s.as_str()).collect();
let hits = returned
.iter()
.filter(|id| gt_set.contains(id.as_str()))
.count();
hits as f32 / ground_truth.len() as f32
}
#[test]
fn new_rejects_bad_config() {
if skip_if_no_gpu() {
return;
}
let ctx = Arc::new(CudaContext::new().unwrap());
let err = CudaIvfIndex::new(
ctx.clone(),
0,
GpuDistanceMetric::DotProduct,
IvfConfig::default(),
)
.expect_err("dim=0 must fail");
assert!(format!("{err}").to_lowercase().contains("dimension"));
let err = CudaIvfIndex::new(
ctx.clone(),
16,
GpuDistanceMetric::DotProduct,
IvfConfig {
n_list: 0,
..IvfConfig::default()
},
)
.expect_err("n_list=0 must fail");
assert!(format!("{err}").contains("n_list"));
let err = CudaIvfIndex::new(
ctx,
16,
GpuDistanceMetric::DotProduct,
IvfConfig {
n_list: 8,
nprobe: 16,
..IvfConfig::default()
},
)
.expect_err("nprobe > n_list must fail");
assert!(format!("{err}").contains("nprobe"));
}
#[test]
fn build_rejects_empty_input() {
if skip_if_no_gpu() {
return;
}
let ctx = Arc::new(CudaContext::new().unwrap());
let mut idx = CudaIvfIndex::new(
ctx,
8,
GpuDistanceMetric::DotProduct,
IvfConfig {
n_list: 4,
nprobe: 2,
..IvfConfig::default()
},
)
.unwrap();
assert!(idx.build(&[]).is_err());
}
#[test]
fn build_rejects_too_few_vectors() {
if skip_if_no_gpu() {
return;
}
let ctx = Arc::new(CudaContext::new().unwrap());
let mut idx = CudaIvfIndex::new(
ctx,
8,
GpuDistanceMetric::DotProduct,
IvfConfig {
n_list: 16,
nprobe: 4,
..IvfConfig::default()
},
)
.unwrap();
let vectors: Vec<GpuVector> = (0..8)
.map(|i| make_vec(&format!("v{i}"), vec![i as f32; 8]))
.collect();
assert!(idx.build(&vectors).is_err());
}
#[test]
fn build_populates_all_clusters_balanced_on_synthetic_data() {
if skip_if_no_gpu() {
return;
}
let ctx = Arc::new(CudaContext::new().unwrap());
let n_centers = 8;
let per_center = 50;
let dim = 16;
let vectors = clustered_dataset(n_centers, per_center, dim, 42);
let mut idx = CudaIvfIndex::new(
ctx,
dim,
GpuDistanceMetric::DotProduct,
IvfConfig {
n_list: n_centers,
nprobe: 1,
training_sample_size: vectors.len(),
kmeans_iters: 20,
seed: Some(7),
},
)
.unwrap();
idx.build(&vectors).unwrap();
assert!(idx.is_trained());
assert_eq!(idx.vector_count(), n_centers * per_center);
let mut rng = SeededRng::new(42);
let mut centers = Vec::new();
for _ in 0..n_centers {
centers.push(rng.gen_vec(dim).iter().map(|x| x * 5.0).collect::<Vec<_>>());
}
for center in ¢ers {
let results = idx.search(center, 10).unwrap();
assert_eq!(results.len(), 10);
}
}
#[test]
fn set_nprobe_validates_and_applies() {
if skip_if_no_gpu() {
return;
}
let ctx = Arc::new(CudaContext::new().unwrap());
let n_centers = 8;
let dim = 8;
let vectors = clustered_dataset(n_centers, 20, dim, 1);
let mut idx = CudaIvfIndex::new(
ctx,
dim,
GpuDistanceMetric::DotProduct,
IvfConfig {
n_list: n_centers,
nprobe: 2,
training_sample_size: vectors.len(),
kmeans_iters: 10,
seed: Some(3),
},
)
.unwrap();
idx.build(&vectors).unwrap();
assert_eq!(idx.nprobe(), 2);
idx.set_nprobe(4).unwrap();
assert_eq!(idx.nprobe(), 4);
assert!(idx.set_nprobe(0).is_err());
assert!(idx.set_nprobe(n_centers + 1).is_err());
}
#[test]
fn recall_at_10_against_bruteforce_dotproduct() {
if skip_if_no_gpu() {
return;
}
let ctx = Arc::new(CudaContext::new().unwrap());
let n = 10_000;
let dim = 64;
let mut rng = SeededRng::new(123);
let vectors: Vec<GpuVector> = (0..n)
.map(|i| make_vec(&format!("v{i}"), rng.gen_vec(dim)))
.collect();
let n_list = 128;
let mut idx = CudaIvfIndex::new(
ctx,
dim,
GpuDistanceMetric::DotProduct,
IvfConfig {
n_list,
nprobe: n_list / 4,
training_sample_size: 2_048,
kmeans_iters: 15,
seed: Some(7),
},
)
.unwrap();
idx.build(&vectors).unwrap();
let mut total_recall = 0.0f32;
let n_queries = 50;
for _ in 0..n_queries {
let q = rng.gen_vec(dim);
let gt = cpu_top_k_dot(&vectors, &q, 10);
let got = idx.search(&q, 10).unwrap();
let got_ids: Vec<String> = got.into_iter().map(|r| r.id).collect();
total_recall += recall(&got_ids, >);
}
let mean_recall = total_recall / n_queries as f32;
assert!(
mean_recall >= 0.70,
"recall@10 was {mean_recall:.3}, expected >= 0.70 at nprobe=n_list/4"
);
eprintln!("[cuda_ivf] recall@10 = {mean_recall:.3} at nprobe = n_list/4");
}
#[test]
fn recall_at_10_against_bruteforce_euclidean() {
if skip_if_no_gpu() {
return;
}
let ctx = Arc::new(CudaContext::new().unwrap());
let n = 10_000;
let dim = 64;
let mut rng = SeededRng::new(456);
let vectors: Vec<GpuVector> = (0..n)
.map(|i| make_vec(&format!("v{i}"), rng.gen_vec(dim)))
.collect();
let n_list = 128;
let mut idx = CudaIvfIndex::new(
ctx,
dim,
GpuDistanceMetric::Euclidean,
IvfConfig {
n_list,
nprobe: n_list / 4,
training_sample_size: 2_048,
kmeans_iters: 15,
seed: Some(11),
},
)
.unwrap();
idx.build(&vectors).unwrap();
let mut total_recall = 0.0f32;
let n_queries = 50;
for _ in 0..n_queries {
let q = rng.gen_vec(dim);
let gt = cpu_top_k_l2(&vectors, &q, 10);
let got = idx.search(&q, 10).unwrap();
let got_ids: Vec<String> = got.into_iter().map(|r| r.id).collect();
total_recall += recall(&got_ids, >);
}
let mean_recall = total_recall / n_queries as f32;
assert!(
mean_recall >= 0.75,
"recall@10 for L2 was {mean_recall:.3}, expected >= 0.75 at nprobe=n_list/4"
);
eprintln!("[cuda_ivf] L2 recall@10 = {mean_recall:.3}");
}
#[test]
fn higher_nprobe_increases_recall() {
if skip_if_no_gpu() {
return;
}
let ctx = Arc::new(CudaContext::new().unwrap());
let n = 5_000;
let dim = 32;
let mut rng = SeededRng::new(777);
let vectors: Vec<GpuVector> = (0..n)
.map(|i| make_vec(&format!("v{i}"), rng.gen_vec(dim)))
.collect();
let n_list = 64;
let mut idx = CudaIvfIndex::new(
ctx,
dim,
GpuDistanceMetric::DotProduct,
IvfConfig {
n_list,
nprobe: 2,
training_sample_size: 1_024,
kmeans_iters: 12,
seed: Some(2024),
},
)
.unwrap();
idx.build(&vectors).unwrap();
let query = rng.gen_vec(dim);
let gt = cpu_top_k_dot(&vectors, &query, 10);
let ids_low: Vec<String> = idx
.search(&query, 10)
.unwrap()
.into_iter()
.map(|r| r.id)
.collect();
let low_recall = recall(&ids_low, >);
idx.set_nprobe(n_list).unwrap();
let ids_full: Vec<String> = idx
.search(&query, 10)
.unwrap()
.into_iter()
.map(|r| r.id)
.collect();
let full_recall = recall(&ids_full, >);
assert!(
full_recall >= low_recall,
"full-scan recall {full_recall:.3} should beat nprobe=2 recall {low_recall:.3}"
);
assert!(
full_recall >= 0.95,
"nprobe = n_list should yield near-perfect recall, got {full_recall:.3}"
);
}