use std::sync::Arc;
use diskann::{
ANNResult,
graph::{self, glue},
provider,
};
use diskann_utils::{future::AsyncFriendly, views::Matrix};
use crate::search::{self, Search, graph::Strategy};
#[derive(Debug)]
pub struct MultiHop<DP, T, S>
where
DP: provider::DataProvider,
{
index: Arc<graph::DiskANNIndex<DP>>,
queries: Arc<Matrix<T>>,
strategy: Strategy<S>,
labels: Arc<[Arc<dyn graph::index::QueryLabelProvider<DP::InternalId>>]>,
}
impl<DP, T, S> MultiHop<DP, T, S>
where
DP: provider::DataProvider,
{
pub fn new(
index: Arc<graph::DiskANNIndex<DP>>,
queries: Arc<Matrix<T>>,
strategy: Strategy<S>,
labels: Arc<[Arc<dyn graph::index::QueryLabelProvider<DP::InternalId>>]>,
) -> anyhow::Result<Arc<Self>> {
strategy.length_compatible(queries.nrows())?;
if labels.len() != queries.nrows() {
Err(anyhow::anyhow!(
"Number of label providers ({}) must be equal to the number of queries ({})",
labels.len(),
queries.nrows()
))
} else {
Ok(Arc::new(Self {
index,
queries,
strategy,
labels,
}))
}
}
}
impl<DP, T, S> Search for MultiHop<DP, T, S>
where
DP: provider::DataProvider<Context: Default, ExternalId: search::Id>,
S: glue::SearchStrategy<DP, [T], DP::ExternalId> + Clone + AsyncFriendly,
T: AsyncFriendly + Clone,
{
type Id = DP::ExternalId;
type Parameters = graph::search::Knn;
type Output = super::knn::Metrics;
fn num_queries(&self) -> usize {
self.queries.nrows()
}
fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
search::IdCount::Fixed(parameters.k_value())
}
async fn search<O>(
&self,
parameters: &Self::Parameters,
buffer: &mut O,
index: usize,
) -> ANNResult<Self::Output>
where
O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
{
let context = DP::Context::default();
let multihop_search = graph::search::MultihopSearch::new(*parameters, &*self.labels[index]);
let stats = self
.index
.search(
multihop_search,
self.strategy.get(index)?,
&context,
self.queries.row(index),
buffer,
)
.await?;
Ok(super::knn::Metrics {
comparisons: stats.cmps,
hops: stats.hops,
})
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroUsize;
use super::*;
use diskann::graph::{index::QueryLabelProvider, test::provider};
#[derive(Debug)]
struct NoOdds;
impl graph::index::QueryLabelProvider<u32> for NoOdds {
fn is_match(&self, id: u32) -> bool {
id.is_multiple_of(2)
}
}
#[test]
fn test_multihop() {
let nearest_neighbors = 5;
let index = search::graph::test_grid_provider();
let mut queries = Matrix::new(0.0f32, 5, index.provider().dim());
queries.row_mut(0).copy_from_slice(&[0.0, 0.0, 0.0, 0.0]);
queries.row_mut(1).copy_from_slice(&[4.0, 0.0, 0.0, 0.0]);
queries.row_mut(2).copy_from_slice(&[0.0, 4.0, 0.0, 0.0]);
queries.row_mut(3).copy_from_slice(&[0.0, 0.0, 4.0, 0.0]);
queries.row_mut(4).copy_from_slice(&[0.0, 0.0, 0.0, 4.0]);
let queries = Arc::new(queries);
let multihop = MultiHop::new(
index,
queries.clone(),
Strategy::broadcast(provider::Strategy::new()),
(0..queries.nrows())
.map(|_| -> Arc<dyn QueryLabelProvider<_>> { Arc::new(NoOdds {}) })
.collect(),
)
.unwrap();
let rt = crate::tokio::runtime(2).unwrap();
let results = search::search(
multihop.clone(),
graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
NonZeroUsize::new(2).unwrap(),
&rt,
)
.unwrap();
assert_eq!(results.len(), queries.nrows());
let rows = results.ids().as_rows();
assert_eq!(*rows.row(0).first().unwrap(), 0);
for r in 0..rows.nrows() {
assert_eq!(rows.row(r).len(), nearest_neighbors);
for &id in rows.row(r) {
assert_eq!(id % 2, 0, "Found odd ID {} in row {}", id, r);
}
}
const TWO: NonZeroUsize = NonZeroUsize::new(2).unwrap();
let setup = search::Setup {
threads: TWO,
tasks: TWO,
reps: TWO,
};
let parameters = [
search::Run::new(
graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
setup.clone(),
),
search::Run::new(
graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(),
setup.clone(),
),
];
let recall_k = nearest_neighbors;
let recall_n = nearest_neighbors;
let all = search::search_all(
multihop,
parameters,
search::graph::knn::Aggregator::new(rows, recall_k, recall_n),
)
.unwrap();
assert_eq!(all.len(), 2);
for summary in all {
assert_eq!(summary.setup, setup);
assert_eq!(summary.end_to_end_latencies.len(), TWO.get());
assert_eq!(summary.mean_latencies.len(), TWO.get());
assert_eq!(summary.p90_latencies.len(), TWO.get());
assert_eq!(summary.p99_latencies.len(), TWO.get());
assert_ne!(summary.mean_cmps, 0.0);
assert_ne!(summary.mean_hops, 0.0);
let recall = summary.recall;
assert_eq!(recall.recall_k, recall_k);
assert_eq!(recall.recall_n, recall_n);
assert_eq!(recall.num_queries, queries.nrows());
assert_eq!(recall.average, 1.0, "we used a search as the groundtruth");
}
}
#[test]
fn test_multihop_error() {
let index = search::graph::test_grid_provider();
let queries = Arc::new(Matrix::new(0.0f32, 2, index.provider().dim()));
let labels: Arc<[_]> = (0..queries.nrows() + 1)
.map(|_| -> Arc<dyn QueryLabelProvider<_>> { Arc::new(NoOdds {}) })
.collect();
let strategy = provider::Strategy::new();
let err = MultiHop::new(
index.clone(),
queries.clone(),
Strategy::collection([strategy]),
labels.clone(),
)
.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("1 strategy was provided when 2 were expected"),
"failed with {msg}"
);
let err = MultiHop::new(
index,
queries.clone(),
Strategy::broadcast(strategy),
labels.clone(),
)
.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains(
"Number of label providers (3) must be equal to the number of queries (2)"
),
"failed with {msg}"
);
}
}