use std::sync::Arc;
use diskann_vector::distance::Metric;
use crate::{
graph::{
self, DiskANNIndex,
search::Knn,
test::{provider as test_provider, synthetic::Grid},
},
neighbor::Neighbor,
test::{
TestPath, TestRoot,
cmp::{assert_eq_verbose, verbose_eq},
get_or_save_test_results,
tokio::current_thread_runtime,
},
utils::IntoUsize,
};
fn root() -> TestRoot {
TestRoot::new("graph/test/cases/grid_search")
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub(super) struct GridSearch {
pub(super) description: String,
pub(super) query: Vec<f32>,
pub(super) results: Vec<(u32, f32)>,
pub(super) comparisons: usize,
pub(super) hops: usize,
pub(super) num_results: usize,
pub(super) grid_dims: usize,
pub(super) grid_size: usize,
pub(super) beam_width: usize,
pub(super) metrics: test_provider::Metrics,
}
verbose_eq!(GridSearch {
query,
description,
results,
comparisons,
hops,
num_results,
grid_size,
grid_dims,
beam_width,
metrics,
});
fn setup_grid_search(grid: Grid, size: usize) -> Arc<DiskANNIndex<test_provider::Provider>> {
let provider = test_provider::Provider::grid(grid, size).unwrap();
let index_config = graph::config::Builder::new(
provider.max_degree(),
graph::config::MaxDegree::same(),
100,
(Metric::L2).into(),
)
.build()
.unwrap();
Arc::new(DiskANNIndex::new(index_config, provider, None))
}
const BEAM_WIDTHS: [usize; 3] = [1, 2, 4];
fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) {
let rt = current_thread_runtime();
let description_0 = "With a query of all -1s, we expect the neighbor with all zeros to be\
the closest. Due to how the grid is generated, this will be coordinate 0. \
Next, there should be `dim` neighbors that are one further away. \
Increasing the beam width should increase the number of comparisons.";
let description_1 = "With a query of all `size`s, the start point is filtered by default \
and should not appear in results.";
let query_desc = [
(vec![-1.0f32; grid.dim().into()], description_0),
(vec![size as f32; grid.dim().into()], description_1),
];
let mut results = Vec::new();
for (query, desc) in query_desc {
for beam_width in BEAM_WIDTHS {
let index = setup_grid_search(grid, size);
let params = Knn::new(10, 10, Some(beam_width)).unwrap();
let context = test_provider::Context::new();
let mut neighbors = vec![Neighbor::<u32>::default(); params.k_value().get()];
let graph::index::SearchStats {
cmps,
hops,
result_count,
range_search_second_round,
} = rt
.block_on(index.search(
params,
&test_provider::Strategy::new(),
&context,
query.as_slice(),
&mut crate::neighbor::BackInserter::new(neighbors.as_mut_slice()),
))
.unwrap();
assert!(
result_count.into_usize() <= params.k_value().get(),
"grid search should not return more than the requested number of neighbors",
);
assert!(
!range_search_second_round,
"range search should not activate for k-nearest-neighbors",
);
let metrics = index.provider().metrics();
assert_eq!(metrics.set_vector, 0);
assert_eq!(metrics.set_neighbors, 0);
assert_eq!(metrics.append_neighbors, 0);
assert_eq!(
metrics.get_neighbors,
hops.into_usize(),
"recorded hops should have a one-to-one correspondence with `get_neighbors`",
);
assert_eq!(
metrics.get_vector,
cmps.into_usize(),
"recorded comparisons should have a one-to-one correspondence with `get_vector`",
);
{
let test_provider::ContextMetrics { spawns, clones } = context.metrics();
assert_eq!(spawns, 0);
assert_eq!(clones, 0);
}
results.push(GridSearch {
query: query.clone(),
description: desc.to_string(),
results: neighbors.into_iter().map(|i| i.as_tuple()).collect(),
comparisons: cmps.into_usize(),
hops: hops.into_usize(),
num_results: result_count.into_usize(),
grid_dims: grid.dim().into(),
grid_size: size,
beam_width,
metrics,
});
}
}
let name = parent.push(format!("search_{}_{}", grid.dim(), size,));
let expected = get_or_save_test_results(&name, &results);
assert_eq_verbose!(expected, results);
}
#[test]
fn grid_search_1_100() {
_grid_search(Grid::One, 100, root().path());
}
#[test]
fn grid_search_3_5() {
_grid_search(Grid::Three, 5, root().path());
}
#[test]
fn grid_search_4_4() {
_grid_search(Grid::Four, 4, root().path());
}