use diskann_benchmark_runner::registry::Benchmarks;
const NAME: &str = "product-exhaustive-search";
crate::utils::stub_impl!("product-quantization", inputs::exhaustive::Product);
pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) {
#[cfg(feature = "product-quantization")]
benchmarks.register(NAME, imp::ProductQ);
#[cfg(not(feature = "product-quantization"))]
imp::register(NAME, benchmarks)
}
#[cfg(feature = "product-quantization")]
mod imp {
use std::io::Write;
use diskann_benchmark_runner::{
dispatcher::{FailureScore, MatchScore},
utils::{percentiles, MicroSeconds},
Benchmark, Output,
};
use diskann_quantization::{product::train::TrainQuantizer, CompressInto};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::iter::{IndexedParallelIterator, ParallelIterator};
use serde::Serialize;
use crate::{
backend::exhaustive::algos::{self, LinearSearch},
inputs,
utils::{
datafiles::{self, ConvertingLoad},
recall, SimilarityMeasure,
},
};
macro_rules! write_field {
($f:ident, $field:tt, $fmt:literal, $($expr:tt)*) => {
writeln!($f, concat!("{:>19}: ", $fmt), $field, $($expr)*)
}
}
fn make_progress_bar(
message: &str,
count: usize,
draw_target: indicatif::ProgressDrawTarget,
) -> anyhow::Result<ProgressBar> {
let progress = ProgressBar::with_draw_target(Some(count as u64), draw_target);
progress.set_style(ProgressStyle::with_template(&format!(
"{} [{{elapsed_precise}}] {{wide_bar}} {{percent}}",
message
))?);
Ok(progress)
}
#[derive(Debug, Clone, Copy)]
pub(super) struct ProductQ;
impl ProductQ {
pub(super) fn run(
&self,
input: &inputs::exhaustive::Product,
mut output: &mut dyn Output,
) -> anyhow::Result<Results> {
writeln!(output, "{}", input)?;
let data = f32::converting_load(datafiles::BinFile(&input.data), input.data_type)?;
let start = std::time::Instant::now();
let parameters = diskann_quantization::product::train::LightPQTrainingParameters::new(
input.num_pq_centers.get(),
5,
);
let offsets = diskann_providers::model::pq::calculate_chunk_offsets_auto(
data.ncols(),
input.num_pq_chunks.get(),
);
let base = {
let threadpool = rayon::ThreadPoolBuilder::new()
.num_threads(input.compression_threads.get())
.build()?;
threadpool.install(|| -> anyhow::Result<_> {
Ok(parameters.train(
data.as_view(),
diskann_quantization::views::ChunkOffsetsView::new(offsets.as_slice())?,
diskann_quantization::Parallelism::Rayon,
&diskann_quantization::random::StdRngBuilder::new(input.seed),
&diskann_quantization::cancel::DontCancel,
)?)
})?
};
let quantizer = diskann_providers::model::pq::FixedChunkPQTable::new(
data.ncols(),
base.flatten().into(),
vec![0.0; data.ncols()].into(),
offsets.into(),
)?;
let training_time: MicroSeconds = start.elapsed().into();
let start = std::time::Instant::now();
let store = {
let threadpool = rayon::ThreadPoolBuilder::new()
.num_threads(input.compression_threads.get())
.build()?;
let compression_progress =
make_progress_bar("compressing", data.nrows(), output.draw_target())?;
let store = threadpool
.install(|| Store::new(data.as_view(), quantizer, &compression_progress))?;
compression_progress.finish();
store
};
let compression_time: MicroSeconds = start.elapsed().into();
let queries =
f32::converting_load(datafiles::BinFile(&input.search.queries), input.data_type)?;
let groundtruth =
datafiles::load_groundtruth(datafiles::BinFile(&input.search.groundtruth))?;
let search_progress =
make_progress_bar("running search", queries.nrows(), output.draw_target())?;
let threadpool = rayon::ThreadPoolBuilder::new()
.num_threads(input.search.num_threads.get())
.build()?;
let recall_n = input
.search
.recalls
.recall_n
.last()
.ok_or_else(|| anyhow::anyhow!("expected at least one value for `recall_n`"))?;
let plan = Plan {
measure: input.distance,
};
let r = threadpool.install(|| {
algos::linear_search(
&store,
queries.as_view(),
&plan,
*recall_n,
&search_progress,
)
})?;
let recalls = recall::compute_multiple_recalls(
&r.ids,
&groundtruth,
&input.search.recalls.recall_k,
&input.search.recalls.recall_n,
)?;
let search_results = SearchResults::new(r, input.search.num_threads.get(), recalls)?;
search_progress.finish();
let result = Results {
training_time,
compression_time,
search_results,
};
writeln!(output, "\n\n{}", result)?;
Ok(result)
}
}
impl Benchmark for ProductQ {
type Input = inputs::exhaustive::Product;
type Output = Results;
fn try_match(
&self,
_input: &inputs::exhaustive::Product,
) -> Result<MatchScore, FailureScore> {
Ok(MatchScore(0))
}
fn description(
&self,
f: &mut std::fmt::Formatter<'_>,
input: Option<&inputs::exhaustive::Product>,
) -> std::fmt::Result {
if input.is_none() {
writeln!(f, "- Exhaustive search for product quantization",)?;
writeln!(f, "- Requires `float32` data")?;
}
Ok(())
}
fn run(
&self,
input: &inputs::exhaustive::Product,
_checkpoint: diskann_benchmark_runner::Checkpoint<'_>,
output: &mut dyn Output,
) -> anyhow::Result<Results> {
self.run(input, output)
}
}
#[derive(Debug, Serialize)]
pub(super) struct Results {
training_time: MicroSeconds,
compression_time: MicroSeconds,
search_results: SearchResults,
}
impl std::fmt::Display for Results {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write_field!(f, "Training Time", "{}s", self.training_time.as_seconds())?;
write_field!(
f,
"Compression Time",
"{}s",
self.compression_time.as_seconds()
)?;
writeln!(f, "{}", self.search_results)?;
Ok(())
}
}
#[derive(Debug, Serialize)]
struct SearchResults {
num_threads: usize,
time: MicroSeconds,
qps: f64,
mean_preprocess: f64,
p90_preprocess: MicroSeconds,
p99_preprocess: MicroSeconds,
mean_search: f64,
p90_search: MicroSeconds,
p99_search: MicroSeconds,
recalls: Vec<recall::RecallMetrics>,
}
impl SearchResults {
fn new(
mut search: LinearSearch,
num_threads: usize,
recalls: Vec<recall::RecallMetrics>,
) -> Result<Self, percentiles::CannotBeEmpty> {
let preprocess_latency = percentiles::compute_percentiles(&mut search.preprocess)?;
let search_latency = percentiles::compute_percentiles(&mut search.search)?;
let time = search.total;
Ok(Self {
num_threads,
time,
qps: (search.ids.nrows() as f64) / time.as_seconds(),
mean_preprocess: preprocess_latency.mean,
p90_preprocess: preprocess_latency.p90,
p99_preprocess: preprocess_latency.p99,
mean_search: search_latency.mean,
p90_search: search_latency.p90,
p99_search: search_latency.p99,
recalls,
})
}
}
impl std::fmt::Display for SearchResults {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write_field!(f, "Total", "{:.2}s", self.time.as_seconds())?;
write_field!(f, "QPS", "{:.3}", self.qps)?;
write_field!(f, "Num Threads", "{}", self.num_threads)?;
write_field!(
f,
"Preprocess Latency",
"{:.1}us ({:.1})",
self.mean_preprocess,
self.p99_preprocess.as_f64(),
)?;
write_field!(
f,
"Search Latency",
"{:.1}us ({:.1})",
self.mean_search,
self.p99_search.as_f64(),
)?;
writeln!(f)?;
let header = ["K", "N", "Recall (%)"];
let mut table =
diskann_benchmark_runner::utils::fmt::Table::new(header, self.recalls.len());
self.recalls.iter().enumerate().for_each(|(row, r)| {
let mut row = table.row(row);
row.insert(r.recall_k, 0);
row.insert(r.recall_n, 1);
row.insert(format!("{:.3}", 100.0 * r.average), 2);
});
write!(f, "{}", table)
}
}
pub(super) struct Store {
data: diskann_utils::views::Matrix<u8>,
quantizer: diskann_providers::model::pq::FixedChunkPQTable,
}
impl Store {
fn new(
input: diskann_utils::views::MatrixView<f32>,
quantizer: diskann_providers::model::pq::FixedChunkPQTable,
progress: &ProgressBar,
) -> anyhow::Result<Self> {
let mut data =
diskann_utils::views::Matrix::new(0, input.nrows(), quantizer.get_num_chunks());
#[allow(clippy::disallowed_methods)]
data.par_row_iter_mut()
.zip(input.par_row_iter())
.try_for_each(|(d, i)| -> anyhow::Result<()> {
quantizer.compress_into(i, d)?;
progress.inc(1);
Ok(())
})?;
Ok(Self { data, quantizer })
}
}
struct Plan {
measure: SimilarityMeasure,
}
impl algos::QuantStore for Store {
type Item<'a>
= &'a [u8]
where
Self: 'a;
fn iter(&self) -> impl Iterator<Item = Self::Item<'_>> {
self.data.row_iter()
}
}
impl algos::CreateQuantComputer<Store> for Plan {
type Computer<'a> = diskann_providers::model::pq::distance::QueryComputer<
&'a diskann_providers::model::pq::FixedChunkPQTable,
>;
fn create_quant_computer<'a>(
&self,
store: &'a Store,
query: &[f32],
) -> anyhow::Result<Self::Computer<'a>> {
Ok(diskann_providers::model::pq::distance::QueryComputer::new(
&store.quantizer,
self.measure.into(),
query,
None,
)?)
}
}
}