use precinct::{AxisBox, Region, RegionIndex, SearchParams};
use rand::Rng;
use std::time::Instant;
struct Config {
n: usize,
dim: usize,
n_queries: usize,
k: usize,
width_mean: f32,
width_std: f32,
label: &'static str,
}
fn generate_boxes(cfg: &Config, rng: &mut impl Rng) -> Vec<AxisBox> {
(0..cfg.n)
.map(|_| {
let center: Vec<f32> = (0..cfg.dim).map(|_| rng.random_range(-1.0..1.0)).collect();
let half_widths: Vec<f32> = (0..cfg.dim)
.map(|_| (cfg.width_mean + rng.random_range(-1.0..1.0) * cfg.width_std).max(0.01))
.collect();
AxisBox::from_center_offset(center, half_widths)
})
.collect()
}
fn generate_queries(cfg: &Config, rng: &mut impl Rng) -> Vec<Vec<f32>> {
(0..cfg.n_queries)
.map(|_| (0..cfg.dim).map(|_| rng.random_range(-1.0..1.0)).collect())
.collect()
}
fn recall_at_k(exact: &[(u32, f32)], approx: &[(u32, f32)], k: usize) -> f32 {
let exact_ids: std::collections::HashSet<u32> =
exact.iter().take(k).map(|(id, _)| *id).collect();
let approx_ids: std::collections::HashSet<u32> =
approx.iter().take(k).map(|(id, _)| *id).collect();
let hits = exact_ids.intersection(&approx_ids).count();
hits as f32 / k as f32
}
fn run_benchmark(cfg: &Config) {
let mut rng = rand::rng();
let boxes = generate_boxes(cfg, &mut rng);
let queries = generate_queries(cfg, &mut rng);
let t = Instant::now();
let mut index = RegionIndex::new(cfg.dim, Default::default()).unwrap();
for (i, b) in boxes.iter().enumerate() {
index.add(i as u32, b.clone());
}
index.build().unwrap();
let build_ms = t.elapsed().as_millis();
let t = Instant::now();
let ground_truth: Vec<Vec<(u32, f32)>> = queries
.iter()
.map(|q| index.search_exhaustive(q, cfg.k))
.collect();
let exhaustive_ms = t.elapsed().as_millis();
println!(
"\n=== {} (n={}, dim={}, width_mean={:.2}) ===",
cfg.label, cfg.n, cfg.dim, cfg.width_mean
);
println!("Build: {}ms | Exhaustive: {}ms", build_ms, exhaustive_ms);
for overretrieve in [1, 2, 5, 10, 20, 50] {
let t = Instant::now();
let mut total_recall = 0.0f32;
for (qi, q) in queries.iter().enumerate() {
let params = SearchParams {
ef: 200,
overretrieve,
};
let approx = index.search(q, cfg.k, params).unwrap();
total_recall += recall_at_k(&ground_truth[qi], &approx, cfg.k);
}
let mean_recall = total_recall / cfg.n_queries as f32;
let search_ms = t.elapsed().as_millis();
println!(
" overretrieve={:>3}x recall@{}={:.4} search={}ms ({:.1} qps)",
overretrieve,
cfg.k,
mean_recall,
search_ms,
cfg.n_queries as f64 / (search_ms as f64 / 1000.0)
);
}
let mut inside_count = 0usize;
let mut total_count = 0usize;
for (qi, q) in queries.iter().enumerate() {
for (id, _) in ground_truth[qi].iter().take(cfg.k) {
let b = &boxes[*id as usize];
if b.contains(q) {
inside_count += 1;
}
total_count += 1;
}
}
println!(
" (ground truth: {:.1}% of top-{} results contain the query point)",
100.0 * inside_count as f64 / total_count as f64,
cfg.k
);
}
fn main() {
let configs = vec![
Config {
n: 10_000,
dim: 128,
n_queries: 200,
k: 10,
width_mean: 0.01,
width_std: 0.005,
label: "narrow (d=128, w=0.01)",
},
Config {
n: 10_000,
dim: 128,
n_queries: 200,
k: 10,
width_mean: 0.1,
width_std: 0.05,
label: "medium (d=128, w=0.1)",
},
Config {
n: 10_000,
dim: 128,
n_queries: 200,
k: 10,
width_mean: 0.5,
width_std: 0.2,
label: "wide (d=128, w=0.5)",
},
Config {
n: 10_000,
dim: 128,
n_queries: 200,
k: 10,
width_mean: 0.2,
width_std: 0.3,
label: "mixed hierarchy (d=128, w=0.2+/-0.3)",
},
Config {
n: 10_000,
dim: 400,
n_queries: 100,
k: 10,
width_mean: 0.1,
width_std: 0.05,
label: "medium (d=400, w=0.1)",
},
Config {
n: 50_000,
dim: 128,
n_queries: 100,
k: 10,
width_mean: 0.1,
width_std: 0.05,
label: "scale 50K (d=128, w=0.1)",
},
];
for cfg in &configs {
run_benchmark(cfg);
}
}