diskann_benchmark_core/build/graph/
multi.rs1use std::{ops::Range, sync::Arc};
7
8use diskann::{
9 ANNError, ANNErrorKind, ANNResult,
10 graph::{self, glue},
11 provider,
12};
13use diskann_utils::{future::AsyncFriendly, views::Matrix};
14
15use crate::build::{Build, ids::ToId};
16
17pub struct MultiInsert<DP, T, S>
26where
27 DP: provider::DataProvider,
28{
29 index: Arc<graph::DiskANNIndex<DP>>,
30 data: Arc<Matrix<T>>,
31 strategy: S,
32 to_id: Box<dyn ToId<DP::ExternalId>>,
33}
34
35impl<DP, T, S> MultiInsert<DP, T, S>
36where
37 DP: provider::DataProvider,
38{
39 pub fn new<I>(
47 index: Arc<graph::DiskANNIndex<DP>>,
48 data: Arc<Matrix<T>>,
49 strategy: S,
50 to_id: I,
51 ) -> Arc<Self>
52 where
53 I: ToId<DP::ExternalId>,
54 {
55 Arc::new(Self {
56 index,
57 data,
58 strategy,
59 to_id: Box::new(to_id),
60 })
61 }
62}
63
64impl<DP, T, S> Build for MultiInsert<DP, T, S>
65where
66 DP: provider::DataProvider<Context: Default> + for<'a> provider::SetElement<&'a [T]>,
67 S: glue::MultiInsertStrategy<DP, Matrix<T>> + Clone + 'static,
68 T: AsyncFriendly + Clone,
69{
70 type Output = ();
71
72 fn num_data(&self) -> usize {
73 self.data.nrows()
74 }
75
76 async fn build(&self, range: Range<usize>) -> ANNResult<Self::Output> {
77 let vectors = self
78 .data
79 .subview(range.clone())
80 .ok_or_else(|| {
81 #[derive(Debug)]
82 struct OutOfBounds {
83 max: usize,
84 start: usize,
85 end: usize,
86 }
87
88 impl std::fmt::Display for OutOfBounds {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 write!(
91 f,
92 "tried to access data with {} rows at range [{}, {})",
93 self.max, self.start, self.end
94 )
95 }
96 }
97
98 ANNError::message(
99 ANNErrorKind::Opaque,
100 OutOfBounds {
101 max: self.data.nrows(),
102 start: range.start,
103 end: range.end,
104 },
105 )
106 })?
107 .to_owned();
108
109 let ids: ANNResult<Arc<[_]>> = range.into_iter().map(|i| self.to_id.to_id(i)).collect();
110 let context = DP::Context::default();
111 self.index
112 .multi_insert::<S, _>(self.strategy.clone(), &context, Arc::new(vectors), ids?)
113 .await?;
114
115 Ok(())
116 }
117}
118
119#[cfg(test)]
124mod tests {
125 use super::*;
126
127 use std::num::NonZeroUsize;
128
129 use diskann::{
130 graph::{
131 DiskANNIndex,
132 test::{provider, synthetic},
133 },
134 provider::NeighborAccessor,
135 utils::IntoUsize,
136 };
137
138 use crate::build;
139
140 #[test]
141 fn test_multi_insert() {
142 let grid = synthetic::Grid::Four;
143 let size = 4;
144 let start_id = u32::MAX;
145 let distance = diskann_vector::distance::Metric::L2;
146
147 let start_point = grid.start_point(size);
148 let data = Arc::new(grid.data(size));
149
150 let provider_config = provider::Config::new(
151 distance,
152 2 * grid.dim().into_usize(),
153 std::iter::once(provider::StartPoint::new(start_id, start_point)),
154 )
155 .unwrap();
156
157 let provider = provider::Provider::new(provider_config);
158
159 let index_config = diskann::graph::config::Builder::new(
160 provider.max_degree().checked_sub(3).unwrap(),
161 diskann::graph::config::MaxDegree::new(provider.max_degree()),
162 20,
163 distance.into(),
164 )
165 .build()
166 .unwrap();
167
168 let index = Arc::new(DiskANNIndex::new(index_config, provider, None));
169
170 let rt = crate::tokio::runtime(1).unwrap();
171 let builder = MultiInsert::new(
172 index.clone(),
173 data.clone(),
174 provider::Strategy::new(),
175 build::ids::Identity::<u32>::new(),
176 );
177 let _ = build::build(
178 builder.clone(),
179 build::Parallelism::sequential(NonZeroUsize::new(10).unwrap()),
180 &rt,
181 )
182 .unwrap();
183
184 let accessor = index.provider().neighbors();
186 let mut v = diskann::graph::AdjacencyList::new();
187
188 for i in 0..data.nrows() {
189 rt.block_on(accessor.get_neighbors(i.try_into().unwrap(), &mut v))
190 .unwrap();
191 assert!(!v.is_empty());
192 }
193
194 rt.block_on(accessor.get_neighbors(start_id, &mut v))
196 .unwrap();
197 assert!(!v.is_empty());
198
199 let err = rt
201 .block_on(builder.build(data.nrows()..data.nrows() + 1))
202 .unwrap_err();
203 let msg = err.to_string();
204 assert!(
205 msg.contains("tried to access data"),
206 "actual message: {msg}"
207 );
208 }
209}