#![cfg_attr(not(feature = "std"), no_std)]
#[cfg(all(feature = "heapless", feature = "unbounded"))]
compile_error!("Feature `heapless` is not compatible with feature `unbounded`.");
use core::future::{poll_fn, Future};
use core::marker::PhantomData;
use core::task::{Context, Poll};
extern crate alloc;
use alloc::rc::Rc;
use async_task::Runnable;
pub use async_task::{FallibleTask, Task};
use atomic_waker::AtomicWaker;
use futures_lite::FutureExt;
#[cfg(not(feature = "portable-atomic"))]
use alloc::sync::Arc;
#[cfg(feature = "portable-atomic")]
use portable_atomic_util::Arc;
use once_cell::sync::OnceCell;
#[cfg(feature = "std")]
pub use futures_lite::future::block_on;
pub struct Executor<'a, const C: usize = 64> {
state: OnceCell<Arc<State<C>>>,
_invariant: PhantomData<core::cell::UnsafeCell<&'a ()>>,
}
impl<'a, const C: usize> Executor<'a, C> {
pub const fn new() -> Self {
Self {
state: OnceCell::new(),
_invariant: PhantomData,
}
}
pub fn spawn<F>(&self, fut: F) -> Task<F::Output>
where
F: Future + Send + 'a,
F::Output: Send + 'a,
{
unsafe { self.spawn_unchecked(fut) }
}
pub fn try_tick(&self) -> bool {
if let Some(runnable) = self.try_runnable() {
runnable.run();
true
} else {
false
}
}
pub async fn tick(&self) {
self.runnable().await.run();
}
pub async fn run<F>(&self, fut: F) -> F::Output
where
F: Future + Send + 'a,
{
unsafe { self.run_unchecked(fut).await }
}
async fn runnable(&self) -> Runnable {
poll_fn(|ctx| self.poll_runnable(ctx)).await
}
fn poll_runnable(&self, ctx: &Context<'_>) -> Poll<Runnable> {
self.state().waker.register(ctx.waker());
if let Some(runnable) = self.try_runnable() {
Poll::Ready(runnable)
} else {
Poll::Pending
}
}
fn try_runnable(&self) -> Option<Runnable> {
let runnable;
#[cfg(not(feature = "heapless"))]
{
runnable = self.state().queue.pop();
}
#[cfg(feature = "heapless")]
{
runnable = self.state().queue.dequeue();
}
runnable
}
unsafe fn spawn_unchecked<F>(&self, fut: F) -> Task<F::Output>
where
F: Future,
{
let schedule = {
let state = self.state().clone();
move |runnable| {
#[cfg(all(not(feature = "heapless"), feature = "unbounded"))]
{
state.queue.push(runnable);
}
#[cfg(all(not(feature = "heapless"), not(feature = "unbounded")))]
{
state.queue.push(runnable).unwrap();
}
#[cfg(feature = "heapless")]
{
state.queue.enqueue(runnable).unwrap();
}
if let Some(waker) = state.waker.take() {
waker.wake();
}
}
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(fut, schedule) };
runnable.schedule();
task
}
async unsafe fn run_unchecked<F>(&self, fut: F) -> F::Output
where
F: Future,
{
let run_forever = async {
loop {
self.tick().await;
}
};
run_forever.or(fut).await
}
fn state(&self) -> &Arc<State<C>> {
self.state.get_or_init(|| Arc::new(State::new()))
}
}
impl<'a, const C: usize> Default for Executor<'a, C> {
fn default() -> Self {
Self::new()
}
}
unsafe impl<'a, const C: usize> Send for Executor<'a, C> {}
unsafe impl<'a, const C: usize> Sync for Executor<'a, C> {}
pub struct LocalExecutor<'a, const C: usize = 64> {
executor: Executor<'a, C>,
_not_send: PhantomData<core::cell::UnsafeCell<&'a Rc<()>>>,
}
#[allow(clippy::missing_safety_doc)]
impl<'a, const C: usize> LocalExecutor<'a, C> {
pub const fn new() -> Self {
Self {
executor: Executor::<C>::new(),
_not_send: PhantomData,
}
}
pub fn spawn<F>(&self, fut: F) -> Task<F::Output>
where
F: Future + 'a,
F::Output: 'a,
{
unsafe { self.executor.spawn_unchecked(fut) }
}
pub fn try_tick(&self) -> bool {
self.executor.try_tick()
}
pub async fn tick(&self) {
self.executor.tick().await
}
pub async fn run<F>(&self, fut: F) -> F::Output
where
F: Future,
{
unsafe { self.executor.run_unchecked(fut) }.await
}
}
impl<'a, const C: usize> Default for LocalExecutor<'a, C> {
fn default() -> Self {
Self::new()
}
}
struct State<const C: usize> {
#[cfg(all(not(feature = "heapless"), feature = "unbounded"))]
queue: crossbeam_queue::SegQueue<Runnable>,
#[cfg(all(not(feature = "heapless"), not(feature = "unbounded")))]
queue: crossbeam_queue::ArrayQueue<Runnable>,
#[cfg(feature = "heapless")]
queue: heapless::mpmc::MpMcQueue<Runnable, C>,
waker: AtomicWaker,
}
impl<const C: usize> State<C> {
fn new() -> Self {
Self {
#[cfg(all(not(feature = "heapless"), feature = "unbounded"))]
queue: crossbeam_queue::SegQueue::new(),
#[cfg(all(not(feature = "heapless"), not(feature = "unbounded")))]
queue: crossbeam_queue::ArrayQueue::new(C),
#[cfg(feature = "heapless")]
queue: heapless::mpmc::MpMcQueue::new(),
waker: AtomicWaker::new(),
}
}
}