use std::{ops::Range, sync::Arc};
use diskann::{
ANNError, ANNErrorKind, ANNResult,
graph::{self, glue},
provider,
};
use diskann_utils::{future::AsyncFriendly, views::Matrix};
use crate::build::{Build, ids::ToId};
pub struct MultiInsert<DP, T, S>
where
DP: provider::DataProvider,
{
index: Arc<graph::DiskANNIndex<DP>>,
data: Arc<Matrix<T>>,
strategy: S,
to_id: Box<dyn ToId<DP::ExternalId>>,
}
impl<DP, T, S> MultiInsert<DP, T, S>
where
DP: provider::DataProvider,
{
pub fn new<I>(
index: Arc<graph::DiskANNIndex<DP>>,
data: Arc<Matrix<T>>,
strategy: S,
to_id: I,
) -> Arc<Self>
where
I: ToId<DP::ExternalId>,
{
Arc::new(Self {
index,
data,
strategy,
to_id: Box::new(to_id),
})
}
}
impl<DP, T, S> Build for MultiInsert<DP, T, S>
where
DP: provider::DataProvider<Context: Default> + for<'a> provider::SetElement<&'a [T]>,
S: glue::MultiInsertStrategy<DP, Matrix<T>> + Clone + 'static,
T: AsyncFriendly + Clone,
{
type Output = ();
fn num_data(&self) -> usize {
self.data.nrows()
}
async fn build(&self, range: Range<usize>) -> ANNResult<Self::Output> {
let vectors = self
.data
.subview(range.clone())
.ok_or_else(|| {
#[derive(Debug)]
struct OutOfBounds {
max: usize,
start: usize,
end: usize,
}
impl std::fmt::Display for OutOfBounds {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"tried to access data with {} rows at range [{}, {})",
self.max, self.start, self.end
)
}
}
ANNError::message(
ANNErrorKind::Opaque,
OutOfBounds {
max: self.data.nrows(),
start: range.start,
end: range.end,
},
)
})?
.to_owned();
let ids: ANNResult<Arc<[_]>> = range.into_iter().map(|i| self.to_id.to_id(i)).collect();
let context = DP::Context::default();
self.index
.multi_insert::<S, _>(self.strategy.clone(), &context, Arc::new(vectors), ids?)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::num::NonZeroUsize;
use diskann::{
graph::{
DiskANNIndex,
test::{provider, synthetic},
},
provider::NeighborAccessor,
utils::IntoUsize,
};
use crate::build;
#[test]
fn test_multi_insert() {
let grid = synthetic::Grid::Four;
let size = 4;
let start_id = u32::MAX;
let distance = diskann_vector::distance::Metric::L2;
let start_point = grid.start_point(size);
let data = Arc::new(grid.data(size));
let provider_config = provider::Config::new(
distance,
2 * grid.dim().into_usize(),
std::iter::once(provider::StartPoint::new(start_id, start_point)),
)
.unwrap();
let provider = provider::Provider::new(provider_config);
let index_config = diskann::graph::config::Builder::new(
provider.max_degree().checked_sub(3).unwrap(),
diskann::graph::config::MaxDegree::new(provider.max_degree()),
20,
distance.into(),
)
.build()
.unwrap();
let index = Arc::new(DiskANNIndex::new(index_config, provider, None));
let rt = crate::tokio::runtime(1).unwrap();
let builder = MultiInsert::new(
index.clone(),
data.clone(),
provider::Strategy::new(),
build::ids::Identity::<u32>::new(),
);
let _ = build::build(
builder.clone(),
build::Parallelism::sequential(NonZeroUsize::new(10).unwrap()),
&rt,
)
.unwrap();
let accessor = index.provider().neighbors();
let mut v = diskann::graph::AdjacencyList::new();
for i in 0..data.nrows() {
rt.block_on(accessor.get_neighbors(i.try_into().unwrap(), &mut v))
.unwrap();
assert!(!v.is_empty());
}
rt.block_on(accessor.get_neighbors(start_id, &mut v))
.unwrap();
assert!(!v.is_empty());
let err = rt
.block_on(builder.build(data.nrows()..data.nrows() + 1))
.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("tried to access data"),
"actual message: {msg}"
);
}
}