mod run_queue;
#[cfg(feature = "integrated-timers")]
mod timer_queue;
pub(crate) mod util;
mod waker;
use core::cell::Cell;
use core::future::Future;
use core::pin::Pin;
use core::ptr::NonNull;
use core::task::{Context, Poll};
use core::{mem, ptr};
use atomic_polyfill::{AtomicU32, Ordering};
use critical_section::CriticalSection;
#[cfg(feature = "integrated-timers")]
use embassy_time::driver::{self, AlarmHandle};
#[cfg(feature = "integrated-timers")]
use embassy_time::Instant;
#[cfg(feature = "rtos-trace")]
use rtos_trace::trace;
use self::run_queue::{RunQueue, RunQueueItem};
use self::util::UninitCell;
pub use self::waker::task_from_waker;
use super::SpawnToken;
pub(crate) const STATE_SPAWNED: u32 = 1 << 0;
pub(crate) const STATE_RUN_QUEUED: u32 = 1 << 1;
#[cfg(feature = "integrated-timers")]
pub(crate) const STATE_TIMER_QUEUED: u32 = 1 << 2;
pub struct TaskHeader {
pub(crate) state: AtomicU32,
pub(crate) run_queue_item: RunQueueItem,
pub(crate) executor: Cell<*const Executor>, pub(crate) poll_fn: UninitCell<unsafe fn(NonNull<TaskHeader>)>,
#[cfg(feature = "integrated-timers")]
pub(crate) expires_at: Cell<Instant>,
#[cfg(feature = "integrated-timers")]
pub(crate) timer_queue_item: timer_queue::TimerQueueItem,
}
impl TaskHeader {
pub(crate) const fn new() -> Self {
Self {
state: AtomicU32::new(0),
run_queue_item: RunQueueItem::new(),
executor: Cell::new(ptr::null()),
poll_fn: UninitCell::uninit(),
#[cfg(feature = "integrated-timers")]
expires_at: Cell::new(Instant::from_ticks(0)),
#[cfg(feature = "integrated-timers")]
timer_queue_item: timer_queue::TimerQueueItem::new(),
}
}
}
#[repr(C)]
pub struct TaskStorage<F: Future + 'static> {
raw: TaskHeader,
future: UninitCell<F>, }
impl<F: Future + 'static> TaskStorage<F> {
const NEW: Self = Self::new();
pub const fn new() -> Self {
Self {
raw: TaskHeader::new(),
future: UninitCell::uninit(),
}
}
pub fn spawn(&'static self, future: impl FnOnce() -> F) -> SpawnToken<impl Sized> {
if self.spawn_mark_used() {
return unsafe { SpawnToken::<F>::new(self.spawn_initialize(future)) };
}
SpawnToken::<F>::new_failed()
}
fn spawn_mark_used(&'static self) -> bool {
let state = STATE_SPAWNED | STATE_RUN_QUEUED;
self.raw
.state
.compare_exchange(0, state, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
}
unsafe fn spawn_initialize(&'static self, future: impl FnOnce() -> F) -> NonNull<TaskHeader> {
self.raw.poll_fn.write(Self::poll);
self.future.write(future());
NonNull::new_unchecked(self as *const TaskStorage<F> as *const TaskHeader as *mut TaskHeader)
}
unsafe fn poll(p: NonNull<TaskHeader>) {
let this = &*(p.as_ptr() as *const TaskStorage<F>);
let future = Pin::new_unchecked(this.future.as_mut());
let waker = waker::from_task(p);
let mut cx = Context::from_waker(&waker);
match future.poll(&mut cx) {
Poll::Ready(_) => {
this.future.drop_in_place();
this.raw.state.fetch_and(!STATE_SPAWNED, Ordering::AcqRel);
}
Poll::Pending => {}
}
mem::forget(waker);
}
}
unsafe impl<F: Future + 'static> Sync for TaskStorage<F> {}
pub struct TaskPool<F: Future + 'static, const N: usize> {
pool: [TaskStorage<F>; N],
}
impl<F: Future + 'static, const N: usize> TaskPool<F, N> {
pub const fn new() -> Self {
Self {
pool: [TaskStorage::NEW; N],
}
}
pub fn spawn(&'static self, future: impl FnOnce() -> F) -> SpawnToken<impl Sized> {
for task in &self.pool {
if task.spawn_mark_used() {
return unsafe { SpawnToken::<F>::new(task.spawn_initialize(future)) };
}
}
SpawnToken::<F>::new_failed()
}
#[doc(hidden)]
pub unsafe fn _spawn_async_fn<FutFn>(&'static self, future: FutFn) -> SpawnToken<impl Sized>
where
FutFn: FnOnce() -> F,
{
for task in &self.pool {
if task.spawn_mark_used() {
return SpawnToken::<FutFn>::new(task.spawn_initialize(future));
}
}
SpawnToken::<FutFn>::new_failed()
}
}
pub struct Executor {
run_queue: RunQueue,
signal_fn: fn(*mut ()),
signal_ctx: *mut (),
#[cfg(feature = "integrated-timers")]
pub(crate) timer_queue: timer_queue::TimerQueue,
#[cfg(feature = "integrated-timers")]
alarm: AlarmHandle,
}
impl Executor {
pub fn new(signal_fn: fn(*mut ()), signal_ctx: *mut ()) -> Self {
#[cfg(feature = "integrated-timers")]
let alarm = unsafe { unwrap!(driver::allocate_alarm()) };
#[cfg(feature = "integrated-timers")]
driver::set_alarm_callback(alarm, signal_fn, signal_ctx);
Self {
run_queue: RunQueue::new(),
signal_fn,
signal_ctx,
#[cfg(feature = "integrated-timers")]
timer_queue: timer_queue::TimerQueue::new(),
#[cfg(feature = "integrated-timers")]
alarm,
}
}
#[inline(always)]
unsafe fn enqueue(&self, cs: CriticalSection, task: NonNull<TaskHeader>) {
#[cfg(feature = "rtos-trace")]
trace::task_ready_begin(task.as_ptr() as u32);
if self.run_queue.enqueue(cs, task) {
(self.signal_fn)(self.signal_ctx)
}
}
pub(super) unsafe fn spawn(&'static self, task: NonNull<TaskHeader>) {
task.as_ref().executor.set(self);
#[cfg(feature = "rtos-trace")]
trace::task_new(task.as_ptr() as u32);
critical_section::with(|cs| {
self.enqueue(cs, task);
})
}
pub unsafe fn poll(&'static self) {
loop {
#[cfg(feature = "integrated-timers")]
self.timer_queue.dequeue_expired(Instant::now(), |task| wake_task(task));
self.run_queue.dequeue_all(|p| {
let task = p.as_ref();
#[cfg(feature = "integrated-timers")]
task.expires_at.set(Instant::MAX);
let state = task.state.fetch_and(!STATE_RUN_QUEUED, Ordering::AcqRel);
if state & STATE_SPAWNED == 0 {
return;
}
#[cfg(feature = "rtos-trace")]
trace::task_exec_begin(p.as_ptr() as u32);
task.poll_fn.read()(p as _);
#[cfg(feature = "rtos-trace")]
trace::task_exec_end();
#[cfg(feature = "integrated-timers")]
self.timer_queue.update(p);
});
#[cfg(feature = "integrated-timers")]
{
let next_expiration = self.timer_queue.next_expiration();
if driver::set_alarm(self.alarm, next_expiration.as_ticks()) {
break;
}
}
#[cfg(not(feature = "integrated-timers"))]
{
break;
}
}
#[cfg(feature = "rtos-trace")]
trace::system_idle();
}
pub fn spawner(&'static self) -> super::Spawner {
super::Spawner::new(self)
}
}
pub unsafe fn wake_task(task: NonNull<TaskHeader>) {
critical_section::with(|cs| {
let header = task.as_ref();
let state = header.state.load(Ordering::Relaxed);
if (state & STATE_RUN_QUEUED != 0) || (state & STATE_SPAWNED == 0) {
return;
}
header.state.store(state | STATE_RUN_QUEUED, Ordering::Relaxed);
let executor = &*header.executor.get();
executor.enqueue(cs, task);
})
}
#[cfg(feature = "integrated-timers")]
struct TimerQueue;
#[cfg(feature = "integrated-timers")]
impl embassy_time::queue::TimerQueue for TimerQueue {
fn schedule_wake(&'static self, at: Instant, waker: &core::task::Waker) {
let task = waker::task_from_waker(waker);
let task = unsafe { task.as_ref() };
let expires_at = task.expires_at.get();
task.expires_at.set(expires_at.min(at));
}
}
#[cfg(feature = "integrated-timers")]
embassy_time::timer_queue_impl!(static TIMER_QUEUE: TimerQueue = TimerQueue);
#[cfg(feature = "rtos-trace")]
impl rtos_trace::RtosTraceOSCallbacks for Executor {
fn task_list() {
}
#[cfg(feature = "integrated-timers")]
fn time() -> u64 {
Instant::now().as_micros()
}
#[cfg(not(feature = "integrated-timers"))]
fn time() -> u64 {
0
}
}
#[cfg(feature = "rtos-trace")]
rtos_trace::global_os_callbacks! {Executor}