use core::sync::atomic::{AtomicU64, Ordering};
use ax_task::{WaitQueue, current, might_sleep};
pub struct RawMutex {
wq: WaitQueue,
owner_id: AtomicU64,
#[cfg(feature = "lockdep")]
pub(crate) lockdep: crate::lockdep::LockdepMap,
}
impl RawMutex {
#[inline(always)]
pub const fn new() -> Self {
Self {
wq: WaitQueue::new(),
owner_id: AtomicU64::new(0),
#[cfg(feature = "lockdep")]
lockdep: crate::lockdep::LockdepMap::new(),
}
}
#[inline(always)]
fn is_owner(&self, owner_id: u64) -> bool {
self.owner_id.load(Ordering::Acquire) == owner_id
}
}
impl Default for RawMutex {
fn default() -> Self {
Self::new()
}
}
unsafe impl lock_api::RawMutex for RawMutex {
type GuardMarker = lock_api::GuardSend;
#[allow(clippy::declare_interior_mutable_const)]
const INIT: Self = RawMutex::new();
#[inline(always)]
fn lock(&self) {
might_sleep();
let current_id = current().id().as_u64();
#[cfg(feature = "lockdep")]
let lockdep = crate::lockdep::LockdepAcquire::prepare(self);
loop {
match self.owner_id.compare_exchange_weak(
0,
current_id,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(owner_id) => {
assert_ne!(
owner_id, current_id,
"Thread({current_id}) tried to acquire mutex it already owns.",
);
self.wq
.wait_until(|| self.is_owner(current_id) || !self.is_locked());
if self.is_owner(current_id) {
break;
}
}
}
}
#[cfg(feature = "lockdep")]
lockdep.finish();
}
#[inline(always)]
fn try_lock(&self) -> bool {
might_sleep();
let current_id = current().id().as_u64();
#[cfg(feature = "lockdep")]
let lockdep = crate::lockdep::LockdepAcquire::prepare(self);
let acquired = self
.owner_id
.compare_exchange(0, current_id, Ordering::Acquire, Ordering::Relaxed)
.is_ok();
#[cfg(feature = "lockdep")]
if acquired {
lockdep.finish();
}
acquired
}
#[inline(always)]
unsafe fn unlock(&self) {
let owner_id = self.owner_id.load(Ordering::Acquire);
let current_id = current().id().as_u64();
assert_eq!(
owner_id, current_id,
"Thread({current_id}) tried to release mutex it doesn't own",
);
self.wq.notify_one_with(true, |id: u64| {
self.owner_id.swap(id, Ordering::Release);
});
#[cfg(feature = "lockdep")]
crate::lockdep::release(self);
}
#[inline(always)]
fn is_locked(&self) -> bool {
self.owner_id.load(Ordering::Acquire) != 0
}
}
pub type Mutex<T> = lock_api::Mutex<RawMutex, T>;
pub type MutexGuard<'a, T> = lock_api::MutexGuard<'a, RawMutex, T>;
#[cfg(all(test, not(target_os = "none")))]
mod tests {
use std::sync::{Mutex as StdMutex, Once, OnceLock};
use ax_task as thread;
use crate::Mutex;
static INIT: Once = Once::new();
static TEST_LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
fn init_test_scheduler() {
INIT.call_once(thread::init_scheduler);
}
fn lock_test_context() -> std::sync::MutexGuard<'static, ()> {
TEST_LOCK
.get_or_init(|| StdMutex::new(()))
.lock()
.expect("test serialization mutex poisoned")
}
fn with_test_context<R>(f: impl FnOnce() -> R) -> R {
let _test_guard = lock_test_context();
init_test_scheduler();
f()
}
fn may_interrupt() {
if fastrand::u8(0..3) == 0 {
thread::yield_now();
}
}
#[test]
fn lots_and_lots() {
with_test_context(|| {
const NUM_TASKS: u32 = 10;
const NUM_ITERS: u32 = 10_000;
static M: Mutex<u32> = Mutex::new(0);
fn inc(delta: u32) {
for _ in 0..NUM_ITERS {
let mut val = M.lock();
*val += delta;
may_interrupt();
drop(val);
may_interrupt();
}
}
for _ in 0..NUM_TASKS {
thread::spawn(|| inc(1));
thread::spawn(|| inc(2));
}
println!("spawn OK");
loop {
let val = M.lock();
if *val == NUM_ITERS * NUM_TASKS * 3 {
break;
}
may_interrupt();
drop(val);
may_interrupt();
}
assert_eq!(*M.lock(), NUM_ITERS * NUM_TASKS * 3);
println!("Mutex test OK");
});
}
#[cfg(feature = "lockdep")]
mod lockdep_tests {
use core::mem::ManuallyDrop;
use std::panic::{AssertUnwindSafe, catch_unwind};
use super::*;
fn reset_lockdep_stack() {
thread::with_current_lockdep_stack(|stack| *stack = thread::HeldLockStack::new());
}
fn with_lockdep_test<R>(f: impl FnOnce() -> R) -> R {
with_test_context(|| {
reset_lockdep_stack();
let result = f();
reset_lockdep_stack();
result
})
}
fn assert_lockdep_failure(f: impl FnOnce()) {
let result = catch_unwind(AssertUnwindSafe(f));
assert!(result.is_err());
reset_lockdep_stack();
}
#[test]
fn rejects_recursive_acquire() {
with_lockdep_test(|| {
let lock = Mutex::new(0usize);
assert_lockdep_failure(|| {
let _guard = lock.lock();
let _guard2 = lock.lock();
});
});
}
#[test]
fn rejects_order_inversion() {
with_lockdep_test(|| {
let lock_a = Mutex::new(0usize);
let lock_b = Mutex::new(0usize);
{
let _guard_a = lock_a.lock();
let _guard_b = lock_b.lock();
}
let guard_b = ManuallyDrop::new(lock_b.lock());
assert_lockdep_failure(|| {
let _guard_a = lock_a.lock();
});
core::mem::forget(guard_b);
});
}
#[test]
fn rejects_out_of_order_unlock() {
with_lockdep_test(|| {
let lock_a = Mutex::new(0usize);
let lock_b = Mutex::new(0usize);
let guard_a = lock_a.lock();
let guard_b = ManuallyDrop::new(lock_b.lock());
assert_lockdep_failure(|| drop(guard_a));
core::mem::forget(guard_b);
});
}
}
}