Skip to main content

diskann_disk/search/provider/
disk_sector_graph.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5#![warn(missing_docs)]
6
7//! Sector graph
8use std::ops::Deref;
9
10use diskann::{ANNError, ANNResult};
11use diskann_providers::common::AlignedBoxWithSlice;
12
13use crate::{
14    data_model::GraphHeader,
15    utils::aligned_file_reader::{traits::AlignedFileReader, AlignedRead},
16};
17
18const DEFAULT_DISK_SECTOR_LEN: usize = 4096;
19
20/// Sector graph read from disk index
21pub struct DiskSectorGraph<AlignedReaderType: AlignedFileReader> {
22    /// Ensure `sector_reader` is dropped before `sectors_data` by placing it before `sectors_data`.
23    /// Graph storage to read sectors
24    sector_reader: AlignedReaderType,
25    /// Sector bytes from disk
26    /// One sector has num_nodes_per_sector nodes
27    /// Each node's layout: {full precision vector:[T; DIM]}{num_nbrs: u32}{neighbors: [u32; num_nbrs]}
28    /// The fp vector is not aligned
29    ///
30    /// index info for multi-node sectors
31    /// node `i` is in sector: [i / num_nodes_per_sector]
32    /// offset in sector: [(i % num_nodes_per_sector) * node_len]
33    ///
34    /// index info for multi-sector nodes
35    /// node `i` is in sector: [i * max_node_len.div_ceil(block_size)]
36    /// offset in sector: [0]
37    sectors_data: AlignedBoxWithSlice<u8>,
38    /// Current sector index into which the next read reads data
39    cur_sector_idx: u64,
40
41    /// 0 for multi-sector nodes, >0 for multi-node sectors
42    num_nodes_per_sector: u64,
43
44    node_len: u64,
45
46    max_n_batch_sector_read: usize,
47
48    num_sectors_per_node: usize,
49
50    block_size: usize,
51}
52
53impl<AlignedReaderType: AlignedFileReader> DiskSectorGraph<AlignedReaderType> {
54    /// Create SectorGraph instance
55    pub fn new(
56        sector_reader: AlignedReaderType,
57        header: &GraphHeader,
58        max_n_batch_sector_read: usize,
59    ) -> ANNResult<Self> {
60        let mut block_size = header.block_size() as usize;
61        let version = header.layout_version();
62        if (version.major_version() == 0 && version.minor_version() == 0) || block_size == 0 {
63            block_size = DEFAULT_DISK_SECTOR_LEN;
64        }
65
66        let num_nodes_per_sector = header.metadata().num_nodes_per_block;
67        let node_len = header.metadata().node_len;
68        let num_sectors_per_node = if num_nodes_per_sector > 0 {
69            1
70        } else {
71            (node_len as usize).div_ceil(block_size)
72        };
73
74        Ok(Self {
75            sector_reader,
76            sectors_data: AlignedBoxWithSlice::new(
77                max_n_batch_sector_read * num_sectors_per_node * block_size,
78                block_size,
79            )?,
80            cur_sector_idx: 0,
81            num_nodes_per_sector,
82            node_len,
83            max_n_batch_sector_read,
84            num_sectors_per_node,
85            block_size,
86        })
87    }
88
89    /// Reconfigure SectorGraph if the max number of sectors to read is larger than the current one
90    pub fn reconfigure(&mut self, max_n_batch_sector_read: usize) -> ANNResult<()> {
91        if max_n_batch_sector_read > self.max_n_batch_sector_read {
92            self.max_n_batch_sector_read = max_n_batch_sector_read;
93            self.sectors_data = AlignedBoxWithSlice::new(
94                max_n_batch_sector_read * self.num_sectors_per_node * self.block_size,
95                self.block_size,
96            )?;
97        }
98        Ok(())
99    }
100
101    /// Reset SectorGraph
102    pub fn reset(&mut self) {
103        self.cur_sector_idx = 0;
104    }
105
106    /// Read sectors into sectors_data
107    /// They are in the same order as sectors_to_fetch
108    pub fn read_graph(&mut self, sectors_to_fetch: &[u64]) -> ANNResult<()> {
109        let cur_sector_idx_usize: usize = self.cur_sector_idx.try_into()?;
110        if sectors_to_fetch.len() > self.max_n_batch_sector_read - cur_sector_idx_usize {
111            return Err(ANNError::log_index_error(format_args!(
112                "Trying to read too many sectors. number of sectors to read: {}, max number of sectors can read: {}",
113                sectors_to_fetch.len(),
114                self.max_n_batch_sector_read - cur_sector_idx_usize,
115            )));
116        }
117
118        let len_per_node = self.num_sectors_per_node * self.block_size;
119        let mut sector_slices = self.sectors_data.split_into_nonoverlapping_mut_slices(
120            cur_sector_idx_usize * len_per_node
121                ..(cur_sector_idx_usize + sectors_to_fetch.len()) * len_per_node,
122            len_per_node,
123        )?;
124        let mut read_requests = Vec::with_capacity(sector_slices.len());
125        for (local_sector_idx, slice) in sector_slices.iter_mut().enumerate() {
126            let sector_id = sectors_to_fetch[local_sector_idx];
127            read_requests.push(AlignedRead::new(sector_id * self.block_size as u64, slice)?);
128        }
129
130        self.sector_reader.read(&mut read_requests)?;
131        self.cur_sector_idx += sectors_to_fetch.len() as u64;
132
133        Ok(())
134    }
135
136    #[inline]
137    /// Get node data by local index.
138    pub fn node_disk_buf(&self, node_index_local: usize, vertex_id: u32) -> &[u8] {
139        // get sector_buf where this node is located
140        let sector_buf = self.get_sector_buf(node_index_local);
141        let node_offset = self.get_node_offset(vertex_id);
142        &sector_buf[node_offset..node_offset + self.node_len as usize]
143    }
144
145    /// Get sector data by local index
146    #[inline]
147    fn get_sector_buf(&self, local_sector_idx: usize) -> &[u8] {
148        let len_per_node = self.num_sectors_per_node * self.block_size;
149        &self.sectors_data[local_sector_idx * len_per_node..(local_sector_idx + 1) * len_per_node]
150    }
151
152    /// Get offset of node in sectors_data
153    #[inline]
154    fn get_node_offset(&self, vertex_id: u32) -> usize {
155        if self.num_nodes_per_sector == 0 {
156            // multi-sector node
157            0
158        } else {
159            // multi node in a sector
160            (vertex_id as u64 % self.num_nodes_per_sector * self.node_len) as usize
161        }
162    }
163
164    #[inline]
165    /// Gets the index for the sector that contains the node with the given vertex_id
166    pub fn node_sector_index(&self, vertex_id: u32) -> u64 {
167        1 + if self.num_nodes_per_sector > 0 {
168            vertex_id as u64 / self.num_nodes_per_sector
169        } else {
170            vertex_id as u64 * self.num_sectors_per_node as u64
171        }
172    }
173}
174
175impl<AlignedReaderType: AlignedFileReader> Deref for DiskSectorGraph<AlignedReaderType> {
176    type Target = [u8];
177
178    fn deref(&self) -> &Self::Target {
179        &self.sectors_data
180    }
181}
182
183#[cfg(test)]
184mod disk_sector_graph_test {
185    use crate::utils::aligned_file_reader::{
186        traits::AlignedReaderFactory, AlignedFileReaderFactory,
187    };
188    use diskann_utils::test_data_root;
189
190    use super::*;
191    use crate::data_model::{GraphLayoutVersion, GraphMetadata};
192
193    fn test_index_path() -> String {
194        test_data_root()
195            .join("disk_index_misc/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_aligned_reader_test.index")
196            .to_string_lossy()
197            .to_string()
198    }
199
200    fn test_initialize_disk_sector_graph(
201        num_nodes_per_sector: u64,
202        num_sectors_per_node: usize,
203        sector_reader: <AlignedFileReaderFactory as AlignedReaderFactory>::AlignedReaderType,
204    ) -> DiskSectorGraph<<AlignedFileReaderFactory as AlignedReaderFactory>::AlignedReaderType>
205    {
206        DiskSectorGraph {
207            sectors_data: AlignedBoxWithSlice::new(512, 512).unwrap(),
208            sector_reader,
209            cur_sector_idx: 0,
210            num_nodes_per_sector,
211            node_len: 32,
212            max_n_batch_sector_read: 4,
213            num_sectors_per_node,
214            block_size: 64,
215        }
216    }
217
218    #[test]
219    fn test_new_disk_sector_graph_multi_node_per_sector() {
220        let metadata = GraphMetadata::new(1000, 32, 500, 32, 2, 20, 50, 1024, 256);
221        let header = GraphHeader::new(metadata, 64, GraphLayoutVersion::new(1, 0));
222        let reader = AlignedFileReaderFactory::new(test_index_path())
223            .build()
224            .unwrap();
225        let graph = DiskSectorGraph::new(reader, &header, 2).unwrap();
226        assert_eq!(graph.sectors_data.len(), 128);
227        assert_eq!(graph.num_sectors_per_node, 1);
228        assert_eq!(graph.num_nodes_per_sector, 2);
229    }
230
231    #[test]
232    fn test_new_disk_sector_graph_multi_sector_per_node() {
233        let metadata = GraphMetadata::new(1000, 32, 500, 128, 0, 20, 50, 1024, 256);
234        let header = GraphHeader::new(metadata, 64, GraphLayoutVersion::new(1, 0));
235        let reader = AlignedFileReaderFactory::new(test_index_path())
236            .build()
237            .unwrap();
238        let graph = DiskSectorGraph::new(reader, &header, 2).unwrap();
239        assert_eq!(graph.sectors_data.len(), 256);
240        assert_eq!(graph.num_sectors_per_node, 2);
241        assert_eq!(graph.num_nodes_per_sector, 0);
242    }
243
244    #[test]
245    fn test_new_disk_sector_graph_old_version_data() {
246        let metadata = GraphMetadata::new(1000, 32, 500, 128, 0, 20, 50, 1024, 256);
247        let header = GraphHeader::new(metadata, 9999, GraphLayoutVersion::new(0, 0));
248        let reader = AlignedFileReaderFactory::new(test_index_path())
249            .build()
250            .unwrap();
251        let graph = DiskSectorGraph::new(reader, &header, 2).unwrap();
252        assert_eq!(graph.block_size, DEFAULT_DISK_SECTOR_LEN);
253    }
254
255    #[test]
256    fn get_sector_buf_test() {
257        let reader = AlignedFileReaderFactory::new(test_index_path())
258            .build()
259            .unwrap();
260        let graph = test_initialize_disk_sector_graph(2, 1, reader);
261        let sector_buf = graph.get_sector_buf(0);
262        assert_eq!(sector_buf.len(), 64);
263    }
264
265    #[test]
266    fn get_node_offset_test_multi_node_per_sector() {
267        let reader = AlignedFileReaderFactory::new(test_index_path())
268            .build()
269            .unwrap();
270        let graph = test_initialize_disk_sector_graph(4, 1, reader);
271
272        assert_eq!(graph.get_node_offset(0), 0);
273        assert_eq!(graph.get_node_offset(1), 32);
274        assert_eq!(graph.get_node_offset(2), 64);
275        assert_eq!(graph.get_node_offset(3), 96);
276        assert_eq!(graph.get_node_offset(4), 0);
277        assert_eq!(graph.get_node_offset(5), 32);
278        assert_eq!(graph.get_node_offset(6), 64);
279        assert_eq!(graph.get_node_offset(7), 96);
280    }
281
282    #[test]
283    fn get_node_offset_test_multi_sector_per_node() {
284        let reader = AlignedFileReaderFactory::new(test_index_path())
285            .build()
286            .unwrap();
287        let graph = test_initialize_disk_sector_graph(0, 2, reader);
288
289        assert_eq!(graph.get_node_offset(0), 0);
290        assert_eq!(graph.get_node_offset(1), 0);
291        assert_eq!(graph.get_node_offset(2), 0);
292        assert_eq!(graph.get_node_offset(3), 0);
293        assert_eq!(graph.get_node_offset(4), 0);
294        assert_eq!(graph.get_node_offset(5), 0);
295    }
296
297    #[test]
298    fn node_sector_index_test_multi_node_per_sector() {
299        let reader = AlignedFileReaderFactory::new(test_index_path())
300            .build()
301            .unwrap();
302        let graph = test_initialize_disk_sector_graph(4, 1, reader);
303
304        assert_eq!(graph.node_sector_index(0), 1);
305        assert_eq!(graph.node_sector_index(3), 1);
306        assert_eq!(graph.node_sector_index(4), 2);
307        assert_eq!(graph.node_sector_index(5), 2);
308        assert_eq!(graph.node_sector_index(7), 2);
309        assert_eq!(graph.node_sector_index(8), 3);
310        assert_eq!(graph.node_sector_index(1023), 256);
311        assert_eq!(graph.node_sector_index(1024), 257);
312        assert_eq!(graph.node_sector_index(2047), 512);
313        assert_eq!(graph.node_sector_index(2048), 513);
314    }
315
316    #[test]
317    fn node_sector_index_test_multi_sector_per_node() {
318        let reader = AlignedFileReaderFactory::new(test_index_path())
319            .build()
320            .unwrap();
321        let graph = test_initialize_disk_sector_graph(0, 2, reader);
322
323        assert_eq!(graph.node_sector_index(0), 1);
324        assert_eq!(graph.node_sector_index(3), 7);
325        assert_eq!(graph.node_sector_index(4), 9);
326        assert_eq!(graph.node_sector_index(5), 11);
327        assert_eq!(graph.node_sector_index(7), 15);
328        assert_eq!(graph.node_sector_index(8), 17);
329        assert_eq!(graph.node_sector_index(1023), 2047);
330        assert_eq!(graph.node_sector_index(1024), 2049);
331        assert_eq!(graph.node_sector_index(2047), 4095);
332        assert_eq!(graph.node_sector_index(2048), 4097);
333    }
334
335    #[test]
336    fn test_read_graph_max_sectors() {
337        let reader = AlignedFileReaderFactory::new(test_index_path())
338            .build()
339            .unwrap();
340        let mut disk_sector_graph = test_initialize_disk_sector_graph(0, 2, reader);
341
342        // Try to read more sectors than the maximum allowed
343        let sectors_to_fetch = vec![1, 2, 3, 4, 5, 6];
344        let result = disk_sector_graph.read_graph(&sectors_to_fetch);
345
346        // Check that an error is returned
347        // Trying to read too many sectors. number of sectors to read: {}, max number of sectors can read: {}",
348        assert!(result.is_err());
349    }
350
351    #[test]
352    fn test_disk_sector_graph_deref() {
353        let reader = AlignedFileReaderFactory::new(test_index_path())
354            .build()
355            .unwrap();
356        let graph = test_initialize_disk_sector_graph(1, 1, reader);
357        let data = &graph;
358        assert_eq!(data.len(), 512);
359    }
360}