use std::{future::Future, sync::Arc};
use diskann_utils::Reborrow;
use diskann_utils::future::AssertSend;
use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction};
use crate::{
ANNError, ANNResult,
error::{ErrorExt, StandardError},
graph::{AdjacencyList, SearchOutputBuffer, workingset},
neighbor::Neighbor,
provider::{
Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer,
DataProvider, HasId, NeighborAccessor,
},
utils::VectorId,
};
pub trait SearchExt: Accessor {
fn starting_points(&self)
-> impl std::future::Future<Output = ANNResult<Vec<Self::Id>>> + Send;
fn terminate_early(&mut self) -> bool {
false
}
fn is_not_start_point(
&self,
) -> impl std::future::Future<
Output = ANNResult<impl Fn(Self::Id) -> bool + Send + Sync + 'static>,
> + Send {
async move {
let starting_points = self.starting_points().await?;
Ok(move |id| !starting_points.contains(&id))
}
}
}
pub trait Predicate<T> {
fn eval(&self, item: &T) -> bool;
}
pub trait PredicateMut<T> {
fn eval_mut(&mut self, item: &T) -> bool;
}
pub trait HybridPredicate<T>: Predicate<T> + PredicateMut<T> {}
pub struct NotInMut<'a, K>(&'a mut hashbrown::HashSet<K>);
impl<'a, K> NotInMut<'a, K> {
pub fn new(set: &'a mut hashbrown::HashSet<K>) -> Self {
Self(set)
}
}
impl<T> Predicate<T> for NotInMut<'_, T>
where
T: Eq + std::hash::Hash,
{
fn eval(&self, item: &T) -> bool {
!self.0.contains(item)
}
}
impl<T> PredicateMut<T> for NotInMut<'_, T>
where
T: Clone + Eq + std::hash::Hash,
{
fn eval_mut(&mut self, item: &T) -> bool {
self.0.insert(item.clone())
}
}
impl<T> HybridPredicate<T> for NotInMut<'_, T> where T: Clone + Eq + std::hash::Hash {}
pub trait ExpandBeam<T>: BuildQueryComputer<T> + AsNeighbor + Sized {
fn expand_beam<Itr, P, F>(
&mut self,
ids: Itr,
computer: &Self::QueryComputer,
mut pred: P,
mut on_neighbors: F,
) -> impl std::future::Future<Output = ANNResult<()>> + Send
where
Itr: Iterator<Item = Self::Id> + Send,
P: HybridPredicate<Self::Id> + Send + Sync,
F: FnMut(f32, Self::Id) + Send,
{
async move {
let mut neighbors = AdjacencyList::new();
for id in ids {
self.get_neighbors(id, &mut neighbors).send().await?;
neighbors.retain(|i| pred.eval(i));
self.distances_unordered(neighbors.iter().copied(), computer, |distance, id| {
if pred.eval_mut(&id) {
on_neighbors(distance, id);
}
})
.send()
.await
.allow_transient("allowing transient error in beam expansion")?;
}
Ok(())
}
}
}
pub trait SearchStrategy<Provider, T>: Send + Sync
where
Provider: DataProvider,
{
type QueryComputer: for<'a, 'b> PreprocessedDistanceFunction<
<Self::SearchAccessor<'a> as Accessor>::ElementRef<'b>,
f32,
> + Send
+ Sync
+ 'static;
type SearchAccessorError: StandardError;
type SearchAccessor<'a>: ExpandBeam<T, QueryComputer = Self::QueryComputer, Id = Provider::InternalId>
+ SearchExt;
fn search_accessor<'a>(
&'a self,
provider: &'a Provider,
context: &'a Provider::Context,
) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError>;
}
pub trait DefaultPostProcessor<Provider, T, O = <Provider as DataProvider>::InternalId>:
SearchStrategy<Provider, T>
where
Provider: DataProvider,
O: Send,
{
type Processor: for<'a> SearchPostProcess<Self::SearchAccessor<'a>, T, O> + Send + Sync;
fn default_post_processor(&self) -> Self::Processor;
}
pub trait DefaultSearchStrategy<Provider, T, O = <Provider as DataProvider>::InternalId>:
SearchStrategy<Provider, T> + DefaultPostProcessor<Provider, T, O>
where
Provider: DataProvider,
O: Send,
{
}
impl<S, Provider, T, O> DefaultSearchStrategy<Provider, T, O> for S
where
S: SearchStrategy<Provider, T> + DefaultPostProcessor<Provider, T, O>,
Provider: DataProvider,
O: Send,
{
}
#[macro_export]
macro_rules! default_post_processor {
($Processor:ty) => {
type Processor = $Processor;
fn default_post_processor(&self) -> Self::Processor {
Default::default()
}
};
}
pub trait SearchPostProcess<A, T, O = <A as HasId>::Id>
where
A: BuildQueryComputer<T>,
{
type Error: StandardError;
fn post_process<I, B>(
&self,
accessor: &mut A,
query: T,
computer: &<A as BuildQueryComputer<T>>::QueryComputer,
candidates: I,
output: &mut B,
) -> impl std::future::Future<Output = Result<usize, Self::Error>> + Send
where
I: Iterator<Item = Neighbor<A::Id>> + Send,
B: SearchOutputBuffer<O> + Send + ?Sized;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct CopyIds;
impl<A, T> SearchPostProcess<A, T> for CopyIds
where
A: BuildQueryComputer<T>,
{
type Error = std::convert::Infallible;
fn post_process<I, B>(
&self,
_accessor: &mut A,
_query: T,
_computer: &A::QueryComputer,
candidates: I,
output: &mut B,
) -> impl std::future::Future<Output = Result<usize, Self::Error>> + Send
where
I: Iterator<Item = Neighbor<A::Id>> + Send,
B: SearchOutputBuffer<A::Id> + Send + ?Sized,
{
let count = output.extend(candidates.map(|n| (n.id, n.distance)));
std::future::ready(Ok(count))
}
}
pub trait SearchPostProcessStep<A, T, O = <A as HasId>::Id>
where
A: BuildQueryComputer<T>,
{
type Error<NextError>: StandardError
where
NextError: StandardError;
type NextAccessor: BuildQueryComputer<T, Id = A::Id>;
fn post_process_step<I, B, Next>(
&self,
next: &Next,
accessor: &mut A,
query: T,
computer: &<A as BuildQueryComputer<T>>::QueryComputer,
candidates: I,
output: &mut B,
) -> impl std::future::Future<Output = Result<usize, Self::Error<Next::Error>>> + Send
where
I: Iterator<Item = Neighbor<A::Id>> + Send,
B: SearchOutputBuffer<O> + Send + ?Sized,
Next: SearchPostProcess<Self::NextAccessor, T, O> + Sync;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct FilterStartPoints;
impl<A, T, O> SearchPostProcessStep<A, T, O> for FilterStartPoints
where
A: BuildQueryComputer<T> + SearchExt,
T: Copy + Send + Sync,
{
type Error<NextError>
= ANNError
where
NextError: StandardError;
type NextAccessor = A;
async fn post_process_step<I, B, Next>(
&self,
next: &Next,
accessor: &mut A,
query: T,
computer: &A::QueryComputer,
candidates: I,
output: &mut B,
) -> ANNResult<usize>
where
I: Iterator<Item = Neighbor<A::Id>> + Send,
B: SearchOutputBuffer<O> + Send + ?Sized,
Next: SearchPostProcess<A, T, O> + Sync,
{
let filter = accessor.is_not_start_point().await?;
next.post_process(
accessor,
query,
computer,
candidates.filter(|n| filter(n.id)),
output,
)
.await
.map_err(|err| {
let err = err.into();
err.context("after filtering start points")
})
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Pipeline<Head, Tail> {
head: Head,
tail: Tail,
}
impl<Head, Tail> Pipeline<Head, Tail> {
pub fn new(head: Head, tail: Tail) -> Self {
Self { head, tail }
}
}
impl<A, T, O, Head, Tail> SearchPostProcess<A, T, O> for Pipeline<Head, Tail>
where
A: BuildQueryComputer<T>,
Head: SearchPostProcessStep<A, T, O>,
Tail: SearchPostProcess<Head::NextAccessor, T, O> + Sync,
{
type Error = Head::Error<Tail::Error>;
fn post_process<I, B>(
&self,
accessor: &mut A,
query: T,
computer: &<A as BuildQueryComputer<T>>::QueryComputer,
candidates: I,
output: &mut B,
) -> impl std::future::Future<Output = Result<usize, Self::Error>> + Send
where
I: Iterator<Item = Neighbor<A::Id>> + Send,
B: SearchOutputBuffer<O> + Send + ?Sized,
{
self.head
.post_process_step(&self.tail, accessor, query, computer, candidates, output)
}
}
pub trait InsertStrategy<Provider, T>: SearchStrategy<Provider, T> + 'static
where
Provider: DataProvider,
{
type PruneStrategy: PruneStrategy<Provider>;
fn prune_strategy(&self) -> Self::PruneStrategy;
fn insert_search_accessor<'a>(
&'a self,
provider: &'a Provider,
context: &'a Provider::Context,
) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
self.search_accessor(provider, context)
}
}
pub trait PruneStrategy<Provider>: Send + Sync + 'static
where
Provider: DataProvider,
{
type WorkingSet: Send + Sync;
type DistanceComputer: for<'a, 'b, 'c, 'd> DistanceFunction<
<Self::PruneAccessor<'a> as Accessor>::ElementRef<'b>,
<Self::PruneAccessor<'c> as Accessor>::ElementRef<'d>,
f32,
> + Send
+ Sync
+ 'static;
type PruneAccessor<'a>: Accessor<Id = Provider::InternalId>
+ BuildDistanceComputer<DistanceComputer = Self::DistanceComputer>
+ AsNeighborMut
+ workingset::Fill<Self::WorkingSet>;
type PruneAccessorError: StandardError;
fn create_working_set(&self, capacity: usize) -> Self::WorkingSet;
fn prune_accessor<'a>(
&'a self,
provider: &'a Provider,
context: &'a Provider::Context,
) -> Result<Self::PruneAccessor<'a>, Self::PruneAccessorError>;
}
pub trait MultiInsertStrategy<Provider, B>: Send + Sync
where
Provider: DataProvider,
B: Batch,
{
type WorkingSet: Send + Sync + 'static;
type Seed: workingset::AsWorkingSet<Self::WorkingSet> + Send + Sync + 'static;
type FinishError: Into<ANNError> + std::fmt::Debug + Send + Sync;
type InsertStrategy: for<'a> InsertStrategy<
Provider,
B::Element<'a>,
PruneStrategy: PruneStrategy<Provider, WorkingSet = Self::WorkingSet>,
>;
fn insert_strategy(&self) -> Self::InsertStrategy;
fn finish<Itr>(
&self,
provider: &Provider,
context: &Provider::Context,
batch: &Arc<B>,
ids: Itr,
) -> impl std::future::Future<Output = Result<Self::Seed, Self::FinishError>> + Send
where
Itr: ExactSizeIterator<Item = Provider::InternalId> + Send;
}
pub trait Batch: Send + Sync + 'static {
type Element<'a>: Copy;
fn len(&self) -> usize;
fn get(&self, i: usize) -> Self::Element<'_>;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T: Send + Sync + 'static> Batch for diskann_utils::views::Matrix<T> {
type Element<'a> = &'a [T];
fn len(&self) -> usize {
self.nrows()
}
fn get(&self, i: usize) -> Self::Element<'_> {
self.row(i)
}
}
pub trait InplaceDeleteStrategy<Provider>: Send + Sync + 'static
where
Provider: DataProvider,
{
type DeleteElement<'a>: Copy + Send + Sync;
type DeleteElementGuard: Send
+ Sync
+ for<'a> Reborrow<'a, Target = Self::DeleteElement<'a>>
+ 'static;
type DeleteElementError: StandardError;
type PruneStrategy: PruneStrategy<Provider>;
type DeleteSearchAccessor<'a>: ExpandBeam<Self::DeleteElement<'a>, Id = Provider::InternalId>
+ SearchExt;
type SearchPostProcessor: for<'a> SearchPostProcess<Self::DeleteSearchAccessor<'a>, Self::DeleteElement<'a>>
+ Send
+ Sync;
type SearchStrategy: for<'a> SearchStrategy<
Provider,
Self::DeleteElement<'a>,
SearchAccessor<'a> = Self::DeleteSearchAccessor<'a>,
>;
fn prune_strategy(&self) -> Self::PruneStrategy;
fn search_strategy(&self) -> Self::SearchStrategy;
fn search_post_processor(&self) -> Self::SearchPostProcessor;
fn get_delete_element<'a>(
&'a self,
provider: &'a Provider,
context: &'a Provider::Context,
id: Provider::InternalId,
) -> impl Future<Output = Result<Self::DeleteElementGuard, Self::DeleteElementError>> + Send;
}
pub trait IdIterator<I>
where
I: Iterator<Item: VectorId>,
{
fn id_iterator(&mut self) -> impl std::future::Future<Output = Result<I, ANNError>>;
}
#[cfg(test)]
mod tests {
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use diskann_vector::PreprocessedDistanceFunction;
use futures_util::future;
use super::*;
use crate::{
ANNResult, neighbor,
provider::{DelegateNeighbor, ExecutionContext, HasId, NeighborAccessor},
};
struct SimpleProvider {
items: Vec<f32>,
}
#[derive(Default, Clone)]
struct CountGetVector {
count: Arc<AtomicUsize>,
}
impl ExecutionContext for CountGetVector {}
impl CountGetVector {
fn count(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
fn clear(&self) {
self.count.store(0, Ordering::Relaxed)
}
}
impl DataProvider for SimpleProvider {
type Context = CountGetVector;
type InternalId = u32;
type ExternalId = u32;
type Error = ANNError;
type Guard = crate::provider::NoopGuard<u32>;
fn to_internal_id(
&self,
_context: &CountGetVector,
gid: &Self::ExternalId,
) -> Result<Self::InternalId, Self::Error> {
Ok(*gid)
}
fn to_external_id(
&self,
_context: &CountGetVector,
id: Self::InternalId,
) -> Result<Self::ExternalId, Self::Error> {
Ok(id)
}
}
#[derive(Clone, Copy)]
struct Retriever<'a> {
provider: &'a SimpleProvider,
count: &'a CountGetVector,
}
impl SearchExt for Retriever<'_> {
async fn starting_points(&self) -> ANNResult<Vec<u32>> {
Ok(vec![0])
}
}
impl<'a> Retriever<'a> {
fn new(provider: &'a SimpleProvider, count: &'a CountGetVector) -> Self {
Self { provider, count }
}
}
impl HasId for Retriever<'_> {
type Id = u32;
}
impl Accessor for Retriever<'_> {
type Element<'a>
= f32
where
Self: 'a;
type ElementRef<'a> = f32;
type GetError = ANNError;
fn get_element(
&mut self,
id: Self::Id,
) -> impl std::future::Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send
{
let result = match self.provider.items.get(id as usize) {
Some(v) => {
self.count.count.fetch_add(1, Ordering::Relaxed);
Ok(*v)
}
None => panic!("invalid id: {}", id),
};
async move { result }
}
}
impl NeighborAccessor for Retriever<'_> {
fn get_neighbors(
self,
_id: Self::Id,
neighbors: &mut AdjacencyList<Self::Id>,
) -> impl Future<Output = ANNResult<Self>> + Send {
neighbors.clear();
future::ok(self)
}
}
struct QueryComputer;
impl PreprocessedDistanceFunction<f32, f32> for QueryComputer {
fn evaluate_similarity(&self, _changing: f32) -> f32 {
panic!("this method should not be called")
}
}
impl BuildQueryComputer<f32> for Retriever<'_> {
type QueryComputerError = ANNError;
type QueryComputer = QueryComputer;
fn build_query_computer(&self, _from: f32) -> Result<QueryComputer, ANNError> {
Ok(QueryComputer)
}
}
impl ExpandBeam<f32> for Retriever<'_> {}
struct Strategy;
impl SearchStrategy<SimpleProvider, f32> for Strategy {
type QueryComputer = QueryComputer;
type SearchAccessorError = ANNError;
type SearchAccessor<'a> = Retriever<'a>;
fn search_accessor<'a>(
&'a self,
provider: &'a SimpleProvider,
context: &'a CountGetVector,
) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
Ok(Retriever::new(provider, context))
}
}
impl DefaultPostProcessor<SimpleProvider, f32> for Strategy {
default_post_processor!(CopyIds);
}
#[tokio::test(flavor = "current_thread")]
async fn test_default_post_process() {
let ctx = CountGetVector::default();
let strategy = Strategy;
let num_points: usize = 100;
let provider = SimpleProvider {
items: (0..num_points).map(|i| i as f32).collect(),
};
assert_eq!(provider.to_internal_id(&ctx, &10).unwrap(), 10);
assert_eq!(provider.to_external_id(&ctx, 10).unwrap(), 10);
let mut accessor = strategy.search_accessor(&provider, &ctx).unwrap();
assert_eq!(accessor.starting_points().await.unwrap().as_slice(), &[0]);
for i in 0..num_points {
assert_eq!(accessor.get_element(i as u32).await.unwrap(), i as f32);
}
let mut neighbors = AdjacencyList::new();
accessor
.delegate_neighbor()
.get_neighbors(0, &mut neighbors)
.await
.unwrap();
assert_eq!(neighbors.len(), 0);
assert_eq!(ctx.count(), num_points);
ctx.clear();
let query = 11.5;
let computer = accessor.build_query_computer(query).unwrap();
for input_len in 0..10 {
let input: Vec<_> = (0..input_len)
.map(|i| Neighbor::<u32>::new(i as u32, i as f32))
.collect();
for output_len in 0..10 {
let mut output = vec![Neighbor::<u32>::default(); output_len];
let count = strategy
.default_post_processor()
.post_process(
&mut accessor,
query,
&computer,
input.iter().copied(),
&mut neighbor::BackInserter::new(output.as_mut_slice()),
)
.await
.unwrap();
assert_eq!(count, input_len.min(output_len));
for (i, n) in output.iter().take(count).enumerate() {
assert_eq!(i, n.id as usize);
assert_eq!(i as f32, n.distance);
}
for n in output.iter().skip(count) {
assert_eq!(n.id, 0);
assert_eq!(n.distance, 0.0);
}
}
}
assert_eq!(ctx.count(), 0);
}
}