Skip to main content

diskann_platform/win/
file_io.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5/// The module provides unsafe wrappers around two Windows API functions: `ReadFile` and `GetQueuedCompletionStatus`.
6///
7/// These wrappers aim to simplify and abstract the use of these functions, providing easier error handling and a safer interface.
8/// They return standard Rust `io::Result` types for convenience and consistency with the rest of the Rust standard library.
9use std::io;
10use std::ptr;
11
12use windows_sys::Win32::{
13    Foundation::{GetLastError, ERROR_IO_PENDING, WAIT_TIMEOUT},
14    Storage::FileSystem::ReadFile,
15    System::IO::{GetQueuedCompletionStatus, OVERLAPPED},
16};
17
18use super::{DWORD, ULONG_PTR};
19use crate::{FileHandle, IOCompletionPort};
20
21/// Asynchronously queue a read request from a file into a buffer slice.
22///
23/// Wraps the unsafe Windows API function `ReadFile`, making it safe to call only when the overlapped buffer
24/// remains valid and unchanged anywhere else during the entire async operation.
25///
26/// Returns a boolean indicating whether the read operation completed synchronously or is pending.
27///
28/// # Safety
29///
30/// This function is marked as `unsafe` because it uses raw pointers and requires the caller to ensure
31/// that the buffer slice and the overlapped buffer stay valid during the whole async operation.
32///
33/// SAFETY: THIS IS NOT ENTIRELY SAFE! PLEASE READ!
34///
35/// This function is thread safe i.e. the same file handle can be used by multiple threads to read from the file
36/// as it uses the windows ReadFile API with async mode using OVERLAPPED structure.
37/// ReadFile Function - https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-readfile
38/// Synchronous and Asynchronous I/O - https://learn.microsoft.com/en-us/windows/win32/FileIO/synchronous-and-asynchronous-i-o
39///
40/// The only caveat is read operation is followed by polling on the handle using GetQueuedCompletionStatus API.
41/// If multiple threads are submitting read requests and polling then polling will return the completion status of any of the read requests.
42/// This is because GetQueuedCompletionStatus API returns the completion status of any of the read requests that are completed.
43pub unsafe fn read_file_to_slice<T>(
44    file_handle: &FileHandle,
45    buffer_slice: &mut [T],
46    overlapped: *mut OVERLAPPED,
47    offset: u64,
48) -> io::Result<bool> {
49    let num_bytes = std::mem::size_of_val(buffer_slice);
50    unsafe {
51        ptr::write(overlapped, std::mem::zeroed());
52        (*overlapped).Anonymous.Anonymous.Offset = offset as u32;
53        (*overlapped).Anonymous.Anonymous.OffsetHigh = (offset >> 32) as u32;
54    }
55
56    let win32_result: i32 = unsafe {
57        ReadFile(
58            file_handle.handle,
59            buffer_slice.as_mut_ptr().cast::<u8>(),
60            num_bytes as DWORD,
61            ptr::null_mut(),
62            overlapped,
63        )
64    };
65
66    // `ReadFile` returns zero on failure.
67    if win32_result == 0 {
68        let error = unsafe { GetLastError() };
69        return if error != ERROR_IO_PENDING {
70            Err(io::Error::from_raw_os_error(error as i32))
71        } else {
72            Ok(false)
73        };
74    }
75
76    Ok(true)
77}
78
79/// Retrieves the results of an asynchronous I/O operation on an I/O completion port.
80///
81/// Wraps the unsafe Windows API function `GetQueuedCompletionStatus`, making it safe to call only when the overlapped buffer
82/// remains valid and unchanged anywhere else during the entire async operation.
83///
84/// Returns a boolean indicating whether an I/O operation completed synchronously or is still pending.
85///
86/// # Safety
87///
88/// This function is marked as `unsafe` because it uses raw pointers and requires the caller to ensure
89/// that the overlapped buffer stays valid during the whole async operation.
90///
91/// SAFETY: THIS IS NOT ENTIRELY SAFE! PLEASE READ!
92///
93/// This function is thread safe i.e. the same file handle can be used by multiple threads to read from the file
94/// as it uses the windows ReadFile API with async mode using OVERLAPPED structure.
95/// ReadFile Function - https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-readfile
96/// Synchronous and Asynchronous I/O - https://learn.microsoft.com/en-us/windows/win32/FileIO/synchronous-and-asynchronous-i-o
97///
98/// The only caveat is read operation is followed by polling on the handle using GetQueuedCompletionStatus API.
99/// If multiple threads are submitting read requests and polling then polling will return the completion status of any of the read requests.
100/// This is because GetQueuedCompletionStatus API returns the completion status of any of the read requests that are completed.
101pub unsafe fn get_queued_completion_status(
102    completion_port: &IOCompletionPort,
103    lp_number_of_bytes: &mut DWORD,
104    lp_completion_key: &mut ULONG_PTR,
105    lp_overlapped: *mut *mut OVERLAPPED,
106    dw_milliseconds: DWORD,
107) -> io::Result<bool> {
108    let result = unsafe {
109        GetQueuedCompletionStatus(
110            *completion_port.mutex_guarded_handle()?,
111            lp_number_of_bytes,
112            lp_completion_key,
113            lp_overlapped,
114            dw_milliseconds,
115        )
116    };
117
118    match result {
119        0 => {
120            let error = unsafe { GetLastError() };
121            if error == WAIT_TIMEOUT {
122                Ok(false)
123            } else {
124                Err(io::Error::from_raw_os_error(error as i32))
125            }
126        }
127        _ => Ok(true),
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use std::{fs::File, io::Write, path::Path};
134
135    use super::*;
136    use crate::win::file_handle::{AccessMode, ShareMode};
137
138    #[test]
139    fn test_read_file_to_slice() {
140        // Create a temporary file and write some data into it
141        let path = Path::new("temp.txt");
142        {
143            let mut file = File::create(path).unwrap();
144            file.write_all(b"Hello, world!").unwrap();
145        }
146
147        let mut buffer: [u8; 512] = [0; 512];
148        let mut overlapped = unsafe { std::mem::zeroed::<OVERLAPPED>() };
149        {
150            let file_handle = unsafe {
151                FileHandle::new(path.to_str().unwrap(), AccessMode::Read, ShareMode::Read)
152            }
153            .unwrap();
154
155            // Call the function under test
156            let result =
157                unsafe { read_file_to_slice(&file_handle, &mut buffer, &mut overlapped, 0) };
158
159            assert!(result.is_ok());
160            let result_str = std::str::from_utf8(&buffer[.."Hello, world!".len()]).unwrap();
161            assert_eq!(result_str, "Hello, world!");
162        }
163
164        // Clean up
165        std::fs::remove_file("temp.txt").unwrap();
166    }
167}