Skip to main content

diskann_benchmark_core/streaming/graph/
drop_deleted.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{ops::Range, sync::Arc};
7
8use diskann::{ANNResult, graph, provider};
9
10use crate::build::{Build, ids::ToIdSized};
11
12/// A [`Build`] stage that invokes
13/// [`drop_deleted_neighbors`](diskann::graph::DiskANNIndex::drop_deleted_neighbors)    
14/// on a collection of points.
15///
16/// The collection of points is determined by an implementation of [`ToIdSized`].
17pub struct DropDeleted<DP>
18where
19    DP: provider::DataProvider,
20{
21    index: Arc<graph::DiskANNIndex<DP>>,
22    only_orphans: bool,
23    to_id: Box<dyn ToIdSized<DP::InternalId>>,
24}
25
26impl<DP> DropDeleted<DP>
27where
28    DP: provider::DataProvider,
29{
30    /// Construct a new [`DropDeleted`] build stage.
31    ///
32    /// This [`Build`] object will run for all Ids provided by `to_id`, invoking
33    /// [`diskann::graph::DiskANNIndex::drop_deleted_neighbors`] on each ID.
34    ///
35    /// Argument `only_orphans` is passed directly to the method on [`diskann::graph::DiskANNIndex`].
36    ///
37    /// # Notes
38    ///
39    /// This method is a little different from other stages since it uses internal IDs
40    /// rather than external IDs. As such, users are **not** encouraged to use it.
41    pub fn new(
42        index: Arc<graph::DiskANNIndex<DP>>,
43        only_orphans: bool,
44        to_id: impl ToIdSized<DP::InternalId> + 'static,
45    ) -> Arc<Self> {
46        Arc::new(Self {
47            index,
48            only_orphans,
49            to_id: Box::new(to_id),
50        })
51    }
52}
53
54impl<DP> Build for DropDeleted<DP>
55where
56    DP: provider::DataProvider<Context: Default> + provider::Delete + provider::DefaultAccessor,
57    for<'a> <DP as provider::DefaultAccessor>::Accessor<'a>: provider::AsNeighborMut,
58{
59    type Output = ();
60
61    fn num_data(&self) -> usize {
62        self.to_id.len()
63    }
64
65    async fn build(&self, range: Range<usize>) -> ANNResult<Self::Output> {
66        let mut accessor = self.index.provider().default_accessor();
67        for i in range {
68            let context = DP::Context::default();
69            self.index
70                .drop_deleted_neighbors(
71                    &context,
72                    &mut accessor,
73                    self.to_id.to_id(i)?,
74                    self.only_orphans,
75                )
76                .await?;
77        }
78        Ok(())
79    }
80}
81
82///////////
83// Tests //
84///////////
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89
90    use std::num::NonZeroUsize;
91
92    use diskann::{
93        graph::test::provider,
94        provider::{Delete, NeighborAccessor},
95        utils::ONE,
96    };
97
98    use crate::{build, streaming::graph::test};
99
100    // In this test - we build an index, delete all the even numbered entries, then run
101    // `drop_deleted` on the whole index.
102    //
103    // This will leave a broken index in the end, but we mainly care that `drop_deleted`
104    // runs correctly.
105    #[test]
106    fn test_drop_deleted() {
107        let (index, num_points) = test::build_test_index();
108        let rt = crate::tokio::runtime(2).unwrap();
109
110        let ctx = provider::Context::new();
111        let num_points: u32 = num_points.try_into().unwrap();
112
113        for i in (0..num_points).filter(|i| i.is_multiple_of(2)) {
114            rt.block_on(index.provider().delete(&ctx, &i)).unwrap();
115        }
116
117        let _ = build::build(
118            DropDeleted::new(index.clone(), false, build::ids::Range::new(0..num_points)),
119            build::Parallelism::dynamic(ONE, NonZeroUsize::new(2).unwrap()),
120            &rt,
121        )
122        .unwrap();
123
124        let accessor = index.provider().neighbors();
125        let mut v = diskann::graph::AdjacencyList::new();
126
127        // `drop_deleted` short-circuits already deleted entries. So we should only check
128        // the odd indices.
129        for i in (0..num_points).filter(|i| !i.is_multiple_of(2)) {
130            rt.block_on(accessor.get_neighbors(i, &mut v)).unwrap();
131
132            assert!(!v.is_empty());
133            for n in v.iter() {
134                assert!(
135                    !n.is_multiple_of(2),
136                    "all multiples of 2 should be removed for entry {}: {:?}",
137                    i,
138                    v,
139                );
140            }
141        }
142    }
143}