diskann_disk/utils/aligned_file_reader/
windows_aligned_file_reader.rs1use std::{ptr, thread, time::Duration};
6
7use diskann::{ANNError, ANNResult};
8use diskann_platform::{
9 get_queued_completion_status, read_file_to_slice, ssd_io_context::IOContext, AccessMode,
10 FileHandle, IOCompletionPort, ShareMode, DWORD, OVERLAPPED, ULONG_PTR,
11};
12
13use super::traits::AlignedFileReader;
14use crate::utils::aligned_file_reader::AlignedRead;
15
16pub const MAX_IO_CONCURRENCY: usize = 128;
17pub const IO_COMPLETION_TIMEOUT: DWORD = u32::MAX; pub const ASYNC_IO_COMPLETION_CHECK_INTERVAL: Duration = Duration::from_micros(5);
19
20pub struct WindowsAlignedFileReader {
29 io_context: IOContext,
30}
31
32impl WindowsAlignedFileReader {
33 pub fn new(fname: &str) -> ANNResult<Self> {
34 let mut io_context = IOContext::new();
35 tracing::debug!("Creating file handle for {}", fname);
36 match unsafe { FileHandle::new(fname, AccessMode::Read, ShareMode::Read) } {
37 Ok(file_handle) => io_context.file_handle = file_handle,
38 Err(err) => {
39 return Err(ANNError::log_io_error(err));
40 }
41 }
42
43 match IOCompletionPort::new(&io_context.file_handle, None, 0, 0) {
45 Ok(io_completion_port) => io_context.io_completion_port = io_completion_port,
46 Err(err) => {
47 return Err(ANNError::log_io_error(err));
48 }
49 }
50
51 Ok(WindowsAlignedFileReader { io_context })
52 }
53}
54
55impl AlignedFileReader for WindowsAlignedFileReader {
56 fn read(&mut self, read_requests: &mut [AlignedRead<u8>]) -> ANNResult<()> {
58 let n_requests = read_requests.len();
59 let n_batches = n_requests.div_ceil(MAX_IO_CONCURRENCY);
60 let ctx = &self.io_context;
61 let mut overlapped_in_out =
62 vec![unsafe { std::mem::zeroed::<OVERLAPPED>() }; MAX_IO_CONCURRENCY];
63
64 for batch_idx in 0..n_batches {
65 let batch_start = MAX_IO_CONCURRENCY * batch_idx;
66 let batch_size = std::cmp::min(n_requests - batch_start, MAX_IO_CONCURRENCY);
67
68 for j in 0..batch_size {
69 let req = &mut read_requests[batch_start + j];
70 let offset = req.offset();
71 let os = &mut overlapped_in_out[j];
72
73 match unsafe {
74 read_file_to_slice(&ctx.file_handle, req.aligned_buf_mut(), os, offset)
75 } {
76 Ok(_) => {}
77 Err(error) => {
78 return Err(ANNError::log_io_error(error));
79 }
80 }
81 }
82
83 let mut n_read: DWORD = 0;
84 let mut n_complete: u64 = 0;
85 let mut completion_key: ULONG_PTR = 0;
86 let mut lp_os: *mut OVERLAPPED = ptr::null_mut();
87 while n_complete < batch_size as u64 {
88 match unsafe {
89 get_queued_completion_status(
90 &ctx.io_completion_port,
91 &mut n_read,
92 &mut completion_key,
93 &mut lp_os,
94 IO_COMPLETION_TIMEOUT,
95 )
96 } {
97 Ok(true) => n_complete += 1,
99 Ok(false) => {
101 thread::sleep(ASYNC_IO_COMPLETION_CHECK_INTERVAL);
102 }
103 Err(error) => return Err(ANNError::log_io_error(error)),
105 }
106 }
107 }
108
109 Ok(())
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use std::{
116 fs::File,
117 io::{BufReader, Read, Seek, SeekFrom},
118 };
119
120 use bincode::deserialize_from;
121 use diskann_utils::test_data_root;
122 use serde::{Deserialize, Serialize};
123
124 use super::*;
125 use crate::utils::aligned_file_reader::AlignedRead;
126 use diskann_providers::common::AlignedBoxWithSlice;
127
128 fn test_index_path() -> String {
129 test_data_root()
130 .join("disk_index_misc/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_aligned_reader_test.index")
131 .to_string_lossy()
132 .to_string()
133 }
134
135 fn truth_node_data_path() -> String {
136 test_data_root()
137 .join("disk_index_misc/disk_index_node_data_aligned_reader_truth.bin")
138 .to_string_lossy()
139 .to_string()
140 }
141
142 const DEFAULT_DISK_SECTOR_LEN: usize = 4096;
143
144 #[derive(Debug, Serialize, Deserialize)]
145 struct NodeData {
146 num_neighbors: u32,
147 coordinates: Vec<f32>,
148 neighbors: Vec<u32>,
149 }
150
151 impl PartialEq for NodeData {
152 fn eq(&self, other: &Self) -> bool {
153 self.num_neighbors == other.num_neighbors
154 && self.coordinates == other.coordinates
155 && self.neighbors == other.neighbors
156 }
157 }
158
159 #[test]
160 fn test_new_aligned_file_reader() {
161 let result = WindowsAlignedFileReader::new(&test_index_path());
163 assert!(result.is_ok());
164 }
165
166 #[test]
167 fn test_read() {
168 let mut reader = WindowsAlignedFileReader::new(&test_index_path()).unwrap();
169
170 let read_length = 512; let num_read = 10;
172 let mut aligned_mem = AlignedBoxWithSlice::<u8>::new(read_length * num_read, 512).unwrap();
173
174 let mut mem_slices = aligned_mem
176 .split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length)
177 .unwrap();
178
179 let mut aligned_reads: Vec<AlignedRead<'_, u8>> = mem_slices
180 .iter_mut()
181 .enumerate()
182 .map(|(i, slice)| {
183 let offset = (i * read_length) as u64;
184 AlignedRead::new(offset, slice).unwrap()
185 })
186 .collect();
187
188 let result = reader.read(&mut aligned_reads);
189 assert!(result.is_ok());
190
191 let mut file = File::open(test_index_path()).unwrap();
193 for current_read in aligned_reads {
194 let mut expected = vec![0; current_read.aligned_buf().len()];
195 file.seek(SeekFrom::Start(current_read.offset())).unwrap();
196 file.read_exact(&mut expected).unwrap();
197
198 assert_eq!(
199 expected,
200 current_read.aligned_buf(),
201 "aligned_buf did not contain the expected data"
202 );
203 }
204 }
205
206 #[test]
207 fn test_read_disk_index_by_sector() {
208 let mut reader = WindowsAlignedFileReader::new(&test_index_path()).unwrap();
209
210 let read_length = DEFAULT_DISK_SECTOR_LEN; let num_sector = 10;
212 let mut aligned_mem =
213 AlignedBoxWithSlice::<u8>::new(read_length * num_sector, 512).unwrap();
214
215 let mut mem_slices = aligned_mem
217 .split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length)
218 .unwrap();
219
220 let mut aligned_reads: Vec<AlignedRead<'_, u8>> = mem_slices
221 .iter_mut()
222 .enumerate()
223 .map(|(sector_id, slice)| {
224 let offset = (sector_id * read_length) as u64;
225 AlignedRead::new(offset, slice).unwrap()
226 })
227 .collect();
228
229 let result = reader.read(&mut aligned_reads);
230 assert!(result.is_ok());
231
232 aligned_reads.iter().for_each(|read| {
233 assert_eq!(read.aligned_buf().len(), DEFAULT_DISK_SECTOR_LEN);
234 });
235
236 let disk_layout_meta = reconstruct_disk_meta(aligned_reads[0].aligned_buf_mut());
237 assert!(disk_layout_meta.len() > 9);
238
239 let dims = disk_layout_meta[1];
240 let num_pts = disk_layout_meta[0];
241 let node_len = disk_layout_meta[3];
242 let max_num_nodes_per_sector = disk_layout_meta[4];
243
244 assert!(node_len * max_num_nodes_per_sector < DEFAULT_DISK_SECTOR_LEN as u64);
245
246 let num_nbrs_start = (dims as usize) * std::mem::size_of::<f32>();
247 let nbrs_buf_start = num_nbrs_start + std::mem::size_of::<u32>();
248
249 let mut node_data_array = Vec::with_capacity(max_num_nodes_per_sector as usize * 9);
250
251 (1..9).for_each(|sector_id| {
253 let sector_data = &mem_slices[sector_id];
254 for node_data in sector_data.chunks_exact(node_len as usize) {
255 let coordinates_end = (dims as usize) * std::mem::size_of::<f32>();
257 let coordinates = node_data[0..coordinates_end]
258 .chunks_exact(std::mem::size_of::<f32>())
259 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
260 .collect();
261
262 let neighbors_num = u32::from_le_bytes(
264 node_data[num_nbrs_start..nbrs_buf_start]
265 .try_into()
266 .unwrap(),
267 );
268
269 let nbors_buf_end =
270 nbrs_buf_start + (neighbors_num as usize) * std::mem::size_of::<u32>();
271
272 let mut neighbors = Vec::new();
274 for nbors_data in node_data[nbrs_buf_start..nbors_buf_end]
275 .chunks_exact(std::mem::size_of::<u32>())
276 {
277 let nbors_id = u32::from_le_bytes(nbors_data.try_into().unwrap());
278 assert!(nbors_id < num_pts as u32);
279 neighbors.push(nbors_id);
280 }
281
282 node_data_array.push(NodeData {
284 num_neighbors: neighbors_num,
285 coordinates,
286 neighbors,
287 });
288 }
289 });
290
291 let node_data_truth_file = File::open(truth_node_data_path()).unwrap();
293 let reader = BufReader::new(node_data_truth_file);
294
295 let node_data_vec: Vec<NodeData> = deserialize_from(reader).unwrap();
296 for (node_from_node_data_file, node_from_disk_index) in
297 node_data_vec.iter().zip(node_data_array.iter())
298 {
299 assert_eq!(node_from_node_data_file, node_from_disk_index);
301 }
302 }
303
304 #[test]
305 fn test_read_fail_invalid_file() {
306 let reader = WindowsAlignedFileReader::new("/invalid_path");
307 assert!(reader.is_err());
308 }
309
310 #[test]
311 #[allow(clippy::read_zero_byte_vec)]
312 fn test_read_no_requests() {
313 let mut reader = WindowsAlignedFileReader::new(&test_index_path()).unwrap();
314
315 let mut read_requests = Vec::<AlignedRead<u8>>::new();
316 let result = reader.read(&mut read_requests);
317 assert!(result.is_ok());
318 }
319
320 fn reconstruct_disk_meta(buffer: &[u8]) -> Vec<u64> {
321 let size_of_u64 = std::mem::size_of::<u64>();
322
323 let num_values = buffer.len() / size_of_u64;
324 let mut disk_layout_meta = Vec::with_capacity(num_values);
325 let meta_data = &buffer[8..];
326
327 for chunk in meta_data.chunks_exact(size_of_u64) {
328 let value = u64::from_le_bytes(chunk.try_into().unwrap());
329 disk_layout_meta.push(value);
330 }
331
332 disk_layout_meta
333 }
334}