use core::cell::UnsafeCell;
use core::ops::{Deref, DerefMut};
use embassy_sync::blocking_mutex::raw::RawMutex;
use crate::utils::init::{init, Init, UnsafeCellInit};
use crate::utils::sync::blocking::raw::MatterRawMutex;
use super::signal::Signal;
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct TryLockError;
pub struct IfMutex<T, M = MatterRawMutex>
where
T: ?Sized,
M: RawMutex,
{
state: Signal<bool, M>,
inner: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Send, M: RawMutex + Send> Send for IfMutex<T, M> {}
unsafe impl<T: ?Sized + Send, M: RawMutex + Sync> Sync for IfMutex<T, M> {}
impl<T, M> IfMutex<T, M>
where
M: RawMutex,
{
#[inline(always)]
pub const fn new(value: T) -> Self {
Self {
state: Signal::new(false),
inner: UnsafeCell::new(value),
}
}
pub fn init<I: Init<T>>(value: I) -> impl Init<Self> {
init!(Self {
state: Signal::new(false),
inner <- UnsafeCell::init(value),
})
}
}
impl<T, M> IfMutex<T, M>
where
T: ?Sized,
M: RawMutex,
{
pub async fn lock(&self) -> IfMutexGuard<'_, T, M> {
self.lock_if(|_| true).await
}
pub async fn lock_if<F>(&self, f: F) -> IfMutexGuard<'_, T, M>
where
F: Fn(&T) -> bool,
{
self.state
.wait(|locked| {
if !*locked && f(unsafe { &*self.inner.get() }) {
*locked = true;
Some(())
} else {
None
}
})
.await;
IfMutexGuard { mutex: self }
}
pub async fn with<F, R>(&self, mut f: F) -> R
where
F: FnMut(&mut T) -> Option<R>,
{
let result = self
.state
.wait(|locked| {
if !*locked {
if let Some(result) = f(unsafe { &mut *self.inner.get() }) {
*locked = true;
return Some(result);
}
}
None
})
.await;
let _ = IfMutexGuard { mutex: self };
result
}
pub fn try_lock(&self) -> Result<IfMutexGuard<'_, T, M>, TryLockError> {
self.try_lock_if(|_| true)
}
pub fn try_lock_if<F>(&self, mut f: F) -> Result<IfMutexGuard<'_, T, M>, TryLockError>
where
F: FnMut(&T) -> bool,
{
self.state.modify(|locked| {
if *locked {
(false, Err(TryLockError))
} else if f(unsafe { &*self.inner.get() }) {
*locked = true;
(false, Ok(()))
} else {
(false, Err(TryLockError))
}
})?;
Ok(IfMutexGuard { mutex: self })
}
pub fn into_inner(self) -> T
where
T: Sized,
{
self.inner.into_inner()
}
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
}
pub struct IfMutexGuard<'a, T, M = MatterRawMutex>
where
T: ?Sized,
M: RawMutex,
{
mutex: &'a IfMutex<T, M>,
}
impl<T, M> Drop for IfMutexGuard<'_, T, M>
where
T: ?Sized,
M: RawMutex,
{
fn drop(&mut self) {
self.mutex.state.modify(|locked| {
assert!(*locked);
*locked = false;
(true, ())
})
}
}
impl<T, M> Deref for IfMutexGuard<'_, T, M>
where
T: ?Sized,
M: RawMutex,
{
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*(self.mutex.inner.get() as *const T) }
}
}
impl<T, M> DerefMut for IfMutexGuard<'_, T, M>
where
T: ?Sized,
M: RawMutex,
{
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *(self.mutex.inner.get()) }
}
}