use async_task::{Builder, Runnable};
use slab::Slab;
use std::{
cell::RefCell,
future::Future,
rc::Rc,
sync::{
atomic::{AtomicBool, Ordering},
mpsc, Arc, Mutex,
},
task::Waker,
};
use crate::{
sources::{
channel::ChannelError,
ping::{make_ping, Ping, PingError, PingSource},
EventSource,
},
Poll, PostAction, Readiness, Token, TokenFactory,
};
#[derive(Debug)]
pub struct Executor<T> {
state: Rc<State<T>>,
ping: PingSource,
}
#[derive(Clone, Debug)]
pub struct Scheduler<T> {
state: Rc<State<T>>,
}
#[derive(Debug)]
struct State<T> {
incoming: mpsc::Receiver<Runnable<usize>>,
sender: Arc<Sender>,
active_tasks: RefCell<Option<Slab<Active<T>>>>,
}
#[derive(Debug)]
struct Sender {
sender: Mutex<mpsc::Sender<Runnable<usize>>>,
wake_up: Ping,
notified: AtomicBool,
}
#[derive(Debug)]
enum Active<T> {
Future(Waker),
Finished(T),
}
impl<T> Active<T> {
fn is_finished(&self) -> bool {
matches!(self, Active::Finished(_))
}
}
impl<T> Scheduler<T> {
pub fn schedule<Fut: 'static>(&self, future: Fut) -> Result<(), ExecutorDestroyed>
where
Fut: Future<Output = T>,
T: 'static,
{
struct StoreOnDrop<'a, T> {
index: usize,
value: Option<T>,
state: &'a State<T>,
}
impl<T> Drop for StoreOnDrop<'_, T> {
fn drop(&mut self) {
let mut active_tasks = self.state.active_tasks.borrow_mut();
if let Some(active_tasks) = active_tasks.as_mut() {
if let Some(value) = self.value.take() {
active_tasks[self.index] = Active::Finished(value);
} else {
active_tasks.remove(self.index);
}
}
}
}
fn assert_send_and_sync<T: Send + Sync>(_: &T) {}
let mut active_guard = self.state.active_tasks.borrow_mut();
let active_tasks = active_guard.as_mut().ok_or(ExecutorDestroyed)?;
let index = active_tasks.vacant_key();
let future = {
let state = self.state.clone();
async move {
let mut guard = StoreOnDrop {
index,
value: None,
state: &state,
};
let value = future.await;
guard.value = Some(value);
}
};
let schedule = {
let sender = self.state.sender.clone();
move |runnable| sender.send(runnable)
};
assert_send_and_sync(&schedule);
let (runnable, task) = Builder::new()
.metadata(index)
.spawn_local(move |_| future, schedule);
active_tasks.insert(Active::Future(runnable.waker()));
drop(active_guard);
runnable.schedule();
task.detach();
Ok(())
}
}
impl Sender {
fn send(&self, runnable: Runnable<usize>) {
if let Err(e) = self
.sender
.lock()
.unwrap_or_else(|e| e.into_inner())
.send(runnable)
{
std::mem::forget(e);
unreachable!("Attempted to send runnable to a stopped executor");
}
if self.notified.swap(true, Ordering::SeqCst) {
return;
}
self.wake_up.ping();
}
}
impl<T> Drop for Executor<T> {
fn drop(&mut self) {
let active_tasks = self.state.active_tasks.borrow_mut().take().unwrap();
for (_, task) in active_tasks {
if let Active::Future(waker) = task {
std::panic::catch_unwind(|| waker.wake()).ok();
}
}
while self.state.incoming.try_recv().is_ok() {}
}
}
#[derive(thiserror::Error, Debug)]
#[error("the executor was destroyed")]
pub struct ExecutorDestroyed;
pub fn executor<T>() -> crate::Result<(Executor<T>, Scheduler<T>)> {
let (sender, incoming) = mpsc::channel();
let (wake_up, ping) = make_ping()?;
let state = Rc::new(State {
incoming,
active_tasks: RefCell::new(Some(Slab::new())),
sender: Arc::new(Sender {
sender: Mutex::new(sender),
wake_up,
notified: AtomicBool::new(false),
}),
});
Ok((
Executor {
state: state.clone(),
ping,
},
Scheduler { state },
))
}
impl<T> EventSource for Executor<T> {
type Event = T;
type Metadata = ();
type Ret = ();
type Error = ExecutorError;
fn process_events<F>(
&mut self,
readiness: Readiness,
token: Token,
mut callback: F,
) -> Result<PostAction, Self::Error>
where
F: FnMut(T, &mut ()),
{
let state = &self.state;
state.sender.notified.store(false, Ordering::SeqCst);
let clear_readiness = {
let mut clear_readiness = false;
for _ in 0..1024 {
let runnable = match state.incoming.try_recv() {
Ok(runnable) => runnable,
Err(_) => {
clear_readiness = true;
break;
}
};
let index = *runnable.metadata();
runnable.run();
let mut active_guard = state.active_tasks.borrow_mut();
let active_tasks = active_guard.as_mut().unwrap();
if let Some(state) = active_tasks.get(index) {
if state.is_finished() {
let result = match active_tasks.remove(index) {
Active::Finished(result) => result,
_ => unreachable!(),
};
callback(result, &mut ());
}
}
}
clear_readiness
};
if clear_readiness {
self.ping
.process_events(readiness, token, |(), &mut ()| {})
.map_err(ExecutorError::WakeError)?;
}
Ok(PostAction::Continue)
}
fn register(&mut self, poll: &mut Poll, token_factory: &mut TokenFactory) -> crate::Result<()> {
self.ping.register(poll, token_factory)?;
Ok(())
}
fn reregister(
&mut self,
poll: &mut Poll,
token_factory: &mut TokenFactory,
) -> crate::Result<()> {
self.ping.reregister(poll, token_factory)?;
Ok(())
}
fn unregister(&mut self, poll: &mut Poll) -> crate::Result<()> {
self.ping.unregister(poll)?;
Ok(())
}
}
#[derive(thiserror::Error, Debug)]
pub enum ExecutorError {
#[error("error adding new futures")]
NewFutureError(ChannelError),
#[error("error processing wake events")]
WakeError(PingError),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ready() {
let mut event_loop = crate::EventLoop::<u32>::try_new().unwrap();
let handle = event_loop.handle();
let (exec, sched) = executor::<u32>().unwrap();
handle
.insert_source(exec, move |ret, &mut (), got| {
*got = ret;
})
.unwrap();
let mut got = 0;
let fut = async { 42 };
event_loop
.dispatch(Some(::std::time::Duration::ZERO), &mut got)
.unwrap();
assert_eq!(got, 0);
sched.schedule(fut).unwrap();
event_loop
.dispatch(Some(::std::time::Duration::ZERO), &mut got)
.unwrap();
assert_eq!(got, 42);
}
}