use std::{
io,
mem::MaybeUninit,
os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle, RawHandle},
ptr::null_mut,
time::Duration,
};
use compio_log::*;
use windows_sys::Win32::{
Foundation::{
ERROR_BAD_COMMAND, ERROR_BROKEN_PIPE, ERROR_HANDLE_EOF, ERROR_IO_INCOMPLETE,
ERROR_NETNAME_DELETED, ERROR_NO_DATA, ERROR_PIPE_CONNECTED, ERROR_PIPE_NOT_CONNECTED,
FACILITY_NTWIN32, INVALID_HANDLE_VALUE, NTSTATUS, RtlNtStatusToDosError, STATUS_SUCCESS,
},
Storage::FileSystem::SetFileCompletionNotificationModes,
System::{
IO::{
CreateIoCompletionPort, GetQueuedCompletionStatusEx, OVERLAPPED_ENTRY,
PostQueuedCompletionStatus,
},
SystemServices::ERROR_SEVERITY_ERROR,
Threading::INFINITE,
WindowsProgramming::{FILE_SKIP_COMPLETION_PORT_ON_SUCCESS, FILE_SKIP_SET_EVENT_ON_HANDLE},
},
};
use crate::{Overlapped, RawFd, syscall};
cfg_if::cfg_if! {
if #[cfg(feature = "iocp-global")] {
mod global;
pub use global::*;
} else {
mod multi;
pub use multi::*;
}
}
struct CompletionPort {
port: OwnedHandle,
}
impl CompletionPort {
pub const DEFAULT_CAPACITY: usize = 1024;
pub fn new() -> io::Result<Self> {
let port = unsafe { CreateIoCompletionPort(INVALID_HANDLE_VALUE, null_mut(), 0, 1) };
if port.is_null() {
return Err(::std::io::Error::last_os_error());
}
trace!("new iocp handle: {port:p}");
let port = unsafe { OwnedHandle::from_raw_handle(port) };
Ok(Self { port })
}
pub fn attach(&self, fd: RawFd) -> io::Result<()> {
syscall!(
BOOL,
CreateIoCompletionPort(fd, self.port.as_raw_handle(), 0, 0) as isize
)?;
syscall!(
BOOL,
SetFileCompletionNotificationModes(
fd,
(FILE_SKIP_COMPLETION_PORT_ON_SUCCESS | FILE_SKIP_SET_EVENT_ON_HANDLE) as _
)
)?;
Ok(())
}
pub fn post(&self, res: io::Result<usize>, optr: *mut Overlapped) -> io::Result<()> {
if let Some(overlapped) = unsafe { optr.as_mut() } {
match &res {
Ok(transferred) => {
overlapped.base.Internal = STATUS_SUCCESS as _;
overlapped.base.InternalHigh = *transferred;
}
Err(e) => {
let code = e.raw_os_error().unwrap_or(ERROR_BAD_COMMAND as _);
overlapped.base.Internal = ntstatus_from_win32(code) as _;
}
}
}
self.post_raw(optr)
}
pub fn post_raw(&self, optr: *const Overlapped) -> io::Result<()> {
syscall!(
BOOL,
PostQueuedCompletionStatus(self.port.as_raw_handle() as _, 0, 0, optr.cast())
)?;
Ok(())
}
pub fn poll_raw(
&self,
timeout: Option<Duration>,
entries: &mut [MaybeUninit<OVERLAPPED_ENTRY>],
) -> io::Result<usize> {
let mut recv_count = 0;
let timeout = match timeout {
Some(timeout) => timeout.as_millis() as u32,
None => INFINITE,
};
syscall!(
BOOL,
GetQueuedCompletionStatusEx(
self.port.as_raw_handle() as _,
entries.as_mut_ptr().cast(),
entries.len() as _,
&mut recv_count,
timeout,
0
)
)?;
trace!("recv_count: {recv_count}");
Ok(recv_count as _)
}
pub fn poll(
&self,
timeout: Option<Duration>,
current_driver: Option<RawFd>,
) -> io::Result<impl Iterator<Item = RawEntry>> {
let mut entries = Vec::with_capacity(Self::DEFAULT_CAPACITY);
let len = match self.poll_raw(timeout, entries.spare_capacity_mut()) {
Ok(len) => len,
Err(e) if e.raw_os_error() == Some(ERROR_NETNAME_DELETED as _) => 0,
Err(e) => return Err(e),
};
unsafe { entries.set_len(len) };
Ok(entries.into_iter().filter_map(move |entry| {
let overlapped_ptr: *mut Overlapped = entry.lpOverlapped.cast();
let overlapped = unsafe { &*overlapped_ptr };
if let Some(current_driver) = current_driver
&& overlapped.driver != current_driver
{
if let Err(_e) = syscall!(
BOOL,
PostQueuedCompletionStatus(
overlapped.driver as _,
entry.dwNumberOfBytesTransferred,
entry.lpCompletionKey,
entry.lpOverlapped,
)
) {
error!(
"fail to repost entry ({}, {}, {:p}) to driver {:p}: {:?}",
entry.dwNumberOfBytesTransferred,
entry.lpCompletionKey,
entry.lpOverlapped,
overlapped.driver,
_e
);
}
return None;
}
let status = overlapped.base.Internal as NTSTATUS;
let res = if status >= 0 {
Ok(overlapped.base.InternalHigh)
} else {
let error = unsafe { RtlNtStatusToDosError(status) };
match error {
ERROR_IO_INCOMPLETE
| ERROR_NETNAME_DELETED
| ERROR_HANDLE_EOF
| ERROR_BROKEN_PIPE
| ERROR_PIPE_CONNECTED
| ERROR_PIPE_NOT_CONNECTED
| ERROR_NO_DATA => Ok(0),
_ => Err(io::Error::from_raw_os_error(error as _)),
}
};
Some(RawEntry::new(overlapped_ptr, res))
}))
}
}
impl AsRawHandle for CompletionPort {
fn as_raw_handle(&self) -> RawHandle {
self.port.as_raw_handle()
}
}
pub(crate) struct RawEntry {
pub overlapped: *mut Overlapped,
pub result: io::Result<usize>,
}
impl RawEntry {
pub fn new(overlapped: *mut Overlapped, result: io::Result<usize>) -> Self {
Self { overlapped, result }
}
}
#[inline]
fn ntstatus_from_win32(x: i32) -> NTSTATUS {
if x <= 0 {
x
} else {
((x) & 0x0000FFFF) | (FACILITY_NTWIN32 << 16) as NTSTATUS | ERROR_SEVERITY_ERROR as NTSTATUS
}
}