use core::cell::RefCell;
#[cfg(any(feature = "async", feature = "embassy-net"))]
use core::task::Waker;
use critical_section::Mutex;
pub struct CriticalSectionCell<T> {
inner: Mutex<RefCell<T>>,
}
impl<T> CriticalSectionCell<T> {
pub const fn new(value: T) -> Self {
Self {
inner: Mutex::new(RefCell::new(value)),
}
}
#[inline]
pub fn with<R, F>(&self, f: F) -> R
where
F: FnOnce(&mut T) -> R,
{
critical_section::with(|cs| {
let mut value = self.inner.borrow_ref_mut(cs);
f(&mut value)
})
}
#[inline]
pub fn try_with<R, F>(&self, f: F) -> Option<R>
where
F: FnOnce(&mut T) -> R,
{
critical_section::with(|cs| {
self.inner
.borrow(cs)
.try_borrow_mut()
.ok()
.map(|mut value| f(&mut value))
})
}
#[cfg(any(feature = "embassy-net", test))]
#[inline]
pub fn with_ref<R, F>(&self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
critical_section::with(|cs| {
let value = self.inner.borrow_ref(cs);
f(&value)
})
}
}
unsafe impl<T> Sync for CriticalSectionCell<T> {}
#[cfg(any(feature = "async", feature = "embassy-net"))]
pub struct AtomicWaker {
waker: CriticalSectionCell<Option<Waker>>,
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
impl AtomicWaker {
pub const fn new() -> Self {
Self {
waker: CriticalSectionCell::new(None),
}
}
pub fn register(&self, waker: &Waker) {
self.waker.with(|slot| {
match slot {
Some(existing) if existing.will_wake(waker) => {
}
_ => {
*slot = Some(waker.clone());
}
}
});
}
#[inline]
pub fn wake(&self) {
let waker = self.waker.with(|slot| slot.take());
if let Some(w) = waker {
w.wake();
}
}
#[cfg(test)]
pub fn is_registered(&self) -> bool {
self.waker.with_ref(|slot| slot.is_some())
}
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
impl Default for AtomicWaker {
fn default() -> Self {
Self::new()
}
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
unsafe impl Send for AtomicWaker {}
#[cfg(any(feature = "async", feature = "embassy-net"))]
unsafe impl Sync for AtomicWaker {}
#[cfg(test)]
#[allow(clippy::std_instead_of_core, clippy::std_instead_of_alloc)]
mod tests {
extern crate std;
use super::*;
#[cfg(any(feature = "async", feature = "embassy-net"))]
use core::task::{RawWaker, RawWakerVTable};
#[cfg(any(feature = "async", feature = "embassy-net"))]
use std::sync::Arc;
#[cfg(any(feature = "async", feature = "embassy-net"))]
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(any(feature = "async", feature = "embassy-net"))]
struct WakeCounter {
count: AtomicUsize,
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
impl WakeCounter {
fn new() -> Arc<Self> {
Arc::new(Self {
count: AtomicUsize::new(0),
})
}
fn count(&self) -> usize {
self.count.load(Ordering::SeqCst)
}
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
fn test_waker(counter: Arc<WakeCounter>) -> Waker {
fn clone_fn(ptr: *const ()) -> RawWaker {
let arc = unsafe { Arc::from_raw(ptr as *const WakeCounter) };
let cloned = arc.clone();
core::mem::forget(arc);
RawWaker::new(Arc::into_raw(cloned) as *const (), &VTABLE)
}
fn wake_fn(ptr: *const ()) {
let arc = unsafe { Arc::from_raw(ptr as *const WakeCounter) };
arc.count.fetch_add(1, Ordering::SeqCst);
}
fn wake_by_ref_fn(ptr: *const ()) {
let arc = unsafe { Arc::from_raw(ptr as *const WakeCounter) };
arc.count.fetch_add(1, Ordering::SeqCst);
core::mem::forget(arc);
}
fn drop_fn(ptr: *const ()) {
unsafe {
Arc::from_raw(ptr as *const WakeCounter);
}
}
static VTABLE: RawWakerVTable =
RawWakerVTable::new(clone_fn, wake_fn, wake_by_ref_fn, drop_fn);
let raw = RawWaker::new(Arc::into_raw(counter) as *const (), &VTABLE);
unsafe { Waker::from_raw(raw) }
}
#[test]
fn critical_section_cell_new() {
let cell: CriticalSectionCell<u32> = CriticalSectionCell::new(42);
let value = cell.with(|v| *v);
assert_eq!(value, 42);
}
#[test]
fn critical_section_cell_with_mutates() {
let cell: CriticalSectionCell<u32> = CriticalSectionCell::new(0);
cell.with(|v| *v += 10);
let value = cell.with(|v| *v);
assert_eq!(value, 10);
}
#[test]
fn critical_section_cell_with_returns_value() {
let cell: CriticalSectionCell<u32> = CriticalSectionCell::new(42);
let result = cell.with(|v| *v * 2);
assert_eq!(result, 84);
}
#[test]
fn critical_section_cell_try_with_succeeds() {
let cell: CriticalSectionCell<u32> = CriticalSectionCell::new(42);
let result = cell.try_with(|v| *v);
assert_eq!(result, Some(42));
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
#[test]
fn critical_section_cell_with_ref_reads() {
let cell: CriticalSectionCell<u32> = CriticalSectionCell::new(42);
let value = cell.with_ref(|v| *v);
assert_eq!(value, 42);
}
#[test]
fn critical_section_cell_static_usage() {
static CELL: CriticalSectionCell<u32> = CriticalSectionCell::new(0);
CELL.with(|v| *v = 100);
let value = CELL.with(|v| *v);
assert_eq!(value, 100);
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
#[test]
fn atomic_waker_new_is_empty() {
let waker = AtomicWaker::new();
assert!(!waker.is_registered());
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
#[test]
fn atomic_waker_default_is_empty() {
let waker = AtomicWaker::default();
assert!(!waker.is_registered());
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
#[test]
fn atomic_waker_register_stores_waker() {
let atomic_waker = AtomicWaker::new();
let counter = WakeCounter::new();
let waker = test_waker(counter.clone());
atomic_waker.register(&waker);
assert!(atomic_waker.is_registered());
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
#[test]
fn atomic_waker_wake_calls_waker() {
let atomic_waker = AtomicWaker::new();
let counter = WakeCounter::new();
let waker = test_waker(counter.clone());
atomic_waker.register(&waker);
assert_eq!(counter.count(), 0);
atomic_waker.wake();
assert_eq!(counter.count(), 1);
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
#[test]
fn atomic_waker_wake_clears_waker() {
let atomic_waker = AtomicWaker::new();
let counter = WakeCounter::new();
let waker = test_waker(counter.clone());
atomic_waker.register(&waker);
assert!(atomic_waker.is_registered());
atomic_waker.wake();
assert!(!atomic_waker.is_registered());
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
#[test]
fn atomic_waker_wake_without_registered_is_noop() {
let atomic_waker = AtomicWaker::new();
atomic_waker.wake();
assert!(!atomic_waker.is_registered());
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
#[test]
fn atomic_waker_register_overwrites_previous() {
let atomic_waker = AtomicWaker::new();
let counter1 = WakeCounter::new();
let counter2 = WakeCounter::new();
let waker1 = test_waker(counter1.clone());
let waker2 = test_waker(counter2.clone());
atomic_waker.register(&waker1);
atomic_waker.register(&waker2);
atomic_waker.wake();
assert_eq!(counter1.count(), 0);
assert_eq!(counter2.count(), 1);
}
#[cfg(any(feature = "async", feature = "embassy-net"))]
#[test]
fn atomic_waker_double_wake_only_wakes_once() {
let atomic_waker = AtomicWaker::new();
let counter = WakeCounter::new();
let waker = test_waker(counter.clone());
atomic_waker.register(&waker);
atomic_waker.wake();
atomic_waker.wake();
assert_eq!(counter.count(), 1);
}
}