use alloc::collections::VecDeque;
use alloc::rc::Rc;
use core::cell::{Cell, RefCell};
use core::future::Future;
use core::marker::PhantomData;
use core::mem::{forget, ManuallyDrop};
use core::num::NonZeroUsize;
use core::task::{Poll, Waker};
use crate::sync::Arc;
use crate::{Event, EventListenerRc, IntoNotification};
use async_task::{Runnable, Task};
use atomic_waker::AtomicWaker;
use concurrent_queue::ConcurrentQueue;
use futures_lite::prelude::*;
use slab::Slab;
pub struct Executor<'a, T = DefaultThreadId> {
state: Arc<State<T>>,
_marker: PhantomData<&'a Cell<Rc<()>>>,
}
impl<T> Drop for Executor<'_, T> {
fn drop(&mut self) {
loop {
let mut thread_state = self.state.thread_state.borrow_mut();
let waker = match thread_state.active.drain().next() {
Some(waker) => waker,
None => break,
};
drop(thread_state);
waker.wake();
}
while self.state.task_queue.pop().is_ok()
&& self
.state
.thread_state
.borrow_mut()
.task_queue
.pop_front()
.is_some()
{}
unsafe {
ManuallyDrop::drop(&mut self.state.thread_state.borrow_mut());
}
}
}
struct State<T> {
task_queue: ConcurrentQueue<Runnable>,
mainstream_waker: AtomicWaker,
thread_id: T,
origin_thread: Option<NonZeroUsize>,
thread_state: RefCell<ManuallyDrop<ThreadState>>,
}
unsafe impl<T: Send + Sync> Send for State<T> {}
unsafe impl<T: Send + Sync> Sync for State<T> {}
struct ThreadState {
task_queue: VecDeque<Runnable>,
thread_waker: Rc<Event<Option<Runnable>>>,
active: Slab<Waker>,
is_mainstream_listening: bool,
}
impl<'a, T: Default + ThreadId + Send + Sync + 'static> Default for Executor<'a, T> {
fn default() -> Self {
Self::with_thread_id(T::default())
}
}
impl<'a> Executor<'a> {
pub fn new() -> Self {
Self::with_thread_id(DefaultThreadId::new())
}
}
impl<'a, T: ThreadId + Send + Sync + 'static> Executor<'a, T> {
pub fn with_thread_id(thread_id: T) -> Self {
Self {
state: Arc::new(State {
task_queue: ConcurrentQueue::unbounded(),
mainstream_waker: AtomicWaker::new(),
origin_thread: thread_id.id(),
thread_id,
thread_state: RefCell::new(ManuallyDrop::new(ThreadState {
task_queue: VecDeque::new(),
thread_waker: Rc::new(Event::new()),
active: Slab::new(),
is_mainstream_listening: false,
})),
}),
_marker: PhantomData,
}
}
fn with_thread_local<R>(&self, f: impl FnOnce(&mut ThreadState) -> R) -> R {
f(&mut self.state.thread_state.borrow_mut())
}
pub fn is_empty(&self) -> bool {
self.with_thread_local(|state| state.task_queue.is_empty())
&& self.state.task_queue.is_empty()
}
pub fn spawn<O: 'a>(&self, future: impl Future<Output = O> + 'a) -> Task<O> {
let (runnable, task) = self.with_thread_local(move |state| {
let index = state.active.vacant_key();
let future = {
let state = self.state.clone();
async move {
let _guard = CallOnDrop(move || {
let mut thread_state = state.thread_state.borrow_mut();
drop(thread_state.active.try_remove(index));
});
future.await
}
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, self.schedule()) };
state.active.insert(runnable.waker());
(runnable, task)
});
runnable.schedule();
task
}
pub fn try_tick(&self) -> bool {
let mut runnable = self.with_thread_local(|state| {
if let Some(runnable) = state.task_queue.pop_front() {
state.thread_waker.notify(1.tag_with(|| None));
Some(runnable)
} else {
None
}
});
if runnable.is_none() {
if let Ok(r) = self.state.task_queue.pop() {
self.state.mainstream_waker.wake();
runnable = Some(r);
}
}
match runnable {
Some(runnable) => {
runnable.run();
true
}
None => false,
}
}
pub async fn tick(&self) {
Ticker::new(&self.state).tick().await;
}
pub async fn run<O>(&self, f: impl Future<Output = O>) -> O {
let runner = async move {
let mut ticker = Ticker::new(&self.state);
loop {
ticker.tick().await;
}
};
f.or(runner).await
}
#[cfg_attr(coverage, no_coverage)]
fn schedule(&self) -> impl Fn(Runnable) {
let state = self.state.clone();
move |runnable| {
if let (Some(origin_id), Some(our_id)) = (state.origin_thread, state.thread_id.id()) {
if origin_id == our_id {
let mut thread_state = state.thread_state.borrow_mut();
let mut runnable = Some(runnable);
thread_state
.thread_waker
.notify(1.tag_with(|| runnable.take()));
if let Some(runnable) = runnable {
thread_state.task_queue.push_back(runnable);
}
return;
}
}
if let Err(e) = state.task_queue.push(runnable) {
forget(e.into_inner());
return;
}
state.mainstream_waker.wake();
}
}
}
struct Ticker<'a, T> {
state: &'a State<T>,
is_mainstream: bool,
_marker: PhantomData<*const ()>,
}
impl<'a, T: ThreadId + Send + Sync + 'static> Ticker<'a, T> {
fn new(state: &'a State<T>) -> Self {
Self {
state,
is_mainstream: false,
_marker: PhantomData,
}
}
async fn tick(&mut self) {
let listener = {
let thread_state = self.state.thread_state.borrow_mut();
EventListenerRc::new(thread_state.thread_waker.clone())
};
futures_lite::pin!(listener);
loop {
{
let mut thread_state = self.state.thread_state.borrow_mut();
if let Some(runnable) = thread_state.task_queue.pop_front() {
thread_state.thread_waker.notify(1.tag_with(|| None));
drop(thread_state);
runnable.run();
return;
}
if !thread_state.is_mainstream_listening {
thread_state.is_mainstream_listening = true;
self.is_mainstream = true;
}
}
let runnable = {
if self.is_mainstream {
let mainstream_runnable = futures_lite::future::poll_fn(|cx| {
let mut waker_set = false;
loop {
if let Ok(runnable) = self.state.task_queue.pop() {
if waker_set {
self.state.mainstream_waker.take();
}
return Poll::Ready(Some(runnable));
}
if !waker_set {
self.state.mainstream_waker.register(cx.waker());
waker_set = true;
continue;
}
return Poll::Pending;
}
});
(&mut listener).or(mainstream_runnable).await
} else {
(&mut listener).await
}
};
if let Some(runnable) = runnable {
runnable.run();
if let Some(Some(runnable)) = futures_lite::future::poll_once(listener).await {
runnable.run();
}
return;
}
}
}
}
impl<'a, T> Drop for Ticker<'a, T> {
fn drop(&mut self) {
if self.is_mainstream {
let _ = self.state.mainstream_waker.take();
let mut thread_state = self.state.thread_state.borrow_mut();
thread_state.is_mainstream_listening = false;
thread_state.thread_waker.notify(1.tag_with(|| None));
}
}
}
pub unsafe trait ThreadId {
fn id(&self) -> Option<NonZeroUsize>;
}
unsafe impl<T: ThreadId + ?Sized> ThreadId for &T {
#[cfg_attr(coverage, no_coverage)]
fn id(&self) -> Option<NonZeroUsize> {
(**self).id()
}
}
unsafe impl<T: ThreadId + ?Sized> ThreadId for &mut T {
#[cfg_attr(coverage, no_coverage)]
fn id(&self) -> Option<NonZeroUsize> {
(**self).id()
}
}
unsafe impl<T: ThreadId + ?Sized> ThreadId for alloc::boxed::Box<T> {
#[cfg_attr(coverage, no_coverage)]
fn id(&self) -> Option<NonZeroUsize> {
(**self).id()
}
}
unsafe impl<T: ThreadId + ?Sized> ThreadId for Rc<T> {
#[cfg_attr(coverage, no_coverage)]
fn id(&self) -> Option<NonZeroUsize> {
(**self).id()
}
}
unsafe impl<T: ThreadId + ?Sized> ThreadId for Arc<T> {
#[cfg_attr(coverage, no_coverage)]
fn id(&self) -> Option<NonZeroUsize> {
(**self).id()
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct StdThreadId {
_private: (),
}
#[cfg(feature = "std")]
impl StdThreadId {
#[inline(always)]
pub fn new() -> Self {
Self { _private: () }
}
}
#[cfg(feature = "std")]
unsafe impl ThreadId for StdThreadId {
fn id(&self) -> Option<NonZeroUsize> {
std::thread_local! {
static LOCAL: u8 = 0x03;
}
LOCAL
.try_with(|x| {
unsafe { NonZeroUsize::new_unchecked(x as *const _ as usize) }
})
.ok()
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct NoThreadId {
_private: (),
}
impl NoThreadId {
#[inline(always)]
pub fn new() -> Self {
Self { _private: () }
}
}
unsafe impl ThreadId for NoThreadId {
#[inline(always)]
fn id(&self) -> Option<NonZeroUsize> {
None
}
}
#[doc(hidden)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct DefaultThreadId {
#[cfg(feature = "std")]
inner: StdThreadId,
#[cfg(not(feature = "std"))]
inner: NoThreadId,
}
impl DefaultThreadId {
#[inline(always)]
pub fn new() -> Self {
Self::default()
}
}
unsafe impl ThreadId for DefaultThreadId {
#[inline(always)]
fn id(&self) -> Option<NonZeroUsize> {
self.inner.id()
}
}
struct CallOnDrop<F: FnMut()>(F);
impl<F: FnMut()> Drop for CallOnDrop<F> {
fn drop(&mut self) {
(self.0)();
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_lite::future;
#[test]
fn smoke() {
future::block_on(async {
let slot = RefCell::new(0);
let ex = Executor::new();
let t1 = ex.spawn(async {
*slot.borrow_mut() += 1;
17
});
let t2 = ex.spawn(async {
*slot.borrow_mut() += 2;
18
});
ex.run(async {
assert_eq!(t1.await, 17);
assert_eq!(t2.await, 18);
assert!(ex.is_empty());
assert_eq!(*slot.borrow(), 3);
})
.await;
});
}
#[test]
fn smoke_no_thread_id() {
future::block_on(async {
let slot = RefCell::new(0);
let ex = Executor::with_thread_id(NoThreadId::new());
let t1 = ex.spawn(async {
*slot.borrow_mut() += 1;
17
});
let t2 = ex.spawn(async {
*slot.borrow_mut() += 2;
18
});
ex.run(async {
assert_eq!(t1.await, 17);
assert_eq!(t2.await, 18);
assert!(ex.is_empty());
assert_eq!(*slot.borrow(), 3);
})
.await;
});
}
#[cfg(feature = "std")]
#[test]
fn try_tick() {
use std::thread;
future::block_on(async {
let ex = Executor::new();
assert!(!ex.try_tick());
let task = ex.spawn({
let mut polls_left = 5;
async move {
while polls_left > 0 {
future::yield_now().await;
polls_left -= 1;
}
}
});
ex.run(async {
assert!(ex.try_tick());
assert!(ex.try_tick());
thread::spawn(move || {
future::block_on(task);
});
assert!(ex.try_tick());
ex.tick().await;
while !ex.is_empty() {
ex.tick().await;
}
})
.await;
})
}
#[test]
fn default_smoke() {
let _: Executor<'_, NoThreadId> = Executor::default();
}
}