use std::io;
use std::sync::Arc;
use tokio::sync::mpsc;
pub trait GenericSource {
fn register(&mut self, notifier: IONotifier) -> Result<(), io::Error>;
fn deregister(&mut self) -> Result<(), io::Error>;
}
pub enum IOSource<'a> {
#[cfg(not(target_arch = "wasm32"))]
MIO(&'a mut dyn mio::event::Source),
Generic(&'a mut dyn GenericSource),
Empty,
}
pub struct IONotifierInner {
pub(crate) eid: usize,
pub(crate) inbound: EventSender<usize>,
pub(crate) outbound: EventSender<usize>,
}
#[derive(Clone)]
pub struct IONotifier(Arc<IONotifierInner>);
impl std::ops::Deref for IONotifier {
type Target = IONotifierInner;
fn deref(&self) -> &IONotifierInner {
&self.0
}
}
impl IONotifier {
pub fn new(eid: usize, inbound: EventSender<usize>, outbound: EventSender<usize>) -> Self {
Self(Arc::new(IONotifierInner { eid, inbound, outbound }))
}
pub async fn notify(&self, event: IOInterest) -> Option<()> {
if event.is_readable() {
self.inbound.notify(self.eid).await?
}
if event.is_writable() {
self.outbound.notify(self.eid).await?
}
Some(())
}
pub fn blocking_notify(&self, event: IOInterest) -> Option<()> {
if event.is_readable() {
self.inbound.blocking_notify(self.eid)?
}
if event.is_writable() {
self.outbound.blocking_notify(self.eid)?
}
Some(())
}
}
#[cfg(not(target_arch = "wasm32"))]
mod arch {
use super::*;
#[derive(Clone)]
pub struct IOWaker(Arc<mio::Waker>);
impl IOWaker {
fn new(registry: &mio::Registry, token: mio::Token) -> Result<Self, io::Error> {
Ok(IOWaker(Arc::new(mio::Waker::new(registry, token)?)))
}
pub fn wake(&self) -> Result<(), io::Error> {
self.0.wake()
}
}
pub type IOToken = mio::Token;
#[allow(non_snake_case)]
pub fn IOToken(id: usize) -> IOToken {
mio::Token(id)
}
pub type IOInterest = mio::Interest;
pub type IOEvents = mio::event::Events;
pub type IOEvent = mio::event::Event;
pub struct IOPoll(mio::Poll);
impl IOPoll {
pub fn new() -> Result<Self, io::Error> {
Ok(Self(mio::Poll::new()?))
}
pub fn register(
&mut self, src: &mut dyn mio::event::Source, token: IOToken, interest: IOInterest,
) -> Result<(), io::Error> {
self.0.registry().register(src, token, interest)
}
pub fn deregister(&self, src: &mut dyn mio::event::Source) -> Result<(), io::Error> {
self.0.registry().deregister(src)
}
pub fn reregister(
&self, src: &mut dyn mio::event::Source, token: IOToken, interest: IOInterest,
) -> Result<(), io::Error> {
self.0.registry().reregister(src, token, interest)
}
pub fn waker(&mut self, token: IOToken) -> Result<IOWaker, io::Error> {
IOWaker::new(self.0.registry(), token)
}
pub fn poll(&mut self, events: &mut IOEvents) -> Result<(), io::Error> {
self.0.poll(events, None)
}
}
pub struct TaskHandle(tokio::task::JoinHandle<()>);
impl Drop for TaskHandle {
fn drop(&mut self) {
self.0.abort();
}
}
#[derive(Clone)]
pub struct AsyncSpawner(pub tokio::runtime::Handle);
impl AsyncSpawner {
#[inline]
pub fn spawn<F: std::future::Future<Output = ()> + Send + 'static>(
&self, fut: F,
) -> tokio::task::JoinHandle<()> {
self.0.spawn(fut)
}
#[inline]
pub fn spawn_with_task_handle<F: std::future::Future<Output = ()> + Send + 'static>(
&self, fut: F,
) -> TaskHandle {
TaskHandle(self.spawn(fut))
}
}
pub async fn sleep(duration: std::time::Duration) {
tokio::time::sleep(duration).await
}
pub async fn timeout<F: std::future::Future<Output = T> + Send + 'static, T: Send + 'static>(
duration: std::time::Duration, fut: F,
) -> Result<T, tokio::time::error::Elapsed> {
tokio::time::timeout(duration, fut).await
}
}
#[cfg(target_arch = "wasm32")]
mod arch {
use super::*;
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
#[derive(Clone)]
pub struct IOWaker(Arc<tokio::sync::Notify>);
impl IOWaker {
pub fn wake(&self) -> Result<(), io::Error> {
self.0.notify_one();
Ok(())
}
}
#[derive(Copy, Clone)]
pub struct IOToken(pub usize);
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct IOInterest(u8);
impl IOInterest {
pub const EMPTY: Self = Self(0);
pub const READABLE: Self = Self(1 << 0);
pub const WRITABLE: Self = Self(1 << 1);
pub fn is_readable(&self) -> bool {
*self & Self::READABLE != Self::EMPTY
}
pub fn is_writable(&self) -> bool {
*self & Self::WRITABLE != Self::EMPTY
}
}
impl std::ops::BitOr for IOInterest {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
Self(self.0 | rhs.0)
}
}
impl std::ops::BitAnd for IOInterest {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
Self(self.0 & rhs.0)
}
}
pub type IOEvents = Vec<IOEvent>;
pub struct IOEvent {
token: IOToken,
state: IOInterest,
}
impl IOEvent {
pub fn token(&self) -> IOToken {
self.token
}
pub fn is_readable(&self) -> bool {
self.state.is_readable()
}
pub fn is_writable(&self) -> bool {
self.state.is_writable()
}
}
pub struct IOPoll {
waker_notifier: Arc<tokio::sync::Notify>,
waker_token: Option<IOToken>,
}
impl IOPoll {
pub fn new(_capacity: usize) -> Result<Self, io::Error> {
let waker_notifier = Arc::new(tokio::sync::Notify::new());
Ok(IOPoll {
waker_notifier,
waker_token: None,
})
}
pub fn waker(&mut self, token: IOToken) -> Result<IOWaker, io::Error> {
self.waker_token = Some(token);
Ok(IOWaker(self.waker_notifier.clone()))
}
pub async fn poll(&mut self, events: &mut IOEvents) -> Result<(), io::Error> {
events.clear();
let token = self.waker_token.as_ref().expect("waker must be created");
self.waker_notifier.notified().await;
events.push(IOEvent {
token: *token,
state: IOInterest::READABLE,
});
Ok(())
}
}
pub struct TaskHandle;
#[derive(Clone)]
pub enum AsyncSpawner {
WASMExecutor(wasm_futures_executor::ThreadPool),
SingleThreaded,
}
impl AsyncSpawner {
pub fn spawn<F: std::future::Future<Output = ()> + Send + 'static>(&self, fut: F) {
match self {
Self::WASMExecutor(spawner) => spawner.spawn_ok(fut),
Self::SingleThreaded => wasm_bindgen_futures::spawn_local(fut),
}
}
pub fn spawn_with_task_handle<F: std::future::Future<Output = ()> + Send + 'static>(
&self, fut: F,
) -> TaskHandle {
self.spawn(fut);
TaskHandle
}
}
pub async fn sleep(duration: std::time::Duration) {
static I32_MAX: u128 = i32::MAX as u128;
let notifier = Arc::new(tokio::sync::Notify::new());
let notifier_clone = notifier.clone();
wasm_bindgen_futures::spawn_local(async move {
let millis = duration.as_millis();
let delay = if millis > I32_MAX { i32::MAX } else { millis as i32 };
let mut cb = |resolve: js_sys::Function, reject: js_sys::Function| {
let scope = js_sys::global().unchecked_into::<web_sys::DedicatedWorkerGlobalScope>();
if let Err(e) = scope.set_timeout_with_callback_and_timeout_and_arguments_0(&resolve, delay) {
reject.call1(&JsValue::NULL, &e).ok();
}
};
let p = js_sys::Promise::new(&mut cb);
JsFuture::from(p).await.ok();
notifier_clone.notify_one();
});
notifier.notified().await
}
pub async fn timeout<F: std::future::Future<Output = T> + Send + 'static, T: Send + 'static>(
duration: std::time::Duration, fut: F,
) -> Result<T, io::Error> {
let notifier_fut = Arc::new(tokio::sync::Notify::new());
let notifier_timeout = Arc::new(tokio::sync::Notify::new());
let notifier_fut_clone = notifier_fut.clone();
let notifier_timeout_clone = notifier_timeout.clone();
let result = Arc::new(tokio::sync::Mutex::new(None));
let result_clone = Arc::clone(&result);
wasm_bindgen_futures::spawn_local(async move {
let res = fut.await;
*result_clone.lock().await = Some(res);
notifier_fut_clone.notify_one();
});
wasm_bindgen_futures::spawn_local(async move {
sleep(duration).await;
notifier_timeout_clone.notify_one();
});
tokio::select! {
_ = notifier_fut.notified() => {
let res = result.lock().await.take().unwrap();
Ok(res)
},
_ = notifier_timeout.notified() => Err(io::Error::new(io::ErrorKind::TimedOut, "Timeout")),
}
}
}
pub use arch::*;
pub struct EventReceiver<T> {
rx: mpsc::Receiver<T>,
}
pub struct EventSender<T> {
waker: IOWaker,
tx: mpsc::Sender<T>,
}
impl<T> Clone for EventSender<T> {
fn clone(&self) -> Self {
Self {
waker: self.waker.clone(),
tx: self.tx.clone(),
}
}
}
pub fn new_poll_event<T>(waker: &IOWaker, channel_size: usize) -> (EventSender<T>, EventReceiver<T>) {
let (tx, rx) = mpsc::channel(channel_size);
(
EventSender {
waker: waker.clone(),
tx,
},
EventReceiver { rx },
)
}
impl<T> EventSender<T> {
pub fn blocking_notify(&self, v: T) -> Option<()> {
self.tx.blocking_send(v).ok()?;
self.waker.wake().ok()
}
pub async fn notify(&self, v: T) -> Option<()> {
self.tx.send(v).await.ok()?;
self.waker.wake().ok()
}
}
impl<T> EventReceiver<T> {
pub fn try_listen(&mut self) -> Option<T> {
match self.rx.try_recv() {
Ok(d) => Some(d),
Err(_) => None,
}
}
#[allow(dead_code)]
pub async fn listen(&mut self) -> Option<T> {
self.rx.recv().await
}
}
pub type EndpointEventReceiver = EventReceiver<usize>;
pub type EndpointEventSender = EventSender<usize>;
pub struct EventBlockingReceiver<T> {
rx: std::sync::mpsc::Receiver<T>,
}
pub struct EventBlockingSender<T> {
waker: IOWaker,
tx: std::sync::mpsc::SyncSender<T>,
}
impl<T> Clone for EventBlockingSender<T> {
fn clone(&self) -> Self {
Self {
waker: self.waker.clone(),
tx: self.tx.clone(),
}
}
}
pub fn new_poll_event_blocking<T>(
waker: &IOWaker, channel_size: usize,
) -> (EventBlockingSender<T>, EventBlockingReceiver<T>) {
let (tx, rx) = std::sync::mpsc::sync_channel(channel_size);
(
EventBlockingSender {
waker: waker.clone(),
tx,
},
EventBlockingReceiver { rx },
)
}
impl<T> EventBlockingSender<T> {
pub fn notify(&self, v: T) -> Option<()> {
self.tx.send(v).ok()?;
self.waker.wake().ok()
}
}
impl<T> EventBlockingReceiver<T> {
pub fn listen(&mut self) -> Option<T> {
match self.rx.try_recv() {
Ok(d) => Some(d),
Err(_) => None,
}
}
}
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use tracing::Level;
pub struct TaskPool {
_loop: TaskHandle,
task_tx: mpsc::Sender<Pin<Box<dyn Future<Output = ()> + Send>>>,
}
macro_rules! taskpool_log {
($lvl: expr, $id: expr, $($arg: tt)+) => {
tracing::event!(target: "TaskPool", $lvl, $($arg)+)
}
}
impl TaskPool {
async fn control_loop(
max_pooled: usize, mut task_rx: mpsc::Receiver<Pin<Box<dyn Future<Output = ()> + Send>>>, spawner: AsyncSpawner,
) {
let mut task_id: usize = 0;
let quota = Arc::new(tokio::sync::Semaphore::new(max_pooled));
let mut pending = HashMap::new(); let mut finished_ids = mpsc::unbounded_channel(); loop {
let permit = match quota.clone().acquire_owned().await {
Err(e) => {
taskpool_log!(Level::ERROR, local_id, "loop: failed to acquire quota: {:?}", e);
continue
}
Ok(p) => p,
};
while let Ok(id) = finished_ids.1.try_recv() {
pending.remove(&id);
}
let fut = match task_rx.recv().await {
Some(t) => t,
None => return,
};
let id = task_id;
task_id = task_id.wrapping_add(1);
let finished = finished_ids.0.clone();
pending.insert(
id,
spawner.spawn_with_task_handle(async move {
fut.await;
finished.send(id).ok();
drop(permit);
}),
);
}
}
pub fn new(max_pooled: usize, spawner: &AsyncSpawner) -> Self {
let (task_tx, task_rx) = mpsc::channel(10);
Self {
_loop: spawner.spawn_with_task_handle(Self::control_loop(max_pooled, task_rx, spawner.clone())),
task_tx,
}
}
pub async fn submit<F: Future<Output = ()> + Send + 'static>(&self, fut: F) -> Result<(), ()> {
self.task_tx.send(Box::pin(fut)).await.map_err(|_| ())
}
}
use parking_lot::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
pub struct NonblockingTaskPool {
task_id: AtomicUsize,
quota: Arc<tokio::sync::Semaphore>,
pending: Arc<Mutex<HashMap<usize, TaskHandle>>>,
spawner: AsyncSpawner,
}
impl NonblockingTaskPool {
pub fn new(max_pooled: usize, spawner: &AsyncSpawner) -> Self {
Self {
task_id: AtomicUsize::new(0),
quota: Arc::new(tokio::sync::Semaphore::new(max_pooled)),
pending: Arc::new(Mutex::new(HashMap::new())),
spawner: spawner.clone(),
}
}
pub fn try_submit<F: Future<Output = ()> + Send + 'static>(&self, fut: F) -> Result<(), ()> {
let permit = self.quota.clone().try_acquire_owned().map_err(|_| ())?;
let id = self.task_id.fetch_add(1, Ordering::Relaxed);
let spawner = self.spawner.clone();
let pending = self.pending.clone();
self.pending.lock().insert(
id,
spawner.spawn_with_task_handle(async move {
fut.await;
pending.lock().remove(&id);
drop(permit);
}),
);
Ok(())
}
}