#![forbid(unsafe_code)]
#![warn(clippy::pedantic)]
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::missing_panics_doc)]
use std::{
cell::RefCell,
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Wake, Waker},
thread::Thread,
};
use crossbeam_queue::SegQueue;
use futures::{channel::oneshot, FutureExt};
pub type Task = Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>;
thread_local! {
static EXECUTOR: RefCell<Option<Arc<Executor>>> = RefCell::new(None);
}
pub(crate) fn context<R>(f: impl FnOnce(&Arc<Executor>) -> R) -> R {
EXECUTOR.with(|e| {
let e = e.borrow();
let e = e
.as_ref()
.expect("spawn called outside of an executor context");
f(e)
})
}
pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
where
F: Future + Send + Sync + 'static,
F::Output: Send,
{
context(|e| e.spawn(fut))
}
#[derive(Default)]
pub struct Executor {
tasks: SegQueue<Task>,
threads: SegQueue<Thread>,
}
impl Executor {
pub fn register(self: &Arc<Self>) {
EXECUTOR.with(|exec| *exec.borrow_mut() = Some(self.clone()));
}
fn wake(&self, task: Task) {
self.tasks.push(task);
if let Some(t) = self.threads.pop() {
t.unpark();
};
}
pub fn poll_once(self: Arc<Self>) {
let mut task = {
loop {
if let Some(task) = self.tasks.pop() {
break task;
}
self.threads.push(std::thread::current());
std::thread::park();
}
};
let wake = Arc::new(TaskWaker {
task: Mutex::new(None),
executor: self,
});
let waker = Waker::from(wake.clone());
let mut cx = Context::from_waker(&waker);
if task.as_mut().poll(&mut cx).is_pending() {
wake.task.lock().unwrap().replace(task);
}
}
pub fn spawn<F>(&self, fut: F) -> JoinHandle<F::Output>
where
F: Future + Send + Sync + 'static,
F::Output: Send,
{
let (sender, handle) = JoinHandle::new();
let fut = Box::pin(fut.map(|out| sender.send(out).unwrap_or_default()));
self.wake(fut);
handle
}
}
struct TaskWaker {
executor: Arc<Executor>,
task: Mutex<Option<Task>>,
}
impl Wake for TaskWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}
fn wake_by_ref(self: &Arc<Self>) {
if let Some(task) = self.task.lock().unwrap().take() {
self.executor.wake(task);
}
}
}
pub struct JoinHandle<R>(oneshot::Receiver<R>);
impl<R> Unpin for JoinHandle<R> {}
impl<R> JoinHandle<R> {
#[must_use]
pub fn new() -> (oneshot::Sender<R>, Self) {
let (sender, receiver) = oneshot::channel();
(sender, Self(receiver))
}
}
impl<R> Future for JoinHandle<R> {
type Output = R;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx).map(Result::unwrap)
}
}