use std::{
collections::{HashMap, HashSet},
io,
mem::ManuallyDrop,
os::{
raw::c_void,
windows::prelude::{
AsRawHandle, AsRawSocket, FromRawHandle, FromRawSocket, IntoRawHandle, IntoRawSocket,
RawHandle,
},
},
pin::Pin,
ptr::{null, NonNull},
sync::Arc,
task::Poll,
time::Duration,
};
use compio_buf::BufResult;
use compio_log::{instrument, trace};
use slab::Slab;
use windows_sys::Win32::{
Foundation::{ERROR_BUSY, ERROR_OPERATION_ABORTED, ERROR_TIMEOUT, WAIT_OBJECT_0, WAIT_TIMEOUT},
Networking::WinSock::{WSACleanup, WSAStartup, WSADATA},
System::{
Threading::{
CloseThreadpoolWait, CreateThreadpoolWait, SetThreadpoolWait,
WaitForThreadpoolWaitCallbacks, PTP_CALLBACK_INSTANCE, PTP_WAIT,
},
IO::OVERLAPPED,
},
};
use crate::{syscall, AsyncifyPool, Entry, OutEntries, ProactorBuilder};
pub(crate) mod op;
mod cp;
pub(crate) use windows_sys::Win32::Networking::WinSock::{
socklen_t, SOCKADDR_STORAGE as sockaddr_storage,
};
pub type RawFd = RawHandle;
pub trait AsRawFd {
fn as_raw_fd(&self) -> RawFd;
}
pub trait FromRawFd {
unsafe fn from_raw_fd(fd: RawFd) -> Self;
}
pub trait IntoRawFd {
fn into_raw_fd(self) -> RawFd;
}
impl AsRawFd for std::fs::File {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_handle()
}
}
impl AsRawFd for socket2::Socket {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_socket() as _
}
}
impl FromRawFd for std::fs::File {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Self::from_raw_handle(fd)
}
}
impl FromRawFd for socket2::Socket {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Self::from_raw_socket(fd as _)
}
}
impl IntoRawFd for std::fs::File {
fn into_raw_fd(self) -> RawFd {
self.into_raw_handle()
}
}
impl IntoRawFd for socket2::Socket {
fn into_raw_fd(self) -> RawFd {
self.into_raw_socket() as _
}
}
pub enum OpType {
Overlapped,
Blocking,
Event(RawFd),
}
pub trait OpCode {
fn op_type(&self) -> OpType {
OpType::Overlapped
}
unsafe fn operate(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>>;
unsafe fn cancel(self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> io::Result<()> {
let _optr = optr; Ok(())
}
}
pub(crate) struct Driver {
port: cp::Port,
waits: HashMap<usize, WinThreadpollWait>,
cancelled: HashSet<usize>,
pool: AsyncifyPool,
notify_overlapped: Arc<Overlapped<()>>,
}
impl Driver {
const NOTIFY: usize = usize::MAX;
pub fn new(builder: &ProactorBuilder) -> io::Result<Self> {
instrument!(compio_log::Level::TRACE, "new", ?builder);
let mut data: WSADATA = unsafe { std::mem::zeroed() };
syscall!(SOCKET, WSAStartup(0x202, &mut data))?;
let port = cp::Port::new()?;
let driver = port.as_raw_handle() as _;
Ok(Self {
port,
waits: HashMap::default(),
cancelled: HashSet::default(),
pool: builder.create_or_get_thread_pool(),
notify_overlapped: Arc::new(Overlapped::new(driver, Self::NOTIFY, ())),
})
}
pub fn create_op<T: OpCode + 'static>(&self, user_data: usize, op: T) -> RawOp {
RawOp::new(self.port.as_raw_handle() as _, user_data, op)
}
pub fn attach(&mut self, fd: RawFd) -> io::Result<()> {
self.port.attach(fd)
}
pub fn cancel(&mut self, user_data: usize, registry: &mut Slab<RawOp>) {
instrument!(compio_log::Level::TRACE, "cancel", user_data);
trace!("cancel RawOp");
self.cancelled.insert(user_data);
if let Some(op) = registry.get_mut(user_data) {
let overlapped_ptr = op.as_mut_ptr();
let op = op.as_op_pin();
trace!("call OpCode::cancel");
unsafe { op.cancel(overlapped_ptr.cast()) }.ok();
}
}
pub fn push(&mut self, user_data: usize, op: &mut RawOp) -> Poll<io::Result<usize>> {
instrument!(compio_log::Level::TRACE, "push", user_data);
if self.cancelled.remove(&user_data) {
trace!("pushed RawOp already cancelled");
Poll::Ready(Err(io::Error::from_raw_os_error(
ERROR_OPERATION_ABORTED as _,
)))
} else {
trace!("push RawOp");
let optr = op.as_mut_ptr();
let op_pin = op.as_op_pin();
match op_pin.op_type() {
OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
OpType::Blocking => {
if self.push_blocking(op)? {
Poll::Pending
} else {
Poll::Ready(Err(io::Error::from_raw_os_error(ERROR_BUSY as _)))
}
}
OpType::Event(e) => {
self.waits.insert(
user_data,
WinThreadpollWait::new(self.port.handle(), e, op)?,
);
Poll::Pending
}
}
}
}
fn push_blocking(&mut self, op: &mut RawOp) -> io::Result<bool> {
struct SendWrapper<T>(T);
unsafe impl<T> Send for SendWrapper<T> {}
let optr = SendWrapper(NonNull::from(op));
let port = self.port.handle();
Ok(self
.pool
.dispatch(move || {
#[allow(clippy::redundant_locals)]
let mut optr = optr;
let op = unsafe { optr.0.as_mut() };
let optr = op.as_mut_ptr();
let res = op.operate_blocking();
port.post(res, optr).ok();
})
.is_ok())
}
fn create_entry(
cancelled: &mut HashSet<usize>,
waits: &mut HashMap<usize, WinThreadpollWait>,
entry: Entry,
) -> Option<Entry> {
let user_data = entry.user_data();
if user_data != Self::NOTIFY {
waits.remove(&user_data);
let result = if cancelled.remove(&user_data) {
Err(io::Error::from_raw_os_error(ERROR_OPERATION_ABORTED as _))
} else {
entry.into_result()
};
Some(Entry::new(user_data, result))
} else {
None
}
}
pub unsafe fn poll(
&mut self,
timeout: Option<Duration>,
mut entries: OutEntries<impl Extend<usize>>,
) -> io::Result<()> {
instrument!(compio_log::Level::TRACE, "poll", ?timeout);
entries.extend(
self.port
.poll(timeout)?
.filter_map(|e| Self::create_entry(&mut self.cancelled, &mut self.waits, e)),
);
Ok(())
}
pub fn handle(&self) -> io::Result<NotifyHandle> {
Ok(NotifyHandle::new(
self.port.handle(),
self.notify_overlapped.clone(),
))
}
}
impl AsRawFd for Driver {
fn as_raw_fd(&self) -> RawFd {
self.port.as_raw_handle()
}
}
impl Drop for Driver {
fn drop(&mut self) {
syscall!(SOCKET, WSACleanup()).ok();
}
}
pub struct NotifyHandle {
port: cp::PortHandle,
overlapped: Arc<Overlapped<()>>,
}
impl NotifyHandle {
fn new(port: cp::PortHandle, overlapped: Arc<Overlapped<()>>) -> Self {
Self { port, overlapped }
}
pub fn notify(&self) -> io::Result<()> {
self.port.post_raw(self.overlapped.as_ref())
}
}
struct WinThreadpollWait {
wait: PTP_WAIT,
#[allow(dead_code)]
context: Box<WinThreadpollWaitContext>,
}
impl WinThreadpollWait {
pub fn new(port: cp::PortHandle, event: RawFd, op: &mut RawOp) -> io::Result<Self> {
let mut context = Box::new(WinThreadpollWaitContext { port, op });
let wait = syscall!(
BOOL,
CreateThreadpoolWait(
Some(Self::wait_callback),
(&mut *context) as *mut WinThreadpollWaitContext as _,
null()
)
)?;
unsafe {
SetThreadpoolWait(wait, event as _, null());
}
Ok(Self { wait, context })
}
unsafe extern "system" fn wait_callback(
_instance: PTP_CALLBACK_INSTANCE,
context: *mut c_void,
_wait: PTP_WAIT,
result: u32,
) {
let context = &*(context as *mut WinThreadpollWaitContext);
let res = match result {
WAIT_OBJECT_0 => Ok(0),
WAIT_TIMEOUT => Err(io::Error::from_raw_os_error(ERROR_TIMEOUT as _)),
_ => Err(io::Error::from_raw_os_error(result as _)),
};
let res = if res.is_err() {
res
} else {
let op = unsafe { &mut *context.op };
op.operate_blocking()
};
context.port.post(res, (*context.op).as_mut_ptr()).ok();
}
}
impl Drop for WinThreadpollWait {
fn drop(&mut self) {
unsafe {
SetThreadpoolWait(self.wait, 0, null());
WaitForThreadpoolWaitCallbacks(self.wait, 1);
CloseThreadpoolWait(self.wait);
}
}
}
struct WinThreadpollWaitContext {
port: cp::PortHandle,
op: *mut RawOp,
}
#[repr(C)]
pub struct Overlapped<T: ?Sized> {
pub base: OVERLAPPED,
pub driver: RawFd,
pub user_data: usize,
pub op: T,
}
impl<T> Overlapped<T> {
pub(crate) fn new(driver: RawFd, user_data: usize, op: T) -> Self {
Self {
base: unsafe { std::mem::zeroed() },
driver,
user_data,
op,
}
}
}
unsafe impl Send for Overlapped<()> {}
unsafe impl Sync for Overlapped<()> {}
pub(crate) struct RawOp {
op: NonNull<Overlapped<dyn OpCode>>,
cancelled: bool,
result: Option<io::Result<usize>>,
}
impl RawOp {
pub(crate) fn new(driver: RawFd, user_data: usize, op: impl OpCode + 'static) -> Self {
let op = Overlapped::new(driver, user_data, op);
let op = Box::new(op) as Box<Overlapped<dyn OpCode>>;
Self {
op: unsafe { NonNull::new_unchecked(Box::into_raw(op)) },
cancelled: false,
result: None,
}
}
pub fn as_op_pin(&mut self) -> Pin<&mut dyn OpCode> {
unsafe { Pin::new_unchecked(&mut self.op.as_mut().op) }
}
pub fn as_mut_ptr(&mut self) -> *mut Overlapped<dyn OpCode> {
self.op.as_ptr()
}
pub fn set_cancelled(&mut self) -> bool {
self.cancelled = true;
self.has_result()
}
pub fn set_result(&mut self, res: io::Result<usize>) -> bool {
self.result = Some(res);
self.cancelled
}
pub fn has_result(&self) -> bool {
self.result.is_some()
}
pub unsafe fn into_inner<T: OpCode>(self) -> BufResult<usize, T> {
let mut this = ManuallyDrop::new(self);
let overlapped: Box<Overlapped<T>> = Box::from_raw(this.op.cast().as_ptr());
BufResult(this.result.take().unwrap(), overlapped.op)
}
fn operate_blocking(&mut self) -> io::Result<usize> {
let optr = self.as_mut_ptr();
let op = self.as_op_pin();
let res = unsafe { op.operate(optr.cast()) };
match res {
Poll::Pending => unreachable!("this operation is not overlapped"),
Poll::Ready(res) => res,
}
}
}
impl Drop for RawOp {
fn drop(&mut self) {
if self.has_result() {
let _ = unsafe { Box::from_raw(self.op.as_ptr()) };
}
}
}