use std::{fmt::Debug, future::Future, num::NonZeroUsize};
use crate::storage::{StorageReadProvider, StorageWriteProvider};
#[cfg(test)]
use diskann::neighbor::Neighbor;
use diskann::{
ANNError, ANNResult,
graph::AdjacencyList,
provider::{
DataProvider, DefaultAccessor, DefaultContext, Delete, ElementStatus, ExecutionContext,
NeighborAccessor, NeighborAccessorMut, NoopGuard, SetElement,
},
utils::{IntoUsize, ONE, VectorRepr},
};
use diskann_utils::future::AsyncFriendly;
use diskann_vector::distance::Metric;
use crate::{
model::graph::provider::async_::{
SimpleNeighborProviderAsync, StartPoints, TableDeleteProviderAsync,
common::{
CreateDeleteProvider, CreateVectorStore, NoDeletes, NoStore, PrefetchCacheLineLevel,
SetElementHelper, VectorStore,
},
},
storage::{AsyncIndexMetadata, AsyncQuantLoadContext, DiskGraphOnly, LoadWith, SaveWith},
};
pub struct DefaultProvider<U, V = NoStore, D = NoDeletes, Ctx = DefaultContext> {
pub base_vectors: U,
pub aux_vectors: V,
pub(crate) neighbor_provider: SimpleNeighborProviderAsync<u32>,
pub(super) deleted: D,
pub(super) metric: Metric,
pub(super) start_points: StartPoints,
context: std::marker::PhantomData<Ctx>,
}
#[derive(Debug, Clone)]
pub struct DefaultProviderParameters {
pub max_points: usize,
pub frozen_points: NonZeroUsize,
pub dim: usize,
pub metric: Metric,
pub prefetch_lookahead: Option<usize>,
pub prefetch_cache_line_level: Option<PrefetchCacheLineLevel>,
pub max_degree: u32,
}
impl DefaultProviderParameters {
pub fn simple(max_points: usize, dim: usize, metric: Metric, max_degree: u32) -> Self {
Self {
max_points,
frozen_points: ONE,
metric,
dim,
prefetch_lookahead: None,
prefetch_cache_line_level: None,
max_degree,
}
}
}
impl<U, V, D, Ctx> DefaultProvider<U, V, D, Ctx> {
pub fn new_empty<CU, CV, CD>(
params: DefaultProviderParameters,
base_precursor: CU,
aux_precursor: CV,
delete_precursor: CD,
) -> ANNResult<Self>
where
CU: CreateVectorStore<Target = U>,
CV: CreateVectorStore<Target = V>,
CD: CreateDeleteProvider<Target = D>,
{
let npts = params.max_points + params.frozen_points.get();
Ok(Self {
base_vectors: base_precursor.create(npts, params.metric, params.prefetch_lookahead),
aux_vectors: aux_precursor.create(npts, params.metric, params.prefetch_lookahead),
neighbor_provider: SimpleNeighborProviderAsync::new(npts, 1, params.max_degree, 1.0),
deleted: delete_precursor.create(npts),
metric: params.metric,
start_points: StartPoints::new(params.max_points as u32, params.frozen_points)?,
context: std::marker::PhantomData,
})
}
#[cfg(test)]
pub(crate) fn is_not_start_point(&self) -> impl Fn(&Neighbor<u32>) -> bool {
let range = self.start_points.range();
move |neighbor| !range.contains(&neighbor.id)
}
pub fn starting_points(&self) -> ANNResult<Vec<u32>> {
Ok(self.start_points.range().collect())
}
pub fn iter(&self) -> std::ops::Range<u32> {
0..self.start_points.end()
}
pub fn neighbors(&self) -> &SimpleNeighborProviderAsync<u32> {
&self.neighbor_provider
}
pub fn num_start_points(&self) -> usize {
self.start_points.len()
}
pub fn capacity(&self) -> usize {
self.start_points.start().into_usize()
}
pub fn total_points(&self) -> usize {
self.start_points.end().into_usize()
}
}
impl<U, V, D, Ctx> IntoIterator for &DefaultProvider<U, V, D, Ctx> {
type Item = u32;
type IntoIter = std::ops::Range<u32>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<U, V, Ctx> DefaultProvider<U, V, TableDeleteProviderAsync, Ctx> {
pub fn clear_delete_set(&self) {
self.deleted.clear();
}
}
impl<U, V, D, Ctx> DefaultProvider<U, V, D, Ctx>
where
U: VectorStore,
V: VectorStore,
{
pub fn counts_for_get_vector(&self) -> (usize, usize) {
(
self.base_vectors.count_for_get_vector(),
self.aux_vectors.count_for_get_vector(),
)
}
}
pub trait SetStartPoints<T>
where
T: ?Sized + 'static,
{
fn set_start_points<'a, Itr>(&self, itr: Itr) -> ANNResult<()>
where
Itr: ExactSizeIterator<Item = &'a T> + 'a;
}
impl<T, U, V, D> SetStartPoints<[T]> for DefaultProvider<U, V, D>
where
U: SetElementHelper<T>,
V: SetElementHelper<T>,
T: std::fmt::Debug + 'static,
{
fn set_start_points<'a, Itr>(&self, itr: Itr) -> ANNResult<()>
where
Itr: ExactSizeIterator<Item = &'a [T]> + 'a,
{
let start_points = self.start_points.range();
if itr.len() != start_points.len() {
return Err(ANNError::log_async_index_error(format!(
"expected `itr` to contain `{}` items, instead it has {}",
start_points.len(),
itr.len(),
)));
}
for (i, v) in std::iter::zip(start_points, itr) {
self.aux_vectors.set_element(&i, v)?;
self.base_vectors.set_element(&i, v)?;
}
Ok(())
}
}
impl<U, V, D, Ctx> SaveWith<(u32, AsyncIndexMetadata)> for DefaultProvider<U, V, D, Ctx>
where
U: AsyncFriendly + SaveWith<AsyncIndexMetadata>,
V: AsyncFriendly + SaveWith<AsyncIndexMetadata>,
D: AsyncFriendly,
ANNError: From<U::Error> + From<V::Error>,
Ctx: ExecutionContext,
{
type Ok = ();
type Error = ANNError;
async fn save_with<P>(
&self,
provider: &P,
auxiliary: &(u32, AsyncIndexMetadata),
) -> Result<Self::Ok, Self::Error>
where
P: StorageWriteProvider,
{
self.base_vectors.save_with(provider, &auxiliary.1).await?;
self.aux_vectors.save_with(provider, &auxiliary.1).await?;
self.neighbor_provider
.save_with(provider, auxiliary)
.await?;
Ok(())
}
}
impl<U, V, D, Ctx> SaveWith<(u32, u32, DiskGraphOnly)> for DefaultProvider<U, V, D, Ctx>
where
U: AsyncFriendly,
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
{
type Ok = ();
type Error = ANNError;
async fn save_with<P>(
&self,
provider: &P,
auxiliary: &(u32, u32, DiskGraphOnly),
) -> Result<Self::Ok, Self::Error>
where
P: StorageWriteProvider,
{
self.neighbor_provider
.save_with(provider, auxiliary)
.await?;
Ok(())
}
}
impl<U, V, D, Ctx> LoadWith<AsyncQuantLoadContext> for DefaultProvider<U, V, D, Ctx>
where
U: VectorStore + LoadWith<AsyncQuantLoadContext>,
V: VectorStore + AsyncFriendly + LoadWith<AsyncQuantLoadContext>,
D: AsyncFriendly + LoadWith<usize>,
ANNError: From<U::Error> + From<V::Error> + From<D::Error>,
Ctx: ExecutionContext,
{
type Error = ANNError;
async fn load_with<P>(provider: &P, ctx: &AsyncQuantLoadContext) -> ANNResult<Self>
where
P: StorageReadProvider,
{
let base_vectors = U::load_with(provider, ctx).await?;
let aux_vectors = V::load_with(provider, ctx).await?;
let deleted = D::load_with(provider, &base_vectors.total()).await?;
let npts = std::cmp::max(base_vectors.total(), aux_vectors.total());
let valid_points = npts
.checked_sub(ctx.num_frozen_points.get())
.ok_or_else(|| {
ANNError::log_index_error(format_args!(
"Expected {} start points but the stored index only has {} total points",
ctx.num_frozen_points.get(),
base_vectors.total(),
))
})?;
let start_points = StartPoints::new(valid_points as u32, ctx.num_frozen_points)?;
Ok(Self {
base_vectors,
aux_vectors,
neighbor_provider: SimpleNeighborProviderAsync::load_with(provider, ctx).await?,
deleted,
metric: ctx.metric,
start_points,
context: std::marker::PhantomData,
})
}
}
impl LoadWith<usize> for NoDeletes {
type Error = ANNError;
async fn load_with<P>(_: &P, _num_points: &usize) -> ANNResult<Self>
where
P: StorageReadProvider,
{
Ok(NoDeletes)
}
}
impl LoadWith<usize> for TableDeleteProviderAsync {
type Error = ANNError;
async fn load_with<P>(_: &P, num_points: &usize) -> ANNResult<Self>
where
P: StorageReadProvider,
{
Ok(TableDeleteProviderAsync::new(*num_points))
}
}
impl<U, V, D, Ctx> DataProvider for DefaultProvider<U, V, D, Ctx>
where
U: AsyncFriendly,
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
{
type Context = Ctx;
type InternalId = u32;
type ExternalId = u32;
type Error = ANNError;
type Guard = NoopGuard<u32>;
fn to_internal_id(
&self,
_context: &Self::Context,
gid: &Self::ExternalId,
) -> Result<Self::InternalId, Self::Error> {
Ok(*gid)
}
fn to_external_id(
&self,
_context: &Self::Context,
id: Self::InternalId,
) -> Result<Self::ExternalId, Self::Error> {
Ok(id)
}
}
impl<U, V, Ctx> Delete for DefaultProvider<U, V, TableDeleteProviderAsync, Ctx>
where
U: AsyncFriendly,
V: AsyncFriendly,
Ctx: ExecutionContext,
{
fn release(
&self,
_context: &Ctx,
id: Self::InternalId,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
self.deleted.undelete(id.into_usize());
let res = self
.neighbor_provider
.set_neighbors_sync(id.into_usize(), &[])
.map_err(|err| err.context(format!("resetting neighbors for undeleted id {}", id)));
std::future::ready(res)
}
#[inline]
fn delete(
&self,
_context: &Ctx,
gid: &Self::ExternalId,
) -> impl Future<Output = Result<(), Self::Error>> + Send {
self.deleted.delete(gid.into_usize());
std::future::ready(Ok(()))
}
#[inline]
fn status_by_external_id(
&self,
context: &Ctx,
gid: &Self::ExternalId,
) -> impl Future<Output = Result<ElementStatus, Self::Error>> + Send {
self.status_by_internal_id(context, *gid)
}
#[inline]
fn status_by_internal_id(
&self,
_context: &Ctx,
id: Self::InternalId,
) -> impl Future<Output = Result<ElementStatus, Self::Error>> + Send {
let status = if self.deleted.is_deleted(id.into_usize()) {
ElementStatus::Deleted
} else {
ElementStatus::Valid
};
std::future::ready(Ok(status))
}
}
impl NeighborAccessor for &SimpleNeighborProviderAsync<u32> {
async fn get_neighbors(
self,
id: Self::Id,
neighbors: &mut AdjacencyList<Self::Id>,
) -> ANNResult<Self> {
self.get_neighbors_sync(id.into_usize(), neighbors)?;
Ok(self)
}
}
impl NeighborAccessorMut for &SimpleNeighborProviderAsync<u32> {
async fn set_neighbors(self, id: u32, neighbors: &[u32]) -> ANNResult<Self> {
self.set_neighbors_sync(id.into_usize(), neighbors)?;
Ok(self)
}
async fn append_vector(self, id: u32, new_neighbor_ids: &[u32]) -> ANNResult<Self> {
self.append_vector_sync(id.into_usize(), new_neighbor_ids)?;
Ok(self)
}
}
impl<U, V, D, Ctx> DefaultAccessor for DefaultProvider<U, V, D, Ctx>
where
U: AsyncFriendly,
V: AsyncFriendly,
D: AsyncFriendly,
Ctx: ExecutionContext,
{
type Accessor<'a> = &'a SimpleNeighborProviderAsync<u32>;
fn default_accessor(&self) -> Self::Accessor<'_> {
self.neighbors()
}
}
impl<U, V, D, Ctx, T> SetElement<&[T]> for DefaultProvider<U, V, D, Ctx>
where
T: VectorRepr,
U: AsyncFriendly + SetElementHelper<T>,
V: AsyncFriendly + SetElementHelper<T>,
D: AsyncFriendly,
Ctx: ExecutionContext,
{
type SetError = ANNError;
fn set_element(
&self,
_context: &Self::Context,
id: &u32,
element: &[T],
) -> impl Future<Output = Result<Self::Guard, Self::SetError>> + Send {
if let Err(err) = self.aux_vectors.set_element(id, element) {
return std::future::ready(Err(err));
}
if let Err(err) = self.base_vectors.set_element(id, element) {
return std::future::ready(Err(err));
}
std::future::ready(Ok(NoopGuard::new(*id)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::graph::provider::async_::{
common::{NoStore, TableBasedDeletes},
inmem::CreateFullPrecision,
};
#[tokio::test]
async fn test_data_provider_and_delete_interface() {
let ctx = &DefaultContext;
let provider = DefaultProvider::new_empty(
DefaultProviderParameters {
max_points: 10,
frozen_points: NonZeroUsize::new(2).unwrap(),
dim: 5,
metric: Metric::L2,
prefetch_lookahead: None,
max_degree: (64.0 * 1.2) as u32,
prefetch_cache_line_level: None,
},
CreateFullPrecision::<f32>::new(5, None),
NoStore,
TableBasedDeletes,
)
.unwrap();
assert_eq!((&provider).into_iter(), 0..(10 + 2));
let iter = provider.iter();
for i in iter.clone() {
assert_eq!(provider.to_external_id(ctx, i).unwrap(), i);
assert_eq!(provider.to_internal_id(ctx, &i).unwrap(), i);
assert_eq!(
provider.status_by_internal_id(ctx, i).await.unwrap(),
ElementStatus::Valid
);
assert_eq!(
provider.status_by_external_id(ctx, &i).await.unwrap(),
ElementStatus::Valid
);
provider.delete(ctx, &i).await.unwrap();
assert_eq!(
provider.status_by_internal_id(ctx, i).await.unwrap(),
ElementStatus::Deleted
);
assert_eq!(
provider.status_by_external_id(ctx, &i).await.unwrap(),
ElementStatus::Deleted
);
}
for i in iter.clone() {
provider
.neighbor_provider
.set_neighbors(i, &[1, 2])
.await
.unwrap();
provider.release(ctx, i).await.unwrap();
assert_eq!(
provider.status_by_internal_id(ctx, i).await.unwrap(),
ElementStatus::Valid
);
assert_eq!(
provider.status_by_external_id(ctx, &i).await.unwrap(),
ElementStatus::Valid
);
let mut neighbors = AdjacencyList::new();
provider
.neighbor_provider
.get_neighbors(i, &mut neighbors)
.await
.unwrap();
assert!(neighbors.to_vec().is_empty());
provider.delete(ctx, &i).await.unwrap();
}
provider.clear_delete_set();
for i in iter.clone() {
assert_eq!(
provider.status_by_internal_id(ctx, i).await.unwrap(),
ElementStatus::Valid
);
assert_eq!(
provider.status_by_external_id(ctx, &i).await.unwrap(),
ElementStatus::Valid
);
}
assert!(
provider
.set_element(ctx, &100, &[1.0, 2.0, 3.0, 4.0])
.await
.is_err()
);
}
}