use std::{num::NonZeroUsize, sync::Arc};
use diskann::{
ANNError, ANNResult,
graph::{
self, ConsolidateKind, InplaceDeleteMethod,
glue::{
Batch, DefaultSearchStrategy, InplaceDeleteStrategy, InsertStrategy,
MultiInsertStrategy, PruneStrategy, SearchStrategy,
},
index::{DegreeStats, PagedSearchState, PartitionedNeighbors, SearchState},
search_output_buffer,
},
neighbor::Neighbor,
provider::{AsNeighbor, AsNeighborMut, DataProvider, Delete, SetElement},
utils::ONE,
};
use crate::storage::{LoadWith, StorageReadProvider};
pub struct DiskANNIndex<DP: DataProvider> {
pub inner: Arc<graph::DiskANNIndex<DP>>,
_runtime: Option<tokio::runtime::Runtime>,
handle: tokio::runtime::Handle,
}
fn create_multi_thread_runtime() -> (tokio::runtime::Runtime, tokio::runtime::Handle) {
#[allow(clippy::expect_used)]
let rt = tokio::runtime::Builder::new_multi_thread()
.build()
.expect("failed to create tokio runtime");
let handle = rt.handle().clone();
(rt, handle)
}
fn create_current_thread_runtime() -> (tokio::runtime::Runtime, tokio::runtime::Handle) {
#[allow(clippy::expect_used)]
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.expect("failed to create tokio runtime");
let handle = rt.handle().clone();
(rt, handle)
}
impl<DP> DiskANNIndex<DP>
where
DP: DataProvider,
{
pub fn new_with_multi_thread_runtime(config: graph::Config, data_provider: DP) -> Self {
let (rt, handle) = create_multi_thread_runtime();
Self::new_internal(config, data_provider, Some(rt), handle, Some(ONE))
}
pub fn new_with_current_thread_runtime(config: graph::Config, data_provider: DP) -> Self {
let (rt, handle) = create_current_thread_runtime();
Self::new_internal(config, data_provider, Some(rt), handle, Some(ONE))
}
pub fn new_with_handle(
config: graph::Config,
data_provider: DP,
handle: tokio::runtime::Handle,
thread_hint: Option<NonZeroUsize>,
) -> Self {
Self::new_internal(config, data_provider, None, handle, thread_hint)
}
fn new_internal(
config: graph::Config,
data_provider: DP,
runtime: Option<tokio::runtime::Runtime>,
handle: tokio::runtime::Handle,
thread_hint: Option<NonZeroUsize>,
) -> Self {
let inner = Arc::new(graph::DiskANNIndex::new(config, data_provider, thread_hint));
Self {
inner,
_runtime: runtime,
handle,
}
}
pub fn run<F, Fut, R>(&self, f: F) -> R
where
F: FnOnce(&Arc<graph::DiskANNIndex<DP>>) -> Fut,
Fut: core::future::Future<Output = R>,
{
self.handle.block_on(f(&self.inner))
}
pub fn load_with_multi_thread_runtime<T, P>(provider: &P, auxiliary: &T) -> ANNResult<Self>
where
graph::DiskANNIndex<DP>: LoadWith<T, Error = ANNError>,
P: StorageReadProvider,
{
let (rt, handle) = create_multi_thread_runtime();
let inner = handle.block_on(graph::DiskANNIndex::<DP>::load_with(provider, auxiliary))?;
Ok(Self {
inner: Arc::new(inner),
_runtime: Some(rt),
handle,
})
}
pub fn load_with_current_thread_runtime<T, P>(provider: &P, auxiliary: &T) -> ANNResult<Self>
where
graph::DiskANNIndex<DP>: LoadWith<T, Error = ANNError>,
P: StorageReadProvider,
{
let (rt, handle) = create_current_thread_runtime();
let inner = handle.block_on(graph::DiskANNIndex::<DP>::load_with(provider, auxiliary))?;
Ok(Self {
inner: Arc::new(inner),
_runtime: Some(rt),
handle,
})
}
pub fn load_with_handle<T, P>(
provider: &P,
auxiliary: &T,
handle: tokio::runtime::Handle,
) -> ANNResult<Self>
where
graph::DiskANNIndex<DP>: LoadWith<T, Error = ANNError>,
P: StorageReadProvider,
{
let inner = handle.block_on(graph::DiskANNIndex::<DP>::load_with(provider, auxiliary))?;
Ok(Self {
inner: Arc::new(inner),
_runtime: None,
handle,
})
}
pub fn insert<S, T>(
&self,
strategy: S,
context: &DP::Context,
id: &DP::ExternalId,
vector: T,
) -> ANNResult<()>
where
S: InsertStrategy<DP, T>,
DP: SetElement<T>,
T: Copy + Send,
{
self.handle
.block_on(self.inner.insert(strategy, context, id, vector))
}
pub fn multi_insert<S, B>(
&self,
strategy: S,
context: &DP::Context,
vectors: Arc<B>,
ids: Arc<[DP::ExternalId]>,
) -> ANNResult<()>
where
Self: 'static,
S: MultiInsertStrategy<DP, B>,
B: Batch,
DP: for<'a> SetElement<B::Element<'a>>,
{
self.handle
.block_on(self.inner.multi_insert(strategy, context, vectors, ids))
}
pub fn is_any_neighbor_deleted<NA>(
&self,
context: &DP::Context,
accessor: &mut NA,
vector_id: DP::InternalId,
) -> ANNResult<bool>
where
DP: Delete,
NA: AsNeighbor<Id = DP::InternalId>,
{
self.handle.block_on(
self.inner
.is_any_neighbor_deleted(context, accessor, vector_id),
)
}
pub fn drop_adj_list<NA>(&self, accessor: &mut NA, vector_id: DP::InternalId) -> ANNResult<()>
where
NA: AsNeighborMut<Id = DP::InternalId>,
{
self.handle
.block_on(self.inner.drop_adj_list(accessor, vector_id))
}
#[allow(clippy::type_complexity)]
pub fn get_undeleted_neighbors<NA>(
&self,
context: &DP::Context,
accessor: &mut NA,
vector_id: DP::InternalId,
) -> ANNResult<PartitionedNeighbors<DP::InternalId>>
where
DP: Delete,
NA: AsNeighbor<Id = DP::InternalId>,
{
self.handle.block_on(
self.inner
.get_undeleted_neighbors(context, accessor, vector_id),
)
}
pub fn inplace_delete<S>(
&self,
strategy: S,
context: &DP::Context,
id: &DP::ExternalId,
num_to_replace: usize,
inplace_delete_method: InplaceDeleteMethod,
) -> ANNResult<()>
where
S: InplaceDeleteStrategy<DP> + Sync + Clone,
DP: Delete,
{
self.handle.block_on(self.inner.inplace_delete(
strategy,
context,
id,
num_to_replace,
inplace_delete_method,
))
}
pub fn drop_deleted_neighbors<NA>(
&self,
context: &DP::Context,
accessor: &mut NA,
vector_id: DP::InternalId,
only_orphans: bool,
) -> ANNResult<ConsolidateKind>
where
DP: Delete,
NA: AsNeighborMut<Id = DP::InternalId>,
{
self.handle.block_on(self.inner.drop_deleted_neighbors(
context,
accessor,
vector_id,
only_orphans,
))
}
pub fn consolidate_vector<S>(
&self,
strategy: &S,
context: &DP::Context,
vector_id: DP::InternalId,
) -> ANNResult<ConsolidateKind>
where
DP: Delete,
S: PruneStrategy<DP>,
{
self.handle
.block_on(self.inner.consolidate_vector(strategy, context, vector_id))
}
pub fn search<S, T, O, OB, P>(
&self,
search_params: P,
strategy: &S,
context: &DP::Context,
query: T,
output: &mut OB,
) -> ANNResult<P::Output>
where
P: graph::search::Search<DP, S, T>,
S: DefaultSearchStrategy<DP, T, O>,
O: Send,
OB: search_output_buffer::SearchOutputBuffer<O> + Send + ?Sized,
{
self.handle.block_on(
self.inner
.search(search_params, strategy, context, query, output),
)
}
#[allow(clippy::type_complexity)]
pub fn start_paged_search<S, T>(
&self,
strategy: S,
context: &DP::Context,
query: T,
l_value: usize,
) -> ANNResult<PagedSearchState<DP, S, S::QueryComputer>>
where
S: SearchStrategy<DP, T> + 'static,
T: Copy + Send,
{
self.handle.block_on(
self.inner
.start_paged_search(strategy, context, query, l_value),
)
}
#[allow(clippy::type_complexity)]
pub fn start_paged_search_with_init_ids<S, T>(
&self,
strategy: S,
context: &DP::Context,
query: T,
l_value: usize,
init_ids: Option<&[DP::InternalId]>,
) -> ANNResult<PagedSearchState<DP, S, S::QueryComputer>>
where
S: SearchStrategy<DP, T> + 'static,
T: Copy + Send,
{
self.handle.block_on(
self.inner
.start_paged_search_with_init_ids(strategy, context, query, l_value, init_ids),
)
}
pub fn next_search_results<S, T>(
&self,
context: &DP::Context,
search_state: &mut SearchState<DP::InternalId, (S, S::QueryComputer)>,
k: usize,
result_output: &mut [Neighbor<DP::InternalId>],
) -> ANNResult<usize>
where
S: SearchStrategy<DP, T>,
{
self.handle.block_on(self.inner.next_search_results(
context,
search_state,
k,
result_output,
))
}
pub fn count_reachable_nodes<NA>(
&self,
start_points: &[DP::InternalId],
accessor: &mut NA,
) -> ANNResult<usize>
where
NA: AsNeighbor<Id = DP::InternalId>,
{
self.handle
.block_on(self.inner.count_reachable_nodes(start_points, accessor))
}
pub fn get_degree_stats<NA>(&self, accessor: &mut NA) -> ANNResult<DegreeStats>
where
for<'a> &'a DP: IntoIterator<Item = DP::InternalId, IntoIter: Send>,
NA: AsNeighbor<Id = DP::InternalId>,
{
self.handle.block_on(self.inner.get_degree_stats(accessor))
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use diskann::{
graph::{self, search_output_buffer},
provider::DefaultContext,
utils::ONE,
};
use diskann_utils::test_data_root;
use diskann_vector::distance::Metric;
use super::DiskANNIndex;
use crate::{
index::diskann_async,
model::{
configuration::IndexConfiguration,
graph::provider::async_::{
common::{FullPrecision, TableBasedDeletes},
inmem::{self, CreateFullPrecision, DefaultProvider},
},
},
storage::{AsyncIndexMetadata, SaveWith, StorageReadProvider, VirtualStorageProvider},
utils::create_rnd_from_seed_in_tests,
};
#[test]
fn test_save_then_sync_load_round_trip() {
let save_path = "/index";
let file_path = "/sift/siftsmall_learn_256pts.fbin";
let train_data = {
let storage = VirtualStorageProvider::new_overlay(test_data_root());
let mut reader = storage.open_reader(file_path).unwrap();
diskann_utils::io::read_bin::<f32>(&mut reader).unwrap()
};
let pq_bytes = 8;
let pq_table = diskann_async::train_pq(
train_data.as_view(),
pq_bytes,
&mut create_rnd_from_seed_in_tests(0xe3c52ef001bc7ade),
crate::utils::create_thread_pool(2).unwrap().as_ref(),
)
.unwrap();
let (build_config, parameters) = diskann_async::simplified_builder(
20,
32,
Metric::L2,
train_data.ncols(),
train_data.nrows(),
|_| {},
)
.unwrap();
let fp_precursor =
CreateFullPrecision::new(parameters.dim, parameters.prefetch_cache_line_level);
let data_provider =
DefaultProvider::new_empty(parameters, fp_precursor, pq_table, TableBasedDeletes)
.unwrap();
let index =
DiskANNIndex::new_with_current_thread_runtime(build_config.clone(), data_provider);
let storage = VirtualStorageProvider::new_memory();
let ctx = DefaultContext;
for (i, v) in train_data.row_iter().enumerate() {
index.insert(FullPrecision, &ctx, &(i as u32), v).unwrap();
}
let save_metadata = AsyncIndexMetadata::new(save_path.to_string());
let storage_ref = &storage;
let metadata_ref = &save_metadata;
index
.run(|inner| {
let inner = Arc::clone(inner);
async move { inner.save_with(storage_ref, metadata_ref).await }
})
.unwrap();
let load_config = IndexConfiguration::new(
Metric::L2,
train_data.ncols(),
train_data.nrows(),
ONE,
1,
build_config,
);
type TestProvider = inmem::FullPrecisionProvider<
f32,
crate::model::graph::provider::async_::FastMemoryQuantVectorProviderAsync,
crate::model::graph::provider::async_::TableDeleteProviderAsync,
>;
let loaded: DiskANNIndex<TestProvider> =
DiskANNIndex::load_with_current_thread_runtime(&storage, &(save_path, load_config))
.unwrap();
let top_k = 5;
let search_l = 20;
let mut ids = vec![0u32; top_k];
let mut distances = vec![0.0f32; top_k];
let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances);
let query = train_data.row(0);
let kind = graph::search::Knn::new_default(top_k, search_l).unwrap();
let stats = loaded
.search(kind, &FullPrecision, &DefaultContext, query, &mut output)
.unwrap();
assert_eq!(stats.result_count, top_k as u32);
assert_eq!(ids[0], 0);
assert_eq!(distances[0], 0.0);
}
}