use std::sync::Arc;
use diskann::{
ANNResult,
graph::{self, glue},
provider,
};
use diskann_benchmark_runner::utils::{MicroSeconds, percentiles};
use diskann_utils::{future::AsyncFriendly, views::Matrix};
use crate::{
recall,
search::{self, Search, graph::Strategy},
utils,
};
#[derive(Debug)]
pub struct KNN<DP, T, S>
where
DP: provider::DataProvider,
{
index: Arc<graph::DiskANNIndex<DP>>,
queries: Arc<Matrix<T>>,
strategy: Strategy<S>,
}
impl<DP, T, S> KNN<DP, T, S>
where
DP: provider::DataProvider,
{
pub fn new(
index: Arc<graph::DiskANNIndex<DP>>,
queries: Arc<Matrix<T>>,
strategy: Strategy<S>,
) -> anyhow::Result<Arc<Self>> {
strategy.length_compatible(queries.nrows())?;
Ok(Arc::new(Self {
index,
queries,
strategy,
}))
}
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct Metrics {
pub comparisons: u32,
pub hops: u32,
}
impl<DP, T, S> Search for KNN<DP, T, S>
where
DP: provider::DataProvider<Context: Default, ExternalId: search::Id>,
S: glue::SearchStrategy<DP, [T], DP::ExternalId> + Clone + AsyncFriendly,
T: AsyncFriendly + Clone,
{
type Id = DP::ExternalId;
type Parameters = graph::search::Knn;
type Output = Metrics;
fn num_queries(&self) -> usize {
self.queries.nrows()
}
fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
search::IdCount::Fixed(parameters.k_value())
}
async fn search<O>(
&self,
parameters: &Self::Parameters,
buffer: &mut O,
index: usize,
) -> ANNResult<Self::Output>
where
O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
{
let context = DP::Context::default();
let knn_search = *parameters;
let stats = self
.index
.search(
knn_search,
self.strategy.get(index)?,
&context,
self.queries.row(index),
buffer,
)
.await?;
Ok(Metrics {
comparisons: stats.cmps,
hops: stats.hops,
})
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct Summary {
pub setup: search::Setup,
pub parameters: graph::search::Knn,
pub end_to_end_latencies: Vec<MicroSeconds>,
pub mean_latencies: Vec<f64>,
pub p90_latencies: Vec<MicroSeconds>,
pub p99_latencies: Vec<MicroSeconds>,
pub recall: recall::RecallMetrics,
pub mean_cmps: f64,
pub mean_hops: f64,
}
pub struct Aggregator<'a, I> {
groundtruth: &'a dyn crate::recall::Rows<I>,
recall_k: usize,
recall_n: usize,
}
impl<'a, I> Aggregator<'a, I> {
pub fn new(
groundtruth: &'a dyn crate::recall::Rows<I>,
recall_k: usize,
recall_n: usize,
) -> Self {
Self {
groundtruth,
recall_k,
recall_n,
}
}
}
impl<I> search::Aggregate<graph::search::Knn, I, Metrics> for Aggregator<'_, I>
where
I: crate::recall::RecallCompatible,
{
type Output = Summary;
fn aggregate(
&mut self,
run: search::Run<graph::search::Knn>,
mut results: Vec<search::SearchResults<I, Metrics>>,
) -> anyhow::Result<Summary> {
let recall = match results.first() {
Some(first) => crate::recall::knn(
self.groundtruth,
None,
first.ids().as_rows(),
self.recall_k,
self.recall_n,
true,
)?,
None => anyhow::bail!("Results must be non-empty"),
};
let mut mean_latencies = Vec::with_capacity(results.len());
let mut p90_latencies = Vec::with_capacity(results.len());
let mut p99_latencies = Vec::with_capacity(results.len());
results.iter_mut().for_each(|r| {
match percentiles::compute_percentiles(r.latencies_mut()) {
Ok(values) => {
let percentiles::Percentiles { mean, p90, p99, .. } = values;
mean_latencies.push(mean);
p90_latencies.push(p90);
p99_latencies.push(p99);
}
Err(_) => {
let zero = MicroSeconds::new(0);
mean_latencies.push(0.0);
p90_latencies.push(zero);
p99_latencies.push(zero);
}
}
});
Ok(Summary {
setup: run.setup().clone(),
parameters: *run.parameters(),
end_to_end_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(),
recall,
mean_latencies,
p90_latencies,
p99_latencies,
mean_cmps: utils::average_all(
results
.iter()
.flat_map(|r| r.output().iter().map(|o| o.comparisons)),
),
mean_hops: utils::average_all(
results
.iter()
.flat_map(|r| r.output().iter().map(|o| o.hops)),
),
})
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroUsize;
use super::*;
use diskann::graph::test::provider;
#[test]
fn test_knn() {
let nearest_neighbors = 5;
let index = search::graph::test_grid_provider();
let mut queries = Matrix::new(0.0f32, 5, index.provider().dim());
queries.row_mut(0).copy_from_slice(&[0.0, 0.0, 0.0, 0.0]);
queries.row_mut(1).copy_from_slice(&[4.0, 0.0, 0.0, 0.0]);
queries.row_mut(2).copy_from_slice(&[0.0, 4.0, 0.0, 0.0]);
queries.row_mut(3).copy_from_slice(&[0.0, 0.0, 4.0, 0.0]);
queries.row_mut(4).copy_from_slice(&[0.0, 0.0, 0.0, 4.0]);
let queries = Arc::new(queries);
let knn = KNN::new(
index,
queries.clone(),
Strategy::broadcast(provider::Strategy::new()),
)
.unwrap();
let rt = crate::tokio::runtime(2).unwrap();
let results = search::search(
knn.clone(),
graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
NonZeroUsize::new(2).unwrap(),
&rt,
)
.unwrap();
assert_eq!(results.len(), queries.nrows());
let rows = results.ids().as_rows();
assert_eq!(*rows.row(0).first().unwrap(), 0);
for r in 0..rows.nrows() {
assert_eq!(rows.row(r).len(), nearest_neighbors);
}
const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap();
let setup = search::Setup {
threads: TWO,
tasks: TWO,
reps: TWO,
};
let parameters = [
search::Run::new(
graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
setup.clone(),
),
search::Run::new(
graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(),
setup.clone(),
),
];
let recall_k = nearest_neighbors;
let recall_n = nearest_neighbors;
let all =
search::search_all(knn, parameters, Aggregator::new(rows, recall_k, recall_n)).unwrap();
assert_eq!(all.len(), 2);
for summary in all {
assert_eq!(summary.setup, setup);
assert_eq!(summary.end_to_end_latencies.len(), TWO.get());
assert_eq!(summary.mean_latencies.len(), TWO.get());
assert_eq!(summary.p90_latencies.len(), TWO.get());
assert_eq!(summary.p99_latencies.len(), TWO.get());
assert_ne!(summary.mean_cmps, 0.0);
assert_ne!(summary.mean_hops, 0.0);
let recall = summary.recall;
assert_eq!(recall.recall_k, recall_k);
assert_eq!(recall.recall_n, recall_n);
assert_eq!(recall.num_queries, queries.nrows());
assert_eq!(recall.average, 1.0, "we used a search as the groundtruth");
}
}
#[test]
fn test_knn_error() {
let index = search::graph::test_grid_provider();
let queries = Arc::new(Matrix::new(0.0f32, 1, index.provider().dim()));
let strategy = provider::Strategy::new();
let err = KNN::new(
index,
queries.clone(),
Strategy::collection([strategy, strategy]),
)
.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("2 strategies were provided when 1 was expected"),
"failed with {msg}"
);
}
}