use core::{
cell::UnsafeCell,
fmt,
ops::{Deref, DerefMut},
sync::atomic::{AtomicBool, Ordering},
};
use super::WaitQueue;
pub struct Mutex<T: ?Sized> {
lock: AtomicBool,
queue: WaitQueue,
val: UnsafeCell<T>,
}
impl<T> Mutex<T> {
pub const fn new(val: T) -> Self {
Self {
lock: AtomicBool::new(false),
queue: WaitQueue::new(),
val: UnsafeCell::new(val),
}
}
}
impl<T: ?Sized> Mutex<T> {
#[track_caller]
pub fn lock(&self) -> MutexGuard<'_, T> {
self.queue.wait_until(|| self.try_lock())
}
pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
self.acquire_lock()
.then(|| unsafe { MutexGuard::new(self) })
}
pub fn get_mut(&mut self) -> &mut T {
self.val.get_mut()
}
fn unlock(&self) {
self.release_lock();
self.queue.wake_one();
}
fn acquire_lock(&self) -> bool {
self.lock
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
}
fn release_lock(&self) {
self.lock.store(false, Ordering::Release);
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.val, f)
}
}
unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
#[clippy::has_significant_drop]
#[must_use]
pub struct MutexGuard<'a, T: ?Sized> {
mutex: &'a Mutex<T>,
}
impl<'a, T: ?Sized> MutexGuard<'a, T> {
unsafe fn new(mutex: &'a Mutex<T>) -> MutexGuard<'a, T> {
MutexGuard { mutex }
}
}
impl<T: ?Sized> Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.mutex.val.get() }
}
}
impl<T: ?Sized> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.mutex.val.get() }
}
}
impl<T: ?Sized> Drop for MutexGuard<'_, T> {
fn drop(&mut self) {
self.mutex.unlock();
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<T: ?Sized> !Send for MutexGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}
impl<'a, T: ?Sized> MutexGuard<'a, T> {
pub fn get_lock(guard: &MutexGuard<'a, T>) -> &'a Mutex<T> {
guard.mutex
}
}
#[cfg(ktest)]
mod test {
use super::*;
use crate::prelude::*;
#[ktest]
fn try_lock_does_not_unlock() {
let lock = Mutex::new(0);
assert!(!lock.lock.load(Ordering::Relaxed));
let guard1 = lock.lock();
assert!(lock.lock.load(Ordering::Relaxed));
assert!(lock.try_lock().is_none());
assert!(lock.lock.load(Ordering::Relaxed));
drop(guard1);
}
}