#![doc = include_str!("../README.md")]
#![cfg_attr(not(test), no_std)]
use core::{
sync::atomic::Ordering::Relaxed,
task::{RawWaker, RawWakerVTable},
};
use portable_atomic::AtomicUsize;
struct MyWaker {
mask: usize,
map_ptr: *const AtomicUsize,
counter: AtomicUsize,
}
impl MyWaker {
fn new(mask: usize, map_ptr: *const AtomicUsize) -> Self {
Self {
mask,
map_ptr,
counter: AtomicUsize::new(0),
}
}
fn to_waker(&self) -> core::task::Waker {
self.counter.fetch_add(1, Relaxed);
unsafe { core::task::Waker::new(self as *const _ as *const (), &RAW_WAKER_VTABLE) }
}
}
const RAW_WAKER_VTABLE: RawWakerVTable =
RawWakerVTable::new(waker_clone, waker_wake, waker_wake_by_ref, waker_drop);
unsafe fn waker_clone(me: *const ()) -> RawWaker {
let waker = unsafe { &*(me as *const MyWaker) };
waker.counter.fetch_add(1, Relaxed);
RawWaker::new(me, &RAW_WAKER_VTABLE)
}
unsafe fn waker_wake(me: *const ()) {
unsafe {
waker_wake_by_ref(me);
waker_drop(me)
}
}
unsafe fn waker_wake_by_ref(me: *const ()) {
let waker = unsafe { &*(me as *const MyWaker) };
let map = unsafe { &*waker.map_ptr };
map.fetch_or(waker.mask, Relaxed);
}
unsafe fn waker_drop(me: *const ()) {
let waker = unsafe { &*(me as *const MyWaker) };
waker.counter.fetch_sub(1, Relaxed);
}
struct AssertLessThanSizeOfUsize<const N: usize>;
impl<const N: usize> AssertLessThanSizeOfUsize<N> {
const OK: () = assert!(
N <= (usize::BITS as usize),
"N must be less than size_of::<usize>()"
);
}
pub fn run_all<const N: usize>(tasks: [&mut dyn Future<Output = ()>; N]) {
let () = AssertLessThanSizeOfUsize::<N>::OK;
let mut live_tasks = (1 << N) - 1;
let ready_map = AtomicUsize::new(live_tasks);
let mut task_list: heapless::Vec<(&mut dyn Future<Output = ()>, MyWaker), N> =
heapless::Vec::new();
for (i, t) in tasks.into_iter().enumerate() {
let _ = task_list.push((t, MyWaker::new(1 << i, &ready_map as *const _)));
}
while live_tasks != 0 {
let mut mask = ready_map.swap(0, Relaxed);
if mask == 0 {
#[cfg(feature = "cortex-m")]
cortex_m::asm::wfe();
#[cfg(not(feature = "cortex-m"))]
core::hint::spin_loop();
continue;
}
while mask != 0 {
let task_mask = mask & (!mask + 1);
mask ^= task_mask;
let task_idx = task_mask.trailing_zeros() as usize;
let task = &mut task_list[task_idx];
let waker = task.1.to_waker();
let mut context = core::task::Context::from_waker(&waker);
let fut = unsafe { core::pin::Pin::new_unchecked(&mut *task.0) };
if let core::task::Poll::Ready(()) = fut.poll(&mut context) {
live_tasks ^= task_mask
}
}
}
assert!(task_list.iter().all(|v| v.1.counter.load(Relaxed) == 0));
}
#[cfg(test)]
mod test {
use core::task::{Context, Poll};
use futures::task::AtomicWaker;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
#[derive(Default)]
struct TestFuture {
polled: AtomicBool,
completed: Arc<AtomicBool>,
waker: Arc<AtomicWaker>,
}
impl TestFuture {
fn new() -> Self {
Default::default()
}
fn complete_after_millis(&self, ms: u64) {
let cmplt = self.completed.clone();
let waker = self.waker.clone();
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(ms));
cmplt.store(true, Ordering::Relaxed);
waker.wake();
});
}
}
impl Future for TestFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.polled.store(true, Ordering::Relaxed);
if self.completed.load(Ordering::Relaxed) {
Poll::Ready(())
} else {
self.waker.register(cx.waker());
Poll::Pending
}
}
}
#[test]
fn one_task() {
let mut f = std::future::ready(());
super::run_all([&mut f]);
}
#[test]
fn multiple_tasks() {
let mut f1 = std::future::ready(());
let mut f2 = std::future::ready(());
super::run_all([&mut f1, &mut f2]);
}
#[test]
fn task_async_completion() {
let mut f = TestFuture::new();
f.complete_after_millis(100);
super::run_all([&mut f]);
assert!(f.polled.load(Ordering::Relaxed));
}
#[test]
fn no_task() {
super::run_all::<0>([]);
}
#[test]
fn task_concurrent_progress() {
let mut task1 = TestFuture::new();
let mut task2 = TestFuture::new();
task1.complete_after_millis(100);
task2.complete_after_millis(200);
super::run_all([&mut task1, &mut task2]);
assert!(task1.polled.load(Ordering::Relaxed));
assert!(task2.polled.load(Ordering::Relaxed));
}
#[test]
fn task_polled_exact_count() {
struct PartialFuture {
polled_count: AtomicUsize,
complete_after: usize,
}
impl PartialFuture {
fn new(complete_after: usize) -> Self {
Self {
polled_count: AtomicUsize::new(0),
complete_after,
}
}
}
impl Future for PartialFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
cx.waker().wake_by_ref(); let count = self.polled_count.fetch_add(1, Ordering::Relaxed);
if count >= self.complete_after {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
let mut task = PartialFuture::new(3);
super::run_all([&mut task]);
assert_eq!(task.polled_count.load(Ordering::Relaxed), 4);
}
}