Skip to main content

diskann_benchmark_core/build/graph/
multi.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{ops::Range, sync::Arc};
7
8use diskann::{
9    ANNError, ANNErrorKind, ANNResult,
10    graph::{self, glue},
11    provider,
12};
13use diskann_utils::{future::AsyncFriendly, views::Matrix};
14
15use crate::build::{Build, ids::ToId};
16
17/// A built-in helper for benchmarking [multi-insert](graph::DiskANNIndex::multi_insert).
18///
19/// This is intended to be used in conjunction with [`crate::build::build`] and [`crate::build::build_tracked`].
20///
21/// # Notes
22///
23/// The multi-insert API for [`diskann::graph::DiskANNIndex`] parallelizes insertion internally. When using
24/// [`crate::build::build`], users should use [`crate::build::Parallelism::sequential`].
25pub struct MultiInsert<DP, T, S>
26where
27    DP: provider::DataProvider,
28{
29    index: Arc<graph::DiskANNIndex<DP>>,
30    data: Arc<Matrix<T>>,
31    strategy: S,
32    to_id: Box<dyn ToId<DP::ExternalId>>,
33}
34
35impl<DP, T, S> MultiInsert<DP, T, S>
36where
37    DP: provider::DataProvider,
38{
39    /// Construct a new [`MultiInsert`] builder for the given `index`.
40    ///
41    /// Vectors will be inserted using all rows of `data` with `strategy` used
42    /// for the [`diskann::graph::glue::InsertStrategy`].
43    ///
44    /// Parameter `to_id` will be used to convert row indices of `data` (`0..data.nrows()`)
45    /// to external IDs.
46    pub fn new<I>(
47        index: Arc<graph::DiskANNIndex<DP>>,
48        data: Arc<Matrix<T>>,
49        strategy: S,
50        to_id: I,
51    ) -> Arc<Self>
52    where
53        I: ToId<DP::ExternalId>,
54    {
55        Arc::new(Self {
56            index,
57            data,
58            strategy,
59            to_id: Box::new(to_id),
60        })
61    }
62}
63
64impl<DP, T, S> Build for MultiInsert<DP, T, S>
65where
66    DP: provider::DataProvider<Context: Default> + for<'a> provider::SetElement<&'a [T]>,
67    S: glue::MultiInsertStrategy<DP, Matrix<T>> + Clone + 'static,
68    T: AsyncFriendly + Clone,
69{
70    type Output = ();
71
72    fn num_data(&self) -> usize {
73        self.data.nrows()
74    }
75
76    async fn build(&self, range: Range<usize>) -> ANNResult<Self::Output> {
77        let vectors = self
78            .data
79            .subview(range.clone())
80            .ok_or_else(|| {
81                #[derive(Debug)]
82                struct OutOfBounds {
83                    max: usize,
84                    start: usize,
85                    end: usize,
86                }
87
88                impl std::fmt::Display for OutOfBounds {
89                    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90                        write!(
91                            f,
92                            "tried to access data with {} rows at range [{}, {})",
93                            self.max, self.start, self.end
94                        )
95                    }
96                }
97
98                ANNError::message(
99                    ANNErrorKind::Opaque,
100                    OutOfBounds {
101                        max: self.data.nrows(),
102                        start: range.start,
103                        end: range.end,
104                    },
105                )
106            })?
107            .to_owned();
108
109        let ids: ANNResult<Arc<[_]>> = range.into_iter().map(|i| self.to_id.to_id(i)).collect();
110        let context = DP::Context::default();
111        self.index
112            .multi_insert::<S, _>(self.strategy.clone(), &context, Arc::new(vectors), ids?)
113            .await?;
114
115        Ok(())
116    }
117}
118
119///////////
120// Tests //
121///////////
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    use std::num::NonZeroUsize;
128
129    use diskann::{
130        graph::{
131            DiskANNIndex,
132            test::{provider, synthetic},
133        },
134        provider::NeighborAccessor,
135        utils::IntoUsize,
136    };
137
138    use crate::build;
139
140    #[test]
141    fn test_multi_insert() {
142        let grid = synthetic::Grid::Four;
143        let size = 4;
144        let start_id = u32::MAX;
145        let distance = diskann_vector::distance::Metric::L2;
146
147        let start_point = grid.start_point(size);
148        let data = Arc::new(grid.data(size));
149
150        let provider_config = provider::Config::new(
151            distance,
152            2 * grid.dim().into_usize(),
153            std::iter::once(provider::StartPoint::new(start_id, start_point)),
154        )
155        .unwrap();
156
157        let provider = provider::Provider::new(provider_config);
158
159        let index_config = diskann::graph::config::Builder::new(
160            provider.max_degree().checked_sub(3).unwrap(),
161            diskann::graph::config::MaxDegree::new(provider.max_degree()),
162            20,
163            distance.into(),
164        )
165        .build()
166        .unwrap();
167
168        let index = Arc::new(DiskANNIndex::new(index_config, provider, None));
169
170        let rt = crate::tokio::runtime(1).unwrap();
171        let builder = MultiInsert::new(
172            index.clone(),
173            data.clone(),
174            provider::Strategy::new(),
175            build::ids::Identity::<u32>::new(),
176        );
177        let _ = build::build(
178            builder.clone(),
179            build::Parallelism::sequential(NonZeroUsize::new(10).unwrap()),
180            &rt,
181        )
182        .unwrap();
183
184        // Ensure that the index is correctly populated.
185        let accessor = index.provider().neighbors();
186        let mut v = diskann::graph::AdjacencyList::new();
187
188        for i in 0..data.nrows() {
189            rt.block_on(accessor.get_neighbors(i.try_into().unwrap(), &mut v))
190                .unwrap();
191            assert!(!v.is_empty());
192        }
193
194        // Check the start point.
195        rt.block_on(accessor.get_neighbors(start_id, &mut v))
196            .unwrap();
197        assert!(!v.is_empty());
198
199        // Test that we correctly get an indexing error for out-of-bounds accesses.
200        let err = rt
201            .block_on(builder.build(data.nrows()..data.nrows() + 1))
202            .unwrap_err();
203        let msg = err.to_string();
204        assert!(
205            msg.contains("tried to access data"),
206            "actual message: {msg}"
207        );
208    }
209}