use core::convert::Infallible;
use core::fmt::Debug;
use core::future::Future;
use core::ops::{Add, Div};
use core::pin::Pin;
use core::sync::atomic::{AtomicUsize, Ordering};
use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use futures::task::AtomicWaker;
use generic_array::{ArrayLength, GenericArray};
use typenum::operator_aliases::{Quot, Sum};
use typenum::{Const, ToUInt, U};
use crate::datastore::generational;
type UsizeBits = U<{ usize::BITS as usize }>;
type UsizeBitsMinusOne = typenum::operator_aliases::Sub1<UsizeBits>;
trait AddUsizeBitsMinusOne {
type Output;
}
trait DivCeilUsizeBits {
type Output;
}
trait Internal {
type LengthInWords: ArrayLength;
}
impl<const LEN: usize> AddUsizeBitsMinusOne for Const<LEN>
where
Const<LEN>: ToUInt<Output: Add<UsizeBitsMinusOne>>,
{
type Output = Sum<U<LEN>, UsizeBitsMinusOne>;
}
impl<const LEN: usize> DivCeilUsizeBits for Const<LEN>
where
Const<LEN>: AddUsizeBitsMinusOne<Output: Div<UsizeBits>>,
{
type Output = Quot<<Const<LEN> as AddUsizeBitsMinusOne>::Output, UsizeBits>;
}
impl<const LEN: usize> Internal for Const<LEN>
where
Const<LEN>: DivCeilUsizeBits<Output: ArrayLength>,
{
type LengthInWords = <Const<LEN> as DivCeilUsizeBits>::Output;
}
type LengthInWords<const LEN: usize> = <Const<LEN> as Internal>::LengthInWords;
#[derive(Debug)]
struct WakerShared<const LEN: usize>
where
Const<LEN>: Internal,
{
waker: AtomicWaker,
active: GenericArray<AtomicUsize, LengthInWords<LEN>>,
}
fn get_active_index_and_mask(index: usize) -> (usize, usize) {
let word_index = index / usize::BITS as usize;
let bit_index = index % usize::BITS as usize;
(word_index, 1 << bit_index)
}
impl<const LEN: usize> WakerShared<LEN>
where
Const<LEN>: Internal,
{
const fn new() -> Self {
let active = {
let mut active = GenericArray::uninit();
let mut index = 0;
let slice = active.as_mut_slice();
while index < slice.len() {
slice[index].write(AtomicUsize::new(usize::MAX));
index += 1;
}
unsafe { GenericArray::assume_init(active) }
};
Self {
waker: AtomicWaker::new(),
active,
}
}
fn reset(&self, index: usize) -> bool {
let (active_word, mask) = self.get_active_ref_and_mask(index);
let previous_value = active_word.fetch_and(!mask, Ordering::Relaxed);
(previous_value & mask) != 0
}
fn reset_all(&self) -> impl Iterator<Item = usize> + use<'_, LEN> {
(0..LEN).filter(|&index| self.reset(index))
}
fn set(&self, index: usize) -> bool {
let (active_word, mask) = self.get_active_ref_and_mask(index);
let previous_value = active_word.fetch_or(mask, Ordering::Relaxed);
self.waker.wake();
(previous_value & mask) != 0
}
async fn register_current(&self) {
core::future::poll_fn(|ctx| {
self.waker.register(ctx.waker());
Poll::Ready(())
})
.await;
}
fn get_active_ref_and_mask(&self, index: usize) -> (&AtomicUsize, usize) {
let (index, mask) = get_active_index_and_mask(index);
(&self.active[index], mask)
}
}
#[derive(Debug)]
struct BitWaker<const LEN: usize>
where
Const<LEN>: Internal,
{
index: usize,
shared: Option<&'static WakerShared<LEN>>,
}
impl<const LEN: usize> BitWaker<LEN>
where
Const<LEN>: Internal,
{
const VTABLE: &RawWakerVTable = &RawWakerVTable::new(
|ptr| RawWaker::new(ptr, Self::VTABLE),
|ptr| unsafe { &*ptr.cast::<Self>() }.wake_by_ref(),
|ptr| unsafe { &*ptr.cast::<Self>() }.wake_by_ref(),
|_| {},
);
const fn invalid() -> Self {
Self {
index: usize::MAX,
shared: None,
}
}
const fn new(index: usize, shared: &'static WakerShared<LEN>) -> Self {
assert!(index < LEN, "Future index out of bounds.");
Self {
index,
shared: Some(shared),
}
}
fn wake_by_ref(&self) {
self.shared.unwrap().set(self.index);
}
fn as_waker(&'static self) -> Waker {
let pointer = (&raw const *self).cast();
unsafe { Waker::new(pointer, Self::VTABLE) }
}
}
#[derive(Debug)]
#[expect(private_bounds)]
pub struct ExecutorShared<const LEN: usize>
where
Const<LEN>: Internal,
{
shared: WakerShared<LEN>,
bit_wakers: [BitWaker<LEN>; LEN],
}
#[expect(private_bounds)]
impl<const LEN: usize> ExecutorShared<LEN>
where
Const<LEN>: Internal,
{
pub const fn new(&'static self) -> Self {
let mut bit_wakers = [const { BitWaker::invalid() }; LEN];
let mut index = 0;
while index < LEN {
bit_wakers[index] = BitWaker::new(index, &self.shared);
index += 1;
}
Self {
shared: WakerShared::new(),
bit_wakers,
}
}
}
#[expect(private_bounds)]
pub struct Executor<'a, const LEN: usize>
where
Const<LEN>: Internal,
{
source: Pin<&'a generational::Source>,
shared: &'static ExecutorShared<LEN>,
futures: [Pin<&'a mut (dyn Future<Output = Infallible> + 'a)>; LEN],
}
impl<const LEN: usize> core::fmt::Debug for Executor<'_, LEN>
where
Const<LEN>: Internal,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Executor")
.field("source", &self.source)
.field("shared", &self.shared)
.field("futures", &"<opaque>")
.finish()
}
}
#[expect(private_bounds)]
impl<'a, const LEN: usize> Executor<'a, LEN>
where
Const<LEN>: Internal,
{
pub fn new(
shared: &'static ExecutorShared<LEN>,
source: Pin<&'a generational::Source>,
futures: [Pin<&'a mut (dyn Future<Output = Infallible> + 'a)>; LEN],
) -> Self {
Self {
source,
shared,
futures,
}
}
pub(crate) fn run_once(&mut self) -> bool {
let mut polled = false;
for index in self.shared.shared.reset_all() {
let future = &mut self.futures[index];
let waker = self.shared.bit_wakers[index].as_waker();
let mut context = Context::from_waker(&waker);
match future.as_mut().poll(&mut context) {
Poll::Pending => {}
}
polled = true;
}
self.source.increment_generation();
polled
}
pub async fn run(mut self) -> ! {
loop {
self.shared.shared.register_current().await;
self.run_once();
let mut yielded = false;
core::future::poll_fn(|_| {
if yielded {
Poll::Ready(())
} else {
yielded = true;
Poll::Pending
}
})
.await;
}
}
}
#[cfg(test)]
mod tests {
use core::pin::pin;
use core::task::Poll;
use std::vec::Vec;
use super::{BitWaker, Executor, ExecutorShared, WakerShared, get_active_index_and_mask};
use crate::datastore::generational;
const TWO_WORDS: usize = usize::BITS as usize * 2;
#[test]
fn calculate_indices() {
assert_eq!(get_active_index_and_mask(0), (0, 1 << 0));
assert_eq!(get_active_index_and_mask(1), (0, 1 << 1));
assert_eq!(
get_active_index_and_mask(usize::BITS as usize - 1),
(0, 1 << (usize::BITS as usize - 1))
);
assert_eq!(get_active_index_and_mask(usize::BITS as usize), (1, 1 << 0));
assert_eq!(
get_active_index_and_mask(usize::BITS as usize + 1),
(1, 1 << 1)
);
}
#[test]
fn waker_shared_initializes_as_all_awake() {
assert_eq!(
Vec::from_iter(WakerShared::<0>::new().reset_all()),
Vec::<usize>::new()
);
assert_eq!(
Vec::from_iter(WakerShared::<1>::new().reset_all()),
Vec::from_iter(0..1)
);
assert_eq!(
Vec::from_iter(WakerShared::<{ usize::BITS as usize - 1 }>::new().reset_all()),
Vec::from_iter(0..usize::BITS as usize - 1)
);
assert_eq!(
Vec::from_iter(WakerShared::<{ usize::BITS as usize }>::new().reset_all()),
Vec::from_iter(0..usize::BITS as usize)
);
assert_eq!(
Vec::from_iter(WakerShared::<{ usize::BITS as usize + 1 }>::new().reset_all()),
Vec::from_iter(0..usize::BITS as usize + 1)
);
}
#[test]
fn bitwaker_valid_indexes() {
static SHARED: WakerShared<TWO_WORDS> = WakerShared::new();
let mut i = 0;
while i < TWO_WORDS {
BitWaker::new(i, &SHARED).wake_by_ref();
i += 1;
}
assert!(std::panic::catch_unwind(|| BitWaker::new(i, &SHARED)).is_err());
}
#[test]
fn extra_code_coverage() {
static SHARED: ExecutorShared<1> = ExecutorShared::new(&SHARED);
let _ = ExecutorShared::new(&SHARED);
let source = pin!(generational::Source::new());
let futures = [pin!(async move { core::future::pending().await }) as _];
let executor = Executor::new(&SHARED, source.as_ref(), futures);
let _ = std::format!("{executor:?}");
let _ = BitWaker::<1>::invalid();
}
#[cfg(not(miri))] #[test]
fn executor() {
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn({
move || {
let source = pin!(generational::Source::new());
static SHARED: ExecutorShared<1> = ExecutorShared::new(&SHARED);
let futures = [pin!(async move {
let mut yielded = false;
core::future::poll_fn(|cx| {
if yielded {
Poll::Ready(())
} else {
yielded = true;
cx.waker().wake_by_ref();
#[expect(clippy::waker_clone_wake)]
cx.waker().clone().wake();
Poll::Pending
}
})
.await;
let _ = tx.send(());
std::future::pending().await
}) as _];
let executor = Executor::new(&SHARED, source.as_ref(), futures);
futures::executor::block_on(executor.run());
}
});
assert!(rx.recv_timeout(std::time::Duration::from_secs(1)).is_ok());
}
}