diskann_platform/win/file_io.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5/// The module provides unsafe wrappers around two Windows API functions: `ReadFile` and `GetQueuedCompletionStatus`.
6///
7/// These wrappers aim to simplify and abstract the use of these functions, providing easier error handling and a safer interface.
8/// They return standard Rust `io::Result` types for convenience and consistency with the rest of the Rust standard library.
9use std::io;
10use std::ptr;
11
12use windows_sys::Win32::{
13 Foundation::{GetLastError, ERROR_IO_PENDING, WAIT_TIMEOUT},
14 Storage::FileSystem::ReadFile,
15 System::IO::{GetQueuedCompletionStatus, OVERLAPPED},
16};
17
18use super::{DWORD, ULONG_PTR};
19use crate::{FileHandle, IOCompletionPort};
20
21/// Asynchronously queue a read request from a file into a buffer slice.
22///
23/// Wraps the unsafe Windows API function `ReadFile`, making it safe to call only when the overlapped buffer
24/// remains valid and unchanged anywhere else during the entire async operation.
25///
26/// Returns a boolean indicating whether the read operation completed synchronously or is pending.
27///
28/// # Safety
29///
30/// This function is marked as `unsafe` because it uses raw pointers and requires the caller to ensure
31/// that the buffer slice and the overlapped buffer stay valid during the whole async operation.
32///
33/// SAFETY: THIS IS NOT ENTIRELY SAFE! PLEASE READ!
34///
35/// This function is thread safe i.e. the same file handle can be used by multiple threads to read from the file
36/// as it uses the windows ReadFile API with async mode using OVERLAPPED structure.
37/// ReadFile Function - https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-readfile
38/// Synchronous and Asynchronous I/O - https://learn.microsoft.com/en-us/windows/win32/FileIO/synchronous-and-asynchronous-i-o
39///
40/// The only caveat is read operation is followed by polling on the handle using GetQueuedCompletionStatus API.
41/// If multiple threads are submitting read requests and polling then polling will return the completion status of any of the read requests.
42/// This is because GetQueuedCompletionStatus API returns the completion status of any of the read requests that are completed.
43pub unsafe fn read_file_to_slice<T>(
44 file_handle: &FileHandle,
45 buffer_slice: &mut [T],
46 overlapped: *mut OVERLAPPED,
47 offset: u64,
48) -> io::Result<bool> {
49 let num_bytes = std::mem::size_of_val(buffer_slice);
50 unsafe {
51 ptr::write(overlapped, std::mem::zeroed());
52 (*overlapped).Anonymous.Anonymous.Offset = offset as u32;
53 (*overlapped).Anonymous.Anonymous.OffsetHigh = (offset >> 32) as u32;
54 }
55
56 let win32_result: i32 = unsafe {
57 ReadFile(
58 file_handle.handle,
59 buffer_slice.as_mut_ptr().cast::<u8>(),
60 num_bytes as DWORD,
61 ptr::null_mut(),
62 overlapped,
63 )
64 };
65
66 // `ReadFile` returns zero on failure.
67 if win32_result == 0 {
68 let error = unsafe { GetLastError() };
69 return if error != ERROR_IO_PENDING {
70 Err(io::Error::from_raw_os_error(error as i32))
71 } else {
72 Ok(false)
73 };
74 }
75
76 Ok(true)
77}
78
79/// Retrieves the results of an asynchronous I/O operation on an I/O completion port.
80///
81/// Wraps the unsafe Windows API function `GetQueuedCompletionStatus`, making it safe to call only when the overlapped buffer
82/// remains valid and unchanged anywhere else during the entire async operation.
83///
84/// Returns a boolean indicating whether an I/O operation completed synchronously or is still pending.
85///
86/// # Safety
87///
88/// This function is marked as `unsafe` because it uses raw pointers and requires the caller to ensure
89/// that the overlapped buffer stays valid during the whole async operation.
90///
91/// SAFETY: THIS IS NOT ENTIRELY SAFE! PLEASE READ!
92///
93/// This function is thread safe i.e. the same file handle can be used by multiple threads to read from the file
94/// as it uses the windows ReadFile API with async mode using OVERLAPPED structure.
95/// ReadFile Function - https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-readfile
96/// Synchronous and Asynchronous I/O - https://learn.microsoft.com/en-us/windows/win32/FileIO/synchronous-and-asynchronous-i-o
97///
98/// The only caveat is read operation is followed by polling on the handle using GetQueuedCompletionStatus API.
99/// If multiple threads are submitting read requests and polling then polling will return the completion status of any of the read requests.
100/// This is because GetQueuedCompletionStatus API returns the completion status of any of the read requests that are completed.
101pub unsafe fn get_queued_completion_status(
102 completion_port: &IOCompletionPort,
103 lp_number_of_bytes: &mut DWORD,
104 lp_completion_key: &mut ULONG_PTR,
105 lp_overlapped: *mut *mut OVERLAPPED,
106 dw_milliseconds: DWORD,
107) -> io::Result<bool> {
108 let result = unsafe {
109 GetQueuedCompletionStatus(
110 *completion_port.mutex_guarded_handle()?,
111 lp_number_of_bytes,
112 lp_completion_key,
113 lp_overlapped,
114 dw_milliseconds,
115 )
116 };
117
118 match result {
119 0 => {
120 let error = unsafe { GetLastError() };
121 if error == WAIT_TIMEOUT {
122 Ok(false)
123 } else {
124 Err(io::Error::from_raw_os_error(error as i32))
125 }
126 }
127 _ => Ok(true),
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use std::{fs::File, io::Write, path::Path};
134
135 use super::*;
136 use crate::win::file_handle::{AccessMode, ShareMode};
137
138 #[test]
139 fn test_read_file_to_slice() {
140 // Create a temporary file and write some data into it
141 let path = Path::new("temp.txt");
142 {
143 let mut file = File::create(path).unwrap();
144 file.write_all(b"Hello, world!").unwrap();
145 }
146
147 let mut buffer: [u8; 512] = [0; 512];
148 let mut overlapped = unsafe { std::mem::zeroed::<OVERLAPPED>() };
149 {
150 let file_handle = unsafe {
151 FileHandle::new(path.to_str().unwrap(), AccessMode::Read, ShareMode::Read)
152 }
153 .unwrap();
154
155 // Call the function under test
156 let result =
157 unsafe { read_file_to_slice(&file_handle, &mut buffer, &mut overlapped, 0) };
158
159 assert!(result.is_ok());
160 let result_str = std::str::from_utf8(&buffer[.."Hello, world!".len()]).unwrap();
161 assert_eq!(result_str, "Hello, world!");
162 }
163
164 // Clean up
165 std::fs::remove_file("temp.txt").unwrap();
166 }
167}