use core::future::Future;
use core::mem::MaybeUninit;
use core::ptr::addr_of_mut;
use core::sync::atomic::{AtomicU8, AtomicU16, AtomicU32, Ordering};
use core::task::Poll;
use atomic_waker::AtomicWaker;
pub trait ReplyRouter {
const HEADER_LEN: usize;
type SlotHandle<'a>: RouterSlotHandle
where
Self: 'a;
type Error;
fn acquire(&self) -> impl Future<Output = Result<Self::SlotHandle<'_>, Self::Error>>;
fn try_acquire(&self) -> Option<Self::SlotHandle<'_>>;
fn write_header(slot: &Self::SlotHandle<'_>, buf: &mut [u8]);
fn parse_header(buf: &[u8]) -> u8;
fn deliver(&self, slot_id: u8, payload: &[u8]);
}
pub trait RouterSlotHandle {
fn slot_id(&self) -> u8;
fn recv_reply(&self) -> impl Future<Output = &[u8]>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Sequential;
impl ReplyRouter for Sequential {
const HEADER_LEN: usize = 0;
type SlotHandle<'a> = SequentialSlot;
type Error = core::convert::Infallible;
async fn acquire(&self) -> Result<SequentialSlot, core::convert::Infallible> {
Ok(SequentialSlot)
}
fn try_acquire(&self) -> Option<SequentialSlot> {
Some(SequentialSlot)
}
fn write_header(_slot: &SequentialSlot, _buf: &mut [u8]) {
}
fn parse_header(_buf: &[u8]) -> u8 {
0
}
fn deliver(&self, _slot_id: u8, _payload: &[u8]) {
}
}
pub struct SequentialSlot;
impl RouterSlotHandle for SequentialSlot {
fn slot_id(&self) -> u8 {
0
}
async fn recv_reply(&self) -> &[u8] {
&[]
}
}
struct SlotCell<const BUF: usize> {
data: core::cell::UnsafeCell<[u8; BUF]>,
len: AtomicU16,
waker: AtomicWaker,
generation: AtomicU8,
}
impl<const BUF: usize> SlotCell<BUF> {
#[allow(clippy::new_without_default)]
fn new() -> Self {
Self {
data: core::cell::UnsafeCell::new([0u8; BUF]),
len: AtomicU16::new(0),
waker: AtomicWaker::new(),
generation: AtomicU8::new(0),
}
}
unsafe fn init_in_place(ptr: *mut SlotCell<BUF>) {
unsafe {
core::ptr::write_bytes(ptr as *mut u8, 0, core::mem::size_of::<SlotCell<BUF>>());
}
unsafe {
addr_of_mut!((*ptr).waker).write(AtomicWaker::new());
}
}
}
pub struct MuxedSlots<const N: usize, const BUF: usize> {
bitmap: AtomicU32,
slots: [SlotCell<BUF>; N],
alloc_waker: AtomicWaker,
}
unsafe impl<const N: usize, const BUF: usize> Send for MuxedSlots<N, BUF> {}
unsafe impl<const N: usize, const BUF: usize> Sync for MuxedSlots<N, BUF> {}
impl<const N: usize, const BUF: usize> MuxedSlots<N, BUF> {
const _ASSERT_N_LE_32: () = assert!(N <= 32, "MuxedSlots: N must be <= 32");
pub fn new() -> Self {
let () = Self::_ASSERT_N_LE_32;
Self {
bitmap: AtomicU32::new(0),
slots: core::array::from_fn(|_| SlotCell::new()),
alloc_waker: AtomicWaker::new(),
}
}
pub fn new_boxed() -> Box<Self> {
let () = Self::_ASSERT_N_LE_32;
let mut uninit: Box<MaybeUninit<Self>> = Box::new_uninit();
let ptr: *mut Self = uninit.as_mut_ptr();
unsafe {
addr_of_mut!((*ptr).bitmap).write(AtomicU32::new(0));
addr_of_mut!((*ptr).alloc_waker).write(AtomicWaker::new());
let slots_ptr: *mut SlotCell<BUF> = addr_of_mut!((*ptr).slots) as *mut SlotCell<BUF>;
for i in 0..N {
SlotCell::<BUF>::init_in_place(slots_ptr.add(i));
}
let raw = Box::into_raw(uninit) as *mut Self;
Box::from_raw(raw)
}
}
const fn valid_mask() -> u32 {
if N == 32 { u32::MAX } else { (1u32 << N) - 1 }
}
fn try_alloc(&self) -> Option<(u8, u8)> {
loop {
let current = self.bitmap.load(Ordering::Acquire);
let free = (!current) & Self::valid_mask();
if free == 0 {
return None;
}
let idx = free.trailing_zeros() as u8;
let new = current | (1u32 << idx);
match self.bitmap.compare_exchange_weak(
current,
new,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
let gen_val = self.slots[idx as usize].generation.load(Ordering::Acquire);
return Some((idx, gen_val));
}
Err(_) => continue, }
}
}
fn free_slot(&self, idx: u8) {
self.slots[idx as usize].len.store(0, Ordering::Release);
self.slots[idx as usize]
.generation
.fetch_add(1, Ordering::Release);
self.bitmap.fetch_and(!(1u32 << idx), Ordering::Release);
self.alloc_waker.wake();
}
pub fn try_recv_slot(&self, slot_id: u8) -> Option<&[u8]> {
if (slot_id as usize) >= N {
return None;
}
let slot = &self.slots[slot_id as usize];
let len = slot.len.load(Ordering::Acquire);
if len > 0 {
let arr = unsafe { &*slot.data.get() };
Some(&arr[..len as usize])
} else {
None
}
}
}
impl<const N: usize, const BUF: usize> Default for MuxedSlots<N, BUF> {
fn default() -> Self {
Self::new()
}
}
impl<const N: usize, const BUF: usize> ReplyRouter for MuxedSlots<N, BUF> {
const HEADER_LEN: usize = 1;
type SlotHandle<'a> = MuxedSlotGuard<'a, N, BUF>;
type Error = core::convert::Infallible;
async fn acquire(&self) -> Result<MuxedSlotGuard<'_, N, BUF>, core::convert::Infallible> {
if let Some((idx, gen_val)) = self.try_alloc() {
return Ok(MuxedSlotGuard {
router: self,
index: idx,
generation: gen_val,
});
}
core::future::poll_fn(|cx| {
self.alloc_waker.register(cx.waker());
if let Some((idx, gen_val)) = self.try_alloc() {
return Poll::Ready(Ok(MuxedSlotGuard {
router: self,
index: idx,
generation: gen_val,
}));
}
Poll::Pending
})
.await
}
fn try_acquire(&self) -> Option<MuxedSlotGuard<'_, N, BUF>> {
self.try_alloc().map(|(idx, gen_val)| MuxedSlotGuard {
router: self,
index: idx,
generation: gen_val,
})
}
fn write_header(slot: &MuxedSlotGuard<'_, N, BUF>, buf: &mut [u8]) {
buf[0] = slot.index;
}
fn parse_header(buf: &[u8]) -> u8 {
buf[0]
}
fn deliver(&self, slot_id: u8, payload: &[u8]) {
debug_assert!((slot_id as usize) < N);
if slot_id as usize >= N {
return;
}
let bitmap = self.bitmap.load(Ordering::Acquire);
if bitmap & (1u32 << slot_id) == 0 {
return;
}
let slot = &self.slots[slot_id as usize];
let gen_before = slot.generation.load(Ordering::Acquire);
let len = payload.len().min(BUF);
unsafe {
let data = &mut *slot.data.get();
data[..len].copy_from_slice(&payload[..len]);
}
let gen_after = slot.generation.load(Ordering::Acquire);
if gen_before != gen_after {
return;
}
slot.len.store(len as u16, Ordering::Release);
slot.waker.wake();
}
}
pub struct MuxedSlotGuard<'a, const N: usize, const BUF: usize> {
router: &'a MuxedSlots<N, BUF>,
index: u8,
#[allow(dead_code)]
generation: u8,
}
impl<const N: usize, const BUF: usize> Drop for MuxedSlotGuard<'_, N, BUF> {
fn drop(&mut self) {
self.router.free_slot(self.index);
}
}
impl<const N: usize, const BUF: usize> RouterSlotHandle for MuxedSlotGuard<'_, N, BUF> {
fn slot_id(&self) -> u8 {
self.index
}
async fn recv_reply(&self) -> &[u8] {
core::future::poll_fn(|cx| {
let slot = &self.router.slots[self.index as usize];
let len = slot.len.load(Ordering::Acquire);
if len > 0 {
let arr = unsafe { &*slot.data.get() };
let data = &arr[..len as usize];
Poll::Ready(data)
} else {
slot.waker.register(cx.waker());
let len = slot.len.load(Ordering::Acquire);
if len > 0 {
let arr = unsafe { &*slot.data.get() };
let data = &arr[..len as usize];
Poll::Ready(data)
} else {
Poll::Pending
}
}
})
.await
}
}
#[derive(Debug, Clone, Copy)]
pub struct MuxedReplyToken {
slot_id: u8,
}
impl MuxedReplyToken {
pub fn new(slot_id: u8) -> Self {
Self { slot_id }
}
pub fn slot_id(&self) -> u8 {
self.slot_id
}
}
pub type MuxedSlots8<const BUF: usize> = MuxedSlots<8, BUF>;
pub type MuxedSlots4<const BUF: usize> = MuxedSlots<4, BUF>;
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn bitmap_alloc_free() {
let router = MuxedSlots::<4, 64>::new();
let s0 = router.try_acquire().expect("slot 0");
let s1 = router.try_acquire().expect("slot 1");
let s2 = router.try_acquire().expect("slot 2");
let s3 = router.try_acquire().expect("slot 3");
let mut ids = [s0.slot_id(), s1.slot_id(), s2.slot_id(), s3.slot_id()];
ids.sort();
assert_eq!(ids, [0, 1, 2, 3]);
assert!(router.try_acquire().is_none());
let id2 = s2.slot_id();
drop(s2);
let s2b = router.try_acquire().expect("reacquired slot");
assert_eq!(s2b.slot_id(), id2);
assert!(router.try_acquire().is_none());
drop(s0);
drop(s1);
drop(s2b);
drop(s3);
assert_eq!(router.bitmap.load(Ordering::Relaxed), 0);
}
#[test]
fn sequential_zero_overhead() {
assert_eq!(Sequential::HEADER_LEN, 0);
let seq = Sequential;
let slot = seq.try_acquire().unwrap();
assert_eq!(slot.slot_id(), 0);
}
#[test]
fn muxed_header_round_trip() {
let router = MuxedSlots::<8, 64>::new();
let slot = router.try_acquire().unwrap();
let mut buf = [0u8; 1];
MuxedSlots::<8, 64>::write_header(&slot, &mut buf);
let parsed = MuxedSlots::<8, 64>::parse_header(&buf);
assert_eq!(parsed, slot.slot_id());
}
#[test]
fn deliver_and_recv() {
let router = MuxedSlots::<4, 64>::new();
let slot = router.try_acquire().unwrap();
let id = slot.slot_id();
router.deliver(id, b"hello");
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let reply = rt.block_on(slot.recv_reply());
assert_eq!(reply, b"hello");
}
#[test]
fn deliver_to_freed_slot_is_discarded() {
let router = MuxedSlots::<4, 64>::new();
let slot = router.try_acquire().unwrap();
let id = slot.slot_id();
drop(slot);
router.deliver(id, b"orphaned reply");
assert_eq!(router.bitmap.load(Ordering::Relaxed), 0);
}
#[test]
fn generation_counter_prevents_stale_delivery() {
let router = MuxedSlots::<1, 64>::new();
let slot1 = router.try_acquire().unwrap();
let id = slot1.slot_id();
let gen1 = slot1.generation;
drop(slot1);
let slot2 = router.try_acquire().unwrap();
assert_eq!(slot2.slot_id(), id); assert_eq!(slot2.generation, gen1.wrapping_add(1));
drop(slot2);
let gen_now = router.slots[id as usize].generation.load(Ordering::Relaxed);
assert_eq!(gen_now, gen1.wrapping_add(2));
}
#[tokio::test]
async fn concurrent_acquire_slots() {
let router = Arc::new(MuxedSlots::<2, 64>::new());
let (tx1, rx1) = tokio::sync::oneshot::channel();
let (tx2, rx2) = tokio::sync::oneshot::channel();
let (done_tx1, done_rx1) = tokio::sync::oneshot::channel::<()>();
let (done_tx2, done_rx2) = tokio::sync::oneshot::channel::<()>();
let r1 = router.clone();
tokio::spawn(async move {
let slot = r1.acquire().await.unwrap();
tx1.send(slot.slot_id()).unwrap();
let _ = done_rx1.await;
drop(slot);
});
let r2 = router.clone();
tokio::spawn(async move {
let slot = r2.acquire().await.unwrap();
tx2.send(slot.slot_id()).unwrap();
let _ = done_rx2.await;
drop(slot);
});
let id1 = rx1.await.unwrap();
let id2 = rx2.await.unwrap();
assert_ne!(id1, id2);
let _ = done_tx1.send(());
let _ = done_tx2.send(());
}
#[tokio::test]
async fn acquire_blocks_when_full_then_wakes() {
let router = Arc::new(MuxedSlots::<1, 64>::new());
let slot = router.acquire().await.unwrap();
let id = slot.slot_id();
let r2 = router.clone();
let handle = tokio::spawn(async move {
let s = r2.acquire().await.unwrap();
s.slot_id()
});
tokio::task::yield_now().await;
drop(slot);
let id2 = handle.await.unwrap();
assert_eq!(id2, id); }
#[tokio::test]
async fn single_waiter_chain_works() {
let router = Arc::new(MuxedSlots::<1, 64>::new());
let slot = router.acquire().await.unwrap();
let r2 = router.clone();
let handle = tokio::spawn(async move {
let s = r2.acquire().await.unwrap();
let id = s.slot_id();
drop(s);
id
});
tokio::task::yield_now().await;
drop(slot);
let id = handle.await.unwrap();
assert_eq!(id, 0);
}
#[tokio::test]
async fn cancel_safety_drop_after_acquire() {
let router = MuxedSlots::<2, 64>::new();
{
let _slot = router.acquire().await.unwrap();
}
assert_eq!(router.bitmap.load(Ordering::Relaxed), 0);
let slot = router.try_acquire().unwrap();
assert!(slot.slot_id() < 2);
}
#[test]
fn muxed_reply_token_round_trip() {
let token = MuxedReplyToken::new(7);
assert_eq!(token.slot_id(), 7);
}
#[test]
fn max_32_slots() {
let router = MuxedSlots::<32, 16>::new();
let mut guards = Vec::new();
for _ in 0..32 {
guards.push(router.try_acquire().unwrap());
}
assert!(router.try_acquire().is_none());
let mut ids: Vec<u8> = guards.iter().map(|g| g.slot_id()).collect();
ids.sort();
let expected: Vec<u8> = (0..32).collect();
assert_eq!(ids, expected);
}
#[test]
fn new_boxed_produces_equivalent_state() {
let router: Box<MuxedSlots<4, 64>> = MuxedSlots::<4, 64>::new_boxed();
assert_eq!(router.bitmap.load(Ordering::Relaxed), 0);
for i in 0..4 {
assert_eq!(router.slots[i].len.load(Ordering::Relaxed), 0);
assert_eq!(router.slots[i].generation.load(Ordering::Relaxed), 0);
}
let s0 = router.try_acquire().expect("slot 0");
let s1 = router.try_acquire().expect("slot 1");
let s2 = router.try_acquire().expect("slot 2");
let s3 = router.try_acquire().expect("slot 3");
assert!(router.try_acquire().is_none());
let mut ids = [s0.slot_id(), s1.slot_id(), s2.slot_id(), s3.slot_id()];
ids.sort();
assert_eq!(ids, [0, 1, 2, 3]);
router.deliver(s1.slot_id(), b"hello");
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let reply = rt.block_on(s1.recv_reply());
assert_eq!(reply, b"hello");
drop(s0);
drop(s1);
drop(s2);
drop(s3);
assert_eq!(router.bitmap.load(Ordering::Relaxed), 0);
}
#[test]
fn new_boxed_on_restricted_stack() {
std::thread::Builder::new()
.stack_size(1 << 20) .spawn(|| {
let router: Box<MuxedSlots<32, 131_072>> = MuxedSlots::<32, 131_072>::new_boxed();
assert_eq!(router.bitmap.load(Ordering::Relaxed), 0);
let slot = router.try_acquire().expect("slot on big router");
assert!(slot.slot_id() < 32);
})
.expect("spawn restricted-stack thread")
.join()
.expect("restricted-stack thread panicked (likely stack overflow)");
}
}