use std::{iter, sync::Arc};
use diskann_vector::distance::Metric;
use crate::{
graph::{
self, AdjacencyList, ConsolidateKind, DiskANNIndex,
test::provider::{self as test_provider, Provider, Strategy},
},
provider::Delete,
provider::NeighborAccessor,
test::tokio::current_thread_runtime,
};
use super::helpers::{
assert_neighbors, create_2d_unit_square, generate_2d_square_adjacency_list, setup_2d_square,
};
fn setup_consolidation_index(
vectors: Vec<Vec<f32>>,
adjacency_lists: Vec<AdjacencyList<u32>>,
) -> Arc<DiskANNIndex<Provider>> {
let num_points = vectors.len();
let dim = vectors[0].len();
let start_id = num_points as u32;
let provider_max_degree = 5;
let pruned_degree = 4;
let provider_config = test_provider::Config::new(
Metric::L2,
provider_max_degree,
test_provider::StartPoint::new(start_id, vec![0.5; dim]),
)
.unwrap();
let start_neighbors =
AdjacencyList::from_iter_untrusted((0..num_points as u32).take(provider_max_degree));
let points = vectors
.into_iter()
.zip(adjacency_lists)
.enumerate()
.map(|(id, (vec, adj))| (id as u32, vec, adj));
let provider = Provider::new_from(
provider_config,
iter::once((start_id, start_neighbors)),
points,
)
.unwrap();
let index_config = graph::config::Builder::new(
pruned_degree,
graph::config::MaxDegree::same(),
10,
Metric::L2.into(),
)
.build()
.unwrap();
Arc::new(DiskANNIndex::new(index_config, provider, None))
}
#[test]
fn flaky_consolidate_returns_failed_retrieval() {
let rt = current_thread_runtime();
let vectors = vec![
vec![0.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0], vec![1.0, 1.0], vec![2.0, 2.0], vec![0.0, 2.0], vec![2.0, 0.0], ];
let adjacency_lists = vec![
AdjacencyList::from_iter_untrusted([1, 2, 3, 4, 5]), AdjacencyList::from_iter_untrusted([0, 3, 4]),
AdjacencyList::from_iter_untrusted([0, 3, 4]),
AdjacencyList::from_iter_untrusted([1, 2, 4]),
AdjacencyList::from_iter_untrusted([0, 1, 2, 3]),
AdjacencyList::from_iter_untrusted([0, 1, 2]),
AdjacencyList::from_iter_untrusted([0, 1, 2, 3, 5]),
];
let index = setup_consolidation_index(vectors, adjacency_lists);
let ctx = test_provider::Context::new();
let flaky_strategy = Strategy::with_transient(
true, [0], );
let result = rt
.block_on(index.consolidate_vector(&flaky_strategy, &ctx, 0))
.unwrap();
assert_eq!(
result,
ConsolidateKind::FailedVectorRetrieval,
"consolidate should handle transient errors gracefully"
);
}
#[test]
fn consolidate_deleted_vertex_returns_deleted() {
let rt = current_thread_runtime();
let adjacency_lists = generate_2d_square_adjacency_list();
let index = setup_2d_square(create_2d_unit_square(), adjacency_lists, 4);
let ctx = test_provider::Context::new();
let strategy = Strategy::new();
rt.block_on(index.data_provider.delete(&ctx, &3)).unwrap();
let result = rt
.block_on(index.consolidate_vector(&strategy, &ctx, 3))
.unwrap();
assert_eq!(result, ConsolidateKind::Deleted);
}
#[test]
fn consolidate_nothing_to_do_returns_complete() {
let rt = current_thread_runtime();
let adjacency_lists = generate_2d_square_adjacency_list();
let index = setup_2d_square(create_2d_unit_square(), adjacency_lists, 4);
let ctx = test_provider::Context::new();
let strategy = Strategy::new();
let result = rt
.block_on(index.consolidate_vector(&strategy, &ctx, 0))
.unwrap();
assert_eq!(result, ConsolidateKind::Complete);
}
#[test]
fn consolidate_repairs_after_deletion() {
let rt = current_thread_runtime();
let adjacency_lists = vec![
AdjacencyList::from_iter_untrusted([1, 2, 4]),
AdjacencyList::from_iter_untrusted([0, 3, 4]),
AdjacencyList::from_iter_untrusted([0, 3, 4]),
AdjacencyList::from_iter_untrusted([1, 2, 4]),
AdjacencyList::from_iter_untrusted([0, 1, 2, 3]),
];
let index = setup_2d_square(create_2d_unit_square(), adjacency_lists, 4);
let ctx = test_provider::Context::new();
let strategy = Strategy::new();
rt.block_on(index.data_provider.delete(&ctx, &3)).unwrap();
for id in 0..5u32 {
let result = rt
.block_on(index.consolidate_vector(&strategy, &ctx, id))
.unwrap();
if id == 3 {
assert_eq!(result, ConsolidateKind::Deleted);
} else {
assert_ne!(result, ConsolidateKind::FailedVectorRetrieval);
}
}
assert_neighbors(&rt, &index, 0, &[1, 2, 4]);
assert_neighbors(&rt, &index, 1, &[0, 2, 4]);
assert_neighbors(&rt, &index, 2, &[0, 1, 4]);
assert_neighbors(&rt, &index, 4, &[0, 1, 2]);
}
#[test]
fn consolidate_prune_only_no_deleted_neighbors() {
let rt = current_thread_runtime();
let adjacency_lists = generate_2d_square_adjacency_list();
let index = setup_2d_square(create_2d_unit_square(), adjacency_lists, 2);
let ctx = test_provider::Context::new();
let strategy = Strategy::new();
assert_neighbors(&rt, &index, 4, &[0, 1, 2, 3]);
let result = rt
.block_on(index.consolidate_vector(&strategy, &ctx, 4))
.unwrap();
assert_eq!(result, ConsolidateKind::Complete);
let mut list = AdjacencyList::new();
rt.block_on(index.provider().neighbors().get_neighbors(4, &mut list))
.unwrap();
assert!(
list.len() <= 2,
"start node should have degree <= pruned_degree (2) after consolidation, got {}",
list.len()
);
}