use crate::sync::{futex_wait_fast, NotSend};
use core::cell::UnsafeCell;
use core::fmt;
use core::sync::atomic::{
AtomicU32,
Ordering::{Acquire, Relaxed, Release},
};
use rusl::futex::futex_wake;
struct InnerMutex {
futex: AtomicU32,
}
impl InnerMutex {
#[inline]
const fn new() -> Self {
Self {
futex: AtomicU32::new(0),
}
}
#[inline]
fn try_lock(&self) -> bool {
self.futex.compare_exchange(0, 1, Acquire, Relaxed).is_ok()
}
#[inline]
fn lock(&self) {
if self.futex.compare_exchange(0, 1, Acquire, Relaxed).is_err() {
self.lock_contended();
}
}
#[cold]
fn lock_contended(&self) {
let mut state = self.spin();
if state == 0 {
match self.futex.compare_exchange(0, 1, Acquire, Relaxed) {
Ok(_) => return, Err(s) => state = s,
}
}
loop {
if state != 2 && self.futex.swap(2, Acquire) == 0 {
return;
}
futex_wait_fast(&self.futex, 2);
state = self.spin();
}
}
fn spin(&self) -> u32 {
let mut spin = 100;
loop {
let state = self.futex.load(Relaxed);
if state != 1 || spin == 0 {
return state;
}
core::hint::spin_loop();
spin -= 1;
}
}
#[inline]
unsafe fn unlock(&self) {
if self.futex.swap(0, Release) == 2 {
self.wake();
}
}
#[cold]
fn wake(&self) {
let _ = futex_wake(&self.futex, 1);
}
}
pub struct Mutex<T: ?Sized> {
inner: InnerMutex,
data: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
#[must_use = "if unused the Mutex will immediately unlock"]
#[clippy::has_significant_drop]
pub struct MutexGuard<'a, T: ?Sized + 'a> {
lock: &'a Mutex<T>,
_not_send: NotSend,
}
unsafe impl<T: ?Sized + Sync> Sync for MutexGuard<'_, T> {}
impl<T> Mutex<T> {
#[inline]
pub const fn new(t: T) -> Mutex<T> {
Mutex {
inner: InnerMutex::new(),
data: UnsafeCell::new(t),
}
}
}
impl<T: ?Sized> Mutex<T> {
pub fn lock(&self) -> MutexGuard<'_, T> {
unsafe {
self.inner.lock();
MutexGuard::new(self)
}
}
pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
unsafe {
if self.inner.try_lock() {
Some(MutexGuard::new(self))
} else {
None
}
}
}
#[inline]
pub fn into_inner(self) -> T
where
T: Sized,
{
self.data.into_inner()
}
#[inline]
pub fn get_mut(&mut self) -> &mut T {
self.data.get_mut()
}
}
impl<T: Default> Default for Mutex<T> {
fn default() -> Mutex<T> {
Mutex::new(Default::default())
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut d = f.debug_struct("Mutex");
if let Some(guard) = self.try_lock() {
d.field("data", &&*guard);
} else {
struct LockedPlaceholder;
impl fmt::Debug for LockedPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
d.field("data", &LockedPlaceholder);
}
d.finish_non_exhaustive()
}
}
impl<'mutex, T: ?Sized> MutexGuard<'mutex, T> {
unsafe fn new(lock: &'mutex Mutex<T>) -> MutexGuard<'mutex, T> {
MutexGuard {
lock,
_not_send: NotSend::new(),
}
}
}
impl<T: ?Sized> core::ops::Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.data.get() }
}
}
impl<T: ?Sized> core::ops::DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.data.get() }
}
}
impl<T: ?Sized> Drop for MutexGuard<'_, T> {
#[inline]
fn drop(&mut self) {
unsafe {
self.lock.inner.unlock();
}
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
#[cfg(test)]
mod tests {
use crate::sync::Mutex;
use core::time::Duration;
#[test]
fn lock_threaded_mutex() {
let count = std::sync::Arc::new(Mutex::new(0));
let mut handles = std::vec::Vec::new();
for _i in 0..15 {
let count_c = count.clone();
let handle = std::thread::spawn(move || {
let mut guard = count_c.lock();
std::thread::sleep(Duration::from_millis(1));
*guard += 1;
});
handles.push(handle);
}
for h in handles {
h.join().unwrap();
}
assert_eq!(15, *count.lock());
}
#[test]
fn try_lock_threaded_mutex() {
let val = std::sync::Arc::new(Mutex::new(0));
let val_c = val.clone();
assert_eq!(0, *val_c.try_lock().unwrap());
std::thread::spawn(move || {
let _guard = val_c.lock();
std::thread::sleep(Duration::from_millis(2000));
});
std::thread::sleep(Duration::from_millis(100));
assert!(val.try_lock().is_none());
}
}