use super::event::SafeEvent;
use super::wrappers::Session;
use crate::platform::wintun::queue::SessionQueueT;
use parking_lot::Mutex;
use std::io::{self, Read, Write};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use windows::{
Win32::Foundation::HANDLE,
Win32::System::Threading::{WaitForMultipleObjects, WAIT_ABANDONED_0, WAIT_OBJECT_0},
Win32::System::WindowsProgramming::INFINITE,
};
pub struct AsyncTokioQueue {
session: Arc<Mutex<Session>>,
shutdown_event: Arc<SafeEvent>,
data_ready: Arc<Mutex<DataReadinessHandler>>,
}
impl SessionQueueT for AsyncTokioQueue {
fn new(session: Session) -> Self {
Self {
session: Arc::new(Mutex::new(session)),
shutdown_event: Arc::new(SafeEvent::new(true, false)),
data_ready: Default::default(),
}
}
}
impl Drop for AsyncTokioQueue {
fn drop(&mut self) {
self.shutdown_event.set_event();
}
}
const WAIT_OBJECT_1: u32 = WAIT_OBJECT_0 + 1;
const WAIT_ABANDONED_1: u32 = WAIT_ABANDONED_0 + 1;
#[derive(Default)]
struct DataReadinessHandler {
tokio_wait_thread: Option<tokio::task::JoinHandle<()>>,
waker: Option<Waker>,
}
fn wait_for_read(
read_event: HANDLE,
shutdown_event: Arc<SafeEvent>,
data_ready: Arc<Mutex<DataReadinessHandler>>,
) {
let result =
unsafe { WaitForMultipleObjects(&[shutdown_event.handle(), read_event], false, INFINITE) };
match result {
WAIT_OBJECT_0 => {}
WAIT_OBJECT_1 => {
let mut data_ready = data_ready.lock();
if let Some(waker) = (*data_ready).waker.take() {
waker.wake()
}
}
WAIT_ABANDONED_0 => {}
WAIT_ABANDONED_1 => {
panic!("Read event deleted unexpectedly");
}
e => {
panic!("Unexpected event result: {e:?}");
}
}
}
impl AsyncRead for AsyncTokioQueue {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let mut data_ready = self.data_ready.lock();
if (*data_ready).waker.is_some() {
(*data_ready).waker = Some(cx.waker().clone());
return Poll::Pending;
}
let mut session = self.session.lock();
match session.read(buf.initialize_unfilled()) {
Ok(n) => {
buf.set_filled(n);
Poll::Ready(Ok(()))
}
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
(*data_ready).waker = Some(cx.waker().clone());
let read_event = session.read_event();
let inner_shutdown_event = self.shutdown_event.clone();
let inner_data_ready = self.data_ready.clone();
(*data_ready).tokio_wait_thread =
Some(tokio::task::spawn_blocking(move || {
wait_for_read(read_event, inner_shutdown_event, inner_data_ready);
}));
Poll::Pending
} else {
Poll::Ready(Err(e))
}
}
}
}
}
impl AsyncWrite for AsyncTokioQueue {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.session.lock().write(buf) {
Ok(len) => Poll::Ready(Ok(len)),
Err(err) => {
if err.kind() != io::ErrorKind::WouldBlock {
Poll::Ready(Err(err))
} else {
let waker = cx.waker().clone();
let _ = tokio::task::spawn_local(async { waker.wake() });
Poll::Pending
}
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}