use diskann_benchmark_runner::registry::Benchmarks;
crate::utils::stub_impl!("scalar-quantization", inputs::async_::IndexSQOperation);
pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) {
#[cfg(feature = "scalar-quantization")]
{
use half::f16;
benchmarks.register::<imp::ScalarQuantized<'static, 8, f32>>("async-sq-8-bit-f32");
benchmarks.register::<imp::ScalarQuantized<'static, 4, f32>>("async-sq-4-bit-f32");
benchmarks.register::<imp::ScalarQuantized<'static, 2, f32>>("async-sq-2-bit-f32");
benchmarks.register::<imp::ScalarQuantized<'static, 1, f32>>("async-sq-1-bit-f32");
benchmarks.register::<imp::ScalarQuantized<'static, 8, f16>>("async-sq-8-bit-f16");
benchmarks.register::<imp::ScalarQuantized<'static, 4, f16>>("async-sq-4-bit-f16");
benchmarks.register::<imp::ScalarQuantized<'static, 2, f16>>("async-sq-2-bit-f16");
benchmarks.register::<imp::ScalarQuantized<'static, 1, f16>>("async-sq-1-bit-f16");
benchmarks.register::<imp::ScalarQuantized<'static, 1, i8>>("async-sq-1-bit-i8");
}
#[cfg(not(feature = "scalar-quantization"))]
imp::register("async-pq", benchmarks);
}
#[cfg(feature = "scalar-quantization")]
mod imp {
use std::{io::Write, sync::Arc};
use anyhow::Context;
use diskann_benchmark_runner::{
describeln,
dispatcher::{Description, DispatchRule, FailureScore, MatchScore},
utils::{datatype, MicroSeconds},
Benchmark, Checkpoint, Output,
};
use diskann_providers::{
index::diskann_async::{self},
model::{
configuration::IndexConfiguration,
graph::provider::async_::{common, inmem},
},
};
use diskann_utils::views::{Matrix, MatrixView};
use half::f16;
use crate::{
backend::index::{
benchmarks::{run_build, run_search_outer, BuildAndSearch, FullPrecision},
build::{self, load_index, only_single_insert, save_index, BuildStats},
result::QuantBuildResult,
},
inputs::async_::{IndexSQOperation, IndexSource},
utils::{self, datafiles},
};
pub(super) struct ScalarQuantized<'a, const NBITS: usize, T> {
input: &'a IndexSQOperation,
_type: std::marker::PhantomData<T>,
}
impl<'a, const NBITS: usize, T> ScalarQuantized<'a, NBITS, T> {
fn new(input: &'a IndexSQOperation) -> Self {
assert_eq!(input.num_bits, NBITS);
Self {
input,
_type: std::marker::PhantomData,
}
}
}
macro_rules! impl_sq_build {
($N:literal, $T: ty) => {
impl Benchmark for ScalarQuantized<'static, $N, $T> {
type Input = IndexSQOperation;
type Output = QuantBuildResult;
fn try_match(input: &IndexSQOperation) -> Result<MatchScore, FailureScore> {
let mut failure_score: Option<u32> = None;
match input.index_operation.source {
IndexSource::Load(_) => {}
IndexSource::Build(ref build) => {
if build.multi_insert.is_some() {
failure_score = Some(1);
}
}
}
if <FullPrecision<'static, $T> as Benchmark>::try_match(&input.index_operation)
.is_err()
{
*failure_score.get_or_insert(0) += 1;
}
if input.num_bits != $N {
*failure_score.get_or_insert(0) += 10 + ($N as usize).abs_diff(input.num_bits) as u32;
}
match failure_score {
None => Ok(MatchScore(0)),
Some(score) => Err(FailureScore(score)),
}
}
fn description(
f: &mut std::fmt::Formatter<'_>,
input: Option<&IndexSQOperation>,
) -> std::fmt::Result {
match input {
None => {
describeln!(
f,
"- Index Build and Search using {} scalar quantized bits",
$N
)?;
describeln!(
f,
"- Requires `{}` data",
Description::<datatype::DataType, datatype::Type<$T>>::new(),
)?;
describeln!(f, "- Implements `squared_l2` or `inner_product` distance",)?;
describeln!(f, "- Does not support multi-insert")?;
}
Some(input) => {
if input.num_bits != $N {
describeln!(
f,
"- Expected {} bits, instead got {}",
$N,
input.num_bits
)?;
}
let mut check_match = |data_type: &datatype::DataType| {
if datatype::Type::<$T>::try_match(data_type).is_err() {
describeln!(
f,
"- Only `{}` data type is supported. Instead, got {}",
Description::<datatype::DataType, datatype::Type<$T>>::new(),
data_type
).unwrap();
}
};
match &input.index_operation.source {
IndexSource::Load(load) => {
check_match(&load.data_type);
}
IndexSource::Build(build) => {
check_match(&build.data_type);
if build.multi_insert.is_some() {
describeln!(
f,
"- Scalar Quantization does not support multi-insert"
)?;
}
}
}
}
}
Ok(())
}
fn run(
input: &IndexSQOperation,
checkpoint: Checkpoint<'_>,
output: &mut dyn Output,
) -> anyhow::Result<QuantBuildResult> {
let sq = ScalarQuantized::<$N, $T>::new(input);
BuildAndSearch::run(sq, checkpoint, output)
}
}
impl<'a> BuildAndSearch<'a> for ScalarQuantized<'a, $N, $T> {
type Data = QuantBuildResult;
fn run(
self,
checkpoint: Checkpoint<'_>,
mut output: &mut dyn Output,
) -> Result<Self::Data, anyhow::Error> {
writeln!(output, "{}", self.input)?;
let (index, build_stats, quant_training_time) = match &self.input.index_operation.source {
IndexSource::Load(load) => {
let index_config: &IndexConfiguration = &load.to_config()?;
let index = {
utils::tokio::block_on(load_index::<_>(&load.load_path, index_config))?
};
(Arc::new(index), None::<BuildStats>, MicroSeconds::new(0))
}
IndexSource::Build(build) => {
let data: Arc<Matrix<$T>> =
Arc::new(datafiles::load_dataset(datafiles::BinFile(&build.data))?);
let start = std::time::Instant::now();
let quantizer = diskann_quantization::scalar::train::ScalarQuantizationParameters::new(
diskann_quantization::num::Positive::new(self.input.standard_deviations).context(
"please file a bug report, this should not have made it past the\
front end",
)?,
)
.train(data.as_view());
let create_index = |data_view: MatrixView<$T>| {
let index = diskann_async::new_quant_index::<$T, _, _>(
self.input.try_as_config()?.build()?,
self.input
.inmem_parameters(data_view.nrows(), data_view.ncols())?,
inmem::WithBits::<$N>::new(quantizer),
common::NoDeletes,
)?;
build::set_start_points(index.provider(), data_view, build.start_point_strategy)?;
Ok(index)
};
let quant_training_time: MicroSeconds = start.elapsed().into();
let (index, build_stats) = run_build(
&build,
common::Quantized,
None,
output,
create_index,
only_single_insert,
)?;
if let Some(save_path) = &build.save_path {
utils::tokio::block_on(save_index(index.clone(), save_path))?;
}
(index, Some(build_stats), quant_training_time)
}
};
let build = if self.input.use_fp_for_search {
run_search_outer(
&self.input.index_operation.search_phase,
common::FullPrecision,
index,
build_stats,
checkpoint,
)?
} else {
run_search_outer(
&self.input.index_operation.search_phase,
common::Quantized,
index,
build_stats,
checkpoint,
)?
};
let result = QuantBuildResult {
quant_training_time,
build,
};
writeln!(output, "\n\n{}", result)?;
Ok(result)
}
}
};
}
impl_sq_build!(8, f32);
impl_sq_build!(4, f32);
impl_sq_build!(2, f32);
impl_sq_build!(1, f32);
impl_sq_build!(8, f16);
impl_sq_build!(4, f16);
impl_sq_build!(2, f16);
impl_sq_build!(1, f16);
impl_sq_build!(1, i8);
}