use core::cell::{RefCell, UnsafeCell};
use core::future::poll_fn;
use core::ops::{Deref, DerefMut};
use core::task::Poll;
use crate::blocking_mutex::raw::RawMutex;
use crate::blocking_mutex::Mutex as BlockingMutex;
use crate::waitqueue::WakerRegistration;
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct TryLockError;
struct State {
locked: bool,
waker: WakerRegistration,
}
pub struct Mutex<M, T>
where
M: RawMutex,
T: ?Sized,
{
state: BlockingMutex<M, RefCell<State>>,
inner: UnsafeCell<T>,
}
unsafe impl<M: RawMutex + Send, T: ?Sized + Send> Send for Mutex<M, T> {}
unsafe impl<M: RawMutex + Sync, T: ?Sized + Send> Sync for Mutex<M, T> {}
impl<M, T> Mutex<M, T>
where
M: RawMutex,
{
pub const fn new(value: T) -> Self {
Self {
inner: UnsafeCell::new(value),
state: BlockingMutex::new(RefCell::new(State {
locked: false,
waker: WakerRegistration::new(),
})),
}
}
}
impl<M, T> Mutex<M, T>
where
M: RawMutex,
T: ?Sized,
{
pub async fn lock(&self) -> MutexGuard<'_, M, T> {
poll_fn(|cx| {
let ready = self.state.lock(|s| {
let mut s = s.borrow_mut();
if s.locked {
s.waker.register(cx.waker());
false
} else {
s.locked = true;
true
}
});
if ready {
Poll::Ready(MutexGuard { mutex: self })
} else {
Poll::Pending
}
})
.await
}
pub fn try_lock(&self) -> Result<MutexGuard<'_, M, T>, TryLockError> {
self.state.lock(|s| {
let mut s = s.borrow_mut();
if s.locked {
Err(TryLockError)
} else {
s.locked = true;
Ok(())
}
})?;
Ok(MutexGuard { mutex: self })
}
pub fn into_inner(self) -> T
where
T: Sized,
{
self.inner.into_inner()
}
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
}
pub struct MutexGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
mutex: &'a Mutex<M, T>,
}
impl<'a, M, T> Drop for MutexGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
fn drop(&mut self) {
self.mutex.state.lock(|s| {
let mut s = unwrap!(s.try_borrow_mut());
s.locked = false;
s.waker.wake();
})
}
}
impl<'a, M, T> Deref for MutexGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*(self.mutex.inner.get() as *const T) }
}
}
impl<'a, M, T> DerefMut for MutexGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *(self.mutex.inner.get()) }
}
}