#![allow(clippy::needless_range_loop)]
use anndists::dist::DistL2;
use cpu_time::ProcessTime;
use rayon::prelude::*;
use rust_diskann::{DiskANN, DiskAnnError, DiskAnnParams};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
mod utils;
use utils::*;
fn euclid(a: &[f32], b: &[f32]) -> f32 {
let mut s = 0.0f32;
for j in 0..a.len() {
let d = a[j] - b[j];
s += d * d;
}
s.sqrt()
}
fn main() -> Result<(), DiskAnnError> {
let fname = String::from("./fashion-mnist-784-euclidean.hdf5");
println!("\n\nDiskANN benchmark on {:?}", fname);
let anndata = annhdf5::AnnBenchmarkData::new(fname.clone())
.expect("Failed to load fashion-mnist-784-euclidean.hdf5");
let knbn_max = anndata.test_distances.dim().1;
let nb_elem = anndata.train_data.len();
let nb_search = anndata.test_data.len();
println!("Train size : {}", nb_elem);
println!("Test size : {}", nb_search);
println!("Ground-truth k per query in file: {}", knbn_max);
let max_degree = 48;
let build_beam_width = 128; let alpha = 1.2; let passes = 1usize; let extra_seeds = 1usize;
let search_k = 10; let search_beam = 384;
let train_vectors: Vec<Vec<f32>> = anndata
.train_data
.iter()
.map(|pair| pair.0.clone())
.collect();
let index_path = "diskann_mnist.db";
let index = if !std::path::Path::new(index_path).exists() {
println!(
"\nBuilding DiskANN index: n={}, dim={}, max_degree={}, \
build_beam={}, alpha={}, passes={}, extra_seeds={}",
train_vectors.len(),
train_vectors[0].len(),
max_degree,
build_beam_width,
alpha,
passes,
extra_seeds
);
let params = DiskAnnParams {
max_degree,
build_beam_width,
alpha,
passes,
extra_seeds,
};
let start_cpu = ProcessTime::now();
let start_wall = SystemTime::now();
let idx = DiskANN::<f32, DistL2>::build_index_with_params(
&train_vectors,
DistL2 {},
index_path,
params,
)?;
let cpu_time: Duration = start_cpu.elapsed();
let wall_time = start_wall.elapsed().unwrap();
println!(
"Build complete. CPU time: {:?}, wall time: {:?}",
cpu_time, wall_time
);
idx
} else {
println!("\nIndex file {} exists, opening…", index_path);
let start_wall = SystemTime::now();
let idx = DiskANN::<f32, DistL2>::open_index_with(index_path, DistL2)?;
let wall_time = start_wall.elapsed().unwrap();
println!(
"Opened index: {} vectors, dim={}, metric={} in {:?}",
idx.num_vectors, idx.dim, idx.distance_name, wall_time
);
idx
};
let index = Arc::new(index);
println!(
"\nSearching {} queries with k={}, beam_width={} …",
nb_search, search_k, search_beam
);
let start_cpu = ProcessTime::now();
let start_wall = SystemTime::now();
let results_dists: Vec<Vec<f32>> = anndata
.test_data
.par_iter()
.map(|q| {
let ids = index.search(q, search_k, search_beam);
let mut ds = Vec::with_capacity(ids.len());
for &id in &ids {
let v = index.get_vector(id as usize);
ds.push(euclid(q, &v));
}
ds.sort_by(|a, b| a.partial_cmp(b).unwrap());
ds
})
.collect();
let cpu_time = start_cpu.elapsed();
let wall_time = start_wall.elapsed().unwrap();
let mut recalls: Vec<usize> = Vec::with_capacity(nb_search);
let mut nb_returned: Vec<usize> = Vec::with_capacity(nb_search);
let mut last_distances_ratio: Vec<f32> = Vec::with_capacity(nb_search);
for i in 0..nb_search {
let true_row = anndata.test_distances.row(i);
let true_k = search_k.min(true_row.len());
let gt_kth = true_row[true_k - 1];
let dists = &results_dists[i];
nb_returned.push(dists.len());
let recall = dists.iter().filter(|x| **x <= gt_kth).count();
recalls.push(recall);
let ratio = if !dists.is_empty() {
dists[dists.len() - 1] / gt_kth
} else {
0.0
};
last_distances_ratio.push(ratio);
}
let knbn = search_k;
let mean_recall = (recalls.iter().sum::<usize>() as f32) / ((knbn * recalls.len()) as f32);
let mean_frac_returned =
(nb_returned.iter().sum::<usize>() as f32) / ((nb_returned.len() * knbn) as f32);
let mean_last_ratio =
last_distances_ratio.iter().sum::<f32>() / (last_distances_ratio.len() as f32);
let search_sys_time_us = wall_time.as_micros() as f32;
let req_per_s = (nb_search as f32) * 1.0e6_f32 / search_sys_time_us;
println!(
"\n mean fraction nb returned by search {:?}",
mean_frac_returned
);
println!("\n last distances ratio {:?}", mean_last_ratio);
println!(
"\n recall rate for {:?} is {:?} , nb req /s {:?}",
anndata.fname, mean_recall, req_per_s
);
println!(
"\n total cpu time for search requests {:?} , system time {:?}",
cpu_time, wall_time
);
Ok(())
}