use super::{
bf_cache::{self, Cache},
error::CacheAccessError,
provider::{self as cache_provider, NeighborStatus},
utils::{CacheKey, Graph, HitStats, KeyGen, LocalStats},
};
use diskann::{
error::{RankedError, ToRanked, TransientError},
graph::{AdjacencyList, test::provider as test_provider, test::provider::Context, workingset},
provider::{self as core_provider},
};
use diskann_utils::future::AsyncFriendly;
use diskann_vector::distance::Metric;
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
enum Tag {
AdjacencyList = 0,
Vector = 1,
}
const TAGS: [Tag; 2] = [Tag::AdjacencyList, Tag::Vector];
impl KeyGen<u32> for Tag {
type Key = CacheKey<u32>;
fn generate(&self, id: u32) -> Self::Key {
CacheKey::new(id, (*self as u8).into())
}
}
#[derive(Debug)]
pub struct ExampleCache {
cache: Cache,
uncacheable: Option<Vec<u32>>,
neighbor_stats: HitStats,
vector_stats: HitStats,
}
impl ExampleCache {
pub fn new(
bytes: diskann_quantization::num::PowerOfTwo,
uncacheable: Option<Vec<u32>>,
) -> Self {
Self {
cache: Cache::new(bytes).unwrap(),
uncacheable,
neighbor_stats: HitStats::new(),
vector_stats: HitStats::new(),
}
}
fn invalidate(&self, id: u32) {
for tag in TAGS {
self.cache.delete(tag.generate(id))
}
}
fn neighbors(&self, max_degree: usize) -> Graph<'_, u32, Tag> {
Graph::new(
&self.cache,
max_degree,
Tag::AdjacencyList,
&self.neighbor_stats,
)
}
}
impl cache_provider::Evict<u32> for ExampleCache {
fn evict(&self, id: u32) {
self.invalidate(id)
}
}
#[derive(Debug)]
pub struct CacheAccessor<'a, T> {
graph: Graph<'a, u32, Tag>,
cacher: T,
stats: LocalStats<'a>,
uncacheable: Option<&'a [u32]>,
keygen: Tag,
}
impl<T> cache_provider::ElementCache<u32, diskann_utils::lifetime::Slice<T>>
for CacheAccessor<'_, bf_cache::VecCacher<T>>
where
T: AsyncFriendly + bytemuck::Pod + Default + std::fmt::Debug,
{
type Error = CacheAccessError;
fn get_cached(&mut self, k: u32) -> Result<Option<&[T]>, CacheAccessError> {
match self
.graph
.cache()
.get(self.keygen.generate(k), &mut self.cacher)
{
Ok(Some(value)) => {
self.stats.hit();
Ok(Some(value))
}
Ok(None) => {
self.stats.miss();
Ok(None)
}
Err(err) => Err(CacheAccessError::read(k, err)),
}
}
fn set_cached(&mut self, k: u32, element: &&[T]) -> Result<(), CacheAccessError> {
self.graph
.cache()
.set(self.keygen.generate(k), &mut self.cacher, element)
.map_err(|err| CacheAccessError::write(k, err))
}
}
impl<T> cache_provider::NeighborCache<u32> for CacheAccessor<'_, T>
where
T: AsyncFriendly,
{
type Error = CacheAccessError;
fn try_get_neighbors(
&mut self,
id: u32,
neighbors: &mut AdjacencyList<u32>,
) -> Result<NeighborStatus, CacheAccessError> {
if let Some(uncacheable) = self.uncacheable
&& uncacheable.contains(&id)
{
self.graph.stats_mut().miss();
Ok(NeighborStatus::Uncacheable)
} else {
self.graph.try_get_neighbors(id, neighbors)
}
}
fn set_neighbors(&mut self, id: u32, neighbors: &[u32]) -> Result<(), CacheAccessError> {
self.graph.set_neighbors(id, neighbors)
}
fn invalidate_neighbors(&mut self, id: u32) {
self.graph.invalidate_neighbors(id)
}
}
impl<'a> cache_provider::AsCacheAccessorFor<'a, test_provider::Accessor<'a>> for ExampleCache {
type Accessor = CacheAccessor<'a, bf_cache::VecCacher<f32>>;
type Error = diskann::error::Infallible;
fn as_cache_accessor_for(
&'a self,
inner: test_provider::Accessor<'a>,
) -> Result<
cache_provider::CachingAccessor<test_provider::Accessor<'a>, Self::Accessor>,
Self::Error,
> {
let provider = inner.provider();
let cache_accessor = CacheAccessor {
graph: self.neighbors(provider.max_degree()),
cacher: bf_cache::VecCacher::<f32>::new(provider.dim()),
uncacheable: self.uncacheable.as_deref(),
stats: LocalStats::new(&self.vector_stats),
keygen: Tag::Vector,
};
Ok(cache_provider::CachingAccessor::new(inner, cache_accessor))
}
}
type WorkingSet = workingset::Map<u32, Box<[f32]>, workingset::map::Ref<[f32]>>;
type AccessorCache<'a> = CacheAccessor<'a, bf_cache::VecCacher<f32>>;
impl<'a> cache_provider::CachedFill<AccessorCache<'a>, WorkingSet> for test_provider::Accessor<'a> {
fn cached_fill<'b, Itr>(
&'b mut self,
cache: &'b mut AccessorCache<'a>,
set: &'b mut WorkingSet,
itr: Itr,
) -> impl diskann_utils::future::SendFuture<
Result<Self::View<'b>, cache_provider::CachingError<Self::Error, CacheAccessError>>,
>
where
Itr: ExactSizeIterator<Item = u32> + Clone + Send + Sync,
{
use cache_provider::{CachingError, ElementCache};
use diskann::provider::{Accessor, CacheableAccessor};
async move {
set.prepare(itr.clone());
for i in itr {
match set.entry(i) {
workingset::map::Entry::Seeded(_) | workingset::map::Entry::Occupied(_) => {}
workingset::map::Entry::Vacant(vacant) => {
match cache
.get_cached(i)
.map_err(CachingError::<Self::Error, _>::Cache)?
{
Some(element) => {
vacant.insert(Self::from_cached(element).into());
}
None => {
let element = match self.get_element(i).await {
Ok(element) => element,
Err(err) => match err.to_ranked() {
RankedError::Transient(transient) => {
transient.acknowledge(
"error during mapping of element to cache",
);
continue;
}
RankedError::Error(critical) => {
return Err(CachingError::Inner(critical));
}
},
};
cache
.set_cached(i, Self::as_cached(&element))
.map_err(CachingError::Cache)?;
vacant.insert(element.into());
}
}
}
}
}
Ok(set.view())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use diskann::{
graph::{DiskANNIndex, glue::SearchStrategy},
provider::{
Accessor, DataProvider, Delete, NeighborAccessor, NeighborAccessorMut, SetElement,
},
};
use diskann_quantization::num::PowerOfTwo;
use diskann_utils::views::Matrix;
use diskann_vector::{PureDistanceFunction, distance::SquaredL2};
use rstest::rstest;
use crate::{
index::diskann_async::tests as async_tests,
model::graph::provider::async_::caching::provider::{AsCacheAccessorFor, CachingProvider},
};
fn test_provider(
uncacheable: Option<Vec<u32>>,
) -> CachingProvider<test_provider::Provider, ExampleCache> {
let dim = 2;
let config = test_provider::Config::new(
Metric::L2,
10,
test_provider::StartPoint::new(u32::MAX, vec![0.0; dim]),
)
.unwrap();
CachingProvider::new(
test_provider::Provider::new(config),
ExampleCache::new(PowerOfTwo::new(1024 * 16).unwrap(), uncacheable),
)
}
fn make_accessor<'a>(
provider: &'a CachingProvider<test_provider::Provider, ExampleCache>,
) -> cache_provider::CachingAccessor<
test_provider::Accessor<'a>,
CacheAccessor<'a, bf_cache::VecCacher<f32>>,
> {
provider
.cache()
.as_cache_accessor_for(test_provider::Accessor::new(provider.inner()))
.unwrap()
}
#[tokio::test]
async fn basic_operations_happy_path() {
let provider = test_provider(None);
let ctx = &Context::new();
assert!(provider.to_external_id(ctx, 0).is_err());
assert!(provider.to_internal_id(ctx, &0).is_err());
assert_eq!(provider.inner().metrics().set_vector, 0);
provider.set_element(ctx, &0, &[1.0, 2.0]).await.unwrap();
assert_eq!(
provider.inner().metrics().set_vector,
1
);
assert_eq!(provider.to_external_id(ctx, 0).unwrap(), 0);
assert_eq!(provider.to_internal_id(ctx, &0).unwrap(), 0);
{
let mut accessor = make_accessor(&provider);
assert_eq!(provider.inner().metrics().get_vector, 0);
let element = accessor.get_element(0).await.unwrap();
assert_eq!(element, &[1.0, 2.0]);
assert_eq!(
accessor.cache().stats.get_local_misses(),
1,
);
assert_eq!(accessor.cache().stats.get_local_hits(), 0);
}
assert_eq!(provider.inner().metrics().get_vector, 1);
{
let mut accessor = make_accessor(&provider);
let element = accessor.get_element(0).await.unwrap();
assert_eq!(element, &[1.0, 2.0]);
assert_eq!(accessor.cache().stats.get_local_misses(), 0);
assert_eq!(
accessor.cache().stats.get_local_hits(),
1,
);
}
assert_eq!(provider.inner().metrics().get_vector, 1);
let mut list = AdjacencyList::new();
{
let mut accessor = make_accessor(&provider);
assert_eq!(provider.inner().metrics().set_neighbors, 0);
accessor.set_neighbors(0, &[1, 2, 3]).await.unwrap();
assert_eq!(
provider.inner().metrics().set_neighbors,
1,
);
assert_eq!(provider.inner().metrics().get_neighbors, 0);
accessor.get_neighbors(0, &mut list).await.unwrap();
assert_eq!(
accessor.cache().graph.stats().get_local_misses(),
1,
);
assert_eq!(accessor.cache().graph.stats().get_local_hits(), 0);
assert_eq!(&*list, &[1, 2, 3]);
}
assert_eq!(
provider.inner().metrics().get_neighbors,
1,
);
{
let mut accessor = make_accessor(&provider);
list.clear();
accessor.get_neighbors(0, &mut list).await.unwrap();
assert_eq!(&*list, &[1, 2, 3]);
assert_eq!(accessor.cache().graph.stats().get_local_misses(), 0);
assert_eq!(
accessor.cache().graph.stats().get_local_hits(),
1,
);
}
assert_eq!(provider.inner().metrics().get_neighbors, 1);
{
let mut accessor = make_accessor(&provider);
provider.cache().invalidate(0);
let element = accessor.get_element(0).await.unwrap();
assert_eq!(element, &[1.0, 2.0]);
assert_eq!(
accessor.cache().stats.get_local_misses(),
1,
);
assert_eq!(accessor.cache().stats.get_local_hits(), 0);
}
assert_eq!(
provider.inner().metrics().get_vector,
2,
);
{
let mut accessor = make_accessor(&provider);
let element = accessor.get_element(0).await.unwrap();
assert_eq!(element, &[1.0, 2.0]);
assert_eq!(accessor.cache().stats.get_local_misses(), 0);
assert_eq!(
accessor.cache().stats.get_local_hits(),
1,
);
}
assert_eq!(provider.inner().metrics().get_vector, 2);
{
let mut accessor = make_accessor(&provider);
list.clear();
accessor.get_neighbors(0, &mut list).await.unwrap();
assert_eq!(&*list, &[1, 2, 3]);
assert_eq!(
provider.inner().metrics().get_neighbors,
2,
);
assert_eq!(
accessor.cache().graph.stats().get_local_misses(),
1,
);
assert_eq!(accessor.cache().graph.stats().get_local_hits(), 0);
accessor.set_neighbors(0, &[2, 3, 4]).await.unwrap();
accessor.get_neighbors(0, &mut list).await.unwrap();
assert_eq!(&*list, &[2, 3, 4]);
assert_eq!(
provider.inner().metrics().set_neighbors,
2,
);
assert_eq!(
provider.inner().metrics().get_neighbors,
3,
);
assert_eq!(
accessor.cache().graph.stats().get_local_misses(),
2,
);
assert_eq!(accessor.cache().graph.stats().get_local_hits(), 0);
accessor.append_vector(0, &[1]).await.unwrap();
accessor.get_neighbors(0, &mut list).await.unwrap();
assert_eq!(&*list, &[2, 3, 4, 1]);
assert_eq!(provider.inner().metrics().set_neighbors, 2,);
assert_eq!(
provider.inner().metrics().get_neighbors,
4,
);
assert_eq!(
accessor.cache().graph.stats().get_local_misses(),
3,
);
assert_eq!(accessor.cache().graph.stats().get_local_hits(), 0);
assert_eq!(
provider.status_by_internal_id(ctx, 0).await.unwrap(),
core_provider::ElementStatus::Valid
);
assert_eq!(
provider.status_by_external_id(ctx, &0).await.unwrap(),
core_provider::ElementStatus::Valid
);
assert!(provider.status_by_internal_id(ctx, 1).await.is_err());
assert!(provider.status_by_external_id(ctx, &1).await.is_err());
provider.delete(ctx, &0).await.unwrap();
assert_eq!(
provider.status_by_internal_id(ctx, 0).await.unwrap(),
core_provider::ElementStatus::Deleted
);
assert_eq!(
provider.status_by_external_id(ctx, &0).await.unwrap(),
core_provider::ElementStatus::Deleted
);
assert!(provider.status_by_internal_id(ctx, 1).await.is_err());
assert!(provider.status_by_external_id(ctx, &1).await.is_err());
let element = accessor.get_element(0).await.unwrap();
assert_eq!(element, &[1.0, 2.0]);
assert_eq!(provider.inner().metrics().get_vector, 2);
assert_eq!(accessor.cache().stats.get_local_misses(), 0);
assert_eq!(
accessor.cache().stats.get_local_hits(),
1,
);
accessor.get_neighbors(0, &mut list).await.unwrap();
assert_eq!(&*list, &[2, 3, 4, 1]);
assert_eq!(provider.inner().metrics().set_neighbors, 2);
assert_eq!(
provider.inner().metrics().get_neighbors,
4,
);
assert_eq!(accessor.cache().graph.stats().get_local_misses(), 3,);
assert_eq!(
accessor.cache().graph.stats().get_local_hits(),
1,
);
provider.release(ctx, 0).await.unwrap();
assert!(provider.status_by_internal_id(ctx, 0).await.is_err());
assert!(provider.status_by_external_id(ctx, &0).await.is_err());
assert!(accessor.get_element(0).await.is_err());
assert_eq!(provider.inner().metrics().get_vector, 2);
assert_eq!(
accessor.cache().stats.get_local_misses(),
1
);
assert_eq!(accessor.cache().stats.get_local_hits(), 1);
assert!(accessor.get_neighbors(0, &mut list).await.is_err());
assert_eq!(provider.inner().metrics().set_neighbors, 2);
assert_eq!(provider.inner().metrics().get_neighbors, 4);
assert_eq!(accessor.cache().graph.stats().get_local_misses(), 4);
assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1);
}
assert_eq!(provider.cache().vector_stats.get_hits(), 3);
assert_eq!(provider.cache().vector_stats.get_misses(), 3);
assert_eq!(provider.cache().neighbor_stats.get_hits(), 2);
assert_eq!(provider.cache().neighbor_stats.get_misses(), 5);
}
#[tokio::test]
async fn test_uncacheable() {
let uncacheable = u32::MAX;
let provider = test_provider(Some(vec![uncacheable]));
let ctx = &Context::new();
let mut accessor = provider
.cache()
.as_cache_accessor_for(test_provider::Accessor::new(provider.inner()))
.unwrap();
provider.set_element(ctx, &0, &[1.0, 2.0]).await.unwrap();
assert_eq!(provider.inner().metrics().set_neighbors, 0);
accessor.set_neighbors(0, &[1, 2, 3]).await.unwrap();
assert_eq!(
provider.inner().metrics().set_neighbors,
1,
);
let mut list = AdjacencyList::new();
assert_eq!(provider.inner().metrics().get_neighbors, 0);
accessor.get_neighbors(0, &mut list).await.unwrap();
assert_eq!(
provider.inner().metrics().get_neighbors,
1,
);
assert_eq!(
accessor.cache().graph.stats().get_local_misses(),
1,
);
assert_eq!(accessor.cache().graph.stats().get_local_hits(), 0);
assert_eq!(&*list, &[1, 2, 3]);
list.clear();
accessor.get_neighbors(0, &mut list).await.unwrap();
assert_eq!(&*list, &[1, 2, 3]);
assert_eq!(provider.inner().metrics().get_neighbors, 1);
assert_eq!(accessor.cache().graph.stats().get_local_misses(), 1);
assert_eq!(
accessor.cache().graph.stats().get_local_hits(),
1,
);
assert_eq!(provider.inner().metrics().set_neighbors, 1);
accessor.set_neighbors(uncacheable, &[4, 5]).await.unwrap();
assert_eq!(
provider.inner().metrics().set_neighbors,
2,
);
assert_eq!(provider.inner().metrics().get_neighbors, 1);
accessor
.get_neighbors(uncacheable, &mut list)
.await
.unwrap();
assert_eq!(
provider.inner().metrics().get_neighbors,
2,
);
assert_eq!(
accessor.cache().graph.stats().get_local_misses(),
2,
);
assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1);
assert_eq!(&*list, &[4, 5]);
assert_eq!(provider.inner().metrics().get_neighbors, 2);
accessor
.get_neighbors(uncacheable, &mut list)
.await
.unwrap();
assert_eq!(
provider.inner().metrics().get_neighbors,
3,
);
assert_eq!(
accessor.cache().graph.stats().get_local_misses(),
3,
);
assert_eq!(accessor.cache().graph.stats().get_local_hits(), 1);
assert_eq!(&*list, &[4, 5]);
}
#[rstest]
#[case(1, 100)]
#[case(3, 7)]
#[case(4, 5)]
#[tokio::test]
async fn grid_search(#[case] dim: usize, #[case] grid_size: usize) {
let l = 10;
let max_degree = 2 * dim;
let num_points = (grid_size).pow(dim as u32);
let start_id = u32::MAX;
let start_point = vec![grid_size as f32; dim];
let metric = Metric::L2;
let cache_size = PowerOfTwo::new(128 * 1024).unwrap();
let index_config = diskann::graph::config::Builder::new(
max_degree,
diskann::graph::config::MaxDegree::default_slack(),
l,
metric.into(),
)
.build()
.unwrap();
let test_config = test_provider::Config::new(
metric,
index_config.max_degree().get(),
test_provider::StartPoint::new(start_id, start_point.clone()),
)
.unwrap();
let ctx = &Context::new();
let mut vectors = <f32 as async_tests::GenerateGrid>::generate_grid(dim, grid_size);
let provider = CachingProvider::new(
test_provider::Provider::new(test_config),
ExampleCache::new(cache_size, None),
);
let index = Arc::new(DiskANNIndex::new(index_config, provider, None));
let adjacency_lists = async_tests::grid_from_dim(dim).neighbors(grid_size);
assert_eq!(adjacency_lists.len(), num_points);
assert_eq!(vectors.len(), num_points);
let strategy = cache_provider::Cached::new(test_provider::Strategy::new());
async_tests::populate_data(index.provider(), ctx, &vectors).await;
{
let mut accessor =
<cache_provider::Cached<test_provider::Strategy> as SearchStrategy<
cache_provider::CachingProvider<test_provider::Provider, ExampleCache>,
&[f32],
>>::search_accessor(&strategy, index.provider(), ctx)
.unwrap();
async_tests::populate_graph(&mut accessor, &adjacency_lists).await;
accessor
.set_neighbors(start_id, &[num_points as u32 - 1])
.await
.unwrap();
}
let corpus: diskann_utils::views::Matrix<f32> = async_tests::squish(vectors.iter(), dim);
let mut paged_tests = Vec::new();
let query = vec![0.0; dim];
let gt = crate::test_utils::groundtruth(corpus.as_view(), &query, |a, b| {
SquaredL2::evaluate(a, b)
});
paged_tests.push(async_tests::PagedSearch::new(query, gt));
let gt = crate::test_utils::groundtruth(corpus.as_view(), &start_point, |a, b| {
SquaredL2::evaluate(a, b)
});
paged_tests.push(async_tests::PagedSearch::new(start_point.clone(), gt));
vectors.push(start_point.clone());
async_tests::check_grid_search(
&index,
&vectors,
&paged_tests,
strategy.clone(),
strategy.clone(),
)
.await;
}
fn check_stats(caching: &CachingProvider<test_provider::Provider, ExampleCache>) {
let provider = caching.inner();
let cache = caching.cache();
println!("neighbor reads: {}", provider.metrics().get_neighbors);
println!("neighbor writes: {}", provider.metrics().set_neighbors);
println!("vector reads: {}", provider.metrics().get_vector);
println!("vector writes: {}", provider.metrics().set_vector);
println!("neighbor hits: {}", cache.neighbor_stats.get_hits());
println!("neighbor misses: {}", cache.neighbor_stats.get_misses());
println!("vector hits: {}", cache.vector_stats.get_hits());
println!("vector misses: {}", cache.vector_stats.get_misses());
assert_eq!(
provider.metrics().get_neighbors,
cache.neighbor_stats.get_misses()
);
assert_eq!(
provider.metrics().get_vector,
cache.vector_stats.get_misses()
);
}
#[rstest]
#[tokio::test]
async fn grid_search_with_build(
#[values((1, 100), (3, 7), (4, 5))] dim_and_size: (usize, usize),
) {
let dim = dim_and_size.0;
let grid_size = dim_and_size.1;
let l = 10;
let start_point = vec![grid_size as f32; dim];
let metric = Metric::L2;
let cache_size = PowerOfTwo::new(128 * 1024).unwrap();
let max_degree = 2 * dim;
let num_points = (grid_size).pow(dim as u32);
let mut vectors = <f32 as async_tests::GenerateGrid>::generate_grid(dim, grid_size);
let index_config = diskann::graph::config::Builder::new_with(
max_degree,
diskann::graph::config::MaxDegree::default_slack(),
l,
metric.into(),
|b| {
b.max_minibatch_par(10);
},
)
.build()
.unwrap();
let test_config = test_provider::Config::new(
metric,
index_config.max_degree().get(),
test_provider::StartPoint::new(u32::MAX, start_point.clone()),
)
.unwrap();
assert_eq!(vectors.len(), num_points);
vectors.push(vec![grid_size as f32; dim]);
let init_index = || {
let provider = CachingProvider::new(
test_provider::Provider::new(test_config.clone()),
ExampleCache::new(cache_size, None),
);
Arc::new(DiskANNIndex::new(index_config.clone(), provider, None))
};
let strategy = cache_provider::Cached::new(test_provider::Strategy::new());
let ctx = &Context::new();
{
let index = init_index();
for (i, v) in vectors.iter().take(num_points).enumerate() {
index
.insert(strategy.clone(), ctx, &(i as u32), v.as_slice())
.await
.unwrap();
}
check_stats(index.provider());
async_tests::check_grid_search(
&index,
&vectors,
&[],
strategy.clone(),
strategy.clone(),
)
.await;
check_stats(index.provider());
}
{
let index = init_index();
let batch = Arc::new(async_tests::squish(vectors.iter().take(num_points), dim));
let ids: Arc<[u32]> = (0..num_points as u32).collect();
index
.multi_insert::<_, Matrix<f32>>(strategy.clone(), ctx, batch, ids)
.await
.unwrap();
async_tests::check_grid_search(
&index,
&vectors,
&[],
strategy.clone(),
strategy.clone(),
)
.await;
check_stats(index.provider());
}
}
#[tokio::test]
async fn test_inplace_delete_2d() {
let metric = Metric::L2;
let num_points = 4;
let strategy = cache_provider::Cached::new(test_provider::Strategy::new());
let cache_size = PowerOfTwo::new(128 * 1024).unwrap();
let start_id = num_points as u32;
let start_point = vec![0.5, 0.5];
let index_config = diskann::graph::config::Builder::new(
4, diskann::graph::config::MaxDegree::default_slack(),
10, metric.into(),
)
.build()
.unwrap();
let ctx = &Context::new();
let test_config = test_provider::Config::new(
metric,
index_config.max_degree().get(),
test_provider::StartPoint::new(start_id, start_point.clone()),
)
.unwrap();
let index = DiskANNIndex::new(
index_config,
CachingProvider::new(
test_provider::Provider::new(test_config),
ExampleCache::new(cache_size, None),
),
None,
);
let vectors = [
vec![0.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 0.0],
vec![1.0, 1.0],
];
let adjacency_lists = [
AdjacencyList::from_iter_untrusted([4, 1]),
AdjacencyList::from_iter_untrusted([4, 0]),
AdjacencyList::from_iter_untrusted([4, 3]),
AdjacencyList::from_iter_untrusted([4, 2]),
AdjacencyList::from_iter_untrusted([0, 1, 2, 3]),
];
let mut accessor = <cache_provider::Cached<test_provider::Strategy> as SearchStrategy<
cache_provider::CachingProvider<test_provider::Provider, ExampleCache>,
&[f32],
>>::search_accessor(&strategy, index.provider(), ctx)
.unwrap();
async_tests::populate_data(index.provider(), ctx, &vectors).await;
async_tests::populate_graph(&mut accessor, &adjacency_lists).await;
index
.inplace_delete(
strategy.clone(),
ctx,
&3, 3, diskann::graph::InplaceDeleteMethod::VisitedAndTopK {
k_value: 4,
l_value: 10,
},
)
.await
.unwrap();
assert!(
index
.data_provider
.status_by_internal_id(ctx, 3)
.await
.unwrap()
.is_deleted()
);
{
let mut list = AdjacencyList::new();
accessor.get_neighbors(4, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[0, 1, 2]);
}
{
let mut list = AdjacencyList::new();
accessor.get_neighbors(2, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[0, 1, 4]);
}
{
let mut list = AdjacencyList::new();
accessor.get_neighbors(0, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[1, 2, 4]);
}
{
let mut list = AdjacencyList::new();
accessor.get_neighbors(1, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[0, 2, 4]);
}
{
let mut list = AdjacencyList::new();
accessor.get_neighbors(3, &mut list).await.unwrap();
assert!(list.is_empty());
}
}
}