use std::{fmt::Debug, num::NonZeroUsize};
use diskann_utils::future::SendFuture;
use thiserror::Error;
use super::Search;
use crate::{
ANNError, ANNErrorKind, ANNResult,
error::IntoANNResult,
graph::{
glue::{SearchExt, SearchPostProcess, SearchStrategy},
index::{DiskANNIndex, SearchStats},
search::record::NoopSearchRecord,
search_output_buffer::SearchOutputBuffer,
},
provider::{BuildQueryComputer, DataProvider},
};
#[derive(Debug, Error)]
pub enum KnnSearchError {
#[error("l_value ({l_value}) cannot be less than k_value ({k_value})")]
LLessThanK { l_value: usize, k_value: usize },
#[error("beam width cannot be zero")]
BeamWidthZero,
#[error("k_value cannot be zero")]
KZero,
#[error("l_value cannot be zero")]
LZero,
}
impl From<KnnSearchError> for ANNError {
#[track_caller]
fn from(err: KnnSearchError) -> Self {
Self::new(ANNErrorKind::IndexError, err)
}
}
#[derive(Debug, Clone, Copy)]
pub struct Knn {
k_value: NonZeroUsize,
l_value: NonZeroUsize,
beam_width: NonZeroUsize,
}
impl Knn {
pub fn new(
k_value: usize,
l_value: usize,
beam_width: Option<usize>,
) -> Result<Self, KnnSearchError> {
let k_value = NonZeroUsize::new(k_value).ok_or(KnnSearchError::KZero)?;
let l_value = NonZeroUsize::new(l_value).ok_or(KnnSearchError::LZero)?;
if k_value > l_value {
return Err(KnnSearchError::LLessThanK {
l_value: l_value.get(),
k_value: k_value.get(),
});
}
const ONE: NonZeroUsize = NonZeroUsize::new(1).unwrap();
let beam_width = match beam_width {
Some(bw) => NonZeroUsize::new(bw).ok_or(KnnSearchError::BeamWidthZero)?,
None => ONE,
};
Ok(Self {
k_value,
l_value,
beam_width,
})
}
pub fn new_default(k_value: usize, l_value: usize) -> Result<Self, KnnSearchError> {
Self::new(k_value, l_value, None)
}
#[inline]
pub fn k_value(&self) -> NonZeroUsize {
self.k_value
}
#[inline]
pub fn l_value(&self) -> NonZeroUsize {
self.l_value
}
#[inline]
pub fn beam_width(&self) -> NonZeroUsize {
self.beam_width
}
}
impl<DP, S, T> Search<DP, S, T> for Knn
where
DP: DataProvider,
S: SearchStrategy<DP, T>,
T: Copy + Send + Sync,
{
type Output = SearchStats;
fn search<O, PP, OB>(
self,
index: &DiskANNIndex<DP>,
strategy: &S,
processor: PP,
context: &DP::Context,
query: T,
output: &mut OB,
) -> impl SendFuture<ANNResult<Self::Output>>
where
O: Send,
PP: for<'a> SearchPostProcess<S::SearchAccessor<'a>, T, O> + Send + Sync,
OB: SearchOutputBuffer<O> + Send + ?Sized,
{
async move {
let mut accessor = strategy
.search_accessor(&index.data_provider, context)
.into_ann_result()?;
let computer = accessor.build_query_computer(query).into_ann_result()?;
let start_ids = accessor.starting_points().await?;
let mut scratch = index.search_scratch(self.l_value.get(), start_ids.len());
let stats = index
.search_internal(
Some(self.beam_width.get()),
&start_ids,
&mut accessor,
&computer,
&mut scratch,
&mut NoopSearchRecord::new(),
)
.await?;
let result_count = processor
.post_process(&mut accessor, query, &computer, scratch.best.iter(), output)
.await
.into_ann_result()?;
Ok(stats.finish(result_count as u32))
}
}
}
#[derive(Debug)]
pub struct RecordedKnn<'r, SR: ?Sized> {
pub inner: Knn,
pub recorder: &'r mut SR,
}
impl<'r, SR: ?Sized> RecordedKnn<'r, SR> {
pub fn new(inner: Knn, recorder: &'r mut SR) -> Self {
Self { inner, recorder }
}
}
impl<'r, DP, S, T, SR> Search<DP, S, T> for RecordedKnn<'r, SR>
where
DP: DataProvider,
S: SearchStrategy<DP, T>,
T: Copy + Send + Sync,
SR: super::record::SearchRecord<DP::InternalId> + ?Sized,
{
type Output = SearchStats;
fn search<O, PP, OB>(
self,
index: &DiskANNIndex<DP>,
strategy: &S,
processor: PP,
context: &DP::Context,
query: T,
output: &mut OB,
) -> impl SendFuture<ANNResult<Self::Output>>
where
O: Send,
PP: for<'a> SearchPostProcess<S::SearchAccessor<'a>, T, O> + Send + Sync,
OB: SearchOutputBuffer<O> + Send + ?Sized,
{
async move {
let mut accessor = strategy
.search_accessor(&index.data_provider, context)
.into_ann_result()?;
let computer = accessor.build_query_computer(query).into_ann_result()?;
let start_ids = accessor.starting_points().await?;
let mut scratch = index.search_scratch(self.inner.l_value.get(), start_ids.len());
let stats = index
.search_internal(
Some(self.inner.beam_width.get()),
&start_ids,
&mut accessor,
&computer,
&mut scratch,
self.recorder,
)
.await?;
let result_count = processor
.post_process(&mut accessor, query, &computer, scratch.best.iter(), output)
.await
.into_ann_result()?;
Ok(stats.finish(result_count as u32))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_knn_search_validation() {
assert!(Knn::new(10, 100, None).is_ok());
assert!(Knn::new(10, 100, Some(4)).is_ok());
assert!(Knn::new(10, 10, None).is_ok());
assert!(matches!(Knn::new(0, 100, None), Err(KnnSearchError::KZero)));
assert!(matches!(Knn::new(10, 0, None), Err(KnnSearchError::LZero)));
assert!(matches!(
Knn::new(100, 10, None),
Err(KnnSearchError::LLessThanK { .. })
));
assert!(matches!(
Knn::new(10, 100, Some(0)),
Err(KnnSearchError::BeamWidthZero)
));
}
}