use core::{
cell::UnsafeCell,
panic::Location,
sync::atomic::{AtomicBool, AtomicU32, Ordering},
};
use ax_task::HeldLock;
use crate::mutex::RawMutex;
const MAX_LOCKS: usize = 1024;
const WORDS_PER_ROW: usize = MAX_LOCKS.div_ceil(64);
type LockdepState = (u32, &'static Location<'static>);
pub(crate) struct LockdepMap {
id: AtomicU32,
}
impl LockdepMap {
pub(crate) const fn new() -> Self {
Self {
id: AtomicU32::new(0),
}
}
pub(crate) fn lock_id(&self) -> Option<u32> {
match self.id.load(Ordering::Acquire) {
0 => None,
id => Some(id),
}
}
}
impl Default for LockdepMap {
fn default() -> Self {
Self::new()
}
}
struct LockGraph {
reachability: [[u64; WORDS_PER_ROW]; MAX_LOCKS],
}
impl LockGraph {
const fn new() -> Self {
Self {
reachability: [[0; WORDS_PER_ROW]; MAX_LOCKS],
}
}
fn reaches(&self, from: u32, to: u32) -> bool {
let Some(row) = self.reachability.get(from as usize) else {
return false;
};
let word = (to as usize) / 64;
let bit = (to as usize) % 64;
row.get(word)
.is_some_and(|entry| (*entry & (1u64 << bit)) != 0)
}
fn add_order(&mut self, before: u32, after: u32, max_id: u32) {
let mut closure = self.reachability[after as usize];
let word = (after as usize) / 64;
let bit = (after as usize) % 64;
closure[word] |= 1u64 << bit;
for row in 1..max_id {
if row == before || self.reaches(row, before) {
for (slot, extra) in self.reachability[row as usize].iter_mut().zip(closure) {
*slot |= extra;
}
}
}
}
fn assert_can_acquire(
&self,
held_locks: &ax_task::HeldLockStack,
lock_id: u32,
addr: usize,
caller: &'static Location<'static>,
) {
assert!(
!held_locks.contains(lock_id),
"lockdep: recursive mutex acquisition detected for id={} addr={:#x} at {} with held \
stack {:?}",
lock_id,
addr,
caller,
held_locks
);
for held in held_locks.iter() {
assert!(
!self.reaches(lock_id, held.id),
"lockdep: lock order inversion detected while acquiring id={} addr={:#x} at {}; \
held lock {:?}; stack {:?}",
lock_id,
addr,
caller,
held,
held_locks
);
}
}
fn record_acquire(
&mut self,
held_locks: &mut ax_task::HeldLockStack,
lock_id: u32,
addr: usize,
caller: &'static Location<'static>,
max_id: u32,
) {
let snapshot = *held_locks;
for held in snapshot.iter() {
self.add_order(held.id, lock_id, max_id);
}
held_locks.push(HeldLock {
id: lock_id,
addr,
caller,
});
}
}
struct GraphState {
lock: AtomicBool,
graph: UnsafeCell<LockGraph>,
}
unsafe impl Sync for GraphState {}
struct GraphGuard;
impl GraphGuard {
fn acquire() -> Self {
while GRAPH_STATE
.lock
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
while GRAPH_STATE.lock.load(Ordering::Acquire) {
core::hint::spin_loop();
}
}
Self
}
}
impl Drop for GraphGuard {
fn drop(&mut self) {
GRAPH_STATE.lock.store(false, Ordering::Release);
}
}
static NEXT_LOCK_ID: AtomicU32 = AtomicU32::new(1);
static GRAPH_STATE: GraphState = GraphState {
lock: AtomicBool::new(false),
graph: UnsafeCell::new(LockGraph::new()),
};
fn with_graph<R>(f: impl FnOnce(&mut LockGraph) -> R) -> R {
let _guard = GraphGuard::acquire();
let graph = unsafe { &mut *GRAPH_STATE.graph.get() };
f(graph)
}
fn current_max_lock_id() -> u32 {
NEXT_LOCK_ID.load(Ordering::Acquire).min(MAX_LOCKS as u32)
}
fn ensure_lock_id(map: &LockdepMap) -> u32 {
let existing = map.id.load(Ordering::Acquire);
if existing != 0 {
return existing;
}
let _guard = GraphGuard::acquire();
let existing = map.id.load(Ordering::Acquire);
if existing != 0 {
return existing;
}
let new_id = NEXT_LOCK_ID.fetch_add(1, Ordering::AcqRel);
assert!(
(new_id as usize) < MAX_LOCKS,
"lockdep: exceeded maximum tracked mutex instances ({MAX_LOCKS})"
);
map.id.store(new_id, Ordering::Release);
new_id
}
fn prepare_acquire(
map: &LockdepMap,
addr: usize,
caller: &'static Location<'static>,
) -> LockdepState {
let lock_id = ensure_lock_id(map);
with_graph(|graph| {
ax_task::with_current_lockdep_stack(|stack| {
graph.assert_can_acquire(stack, lock_id, addr, caller)
});
});
(lock_id, caller)
}
fn finish_acquire(lockdep: LockdepState, addr: usize) {
let (lock_id, caller) = lockdep;
let max_id = current_max_lock_id();
with_graph(|graph| {
ax_task::with_current_lockdep_stack(|stack| {
graph.record_acquire(stack, lock_id, addr, caller, max_id);
});
});
}
pub(crate) struct LockdepAcquire {
addr: usize,
state: LockdepState,
}
impl LockdepAcquire {
#[inline(always)]
#[track_caller]
pub(crate) fn prepare(lock: &RawMutex) -> Self {
let addr = lock as *const _ as *const () as usize;
let state = prepare_acquire(&lock.lockdep, addr, Location::caller());
Self { addr, state }
}
#[inline(always)]
pub(crate) fn finish(self) {
finish_acquire(self.state, self.addr);
}
}
#[inline(always)]
pub(crate) fn release(lock: &RawMutex) {
let Some(lock_id) = lock.lockdep.lock_id() else {
return;
};
ax_task::with_current_lockdep_stack(|stack| stack.pop_checked(lock_id));
}