use std::
{
collections::{BTreeMap},
fmt,
mem,
os::
{
raw::c_void,
windows::io::{AsRawHandle, FromRawHandle, OwnedHandle, RawHandle}
},
ptr::null_mut,
sync::{Arc, RwLock, atomic::{AtomicBool, Ordering}}
};
use windows::
{
Win32::
{
Foundation::{GENERIC_ALL, HANDLE, INVALID_HANDLE_VALUE},
System::{IO::{CreateIoCompletionPort, GetQueuedCompletionStatusEx, OVERLAPPED_ENTRY}, Threading::INFINITE}
},
core::HRESULT
};
use crate::
{
TimerFd,
TimerReadRes,
error::{TimerErrorType, TimerResult},
map_timer_err,
timer_err,
timer_portable::{FdTimerMarker, PollEventType, TimerId, TimerPollOps, poll::PollInterrupt, windows::EventFd}
};
pub mod nt_crap_not_covered_by_crate
{
use std::os::{raw::c_void, windows::io::RawHandle};
use windows::Win32::{Foundation::NTSTATUS};
#[allow(non_camel_case_types)]
#[allow(non_snake_case)]
#[repr(C)]
#[derive(Debug)]
pub struct UNICODE_STRING
{
Length: u16,
MaximumLength: u16,
Buffer: *mut u16,
}
#[allow(non_camel_case_types)]
#[allow(non_snake_case)]
#[repr(C)]
#[derive(Debug)]
pub struct OBJECT_ATTRIBUTES
{
Length: u32,
RootDirectory: RawHandle,
ObjectName: *mut UNICODE_STRING,
Attributes: u32,
SecurityDescriptor: *mut c_void,
SecurityQualityOfService: *mut c_void,
}
#[link(name = "ntdll")]
unsafe extern "system"
{
pub unsafe fn NtCreateWaitCompletionPacket(
IoCompletionHandle: RawHandle, DesiredAccess: u32,
ObjectAttributes: *mut OBJECT_ATTRIBUTES,
) -> NTSTATUS;
pub unsafe fn NtAssociateWaitCompletionPacket(
WaitCompletionPacketHandle: RawHandle,
IoCompletionHandle: RawHandle,
TargetObjectHandle: RawHandle,
KeyContext: *mut c_void,
ApcContext: *mut c_void,
IoStatus: i32,
IoStatusInformation: usize,
AlreadySignaled: *mut u8,
) -> NTSTATUS;
}
}
#[derive(Debug)]
pub struct IocpCrap
{
handle: OwnedHandle,
}
impl fmt::Display for IocpCrap
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "IOCP:{}", self.handle.as_raw_handle() as usize)
}
}
impl AsRawHandle for IocpCrap
{
fn as_raw_handle(&self) -> RawHandle
{
return self.handle.as_raw_handle();
}
}
impl IocpCrap
{
fn new() -> TimerResult<Self>
{
let iopc =
unsafe
{
CreateIoCompletionPort(INVALID_HANDLE_VALUE, None, 0, 0)
.map_err(|e|
map_timer_err!(TimerErrorType::EPoll(e), "CreateIoCompletionPort failed")
)?
};
return Ok(
Self
{
handle: unsafe { OwnedHandle::from_raw_handle(iopc.0) },
}
);
}
fn wait(&self, timer_handlers: &[RawHandle], timeout: Option<u32>) -> TimerResult<Vec<TimerId>>
{
let mut hpacks: Vec<OwnedHandle> = Vec::with_capacity(timer_handlers.len());
for (index, th) in timer_handlers.iter().enumerate()
{
let mut hps: *mut c_void = null_mut();
unsafe
{
nt_crap_not_covered_by_crate::
NtCreateWaitCompletionPacket(
&mut hps as *mut _ as *mut _,
GENERIC_ALL.0,
null_mut())
.ok()
.map_err(|e|
map_timer_err!(TimerErrorType::EPoll(e), "NtCreateWaitCompletionPacket failed 1")
)?;
};
let mut flag: u8 = 0;
let hpack = unsafe { OwnedHandle::from_raw_handle(hps) };
unsafe
{
nt_crap_not_covered_by_crate::
NtAssociateWaitCompletionPacket(hpack.as_raw_handle(), self.handle.as_raw_handle(), *th, null_mut(), null_mut(), 0,
index, &mut flag as *mut u8)
.ok()
.map_err(|e|
map_timer_err!(TimerErrorType::EPoll(e), "NtAssociateWaitCompletionPacket failed")
)?;
}
hpacks.push(hpack);
}
let poll_timeout = timeout.map_or(INFINITE, |f| f);
let mut overlp_res: Vec<OVERLAPPED_ENTRY> = vec![unsafe { mem::zeroed()}; timer_handlers.len()];
let mut complitions: u32 = 0;
let res =
unsafe
{
GetQueuedCompletionStatusEx(
HANDLE(self.handle.as_raw_handle()),
&mut overlp_res,
&mut complitions as *mut _,
poll_timeout,
false.into()
)
};
let mut pet: Vec<TimerId> = Vec::with_capacity(complitions as usize);
const ERR_TIMEOUT: HRESULT = HRESULT::from_win32(0x80070102_u32);
if let Err(ref e) = res && e.code() == ERR_TIMEOUT
{
return Ok(pet);
}
res
.map_err(|e|
map_timer_err!(TimerErrorType::EPoll(e), "GetQueuedCompletionStatusEx failed")
)?;
for (ov, _index) in overlp_res.into_iter().zip(0..complitions as usize)
{
pet.push(TimerId::from(timer_handlers[ov.dwNumberOfBytesTransferred as usize]));
}
return Ok(pet);
}
}
#[derive(Debug)]
pub struct TimerEventWatch
{
epoll: IocpCrap,
wakeup_event: Arc<EventFd>,
polling_flag: AtomicBool,
timers: RwLock<BTreeMap<TimerId, TimerFd>>,
}
impl fmt::Display for TimerEventWatch
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "fd:{}, cnt:{}",
(self.epoll.as_raw_handle() as usize).to_string(),
self.timers.try_read().map_or("locked".to_string(), |f| f.len().to_string())
)
}
}
impl Eq for TimerEventWatch {}
impl PartialEq for TimerEventWatch
{
fn eq(&self, other: &Self) -> bool
{
return self.epoll.as_raw_handle() == other.epoll.as_raw_handle();
}
}
impl TimerPollOps for TimerEventWatch
{
fn new() -> TimerResult<Self>
{
let epoll = IocpCrap::new()?;
let wakeup_event =
EventFd::new("wakeup_event".into())?;
return Ok(
Self
{
epoll: epoll,
wakeup_event: Arc::new(wakeup_event),
polling_flag: AtomicBool::new(false),
timers: RwLock::new(BTreeMap::new())
}
);
}
fn add(&self, timer: TimerFd) -> TimerResult<()>
{
let mut timers_lock =
self
.timers
.write()
.map_or_else(|e| e.into_inner(), |v| v);
let false =
timers_lock.contains_key(&(timer.as_raw_handle() as usize))
else
{
timer_err!(TimerErrorType::Duplicate, "can not add timer {} to epoll, reason duplicate",
timer)
};
timers_lock.insert(TimerId::from(timer.as_raw_handle()), timer);
return Ok(());
}
fn delete<FD: FdTimerMarker>(&self, timer: &FD) -> TimerResult<()>
{
let mut timers_lock =
self
.timers
.write()
.map_or_else(|e| e.into_inner(), |v| v);
let true =
timers_lock.contains_key(&timer.as_timer_id())
else
{
timer_err!(TimerErrorType::Duplicate, "can not add timer {} to epoll, reason does not exist",
timer.as_timer_id())
};
let _ = timers_lock.remove(&timer.as_timer_id());
return Ok(());
}
fn poll(&self, timeout: Option<i32>) -> TimerResult<Option<Vec<PollEventType>>>
{
if self.polling_flag.swap(true, Ordering::SeqCst) == true
{
timer_err!(TimerErrorType::EPollAlreadyPolling,
"epoll fd: '{}' other thread already polling", self.epoll);
}
let poll_ops = ||
{
let rd =
self
.timers
.read()
.unwrap_or_else(|e| e.into_inner());
let mut raw_handlers =
rd.iter().map(|(tid, _)| tid.0 as RawHandle).collect::<Vec<RawHandle>>();
raw_handlers.push(self.wakeup_event.as_raw_handle());
drop(rd);
let res =
self
.epoll
.wait(&raw_handlers, timeout.map(|v| v as u32));
return res;
};
let evs_res = poll_ops();
self.polling_flag.store(false, Ordering::SeqCst);
let poll_res: Vec<PollEventType> =
evs_res?
.into_iter()
.map(|res|
PollEventType::TimerRes(res, TimerReadRes::ok())
)
.collect();
if poll_res.is_empty() == false
{
return Ok(Some(poll_res));
}
else
{
return Ok(None);
}
}
fn get_count(&self) -> usize
{
return
self
.timers
.read()
.map_or_else(|e| e.into_inner(), |v| v)
.len();
}
fn get_poll_interruptor(&self) -> PollInterrupt
{
return PollInterrupt::new(Arc::downgrade(&self.wakeup_event));
}
fn interrupt_poll(&self) -> bool
{
return self.wakeup_event.write(0).is_ok();
}
}