use std::{
cell::Cell,
marker::PhantomData,
pin::Pin,
};
use windows::Win32::{
Foundation::{
CloseHandle,
FALSE,
HANDLE,
INVALID_HANDLE_VALUE,
NTSTATUS,
WAIT_TIMEOUT,
WIN32_ERROR,
},
Networking::WinSock::SOCKET,
System::IO::{
CreateIoCompletionPort,
GetQueuedCompletionStatusEx,
OVERLAPPED,
OVERLAPPED_ENTRY,
},
};
use crate::{
catnap::transport::error::translate_ntstatus,
collections::pin_slab::PinSlab,
expect_some,
runtime::{
fail::Fail,
SharedConditionVariable,
},
};
#[derive(Clone, Copy)]
pub struct OverlappedResult {
pub completion_key: usize,
pub result: NTSTATUS,
pub bytes_transferred: u32,
}
#[repr(C)]
struct OverlappedCompletion<S: Unpin> {
overlapped: OVERLAPPED,
condition_variable: Option<SharedConditionVariable>,
pinslab_index: Option<usize>,
completion_key: usize,
state: S,
}
pub struct IoCompletionPort<S: Unpin> {
iocp: HANDLE,
ops: PinSlab<OverlappedCompletion<S>>,
_marker: PhantomData<Cell<()>>,
}
impl OverlappedResult {
pub fn new(overlapped: OVERLAPPED, completion_key: usize) -> OverlappedResult {
Self {
completion_key,
result: NTSTATUS(overlapped.Internal as i32),
bytes_transferred: overlapped.InternalHigh as u32,
}
}
pub fn ok(&self) -> Result<(), Fail> {
let win32_error: WIN32_ERROR = translate_ntstatus(self.result);
if win32_error.is_ok() {
Ok(())
} else {
Err(win32_error.into())
}
}
}
impl<S: Unpin> IoCompletionPort<S> {
pub fn new() -> Result<IoCompletionPort<S>, Fail> {
let iocp: HANDLE = match unsafe { CreateIoCompletionPort(INVALID_HANDLE_VALUE, None, 0, 1) } {
Ok(handle) => handle,
Err(err) => return Err(err.into()),
};
if iocp.is_invalid() {
const MSG: &str = "CreateIoCompletionPort succeeded by returned handle is invalid";
error!("{}", MSG);
return Err(Fail::new(libc::EFAULT, MSG));
}
Ok(IoCompletionPort {
iocp,
ops: PinSlab::<OverlappedCompletion<S>>::default(),
_marker: PhantomData,
})
}
#[allow(unused)]
pub fn associate_socket(&self, s: SOCKET, completion_key: usize) -> Result<(), Fail> {
self.associate_handle(HANDLE(s.0 as isize), completion_key)
}
pub fn associate_handle(&self, handle: HANDLE, completion_key: usize) -> Result<(), Fail> {
match unsafe { CreateIoCompletionPort(handle, self.iocp, completion_key, 0) } {
Ok(_) => Ok(()),
Err(err) => Err(err.into()),
}
}
pub async unsafe fn do_io<F1, F2, R>(&mut self, state: S, start: F1, finish: F2) -> Result<R, Fail>
where
for<'a> F1: FnOnce(Pin<&'a mut S>, *mut OVERLAPPED) -> Result<(), Fail>,
for<'a> F2: FnOnce(Pin<&'a mut S>, OverlappedResult) -> Result<R, Fail>,
{
let pinslab_index: usize = match self.ops.insert(OverlappedCompletion::new(state)) {
Some(index) => index,
None => {
return Err(Fail::new(
libc::EINVAL,
"Could not allocate space for overlapped completion",
))
},
};
let mut pinned_completion: Pin<&mut OverlappedCompletion<S>> =
expect_some!(self.ops.get_pin_mut(pinslab_index), "Just inserted this");
pinned_completion.as_mut().set_pinslab_index(pinslab_index);
let overlapped: *mut OVERLAPPED = pinned_completion.as_mut().marshal();
let result: Result<R, Fail> = match start(pinned_completion.as_mut().get_state(), overlapped) {
Ok(()) => {
while let Some(mut cv) = pinned_completion.as_ref().get_cv() {
cv.wait().await;
}
let overlapped_result: OverlappedResult = OverlappedResult::new(
pinned_completion.as_mut().overlapped,
pinned_completion.as_mut().completion_key,
);
finish(pinned_completion.as_mut().get_state(), overlapped_result)
},
Err(err) => Err(err),
};
self.ops.remove_unpin(pinslab_index);
result
}
#[allow(unused)]
fn process_overlapped(&mut self, entry: &OVERLAPPED_ENTRY) {
if let Some(overlapped) = std::ptr::NonNull::new(entry.lpOverlapped) {
let mut overlapped: Pin<&mut OverlappedCompletion<S>> = OverlappedCompletion::unmarshal(overlapped);
if let Some(mut cv) = overlapped.as_mut().take_cv() {
debug_assert!(entry.dwNumberOfBytesTransferred as usize == overlapped.as_ref().overlapped.InternalHigh);
unsafe { overlapped.get_unchecked_mut() }.completion_key = entry.lpCompletionKey;
cv.signal();
} else {
trace!("I/O dropped for completion key {}", entry.lpCompletionKey);
if let Some(pinslab_index) = overlapped.as_mut().get_pinslab_index() {
self.ops.remove_unpin(pinslab_index);
}
}
}
}
#[allow(unused)]
pub fn process_events(&mut self) -> Result<(), Fail> {
const BATCH_SIZE: usize = 4;
let mut entries: [OVERLAPPED_ENTRY; BATCH_SIZE] = [OVERLAPPED_ENTRY::default(); BATCH_SIZE];
loop {
let mut dequeued: u32 = 0;
match unsafe { GetQueuedCompletionStatusEx(self.iocp, entries.as_mut_slice(), &mut dequeued, 0, FALSE) } {
Ok(()) => {
for i in 0..dequeued {
self.process_overlapped(&entries[i as usize]);
}
if dequeued < BATCH_SIZE as u32 {
return Ok(());
}
},
Err(err) if err.code() == WIN32_ERROR(WAIT_TIMEOUT.0).into() => return Ok(()),
Err(err) => return Err(Fail::from_win32_error(&err, true)),
}
}
}
}
impl<S: Unpin> OverlappedCompletion<S> {
pub fn new(state: S) -> Self {
let cv: SharedConditionVariable = SharedConditionVariable::default();
Self {
overlapped: OVERLAPPED::default(),
condition_variable: Some(cv),
pinslab_index: None,
completion_key: 0,
state,
}
}
pub fn take_cv(self: Pin<&mut Self>) -> Option<SharedConditionVariable> {
self.get_mut().condition_variable.take()
}
pub fn get_cv(self: Pin<&Self>) -> Option<SharedConditionVariable> {
self.condition_variable.clone()
}
pub fn set_pinslab_index(self: Pin<&mut Self>, index: usize) {
self.get_mut().pinslab_index = Some(index);
}
pub fn get_pinslab_index(self: Pin<&mut Self>) -> Option<usize> {
self.pinslab_index
}
pub fn marshal(mut self: Pin<&mut Self>) -> *mut OVERLAPPED {
unsafe { self.as_mut().get_unchecked_mut() as *mut Self }.cast()
}
pub fn unmarshal<'a>(overlapped: std::ptr::NonNull<OVERLAPPED>) -> Pin<&'a mut Self> {
unsafe { Pin::new_unchecked(&mut *(overlapped.as_ptr() as *mut Self)) }
}
pub fn get_state(self: Pin<&mut Self>) -> Pin<&mut S> {
unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().state) }
}
}
impl<S: Unpin> Drop for IoCompletionPort<S> {
fn drop(&mut self) {
let _ = unsafe { CloseHandle(self.iocp) };
}
}
#[cfg(test)]
mod tests {
use crate::{
ensure_eq,
runtime::{
conditional_yield_with_timeout,
SharedDemiRuntime,
},
OperationResult,
QDesc,
QToken,
};
use futures::pin_mut;
use std::{
cell::UnsafeCell,
iter,
ptr::NonNull,
rc::Rc,
sync::atomic::{
AtomicU32,
Ordering,
},
task::{
Context,
Poll,
},
time::{
Duration,
Instant,
},
};
use super::*;
use ::futures::FutureExt;
use anyhow::{
anyhow,
bail,
ensure,
Result,
};
use futures::Future;
use windows::{
core::{
s,
HRESULT,
PCSTR,
},
Win32::{
Foundation::{
ERROR_IO_PENDING,
GENERIC_WRITE,
},
Storage::FileSystem::{
CreateFileA,
ReadFile,
WriteFile,
FILE_FLAGS_AND_ATTRIBUTES,
FILE_FLAG_FIRST_PIPE_INSTANCE,
FILE_FLAG_OVERLAPPED,
FILE_SHARE_NONE,
OPEN_EXISTING,
PIPE_ACCESS_DUPLEX,
},
System::{
Pipes::{
ConnectNamedPipe,
CreateNamedPipeA,
PIPE_READMODE_MESSAGE,
PIPE_REJECT_REMOTE_CLIENTS,
PIPE_TYPE_MESSAGE,
},
IO::PostQueuedCompletionStatus,
},
},
};
struct SafeHandle(HANDLE);
struct PollOnceFuture<F: Future<Output = (QDesc, OperationResult)>> {
future: F,
count: usize,
}
impl<F: Future<Output = (QDesc, OperationResult)>> Future for PollOnceFuture<F> {
type Output = (QDesc, OperationResult);
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<(QDesc, OperationResult)> {
if self.count == 0 {
unsafe { self.as_mut().get_unchecked_mut() }.count += 1;
unsafe { self.map_unchecked_mut(|s| &mut s.future) }.poll(cx)
} else {
Poll::Ready((
QDesc::from(0),
OperationResult::Failed(Fail::new(libc::EALREADY, "called more than once")),
))
}
}
}
impl Drop for SafeHandle {
fn drop(&mut self) {
let _ = unsafe { CloseHandle(self.0) };
}
}
fn make_iocp<S: Unpin>() -> Result<IoCompletionPort<S>> {
IoCompletionPort::new().map_err(|err: Fail| anyhow!("Failed to create I/O completion port: {}", err))
}
fn post_completion<S: Unpin>(
iocp: &IoCompletionPort<S>,
overlapped: *const OVERLAPPED,
completion_key: usize,
) -> Result<()> {
unsafe { PostQueuedCompletionStatus(iocp.iocp, 0, completion_key, Some(overlapped)) }
.map_err(|err| anyhow!("PostQueuedCompletionStatus failed: {}", err))
}
fn anyhow_fail(error: anyhow::Error) -> Fail {
Fail::new(libc::EFAULT, error.to_string().as_str())
}
fn is_overlapped_ok(result: Result<(), windows::core::Error>) -> Result<(), Fail> {
if let Err(err) = result {
if err.code() == HRESULT::from(ERROR_IO_PENDING) {
Ok(())
} else {
Err(err.into())
}
} else {
Ok(())
}
}
async fn run_as_io_op<F: Future<Output = Result<OperationResult, Fail>>>(future: F) -> (QDesc, OperationResult) {
match future.await {
Ok(result) => (QDesc::from(0), result),
Err(err) => (QDesc::from(0), OperationResult::Failed(err)),
}
}
#[test]
fn test_marshal_unmarshal() -> Result<()> {
let mut completion: Pin<Box<OverlappedCompletion<()>>> = Box::pin(OverlappedCompletion::new(()));
ensure_eq!(
completion.as_mut().marshal() as *const OVERLAPPED as usize,
&completion.overlapped as *const OVERLAPPED as usize
);
ensure_eq!(
completion.as_mut().marshal() as usize,
completion.as_ref().get_ref() as *const OverlappedCompletion<()> as usize
);
ensure_eq!(
(&completion.overlapped as *const OVERLAPPED) as usize,
(completion.as_ref().get_ref() as *const OverlappedCompletion<()>) as usize
);
let overlapped_ptr: NonNull<OVERLAPPED> = NonNull::new(completion.as_mut().marshal()).unwrap();
let unmarshalled: Pin<&mut OverlappedCompletion<()>> = OverlappedCompletion::unmarshal(overlapped_ptr);
ensure_eq!(
unmarshalled.as_ref().get_ref() as *const OverlappedCompletion<()> as usize,
completion.as_ref().get_ref() as *const OverlappedCompletion<()> as usize
);
Ok(())
}
#[test]
fn test_event_processor() -> Result<()> {
const COMPLETION_KEY: usize = 123;
let mut iocp: IoCompletionPort<()> = make_iocp()?;
let overlapped: OverlappedCompletion<()> = OverlappedCompletion::new(());
let mut cv: SharedConditionVariable = overlapped.condition_variable.clone().unwrap();
pin_mut!(overlapped);
let mut runtime: SharedDemiRuntime = SharedDemiRuntime::default();
let server = run_as_io_op(async move {
cv.wait().await;
Ok(OperationResult::Close)
})
.fuse();
let server_task: QToken = runtime.insert_io_coroutine("ioc_server", Box::pin(server)).unwrap();
ensure!(runtime.run_any(&[server_task], Duration::ZERO).is_none());
post_completion(&iocp, overlapped.as_mut().marshal(), COMPLETION_KEY)?;
iocp.process_events()?;
ensure!(
overlapped.condition_variable.is_none(),
"yielder should be cleared by iocp"
);
ensure_eq!(
overlapped.as_ref().completion_key,
COMPLETION_KEY,
"completion key not updated"
);
ensure!(runtime.run_any(&[server_task], Duration::ZERO).is_some());
Ok(())
}
#[test]
fn test_overlapped_named_pipe() -> Result<()> {
const MESSAGE: &str = "Hello world!";
const PIPE_NAME: PCSTR = s!(r"\\.\pipe\demikernel-test-pipe");
const BUFFER_SIZE: u32 = 128;
const COMPLETION_KEY: usize = 0xFEEDF00D;
let server_pipe: SafeHandle = SafeHandle(unsafe {
CreateNamedPipeA(
PIPE_NAME,
PIPE_ACCESS_DUPLEX | FILE_FLAG_FIRST_PIPE_INSTANCE | FILE_FLAG_OVERLAPPED,
PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS,
1,
BUFFER_SIZE,
BUFFER_SIZE,
0,
None,
)
}?);
let server_state: Rc<AtomicU32> = Rc::new(AtomicU32::new(0));
let server_state_view: Rc<AtomicU32> = server_state.clone();
let mut iocp: UnsafeCell<IoCompletionPort<Rc<Vec<u8>>>> = UnsafeCell::new(make_iocp().map_err(anyhow_fail)?);
iocp.get_mut().associate_handle(server_pipe.0, COMPLETION_KEY)?;
let iocp_ref: &mut IoCompletionPort<Rc<Vec<u8>>> = unsafe { &mut *iocp.get() };
let server = Box::pin(
run_as_io_op(async move {
unsafe {
iocp_ref.do_io(
Rc::new(Vec::<u8>::new()),
|_: Pin<&mut Rc<Vec<u8>>>, overlapped: *mut OVERLAPPED| -> Result<(), Fail> {
server_state.fetch_add(1, Ordering::Relaxed);
is_overlapped_ok(ConnectNamedPipe(server_pipe.0, Some(overlapped)))
},
|_: Pin<&mut Rc<Vec<u8>>>, result: OverlappedResult| -> Result<(), Fail> { result.ok() },
)
}
.await?;
server_state.fetch_add(1, Ordering::Relaxed);
let mut buffer: Rc<Vec<u8>> =
Rc::new(iter::repeat(0u8).take(BUFFER_SIZE as usize).collect::<Vec<u8>>());
buffer = unsafe {
iocp_ref.do_io(
buffer,
|state: Pin<&mut Rc<Vec<u8>>>, overlapped: *mut OVERLAPPED| -> Result<(), Fail> {
let vec: &mut Vec<u8> = Rc::get_mut(state.get_mut()).unwrap();
vec.resize(BUFFER_SIZE as usize, 0u8);
is_overlapped_ok(ReadFile(
server_pipe.0,
Some(vec.as_mut_slice()),
None,
Some(overlapped),
))
},
|mut state: Pin<&mut Rc<Vec<u8>>>, result: OverlappedResult| -> Result<Rc<Vec<u8>>, Fail> {
match result.ok() {
Ok(()) => {
if result.bytes_transferred == 0 {
Err(Fail::new(libc::EINVAL, "not bytes received"))
} else {
Rc::get_mut(state.as_mut().get_mut())
.unwrap()
.resize(result.bytes_transferred as usize, 0u8);
Ok(Rc::clone(Pin::get_mut(state)))
}
},
Err(fail) => Err(fail),
}
},
)
}
.await?;
let message: &str = std::str::from_utf8(buffer.as_slice())
.map_err(|_| Fail::new(libc::EINVAL, "utf8 conversion failed"))?;
if message != MESSAGE {
let err_msg: String = format!("expected \"{}\", got \"{}\"", MESSAGE, message);
Err(Fail::new(libc::EINVAL, err_msg.as_str()))
} else {
Ok(OperationResult::Close)
}
})
.fuse(),
);
let mut runtime: SharedDemiRuntime = SharedDemiRuntime::default();
let server_task: QToken = runtime.insert_io_coroutine("ioc_server", server).unwrap();
let mut wait_for_state = |state| -> Result<(), Fail> {
while server_state_view.load(Ordering::Relaxed) < state {
iocp.get_mut().process_events()?;
if let Some(result) = runtime.run_any(&[server_task], Duration::ZERO) {
return match result {
(_, _, OperationResult::Failed(e)) => Err(e),
_ => Err(Fail::new(libc::EFAULT, "server completed early unexpectedly")),
};
}
}
Ok(())
};
wait_for_state(1)?;
let client_handle: SafeHandle = SafeHandle(unsafe {
CreateFileA(
PIPE_NAME,
GENERIC_WRITE.0,
FILE_SHARE_NONE,
None,
OPEN_EXISTING,
FILE_FLAGS_AND_ATTRIBUTES::default(),
HANDLE(0),
)
}?);
wait_for_state(2)?;
let mut bytes_written: u32 = 0;
unsafe {
WriteFile(
client_handle.0,
Some(MESSAGE.as_bytes()),
Some(&mut bytes_written),
None,
)?;
}
std::mem::drop(client_handle);
let result: OperationResult = loop {
iocp.get_mut().process_events()?;
runtime.poll();
if let Some((_, result)) = runtime.get_completed_task(&server_task) {
break result;
}
};
match result {
OperationResult::Close => Ok(()),
_ => bail!("server did not complete successfully"),
}
}
#[test]
fn test_cancel_io() -> Result<()> {
const PIPE_NAME: PCSTR = s!(r"\\.\pipe\demikernel-test-cancel-pipe");
const BUFFER_SIZE: u32 = 128;
const COMPLETION_KEY: usize = 0xFEEDF00D;
let server_pipe: SafeHandle = SafeHandle(unsafe {
CreateNamedPipeA(
PIPE_NAME,
PIPE_ACCESS_DUPLEX | FILE_FLAG_FIRST_PIPE_INSTANCE | FILE_FLAG_OVERLAPPED,
PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS,
1,
BUFFER_SIZE,
BUFFER_SIZE,
0,
None,
)
}?);
let server_state: Rc<AtomicU32> = Rc::new(AtomicU32::new(0));
let server_state_view: Rc<AtomicU32> = server_state.clone();
let mut iocp: UnsafeCell<IoCompletionPort<()>> = UnsafeCell::new(make_iocp().map_err(anyhow_fail)?);
iocp.get_mut().associate_handle(server_pipe.0, COMPLETION_KEY)?;
let iocp_ref: &mut IoCompletionPort<()> = unsafe { &mut *iocp.get() };
let server = run_as_io_op(async move {
match conditional_yield_with_timeout(
unsafe {
iocp_ref.do_io(
(),
|_: Pin<&mut ()>, overlapped: *mut OVERLAPPED| -> Result<(), Fail> {
server_state.fetch_add(1, Ordering::Relaxed);
is_overlapped_ok(ConnectNamedPipe(server_pipe.0, Some(overlapped)))
},
|_: Pin<&mut ()>, result: OverlappedResult| -> Result<(), Fail> { result.ok() },
)
},
Duration::from_micros(100),
)
.await
{
Ok(_) => Ok(OperationResult::Close),
Err(e) => Ok(OperationResult::Failed(e)),
}
})
.fuse();
let mut runtime: SharedDemiRuntime = SharedDemiRuntime::default();
let server_task: QToken = runtime.insert_io_coroutine("ioc_server", Box::pin(server)).unwrap();
ensure!(
server_state_view.load(Ordering::Relaxed) < 1,
"server execution should not start yet"
);
let iocp_ref: &mut IoCompletionPort<()> = unsafe { &mut *iocp.get() };
iocp_ref.process_events()?;
ensure!(
runtime.run_any(&[server_task], Duration::ZERO).is_none(),
"server should not be done"
);
let result: OperationResult = loop {
runtime.advance_clock(Instant::now());
iocp.get_mut().process_events()?;
if let Some((i, _, result)) = runtime.run_any(&[server_task], Duration::ZERO) {
ensure_eq!(i, 0);
break result;
}
};
match result {
OperationResult::Failed(Fail { errno, cause: _ }) if errno == libc::ETIMEDOUT => Ok(()),
OperationResult::Failed(e) => bail!("coroutine failed with unexpected code: {:?}", e),
_ => bail!("expected coroutine to fail"),
}
}
}