use std::sync::Arc;
use diskann::{
ANNResult,
graph::{Config, DiskANNIndex},
utils::VectorRepr,
};
use diskann_utils::future::AsyncFriendly;
use crate::model::{
self,
graph::provider::async_::{
common::{CreateDeleteProvider, CreateVectorStore, NoDeletes, NoStore},
inmem::{
CreateFullPrecision, DefaultProvider, DefaultProviderParameters, DefaultQuant,
FullPrecisionProvider,
},
},
};
#[cfg(test)]
pub(crate) fn simplified_builder(
l_search: usize,
pruned_degree: usize,
metric: diskann_vector::distance::Metric,
dim: usize,
max_points: usize,
modify: impl FnOnce(&mut diskann::graph::config::Builder),
) -> ANNResult<(Config, DefaultProviderParameters)> {
let config = diskann::graph::config::Builder::new_with(
pruned_degree,
diskann::graph::config::MaxDegree::default_slack(),
l_search,
metric.into(),
modify,
)
.build()?;
let params = DefaultProviderParameters {
max_points,
frozen_points: diskann::utils::ONE,
metric,
dim,
prefetch_lookahead: None,
prefetch_cache_line_level: None,
max_degree: config.max_degree_u32().get(),
};
Ok((config, params))
}
pub fn train_pq<Pool>(
data: diskann_utils::views::MatrixView<f32>,
num_pq_chunks: usize,
rng: &mut dyn rand::RngCore,
pool: Pool,
) -> ANNResult<model::pq::FixedChunkPQTable>
where
Pool: crate::utils::AsThreadPool,
{
let dim = data.ncols();
let pivot_args = model::GeneratePivotArguments::new(
data.nrows(),
data.ncols(),
model::pq::NUM_PQ_CENTROIDS,
num_pq_chunks,
5,
false,
)?;
let mut centroid = vec![0.0; dim];
let mut offsets = vec![0; num_pq_chunks + 1];
let mut full_pivot_data = vec![0.0; model::pq::NUM_PQ_CENTROIDS * dim];
model::pq::generate_pq_pivots_from_membuf(
&pivot_args,
data.as_slice(),
&mut centroid,
&mut offsets,
&mut full_pivot_data,
rng,
&mut (false),
pool,
)?;
model::pq::FixedChunkPQTable::new(dim, full_pivot_data.into(), centroid.into(), offsets.into())
}
pub type MemoryIndex<T, D = NoDeletes> = Arc<DiskANNIndex<FullPrecisionProvider<T, NoStore, D>>>;
pub type QuantMemoryIndex<T, Q, D = NoDeletes> = Arc<DiskANNIndex<FullPrecisionProvider<T, Q, D>>>;
pub type PQMemoryIndex<T, D = NoDeletes> = QuantMemoryIndex<T, DefaultQuant, D>;
pub type QuantOnlyIndex<Q, D = NoDeletes> = DiskANNIndex<DefaultProvider<NoStore, Q, D>>;
pub fn new_index<T, D>(
config: Config,
params: DefaultProviderParameters,
deleter: D,
) -> ANNResult<MemoryIndex<T, D::Target>>
where
T: VectorRepr,
D: CreateDeleteProvider,
D::Target: AsyncFriendly,
{
let fp_precursor = CreateFullPrecision::new(params.dim, params.prefetch_cache_line_level);
let data_provider = DefaultProvider::new_empty(params, fp_precursor, NoStore, deleter)?;
Ok(Arc::new(DiskANNIndex::new(config, data_provider, None)))
}
pub fn new_quant_index<T, Q, D>(
config: Config,
params: DefaultProviderParameters,
quant: Q,
deleter: D,
) -> ANNResult<QuantMemoryIndex<T, Q::Target, D::Target>>
where
T: VectorRepr,
Q: CreateVectorStore,
Q::Target: AsyncFriendly,
D: CreateDeleteProvider,
D::Target: AsyncFriendly,
{
let fp_precursor = CreateFullPrecision::new(params.dim, params.prefetch_cache_line_level);
let data_provider = DefaultProvider::new_empty(params, fp_precursor, quant, deleter)?;
Ok(Arc::new(DiskANNIndex::new(config, data_provider, None)))
}
pub fn new_quant_only_index<Q, D>(
config: Config,
params: DefaultProviderParameters,
quant: Q,
deleter: D,
) -> ANNResult<QuantOnlyIndex<Q::Target, D::Target>>
where
Q: CreateVectorStore,
Q::Target: AsyncFriendly,
D: CreateDeleteProvider,
D::Target: AsyncFriendly,
{
let data = DefaultProvider::new_empty(params, NoStore, quant, deleter)?;
Ok(DiskANNIndex::new(config, data, None))
}
#[cfg(test)]
pub(crate) mod tests {
use std::{
collections::HashSet,
marker::PhantomData,
num::{NonZeroU32, NonZeroUsize},
sync::{Arc, Mutex},
};
use crate::storage::VirtualStorageProvider;
use diskann::graph::test::synthetic::Grid;
use diskann::{
graph::{
self, AdjacencyList, InplaceDeleteMethod, StartPointStrategy,
config::IntraBatchCandidates,
glue::{
DefaultSearchStrategy, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy,
SearchStrategy,
},
index::{QueryLabelProvider, QueryVisitDecision},
search::{Knn, Range},
search_output_buffer,
},
neighbor::Neighbor,
provider::{
AsNeighbor, AsNeighborMut, BuildQueryComputer, DataProvider, DefaultContext, Delete,
ExecutionContext, Guard, NeighborAccessor, NeighborAccessorMut, SetElement,
},
utils::{IntoUsize, ONE},
};
use diskann_quantization::scalar::train::ScalarQuantizationParameters;
use diskann_utils::{test_data_root, views::Matrix};
use diskann_vector::{
DistanceFunction, PureDistanceFunction,
distance::{Metric, SquaredL2},
};
use rand::{distr::Distribution, rngs::StdRng, seq::SliceRandom};
use rstest::rstest;
use super::*;
use crate::{
model::graph::provider::{
async_::{
TableDeleteProviderAsync,
common::{FullPrecision, Hybrid, NoDeletes, Quantized, TableBasedDeletes},
inmem::{self, DefaultQuant, SetStartPoints},
},
layers::BetaFilter,
},
storage::StorageReadProvider,
test_utils::{
assert_range_results_exactly_match, assert_top_k_exactly_match, groundtruth, is_match,
},
utils::{VectorDataIterator, create_rnd_from_seed_in_tests},
};
fn no_modify(_: &mut diskann::graph::config::Builder) {}
pub(crate) fn squish<'a, To, T, Itr>(data: Itr, dim: usize) -> diskann_utils::views::Matrix<To>
where
To: Clone + Default,
T: Clone + Into<To> + 'a,
Itr: ExactSizeIterator<Item = &'a Vec<T>> + 'a,
{
let mut mat = diskann_utils::views::Matrix::new(To::default(), data.len(), dim);
std::iter::zip(mat.row_iter_mut(), data).for_each(|(output, input)| {
assert_eq!(
input.len(),
dim,
"all elements in data must have the same length"
);
std::iter::zip(output.iter_mut(), input.iter()).for_each(|(o, i)| {
*o = i.clone().into();
});
});
mat
}
pub(crate) struct PagedSearch<T> {
query: Vec<T>,
groundtruth: Vec<Neighbor<u32>>,
}
impl<T> PagedSearch<T> {
pub(crate) fn new(query: Vec<T>, groundtruth: Vec<Neighbor<u32>>) -> Self {
Self { query, groundtruth }
}
}
pub(crate) async fn populate_data<DP, Ctx, T>(provider: &DP, context: &Ctx, source: &[Vec<T>])
where
Ctx: ExecutionContext,
DP: DataProvider<Context = Ctx, InternalId = u32, ExternalId = u32>
+ for<'a> SetElement<&'a [T]>,
{
for (i, v) in source.iter().enumerate() {
let guard = provider.set_element(context, &(i as u32), v).await.unwrap();
assert_eq!(
guard.id(),
i as u32,
"populate_data only works properly for providers with the identity mapping"
);
guard.complete().await;
}
}
pub(crate) async fn populate_graph<NA>(accessor: &mut NA, source: &[AdjacencyList<u32>])
where
NA: AsNeighborMut<Id = u32>,
{
for (i, v) in source.iter().enumerate() {
accessor.set_neighbors(i as u32, v).await.unwrap();
}
}
pub(crate) fn grid_from_dim(dim: usize) -> Grid {
Grid::from_dim(dim)
.unwrap_or_else(|| panic!("{dim}-dimensions is not supported for grid-generation"))
}
fn grid_to_vecs<T: Clone>(matrix: &Matrix<T>) -> Vec<Vec<T>> {
(0..matrix.nrows())
.map(|i| matrix.row(i).to_vec())
.collect()
}
pub(crate) trait GenerateGrid: Sized {
fn generate_grid(dim: usize, size: usize) -> Vec<Vec<Self>>;
}
impl GenerateGrid for f32 {
fn generate_grid(dim: usize, size: usize) -> Vec<Vec<Self>> {
grid_to_vecs(&grid_from_dim(dim).data(size))
}
}
impl GenerateGrid for i8 {
fn generate_grid(dim: usize, size: usize) -> Vec<Vec<Self>> {
grid_to_vecs(&grid_from_dim(dim).data_as(size, |v| i8::try_from(v).unwrap()))
}
}
impl GenerateGrid for u8 {
fn generate_grid(dim: usize, size: usize) -> Vec<Vec<Self>> {
grid_to_vecs(&grid_from_dim(dim).data_as(size, |v| u8::try_from(v).unwrap()))
}
}
#[derive(Debug)]
struct SearchParameters<Ctx> {
context: Ctx,
search_l: usize,
search_k: usize,
to_check: usize,
}
async fn test_search<DP, S, Q, Checker>(
index: &DiskANNIndex<DP>,
parameters: &SearchParameters<DP::Context>,
strategy: S,
query: Q,
mut checker: Checker,
) where
DP: DataProvider<InternalId = u32>,
S: DefaultSearchStrategy<DP, Q>,
Q: Copy + std::fmt::Debug + Send + Sync,
Checker: FnMut(usize, (u32, f32)) -> Result<(), Box<dyn std::fmt::Display>>,
{
let mut ids = vec![0; parameters.search_k];
let mut distances = vec![0.0; parameters.search_k];
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let graph_search =
graph::search::Knn::new_default(parameters.search_k, parameters.search_l).unwrap();
index
.search(
graph_search,
&strategy,
¶meters.context,
query,
&mut result_output_buffer,
)
.await
.unwrap();
for i in 0..parameters.to_check {
println!("{ids:?}");
if let Err(message) = checker(i, (ids[i], distances[i])) {
panic!(
"Check failed for result {} with error: {}. Query = {:?}. Result: ({}, {})",
i, message, query, ids[i], distances[i]
);
}
}
}
async fn test_multihop_search<DP, S, Q, Checker>(
index: &DiskANNIndex<DP>,
parameters: &SearchParameters<DP::Context>,
strategy: &S,
query: Q,
mut checker: Checker,
filter: &dyn QueryLabelProvider<DP::InternalId>,
) where
DP: DataProvider<InternalId = u32>,
S: DefaultSearchStrategy<DP, Q>,
Q: Copy + std::fmt::Debug + Send + Sync,
Checker: FnMut(usize, (u32, f32)) -> Result<(), Box<dyn std::fmt::Display>>,
{
let mut ids = vec![0; parameters.search_k];
let mut distances = vec![0.0; parameters.search_k];
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let search_params = Knn::new_default(parameters.search_k, parameters.search_l).unwrap();
let multihop = graph::search::MultihopSearch::new(search_params, filter);
index
.search(
multihop,
strategy,
¶meters.context,
query,
&mut result_output_buffer,
)
.await
.unwrap();
for i in 0..parameters.to_check {
println!("{ids:?}");
if let Err(message) = checker(i, (ids[i], distances[i])) {
panic!(
"Check failed for result {} with error: {}. Query = {:?}. Result: ({}, {})",
i, message, query, ids[i], distances[i]
);
}
}
}
async fn test_paged_search<DP, S, Q>(
index: &DiskANNIndex<DP>,
strategy: S,
parameters: &SearchParameters<DP::Context>,
query: Q,
groundtruth: &mut Vec<Neighbor<u32>>,
max_candidates: usize,
) where
DP: DataProvider<InternalId = u32>,
S: SearchStrategy<DP, Q> + 'static,
Q: Copy + std::fmt::Debug + Send + Sync,
{
assert!(max_candidates <= groundtruth.len());
let mut state = index
.start_paged_search(strategy, ¶meters.context, query, parameters.search_l)
.await
.unwrap();
let mut buffer = vec![Neighbor::<u32>::default(); parameters.search_k];
let mut iter = 0;
let mut seen = 0;
while !groundtruth.is_empty() {
let count = index
.next_search_results::<S, Q>(
¶meters.context,
&mut state,
parameters.search_k,
&mut buffer,
)
.await
.unwrap();
for (i, b) in buffer.iter().enumerate().take(count) {
let m = is_match(groundtruth, *b, 0.01);
match m {
None => {
let last = groundtruth.len();
let start = last - last.min(10);
panic!(
"Remaining Groundtruth: {:?}\n\
Could not match: {:?} on iteration {}, position {}.\n\
Remaining entries: {:?}",
&groundtruth[start..],
b,
iter,
i,
&buffer[i..],
);
}
Some(j) => groundtruth.remove(j),
};
seen += 1;
if seen == max_candidates {
return;
}
}
iter += 1;
}
}
pub(crate) async fn check_grid_search<DP, T, FS, QS>(
index: &DiskANNIndex<DP>,
vectors: &[Vec<T>],
paged_queries: &[PagedSearch<T>],
full_strategy: FS,
quant_strategy: QS,
) where
DP: DataProvider<InternalId = u32, Context: Default>,
FS: for<'a> DefaultSearchStrategy<DP, &'a [T]> + Clone + 'static,
QS: for<'a> DefaultSearchStrategy<DP, &'a [T]> + Clone + 'static,
T: Default + Clone + Send + Sync + std::fmt::Debug,
{
let dim = vectors[0].len();
let num_points = vectors.len();
let query = vec![T::default(); dim];
let parameters = SearchParameters {
context: Default::default(),
search_l: 10,
search_k: dim + 1,
to_check: dim + 1,
};
let checker = |position, (id, distance)| -> Result<(), Box<dyn std::fmt::Display>> {
if position == 0 {
if id != 0 {
return Err(Box::new("expected the nearest neighbor to be 0"));
}
if distance != 0.0 {
return Err(Box::new("expected the nearest distance to be 0"));
}
} else if distance != 1.0 {
return Err(Box::new(
"expected corner query close neighbor to have distance 1.0",
));
}
Ok(())
};
test_search(
index,
¶meters,
full_strategy.clone(),
query.as_slice(),
checker,
)
.await;
test_search(
index,
¶meters,
quant_strategy.clone(),
query.as_slice(),
checker,
)
.await;
let query = vectors.last().unwrap();
let parameters = SearchParameters {
to_check: 1,
..parameters
};
assert_eq!(vectors.len(), num_points);
let checker = |position, (id, distance)| -> Result<(), Box<dyn std::fmt::Display>> {
assert_eq!(position, 0);
if id as usize != num_points - 2 {
return Err(Box::new(format!(
"expected {} as the nearest id",
num_points - 2
)));
}
if distance != dim as f32 {
return Err(Box::new(format!("nearest distance should be {}", dim)));
}
Ok(())
};
test_search(
index,
¶meters,
full_strategy.clone(),
query.as_slice(),
checker,
)
.await;
test_search(
index,
¶meters,
quant_strategy.clone(),
query.as_slice(),
checker,
)
.await;
let parameters = SearchParameters {
context: Default::default(),
search_l: 10,
search_k: dim + 1,
to_check: dim + 1,
};
for paged in paged_queries {
let mut gt = paged.groundtruth.clone();
let max_candidates = gt.len();
test_paged_search(
index,
full_strategy.clone(),
¶meters,
&paged.query,
&mut gt,
max_candidates,
)
.await;
let mut gt = paged.groundtruth.clone();
test_paged_search(
index,
quant_strategy.clone(),
¶meters,
&paged.query,
&mut gt,
max_candidates,
)
.await
}
}
#[rstest]
#[case(1, 100)]
#[case(3, 7)]
#[case(4, 5)]
#[tokio::test]
async fn grid_search(#[case] dim: usize, #[case] grid_size: usize) {
let l = 10;
let max_degree = 2 * dim;
let num_points = (grid_size).pow(dim as u32);
let (config, parameters) =
simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap();
let mut adjacency_lists = grid_from_dim(dim).neighbors(grid_size);
let mut vectors = f32::generate_grid(dim, grid_size);
assert_eq!(adjacency_lists.len(), num_points);
assert_eq!(vectors.len(), num_points);
adjacency_lists.push((num_points as u32 - 1).into());
vectors.push(vec![grid_size as f32; dim]);
let table = train_pq(
squish(vectors.iter(), dim).as_view(),
2.min(dim), &mut create_rnd_from_seed_in_tests(0x04a8832604476965),
1usize,
)
.unwrap();
let index = new_quant_index::<f32, _, _>(config, parameters, table, NoDeletes).unwrap();
let neighbor_accessor = &mut index.provider().neighbors();
populate_data(&index.data_provider, &DefaultContext, &vectors).await;
populate_graph(neighbor_accessor, &adjacency_lists).await;
let corpus: diskann_utils::views::Matrix<f32> =
squish(vectors.iter().take(num_points), dim);
let mut paged_tests = Vec::new();
let query = vec![0.0; dim];
let gt = groundtruth(corpus.as_view(), &query, |a, b| SquaredL2::evaluate(a, b));
paged_tests.push(PagedSearch::new(query, gt));
let query = vectors.last().unwrap();
let gt = groundtruth(corpus.as_view(), query, |a, b| SquaredL2::evaluate(a, b));
paged_tests.push(PagedSearch::new(query.clone(), gt));
check_grid_search(
&index,
&vectors,
&paged_tests,
FullPrecision,
Hybrid::new(None),
)
.await;
}
const IBC_NONE: IntraBatchCandidates = IntraBatchCandidates::None;
const IBC_ALL: IntraBatchCandidates = IntraBatchCandidates::All;
#[rstest]
#[tokio::test]
async fn grid_search_with_build<T>(
#[values(PhantomData::<f32>, PhantomData::<i8>, PhantomData::<u8>)] _v: PhantomData<T>,
#[values((1, 100), (3, 7), (4, 5))] dim_and_size: (usize, usize),
#[values(IBC_NONE, IBC_ALL)] intra_batch_candidates: IntraBatchCandidates,
) where
T: VectorRepr + GenerateGrid + Into<f32>,
{
let dim = dim_and_size.0;
let grid_size = dim_and_size.1;
let l = 10;
let max_degree = 2 * dim;
let minibatch_par = 10;
let max_fp_vecs_per_prune = Some(2);
let hybrid = Hybrid::new(max_fp_vecs_per_prune);
let num_points = (grid_size).pow(dim as u32);
let (config, parameters) =
simplified_builder(l, max_degree, Metric::L2, dim, num_points, |p| {
p.max_minibatch_par(minibatch_par)
.intra_batch_candidates(intra_batch_candidates);
})
.unwrap();
let mut vectors: Vec<Vec<T>> = T::generate_grid(dim, grid_size);
assert_eq!(vectors.len(), num_points);
vectors.push(vec![
<T as num_traits::FromPrimitive>::from_usize(grid_size)
.unwrap();
dim
]);
let matrix: Matrix<T> = squish::<T, T, _>(vectors.iter(), dim);
let table = train_pq(
matrix.map(|i| (*i).into()).as_view(),
2.min(dim), &mut create_rnd_from_seed_in_tests(0x04a8832604476965),
1usize,
)
.unwrap();
let init_index = || {
let index = new_quant_index::<T, _, _>(
config.clone(),
parameters.clone(),
table.clone(),
NoDeletes,
)
.unwrap();
index
.provider()
.set_start_points(std::iter::once(matrix.row(num_points)))
.unwrap();
index
};
{
let index = init_index();
let ctx = Default::default();
for (i, v) in matrix.row_iter().take(num_points).enumerate() {
index
.insert(FullPrecision, &ctx, &(i as u32), v)
.await
.unwrap();
}
check_grid_search(&index, &vectors, &[], FullPrecision, hybrid).await;
}
{
let index = init_index();
let ctx = Default::default();
for (i, v) in matrix.row_iter().take(num_points).enumerate() {
index.insert(hybrid, &ctx, &(i as u32), v).await.unwrap();
}
check_grid_search(&index, &vectors, &[], FullPrecision, hybrid).await;
}
{
let index = init_index();
let ctx = Default::default();
let chunk_size = 2 * minibatch_par;
for (batch, batch_data) in matrix
.subview(0..num_points)
.unwrap()
.window_iter(chunk_size)
.enumerate()
{
let batch_data = Arc::new(batch_data.to_owned());
let start = batch * chunk_size;
let batch_ids: Arc<[u32]> = (start..start + batch_data.nrows())
.map(|i| i as u32)
.collect();
index
.multi_insert::<_, Matrix<T>>(FullPrecision, &ctx, batch_data, batch_ids)
.await
.unwrap();
}
check_grid_search(&index, &vectors, &[], FullPrecision, hybrid).await;
}
{
let index = init_index();
let ctx = Default::default();
let batch = Arc::new(matrix.subview(0..num_points).unwrap().to_owned());
let batch_ids: Arc<[u32]> = (0..num_points as u32).collect();
index
.multi_insert::<_, Matrix<T>>(hybrid, &ctx, batch, batch_ids)
.await
.unwrap();
check_grid_search(&index, &vectors, &[], FullPrecision, hybrid).await;
}
}
trait GenerateSphericalData: Sized {
fn generate_spherical(
num: usize,
dim: usize,
radius: f32,
rng: &mut StdRng,
) -> Vec<Vec<Self>>;
}
macro_rules! impl_generate_spherical_data {
($T:ty) => {
impl GenerateSphericalData for $T {
fn generate_spherical(
num: usize,
dim: usize,
radius: f32,
rng: &mut StdRng,
) -> Vec<Vec<Self>> {
use diskann_utils::sampling::random::WithApproximateNorm;
let mut vectors = Vec::with_capacity(num);
for _ in 0..num {
let vector = <$T>::with_approximate_norm(dim, radius, rng);
vectors.push(vector);
}
assert_eq!(vectors.len(), num);
let mut start_point = vec![<$T>::default(); dim];
start_point[0] = radius as $T;
vectors.push(start_point);
vectors
}
}
};
}
impl_generate_spherical_data!(f32);
impl_generate_spherical_data!(i8);
impl_generate_spherical_data!(u8);
struct SphericalTest {
num: usize,
dim: usize,
radius: f32,
num_queries: usize,
}
async fn test_spherical_data_impl<T, S>(
strategy: S,
metric: Metric,
params: SphericalTest,
rng: &mut StdRng,
) where
T: VectorRepr + GenerateSphericalData + Into<f32>,
S: for<'a> InsertStrategy<FullPrecisionProvider<T, DefaultQuant>, &'a [T]>
+ for<'a> DefaultSearchStrategy<FullPrecisionProvider<T, DefaultQuant>, &'a [T]>
+ Clone
+ 'static,
rand::distr::StandardUniform: Distribution<T>,
{
let SphericalTest {
num,
dim,
radius,
num_queries,
} = params;
let ctx = &Default::default();
let l_search = 10;
let (config, params) =
simplified_builder(l_search, 3 * dim, metric, dim, num, no_modify).unwrap();
let data = T::generate_spherical(num, dim, radius, rng);
let table = {
let train_data: diskann_utils::views::Matrix<f32> = squish(data.iter(), dim);
train_pq(train_data.as_view(), 2.min(dim), rng, 1usize).unwrap()
};
let index = new_quant_index::<T, _, _>(config, params, table, NoDeletes).unwrap();
index
.provider()
.set_start_points(std::iter::once(data[num].as_slice()))
.unwrap();
for (i, v) in data.iter().take(num).enumerate() {
index
.insert(strategy.clone(), ctx, &(i as u32), v.as_slice())
.await
.unwrap();
}
let distribution = rand::distr::StandardUniform {};
let data = squish::<T, T, _>(data.iter().take(num), dim);
let distance = T::distance(metric, None);
let parameters = SearchParameters {
context: Default::default(),
search_l: 20,
search_k: 10,
to_check: 10,
};
for _ in 0..num_queries {
let query: Vec<T> = (0..dim).map(|_| distribution.sample(rng)).collect();
let mut gt = groundtruth(data.as_view(), &query, |a, b| {
distance.evaluate_similarity(a, b)
});
let checker = |position, (id, distance)| -> Result<(), Box<dyn std::fmt::Display>> {
let expected: Neighbor<u32> = gt[gt.len() - 1 - position];
if id != expected.id {
if distance == expected.distance {
Ok(())
} else {
Err(Box::new(format!(
"expected neighbor {:?}, but found {}",
expected, id
)))
}
} else if distance != expected.distance {
Err(Box::new(format!(
"expected neighbor {:?}, but found {}",
expected, distance
)))
} else {
Ok(())
}
};
test_search(
&index,
¶meters,
strategy.clone(),
query.as_slice(),
checker,
)
.await;
test_paged_search(
&index,
strategy.clone(),
¶meters,
query.as_slice(),
&mut gt,
3 * parameters.search_k,
)
.await;
}
}
const PF32: PhantomData<f32> = PhantomData;
const PU8: PhantomData<u8> = PhantomData;
const PI8: PhantomData<i8> = PhantomData;
#[rstest]
#[case(PF32, FullPrecision, Metric::L2, 100, 4, 1.5)]
#[case(PF32, Hybrid::new(Some(6)), Metric::L2, 100, 4, 1.5)]
#[case(PF32, FullPrecision, Metric::InnerProduct, 93, 5, 543.5)]
#[case(PF32, Hybrid::new(Some(8)), Metric::InnerProduct, 93, 5, 543.3)]
#[case(PF32, FullPrecision, Metric::Cosine, 77, 7, 2.5)]
#[case(PF32, Hybrid::new(Some(32)), Metric::Cosine, 77, 7, 2.5)]
#[case(PU8, FullPrecision, Metric::L2, 100, 7, 43.0)]
#[case(PU8, FullPrecision, Metric::Cosine, 93, 5, 46.0)]
#[case(PU8, FullPrecision, Metric::InnerProduct, 77, 6, 47.0)]
#[case(PI8, FullPrecision, Metric::L2, 100, 7, 43.0)]
#[case(PI8, FullPrecision, Metric::Cosine, 93, 5, 46.0)]
#[case(PI8, FullPrecision, Metric::InnerProduct, 77, 6, 47.0)]
#[tokio::test]
async fn test_sphere_search<T, S>(
#[case] ty: PhantomData<T>,
#[case] strategy: S,
#[case] metric: Metric,
#[case] num: usize,
#[case] dim: usize,
#[case] radius: f32,
) where
T: VectorRepr + GenerateSphericalData + Into<f32>,
S: for<'a> InsertStrategy<FullPrecisionProvider<T, DefaultQuant>, &'a [T]>
+ for<'a> DefaultSearchStrategy<FullPrecisionProvider<T, DefaultQuant>, &'a [T]>
+ Clone
+ 'static,
rand::distr::StandardUniform: Distribution<T>,
{
use std::hash::{DefaultHasher, Hash, Hasher};
let rng = &mut {
let mut s = DefaultHasher::new();
ty.hash(&mut s);
num.hash(&mut s);
dim.hash(&mut s);
create_rnd_from_seed_in_tests(s.finish())
};
let num_queries = 4;
test_spherical_data_impl::<T, _>(
strategy,
metric,
SphericalTest {
num,
dim,
radius,
num_queries,
},
rng,
)
.await;
}
#[derive(Debug)]
struct EvenFilter;
impl QueryLabelProvider<u32> for EvenFilter {
fn is_match(&self, id: u32) -> bool {
id.is_multiple_of(2)
}
}
async fn test_beta_filtering(
filter: Arc<dyn QueryLabelProvider<u32>>,
dim: usize,
grid_size: usize,
) {
let l = 10;
let max_degree = 2 * dim;
let num_points = (grid_size).pow(dim as u32);
let (config, parameters) =
simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap();
let mut adjacency_lists = Grid::Three.neighbors(grid_size);
let mut vectors = f32::generate_grid(dim, grid_size);
assert_eq!(adjacency_lists.len(), num_points);
assert_eq!(vectors.len(), num_points);
adjacency_lists.push((num_points as u32 - 1).into());
vectors.push(vec![grid_size as f32; dim]);
let table = train_pq(
squish(vectors.iter(), dim).as_view(),
2.min(dim), &mut create_rnd_from_seed_in_tests(0x04a8832604476965),
1usize,
)
.unwrap();
let index = new_quant_index::<f32, _, _>(config, parameters, table, NoDeletes).unwrap();
let neighbor_accessor = &mut index.provider().neighbors();
populate_data(&index.data_provider, &DefaultContext, &vectors).await;
populate_graph(neighbor_accessor, &adjacency_lists).await;
let beta = 0.5;
let corpus: diskann_utils::views::Matrix<f32> =
squish(vectors.iter().take(num_points), dim);
let query = vec![grid_size as f32; dim];
let parameters = SearchParameters {
context: Default::default(),
search_l: 40,
search_k: 20,
to_check: 20,
};
let gt = {
let mut gt = groundtruth(corpus.as_view(), &query, |a, b| SquaredL2::evaluate(a, b));
for n in gt.iter_mut() {
if filter.is_match(n.id) {
n.distance *= beta;
}
}
gt.sort_unstable_by(|a, b| a.cmp(b).reverse());
gt
};
let mut gt_clone = gt.clone();
let strategy = BetaFilter::new(FullPrecision, filter.clone(), beta);
test_search(
&index,
¶meters,
strategy.clone(),
query.as_slice(),
|_, (id, distance)| -> Result<(), Box<dyn std::fmt::Display>> {
if let Some(position) = is_match(>_clone, Neighbor::new(id, distance), 0.0) {
gt_clone.remove(position);
Ok(())
} else {
if id.into_usize() == num_points + 1 {
return Err(Box::new("The start point should not be returned"));
}
Err(Box::new("mismatch"))
}
},
)
.await;
let paged_parameters = SearchParameters {
search_k: 10,
search_l: 40,
..parameters
};
test_paged_search(
&index,
strategy,
&paged_parameters,
query.as_slice(),
&mut gt.clone(),
60,
)
.await;
}
#[tokio::test]
async fn test_even_filtering_beta() {
let filter = Arc::new(EvenFilter);
test_beta_filtering(filter, 3, 7).await;
}
async fn test_multihop_filtering(
filter: &dyn QueryLabelProvider<u32>,
dim: usize,
grid_size: usize,
) {
let l = 10;
let max_degree = 2 * dim;
let num_points = (grid_size).pow(dim as u32);
let (config, parameters) =
simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap();
let mut adjacency_lists = Grid::Three.neighbors(grid_size);
let mut vectors = f32::generate_grid(dim, grid_size);
assert_eq!(adjacency_lists.len(), num_points);
assert_eq!(vectors.len(), num_points);
adjacency_lists.push((num_points as u32 - 1).into());
vectors.push(vec![grid_size as f32; dim]);
let table = train_pq(
squish(vectors.iter(), dim).as_view(),
2.min(dim), &mut create_rnd_from_seed_in_tests(0x04a8832604476965),
1usize,
)
.unwrap();
let index = new_quant_index::<f32, _, _>(config, parameters, table, NoDeletes).unwrap();
let neighbor_accessor = &mut index.provider().neighbors();
populate_data(&index.data_provider, &DefaultContext, &vectors).await;
populate_graph(neighbor_accessor, &adjacency_lists).await;
let corpus: diskann_utils::views::Matrix<f32> =
squish(vectors.iter().take(num_points), dim);
let query = vec![grid_size as f32; dim];
let parameters = SearchParameters {
context: DefaultContext,
search_l: 40,
search_k: 20,
to_check: 20,
};
let gt = {
let mut gt = groundtruth(corpus.as_view(), &query, |a, b| SquaredL2::evaluate(a, b));
gt.retain(|n| filter.is_match(n.id));
gt.sort_unstable_by(|a, b| a.cmp(b).reverse());
gt
};
let mut gt_clone = gt.clone();
let strategy = FullPrecision;
test_multihop_search(
&index,
¶meters,
&strategy.clone(),
query.as_slice(),
|_, (id, distance)| -> Result<(), Box<dyn std::fmt::Display>> {
if let Some(position) = is_match(>_clone, Neighbor::new(id, distance), 0.0) {
gt_clone.remove(position);
Ok(())
} else {
if id.into_usize() == num_points + 1 {
return Err(Box::new("The start point should not be returned"));
}
Err(Box::new("mismatch"))
}
},
filter,
)
.await;
}
#[tokio::test]
async fn test_even_filtering_multihop() {
test_multihop_filtering(&EvenFilter, 3, 7).await;
}
#[derive(Debug, Clone, Default)]
struct CallbackMetrics {
total_visits: usize,
rejected_count: usize,
adjusted_count: usize,
visited_ids: Vec<u32>,
}
#[derive(Debug)]
struct CallbackFilter {
blocked: u32,
adjusted: u32,
adjustment_factor: f32,
metrics: Mutex<CallbackMetrics>,
}
impl CallbackFilter {
fn new(blocked: u32, adjusted: u32, adjustment_factor: f32) -> Self {
Self {
blocked,
adjusted,
adjustment_factor,
metrics: Mutex::new(CallbackMetrics::default()),
}
}
fn hits(&self) -> Vec<u32> {
self.metrics
.lock()
.expect("callback metrics mutex should not be poisoned")
.visited_ids
.clone()
}
fn metrics(&self) -> CallbackMetrics {
self.metrics
.lock()
.expect("callback metrics mutex should not be poisoned")
.clone()
}
}
impl QueryLabelProvider<u32> for CallbackFilter {
fn is_match(&self, _: u32) -> bool {
true
}
fn on_visit(&self, neighbor: Neighbor<u32>) -> QueryVisitDecision<u32> {
let mut metrics = self
.metrics
.lock()
.expect("callback metrics mutex should not be poisoned");
metrics.total_visits += 1;
metrics.visited_ids.push(neighbor.id);
if neighbor.id == self.blocked {
metrics.rejected_count += 1;
return QueryVisitDecision::Reject;
}
if neighbor.id == self.adjusted {
metrics.adjusted_count += 1;
let adjusted =
Neighbor::new(neighbor.id, neighbor.distance * self.adjustment_factor);
return QueryVisitDecision::Accept(adjusted);
}
QueryVisitDecision::Accept(neighbor)
}
}
#[tokio::test]
async fn test_multihop_callback_enforces_filtering() {
let dim = 3;
let grid_size: usize = 5;
let l = 10;
let max_degree = 2 * dim;
let num_points = (grid_size).pow(dim as u32);
let (config, parameters) =
simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap();
let mut adjacency_lists = Grid::Three.neighbors(grid_size);
let mut vectors = f32::generate_grid(dim, grid_size);
adjacency_lists.push((num_points as u32 - 1).into());
vectors.push(vec![grid_size as f32; dim]);
let table = train_pq(
squish(vectors.iter(), dim).as_view(),
2.min(dim),
&mut create_rnd_from_seed_in_tests(0xdd81b895605c73d4),
1usize,
)
.unwrap();
let index = new_quant_index::<f32, _, _>(config, parameters, table, NoDeletes).unwrap();
let neighbor_accessor = &mut index.provider().neighbors();
populate_data(&index.data_provider, &DefaultContext, &vectors).await;
populate_graph(neighbor_accessor, &adjacency_lists).await;
let corpus: diskann_utils::views::Matrix<f32> =
squish(vectors.iter().take(num_points), dim);
let query = vec![grid_size as f32; dim];
let parameters = SearchParameters {
context: DefaultContext,
search_l: 40,
search_k: 20,
to_check: 10,
};
let mut ids = vec![0; parameters.search_k];
let mut distances = vec![0.0; parameters.search_k];
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let blocked = (num_points - 2) as u32;
let adjusted = (num_points - 1) as u32;
let mut baseline_gt =
groundtruth(corpus.as_view(), &query, |a, b| SquaredL2::evaluate(a, b));
baseline_gt.sort_unstable_by(|a, b| a.cmp(b).reverse());
assert!(
baseline_gt.iter().any(|n| n.id == blocked),
"blocked candidate must exist in groundtruth"
);
let baseline_adjusted_distance = baseline_gt
.iter()
.find(|n| n.id == adjusted)
.expect("adjusted node should exist in groundtruth")
.distance;
let filter = CallbackFilter::new(blocked, adjusted, 0.5);
let search_params = Knn::new_default(parameters.search_k, parameters.search_l).unwrap();
let multihop = graph::search::MultihopSearch::new(search_params, &filter);
let stats = index
.search(
multihop,
&FullPrecision,
¶meters.context,
query.as_slice(),
&mut result_output_buffer,
)
.await
.unwrap();
let callback_metrics = filter.metrics();
assert!(
stats.result_count >= parameters.to_check as u32,
"expected at least {} results, got {}",
parameters.to_check,
stats.result_count
);
assert!(
callback_metrics.total_visits > 0,
"callback should have been invoked at least once"
);
assert!(
filter.hits().contains(&blocked),
"callback must evaluate the blocked candidate (visited {} candidates)",
callback_metrics.total_visits
);
assert_eq!(
callback_metrics.rejected_count, 1,
"exactly one candidate (blocked={}) should be rejected",
blocked
);
let produced = stats.result_count as usize;
let inspected = produced.min(parameters.to_check);
assert!(
!ids.iter().take(inspected).any(|&id| id == blocked),
"blocked candidate {} should not appear in final results (found in: {:?})",
blocked,
&ids[..inspected]
);
assert!(
callback_metrics.adjusted_count >= 1,
"adjusted candidate {} should have been visited",
adjusted
);
let adjusted_idx = ids
.iter()
.take(inspected)
.position(|&id| id == adjusted)
.expect("adjusted candidate should be present in results");
let expected_distance = baseline_adjusted_distance * 0.5;
assert!(
(distances[adjusted_idx] - expected_distance).abs() < 1e-5,
"callback should adjust distances before ranking: \
expected {:.6}, got {:.6} (baseline: {:.6}, factor: 0.5)",
expected_distance,
distances[adjusted_idx],
baseline_adjusted_distance
);
println!(
"test_multihop_callback_enforces_filtering metrics:\n\
- total callback visits: {}\n\
- rejected count: {}\n\
- adjusted count: {}\n\
- search hops: {}\n\
- search comparisons: {}\n\
- result count: {}",
callback_metrics.total_visits,
callback_metrics.rejected_count,
callback_metrics.adjusted_count,
stats.hops,
stats.cmps,
stats.result_count
);
}
#[tokio::test]
async fn test_inplace_delete_2d() {
test_inplace_delete_2d_impl(FullPrecision).await;
test_inplace_delete_2d_impl(Hybrid::new(None)).await;
}
async fn test_inplace_delete_2d_impl<S>(strategy: S)
where
S: InplaceDeleteStrategy<TestProvider>
+ for<'a> SearchStrategy<TestProvider, S::DeleteElement<'a>>
+ Sync
+ std::clone::Clone,
{
let dim = 2;
let (config, parameters) = simplified_builder(
10, 4, Metric::L2, dim, 4, no_modify,
)
.unwrap();
let pqtable = model::pq::FixedChunkPQTable::new(
dim,
Box::new([0.0, 0.0]),
Box::new([0.0, 0.0]),
Box::new([0, 2]),
)
.unwrap();
let index =
new_quant_index::<f32, _, _>(config, parameters, pqtable, TableBasedDeletes).unwrap();
let neighbor_accessor = &mut index.provider().neighbors();
let start_point: &[f32] = &[0.5, 0.5];
index
.provider()
.set_start_points(std::iter::once(start_point))
.unwrap();
let vectors = [
vec![0.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 0.0],
vec![1.0, 1.0],
];
let adjacency_lists = [
AdjacencyList::from_iter_untrusted([4, 1]),
AdjacencyList::from_iter_untrusted([4, 0]),
AdjacencyList::from_iter_untrusted([4, 3]),
AdjacencyList::from_iter_untrusted([4, 2]),
AdjacencyList::from_iter_untrusted([0, 1, 2, 3]),
];
let ctx = DefaultContext;
populate_graph(neighbor_accessor, &adjacency_lists).await;
populate_data(&index.data_provider, &ctx, &vectors).await;
index
.inplace_delete(
strategy,
&ctx,
&3, 3, InplaceDeleteMethod::VisitedAndTopK {
k_value: 4,
l_value: 10,
},
)
.await
.unwrap();
assert!(
index
.data_provider
.status_by_internal_id(&ctx, 3)
.await
.unwrap()
.is_deleted()
);
let neighbor_accessor = &mut index.provider().neighbors();
{
let mut list = AdjacencyList::new();
neighbor_accessor.get_neighbors(4, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[0, 1, 2]);
}
{
let mut list = AdjacencyList::new();
neighbor_accessor.get_neighbors(2, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[0, 1, 4]);
}
{
let mut list = AdjacencyList::new();
neighbor_accessor.get_neighbors(0, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[1, 2, 4]);
}
{
let mut list = AdjacencyList::new();
neighbor_accessor.get_neighbors(1, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[0, 2, 4]);
}
{
let mut list = AdjacencyList::new();
neighbor_accessor.get_neighbors(3, &mut list).await.unwrap();
assert!(list.is_empty());
}
}
#[tokio::test]
async fn test_consolidate_deletes_2d() {
let dim = 2;
let (config, parameters) = simplified_builder(
10, 4, Metric::L2, dim, 4, no_modify,
)
.unwrap();
let pqtable = model::pq::FixedChunkPQTable::new(
dim,
Box::new([0.0, 0.0]),
Box::new([0.0, 0.0]),
Box::new([0, 2]),
)
.unwrap();
let index =
new_quant_index::<f32, _, _>(config, parameters, pqtable, TableBasedDeletes).unwrap();
let neighbor_accessor = &mut index.provider().neighbors();
let start_point: &[f32] = &[0.5, 0.5];
index
.provider()
.set_start_points(std::iter::once(start_point))
.unwrap();
let vectors = [
vec![0.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 0.0],
vec![1.0, 1.0],
];
let adjacency_lists = [
AdjacencyList::from_iter_untrusted([4, 1, 2]), AdjacencyList::from_iter_untrusted([4, 0, 3]), AdjacencyList::from_iter_untrusted([4, 3, 0]), AdjacencyList::from_iter_untrusted([4, 2, 1]), AdjacencyList::from_iter_untrusted([0, 1, 2, 3]), ];
let ctx = DefaultContext;
populate_graph(neighbor_accessor, &adjacency_lists).await;
populate_data(&index.data_provider, &ctx, &vectors).await;
let starting_point_ids = index.provider().starting_points().unwrap();
assert!(starting_point_ids.contains(&4));
assert!(starting_point_ids.len() == 1);
index
.data_provider
.delete(&ctx, &3_u32)
.await
.expect("Error in delete");
for vector_id in 0..5 {
index
.consolidate_vector(&FullPrecision, &ctx, vector_id as u32)
.await
.expect("Error in consolidate_vector");
}
let neighbor_accessor = &mut index.provider().neighbors();
{
let mut list = AdjacencyList::new();
neighbor_accessor.get_neighbors(0, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[1, 2, 4]);
}
{
let mut list = AdjacencyList::new();
neighbor_accessor.get_neighbors(1, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[0, 2, 4]);
}
{
let mut list = AdjacencyList::new();
neighbor_accessor.get_neighbors(2, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[0, 1, 4]);
}
{
let mut list = AdjacencyList::new();
neighbor_accessor.get_neighbors(4, &mut list).await.unwrap();
list.sort();
assert_eq!(&*list, &[0, 1, 2]);
}
}
const SIFTSMALL: &str = "/sift/siftsmall_learn_256pts.fbin";
#[rstest]
#[tokio::test]
async fn test_sift_build_and_search<S>(
#[values(FullPrecision, Hybrid::new(None))] build_strategy: S,
#[values(1, 10)] batchsize: usize,
) where
S: for<'a> InsertStrategy<TestProvider, &'a [f32]>
+ MultiInsertStrategy<TestProvider, Matrix<f32>>
+ Clone,
{
let ctx = &DefaultContext;
let parameters = InitParams {
l_build: 64,
max_degree: 16,
metric: Metric::L2,
batchsize: NonZeroUsize::new(batchsize).unwrap(),
};
let (index, data) = init_from_file(
build_strategy.clone(),
parameters,
SIFTSMALL,
8,
StartPointStrategy::RandomSamples {
nsamples: ONE,
seed: 0xe058c9c57864dd1e,
},
)
.await;
let starting_points = index.provider().starting_points().unwrap();
let neighbor_accessor = &mut index.provider().neighbors();
assert_eq!(
index
.count_reachable_nodes(&starting_points, neighbor_accessor)
.await
.unwrap(),
data.nrows() + 1,
);
let top_k = 10;
let search_l = 32;
let mut ids = vec![0; top_k];
let mut distances = vec![0.0; top_k];
for (q, query) in data.row_iter().enumerate() {
let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b));
{
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap();
index
.search(
graph_search,
&FullPrecision,
ctx,
query,
&mut result_output_buffer,
)
.await
.unwrap();
}
assert_top_k_exactly_match(q, >, &ids, &distances, top_k);
{
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap();
index
.search(
graph_search,
&Hybrid::new(None),
ctx,
query,
&mut result_output_buffer,
)
.await
.unwrap();
}
assert_top_k_exactly_match(q, >, &ids, &distances, top_k);
}
}
#[rstest]
#[tokio::test]
async fn test_sift_build_and_range_search<S>(
#[values(FullPrecision, Hybrid::new(None))] build_strategy: S,
#[values(1, 10)] batchsize: usize,
#[values((-2.0,-1.0), (-1.0, 0.0), (40000.0,50000.0), (50000.0,75000.0))] radii: (f32, f32),
) where
S: for<'a> InsertStrategy<TestProvider, &'a [f32]>
+ MultiInsertStrategy<TestProvider, Matrix<f32>>
+ Clone,
{
let ctx = &DefaultContext;
let parameters = InitParams {
l_build: 64,
max_degree: 16,
metric: Metric::L2,
batchsize: NonZeroUsize::new(batchsize).unwrap(),
};
let (index, data) = init_from_file(
build_strategy.clone(),
parameters,
SIFTSMALL,
8,
StartPointStrategy::RandomSamples {
nsamples: ONE,
seed: 0xe058c9c57864dd1e,
},
)
.await;
let starting_l_value = 32;
let lower_l_value = 4;
let radius = radii.1;
let inner_radius = radii.0;
for (q, query) in data.row_iter().enumerate() {
let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b));
{
let range_search = Range::new(starting_l_value, radius).unwrap();
let mut results: Vec<Neighbor<u32>> = Vec::new();
let _ = index
.search(range_search, &FullPrecision, ctx, query, &mut results)
.await
.unwrap();
let ids: Vec<u32> = results.iter().map(|n| n.id).collect();
assert_range_results_exactly_match(q, >, &ids, radius, None);
}
{
let range_search = Range::new(starting_l_value, radius).unwrap();
let mut results: Vec<Neighbor<u32>> = Vec::new();
let _ = index
.search(range_search, &Hybrid::new(None), ctx, query, &mut results)
.await
.unwrap();
let ids: Vec<u32> = results.iter().map(|n| n.id).collect();
assert_range_results_exactly_match(q, >, &ids, radius, None);
}
{
assert!(inner_radius <= radius);
let range_search = Range::with_options(
None,
starting_l_value,
None,
radius,
Some(inner_radius),
1.0,
1.0,
)
.unwrap();
let mut results: Vec<Neighbor<u32>> = Vec::new();
let _ = index
.search(range_search, &FullPrecision, ctx, query, &mut results)
.await
.unwrap();
let ids: Vec<u32> = results.iter().map(|n| n.id).collect();
assert_range_results_exactly_match(q, >, &ids, radius, Some(inner_radius));
}
{
let range_search = Range::new(lower_l_value, radius).unwrap();
let mut results: Vec<Neighbor<u32>> = Vec::new();
let _ = index
.search(range_search, &FullPrecision, ctx, query, &mut results)
.await
.unwrap();
let mut ids_set = std::collections::HashSet::new();
for n in &results {
assert!(ids_set.insert(n.id));
}
}
}
}
async fn init_and_build_index_from_file<C, B, DP>(
file: &str,
create_fn: C,
build_fn: B,
) -> (Arc<DiskANNIndex<DP>>, Arc<Matrix<f32>>)
where
C: FnOnce(Arc<Matrix<f32>>, &[f32]) -> Arc<DiskANNIndex<DP>>,
B: AsyncFnOnce(Arc<DiskANNIndex<DP>>, Arc<Matrix<f32>>),
DP: DataProvider<Context = DefaultContext, ExternalId = u32>
+ for<'a> diskann::provider::SetElement<&'a [f32]>,
{
let storage = VirtualStorageProvider::new_overlay(test_data_root());
let mut reader = storage.open_reader(file).unwrap();
let data = Arc::new(diskann_utils::io::read_bin::<f32>(&mut reader).unwrap());
let rng = &mut create_rnd_from_seed_in_tests(0xe058c9c57864dd1e);
let random_index = rand::Rng::random_range(rng, 0..data.nrows());
let start_point = data.row(random_index);
let index = create_fn(data.clone(), start_point);
build_fn(index.clone(), data.clone()).await;
(index, data)
}
async fn build_using_single_insert<DP>(index: Arc<DiskANNIndex<DP>>, data: Arc<Matrix<f32>>)
where
DP: DataProvider<Context = DefaultContext, ExternalId = u32>
+ for<'a> diskann::provider::SetElement<&'a [f32]>,
Quantized: for<'a> InsertStrategy<DP, &'a [f32]> + Clone + Send + Sync,
{
let ctx = &DefaultContext;
for (i, vector) in data.row_iter().enumerate() {
index
.insert(Quantized, ctx, &(i as u32), vector)
.await
.unwrap()
}
}
macro_rules! scalar_quant_test {
($name:ident, $nbits:literal, $search_l:literal) => {
#[tokio::test]
async fn $name() {
let ctx = &DefaultContext;
let parameters = InitParams {
l_build: 64,
max_degree: 16,
metric: Metric::L2,
batchsize: NonZeroUsize::new(1).unwrap(),
};
let create_fn = |data: Arc<Matrix<f32>>, start_point: &[f32]| {
let quantizer = ScalarQuantizationParameters::default().train(data.as_view());
let (config, params) =
parameters.materialize(data.nrows(), data.ncols()).unwrap();
let index = new_quant_index::<f32, _, _>(
config,
params,
inmem::WithBits::<$nbits>::new(quantizer),
NoDeletes,
)
.unwrap();
index
.provider()
.set_start_points(std::iter::once(start_point))
.unwrap();
index
};
let (index, data) =
init_and_build_index_from_file(SIFTSMALL, create_fn, build_using_single_insert)
.await;
let neighbor_accessor = &mut index.provider().neighbors();
assert_eq!(
index
.count_reachable_nodes(
&index.provider().starting_points().unwrap(),
neighbor_accessor
)
.await
.unwrap(),
data.nrows() + 1,
);
let top_k = 8;
let search_l = $search_l; let mut ids = vec![0; top_k];
let mut distances = vec![0.0; top_k];
for (q, query) in data.row_iter().enumerate() {
let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b));
{
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let graph_search =
graph::search::Knn::new_default(top_k, search_l).unwrap();
index
.search(
graph_search,
&FullPrecision,
ctx,
query,
&mut result_output_buffer,
)
.await
.unwrap();
}
assert_top_k_exactly_match(q, >, &ids, &distances, top_k);
{
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let graph_search =
graph::search::Knn::new_default(top_k, search_l).unwrap();
index
.search(
graph_search,
&Quantized,
ctx,
query,
&mut result_output_buffer,
)
.await
.unwrap();
}
assert_top_k_exactly_match(q, >, &ids, &distances, top_k);
}
}
};
}
scalar_quant_test!(test_sift_build_and_search_scalar_q_1bit, 1, 130);
scalar_quant_test!(test_sift_build_and_search_scalar_q_4bit, 4, 20);
scalar_quant_test!(test_sift_build_and_search_scalar_q_8bit, 8, 20);
scalar_quant_test!(test_sift_build_and_search_scalar_q_7bit, 7, 20);
macro_rules! scalar_only_test {
($name:ident, $nbits:literal) => {
#[tokio::test]
async fn $name() {
let ctx = &DefaultContext;
let parameters = InitParams {
l_build: 64,
max_degree: 16,
metric: Metric::L2,
batchsize: NonZeroUsize::new(1).unwrap(),
};
let create_fn = |data: Arc<Matrix<f32>>, start_point: &[f32]| {
let quantizer = ScalarQuantizationParameters::default().train(data.as_view());
let (config, params) =
parameters.materialize(data.nrows(), data.ncols()).unwrap();
let index = Arc::new(
new_quant_only_index(
config,
params,
inmem::WithBits::<$nbits>::new(quantizer),
NoDeletes,
)
.unwrap(),
);
index
.provider()
.set_start_points(std::iter::once(start_point))
.unwrap();
index
};
let (index, data) =
init_and_build_index_from_file(SIFTSMALL, create_fn, build_using_single_insert)
.await;
let neighbor_accessor = &mut index.provider().neighbors();
assert_eq!(
index
.count_reachable_nodes(
&index.provider().starting_points().unwrap(),
neighbor_accessor
)
.await
.unwrap(),
data.nrows() + 1,
);
let top_k = 10;
let mut ids = vec![0; top_k];
let mut distances = vec![0.0; top_k];
for (q, query) in data.row_iter().enumerate() {
{
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let graph_search = graph::search::Knn::new_default(top_k, top_k).unwrap();
index
.search(
graph_search,
&Quantized,
ctx,
query,
&mut result_output_buffer,
)
.await
.unwrap();
}
assert!(ids.contains(&(q as u32)));
}
}
};
}
scalar_only_test!(test_sift_quant_only_build_and_search_scalar_1bit, 1);
scalar_only_test!(test_sift_quant_only_build_and_search_scalar_4bit, 4);
scalar_only_test!(test_sift_quant_only_build_and_search_scalar_8bit, 8);
scalar_only_test!(test_sift_quant_only_build_and_search_scalar_7bit, 7);
#[tokio::test]
async fn test_sift_build_and_search_spherical() {
let ctx = &DefaultContext;
let parameters = InitParams {
l_build: 64,
max_degree: 16,
metric: Metric::L2,
batchsize: NonZeroUsize::new(1).unwrap(),
};
let rng = &mut create_rnd_from_seed_in_tests(0x56870bccb0c44b66);
let create_fn = |data: Arc<Matrix<f32>>, start_point: &[f32]| {
let quantizer = diskann_quantization::spherical::SphericalQuantizer::train(
data.as_view(),
diskann_quantization::algorithms::transforms::TransformKind::PaddingHadamard {
target_dim: diskann_quantization::algorithms::transforms::TargetDim::Natural,
},
Metric::L2.try_into().unwrap(),
diskann_quantization::spherical::PreScale::ReciprocalMeanNorm,
rng,
diskann_quantization::alloc::GlobalAllocator,
)
.unwrap();
let (config, params) = parameters.materialize(data.nrows(), data.ncols()).unwrap();
let index = new_quant_index::<f32, _, _>(
config,
params,
diskann_quantization::spherical::iface::Impl::<1>::new(quantizer).unwrap(),
NoDeletes,
)
.unwrap();
index
.provider()
.set_start_points(std::iter::once(start_point))
.unwrap();
index
};
let build_fn = async |index: Arc<DiskANNIndex<_>>, data: Arc<Matrix<f32>>| {
let ctx = &DefaultContext;
let strategy = inmem::spherical::Quantized::build();
for (i, vector) in data.row_iter().enumerate() {
index
.insert(strategy, ctx, &(i as u32), vector)
.await
.unwrap()
}
};
let (index, data) = init_and_build_index_from_file(SIFTSMALL, create_fn, build_fn).await;
let neighbor_accessor = &mut index.provider().neighbors();
assert_eq!(
index
.count_reachable_nodes(
&index.provider().starting_points().unwrap(),
neighbor_accessor
)
.await
.unwrap(),
data.nrows() + 1,
);
let top_k = 5;
let search_l = 80;
let mut ids = vec![0; top_k];
let mut distances = vec![0.0; top_k];
for (q, query) in data.row_iter().enumerate() {
let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b));
let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap();
index
.search(graph_search, &FullPrecision, ctx, query, &mut output)
.await
.unwrap();
assert_top_k_exactly_match(q, >, &ids, &distances, top_k);
let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let strategy = inmem::spherical::Quantized::search(
diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed,
);
let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap();
index
.search(graph_search, &strategy, ctx, query, &mut output)
.await
.unwrap();
assert_top_k_exactly_match(q, >, &ids, &distances, top_k);
}
let strategy = inmem::spherical::Quantized::build();
let accessor = strategy.search_accessor(index.provider(), ctx).unwrap();
let computer = accessor.build_query_computer(data.row(0)).unwrap();
assert_eq!(
computer.layout(),
diskann_quantization::spherical::iface::QueryLayout::SameAsData
);
}
#[tokio::test]
async fn test_sift_spherical_only_build_and_search_() {
let ctx = &DefaultContext;
let rng = &mut create_rnd_from_seed_in_tests(0x56870bccb0c44b66);
let create_fn = |data: Arc<Matrix<f32>>, start_points: &[f32]| {
let quantizer = diskann_quantization::spherical::SphericalQuantizer::train(
data.as_view(),
diskann_quantization::algorithms::transforms::TransformKind::PaddingHadamard {
target_dim: diskann_quantization::algorithms::transforms::TargetDim::Natural,
},
Metric::L2.try_into().unwrap(),
diskann_quantization::spherical::PreScale::ReciprocalMeanNorm,
rng,
diskann_quantization::alloc::GlobalAllocator,
)
.unwrap();
let (config, params) =
simplified_builder(64, 16, Metric::L2, data.ncols(), data.nrows(), no_modify)
.unwrap();
let index = new_quant_only_index(
config,
params,
diskann_quantization::spherical::iface::Impl::<1>::new(quantizer).unwrap(),
NoDeletes,
)
.unwrap();
index
.provider()
.set_start_points(std::iter::once(start_points))
.unwrap();
Arc::new(index)
};
let build_fn = async |index: Arc<DiskANNIndex<_>>, data: Arc<Matrix<f32>>| {
let ctx = &DefaultContext;
let strategy = inmem::spherical::Quantized::build();
for (i, vector) in data.row_iter().enumerate() {
index
.insert(strategy, ctx, &(i as u32), vector)
.await
.unwrap()
}
};
let (index, data) = init_and_build_index_from_file(SIFTSMALL, create_fn, build_fn).await;
let neighbor_accessor = &mut index.provider().neighbors();
assert_eq!(
index
.count_reachable_nodes(
&index.provider().starting_points().unwrap(),
neighbor_accessor
)
.await
.unwrap(),
data.nrows() + 1,
);
let top_k = 5;
let search_l = 80;
let mut ids = vec![0; top_k];
let mut distances = vec![0.0; top_k];
for (q, query) in data.row_iter().enumerate() {
let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let strategy = inmem::spherical::Quantized::search(
diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed,
);
let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap();
index
.search(graph_search, &strategy, ctx, query, &mut output)
.await
.unwrap();
assert!(ids.contains(&(q as u32)));
}
let strategy = inmem::spherical::Quantized::build();
let accessor = <inmem::spherical::Quantized as SearchStrategy<
DefaultProvider<NoStore, inmem::spherical::SphericalStore>,
&[f32],
>>::search_accessor(&strategy, index.provider(), ctx)
.unwrap();
let computer = accessor.build_query_computer(data.row(0)).unwrap();
assert_eq!(
computer.layout(),
diskann_quantization::spherical::iface::QueryLayout::SameAsData
);
}
#[tokio::test]
async fn test_sift_pq_only_build_and_search() {
let ctx = &DefaultContext;
let create_fn = |data: Arc<Matrix<f32>>, start_points: &[f32]| {
let pq_table = train_pq(
data.as_view(),
32,
&mut create_rnd_from_seed_in_tests(0xe3c52ef001bc7ade),
1,
)
.unwrap();
let (config, parameters) =
simplified_builder(64, 16, Metric::L2, data.ncols(), data.nrows(), no_modify)
.unwrap();
let index =
Arc::new(new_quant_only_index(config, parameters, pq_table, NoDeletes).unwrap());
index
.provider()
.set_start_points(std::iter::once(start_points))
.unwrap();
index
};
let (index, data) =
init_and_build_index_from_file(SIFTSMALL, create_fn, build_using_single_insert).await;
let neighbor_accessor = &mut index.provider().neighbors();
assert_eq!(
index
.count_reachable_nodes(
&index.provider().starting_points().unwrap(),
neighbor_accessor
)
.await
.unwrap(),
data.nrows() + 1,
);
let top_k = 10;
let search_l = 32;
let mut ids = vec![0; top_k];
let mut distances = vec![0.0; top_k];
for (q, query) in data.row_iter().enumerate() {
let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b));
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap();
index
.search(
graph_search,
&Quantized,
ctx,
query,
&mut result_output_buffer,
)
.await
.unwrap();
assert_top_k_exactly_match(q, >, &ids, &distances, top_k);
}
}
async fn check_graph_for_self_loops_or_duplicates<NA, Itr>(accessor: &mut NA, itr: Itr)
where
NA: AsNeighbor<Id = u32>,
Itr: Iterator<Item = u32>,
{
for id in itr {
let mut adj_list = AdjacencyList::new();
accessor
.get_neighbors(id, &mut adj_list)
.await
.expect("Error in get_neighbors");
assert!(!adj_list.contains(id));
let len_before_dedup = adj_list.len();
let mut adj_list: Vec<_> = adj_list.into();
adj_list.sort();
adj_list.dedup();
assert_eq!(adj_list.len(), len_before_dedup);
}
}
type TestProvider =
FullPrecisionProvider<f32, DefaultQuant, TableDeleteProviderAsync, DefaultContext>;
type TestIndex = DiskANNIndex<TestProvider>;
#[derive(Debug, Clone, Copy)]
pub struct InitParams {
pub l_build: usize,
pub max_degree: usize,
pub metric: Metric,
pub batchsize: NonZeroUsize,
}
impl InitParams {
pub fn materialize(
&self,
npoints: usize,
dim: usize,
) -> ANNResult<(Config, DefaultProviderParameters)> {
simplified_builder(
self.l_build,
self.max_degree,
self.metric,
dim,
npoints,
|builder| {
builder.max_minibatch_par(self.batchsize.into());
},
)
}
}
pub async fn build_index<S, U, V, D>(
index: &Arc<DiskANNIndex<DefaultProvider<U, V, D>>>,
strategy: S,
parameters: InitParams,
file: &str,
start_strategy: StartPointStrategy,
train_data: diskann_utils::views::MatrixView<'_, f32>,
) where
DefaultProvider<U, V, D>: DataProvider<ExternalId = u32, Context = DefaultContext>
+ for<'a> SetElement<&'a [f32]>
+ SetStartPoints<[f32]>,
S: for<'a> InsertStrategy<DefaultProvider<U, V, D>, &'a [f32]>
+ MultiInsertStrategy<DefaultProvider<U, V, D>, Matrix<f32>>
+ Clone,
{
let ctx = &DefaultContext;
let storage = VirtualStorageProvider::new_overlay(test_data_root());
let mut iter = VectorDataIterator::<_, f32>::new(file, None, &storage).unwrap();
let start_vectors: Matrix<f32> = start_strategy.compute(train_data).unwrap();
index
.provider()
.set_start_points(start_vectors.row_iter())
.unwrap();
let batchsize: usize = parameters.batchsize.into();
if batchsize == 1 {
for (i, (vector, _)) in iter.enumerate() {
index
.insert(strategy.clone(), ctx, &(i as u32), &vector)
.await
.unwrap()
}
} else {
let mut i: u32 = 0;
while let Some(data) = iter.next_n(batchsize) {
let mut vectors = Matrix::new(0.0f32, data.len(), start_vectors.ncols());
let ids: Arc<[_]> = std::iter::zip(vectors.row_iter_mut(), data.iter())
.map(|(dst, (v, _))| {
dst.copy_from_slice(v);
let id = i;
i += 1;
id
})
.collect();
index
.multi_insert::<S, _>(strategy.clone(), ctx, Arc::new(vectors), ids)
.await
.unwrap();
}
}
}
async fn init_from_file<S>(
strategy: S,
parameters: InitParams,
file: &str,
num_pq_chunks: usize,
startpoint: StartPointStrategy,
) -> (Arc<TestIndex>, diskann_utils::views::Matrix<f32>)
where
S: for<'a> InsertStrategy<TestProvider, &'a [f32]>
+ MultiInsertStrategy<TestProvider, Matrix<f32>>
+ Clone,
{
let storage = VirtualStorageProvider::new_overlay(test_data_root());
let mut reader = storage.open_reader(file).unwrap();
let train_data = diskann_utils::io::read_bin::<f32>(&mut reader).unwrap();
let (npoints, dim) = (train_data.nrows(), train_data.ncols());
let table = train_pq(
train_data.as_view(),
num_pq_chunks,
&mut create_rnd_from_seed_in_tests(0xe3c52ef001bc7ade),
1,
)
.unwrap();
let (config, params) = parameters.materialize(npoints, dim).unwrap();
let index = new_quant_index(config, params, table, TableBasedDeletes).unwrap();
build_index(
&index,
strategy,
parameters,
file,
startpoint,
train_data.as_view(),
)
.await;
(index, train_data)
}
#[rstest]
#[tokio::test]
async fn inplace_delete_on_sift<S>(
#[values(FullPrecision, Hybrid::new(None))] strategy: S,
#[values(20, 100)] points_to_delete: u32,
#[values(
InplaceDeleteMethod::VisitedAndTopK{k_value:5, l_value:10},
InplaceDeleteMethod::TwoHopAndOneHop,
InplaceDeleteMethod::OneHop,
)]
delete_method: InplaceDeleteMethod,
) where
S: for<'a> InsertStrategy<TestProvider, &'a [f32]>
+ for<'a> SearchStrategy<TestProvider, &'a [f32]>
+ for<'a> InplaceDeleteStrategy<TestProvider, DeleteElement<'a> = &'a [f32]>
+ MultiInsertStrategy<TestProvider, Matrix<f32>>
+ Clone,
{
let ctx = &DefaultContext;
let parameters = InitParams {
l_build: 10,
max_degree: 32,
metric: Metric::L2,
batchsize: NonZeroUsize::new(1).unwrap(),
};
let (index, data) = init_from_file(
strategy.clone(),
parameters,
SIFTSMALL,
8,
StartPointStrategy::RandomSamples {
nsamples: ONE,
seed: 0x440f42ab05085ba2,
},
)
.await;
let num_to_replace = 3;
for id in 0..points_to_delete {
index
.inplace_delete(strategy.clone(), ctx, &id, num_to_replace, delete_method)
.await
.unwrap();
}
for id in 0..points_to_delete {
assert!(
index
.data_provider
.status_by_external_id(ctx, &id)
.await
.unwrap()
.is_deleted()
);
}
let num_start_points = index
.provider()
.starting_points()
.expect("Error in get_starting_point_ids")
.len();
let neighbor_accessor = &mut index.provider().neighbors();
for id in 0..data.nrows() + num_start_points {
index
.drop_deleted_neighbors(ctx, neighbor_accessor, id.try_into().unwrap(), false)
.await
.unwrap();
}
for id in points_to_delete.into_usize()..data.nrows() + num_start_points {
assert!(
!(index.is_any_neighbor_deleted(ctx, neighbor_accessor, id.try_into().unwrap()))
.await
.expect("Error in is_any_neighbor_deleted")
);
}
let mut adj_list = AdjacencyList::new();
for id in 0..points_to_delete {
neighbor_accessor
.get_neighbors(id, &mut adj_list)
.await
.expect("Error in get_neighbors");
assert!(adj_list.is_empty());
}
check_graph_for_self_loops_or_duplicates(
neighbor_accessor,
(&index.data_provider).into_iter(),
)
.await;
}
#[rstest]
#[tokio::test]
async fn multi_inplace_delete_on_sift<S>(
#[values(FullPrecision, Hybrid::new(None))] strategy: S,
#[values(20, 100)] points_to_delete: u32,
#[values(
InplaceDeleteMethod::VisitedAndTopK{k_value:5, l_value:10},
InplaceDeleteMethod::TwoHopAndOneHop
)]
delete_method: InplaceDeleteMethod,
) where
S: for<'a> InsertStrategy<TestProvider, &'a [f32]>
+ for<'a> SearchStrategy<TestProvider, &'a [f32]>
+ for<'a> InplaceDeleteStrategy<TestProvider, DeleteElement<'a> = &'a [f32]>
+ MultiInsertStrategy<TestProvider, Matrix<f32>>
+ Clone,
{
let ctx = &DefaultContext;
let parameters = InitParams {
l_build: 10,
max_degree: 32,
metric: Metric::L2,
batchsize: NonZeroUsize::new(1).unwrap(),
};
let (index, data) = init_from_file(
strategy.clone(),
parameters,
SIFTSMALL,
8,
StartPointStrategy::RandomSamples {
nsamples: ONE,
seed: 0x440f42ab05085ba2,
},
)
.await;
let num_to_replace = 3;
let ids: Vec<u32> = (0..points_to_delete).collect();
let ids = Arc::new(ids.as_slice());
index
.multi_inplace_delete(
strategy,
ctx,
(&**ids).into(),
num_to_replace,
delete_method,
)
.await
.unwrap();
for id in 0..points_to_delete {
assert!(
index
.data_provider
.status_by_external_id(ctx, &id)
.await
.unwrap()
.is_deleted()
);
}
let num_start_points = index
.data_provider
.starting_points()
.expect("Error in get_starting_point_ids")
.len();
let neighbor_accessor = &mut index.provider().neighbors();
for id in 0..data.nrows() + num_start_points {
index
.drop_deleted_neighbors(ctx, neighbor_accessor, id.try_into().unwrap(), false)
.await
.unwrap();
}
for id in points_to_delete.into_usize()..data.nrows() + num_start_points {
assert!(
!(index.is_any_neighbor_deleted(ctx, neighbor_accessor, id.try_into().unwrap()))
.await
.expect("Error in is_any_neighbor_deleted")
);
}
let mut adj_list = AdjacencyList::new();
for id in 0..points_to_delete {
neighbor_accessor
.get_neighbors(id, &mut adj_list)
.await
.expect("Error in get_neighbors");
assert!(adj_list.is_empty());
}
check_graph_for_self_loops_or_duplicates(
neighbor_accessor,
(&index.data_provider).into_iter(),
)
.await;
}
#[rstest]
#[tokio::test]
async fn test_sift_256_vectors_with_consolidate_deletes(
#[values(20, 100)] points_to_delete: u32,
) {
let ctx = &DefaultContext;
let parameters = InitParams {
l_build: 10,
max_degree: 32,
metric: Metric::L2,
batchsize: NonZeroUsize::new(1).unwrap(),
};
let (index, data) = init_from_file(
FullPrecision,
parameters,
SIFTSMALL,
8,
StartPointStrategy::RandomSamples {
nsamples: ONE,
seed: 0x440f42ab05085ba2,
},
)
.await;
for id in 0..points_to_delete {
index
.data_provider
.delete(ctx, &id)
.await
.expect("Error in delete");
}
for id in 0..points_to_delete {
assert!(
index
.data_provider
.status_by_external_id(ctx, &id)
.await
.unwrap()
.is_deleted()
);
}
let num_start_points = index
.provider()
.starting_points()
.expect("Error in get_starting_point_ids")
.len();
let total_points = data.nrows() + num_start_points;
for id in 0..total_points {
index
.consolidate_vector(&FullPrecision, ctx, id.try_into().unwrap())
.await
.expect("Error in consolidate_vector");
}
let neighbor_accessor = &mut index.provider().neighbors();
for id in points_to_delete.into_usize()..total_points {
assert!(
!(index.is_any_neighbor_deleted(ctx, neighbor_accessor, id.try_into().unwrap()))
.await
.expect("Error in is_any_neighbor_deleted")
);
}
check_graph_for_self_loops_or_duplicates(
neighbor_accessor,
(&index.data_provider).into_iter(),
)
.await;
}
#[tokio::test]
async fn test_final_prune() {
let ctx = &DefaultContext;
let max_degree = 32;
let parameters = InitParams {
l_build: 15,
max_degree,
metric: Metric::L2,
batchsize: NonZeroUsize::new(1).unwrap(),
};
let (index, _) = init_from_file(
FullPrecision,
parameters,
SIFTSMALL,
8,
StartPointStrategy::RandomSamples {
nsamples: ONE,
seed: 0x986ce825cbe015e9,
},
)
.await;
let neighbor_accessor = &mut index.provider().neighbors();
let stats = index.get_degree_stats(neighbor_accessor).await.unwrap();
assert!(stats.max_degree.into_usize() > max_degree);
index
.prune_range(&FullPrecision, ctx, 0..256)
.await
.unwrap();
let stats = index.get_degree_stats(neighbor_accessor).await.unwrap();
assert!(stats.max_degree.into_usize() <= max_degree);
}
#[rstest]
#[tokio::test]
async fn test_replace_sift_256_vectors_with_quant_vectors(
#[values(None, Some(32))] max_fp_vecs_per_prune: Option<usize>,
#[values(1, 3)] insert_minibatch_size: usize,
) {
let ctx = &DefaultContext;
let parameters = InitParams {
l_build: 35,
max_degree: 32,
metric: Metric::L2,
batchsize: NonZeroUsize::new(insert_minibatch_size).unwrap(),
};
let (index, data) = init_from_file(
Hybrid::new(max_fp_vecs_per_prune),
parameters,
SIFTSMALL,
32,
StartPointStrategy::RandomSamples {
nsamples: ONE,
seed: 0x812b98835db95971,
},
)
.await;
let mut indices: Vec<_> = (0..data.nrows()).collect();
let rng = &mut create_rnd_from_seed_in_tests(0x7dc205fcda38d3a3);
indices.shuffle(rng);
let mut queries = diskann_utils::views::Matrix::new(0.0, data.nrows(), data.ncols());
std::iter::zip(queries.row_iter_mut(), indices.iter()).for_each(|(row, i)| {
row.copy_from_slice(data.row(*i));
});
for (pos, query) in queries.row_iter().enumerate() {
index
.insert(
Hybrid::new(max_fp_vecs_per_prune),
ctx,
&(pos as u32),
query,
)
.await
.unwrap();
}
assert_eq!(
index
.count_reachable_nodes(
&index.provider().starting_points().unwrap(),
&mut index.provider().neighbors()
)
.await
.unwrap(),
data.nrows() + 1,
);
let top_k = 4;
let search_l = 40;
let mut ids = vec![0; top_k];
let mut distances = vec![0.0; top_k];
for (q, query) in queries.row_iter().enumerate() {
let gt = groundtruth(queries.as_view(), query, |a, b| SquaredL2::evaluate(a, b));
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap();
index
.search(
graph_search,
&Hybrid::new(max_fp_vecs_per_prune),
ctx,
query,
&mut result_output_buffer,
)
.await
.unwrap();
println!(
"gt = {:?}, ids = {:?}, distance = {:?}",
>[gt.len() - 2 * top_k..],
ids,
distances
);
assert_top_k_exactly_match(q, >, &ids, &distances, top_k);
}
}
async fn test_one_level_index_same_as_two_level_impl(batchsize: NonZeroUsize) {
let parameters = InitParams {
l_build: 64,
max_degree: 16,
metric: Metric::L2,
batchsize,
};
let start_point = StartPointStrategy::RandomSamples {
nsamples: ONE,
seed: 0xe058c9c57864dd1e,
};
let (quant_index, data) =
init_from_file(FullPrecision, parameters, SIFTSMALL, 8, start_point).await;
let (config, params) = parameters.materialize(data.nrows(), data.ncols()).unwrap();
let full_index = new_index(config, params, TableBasedDeletes).unwrap();
build_index(
&full_index,
FullPrecision,
parameters,
SIFTSMALL,
start_point,
data.as_view(),
)
.await;
let iter = (&quant_index.data_provider).into_iter();
let mut from_quant = AdjacencyList::new();
let mut from_full = AdjacencyList::new();
for id in iter {
quant_index
.data_provider
.neighbors()
.get_neighbors(id, &mut from_quant)
.await
.unwrap();
full_index
.data_provider
.neighbors()
.get_neighbors(id, &mut from_full)
.await
.unwrap();
from_quant.sort();
from_full.sort();
assert_eq!(from_quant, from_full);
}
}
#[tokio::test]
async fn test_one_level_index_same_as_two_level() {
test_one_level_index_same_as_two_level_impl(NonZeroUsize::new(1).unwrap()).await;
test_one_level_index_same_as_two_level_impl(NonZeroUsize::new(10).unwrap()).await;
}
#[tokio::test]
async fn test_flaky_build() {
let parameters = InitParams {
l_build: 64,
max_degree: 16,
metric: Metric::L2,
batchsize: NonZeroUsize::new(1).unwrap(),
};
let start_point = StartPointStrategy::RandomSamples {
nsamples: ONE,
seed: 0xb4de0a1298a86eea,
};
let (index, data) = init_from_file(
inmem::test::Flaky::new(9),
parameters,
SIFTSMALL,
8,
start_point,
)
.await;
let neighbor_accessor = &mut index.provider().neighbors();
assert_eq!(
index
.count_reachable_nodes(
&index.provider().starting_points().unwrap(),
neighbor_accessor
)
.await
.unwrap(),
data.nrows() + 1,
);
let top_k = 10;
let search_l = 32;
let mut ids = vec![0; top_k];
let mut distances = vec![0.0; top_k];
let ctx = &DefaultContext;
for (q, query) in data.row_iter().enumerate() {
let gt = groundtruth(data.as_view(), query, |a, b| SquaredL2::evaluate(a, b));
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap();
index
.search(
graph_search,
&FullPrecision,
ctx,
query,
&mut result_output_buffer,
)
.await
.unwrap();
assert_top_k_exactly_match(q, >, &ids, &distances, top_k);
}
}
async fn create_retry_saturated_index(
retry: NonZeroU32,
saturated: bool,
) -> ANNResult<MemoryIndex<f32>> {
let (config, params) = simplified_builder(5, 10, Metric::L2, 3, 1001, |builder| {
builder.insert_retry(graph::config::experimental::InsertRetry::new(
retry,
NonZeroU32::new(20).unwrap(),
saturated,
));
})
.unwrap();
let index = new_index::<f32, _>(config, params, NoDeletes).unwrap();
let mut id_counter = 1;
let start = vec![0.0, 0.0, 0.0];
index
.provider()
.set_start_points(vec![start.as_slice()].into_iter())
.unwrap();
for x in 1..11 {
for y in 1..11 {
for z in 1..11 {
let vec = vec![x as f32, y as f32, z as f32];
index
.insert(FullPrecision, &DefaultContext, &id_counter.clone(), &vec)
.await?;
id_counter += 1;
}
}
}
Ok(index)
}
#[tokio::test]
async fn test_saturate_index() {
let index_sat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), true)
.await
.unwrap();
let mut accessor_sat = inmem::FullAccessor::new(index_sat.provider());
let res_sat = index_sat.get_degree_stats(&mut accessor_sat).await.unwrap();
let index_unsat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), false)
.await
.unwrap();
let mut accessor_unsat = inmem::FullAccessor::new(index_unsat.provider());
let res_unsat = index_sat
.get_degree_stats(&mut accessor_unsat)
.await
.unwrap();
assert!(
res_sat.avg_degree > res_unsat.avg_degree,
"Saturated index should have higher average degree than the unsaturated index"
);
}
#[tokio::test]
async fn test_retry_index() {
let index_sat = create_retry_saturated_index(NonZeroU32::new(3).unwrap(), false)
.await
.unwrap();
let mut accessor_sat = inmem::FullAccessor::new(index_sat.provider());
let res_sat = index_sat.get_degree_stats(&mut accessor_sat).await.unwrap();
let index_unsat = create_retry_saturated_index(NonZeroU32::new(1).unwrap(), false)
.await
.unwrap();
let mut accessor_unsat = inmem::FullAccessor::new(index_unsat.provider());
let res_unsat = index_sat
.get_degree_stats(&mut accessor_unsat)
.await
.unwrap();
assert!(
res_sat.avg_degree > res_unsat.avg_degree,
"Saturated index should have higher average degree than the unsaturated index"
);
}
#[cfg(feature = "experimental_diversity_search")]
#[tokio::test]
async fn test_inmemory_search_diversity_search() {
use diskann::neighbor::AttributeValueProvider;
use rand::Rng;
use std::collections::HashMap;
#[derive(Debug, Clone)]
struct TestAttributeProvider {
attributes: HashMap<u32, u32>,
}
impl TestAttributeProvider {
fn new() -> Self {
Self {
attributes: HashMap::new(),
}
}
fn insert(&mut self, id: u32, attribute: u32) {
self.attributes.insert(id, attribute);
}
}
impl diskann::provider::HasId for TestAttributeProvider {
type Id = u32;
}
impl AttributeValueProvider for TestAttributeProvider {
type Value = u32;
fn get(&self, id: Self::Id) -> Option<Self::Value> {
self.attributes.get(&id).copied()
}
}
let dim = 128;
let num_points = 256;
let mut data_vectors = Vec::new();
let mut rng = create_rnd_from_seed_in_tests(42);
for _ in 0..num_points {
let vec: Vec<f32> = (0..dim).map(|_| rng.random_range(0.0..1.0)).collect();
data_vectors.push(vec);
}
let l_build = 50;
let max_degree = 32;
let (config, parameters) =
simplified_builder(l_build, max_degree, Metric::L2, dim, num_points, no_modify)
.unwrap();
let index = new_index::<f32, _>(config, parameters, NoDeletes).unwrap();
index
.provider()
.set_start_points(std::iter::once(data_vectors[0].as_slice()))
.unwrap();
for (i, vec) in data_vectors.iter().enumerate() {
index
.insert(FullPrecision, &DefaultContext, &(i as u32), vec.as_slice())
.await
.unwrap();
}
let mut attribute_provider = TestAttributeProvider::new();
for i in 0..num_points {
let label = ((i % 5) + 1) as u32;
attribute_provider.insert(i as u32, label);
}
attribute_provider.insert(num_points as u32, 1);
let attribute_provider = std::sync::Arc::new(attribute_provider);
let query = vec![0.5f32; dim];
let return_list_size = 10;
let search_list_size = 20;
let diverse_results_k = 1;
let mut indices = vec![0u32; return_list_size];
let mut distances = vec![0f32; return_list_size];
let mut result_output_buffer =
diskann::graph::IdDistance::new(&mut indices, &mut distances);
let diverse_params = diskann::graph::DiverseSearchParams::new(
0, diverse_results_k,
attribute_provider.clone(),
);
let search_params = diskann::graph::search::Knn::new(
return_list_size,
search_list_size,
None, )
.unwrap();
let diverse_search = diskann::graph::search::Diverse::new(search_params, diverse_params);
let result = index
.search(
diverse_search,
&FullPrecision,
&DefaultContext,
query.as_slice(),
&mut result_output_buffer,
)
.await;
assert!(result.is_ok(), "Expected diversity search to succeed");
let stats = result.unwrap();
assert!(
stats.result_count as usize <= return_list_size,
"Expected result count to be <= {}",
return_list_size
);
assert!(
stats.result_count > 0,
"Expected to get some search results"
);
println!("\n=== In-Memory Diversity Search Results ===");
println!("Query: [0.5f32; {}]", dim);
println!("diverse_results_k: {}", diverse_results_k);
println!("Total results: {}\n", stats.result_count);
println!("{:<10} {:<15} {:<10}", "Vertex ID", "Distance", "Label");
println!("{}", "-".repeat(35));
for i in 0..stats.result_count as usize {
let attribute_value = attribute_provider.get(indices[i]).unwrap_or(0);
println!(
"{:<10} {:<15.2} {:<10}",
indices[i], distances[i], attribute_value
);
}
for i in 0..(stats.result_count as usize).saturating_sub(1) {
assert!(distances[i] >= 0.0, "Expected non-negative distance");
assert!(
distances[i] <= distances[i + 1],
"Expected distances to be sorted in ascending order"
);
}
let mut attribute_counts = HashMap::new();
for item in indices.iter().take(stats.result_count as usize) {
if let Some(attribute_value) = attribute_provider.get(*item) {
*attribute_counts.entry(attribute_value).or_insert(0) += 1;
}
}
println!("\n=== Attribute Distribution ===");
let mut sorted_attrs: Vec<_> = attribute_counts.iter().collect();
sorted_attrs.sort_by_key(|(k, _)| *k);
for (attribute_value, count) in &sorted_attrs {
println!(
"Label {}: {} occurrences (max allowed: {})",
attribute_value, count, diverse_results_k
);
}
println!("Total unique labels: {}", attribute_counts.len());
println!("================================\n");
for (attribute_value, count) in &attribute_counts {
println!(
"Assert: Label {} has {} occurrences (max: {})",
attribute_value, count, diverse_results_k
);
assert!(
*count <= diverse_results_k,
"Attribute value {} appears {} times, which exceeds diverse_results_k of {}",
attribute_value,
count,
diverse_results_k
);
}
println!(
"Assert: Found {} unique labels (expected at least 2)",
attribute_counts.len()
);
assert!(
attribute_counts.len() >= 2,
"Expected at least 2 different attribute values for diversity, got {}",
attribute_counts.len()
);
}
#[derive(Debug)]
struct RejectAllFilter {
allowed_in_results: HashSet<u32>,
}
impl RejectAllFilter {
fn only<I: IntoIterator<Item = u32>>(ids: I) -> Self {
Self {
allowed_in_results: ids.into_iter().collect(),
}
}
}
impl QueryLabelProvider<u32> for RejectAllFilter {
fn is_match(&self, vec_id: u32) -> bool {
self.allowed_in_results.contains(&vec_id)
}
fn on_visit(&self, _neighbor: Neighbor<u32>) -> QueryVisitDecision<u32> {
QueryVisitDecision::Reject
}
}
#[derive(Debug)]
struct TerminatingFilter {
target: u32,
hits: Mutex<Vec<u32>>,
}
impl TerminatingFilter {
fn new(target: u32) -> Self {
Self {
target,
hits: Mutex::new(Vec::new()),
}
}
fn hits(&self) -> Vec<u32> {
self.hits
.lock()
.expect("mutex should not be poisoned")
.clone()
}
}
impl QueryLabelProvider<u32> for TerminatingFilter {
fn is_match(&self, vec_id: u32) -> bool {
vec_id == self.target
}
fn on_visit(&self, neighbor: Neighbor<u32>) -> QueryVisitDecision<u32> {
self.hits
.lock()
.expect("mutex should not be poisoned")
.push(neighbor.id);
if neighbor.id == self.target {
QueryVisitDecision::Terminate
} else {
QueryVisitDecision::Accept(neighbor)
}
}
}
#[tokio::test]
async fn test_multihop_reject_all_returns_zero_results() {
let dim = 3;
let grid_size: usize = 4;
let l = 10;
let max_degree = 2 * dim;
let num_points = (grid_size).pow(dim as u32);
let (config, parameters) =
simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap();
let mut adjacency_lists = Grid::Three.neighbors(grid_size);
let mut vectors = f32::generate_grid(dim, grid_size);
adjacency_lists.push((num_points as u32 - 1).into());
vectors.push(vec![grid_size as f32; dim]);
let table = train_pq(
squish(vectors.iter(), dim).as_view(),
2.min(dim),
&mut create_rnd_from_seed_in_tests(0x1234567890abcdef),
1usize,
)
.unwrap();
let index = new_quant_index::<f32, _, _>(config, parameters, table, NoDeletes).unwrap();
let neighbor_accessor = &mut index.provider().neighbors();
populate_data(&index.data_provider, &DefaultContext, &vectors).await;
populate_graph(neighbor_accessor, &adjacency_lists).await;
let query = vec![grid_size as f32; dim];
let mut ids = vec![0; 10];
let mut distances = vec![0.0; 10];
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let filter = RejectAllFilter::only([0_u32]);
let search_params = Knn::new_default(10, 20).unwrap();
let multihop = graph::search::MultihopSearch::new(search_params, &filter);
let stats = index
.search(
multihop,
&FullPrecision,
&DefaultContext,
query.as_slice(),
&mut result_output_buffer,
)
.await
.unwrap();
assert_eq!(
stats.result_count, 0,
"rejecting all via on_visit should result in zero results"
);
}
#[tokio::test]
async fn test_multihop_early_termination() {
let dim = 3;
let grid_size: usize = 5;
let l = 10;
let max_degree = 2 * dim;
let num_points = (grid_size).pow(dim as u32);
let (config, parameters) =
simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap();
let mut adjacency_lists = Grid::Three.neighbors(grid_size);
let mut vectors = f32::generate_grid(dim, grid_size);
adjacency_lists.push((num_points as u32 - 1).into());
vectors.push(vec![grid_size as f32; dim]);
let table = train_pq(
squish(vectors.iter(), dim).as_view(),
2.min(dim),
&mut create_rnd_from_seed_in_tests(0xfedcba0987654321),
1usize,
)
.unwrap();
let index = new_quant_index::<f32, _, _>(config, parameters, table, NoDeletes).unwrap();
let neighbor_accessor = &mut index.provider().neighbors();
populate_data(&index.data_provider, &DefaultContext, &vectors).await;
populate_graph(neighbor_accessor, &adjacency_lists).await;
let query = vec![grid_size as f32; dim];
let mut ids = vec![0; 10];
let mut distances = vec![0.0; 10];
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let target = (num_points / 2) as u32;
let filter = TerminatingFilter::new(target);
let search_params = Knn::new_default(10, 40).unwrap();
let multihop = graph::search::MultihopSearch::new(search_params, &filter);
let stats = index
.search(
multihop,
&FullPrecision,
&DefaultContext,
query.as_slice(),
&mut result_output_buffer,
)
.await
.unwrap();
let hits = filter.hits();
assert!(
hits.contains(&target),
"search should have visited the target"
);
assert!(
stats.result_count >= 1,
"should have at least one result (the target)"
);
}
#[tokio::test]
async fn test_multihop_distance_adjustment_affects_ranking() {
let dim = 3;
let grid_size: usize = 4;
let l = 10;
let max_degree = 2 * dim;
let num_points = (grid_size).pow(dim as u32);
let (config, parameters) =
simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap();
let mut adjacency_lists = Grid::Three.neighbors(grid_size);
let mut vectors = f32::generate_grid(dim, grid_size);
adjacency_lists.push((num_points as u32 - 1).into());
vectors.push(vec![grid_size as f32; dim]);
let table = train_pq(
squish(vectors.iter(), dim).as_view(),
2.min(dim),
&mut create_rnd_from_seed_in_tests(0xabcdef1234567890),
1usize,
)
.unwrap();
let index = new_quant_index::<f32, _, _>(config, parameters, table, NoDeletes).unwrap();
let neighbor_accessor = &mut index.provider().neighbors();
populate_data(&index.data_provider, &DefaultContext, &vectors).await;
populate_graph(neighbor_accessor, &adjacency_lists).await;
let query = vec![0.0; dim];
let mut baseline_ids = vec![0; 10];
let mut baseline_distances = vec![0.0; 10];
let mut baseline_buffer =
search_output_buffer::IdDistance::new(&mut baseline_ids, &mut baseline_distances);
let search_params = Knn::new_default(10, 20).unwrap();
let multihop = graph::search::MultihopSearch::new(search_params, &EvenFilter);
let baseline_stats = index
.search(
multihop,
&FullPrecision,
&DefaultContext,
query.as_slice(),
&mut baseline_buffer,
)
.await
.unwrap();
let boosted_point = (num_points - 2) as u32; let filter = CallbackFilter::new(u32::MAX, boosted_point, 0.01);
let mut adjusted_ids = vec![0; 10];
let mut adjusted_distances = vec![0.0; 10];
let mut adjusted_buffer =
search_output_buffer::IdDistance::new(&mut adjusted_ids, &mut adjusted_distances);
let search_params = Knn::new_default(10, 20).unwrap();
let multihop = graph::search::MultihopSearch::new(search_params, &filter);
let adjusted_stats = index
.search(
multihop,
&FullPrecision,
&DefaultContext,
query.as_slice(),
&mut adjusted_buffer,
)
.await
.unwrap();
assert!(
baseline_stats.result_count > 0,
"baseline should have results"
);
assert!(
adjusted_stats.result_count > 0,
"adjusted should have results"
);
let boosted_in_baseline = baseline_ids
.iter()
.take(baseline_stats.result_count as usize)
.position(|&id| id == boosted_point);
let boosted_in_adjusted = adjusted_ids
.iter()
.take(adjusted_stats.result_count as usize)
.position(|&id| id == boosted_point);
if filter.hits().contains(&boosted_point) {
assert!(
boosted_in_adjusted.is_some(),
"boosted point should appear in adjusted results when visited"
);
if let (Some(baseline_pos), Some(adjusted_pos)) =
(boosted_in_baseline, boosted_in_adjusted)
{
assert!(
adjusted_pos <= baseline_pos,
"boosted point should rank equal or better after distance reduction"
);
}
}
}
#[tokio::test]
async fn test_multihop_terminate_stops_traversal() {
#[derive(Debug)]
struct TerminateAfterN {
max_visits: usize,
visits: Mutex<usize>,
}
impl TerminateAfterN {
fn new(max_visits: usize) -> Self {
Self {
max_visits,
visits: Mutex::new(0),
}
}
fn visit_count(&self) -> usize {
*self.visits.lock().unwrap()
}
}
impl QueryLabelProvider<u32> for TerminateAfterN {
fn is_match(&self, _: u32) -> bool {
true
}
fn on_visit(&self, neighbor: Neighbor<u32>) -> QueryVisitDecision<u32> {
let mut visits = self.visits.lock().unwrap();
*visits += 1;
if *visits >= self.max_visits {
QueryVisitDecision::Terminate
} else {
QueryVisitDecision::Accept(neighbor)
}
}
}
let dim = 3;
let grid_size: usize = 5;
let l = 10;
let max_degree = 2 * dim;
let num_points = (grid_size).pow(dim as u32);
let (config, parameters) =
simplified_builder(l, max_degree, Metric::L2, dim, num_points, no_modify).unwrap();
let mut adjacency_lists = Grid::Three.neighbors(grid_size);
let mut vectors = f32::generate_grid(dim, grid_size);
adjacency_lists.push((num_points as u32 - 1).into());
vectors.push(vec![grid_size as f32; dim]);
let table = train_pq(
squish(vectors.iter(), dim).as_view(),
2.min(dim),
&mut create_rnd_from_seed_in_tests(0x9876543210fedcba),
1usize,
)
.unwrap();
let index = new_quant_index::<f32, _, _>(config, parameters, table, NoDeletes).unwrap();
let neighbor_accessor = &mut index.provider().neighbors();
populate_data(&index.data_provider, &DefaultContext, &vectors).await;
populate_graph(neighbor_accessor, &adjacency_lists).await;
let query = vec![grid_size as f32; dim];
let mut ids = vec![0; 10];
let mut distances = vec![0.0; 10];
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let max_visits = 5;
let filter = TerminateAfterN::new(max_visits);
let search_params = Knn::new_default(10, 100).unwrap(); let multihop = graph::search::MultihopSearch::new(search_params, &filter);
let _stats = index
.search(
multihop,
&FullPrecision,
&DefaultContext,
query.as_slice(),
&mut result_output_buffer,
)
.await
.unwrap();
assert!(
filter.visit_count() <= max_visits + 10, "search should have terminated early, got {} visits",
filter.visit_count()
);
}
#[tokio::test]
async fn vectors_with_infinity_values_should_be_inserted_and_searched_without_panic() {
let l_build: usize = 20;
let insert_count = l_build + 10;
const VECTORS_DIMENSION: usize = 384;
let vector_value_start: f32 = half::f16::MAX.to_f32() - insert_count as f32 / 3.0;
let (config, mut parameters) = simplified_builder(
l_build,
32,
Metric::L2,
VECTORS_DIMENSION,
insert_count,
|_| {},
)
.unwrap();
parameters.frozen_points = NonZeroUsize::new(1).unwrap();
let index = new_index::<half::f16, _>(config, parameters, NoDeletes).unwrap();
let vectors = (0..insert_count)
.map(move |i| [half::f16::from_f32(vector_value_start + i as f32); VECTORS_DIMENSION])
.collect::<Vec<_>>();
assert_ne!(
vectors[0][0],
half::f16::INFINITY,
"First vector should not have infinity value"
);
assert_eq!(
vectors[vectors.len() - 1][0],
half::f16::INFINITY,
"Last vector should have infinity value"
);
for (i, vector) in vectors.iter().take(insert_count).enumerate() {
let vector_id = i as u32;
index
.insert(FullPrecision, &DefaultContext, &vector_id, vector)
.await
.unwrap();
}
let query_count: usize = 1;
let queries: Vec<half::f16> = vec![half::f16::default(); query_count * VECTORS_DIMENSION];
let top_k = l_build;
let search_l = l_build;
let mut ids = vec![0; top_k];
let mut distances = vec![0.0; top_k];
let ctx = DefaultContext;
let search_params = graph::search::Knn::new_default(top_k, search_l).unwrap();
for i in 0..query_count {
let query_vector = &queries[i * VECTORS_DIMENSION..(i + 1) * VECTORS_DIMENSION];
let mut result_output_buffer =
search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let search_result = index
.search(
search_params,
&FullPrecision,
&ctx,
query_vector,
&mut result_output_buffer,
)
.await
.unwrap();
assert!(
search_result.result_count > 0,
"Expected non-empty result for query {}",
i
);
}
}
}