#[cfg_attr(target_has_atomic = "ptr", path = "run_queue_atomics.rs")]
#[cfg_attr(not(target_has_atomic = "ptr"), path = "run_queue_critical_section.rs")]
mod run_queue;
#[cfg_attr(all(cortex_m, target_has_atomic = "8"), path = "state_atomics_arm.rs")]
#[cfg_attr(all(not(cortex_m), target_has_atomic = "8"), path = "state_atomics.rs")]
#[cfg_attr(not(target_has_atomic = "8"), path = "state_critical_section.rs")]
mod state;
#[cfg(feature = "integrated-timers")]
mod timer_queue;
pub(crate) mod util;
#[cfg_attr(feature = "turbowakers", path = "waker_turbo.rs")]
mod waker;
use core::future::Future;
use core::marker::PhantomData;
use core::mem;
use core::pin::Pin;
use core::ptr::NonNull;
use core::task::{Context, Poll};
#[cfg(feature = "integrated-timers")]
use embassy_time_driver::{self, AlarmHandle};
#[cfg(feature = "rtos-trace")]
use rtos_trace::trace;
use self::run_queue::{RunQueue, RunQueueItem};
use self::state::State;
use self::util::{SyncUnsafeCell, UninitCell};
pub use self::waker::task_from_waker;
use super::SpawnToken;
pub(crate) struct TaskHeader {
pub(crate) state: State,
pub(crate) run_queue_item: RunQueueItem,
pub(crate) executor: SyncUnsafeCell<Option<&'static SyncExecutor>>,
poll_fn: SyncUnsafeCell<Option<unsafe fn(TaskRef)>>,
#[cfg(feature = "integrated-timers")]
pub(crate) expires_at: SyncUnsafeCell<u64>,
#[cfg(feature = "integrated-timers")]
pub(crate) timer_queue_item: timer_queue::TimerQueueItem,
}
#[derive(Clone, Copy)]
pub struct TaskRef {
ptr: NonNull<TaskHeader>,
}
unsafe impl Send for TaskRef where &'static TaskHeader: Send {}
unsafe impl Sync for TaskRef where &'static TaskHeader: Sync {}
impl TaskRef {
fn new<F: Future + 'static>(task: &'static TaskStorage<F>) -> Self {
Self {
ptr: NonNull::from(task).cast(),
}
}
pub(crate) unsafe fn from_ptr(ptr: *const TaskHeader) -> Self {
Self {
ptr: NonNull::new_unchecked(ptr as *mut TaskHeader),
}
}
pub(crate) fn header(self) -> &'static TaskHeader {
unsafe { self.ptr.as_ref() }
}
pub(crate) fn as_ptr(self) -> *const TaskHeader {
self.ptr.as_ptr()
}
}
#[repr(C)]
pub struct TaskStorage<F: Future + 'static> {
raw: TaskHeader,
future: UninitCell<F>, }
impl<F: Future + 'static> TaskStorage<F> {
const NEW: Self = Self::new();
pub const fn new() -> Self {
Self {
raw: TaskHeader {
state: State::new(),
run_queue_item: RunQueueItem::new(),
executor: SyncUnsafeCell::new(None),
poll_fn: SyncUnsafeCell::new(None),
#[cfg(feature = "integrated-timers")]
expires_at: SyncUnsafeCell::new(0),
#[cfg(feature = "integrated-timers")]
timer_queue_item: timer_queue::TimerQueueItem::new(),
},
future: UninitCell::uninit(),
}
}
pub fn spawn(&'static self, future: impl FnOnce() -> F) -> SpawnToken<impl Sized> {
let task = AvailableTask::claim(self);
match task {
Some(task) => task.initialize(future),
None => SpawnToken::new_failed(),
}
}
unsafe fn poll(p: TaskRef) {
let this = &*(p.as_ptr() as *const TaskStorage<F>);
let future = Pin::new_unchecked(this.future.as_mut());
let waker = waker::from_task(p);
let mut cx = Context::from_waker(&waker);
match future.poll(&mut cx) {
Poll::Ready(_) => {
this.future.drop_in_place();
this.raw.state.despawn();
#[cfg(feature = "integrated-timers")]
this.raw.expires_at.set(u64::MAX);
}
Poll::Pending => {}
}
mem::forget(waker);
}
#[doc(hidden)]
#[allow(dead_code)]
fn _assert_sync(self) {
fn assert_sync<T: Sync>(_: T) {}
assert_sync(self)
}
}
pub struct AvailableTask<F: Future + 'static> {
task: &'static TaskStorage<F>,
}
impl<F: Future + 'static> AvailableTask<F> {
pub fn claim(task: &'static TaskStorage<F>) -> Option<Self> {
task.raw.state.spawn().then(|| Self { task })
}
fn initialize_impl<S>(self, future: impl FnOnce() -> F) -> SpawnToken<S> {
unsafe {
self.task.raw.poll_fn.set(Some(TaskStorage::<F>::poll));
self.task.future.write_in_place(future);
let task = TaskRef::new(self.task);
SpawnToken::new(task)
}
}
pub fn initialize(self, future: impl FnOnce() -> F) -> SpawnToken<F> {
self.initialize_impl::<F>(future)
}
#[doc(hidden)]
pub unsafe fn __initialize_async_fn<FutFn>(self, future: impl FnOnce() -> F) -> SpawnToken<FutFn> {
self.initialize_impl::<FutFn>(future)
}
}
pub struct TaskPool<F: Future + 'static, const N: usize> {
pool: [TaskStorage<F>; N],
}
impl<F: Future + 'static, const N: usize> TaskPool<F, N> {
pub const fn new() -> Self {
Self {
pool: [TaskStorage::NEW; N],
}
}
fn spawn_impl<T>(&'static self, future: impl FnOnce() -> F) -> SpawnToken<T> {
match self.pool.iter().find_map(AvailableTask::claim) {
Some(task) => task.initialize_impl::<T>(future),
None => SpawnToken::new_failed(),
}
}
pub fn spawn(&'static self, future: impl FnOnce() -> F) -> SpawnToken<impl Sized> {
self.spawn_impl::<F>(future)
}
#[doc(hidden)]
pub unsafe fn _spawn_async_fn<FutFn>(&'static self, future: FutFn) -> SpawnToken<impl Sized>
where
FutFn: FnOnce() -> F,
{
self.spawn_impl::<FutFn>(future)
}
}
#[derive(Clone, Copy)]
pub(crate) struct Pender(*mut ());
unsafe impl Send for Pender {}
unsafe impl Sync for Pender {}
impl Pender {
pub(crate) fn pend(self) {
extern "Rust" {
fn __pender(context: *mut ());
}
unsafe { __pender(self.0) };
}
}
pub(crate) struct SyncExecutor {
run_queue: RunQueue,
pender: Pender,
#[cfg(feature = "integrated-timers")]
pub(crate) timer_queue: timer_queue::TimerQueue,
#[cfg(feature = "integrated-timers")]
alarm: AlarmHandle,
}
impl SyncExecutor {
pub(crate) fn new(pender: Pender) -> Self {
#[cfg(feature = "integrated-timers")]
let alarm = unsafe { unwrap!(embassy_time_driver::allocate_alarm()) };
Self {
run_queue: RunQueue::new(),
pender,
#[cfg(feature = "integrated-timers")]
timer_queue: timer_queue::TimerQueue::new(),
#[cfg(feature = "integrated-timers")]
alarm,
}
}
#[inline(always)]
unsafe fn enqueue(&self, task: TaskRef) {
#[cfg(feature = "rtos-trace")]
trace::task_ready_begin(task.as_ptr() as u32);
if self.run_queue.enqueue(task) {
self.pender.pend();
}
}
#[cfg(feature = "integrated-timers")]
fn alarm_callback(ctx: *mut ()) {
let this: &Self = unsafe { &*(ctx as *const Self) };
this.pender.pend();
}
pub(super) unsafe fn spawn(&'static self, task: TaskRef) {
task.header().executor.set(Some(self));
#[cfg(feature = "rtos-trace")]
trace::task_new(task.as_ptr() as u32);
self.enqueue(task);
}
pub(crate) unsafe fn poll(&'static self) {
#[cfg(feature = "integrated-timers")]
embassy_time_driver::set_alarm_callback(self.alarm, Self::alarm_callback, self as *const _ as *mut ());
#[allow(clippy::never_loop)]
loop {
#[cfg(feature = "integrated-timers")]
self.timer_queue
.dequeue_expired(embassy_time_driver::now(), wake_task_no_pend);
self.run_queue.dequeue_all(|p| {
let task = p.header();
#[cfg(feature = "integrated-timers")]
task.expires_at.set(u64::MAX);
if !task.state.run_dequeue() {
return;
}
#[cfg(feature = "rtos-trace")]
trace::task_exec_begin(p.as_ptr() as u32);
task.poll_fn.get().unwrap_unchecked()(p);
#[cfg(feature = "rtos-trace")]
trace::task_exec_end();
#[cfg(feature = "integrated-timers")]
self.timer_queue.update(p);
});
#[cfg(feature = "integrated-timers")]
{
let next_expiration = self.timer_queue.next_expiration();
if embassy_time_driver::set_alarm(self.alarm, next_expiration) {
break;
}
}
#[cfg(not(feature = "integrated-timers"))]
{
break;
}
}
#[cfg(feature = "rtos-trace")]
trace::system_idle();
}
}
#[repr(transparent)]
pub struct Executor {
pub(crate) inner: SyncExecutor,
_not_sync: PhantomData<*mut ()>,
}
impl Executor {
pub(crate) unsafe fn wrap(inner: &SyncExecutor) -> &Self {
mem::transmute(inner)
}
pub fn new(context: *mut ()) -> Self {
Self {
inner: SyncExecutor::new(Pender(context)),
_not_sync: PhantomData,
}
}
pub(super) unsafe fn spawn(&'static self, task: TaskRef) {
self.inner.spawn(task)
}
pub unsafe fn poll(&'static self) {
self.inner.poll()
}
pub fn spawner(&'static self) -> super::Spawner {
super::Spawner::new(self)
}
}
pub fn wake_task(task: TaskRef) {
let header = task.header();
if header.state.run_enqueue() {
unsafe {
let executor = header.executor.get().unwrap_unchecked();
executor.enqueue(task);
}
}
}
pub fn wake_task_no_pend(task: TaskRef) {
let header = task.header();
if header.state.run_enqueue() {
unsafe {
let executor = header.executor.get().unwrap_unchecked();
executor.run_queue.enqueue(task);
}
}
}
#[cfg(feature = "integrated-timers")]
struct TimerQueue;
#[cfg(feature = "integrated-timers")]
impl embassy_time_queue_driver::TimerQueue for TimerQueue {
fn schedule_wake(&'static self, at: u64, waker: &core::task::Waker) {
let task = waker::task_from_waker(waker);
let task = task.header();
unsafe {
let expires_at = task.expires_at.get();
task.expires_at.set(expires_at.min(at));
}
}
}
#[cfg(feature = "integrated-timers")]
embassy_time_queue_driver::timer_queue_impl!(static TIMER_QUEUE: TimerQueue = TimerQueue);
#[cfg(feature = "rtos-trace")]
impl rtos_trace::RtosTraceOSCallbacks for Executor {
fn task_list() {
}
#[cfg(feature = "integrated-timers")]
fn time() -> u64 {
Instant::now().as_micros()
}
#[cfg(not(feature = "integrated-timers"))]
fn time() -> u64 {
0
}
}
#[cfg(feature = "rtos-trace")]
rtos_trace::global_os_callbacks! {Executor}