use std::ops::Deref;
use diskann_utils::{Reborrow, WithLifetime};
use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction};
use sealed::{BoundTo, Sealed};
use crate::{ANNError, ANNResult, error::ToRanked, graph::AdjacencyList, utils::VectorId};
pub trait ExecutionContext: Send + Sync + Clone + 'static {
fn wrap_spawn<F, T>(&self, f: F) -> impl std::future::Future<Output = T> + Send + 'static
where
F: std::future::Future<Output = T> + Send + 'static,
{
f
}
}
pub trait DataProvider: Sized + Send + Sync + 'static {
type Context: ExecutionContext;
type InternalId: VectorId;
type ExternalId: PartialEq + Send + Sync + 'static;
type Error: ToRanked + std::fmt::Debug + Send + Sync + 'static;
type Guard: Guard<Id = Self::InternalId> + 'static;
fn to_internal_id(
&self,
context: &Self::Context,
gid: &Self::ExternalId,
) -> Result<Self::InternalId, Self::Error>;
fn to_external_id(
&self,
context: &Self::Context,
id: Self::InternalId,
) -> Result<Self::ExternalId, Self::Error>;
}
pub trait Delete: DataProvider {
fn delete(
&self,
context: &Self::Context,
gid: &Self::ExternalId,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn release(
&self,
context: &Self::Context,
id: Self::InternalId,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
fn status_by_internal_id(
&self,
context: &Self::Context,
id: Self::InternalId,
) -> impl std::future::Future<Output = Result<ElementStatus, Self::Error>> + Send;
fn status_by_external_id(
&self,
context: &Self::Context,
gid: &Self::ExternalId,
) -> impl std::future::Future<Output = Result<ElementStatus, Self::Error>> + Send;
fn statuses_unordered<Itr, F>(
&self,
context: &Self::Context,
itr: Itr,
mut f: F,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send
where
Itr: Iterator<Item = Self::InternalId> + Send,
F: FnMut(Result<ElementStatus, Self::Error>, Self::InternalId) + Send,
{
async move {
for i in itr {
f(self.status_by_internal_id(context, i).await, i);
}
Ok(())
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ElementStatus {
Valid,
Deleted,
}
impl ElementStatus {
pub fn is_valid(self) -> bool {
self == Self::Valid
}
pub fn is_deleted(self) -> bool {
self == Self::Deleted
}
}
pub trait HasId {
type Id: VectorId;
}
impl<T> HasId for &T
where
T: HasId,
{
type Id = T::Id;
}
impl<T> HasId for &mut T
where
T: HasId,
{
type Id = T::Id;
}
pub trait SetElement<T>: DataProvider {
type SetError: ToRanked + std::fmt::Debug + Send + Sync + 'static;
fn set_element(
&self,
context: &Self::Context,
id: &Self::ExternalId,
element: T,
) -> impl std::future::Future<Output = Result<Self::Guard, Self::SetError>> + Send;
}
pub trait Guard: Send + Sync + 'static {
type Id;
fn complete(self) -> impl std::future::Future<Output = ()> + Send;
fn id(&self) -> Self::Id;
}
#[derive(Debug, Default)]
pub struct NoopGuard<I>(I);
impl<I> NoopGuard<I> {
pub fn new(id: I) -> Self {
Self(id)
}
}
impl<I> Guard for NoopGuard<I>
where
I: Send + Sync + Copy + 'static,
{
type Id = I;
async fn complete(self) {}
fn id(&self) -> Self::Id {
self.0
}
}
pub trait Accessor: HasId + Send + Sync {
type ElementRef<'a>;
type Element<'a>: for<'b> Reborrow<'b, Target = Self::ElementRef<'b>> + Send + Sync
where
Self: 'a;
type GetError: ToRanked + std::fmt::Debug + Send + Sync + 'static;
fn get_element(
&mut self,
id: Self::Id,
) -> impl std::future::Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send;
fn on_elements_unordered<Itr, F>(
&mut self,
itr: Itr,
mut f: F,
) -> impl std::future::Future<Output = Result<(), Self::GetError>> + Send
where
Self: Sync,
Itr: Iterator<Item = Self::Id> + Send,
F: Send + for<'a> FnMut(Self::ElementRef<'a>, Self::Id),
{
async move {
for i in itr {
f(self.get_element(i).await?.reborrow(), i);
}
Ok(())
}
}
}
pub trait CacheableAccessor: Accessor {
type Map: WithLifetime;
fn from_cached<'a>(element: <Self::Map as WithLifetime>::Of<'a>) -> Self::Element<'a>
where
Self: 'a;
fn as_cached<'a, 'b>(element: &'a Self::Element<'b>) -> &'a <Self::Map as WithLifetime>::Of<'b>
where
Self: 'a + 'b;
}
pub trait BuildDistanceComputer: Accessor {
type DistanceComputerError: std::error::Error + Into<ANNError> + Send + Sync + 'static;
type DistanceComputer: for<'a, 'b> DistanceFunction<Self::ElementRef<'a>, Self::ElementRef<'b>>
+ Send
+ Sync
+ 'static;
fn build_distance_computer(
&self,
) -> Result<Self::DistanceComputer, Self::DistanceComputerError>;
}
pub trait BuildQueryComputer<T>: Accessor {
type QueryComputerError: std::error::Error + Into<ANNError> + Send + Sync + 'static;
type QueryComputer: for<'a> PreprocessedDistanceFunction<Self::ElementRef<'a>, f32>
+ Send
+ Sync
+ 'static;
fn build_query_computer(
&self,
from: T,
) -> Result<Self::QueryComputer, Self::QueryComputerError>;
fn distances_unordered<Itr, F>(
&mut self,
vec_id_itr: Itr,
computer: &Self::QueryComputer,
mut f: F,
) -> impl std::future::Future<Output = Result<(), Self::GetError>> + Send
where
Itr: Iterator<Item = Self::Id> + Send,
F: Send + FnMut(f32, Self::Id),
{
self.on_elements_unordered(vec_id_itr, move |element, i| {
let distance = computer.evaluate_similarity(element);
f(distance, i);
})
}
}
pub trait NeighborAccessor: HasId + Sized + Send + Sync {
fn get_neighbors(
self,
id: Self::Id,
neighbors: &mut AdjacencyList<Self::Id>,
) -> impl std::future::Future<Output = ANNResult<Self>> + Send;
}
pub trait NeighborAccessorMut: NeighborAccessor {
fn set_neighbors(
self,
id: Self::Id,
neighbors: &[Self::Id],
) -> impl std::future::Future<Output = ANNResult<Self>> + Send;
fn append_vector(
self,
id: Self::Id,
neighbors: &[Self::Id],
) -> impl std::future::Future<Output = ANNResult<Self>> + Send;
fn set_neighbors_bulk<I, T>(
mut self,
iter: I,
) -> impl std::future::Future<Output = ANNResult<Self>> + Send
where
I: Iterator<Item = (Self::Id, T)> + Send,
T: Deref<Target = [Self::Id]> + Send,
{
async move {
for (vector_id, neighbors) in iter {
self = self.set_neighbors(vector_id, neighbors.deref()).await?;
}
Ok(self)
}
}
}
impl<T> NeighborAccessor for &mut T
where
T: AsNeighbor,
{
async fn get_neighbors(
self,
id: Self::Id,
neighbors: &mut AdjacencyList<Self::Id>,
) -> ANNResult<Self> {
self.delegate_neighbor()
.get_neighbors(id, neighbors)
.await?;
Ok(self)
}
}
impl<T> NeighborAccessorMut for &mut T
where
T: AsNeighborMut,
{
async fn set_neighbors(self, id: Self::Id, neighbors: &[Self::Id]) -> ANNResult<Self> {
self.delegate_neighbor()
.set_neighbors(id, neighbors)
.await?;
Ok(self)
}
async fn append_vector(self, id: Self::Id, neighbors: &[Self::Id]) -> ANNResult<Self> {
self.delegate_neighbor()
.append_vector(id, neighbors)
.await?;
Ok(self)
}
async fn set_neighbors_bulk<I, U>(self, iter: I) -> ANNResult<Self>
where
I: Iterator<Item = (Self::Id, U)> + Send,
U: Deref<Target = [Self::Id]> + Send,
{
self.delegate_neighbor().set_neighbors_bulk(iter).await?;
Ok(self)
}
}
pub trait DelegateNeighbor<'this, Lifetime: Sealed = BoundTo<&'this Self>>:
HasId + Send + Sync
{
type Delegate: NeighborAccessor<Id = Self::Id>;
fn delegate_neighbor(&'this mut self) -> Self::Delegate;
}
impl<'this, T> DelegateNeighbor<'this> for T
where
T: Copy + NeighborAccessor,
{
type Delegate = Self;
fn delegate_neighbor(&'this mut self) -> Self::Delegate {
*self
}
}
pub trait AsNeighbor: for<'a> DelegateNeighbor<'a> {}
pub trait AsNeighborMut: for<'a> DelegateNeighbor<'a, Delegate: NeighborAccessorMut> {}
impl<T> AsNeighbor for T where T: for<'a> DelegateNeighbor<'a> {}
impl<T> AsNeighborMut for T where T: for<'a> DelegateNeighbor<'a, Delegate: NeighborAccessorMut> {}
pub trait DefaultAccessor: DataProvider {
type Accessor<'a>: HasId<Id = Self::InternalId>
where
Self: 'a;
fn default_accessor(&self) -> Self::Accessor<'_>;
}
#[derive(Default, Clone)]
pub struct DefaultContext;
impl std::fmt::Display for DefaultContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "default context")
}
}
impl ExecutionContext for DefaultContext {}
mod sealed {
pub trait Sealed: Sized {}
pub struct BoundTo<T>(T);
impl<T> Sealed for BoundTo<T> {}
}
#[cfg(test)]
mod tests {
use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::{
Arc, Mutex,
atomic::{AtomicUsize, Ordering},
},
task,
};
use pin_project::{pin_project, pinned_drop};
use super::*;
use crate::{always_escalate, error::Infallible};
#[test]
fn test_default_context() {
let ctx = DefaultContext;
assert_eq!(ctx.to_string(), "default context");
assert_eq!(
std::mem::size_of::<DefaultContext>(),
0,
"expected DefaultContext to be an empty class"
);
}
#[derive(Debug)]
struct TestContextInner {
spawned: AtomicUsize,
dropped: AtomicUsize,
}
#[derive(Debug, Clone)]
struct TestContext {
inner: Arc<TestContextInner>,
}
impl Default for TestContext {
fn default() -> Self {
Self {
inner: Arc::new(TestContextInner {
spawned: AtomicUsize::new(0),
dropped: AtomicUsize::new(0),
}),
}
}
}
impl std::fmt::Display for TestContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "test context")
}
}
#[pin_project(PinnedDrop)]
pub struct SpawnCounter<F> {
#[pin]
inner: F,
parent: TestContext,
}
#[pinned_drop]
impl<F> PinnedDrop for SpawnCounter<F> {
fn drop(self: Pin<&mut Self>) {
self.parent.inner.dropped.fetch_add(1, Ordering::AcqRel);
}
}
impl<F> Future for SpawnCounter<F>
where
F: Future,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
impl ExecutionContext for TestContext {
fn wrap_spawn<F, T>(&self, f: F) -> impl Future<Output = T> + Send + 'static
where
F: Future<Output = T> + Send + 'static,
{
self.inner.spawned.fetch_add(1, Ordering::AcqRel);
SpawnCounter {
inner: f,
parent: self.clone(),
}
}
}
#[allow(clippy::manual_async_fn)]
fn test_spawning<Context>(
context: Context,
width: usize,
depth: usize,
) -> impl Future<Output = ()> + Send + 'static
where
Context: ExecutionContext + 'static + std::fmt::Debug,
{
async move {
if depth == 0 {
return;
}
let handles: Box<[_]> = (0..width)
.map(|_| {
let clone = context.clone();
tokio::spawn(context.wrap_spawn(test_spawning(clone, width, depth - 1)))
})
.collect();
for h in handles {
h.await.unwrap();
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_task_spawning() {
let context = TestContext::default();
assert_eq!(context.inner.spawned.load(Ordering::Acquire), 0);
assert_eq!(context.inner.dropped.load(Ordering::Acquire), 0);
let width = 10;
let depth = 3;
test_spawning(context.clone(), width, depth).await;
let expected = (width.pow((depth + 1).try_into().unwrap()) - 1) / (width - 1) - 1;
assert_eq!(context.inner.spawned.load(Ordering::Acquire), expected);
assert_eq!(context.inner.dropped.load(Ordering::Acquire), expected);
}
#[tokio::test]
async fn test_noop_guard() {
{
let guard = NoopGuard::<usize>::new(10);
assert_eq!(guard.id(), 10);
guard.complete().await;
}
{
let guard = NoopGuard::<usize>::new(5);
assert_eq!(guard.id(), 5);
}
}
#[test]
fn simple_status_test() {
let valid = ElementStatus::Valid;
assert!(valid.is_valid());
assert!(!valid.is_deleted());
let deleted = ElementStatus::Deleted;
assert!(!deleted.is_valid());
assert!(deleted.is_deleted());
}
struct SimpleProvider {
data: Mutex<HashMap<u32, (f32, String)>>,
}
impl SimpleProvider {
fn new(v: f32, st: String) -> Self {
let mut data = HashMap::new();
data.insert(u32::MAX, (v, st));
Self {
data: Mutex::new(data),
}
}
}
impl DataProvider for SimpleProvider {
type Context = DefaultContext;
type InternalId = u32;
type ExternalId = u32;
type Error = ANNError;
type Guard = NoopGuard<u32>;
fn to_internal_id(&self, _context: &DefaultContext, gid: &u32) -> Result<u32, ANNError> {
Ok(*gid)
}
fn to_external_id(&self, _context: &DefaultContext, id: u32) -> Result<u32, ANNError> {
Ok(id)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Missing;
impl std::fmt::Display for Missing {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "key is missing")
}
}
impl std::error::Error for Missing {}
impl From<Missing> for ANNError {
#[cold]
fn from(missing: Missing) -> ANNError {
ANNError::log_async_error(missing)
}
}
always_escalate!(Missing);
struct FloatAccessor<'a>(&'a SimpleProvider);
impl HasId for FloatAccessor<'_> {
type Id = u32;
}
impl Accessor for FloatAccessor<'_> {
type Element<'a>
= f32
where
Self: 'a;
type ElementRef<'a> = f32;
type GetError = Missing;
fn get_element(
&mut self,
id: u32,
) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
let guard = self.0.data.lock().unwrap();
let v = match guard.get(&id) {
None => Err(Missing),
Some(v) => Ok(v.0),
};
std::future::ready(v)
}
async fn on_elements_unordered<Itr, F>(
&mut self,
itr: Itr,
mut f: F,
) -> Result<(), Self::GetError>
where
Self: Sync,
Itr: Iterator<Item = u32>,
F: Send + FnMut(f32, u32),
{
let guard = self.0.data.lock().unwrap();
for i in itr {
match guard.get(&i) {
None => return Err(Missing),
Some(v) => f(v.0, i),
}
}
Ok(())
}
}
struct StringAccessor<'a> {
provider: &'a SimpleProvider,
buf: String,
}
impl<'a> StringAccessor<'a> {
fn new(provider: &'a SimpleProvider) -> Self {
Self {
provider,
buf: String::new(),
}
}
}
impl HasId for StringAccessor<'_> {
type Id = u32;
}
impl Accessor for StringAccessor<'_> {
type Element<'a>
= &'a str
where
Self: 'a;
type ElementRef<'a> = &'a str;
type GetError = Missing;
fn get_element(
&mut self,
id: u32,
) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
let guard = self.provider.data.lock().unwrap();
let v = match guard.get(&id) {
None => Err(Missing),
Some(v) => {
self.buf.clone_from(&v.1);
Ok(&*self.buf)
}
};
std::future::ready(v)
}
}
#[tokio::test]
async fn test_default_implementations() {
let provider = SimpleProvider::new(-1.0, "hello".to_string());
{
let mut data = provider.data.lock().unwrap();
data.insert(0, (0.0, "world".to_string()));
data.insert(1, (1.0, "foo".to_string()));
data.insert(2, (2.0, "bar".to_string()));
}
{
let mut accessor = FloatAccessor(&provider);
assert_eq!(accessor.get_element(0).await.unwrap(), 0.0);
assert_eq!(accessor.get_element(1).await.unwrap(), 1.0);
assert_eq!(accessor.get_element(u32::MAX).await.unwrap(), -1.0);
let mut v = Vec::new();
accessor
.on_elements_unordered([2, 1, 0].into_iter(), |element, id| v.push((element, id)))
.await
.unwrap();
assert_eq!(&v, &[(2.0, 2), (1.0, 1), (0.0, 0)]);
let err = accessor
.on_elements_unordered([2, 1, 0, 3].into_iter(), |element, id| {
v.push((element, id))
})
.await
.unwrap_err();
assert_eq!(err, Missing);
}
{
let mut accessor = StringAccessor::new(&provider);
assert_eq!(accessor.get_element(0).await.unwrap(), "world");
assert_eq!(accessor.get_element(1).await.unwrap(), "foo");
assert_eq!(accessor.get_element(u32::MAX).await.unwrap(), "hello");
let expected = [("bar", 2), ("foo", 1), ("world", 0)];
let mut expected_iter = expected.into_iter();
accessor
.on_elements_unordered([2, 1, 0].into_iter(), |element, id| {
assert_eq!((element, id), expected_iter.next().unwrap());
})
.await
.unwrap();
assert!(expected_iter.next().is_none());
let mut expected_iter = expected.into_iter();
let err = accessor
.on_elements_unordered([2, 1, 0, 3].into_iter(), |element, id| {
assert_eq!((element, id), expected_iter.next().unwrap());
})
.await
.unwrap_err();
assert_eq!(err, Missing);
assert!(expected_iter.next().is_none());
}
}
#[derive(Debug)]
struct Store {
data: Box<[u8]>,
}
impl Store {
fn new() -> Self {
Self {
data: Box::from([1, 2, 3, 4]),
}
}
fn dim(&self) -> usize {
self.data.len()
}
}
macro_rules! common_test_accessor {
($T:ty) => {
impl HasId for $T {
type Id = u32;
}
impl BuildDistanceComputer for $T {
type DistanceComputerError = Infallible;
type DistanceComputer = <u8 as crate::utils::VectorRepr>::Distance;
fn build_distance_computer(&self) -> Result<Self::DistanceComputer, Infallible> {
Ok(<u8 as crate::utils::VectorRepr>::distance(
diskann_vector::distance::Metric::L2,
None,
))
}
}
};
}
struct Allocating<'a> {
store: &'a Store,
}
impl<'a> Allocating<'a> {
fn new(store: &'a Store) -> Self {
Self { store }
}
}
common_test_accessor!(Allocating<'_>);
impl Accessor for Allocating<'_> {
type Element<'a>
= Box<[u8]>
where
Self: 'a;
type ElementRef<'a> = &'a [u8];
type GetError = Infallible;
async fn get_element(&mut self, _: u32) -> Result<Box<[u8]>, Infallible> {
Ok(self.store.data.clone())
}
}
struct Forwarding<'a> {
store: &'a Store,
}
impl<'a> Forwarding<'a> {
fn new(store: &'a Store) -> Self {
Self { store }
}
}
common_test_accessor!(Forwarding<'_>);
impl<'provider> Accessor for Forwarding<'provider> {
type Element<'a>
= &'provider [u8]
where
Self: 'a;
type ElementRef<'a> = &'a [u8];
type GetError = Infallible;
async fn get_element(&mut self, _: u32) -> Result<&'provider [u8], Infallible> {
Ok(&*self.store.data)
}
}
struct Wrapping<'a> {
store: &'a Store,
}
impl<'a> Wrapping<'a> {
fn new(store: &'a Store) -> Self {
Self { store }
}
}
#[derive(Debug)]
struct Wrapped<'a>(&'a [u8]);
impl<'a> Reborrow<'a> for Wrapped<'_> {
type Target = &'a [u8];
fn reborrow(&'a self) -> Self::Target {
self.0
}
}
impl From<Wrapped<'_>> for Box<[u8]> {
fn from(wrapped: Wrapped<'_>) -> Self {
wrapped.0.into()
}
}
common_test_accessor!(Wrapping<'_>);
impl Accessor for Wrapping<'_> {
type Element<'a>
= Wrapped<'a>
where
Self: 'a;
type ElementRef<'a> = &'a [u8];
type GetError = Infallible;
async fn get_element(&mut self, _: u32) -> Result<Wrapped<'_>, Infallible> {
Ok(Wrapped(&self.store.data))
}
}
#[derive(Debug)]
struct Sharing<'a> {
store: &'a Store,
local: Box<[u8]>,
}
impl<'a> Sharing<'a> {
fn new(store: &'a Store) -> Self {
Self {
store,
local: (0..store.dim()).map(|_| 0).collect(),
}
}
}
common_test_accessor!(Sharing<'_>);
impl Accessor for Sharing<'_> {
type Element<'a>
= &'a [u8]
where
Self: 'a;
type ElementRef<'a> = &'a [u8];
type GetError = Infallible;
async fn get_element(&mut self, _: u32) -> Result<&[u8], Infallible> {
self.local.copy_from_slice(&self.store.data);
Ok(&self.local)
}
}
#[tokio::test]
async fn test_accessor_patterns() {
let store = Store::new();
let base: &[u8] = &[2, 3, 4, 5];
{
let mut accessor = Allocating::new(&store);
let computer = accessor.build_distance_computer().unwrap();
let element = accessor.get_element(0).await.unwrap();
assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0);
}
{
let mut accessor = Forwarding::new(&store);
let computer = accessor.build_distance_computer().unwrap();
let element = accessor.get_element(0).await.unwrap();
assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0);
}
{
let mut accessor = Wrapping::new(&store);
let computer = accessor.build_distance_computer().unwrap();
let element = accessor.get_element(0).await.unwrap();
assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0);
}
{
let mut accessor = Sharing::new(&store);
let computer = accessor.build_distance_computer().unwrap();
let element = accessor.get_element(0).await.unwrap();
assert_eq!(computer.evaluate_similarity(base, element.reborrow()), 4.0);
}
}
}