#![no_std]
mod macros;
use core::cell::UnsafeCell;
use core::future::poll_fn;
use core::ops::{Deref, DerefMut};
use core::task::Poll;
use embassy_sync::{
blocking_mutex::{Mutex, raw::RawMutex},
waitqueue::WakerRegistration,
};
#[repr(transparent)]
#[derive(Debug)]
struct BufferPtr<T: ?Sized>(*mut T);
unsafe impl<T: ?Sized> Send for BufferPtr<T> {}
unsafe impl<T: ?Sized> Sync for BufferPtr<T> {}
struct State {
available: u32,
waker: WakerRegistration,
}
pub struct BufferPool<M: RawMutex, T, const N: usize> {
buffer: UnsafeCell<[T; N]>,
state: Mutex<M, State>,
}
unsafe impl<M: RawMutex + Send, T: Send, const N: usize> Send for BufferPool<M, T, N> {}
unsafe impl<M: RawMutex + Sync, T: Send, const N: usize> Sync for BufferPool<M, T, N> {}
impl<M: RawMutex, T: Copy, const N: usize> BufferPool<M, T, N> {
pub const fn new(buffer: [T; N]) -> Self {
assert!(N > 0 && N <= 32);
Self {
buffer: UnsafeCell::new(buffer),
state: Mutex::new(State {
available: u32::MAX >> (32 - N),
waker: WakerRegistration::new(),
}),
}
}
pub fn try_take(&'static self) -> Option<BufferGuard<M, T>> {
unsafe {
self.state.lock_mut(|state| {
if state.available == 0 {
return None;
}
let index = state.available.trailing_zeros() as usize;
state.available &= !(1 << index);
let buffer = &mut (*self.buffer.get())[index];
Some(BufferGuard {
store: &self.state,
ptr: BufferPtr(buffer),
index,
})
})
}
}
pub fn take(&'static self) -> impl Future<Output = BufferGuard<M, T>> {
poll_fn(|cx| unsafe {
self.state.lock_mut(|state| {
if state.available == 0 {
state.waker.register(cx.waker());
return Poll::Pending;
}
let index = state.available.trailing_zeros() as usize;
state.available &= !(1 << index);
let buffer = &mut (*self.buffer.get())[index];
Poll::Ready(BufferGuard {
store: &self.state,
ptr: BufferPtr(buffer),
index,
})
})
})
}
}
pub struct BufferGuard<M: RawMutex + 'static, T> {
store: &'static Mutex<M, State>,
ptr: BufferPtr<T>,
index: usize,
}
impl<M: RawMutex + 'static, T> Drop for BufferGuard<M, T> {
fn drop(&mut self) {
unsafe {
self.store.lock_mut(|state| {
state.available |= 1 << self.index;
state.waker.wake();
});
}
}
}
impl<M: RawMutex + 'static, T> Deref for BufferGuard<M, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.ptr.0 }
}
}
impl<M: RawMutex + 'static, T> DerefMut for BufferGuard<M, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.ptr.0 }
}
}
impl<M: RawMutex + 'static, T> BufferGuard<M, T> {
pub fn map<U: ?Sized>(
orig: Self,
fun: impl FnOnce(&mut T) -> &mut U,
) -> MappedBufferGuard<M, U> {
let store = orig.store;
let index = orig.index;
let value = fun(unsafe { &mut *orig.ptr.0 });
core::mem::forget(orig);
MappedBufferGuard {
store,
value,
index,
}
}
}
pub struct MappedBufferGuard<M: RawMutex + 'static, U: ?Sized> {
store: &'static Mutex<M, State>,
index: usize,
value: *mut U,
}
impl<M: RawMutex + 'static, U: ?Sized> Drop for MappedBufferGuard<M, U> {
fn drop(&mut self) {
unsafe {
self.store.lock_mut(|state| {
state.available |= 1 << self.index;
state.waker.wake();
});
}
}
}
impl<M: RawMutex + 'static, U: ?Sized> Deref for MappedBufferGuard<M, U> {
type Target = U;
fn deref(&self) -> &Self::Target {
unsafe { &*self.value }
}
}
impl<M: RawMutex + 'static, U: ?Sized> DerefMut for MappedBufferGuard<M, U> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.value }
}
}
impl<M: RawMutex + 'static, U: ?Sized> MappedBufferGuard<M, U> {
pub fn map<V: ?Sized>(
orig: Self,
fun: impl FnOnce(&mut U) -> &mut V,
) -> MappedBufferGuard<M, V> {
let store = orig.store;
let index = orig.index;
let value = fun(unsafe { &mut *orig.value });
core::mem::forget(orig);
MappedBufferGuard {
store,
value,
index,
}
}
}