Skip to main content

diskann_disk/utils/aligned_file_reader/
windows_aligned_file_reader.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use std::{ptr, thread, time::Duration};
6
7use diskann::{ANNError, ANNResult};
8use diskann_platform::{
9    get_queued_completion_status, read_file_to_slice, ssd_io_context::IOContext, AccessMode,
10    FileHandle, IOCompletionPort, ShareMode, DWORD, OVERLAPPED, ULONG_PTR,
11};
12
13use super::traits::AlignedFileReader;
14use crate::utils::aligned_file_reader::AlignedRead;
15
16pub const MAX_IO_CONCURRENCY: usize = 128;
17pub const IO_COMPLETION_TIMEOUT: DWORD = u32::MAX; // Infinite timeout.
18pub const ASYNC_IO_COMPLETION_CHECK_INTERVAL: Duration = Duration::from_micros(5);
19
20/// AlignedFileReader for Windows.  When you modify this class run the benchmarks to make sure
21/// we don't regress on runtime.
22///
23/// # Run this before making your code change
24/// cargo bench --bench bench_main -p diskann -- --save-baseline prior_to_change
25///
26/// # Run this after making your code change to generate comparison metrics
27/// cargo bench --bench bench_main -p diskann -- --baseline prior_to_change
28pub struct WindowsAlignedFileReader {
29    io_context: IOContext,
30}
31
32impl WindowsAlignedFileReader {
33    pub fn new(fname: &str) -> ANNResult<Self> {
34        let mut io_context = IOContext::new();
35        tracing::debug!("Creating file handle for {}", fname);
36        match unsafe { FileHandle::new(fname, AccessMode::Read, ShareMode::Read) } {
37            Ok(file_handle) => io_context.file_handle = file_handle,
38            Err(err) => {
39                return Err(ANNError::log_io_error(err));
40            }
41        }
42
43        // Create a io completion port for the file handle, later it will be used to get the completion status.
44        match IOCompletionPort::new(&io_context.file_handle, None, 0, 0) {
45            Ok(io_completion_port) => io_context.io_completion_port = io_completion_port,
46            Err(err) => {
47                return Err(ANNError::log_io_error(err));
48            }
49        }
50
51        Ok(WindowsAlignedFileReader { io_context })
52    }
53}
54
55impl AlignedFileReader for WindowsAlignedFileReader {
56    // Read the data from the file by sending concurrent io requests in batches.
57    fn read(&mut self, read_requests: &mut [AlignedRead<u8>]) -> ANNResult<()> {
58        let n_requests = read_requests.len();
59        let n_batches = n_requests.div_ceil(MAX_IO_CONCURRENCY);
60        let ctx = &self.io_context;
61        let mut overlapped_in_out =
62            vec![unsafe { std::mem::zeroed::<OVERLAPPED>() }; MAX_IO_CONCURRENCY];
63
64        for batch_idx in 0..n_batches {
65            let batch_start = MAX_IO_CONCURRENCY * batch_idx;
66            let batch_size = std::cmp::min(n_requests - batch_start, MAX_IO_CONCURRENCY);
67
68            for j in 0..batch_size {
69                let req = &mut read_requests[batch_start + j];
70                let offset = req.offset();
71                let os = &mut overlapped_in_out[j];
72
73                match unsafe {
74                    read_file_to_slice(&ctx.file_handle, req.aligned_buf_mut(), os, offset)
75                } {
76                    Ok(_) => {}
77                    Err(error) => {
78                        return Err(ANNError::log_io_error(error));
79                    }
80                }
81            }
82
83            let mut n_read: DWORD = 0;
84            let mut n_complete: u64 = 0;
85            let mut completion_key: ULONG_PTR = 0;
86            let mut lp_os: *mut OVERLAPPED = ptr::null_mut();
87            while n_complete < batch_size as u64 {
88                match unsafe {
89                    get_queued_completion_status(
90                        &ctx.io_completion_port,
91                        &mut n_read,
92                        &mut completion_key,
93                        &mut lp_os,
94                        IO_COMPLETION_TIMEOUT,
95                    )
96                } {
97                    // An IO request completed.
98                    Ok(true) => n_complete += 1,
99                    // No IO request completed, continue to wait.
100                    Ok(false) => {
101                        thread::sleep(ASYNC_IO_COMPLETION_CHECK_INTERVAL);
102                    }
103                    // An error ocurred.
104                    Err(error) => return Err(ANNError::log_io_error(error)),
105                }
106            }
107        }
108
109        Ok(())
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use std::{
116        fs::File,
117        io::{BufReader, Read, Seek, SeekFrom},
118    };
119
120    use bincode::deserialize_from;
121    use diskann_utils::test_data_root;
122    use serde::{Deserialize, Serialize};
123
124    use super::*;
125    use crate::utils::aligned_file_reader::AlignedRead;
126    use diskann_providers::common::AlignedBoxWithSlice;
127
128    fn test_index_path() -> String {
129        test_data_root()
130            .join("disk_index_misc/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_aligned_reader_test.index")
131            .to_string_lossy()
132            .to_string()
133    }
134
135    fn truth_node_data_path() -> String {
136        test_data_root()
137            .join("disk_index_misc/disk_index_node_data_aligned_reader_truth.bin")
138            .to_string_lossy()
139            .to_string()
140    }
141
142    const DEFAULT_DISK_SECTOR_LEN: usize = 4096;
143
144    #[derive(Debug, Serialize, Deserialize)]
145    struct NodeData {
146        num_neighbors: u32,
147        coordinates: Vec<f32>,
148        neighbors: Vec<u32>,
149    }
150
151    impl PartialEq for NodeData {
152        fn eq(&self, other: &Self) -> bool {
153            self.num_neighbors == other.num_neighbors
154                && self.coordinates == other.coordinates
155                && self.neighbors == other.neighbors
156        }
157    }
158
159    #[test]
160    fn test_new_aligned_file_reader() {
161        // Replace "test_file_path" with actual file path
162        let result = WindowsAlignedFileReader::new(&test_index_path());
163        assert!(result.is_ok());
164    }
165
166    #[test]
167    fn test_read() {
168        let mut reader = WindowsAlignedFileReader::new(&test_index_path()).unwrap();
169
170        let read_length = 512; // adjust according to your logic
171        let num_read = 10;
172        let mut aligned_mem = AlignedBoxWithSlice::<u8>::new(read_length * num_read, 512).unwrap();
173
174        // create and add AlignedReads to the vector
175        let mut mem_slices = aligned_mem
176            .split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length)
177            .unwrap();
178
179        let mut aligned_reads: Vec<AlignedRead<'_, u8>> = mem_slices
180            .iter_mut()
181            .enumerate()
182            .map(|(i, slice)| {
183                let offset = (i * read_length) as u64;
184                AlignedRead::new(offset, slice).unwrap()
185            })
186            .collect();
187
188        let result = reader.read(&mut aligned_reads);
189        assert!(result.is_ok());
190
191        // Assert that the actual data is correct.
192        let mut file = File::open(test_index_path()).unwrap();
193        for current_read in aligned_reads {
194            let mut expected = vec![0; current_read.aligned_buf().len()];
195            file.seek(SeekFrom::Start(current_read.offset())).unwrap();
196            file.read_exact(&mut expected).unwrap();
197
198            assert_eq!(
199                expected,
200                current_read.aligned_buf(),
201                "aligned_buf did not contain the expected data"
202            );
203        }
204    }
205
206    #[test]
207    fn test_read_disk_index_by_sector() {
208        let mut reader = WindowsAlignedFileReader::new(&test_index_path()).unwrap();
209
210        let read_length = DEFAULT_DISK_SECTOR_LEN; // adjust according to your logic
211        let num_sector = 10;
212        let mut aligned_mem =
213            AlignedBoxWithSlice::<u8>::new(read_length * num_sector, 512).unwrap();
214
215        // Each slice will be used as the buffer for a read request of a sector.
216        let mut mem_slices = aligned_mem
217            .split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length)
218            .unwrap();
219
220        let mut aligned_reads: Vec<AlignedRead<'_, u8>> = mem_slices
221            .iter_mut()
222            .enumerate()
223            .map(|(sector_id, slice)| {
224                let offset = (sector_id * read_length) as u64;
225                AlignedRead::new(offset, slice).unwrap()
226            })
227            .collect();
228
229        let result = reader.read(&mut aligned_reads);
230        assert!(result.is_ok());
231
232        aligned_reads.iter().for_each(|read| {
233            assert_eq!(read.aligned_buf().len(), DEFAULT_DISK_SECTOR_LEN);
234        });
235
236        let disk_layout_meta = reconstruct_disk_meta(aligned_reads[0].aligned_buf_mut());
237        assert!(disk_layout_meta.len() > 9);
238
239        let dims = disk_layout_meta[1];
240        let num_pts = disk_layout_meta[0];
241        let node_len = disk_layout_meta[3];
242        let max_num_nodes_per_sector = disk_layout_meta[4];
243
244        assert!(node_len * max_num_nodes_per_sector < DEFAULT_DISK_SECTOR_LEN as u64);
245
246        let num_nbrs_start = (dims as usize) * std::mem::size_of::<f32>();
247        let nbrs_buf_start = num_nbrs_start + std::mem::size_of::<u32>();
248
249        let mut node_data_array = Vec::with_capacity(max_num_nodes_per_sector as usize * 9);
250
251        // Only validate the first 9 sectors with graph nodes.
252        (1..9).for_each(|sector_id| {
253            let sector_data = &mem_slices[sector_id];
254            for node_data in sector_data.chunks_exact(node_len as usize) {
255                // Extract coordinates data from the start of the node_data
256                let coordinates_end = (dims as usize) * std::mem::size_of::<f32>();
257                let coordinates = node_data[0..coordinates_end]
258                    .chunks_exact(std::mem::size_of::<f32>())
259                    .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
260                    .collect();
261
262                // Extract number of neighbors from the node_data
263                let neighbors_num = u32::from_le_bytes(
264                    node_data[num_nbrs_start..nbrs_buf_start]
265                        .try_into()
266                        .unwrap(),
267                );
268
269                let nbors_buf_end =
270                    nbrs_buf_start + (neighbors_num as usize) * std::mem::size_of::<u32>();
271
272                // Extract neighbors from the node data.
273                let mut neighbors = Vec::new();
274                for nbors_data in node_data[nbrs_buf_start..nbors_buf_end]
275                    .chunks_exact(std::mem::size_of::<u32>())
276                {
277                    let nbors_id = u32::from_le_bytes(nbors_data.try_into().unwrap());
278                    assert!(nbors_id < num_pts as u32);
279                    neighbors.push(nbors_id);
280                }
281
282                // Create NodeData struct and push it to the node_data_array
283                node_data_array.push(NodeData {
284                    num_neighbors: neighbors_num,
285                    coordinates,
286                    neighbors,
287                });
288            }
289        });
290
291        // Compare that each node read from the disk index are expected.
292        let node_data_truth_file = File::open(truth_node_data_path()).unwrap();
293        let reader = BufReader::new(node_data_truth_file);
294
295        let node_data_vec: Vec<NodeData> = deserialize_from(reader).unwrap();
296        for (node_from_node_data_file, node_from_disk_index) in
297            node_data_vec.iter().zip(node_data_array.iter())
298        {
299            // Verify that the NodeData from the file is equal to the NodeData in node_data_array
300            assert_eq!(node_from_node_data_file, node_from_disk_index);
301        }
302    }
303
304    #[test]
305    fn test_read_fail_invalid_file() {
306        let reader = WindowsAlignedFileReader::new("/invalid_path");
307        assert!(reader.is_err());
308    }
309
310    #[test]
311    #[allow(clippy::read_zero_byte_vec)]
312    fn test_read_no_requests() {
313        let mut reader = WindowsAlignedFileReader::new(&test_index_path()).unwrap();
314
315        let mut read_requests = Vec::<AlignedRead<u8>>::new();
316        let result = reader.read(&mut read_requests);
317        assert!(result.is_ok());
318    }
319
320    fn reconstruct_disk_meta(buffer: &[u8]) -> Vec<u64> {
321        let size_of_u64 = std::mem::size_of::<u64>();
322
323        let num_values = buffer.len() / size_of_u64;
324        let mut disk_layout_meta = Vec::with_capacity(num_values);
325        let meta_data = &buffer[8..];
326
327        for chunk in meta_data.chunks_exact(size_of_u64) {
328            let value = u64::from_le_bytes(chunk.try_into().unwrap());
329            disk_layout_meta.push(value);
330        }
331
332        disk_layout_meta
333    }
334}