diskann_disk/search/provider/
disk_sector_graph.rs1#![warn(missing_docs)]
6
7use std::ops::Deref;
9
10use diskann::{ANNError, ANNResult};
11use diskann_providers::common::AlignedBoxWithSlice;
12
13use crate::{
14 data_model::GraphHeader,
15 utils::aligned_file_reader::{traits::AlignedFileReader, AlignedRead},
16};
17
18const DEFAULT_DISK_SECTOR_LEN: usize = 4096;
19
20pub struct DiskSectorGraph<AlignedReaderType: AlignedFileReader> {
22 sector_reader: AlignedReaderType,
25 sectors_data: AlignedBoxWithSlice<u8>,
38 cur_sector_idx: u64,
40
41 num_nodes_per_sector: u64,
43
44 node_len: u64,
45
46 max_n_batch_sector_read: usize,
47
48 num_sectors_per_node: usize,
49
50 block_size: usize,
51}
52
53impl<AlignedReaderType: AlignedFileReader> DiskSectorGraph<AlignedReaderType> {
54 pub fn new(
56 sector_reader: AlignedReaderType,
57 header: &GraphHeader,
58 max_n_batch_sector_read: usize,
59 ) -> ANNResult<Self> {
60 let mut block_size = header.block_size() as usize;
61 let version = header.layout_version();
62 if (version.major_version() == 0 && version.minor_version() == 0) || block_size == 0 {
63 block_size = DEFAULT_DISK_SECTOR_LEN;
64 }
65
66 let num_nodes_per_sector = header.metadata().num_nodes_per_block;
67 let node_len = header.metadata().node_len;
68 let num_sectors_per_node = if num_nodes_per_sector > 0 {
69 1
70 } else {
71 (node_len as usize).div_ceil(block_size)
72 };
73
74 Ok(Self {
75 sector_reader,
76 sectors_data: AlignedBoxWithSlice::new(
77 max_n_batch_sector_read * num_sectors_per_node * block_size,
78 block_size,
79 )?,
80 cur_sector_idx: 0,
81 num_nodes_per_sector,
82 node_len,
83 max_n_batch_sector_read,
84 num_sectors_per_node,
85 block_size,
86 })
87 }
88
89 pub fn reconfigure(&mut self, max_n_batch_sector_read: usize) -> ANNResult<()> {
91 if max_n_batch_sector_read > self.max_n_batch_sector_read {
92 self.max_n_batch_sector_read = max_n_batch_sector_read;
93 self.sectors_data = AlignedBoxWithSlice::new(
94 max_n_batch_sector_read * self.num_sectors_per_node * self.block_size,
95 self.block_size,
96 )?;
97 }
98 Ok(())
99 }
100
101 pub fn reset(&mut self) {
103 self.cur_sector_idx = 0;
104 }
105
106 pub fn read_graph(&mut self, sectors_to_fetch: &[u64]) -> ANNResult<()> {
109 let cur_sector_idx_usize: usize = self.cur_sector_idx.try_into()?;
110 if sectors_to_fetch.len() > self.max_n_batch_sector_read - cur_sector_idx_usize {
111 return Err(ANNError::log_index_error(format_args!(
112 "Trying to read too many sectors. number of sectors to read: {}, max number of sectors can read: {}",
113 sectors_to_fetch.len(),
114 self.max_n_batch_sector_read - cur_sector_idx_usize,
115 )));
116 }
117
118 let len_per_node = self.num_sectors_per_node * self.block_size;
119 let mut sector_slices = self.sectors_data.split_into_nonoverlapping_mut_slices(
120 cur_sector_idx_usize * len_per_node
121 ..(cur_sector_idx_usize + sectors_to_fetch.len()) * len_per_node,
122 len_per_node,
123 )?;
124 let mut read_requests = Vec::with_capacity(sector_slices.len());
125 for (local_sector_idx, slice) in sector_slices.iter_mut().enumerate() {
126 let sector_id = sectors_to_fetch[local_sector_idx];
127 read_requests.push(AlignedRead::new(sector_id * self.block_size as u64, slice)?);
128 }
129
130 self.sector_reader.read(&mut read_requests)?;
131 self.cur_sector_idx += sectors_to_fetch.len() as u64;
132
133 Ok(())
134 }
135
136 #[inline]
137 pub fn node_disk_buf(&self, node_index_local: usize, vertex_id: u32) -> &[u8] {
139 let sector_buf = self.get_sector_buf(node_index_local);
141 let node_offset = self.get_node_offset(vertex_id);
142 §or_buf[node_offset..node_offset + self.node_len as usize]
143 }
144
145 #[inline]
147 fn get_sector_buf(&self, local_sector_idx: usize) -> &[u8] {
148 let len_per_node = self.num_sectors_per_node * self.block_size;
149 &self.sectors_data[local_sector_idx * len_per_node..(local_sector_idx + 1) * len_per_node]
150 }
151
152 #[inline]
154 fn get_node_offset(&self, vertex_id: u32) -> usize {
155 if self.num_nodes_per_sector == 0 {
156 0
158 } else {
159 (vertex_id as u64 % self.num_nodes_per_sector * self.node_len) as usize
161 }
162 }
163
164 #[inline]
165 pub fn node_sector_index(&self, vertex_id: u32) -> u64 {
167 1 + if self.num_nodes_per_sector > 0 {
168 vertex_id as u64 / self.num_nodes_per_sector
169 } else {
170 vertex_id as u64 * self.num_sectors_per_node as u64
171 }
172 }
173}
174
175impl<AlignedReaderType: AlignedFileReader> Deref for DiskSectorGraph<AlignedReaderType> {
176 type Target = [u8];
177
178 fn deref(&self) -> &Self::Target {
179 &self.sectors_data
180 }
181}
182
183#[cfg(test)]
184mod disk_sector_graph_test {
185 use crate::utils::aligned_file_reader::{
186 traits::AlignedReaderFactory, AlignedFileReaderFactory,
187 };
188 use diskann_utils::test_data_root;
189
190 use super::*;
191 use crate::data_model::{GraphLayoutVersion, GraphMetadata};
192
193 fn test_index_path() -> String {
194 test_data_root()
195 .join("disk_index_misc/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_aligned_reader_test.index")
196 .to_string_lossy()
197 .to_string()
198 }
199
200 fn test_initialize_disk_sector_graph(
201 num_nodes_per_sector: u64,
202 num_sectors_per_node: usize,
203 sector_reader: <AlignedFileReaderFactory as AlignedReaderFactory>::AlignedReaderType,
204 ) -> DiskSectorGraph<<AlignedFileReaderFactory as AlignedReaderFactory>::AlignedReaderType>
205 {
206 DiskSectorGraph {
207 sectors_data: AlignedBoxWithSlice::new(512, 512).unwrap(),
208 sector_reader,
209 cur_sector_idx: 0,
210 num_nodes_per_sector,
211 node_len: 32,
212 max_n_batch_sector_read: 4,
213 num_sectors_per_node,
214 block_size: 64,
215 }
216 }
217
218 #[test]
219 fn test_new_disk_sector_graph_multi_node_per_sector() {
220 let metadata = GraphMetadata::new(1000, 32, 500, 32, 2, 20, 50, 1024, 256);
221 let header = GraphHeader::new(metadata, 64, GraphLayoutVersion::new(1, 0));
222 let reader = AlignedFileReaderFactory::new(test_index_path())
223 .build()
224 .unwrap();
225 let graph = DiskSectorGraph::new(reader, &header, 2).unwrap();
226 assert_eq!(graph.sectors_data.len(), 128);
227 assert_eq!(graph.num_sectors_per_node, 1);
228 assert_eq!(graph.num_nodes_per_sector, 2);
229 }
230
231 #[test]
232 fn test_new_disk_sector_graph_multi_sector_per_node() {
233 let metadata = GraphMetadata::new(1000, 32, 500, 128, 0, 20, 50, 1024, 256);
234 let header = GraphHeader::new(metadata, 64, GraphLayoutVersion::new(1, 0));
235 let reader = AlignedFileReaderFactory::new(test_index_path())
236 .build()
237 .unwrap();
238 let graph = DiskSectorGraph::new(reader, &header, 2).unwrap();
239 assert_eq!(graph.sectors_data.len(), 256);
240 assert_eq!(graph.num_sectors_per_node, 2);
241 assert_eq!(graph.num_nodes_per_sector, 0);
242 }
243
244 #[test]
245 fn test_new_disk_sector_graph_old_version_data() {
246 let metadata = GraphMetadata::new(1000, 32, 500, 128, 0, 20, 50, 1024, 256);
247 let header = GraphHeader::new(metadata, 9999, GraphLayoutVersion::new(0, 0));
248 let reader = AlignedFileReaderFactory::new(test_index_path())
249 .build()
250 .unwrap();
251 let graph = DiskSectorGraph::new(reader, &header, 2).unwrap();
252 assert_eq!(graph.block_size, DEFAULT_DISK_SECTOR_LEN);
253 }
254
255 #[test]
256 fn get_sector_buf_test() {
257 let reader = AlignedFileReaderFactory::new(test_index_path())
258 .build()
259 .unwrap();
260 let graph = test_initialize_disk_sector_graph(2, 1, reader);
261 let sector_buf = graph.get_sector_buf(0);
262 assert_eq!(sector_buf.len(), 64);
263 }
264
265 #[test]
266 fn get_node_offset_test_multi_node_per_sector() {
267 let reader = AlignedFileReaderFactory::new(test_index_path())
268 .build()
269 .unwrap();
270 let graph = test_initialize_disk_sector_graph(4, 1, reader);
271
272 assert_eq!(graph.get_node_offset(0), 0);
273 assert_eq!(graph.get_node_offset(1), 32);
274 assert_eq!(graph.get_node_offset(2), 64);
275 assert_eq!(graph.get_node_offset(3), 96);
276 assert_eq!(graph.get_node_offset(4), 0);
277 assert_eq!(graph.get_node_offset(5), 32);
278 assert_eq!(graph.get_node_offset(6), 64);
279 assert_eq!(graph.get_node_offset(7), 96);
280 }
281
282 #[test]
283 fn get_node_offset_test_multi_sector_per_node() {
284 let reader = AlignedFileReaderFactory::new(test_index_path())
285 .build()
286 .unwrap();
287 let graph = test_initialize_disk_sector_graph(0, 2, reader);
288
289 assert_eq!(graph.get_node_offset(0), 0);
290 assert_eq!(graph.get_node_offset(1), 0);
291 assert_eq!(graph.get_node_offset(2), 0);
292 assert_eq!(graph.get_node_offset(3), 0);
293 assert_eq!(graph.get_node_offset(4), 0);
294 assert_eq!(graph.get_node_offset(5), 0);
295 }
296
297 #[test]
298 fn node_sector_index_test_multi_node_per_sector() {
299 let reader = AlignedFileReaderFactory::new(test_index_path())
300 .build()
301 .unwrap();
302 let graph = test_initialize_disk_sector_graph(4, 1, reader);
303
304 assert_eq!(graph.node_sector_index(0), 1);
305 assert_eq!(graph.node_sector_index(3), 1);
306 assert_eq!(graph.node_sector_index(4), 2);
307 assert_eq!(graph.node_sector_index(5), 2);
308 assert_eq!(graph.node_sector_index(7), 2);
309 assert_eq!(graph.node_sector_index(8), 3);
310 assert_eq!(graph.node_sector_index(1023), 256);
311 assert_eq!(graph.node_sector_index(1024), 257);
312 assert_eq!(graph.node_sector_index(2047), 512);
313 assert_eq!(graph.node_sector_index(2048), 513);
314 }
315
316 #[test]
317 fn node_sector_index_test_multi_sector_per_node() {
318 let reader = AlignedFileReaderFactory::new(test_index_path())
319 .build()
320 .unwrap();
321 let graph = test_initialize_disk_sector_graph(0, 2, reader);
322
323 assert_eq!(graph.node_sector_index(0), 1);
324 assert_eq!(graph.node_sector_index(3), 7);
325 assert_eq!(graph.node_sector_index(4), 9);
326 assert_eq!(graph.node_sector_index(5), 11);
327 assert_eq!(graph.node_sector_index(7), 15);
328 assert_eq!(graph.node_sector_index(8), 17);
329 assert_eq!(graph.node_sector_index(1023), 2047);
330 assert_eq!(graph.node_sector_index(1024), 2049);
331 assert_eq!(graph.node_sector_index(2047), 4095);
332 assert_eq!(graph.node_sector_index(2048), 4097);
333 }
334
335 #[test]
336 fn test_read_graph_max_sectors() {
337 let reader = AlignedFileReaderFactory::new(test_index_path())
338 .build()
339 .unwrap();
340 let mut disk_sector_graph = test_initialize_disk_sector_graph(0, 2, reader);
341
342 let sectors_to_fetch = vec![1, 2, 3, 4, 5, 6];
344 let result = disk_sector_graph.read_graph(§ors_to_fetch);
345
346 assert!(result.is_err());
349 }
350
351 #[test]
352 fn test_disk_sector_graph_deref() {
353 let reader = AlignedFileReaderFactory::new(test_index_path())
354 .build()
355 .unwrap();
356 let graph = test_initialize_disk_sector_graph(1, 1, reader);
357 let data = &graph;
358 assert_eq!(data.len(), 512);
359 }
360}