use std::{iter, sync::Arc};
use diskann_utils::views::Matrix;
use diskann_vector::distance::Metric;
use crate::{
graph::{
self, AdjacencyList, DiskANNIndex,
test::{
provider::{self as test_provider, Provider, StartPoint},
synthetic::Grid,
},
},
provider::NeighborAccessor,
};
pub(super) fn create_2d_unit_square() -> Matrix<f32> {
Grid::Two.data(2)
}
pub(super) fn setup_2d_square(
vectors: Matrix<f32>,
adjacency_lists: Vec<AdjacencyList<u32>>,
pruned_degree: usize,
) -> Arc<DiskANNIndex<Provider>> {
let num_points = vectors.nrows();
let dim = vectors.ncols();
assert!(
adjacency_lists.len() >= num_points,
"need at least one adjacency list per vector, got {} lists for {} vectors",
adjacency_lists.len(),
num_points,
);
let start_id = num_points as u32;
let start_adj = adjacency_lists
.get(num_points)
.cloned()
.unwrap_or_else(|| AdjacencyList::from_iter_untrusted(0..num_points as u32));
let provider_max_degree = adjacency_lists
.iter()
.map(|a| a.len())
.max()
.map(|m| m.max(pruned_degree))
.unwrap_or(pruned_degree)
.max(start_adj.len());
let provider_config = test_provider::Config::new(
Metric::L2,
provider_max_degree,
test_provider::StartPoint::new(start_id, vec![0.5; dim]),
)
.unwrap();
let points = vectors
.row_iter()
.zip(adjacency_lists.into_iter().take(num_points))
.enumerate()
.map(|(id, (row, adj))| (id as u32, row.to_vec(), adj));
let provider =
Provider::new_from(provider_config, iter::once((start_id, start_adj)), points).unwrap();
let index_config = graph::config::Builder::new(
pruned_degree,
graph::config::MaxDegree::same(),
10,
Metric::L2.into(),
)
.build()
.unwrap();
Arc::new(DiskANNIndex::new(index_config, provider, None))
}
pub(super) fn setup_2d_square_using_synthetics_grid(
size: usize,
start_id: u32,
pruned_degree: usize,
) -> Arc<DiskANNIndex<Provider>> {
let grid = Grid::Two;
let setup = grid.setup(size, start_id);
let provider_config = test_provider::Config::new(
Metric::L2,
(grid.dim() * 2).into(),
StartPoint::new(setup.start_id(), setup.start_point()),
)
.unwrap();
let provider =
Provider::new_from(provider_config, setup.start_neighbors(), setup.setup()).unwrap();
let index_config = graph::config::Builder::new(
pruned_degree,
graph::config::MaxDegree::same(),
10,
Metric::L2.into(),
)
.build()
.unwrap();
Arc::new(DiskANNIndex::new(index_config, provider, None))
}
pub(super) fn generate_2d_square_adjacency_list() -> Vec<AdjacencyList<u32>> {
vec![
AdjacencyList::from_iter_untrusted([1, 4]),
AdjacencyList::from_iter_untrusted([0, 4]),
AdjacencyList::from_iter_untrusted([3, 4]),
AdjacencyList::from_iter_untrusted([2, 4]),
AdjacencyList::from_iter_untrusted([0, 1, 2, 3]),
]
}
pub(super) fn assert_neighbors(
rt: &tokio::runtime::Runtime,
index: &DiskANNIndex<Provider>,
id: u32,
expected: &[u32],
) {
let mut list = AdjacencyList::new();
rt.block_on(index.provider().neighbors().get_neighbors(id, &mut list))
.expect("get_neighbors failed");
list.sort();
assert_eq!(&*list, expected, "neighbors of node {id}");
}