use core::cell::{Cell, UnsafeCell};
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use core::ops::{Deref, DerefMut};
use embassy_sync::blocking_mutex::raw::RawMutex;
use embassy_sync::blocking_mutex::Mutex;
#[derive(Debug)]
pub struct Watch<M: RawMutex, T: Clone> {
mutex: Mutex<M, WatchState<T>>,
}
#[derive(Debug)]
struct WatchState<T: Clone> {
data: UnsafeCell<MaybeUninit<T>>,
current_id: Cell<u8>,
}
trait SealedWatchBehavior<T> {
fn try_changed(&self, id: &mut u8) -> Option<T>;
fn clear(&self);
fn send(&self, val: T);
}
#[allow(private_bounds)]
pub trait WatchBehavior<T: Clone>: SealedWatchBehavior<T> {
fn try_get(&self, id: Option<&mut u8>) -> Option<T>;
fn contains_value(&self) -> bool;
}
impl<M: RawMutex, T: Clone> SealedWatchBehavior<T> for Watch<M, T> {
fn try_changed(&self, id: &mut u8) -> Option<T> {
self.mutex.lock(|state| {
let current_id = state.current_id.get();
if current_id != *id {
*id = current_id;
let data = unsafe { state.data.get().read().assume_init() };
Some(data)
} else {
None
}
})
}
fn clear(&self) {
self.mutex.lock(|state| {
state.current_id.set(0);
})
}
fn send(&self, val: T) {
self.mutex.lock(|state| {
unsafe { state.data.get().write(MaybeUninit::new(val)) };
let mut new_id = state.current_id.get().wrapping_add(1);
if new_id == 0 {
new_id = 1;
}
state.current_id.set(new_id);
})
}
}
impl<M: RawMutex, T: Clone> WatchBehavior<T> for Watch<M, T> {
fn try_get(&self, id: Option<&mut u8>) -> Option<T> {
self.mutex.lock(|state| {
let current_id = state.current_id.get();
if let Some(id) = id {
*id = current_id;
}
if current_id == 0 {
None
} else {
let data = unsafe { state.data.get().read().assume_init() };
Some(data)
}
})
}
fn contains_value(&self) -> bool {
self.mutex.lock(|state| state.current_id.get() != 0)
}
}
impl<M: RawMutex, T: Clone> Watch<M, T> {
pub const fn new() -> Self {
Self {
mutex: Mutex::new(WatchState {
data: UnsafeCell::new(MaybeUninit::zeroed()),
current_id: Cell::new(0),
}),
}
}
pub const fn new_with(data: T) -> Self {
Self {
mutex: Mutex::new(WatchState {
data: UnsafeCell::new(MaybeUninit::new(data)),
current_id: Cell::new(0),
}),
}
}
pub fn sender(&self) -> Sender<'_, M, T> {
Sender(Snd::new(self))
}
pub fn receiver(&self) -> Receiver<'_, M, T> {
Receiver(Rcv::new(self))
}
pub fn get_msg_id(&self) -> u8 {
self.mutex.lock(|state| state.current_id.get())
}
pub fn try_get(&self) -> Option<T> {
WatchBehavior::try_get(self, None)
}
}
#[derive(Debug)]
pub struct Snd<'a, T: Clone, W: WatchBehavior<T> + ?Sized> {
watch: &'a W,
_phantom: PhantomData<T>,
}
impl<'a, T: Clone, W: WatchBehavior<T> + ?Sized> Clone for Snd<'a, T, W> {
fn clone(&self) -> Self {
Self {
watch: self.watch,
_phantom: PhantomData,
}
}
}
impl<'a, T: Clone, W: WatchBehavior<T> + ?Sized> Snd<'a, T, W> {
fn new(watch: &'a W) -> Self {
Self {
watch,
_phantom: PhantomData,
}
}
pub fn send(&self, val: T) {
self.watch.send(val)
}
pub fn clear(&self) {
self.watch.clear()
}
pub fn try_get(&self) -> Option<T> {
self.watch.try_get(None)
}
pub fn contains_value(&self) -> bool {
self.watch.contains_value()
}
}
#[derive(Debug)]
pub struct Sender<'a, M: RawMutex, T: Clone>(Snd<'a, T, Watch<M, T>>);
impl<'a, M: RawMutex, T: Clone> Clone for Sender<'a, M, T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<'a, M: RawMutex, T: Clone> Deref for Sender<'a, M, T> {
type Target = Snd<'a, T, Watch<M, T>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'a, M: RawMutex, T: Clone> DerefMut for Sender<'a, M, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub struct Rcv<'a, T: Clone, W: WatchBehavior<T> + ?Sized> {
watch: &'a W,
at_id: u8,
_phantom: PhantomData<T>,
}
impl<'a, T: Clone, W: WatchBehavior<T> + ?Sized> Rcv<'a, T, W> {
fn new(watch: &'a W) -> Self {
Self {
watch,
at_id: 0,
_phantom: PhantomData,
}
}
pub fn try_get(&mut self) -> Option<T> {
self.watch.try_get(Some(&mut self.at_id))
}
pub fn try_changed(&mut self) -> Option<T> {
self.watch.try_changed(&mut self.at_id)
}
pub fn contains_value(&self) -> bool {
self.watch.contains_value()
}
}
pub struct Receiver<'a, M: RawMutex, T: Clone>(Rcv<'a, T, Watch<M, T>>);
impl<'a, M: RawMutex, T: Clone> Deref for Receiver<'a, M, T> {
type Target = Rcv<'a, T, Watch<M, T>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'a, M: RawMutex, T: Clone> DerefMut for Receiver<'a, M, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[cfg(test)]
mod tests {
use super::Watch;
use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
#[test]
fn multiple_sends() {
static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
let mut rcv = WATCH.receiver();
let snd = WATCH.sender();
assert_eq!(rcv.try_changed(), None);
snd.send(20);
assert_eq!(rcv.try_changed(), Some(20));
assert_eq!(rcv.try_changed(), None);
}
#[test]
fn all_try_get() {
static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
let mut rcv = WATCH.receiver();
let snd = WATCH.sender();
assert_eq!(WATCH.try_get(), None);
assert_eq!(rcv.try_get(), None);
assert_eq!(snd.try_get(), None);
snd.send(10);
assert_eq!(WATCH.try_get(), Some(10));
assert_eq!(rcv.try_get(), Some(10));
assert_eq!(snd.try_get(), Some(10));
}
#[test]
fn once_lock_like() {
static CONFIG0: u8 = 10;
static CONFIG1: u8 = 20;
static WATCH: Watch<CriticalSectionRawMutex, &'static u8> = Watch::new();
let mut rcv = WATCH.receiver();
let snd = WATCH.sender();
assert_eq!(rcv.try_changed(), None);
snd.send(&CONFIG0);
let rcv0 = rcv.try_changed().unwrap();
assert_eq!(rcv0, &10);
snd.send(&CONFIG1);
let rcv1 = rcv.try_changed();
assert_eq!(rcv1, Some(&20));
assert_eq!(rcv.try_changed(), None);
assert_eq!(rcv0, &CONFIG0);
assert_eq!(rcv1, Some(&CONFIG1));
}
#[test]
fn sender_modify() {
static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
let mut rcv = WATCH.receiver();
let snd = WATCH.sender();
snd.send(10);
assert_eq!(rcv.try_changed(), Some(10));
}
#[test]
fn receive_after_create() {
static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
let snd = WATCH.sender();
snd.send(10);
let mut rcv = WATCH.receiver();
assert_eq!(rcv.try_changed(), Some(10));
}
#[test]
fn multiple_receivers() {
static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
let mut rcv0 = WATCH.receiver();
let snd = WATCH.sender();
assert_eq!(rcv0.try_changed(), None);
snd.send(0);
assert_eq!(rcv0.try_changed(), Some(0));
}
#[test]
fn clone_senders() {
static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
let snd0 = WATCH.sender();
let snd1 = snd0.clone();
let mut rcv = WATCH.receiver();
snd0.send(10);
assert_eq!(rcv.try_changed(), Some(10));
snd1.send(20);
assert_eq!(rcv.try_changed(), Some(20));
}
#[test]
fn contains_value() {
static WATCH: Watch<CriticalSectionRawMutex, u8> = Watch::new();
let rcv = WATCH.receiver();
let snd = WATCH.sender();
assert_eq!(rcv.contains_value(), false);
assert_eq!(snd.contains_value(), false);
snd.send(10);
assert_eq!(rcv.contains_value(), true);
assert_eq!(snd.contains_value(), true);
}
}