Skip to main content

diskann_benchmark_core/build/graph/
single.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{ops::Range, sync::Arc};
7
8use diskann::{
9    ANNResult,
10    graph::{self, glue},
11    provider,
12};
13use diskann_utils::{future::AsyncFriendly, views::Matrix};
14
15use crate::build::{Build, ids::ToId};
16
17/// A built-in helper for benchmarking [insert](graph::DiskANNIndex::insert).
18///
19/// This is intended to be used in conjunction with [`crate::build::build`] and [`crate::build::build_tracked`].
20#[derive(Debug)]
21pub struct SingleInsert<DP, T, S>
22where
23    DP: provider::DataProvider,
24{
25    index: Arc<graph::DiskANNIndex<DP>>,
26    data: Arc<Matrix<T>>,
27    strategy: S,
28    to_id: Box<dyn ToId<DP::ExternalId>>,
29}
30
31impl<DP, T, S> SingleInsert<DP, T, S>
32where
33    DP: provider::DataProvider,
34{
35    /// Construct a new [`SingleInsert`] builder for the given `index`.
36    ///
37    /// Vectors will be inserted using all rows of `data` with `strategy` used
38    /// for the [`diskann::graph::glue::InsertStrategy`].
39    ///
40    /// Parameter `to_id` will be used to convert row indices of `data` (`0..data.nrows()`)
41    /// to external IDs.
42    pub fn new<I>(
43        index: Arc<graph::DiskANNIndex<DP>>,
44        data: Arc<Matrix<T>>,
45        strategy: S,
46        to_id: I,
47    ) -> Arc<Self>
48    where
49        I: ToId<DP::ExternalId>,
50    {
51        Arc::new(Self {
52            index,
53            data,
54            strategy,
55            to_id: Box::new(to_id),
56        })
57    }
58}
59
60impl<DP, T, S> Build for SingleInsert<DP, T, S>
61where
62    DP: provider::DataProvider<Context: Default> + for<'a> provider::SetElement<&'a [T]>,
63    S: for<'a> glue::InsertStrategy<DP, &'a [T]> + Clone + AsyncFriendly,
64    T: AsyncFriendly + Clone,
65{
66    type Output = ();
67
68    fn num_data(&self) -> usize {
69        self.data.nrows()
70    }
71
72    async fn build(&self, range: Range<usize>) -> ANNResult<Self::Output> {
73        for i in range {
74            let context = DP::Context::default();
75            self.index
76                .insert(
77                    self.strategy.clone(),
78                    &context,
79                    &self.to_id.to_id(i)?,
80                    self.data.row(i),
81                )
82                .await?;
83        }
84        Ok(())
85    }
86}
87
88///////////
89// Tests //
90///////////
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    use diskann::{
97        graph::test::{provider, synthetic},
98        provider::NeighborAccessor,
99        utils::{IntoUsize, ONE},
100    };
101
102    use crate::build;
103
104    #[test]
105    fn test_single_insert() {
106        let grid = synthetic::Grid::Four;
107        let size = 4;
108        let start_id = u32::MAX;
109        let distance = diskann_vector::distance::Metric::L2;
110
111        let start_point = grid.start_point(size);
112        let data = Arc::new(grid.data(size));
113
114        let provider_config = provider::Config::new(
115            distance,
116            2 * grid.dim().into_usize(),
117            std::iter::once(provider::StartPoint::new(start_id, start_point)),
118        )
119        .unwrap();
120
121        let provider = provider::Provider::new(provider_config);
122
123        let index_config = diskann::graph::config::Builder::new(
124            provider.max_degree().checked_sub(3).unwrap(),
125            diskann::graph::config::MaxDegree::new(provider.max_degree()),
126            20,
127            distance.into(),
128        )
129        .build()
130        .unwrap();
131
132        let index = Arc::new(diskann::graph::DiskANNIndex::new(
133            index_config,
134            provider,
135            None,
136        ));
137
138        let rt = crate::tokio::runtime(1).unwrap();
139        let _ = build::build(
140            SingleInsert::new(
141                index.clone(),
142                data.clone(),
143                provider::Strategy::new(),
144                build::ids::Identity::<u32>::new(),
145            ),
146            build::Parallelism::dynamic(ONE, ONE),
147            &rt,
148        )
149        .unwrap();
150
151        // Ensure that the index is correctly populated.
152        let accessor = index.provider().neighbors();
153        let mut v = diskann::graph::AdjacencyList::new();
154
155        for i in 0..data.nrows() {
156            rt.block_on(accessor.get_neighbors(i.try_into().unwrap(), &mut v))
157                .unwrap();
158            assert!(!v.is_empty());
159        }
160
161        // Check the start point.
162        rt.block_on(accessor.get_neighbors(start_id, &mut v))
163            .unwrap();
164        assert!(!v.is_empty());
165    }
166}