#![warn(missing_docs)]
mod concurrency;
pub mod io;
mod reactor;
#[cfg(test)]
mod test;
pub mod time;
use std::{
cell::{Cell, RefCell, UnsafeCell},
collections::VecDeque,
fmt::Debug,
future::{poll_fn, Future},
num::NonZero,
pin::{pin, Pin},
rc::Rc,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll, Wake, Waker},
thread::{self, ThreadId},
};
use atomic_waker::AtomicWaker;
use concurrent_queue::ConcurrentQueue;
use futures_core::future::LocalBoxFuture;
use slab::Slab;
#[doc(hidden)]
pub use concurrency::{JoinFuture, MergeFutureStream, MergeStream};
pub use io::Async;
use reactor::{Notifier, REACTOR};
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Ord, Eq, Hash)]
struct Id(NonZero<usize>);
impl Id {
const fn new(n: usize) -> Self {
Id(NonZero::new(n).expect("expected non-zero ID"))
}
const fn overflowing_incr(&self) -> Self {
match self.0.checked_add(1) {
Some(next) => Self(next),
None => const { Id::new(1) },
}
}
}
impl Wake for Notifier {
fn wake(self: Arc<Self>) {
let _ = self.notify();
}
}
pub fn block_on<T, F>(mut fut: F) -> T
where
F: Future<Output = T>,
{
let mut fut = pin!(fut);
let waker = REACTOR.with(|r| r.notifier()).into();
loop {
if let Poll::Ready(out) = fut.as_mut().poll(&mut Context::from_waker(&waker)) {
return out;
}
let wait_res = REACTOR.with(|r| r.wait());
if let Err(err) = wait_res {
log::error!(
"{:?} Error polling reactor: {err}",
std::thread::current().id()
);
}
}
}
#[derive(Debug)]
struct WakeQueue {
base_waker: AtomicWaker,
local_thread: ThreadId,
local: UnsafeCell<VecDeque<usize>>,
concurrent: ConcurrentQueue<usize>,
}
unsafe impl Send for WakeQueue {}
unsafe impl Sync for WakeQueue {}
impl WakeQueue {
fn with_capacity(capacity: usize) -> Self {
Self {
base_waker: AtomicWaker::new(),
local_thread: thread::current().id(),
local: UnsafeCell::new(VecDeque::with_capacity(capacity)),
concurrent: ConcurrentQueue::unbounded(),
}
}
fn push(&self, val: usize) {
if thread::current().id() == self.local_thread {
unsafe { (*self.local.get()).push_back(val) };
} else {
let _ = self.concurrent.push(val);
}
}
fn drain_for_each<F: FnMut(usize)>(&self, mut f: F) {
if thread::current().id() == self.local_thread {
let local_len = unsafe { (*self.local.get()).len() };
let con_len = self.concurrent.len();
log::trace!(
"{:?} {local_len} local wakeups, {con_len} concurrent wakeups",
std::thread::current().id()
);
for _ in 0..local_len {
let val = unsafe { (*self.local.get()).pop_front().unwrap() };
f(val);
}
for val in self.concurrent.try_iter().take(con_len) {
f(val);
}
}
}
fn reset(&self, init_val: usize) {
if thread::current().id() == self.local_thread {
unsafe {
(*self.local.get()).clear();
(*self.local.get()).push_back(init_val);
}
while self.concurrent.pop().is_ok() {}
}
}
}
struct TaskWaker {
queue: Arc<WakeQueue>,
awoken: AtomicBool,
task_id: usize,
}
impl Wake for TaskWaker {
fn wake(self: Arc<Self>) {
if self
.awoken
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
self.queue.push(self.task_id);
self.queue.base_waker.wake();
}
}
}
impl TaskWaker {
fn new(queue: Arc<WakeQueue>, task_id: usize) -> Self {
Self {
awoken: AtomicBool::new(false),
queue,
task_id,
}
}
fn waker_pair(queue: Arc<WakeQueue>, task_id: usize) -> (Arc<Self>, Waker) {
let this = Arc::new(Self::new(queue, task_id));
let waker = this.clone().into();
(this, waker)
}
fn to_sleep(&self) {
self.awoken.store(false, Ordering::Relaxed);
}
}
struct SpawnedTask<'a> {
future: LocalBoxFuture<'a, ()>,
handle_data: Rc<HandleData>,
}
struct Task<'a> {
future: LocalBoxFuture<'a, ()>,
handle_data: Rc<HandleData>,
waker_pair: (Arc<TaskWaker>, Waker),
}
impl<'a> Task<'a> {
fn poll(&mut self) -> Poll<()> {
let (waker_data, waker) = &self.waker_pair;
waker_data.to_sleep();
self.future.as_mut().poll(&mut Context::from_waker(waker))
}
fn from_spawned(spawned_task: SpawnedTask<'a>, waker_pair: (Arc<TaskWaker>, Waker)) -> Self {
let handle_data = spawned_task.handle_data;
handle_data.waker.set(Some(waker_pair.1.clone()));
Self {
future: spawned_task.future,
handle_data,
waker_pair,
}
}
}
pub struct Executor<'a> {
tasks: RefCell<Slab<Task<'a>>>,
spawned: RefCell<Vec<SpawnedTask<'a>>>,
wake_queue: Arc<WakeQueue>,
}
impl Default for Executor<'_> {
fn default() -> Self {
Self::new()
}
}
const MAIN_TASK_ID: usize = usize::MAX;
impl<'a> Executor<'a> {
pub fn new() -> Self {
Self::with_capacity(4)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
tasks: RefCell::new(Slab::with_capacity(capacity)),
spawned: RefCell::new(Vec::with_capacity(capacity)),
wake_queue: Arc::new(WakeQueue::with_capacity(capacity)),
}
}
pub fn spawn<T: 'a>(&self, fut: impl Future<Output = T> + 'a) -> TaskHandle<T> {
let ret = Rc::new(RetData {
value: Cell::new(None),
waker: Cell::new(None),
});
let ret_clone = ret.clone();
let handle_data = Rc::<HandleData>::default();
let mut spawned = self.spawned.borrow_mut();
spawned.push(SpawnedTask {
future: Box::pin(async move {
let retval = fut.await;
let ret = ret_clone;
ret.value.set(Some(retval));
if let Some(waker) = ret.waker.take() {
waker.wake();
}
}),
handle_data: handle_data.clone(),
});
TaskHandle { ret, handle_data }
}
pub fn spawn_rc<T: 'a, Fut: Future<Output = T> + 'a, F>(self: Rc<Self>, f: F) -> TaskHandle<T>
where
F: FnOnce(Rc<Self>) -> Fut + 'a,
{
let cl = self.clone();
self.spawn(f(cl))
}
fn register_base_waker(&self, base_waker: &Waker) {
self.wake_queue.base_waker.register(base_waker);
}
fn poll_tasks(&self) -> bool {
let mut main_task_awoken = false;
let mut tasks = self.tasks.borrow_mut();
self.wake_queue.drain_for_each(|task_id| {
if task_id == MAIN_TASK_ID {
main_task_awoken = true;
}
else if let Some(task) = tasks.get_mut(task_id) {
if task.handle_data.cancelled.get() || task.poll().is_ready() {
tasks.remove(task_id);
}
}
});
main_task_awoken
}
fn poll_spawned(&self) {
let mut tasks = self.tasks.borrow_mut();
while let Some(spawned_task) = self.spawned.borrow_mut().pop() {
if spawned_task.handle_data.cancelled.get() {
continue;
}
let next_vacancy = tasks.vacant_entry();
let task_id = next_vacancy.key();
assert_ne!(task_id, MAIN_TASK_ID);
let waker_pair = TaskWaker::waker_pair(self.wake_queue.clone(), task_id);
let mut task = Task::from_spawned(spawned_task, waker_pair);
if task.poll().is_pending() {
next_vacancy.insert(task);
}
}
}
pub fn block_on<T>(&self, fut: impl Future<Output = T>) -> T {
block_on(self.run(fut))
}
pub async fn run<T>(&self, fut: impl Future<Output = T>) -> T {
let mut fut = pin!(fut);
let (main_waker_data, main_waker) =
TaskWaker::waker_pair(self.wake_queue.clone(), MAIN_TASK_ID);
self.wake_queue.reset(MAIN_TASK_ID);
let out = poll_fn(move |cx| {
self.register_base_waker(cx.waker());
let main_task_awoken = self.poll_tasks();
if main_task_awoken {
main_waker_data.to_sleep();
if let Poll::Ready(out) = fut.as_mut().poll(&mut Context::from_waker(&main_waker)) {
return Poll::Ready(out);
}
}
self.poll_spawned();
Poll::Pending
})
.await;
self.tasks.borrow_mut().clear();
self.spawned.borrow_mut().clear();
out
}
}
struct RetData<T> {
value: Cell<Option<T>>,
waker: Cell<Option<Waker>>,
}
#[derive(Default)]
struct HandleData {
cancelled: Cell<bool>,
waker: Cell<Option<Waker>>,
}
pub struct TaskHandle<T> {
ret: Rc<RetData<T>>,
handle_data: Rc<HandleData>,
}
impl<T> TaskHandle<T> {
pub fn cancel(&self) {
self.handle_data.cancelled.set(true);
if let Some(waker) = self.handle_data.waker.take() {
waker.wake();
}
}
pub fn is_finished(&self) -> bool {
unsafe { (*self.ret.value.as_ptr()).is_some() }
}
pub fn is_cancelled(&self) -> bool {
self.handle_data.cancelled.get()
}
}
impl<T> Future for TaskHandle<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(val) = self.ret.value.take() {
return Poll::Ready(val);
}
let mut waker = self.ret.waker.take();
match &mut waker {
Some(waker) => waker.clone_from(cx.waker()),
None => waker = Some(cx.waker().clone()),
}
self.ret.waker.set(waker);
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use std::{future::pending, time::Duration};
use crate::{test::MockWaker, time::sleep};
use super::*;
#[test]
fn spawn_and_poll() {
let ex = Executor::new();
assert_eq!(ex.tasks.borrow().len(), 0);
ex.spawn(pending::<()>());
ex.spawn(pending::<()>());
ex.spawn(pending::<()>());
ex.poll_tasks();
ex.poll_spawned();
assert_eq!(ex.tasks.borrow().len(), 3);
ex.spawn(async {});
ex.spawn(async {});
ex.poll_tasks();
ex.poll_spawned();
assert_eq!(ex.tasks.borrow().len(), 3);
}
#[test]
fn task_waker() {
let base_waker = Arc::new(MockWaker::default());
let mut n = 0;
let ex = Executor::new();
ex.register_base_waker(&base_waker.clone().into());
ex.spawn(poll_fn(|cx| {
n += 1;
cx.waker().wake_by_ref();
Poll::<()>::Pending
}));
ex.poll_spawned();
assert_eq!(unsafe { (*ex.wake_queue.local.get()).len() }, 1);
assert!(base_waker.get());
ex.poll_tasks();
assert_eq!(unsafe { (*ex.wake_queue.local.get()).len() }, 1);
drop(ex);
assert_eq!(n, 2);
}
#[test]
fn cancel() {
let ex = Executor::new();
assert_eq!(ex.tasks.borrow().len(), 0);
let task = ex.spawn(pending::<()>());
task.cancel();
assert!(task.is_cancelled());
ex.poll_tasks();
ex.poll_spawned();
assert_eq!(ex.tasks.borrow().len(), 0);
let task = ex.spawn(pending::<()>());
assert!(!task.is_cancelled());
ex.poll_tasks();
ex.poll_spawned();
assert_eq!(ex.tasks.borrow().len(), 1);
task.cancel();
ex.poll_tasks();
ex.poll_spawned();
assert_eq!(ex.tasks.borrow().len(), 0);
}
#[test]
fn wake_queue() {
let queue = WakeQueue::with_capacity(4);
queue.push(12);
queue.push(13);
thread::scope(|s| {
let queue = &queue;
for i in 0..10 {
s.spawn(move || queue.push(i));
}
});
assert_eq!(queue.concurrent.len(), 10);
assert_eq!(unsafe { (*queue.local.get()).len() }, 2);
let mut elems = vec![];
queue.drain_for_each(|e| elems.push(e));
elems.sort_unstable();
assert_eq!(elems, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13]);
queue.push(12);
queue.push(13);
queue.reset(6);
assert_eq!(queue.concurrent.len(), 0);
assert_eq!(unsafe { (*queue.local.get()).len() }, 1);
queue.drain_for_each(|e| assert_eq!(e, 6));
}
#[test]
fn switch_waker() {
let ex = Executor::new();
let waker1 = Arc::new(MockWaker::default());
let waker2 = Arc::new(MockWaker::default());
let mut fut = pin!(ex.run(async {
let _bg = ex.spawn(sleep(Duration::from_millis(100)));
sleep(Duration::from_millis(50)).await;
pending::<()>().await;
}));
assert!(fut
.as_mut()
.poll(&mut Context::from_waker(&waker1.clone().into()))
.is_pending());
REACTOR.with(|r| r.wait()).unwrap();
assert!(waker1.get());
assert!(fut
.as_mut()
.poll(&mut Context::from_waker(&waker2.clone().into()))
.is_pending());
REACTOR.with(|r| r.wait()).unwrap();
assert!(waker2.get());
}
#[test]
fn switch_waker_join() {
let waker1 = Arc::new(MockWaker::default());
let waker2 = Arc::new(MockWaker::default());
let mut fut = pin!(join!(
sleep(Duration::from_millis(50)),
sleep(Duration::from_millis(100))
));
assert!(fut
.as_mut()
.poll(&mut Context::from_waker(&waker1.clone().into()))
.is_pending());
REACTOR.with(|r| r.wait()).unwrap();
assert!(waker1.get());
assert!(fut
.as_mut()
.poll(&mut Context::from_waker(&waker2.clone().into()))
.is_pending());
REACTOR.with(|r| r.wait()).unwrap();
assert!(waker2.get());
}
}