use byteorder::{ByteOrder, LittleEndian};
use gnuplot::*;
use hnsw::*;
use rand::distributions::Standard;
use rand::{Rng, SeedableRng};
use rand_pcg::Pcg64;
use space::Metric;
use space::Neighbor;
use std::cell::RefCell;
use std::io::Read;
use std::path::PathBuf;
use structopt::StructOpt;
struct Euclidean;
impl Metric<&[f32]> for Euclidean {
type Unit = u32;
fn distance(&self, a: &&[f32], b: &&[f32]) -> u32 {
a.iter()
.zip(b.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f32>()
.sqrt()
.to_bits()
}
}
#[derive(Debug, StructOpt)]
#[structopt(name = "recall", about = "Generates recall graphs for HNSW")]
struct Opt {
#[structopt(short = "m", long = "max_edges", default_value = "24")]
m: usize,
#[structopt(short = "s", long = "size", default_value = "10000")]
size: usize,
#[structopt(short = "q", long = "queries", default_value = "10000")]
num_queries: usize,
#[structopt(short = "l", long = "dimensions", default_value = "64")]
dimensions: usize,
#[structopt(short = "b", long = "beginning_ef", default_value = "1")]
beginning_ef: usize,
#[structopt(short = "e", long = "ending_ef", default_value = "64")]
ending_ef: usize,
#[structopt(short = "k", long = "neighbors", default_value = "2")]
k: usize,
#[structopt(short = "f", long = "file")]
file: Option<PathBuf>,
#[structopt(short = "d", long = "descriptor_stride", default_value = "64")]
descriptor_stride: usize,
#[structopt(short = "c", long = "ef_construction", default_value = "400")]
ef_construction: usize,
}
fn process<const M: usize, const M0: usize>(opt: &Opt) -> (Vec<f64>, Vec<f64>) {
assert!(
opt.k <= opt.size,
"You must choose a dataset size larger or equal to the test search size"
);
let rng = Pcg64::from_seed([5; 32]);
let (search_space, query_strings): (Vec<f32>, Vec<f32>) = if let Some(filepath) = &opt.file {
eprintln!(
"Reading {} search space descriptors of size {} f32s from file \"{}\"...",
opt.size,
opt.descriptor_stride,
filepath.display()
);
let mut file = std::fs::File::open(filepath).expect("unable to open file");
let mut search_space = vec![0u8; opt.size * opt.descriptor_stride * 4];
file.read_exact(&mut search_space).expect(
"unable to read enough search descriptors from the file (try decreasing -s/-q)",
);
let search_space = search_space
.chunks_exact(4)
.map(LittleEndian::read_f32)
.collect();
eprintln!("Done.");
eprintln!(
"Reading {} query descriptors of size {} f32s from file \"{}\"...",
opt.num_queries,
opt.descriptor_stride,
filepath.display()
);
let mut query_strings = vec![0u8; opt.num_queries * opt.descriptor_stride * 4];
file.read_exact(&mut query_strings)
.expect("unable to read enough query descriptors from the file (try decreasing -q/-s)");
let query_strings = query_strings
.chunks_exact(4)
.map(LittleEndian::read_f32)
.collect();
eprintln!("Done.");
(search_space, query_strings)
} else {
eprintln!("Generating {} random bitstrings...", opt.size);
let search_space: Vec<f32> = rng
.sample_iter(&Standard)
.take(opt.size * opt.descriptor_stride)
.collect();
eprintln!("Done.");
let rng = Pcg64::from_seed([6; 32]);
eprintln!(
"Generating {} independent random query strings...",
opt.num_queries
);
let query_strings: Vec<f32> = rng
.sample_iter(&Standard)
.take(opt.num_queries * opt.descriptor_stride)
.collect();
eprintln!("Done.");
(search_space, query_strings)
};
let search_space: Vec<_> = search_space
.chunks_exact(opt.descriptor_stride)
.map(|c| &c[..opt.dimensions])
.collect();
let query_strings: Vec<_> = query_strings
.chunks_exact(opt.descriptor_stride)
.map(|c| &c[..opt.dimensions])
.collect();
eprintln!(
"Computing the correct nearest neighbor distance for all {} queries...",
opt.num_queries
);
let correct_worst_distances: Vec<_> = query_strings
.iter()
.cloned()
.map(|feature| {
let mut v = vec![];
for distance in search_space.iter().map(|n| Euclidean.distance(n, &feature)) {
let pos = v.binary_search(&distance).unwrap_or_else(|e| e);
v.insert(pos, distance);
if v.len() > opt.k {
v.resize_with(opt.k, || unreachable!());
}
}
v.into_iter().take(opt.k).last().unwrap()
})
.collect();
eprintln!("Done.");
eprintln!("Generating HNSW...");
let mut hnsw: Hnsw<_, _, Pcg64, M, M0> = Hnsw::new_params(
Euclidean,
Params::new().ef_construction(opt.ef_construction),
);
let mut searcher: Searcher<_> = Searcher::default();
for feature in &search_space {
hnsw.insert(*feature, &mut searcher);
}
eprintln!("Done.");
eprintln!("Computing recall graph...");
let efs = opt.beginning_ef..=opt.ending_ef;
let state = RefCell::new((searcher, query_strings.iter().cloned().enumerate().cycle()));
let (recalls, times): (Vec<f64>, Vec<f64>) = efs
.map(|ef| {
let correct = RefCell::new(0usize);
let dest = vec![
Neighbor {
index: !0,
distance: !0,
};
opt.k
];
let stats = easybench::bench_env(dest, |mut dest| {
let mut refmut = state.borrow_mut();
let (searcher, query) = &mut *refmut;
let (ix, query_feature) = query.next().unwrap();
let correct_worst_distance = correct_worst_distances[ix];
for &mut neighbor in hnsw.nearest(&query_feature, ef, searcher, &mut dest) {
if Euclidean.distance(&search_space[neighbor.index], &query_feature)
<= correct_worst_distance
{
*correct.borrow_mut() += 1;
}
}
});
(stats, correct.into_inner())
})
.fold(
(vec![], vec![]),
|(mut recalls, mut times), (stats, correct)| {
times.push((stats.ns_per_iter * 0.1f64.powi(9)).recip());
recalls.push(correct as f64 / (stats.iterations * opt.k) as f64);
(recalls, times)
},
);
eprintln!("Done.");
(recalls, times)
}
fn main() {
let opt = Opt::from_args();
let (recalls, times) = {
match opt.m {
4 => process::<4, 8>(&opt),
8 => process::<8, 16>(&opt),
12 => process::<12, 24>(&opt),
16 => process::<16, 32>(&opt),
20 => process::<20, 40>(&opt),
24 => process::<24, 48>(&opt),
28 => process::<28, 56>(&opt),
32 => process::<32, 64>(&opt),
36 => process::<36, 72>(&opt),
40 => process::<40, 80>(&opt),
44 => process::<44, 88>(&opt),
48 => process::<48, 96>(&opt),
52 => process::<52, 104>(&opt),
_ => {
eprintln!("Only M between 4 and 52 inclusive and multiples of 4 are allowed");
return;
}
}
};
let mut fg = Figure::new();
fg.axes2d()
.set_title(
&format!(
"{}-NN Recall Graph (dimensions = {}, size = {}, M = {})",
opt.k, opt.dimensions, opt.size, opt.m
),
&[],
)
.set_x_label("Recall Rate", &[])
.set_y_label("Lookups per second", &[])
.lines(&recalls, ×, &[LineWidth(2.0), Color("blue")])
.set_y_ticks(Some((Auto, 2)), &[], &[])
.set_grid_options(true, &[LineStyle(DotDotDash), Color("black")])
.set_minor_grid_options(&[LineStyle(SmallDot), Color("red")])
.set_x_grid(true)
.set_y_grid(true)
.set_y_minor_grid(true);
fg.show().expect("unable to show gnuplot");
}