use std::{num::NonZeroUsize, sync::Arc};
use diskann_utils::views::MatrixView;
use diskann_vector::distance::Metric;
use crate::{
graph::{
self, DiskANNIndex,
config::IntraBatchCandidates,
search::Knn,
test::{provider as test_provider, synthetic::Grid},
},
neighbor::Neighbor,
test::{
TestPath, TestRoot,
cmp::{assert_eq_verbose, verbose_eq},
get_or_save_test_results,
tokio::current_thread_runtime,
},
utils::IntoUsize,
};
use super::{DUMP_GRAPH_STATE, grid_search::GridSearch};
fn root() -> TestRoot {
TestRoot::new("graph/test/cases/grid_insert")
}
fn empty_provider(grid: Grid, size: usize) -> test_provider::Provider {
let max_degree: usize = (grid.dim() as usize) * 2;
let start_vector = grid.start_point(size);
let config = test_provider::Config::new(
Metric::L2,
max_degree,
test_provider::StartPoint::new(u32::MAX, start_vector),
)
.unwrap();
test_provider::Provider::new(config)
}
fn build_index(
provider: test_provider::Provider,
intra_batch_candidates: IntraBatchCandidates,
max_minibatch_par: usize,
) -> Arc<DiskANNIndex<test_provider::Provider>> {
let provider_degree = provider.max_degree();
let target_degree = match provider_degree.checked_sub(2) {
Some(degree) => degree.max(2).min(provider_degree),
None => provider_degree,
};
let index_config = graph::config::Builder::new_with(
target_degree,
graph::config::MaxDegree::new(provider_degree),
100,
(Metric::L2).into(),
|b| {
b.intra_batch_candidates(intra_batch_candidates)
.max_minibatch_par(max_minibatch_par);
},
)
.build()
.unwrap();
Arc::new(DiskANNIndex::new(index_config, provider, None))
}
fn run_build(
index: &Arc<DiskANNIndex<test_provider::Provider>>,
data: MatrixView<'_, f32>,
batchsize: Option<NonZeroUsize>,
working_set_reuse: bool,
runtime: &tokio::runtime::Runtime,
) -> test_provider::Context {
let strategy = test_provider::Strategy::with_options(working_set_reuse);
let context = test_provider::Context::new();
match batchsize {
None => {
for (id, vector) in data.row_iter().enumerate() {
runtime
.block_on(index.insert(strategy.clone(), &context, &(id as u32), vector))
.unwrap();
}
}
Some(batchsize) => {
let mut start = 0;
while start < data.nrows() {
let stop = (start + batchsize.get()).min(data.nrows());
let batch = Arc::new(data.subview(start..stop).unwrap().to_owned());
runtime
.block_on(index.multi_insert::<test_provider::Strategy, _>(
strategy.clone(),
&context,
batch,
(start..stop).map(|i| i as u32).collect(),
))
.unwrap();
start = stop;
}
}
}
context
}
fn maybe_dump_graph(index: &DiskANNIndex<test_provider::Provider>) -> Option<Vec<(u32, Vec<u32>)>> {
if !DUMP_GRAPH_STATE {
None
} else {
Some(
index
.provider()
.dump_neighbors(true)
.into_iter()
.map(|(id, list)| (id, list.into()))
.collect(),
)
}
}
fn run_searches(
index: &DiskANNIndex<test_provider::Provider>,
grid: Grid,
size: usize,
description_prefix: &str,
runtime: &tokio::runtime::Runtime,
) -> Vec<GridSearch> {
let desc_0 = format!(
"{} Search with query of all -1s. \
The nearest neighbor should be coordinate 0 (all zeros).",
description_prefix,
);
let desc_1 = format!(
"{} Search with query of all `size`. \
The start point should appear as it is not filtered by default.",
description_prefix,
);
let queries = [
(vec![-1.0f32; grid.dim().into()], desc_0),
(vec![size as f32; grid.dim().into()], desc_1),
];
let mut results = Vec::new();
for (query, desc) in queries {
let params = Knn::new(10, 10, None).unwrap();
let search_ctx = test_provider::Context::new();
let mut neighbors = vec![Neighbor::<u32>::default(); params.k_value().get()];
let graph::index::SearchStats {
cmps,
hops,
result_count,
range_search_second_round,
} = runtime
.block_on(index.search(
params,
&test_provider::Strategy::new(),
&search_ctx,
&*query,
&mut crate::neighbor::BackInserter::new(neighbors.as_mut_slice()),
))
.unwrap();
assert!(
!range_search_second_round,
"range search should not activate for k-nearest-neighbors",
);
let metrics = index.provider().metrics();
results.push(GridSearch {
query: query.clone(),
description: desc,
results: neighbors.into_iter().map(|i| i.as_tuple()).collect(),
comparisons: cmps.into_usize(),
hops: hops.into_usize(),
num_results: result_count.into_usize(),
grid_dims: grid.dim().into(),
grid_size: size,
beam_width: params.beam_width().get(),
metrics,
});
}
results
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct GridInsertBaseline {
description: String,
grid_dims: usize,
grid_size: usize,
num_inserted: usize,
insert_metrics: test_provider::Metrics,
context_metrics: test_provider::ContextMetrics,
searches: Vec<GridSearch>,
#[serde(skip_serializing_if = "Option::is_none", default)]
graph_state: Option<Vec<(u32, Vec<u32>)>>,
}
verbose_eq!(GridInsertBaseline {
description,
grid_dims,
grid_size,
num_inserted,
insert_metrics,
context_metrics,
searches,
graph_state,
});
fn ibc_label(ibc: IntraBatchCandidates) -> String {
match ibc {
IntraBatchCandidates::None => "ibc=none".to_string(),
IntraBatchCandidates::Max(n) => format!("ibc=max({})", n),
IntraBatchCandidates::All => "ibc=all".to_string(),
}
}
fn baseline_name(
grid: Grid,
size: usize,
batchsize: Option<NonZeroUsize>,
ibc: IntraBatchCandidates,
) -> String {
let batch_tag = match batchsize {
None => "single".to_string(),
Some(bs) => format!("batch_{}", bs),
};
let ibc_tag = match ibc {
IntraBatchCandidates::None => "ibc_none".to_string(),
IntraBatchCandidates::Max(n) => format!("ibc_max_{}", n),
IntraBatchCandidates::All => "ibc_all".to_string(),
};
format!("insert_{}_{}_{}/{}", grid.dim(), size, batch_tag, ibc_tag)
}
struct TestParams {
grid: Grid,
size: usize,
batchsize: Option<NonZeroUsize>,
intra_batch_candidates: IntraBatchCandidates,
max_minibatch_par: usize,
}
fn _grid_build_and_search(params: TestParams, mut parent: TestPath<'_>) {
let rt = current_thread_runtime();
let TestParams {
grid,
size,
batchsize,
intra_batch_candidates,
max_minibatch_par,
} = params;
let num_points = grid.num_points(size);
let grid_data = grid.data(size);
let index = build_index(
empty_provider(grid, size),
intra_batch_candidates,
max_minibatch_par,
);
let insert_context = run_build(&index, grid_data.as_view(), batchsize, false, &rt);
let insert_metrics = index.provider().metrics();
index.provider().is_consistent().unwrap();
let graph_state = maybe_dump_graph(&index);
let mode_desc = match batchsize {
None => "one-by-one".to_string(),
Some(bs) if bs.get() >= num_points => "batch (all-at-once)".to_string(),
Some(bs) => format!("batch (chunks of {})", bs),
};
let description_prefix = format!(
"After inserting {} points ({}, {}) into a {}D grid of size {}.",
num_points,
mode_desc,
ibc_label(intra_batch_candidates),
grid.dim(),
size,
);
let searches = run_searches(&index, grid, size, &description_prefix, &rt);
let baseline = GridInsertBaseline {
description: description_prefix,
grid_dims: grid.dim().into(),
grid_size: size,
num_inserted: num_points,
insert_metrics,
context_metrics: insert_context.metrics(),
searches,
graph_state,
};
let name = parent.push(baseline_name(grid, size, batchsize, intra_batch_candidates));
let expected = get_or_save_test_results(&name, &baseline);
assert_eq_verbose!(expected, baseline);
let reuse_index = build_index(
empty_provider(grid, size),
intra_batch_candidates,
max_minibatch_par,
);
let _ = run_build(&reuse_index, grid_data.as_view(), batchsize, true, &rt);
assert_eq_verbose!(
index.provider().dump_neighbors(true),
reuse_index.provider().dump_neighbors(true),
);
assert!(
reuse_index.provider().metrics().get_vector <= index.provider().metrics().get_vector,
"with reuse: {}, without reuse: {}",
reuse_index.provider().metrics().get_vector,
index.provider().metrics().get_vector,
);
}
fn _assert_thread_invariant(
grid: Grid,
size: usize,
batchsize: NonZeroUsize,
intra_batch_candidates: IntraBatchCandidates,
max_minibatch_par: usize,
) {
let rt_st = current_thread_runtime();
let grid_data = grid.data(size);
let index_st = build_index(
empty_provider(grid, size),
intra_batch_candidates,
max_minibatch_par,
);
run_build(
&index_st,
grid_data.as_view(),
Some(batchsize),
false,
&rt_st,
);
let metrics_st = index_st.provider().metrics();
let rt_mt = tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.build()
.expect("multi-thread runtime should build");
let index_mt = build_index(
empty_provider(grid, size),
intra_batch_candidates,
max_minibatch_par,
);
run_build(
&index_mt,
grid_data.as_view(),
Some(batchsize),
false,
&rt_mt,
);
let metrics_mt = index_mt.provider().metrics();
assert_eq_verbose!(metrics_st, metrics_mt);
let prefix = "Thread invariance check.";
let searches_st = run_searches(&index_st, grid, size, prefix, &rt_st);
let searches_mt = run_searches(&index_mt, grid, size, prefix, &rt_mt);
assert_eq_verbose!(searches_st, searches_mt);
}
#[test]
fn single_1d_100() {
_grid_build_and_search(
TestParams {
grid: Grid::One,
size: 100,
batchsize: None,
intra_batch_candidates: IntraBatchCandidates::None,
max_minibatch_par: 1,
},
root().path(),
);
}
#[test]
fn single_3d_5() {
_grid_build_and_search(
TestParams {
grid: Grid::Three,
size: 5,
batchsize: None,
intra_batch_candidates: IntraBatchCandidates::None,
max_minibatch_par: 1,
},
root().path(),
);
}
#[test]
fn single_4d_4() {
_grid_build_and_search(
TestParams {
grid: Grid::Four,
size: 4,
batchsize: None,
intra_batch_candidates: IntraBatchCandidates::None,
max_minibatch_par: 1,
},
root().path(),
);
}
fn all_at_once(grid: Grid, size: usize) -> NonZeroUsize {
NonZeroUsize::new(grid.num_points(size)).unwrap()
}
#[test]
fn batch_all_ibc_none_1d_100() {
let (grid, size) = (Grid::One, 100);
_grid_build_and_search(
TestParams {
grid,
size,
batchsize: Some(all_at_once(grid, size)),
intra_batch_candidates: IntraBatchCandidates::None,
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn batch_all_ibc_none_3d_5() {
let (grid, size) = (Grid::Three, 5);
_grid_build_and_search(
TestParams {
grid,
size,
batchsize: Some(all_at_once(grid, size)),
intra_batch_candidates: IntraBatchCandidates::None,
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn batch_all_ibc_none_4d_4() {
let (grid, size) = (Grid::Four, 4);
_grid_build_and_search(
TestParams {
grid,
size,
batchsize: Some(all_at_once(grid, size)),
intra_batch_candidates: IntraBatchCandidates::None,
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn batch_all_ibc_4_3d_5() {
let (grid, size) = (Grid::Three, 5);
_grid_build_and_search(
TestParams {
grid,
size,
batchsize: Some(all_at_once(grid, size)),
intra_batch_candidates: IntraBatchCandidates::new(4),
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn batch_all_ibc_4_4d_4() {
let (grid, size) = (Grid::Four, 4);
_grid_build_and_search(
TestParams {
grid,
size,
batchsize: Some(all_at_once(grid, size)),
intra_batch_candidates: IntraBatchCandidates::new(4),
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn batch_all_ibc_all_1d_100() {
let (grid, size) = (Grid::One, 100);
_grid_build_and_search(
TestParams {
grid,
size,
batchsize: Some(all_at_once(grid, size)),
intra_batch_candidates: IntraBatchCandidates::All,
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn batch_all_ibc_all_3d_5() {
let (grid, size) = (Grid::Three, 5);
_grid_build_and_search(
TestParams {
grid,
size,
batchsize: Some(all_at_once(grid, size)),
intra_batch_candidates: IntraBatchCandidates::All,
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn batch_all_ibc_all_4d_4() {
let (grid, size) = (Grid::Four, 4);
_grid_build_and_search(
TestParams {
grid,
size,
batchsize: Some(all_at_once(grid, size)),
intra_batch_candidates: IntraBatchCandidates::All,
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn batch_25_ibc_none_3d_5() {
_grid_build_and_search(
TestParams {
grid: Grid::Three,
size: 5,
batchsize: NonZeroUsize::new(25),
intra_batch_candidates: IntraBatchCandidates::None,
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn batch_25_ibc_none_4d_4() {
_grid_build_and_search(
TestParams {
grid: Grid::Four,
size: 4,
batchsize: NonZeroUsize::new(25),
intra_batch_candidates: IntraBatchCandidates::None,
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn batch_25_ibc_all_3d_5() {
_grid_build_and_search(
TestParams {
grid: Grid::Three,
size: 5,
batchsize: NonZeroUsize::new(25),
intra_batch_candidates: IntraBatchCandidates::All,
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn batch_25_ibc_all_4d_4() {
_grid_build_and_search(
TestParams {
grid: Grid::Four,
size: 4,
batchsize: NonZeroUsize::new(25),
intra_batch_candidates: IntraBatchCandidates::All,
max_minibatch_par: 2,
},
root().path(),
);
}
#[test]
fn thread_invariant_batch_all_ibc_all_3d_5() {
_assert_thread_invariant(
Grid::Three,
5,
all_at_once(Grid::Three, 5),
IntraBatchCandidates::All,
2,
);
}
#[test]
fn thread_invariant_batch_25_ibc_none_4d_4() {
_assert_thread_invariant(
Grid::Four,
4,
NonZeroUsize::new(25).unwrap(),
IntraBatchCandidates::None,
2,
);
}