use serde::{Deserialize, Serialize};
use std::{fmt, io::Write};
use diskann::utils::VectorRepr;
use diskann_benchmark_runner::{
benchmark::{PassFail, Regression},
dispatcher::{DispatchRule, FailureScore, MatchScore},
output::Output,
utils::{
datatype::{DataType, Type},
fmt::Table,
num::{relative_change, NonNegativeFinite},
},
Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input,
};
use diskann_providers::storage::FileStorageProvider;
use half::f16;
use crate::{
backend::disk_index::{
build::{build_disk_index, DiskBuildStats},
search::{search_disk_index, DiskSearchStats},
},
inputs::disk::{DiskIndexLoad, DiskIndexOperation, DiskIndexSource},
};
struct DiskIndex<T> {
_vector_type: std::marker::PhantomData<T>,
}
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct DiskIndexStats {
pub(super) build: Option<DiskBuildStats>,
pub(super) search: DiskSearchStats,
}
impl<T> DiskIndex<T>
where
T: VectorRepr,
{
fn new() -> Self {
Self {
_vector_type: std::marker::PhantomData,
}
}
}
impl<T> Benchmark for DiskIndex<T>
where
T: VectorRepr + 'static,
Type<T>: DispatchRule<DataType>,
{
type Input = DiskIndexOperation;
type Output = DiskIndexStats;
fn try_match(&self, input: &DiskIndexOperation) -> Result<MatchScore, FailureScore> {
match &input.source {
DiskIndexSource::Load(load) => Type::<T>::try_match(&load.data_type),
DiskIndexSource::Build(build) => Type::<T>::try_match(&build.data_type),
}
}
fn description(
&self,
f: &mut std::fmt::Formatter<'_>,
input: Option<&DiskIndexOperation>,
) -> std::fmt::Result {
match input {
Some(arg) => match &arg.source {
DiskIndexSource::Load(load) => Type::<T>::description(f, Some(&load.data_type)),
DiskIndexSource::Build(build) => Type::<T>::description(f, Some(&build.data_type)),
},
None => Type::<T>::description(f, None::<&DataType>),
}
}
fn run(
&self,
input: &DiskIndexOperation,
_checkpoint: Checkpoint<'_>,
mut output: &mut dyn Output,
) -> anyhow::Result<DiskIndexStats> {
writeln!(output, "{}", input.source)?;
let (build_stats, index_load) = match &input.source {
DiskIndexSource::Load(load) => Ok((None, (*load).clone())),
DiskIndexSource::Build(build) => build_disk_index::<T, _>(&FileStorageProvider, build)
.map(|stats| {
(
Some(stats),
DiskIndexLoad {
data_type: build.data_type,
load_path: build.save_path.clone(),
},
)
}),
}?;
if let Some(build_stats) = &build_stats {
writeln!(output, "{}", build_stats)?;
}
writeln!(output, "{}", input.search_phase)?;
let search_stats =
search_disk_index::<T, _>(&index_load, &input.search_phase, &FileStorageProvider)?;
writeln!(output, "{}", search_stats)?;
Ok(DiskIndexStats {
build: build_stats,
search: search_stats,
})
}
}
pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) {
benchmarks.register_regression("disk-index-f32", DiskIndex::<f32>::new());
benchmarks.register_regression("disk-index-f16", DiskIndex::<f16>::new());
benchmarks.register_regression("disk-index-u8", DiskIndex::<u8>::new());
benchmarks.register_regression("disk-index-i8", DiskIndex::<i8>::new());
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub(super) struct DiskIndexTolerance {
build_time_regression: NonNegativeFinite,
qps_regression: NonNegativeFinite,
recall_regression: NonNegativeFinite,
mean_ios_regression: NonNegativeFinite,
mean_comps_regression: NonNegativeFinite,
mean_latency_regression: NonNegativeFinite,
p95_latency_regression: NonNegativeFinite,
}
impl DiskIndexTolerance {
const fn tag() -> &'static str {
"disk-index-tolerance"
}
}
impl CheckDeserialization for DiskIndexTolerance {
fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> {
Ok(())
}
}
impl Input for DiskIndexTolerance {
fn tag() -> &'static str {
Self::tag()
}
fn try_deserialize(
serialized: &serde_json::Value,
checker: &mut Checker,
) -> anyhow::Result<Any> {
checker.any(Self::deserialize(serialized)?)
}
fn example() -> anyhow::Result<serde_json::Value> {
const DEFAULT: NonNegativeFinite = match NonNegativeFinite::new(0.10) {
Ok(v) => v,
Err(_) => panic!("use a non-negative finite value"),
};
const RECALL: NonNegativeFinite = match NonNegativeFinite::new(0.01) {
Ok(v) => v,
Err(_) => panic!("use a non-negative finite value"),
};
Ok(serde_json::to_value(DiskIndexTolerance {
build_time_regression: DEFAULT,
qps_regression: DEFAULT,
recall_regression: RECALL,
mean_ios_regression: DEFAULT,
mean_comps_regression: DEFAULT,
mean_latency_regression: DEFAULT,
p95_latency_regression: DEFAULT,
})?)
}
}
#[derive(Clone, Copy)]
enum Direction {
LowerIsBetter,
HigherIsBetter,
}
#[derive(Debug, Serialize)]
struct MetricComparison {
metric: String,
before: f64,
after: f64,
change_pct: String,
tolerance_pct: f64,
passed: bool,
remark: String,
}
#[derive(Debug, Serialize)]
struct DiskIndexCheckResult {
comparisons: Vec<MetricComparison>,
}
impl fmt::Display for DiskIndexCheckResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let header = ["Metric", "Before", "After", "Change", "Tolerance", "Remark"];
let mut table = Table::new(header, self.comparisons.len());
for (i, c) in self.comparisons.iter().enumerate() {
let mut row = table.row(i);
row.insert(c.metric.clone(), 0);
row.insert(format!("{:.3}", c.before), 1);
row.insert(format!("{:.3}", c.after), 2);
row.insert(c.change_pct.clone(), 3);
row.insert(format!("{:.1}%", c.tolerance_pct * 100.0), 4);
if !c.remark.is_empty() {
row.insert(c.remark.clone(), 5);
}
}
table.fmt(f)
}
}
fn check_metric(
name: String,
direction: Direction,
before: f64,
after: f64,
tolerance: NonNegativeFinite,
passed: &mut bool,
) -> MetricComparison {
let (change_pct, remark, metric_passed) = match relative_change(before, after) {
Ok(change) => {
let ok = match direction {
Direction::LowerIsBetter => change <= tolerance.get(),
Direction::HigherIsBetter => -change <= tolerance.get(),
};
if !ok {
*passed = false;
}
(
format!("{:.3}%", change * 100.0),
if ok {
String::new()
} else {
"REGRESSION".to_string()
},
ok,
)
}
Err(e) => {
*passed = false;
("invalid".to_string(), e.to_string(), false)
}
};
MetricComparison {
metric: name,
before,
after,
change_pct,
tolerance_pct: tolerance.get(),
passed: metric_passed,
remark,
}
}
impl<T> Regression for DiskIndex<T>
where
T: VectorRepr + 'static,
Type<T>: DispatchRule<DataType>,
{
type Tolerances = DiskIndexTolerance;
type Pass = DiskIndexCheckResult;
type Fail = DiskIndexCheckResult;
fn check(
&self,
tolerances: &DiskIndexTolerance,
_input: &DiskIndexOperation,
before: &DiskIndexStats,
after: &DiskIndexStats,
) -> anyhow::Result<PassFail<DiskIndexCheckResult, DiskIndexCheckResult>> {
use Direction::{HigherIsBetter, LowerIsBetter};
let mut passed = true;
let mut comparisons = Vec::new();
if let (Some(b_build), Some(a_build)) = (&before.build, &after.build) {
comparisons.push(check_metric(
"build_time".to_string(),
LowerIsBetter,
b_build.build_time_seconds(),
a_build.build_time_seconds(),
tolerances.build_time_regression,
&mut passed,
));
}
anyhow::ensure!(
before.search.search_results_per_l.len() == after.search.search_results_per_l.len(),
"before has {} search_l entries but after has {}",
before.search.search_results_per_l.len(),
after.search.search_results_per_l.len(),
);
for (b_sr, a_sr) in before
.search
.search_results_per_l
.iter()
.zip(after.search.search_results_per_l.iter())
{
anyhow::ensure!(
b_sr.search_l == a_sr.search_l,
"search_l mismatch: before={} after={}",
b_sr.search_l,
a_sr.search_l,
);
let prefix = if before.search.search_results_per_l.len() > 1 {
format!("L{}:", b_sr.search_l)
} else {
String::new()
};
comparisons.push(check_metric(
format!("{prefix}qps"),
HigherIsBetter,
b_sr.qps as f64,
a_sr.qps as f64,
tolerances.qps_regression,
&mut passed,
));
comparisons.push(check_metric(
format!("{prefix}recall"),
HigherIsBetter,
b_sr.recall as f64,
a_sr.recall as f64,
tolerances.recall_regression,
&mut passed,
));
comparisons.push(check_metric(
format!("{prefix}mean_latency"),
LowerIsBetter,
b_sr.mean_latency,
a_sr.mean_latency,
tolerances.mean_latency_regression,
&mut passed,
));
comparisons.push(check_metric(
format!("{prefix}p95_latency"),
LowerIsBetter,
b_sr.p95_latency.as_f64(),
a_sr.p95_latency.as_f64(),
tolerances.p95_latency_regression,
&mut passed,
));
comparisons.push(check_metric(
format!("{prefix}mean_ios"),
LowerIsBetter,
b_sr.mean_ios,
a_sr.mean_ios,
tolerances.mean_ios_regression,
&mut passed,
));
comparisons.push(check_metric(
format!("{prefix}mean_comparisons"),
LowerIsBetter,
b_sr.mean_comparisons,
a_sr.mean_comparisons,
tolerances.mean_comps_regression,
&mut passed,
));
}
let result = DiskIndexCheckResult { comparisons };
if passed {
Ok(PassFail::Pass(result))
} else {
Ok(PassFail::Fail(result))
}
}
}