use core::fmt::Debug;
use std::{
marker::PhantomPinned,
os::windows::prelude::RawSocket,
pin::Pin,
sync::{Arc, Mutex},
};
use windows_sys::Win32::{
Foundation::{ERROR_INVALID_HANDLE, ERROR_IO_PENDING, HANDLE, STATUS_CANCELLED},
Networking::WinSock::{
WSAGetLastError, WSAIoctl, SIO_BASE_HANDLE, SIO_BSP_HANDLE, SIO_BSP_HANDLE_POLL,
SIO_BSP_HANDLE_SELECT, SOCKET_ERROR,
},
System::WindowsProgramming::IO_STATUS_BLOCK,
};
use super::{afd, from_overlapped, into_overlapped, Afd, AfdPollInfo, Event};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum SockPollStatus {
Idle,
Pending,
Cancelled,
}
#[derive(Debug)]
pub struct SocketState {
pub socket: RawSocket,
pub inner: Option<Pin<Arc<Mutex<SockState>>>>,
pub token: mio::Token,
pub interest: mio::Interest,
}
impl SocketState {
pub fn new(socket: RawSocket) -> Self {
Self {
socket,
inner: None,
token: mio::Token(0),
interest: mio::Interest::READABLE,
}
}
}
pub struct SockState {
pub iosb: IO_STATUS_BLOCK,
pub poll_info: AfdPollInfo,
pub afd: Arc<Afd>,
pub base_socket: RawSocket,
pub user_evts: u32,
pub pending_evts: u32,
pub user_data: u64,
pub poll_status: SockPollStatus,
pub delete_pending: bool,
pub error: Option<i32>,
_pinned: PhantomPinned,
}
impl SockState {
pub fn new(raw_socket: RawSocket, afd: Arc<Afd>) -> std::io::Result<SockState> {
Ok(SockState {
iosb: unsafe { std::mem::zeroed() },
poll_info: unsafe { std::mem::zeroed() },
afd,
base_socket: get_base_socket(raw_socket)?,
user_evts: 0,
pending_evts: 0,
user_data: 0,
poll_status: SockPollStatus::Idle,
delete_pending: false,
error: None,
_pinned: PhantomPinned,
})
}
pub fn update(&mut self, self_arc: &Pin<Arc<Mutex<SockState>>>) -> std::io::Result<()> {
assert!(!self.delete_pending);
self.error = None;
if let SockPollStatus::Pending = self.poll_status {
if (self.user_evts & afd::KNOWN_EVENTS & !self.pending_evts) == 0 {
} else {
if let Err(e) = self.cancel() {
self.error = e.raw_os_error();
return Err(e);
}
return Ok(());
}
} else if let SockPollStatus::Cancelled = self.poll_status {
} else if let SockPollStatus::Idle = self.poll_status {
self.poll_info.exclusive = 0;
self.poll_info.number_of_handles = 1;
self.poll_info.timeout = i64::MAX;
self.poll_info.handles[0].handle = self.base_socket as HANDLE;
self.poll_info.handles[0].status = 0;
self.poll_info.handles[0].events = self.user_evts | afd::POLL_LOCAL_CLOSE;
let overlapped_ptr = into_overlapped(self_arc.clone());
let result = unsafe {
self.afd
.poll(&mut self.poll_info, &mut self.iosb, overlapped_ptr)
};
if let Err(e) = result {
let code = e.raw_os_error().unwrap();
if code == ERROR_IO_PENDING as i32 {
} else {
drop(from_overlapped(overlapped_ptr as *mut _));
if code == ERROR_INVALID_HANDLE as i32 {
self.mark_delete();
return Ok(());
} else {
self.error = e.raw_os_error();
return Err(e);
}
}
}
self.poll_status = SockPollStatus::Pending;
self.pending_evts = self.user_evts;
} else {
unreachable!("Invalid poll status during update")
}
Ok(())
}
pub fn feed_event(&mut self) -> Option<Event> {
self.poll_status = SockPollStatus::Idle;
self.pending_evts = 0;
let mut afd_events = 0;
unsafe {
if self.delete_pending {
return None;
} else if self.iosb.Anonymous.Status == STATUS_CANCELLED {
} else if self.iosb.Anonymous.Status < 0 {
afd_events = afd::POLL_CONNECT_FAIL;
} else if self.poll_info.number_of_handles < 1 {
} else if self.poll_info.handles[0].events & afd::POLL_LOCAL_CLOSE != 0 {
self.mark_delete();
return None;
} else {
afd_events = self.poll_info.handles[0].events;
}
}
afd_events &= self.user_evts;
if afd_events == 0 {
return None;
}
self.user_evts &= !afd_events;
Some(Event {
data: self.user_data,
flags: afd_events,
})
}
pub fn mark_delete(&mut self) {
if !self.delete_pending {
if let SockPollStatus::Pending = self.poll_status {
drop(self.cancel());
}
self.delete_pending = true;
}
}
pub fn set_event(&mut self, ev: Event) -> bool {
let events = ev.flags | afd::POLL_CONNECT_FAIL | afd::POLL_ABORT;
self.user_evts = events;
self.user_data = ev.data;
(events & !self.pending_evts) != 0
}
pub fn cancel(&mut self) -> std::io::Result<()> {
match self.poll_status {
SockPollStatus::Pending => {}
_ => unreachable!("Invalid poll status during cancel"),
};
unsafe {
self.afd.cancel(&mut self.iosb)?;
}
self.poll_status = SockPollStatus::Cancelled;
self.pending_evts = 0;
Ok(())
}
}
impl Debug for SockState {
#[allow(unused_variables)]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
unimplemented!()
}
}
impl Drop for SockState {
fn drop(&mut self) {
self.mark_delete();
}
}
fn get_base_socket(raw_socket: RawSocket) -> std::io::Result<RawSocket> {
let res = try_get_base_socket(raw_socket, SIO_BASE_HANDLE);
if let Ok(base_socket) = res {
return Ok(base_socket);
}
for &ioctl in &[SIO_BSP_HANDLE_SELECT, SIO_BSP_HANDLE_POLL, SIO_BSP_HANDLE] {
if let Ok(base_socket) = try_get_base_socket(raw_socket, ioctl) {
if base_socket != raw_socket {
return Ok(base_socket);
}
}
}
let os_error = res.unwrap_err();
let err = std::io::Error::from_raw_os_error(os_error);
Err(err)
}
fn try_get_base_socket(raw_socket: RawSocket, ioctl: u32) -> Result<RawSocket, i32> {
let mut base_socket: RawSocket = 0;
let mut bytes: u32 = 0;
let result = unsafe {
WSAIoctl(
raw_socket as usize,
ioctl,
std::ptr::null_mut(),
0,
&mut base_socket as *mut _ as *mut std::ffi::c_void,
std::mem::size_of::<RawSocket>() as u32,
&mut bytes,
std::ptr::null_mut(),
None,
)
};
if result != SOCKET_ERROR {
Ok(base_socket)
} else {
Err(unsafe { WSAGetLastError() })
}
}