use alloc::{format, string::String};
use core::{cell::RefCell, panic::Location};
use std::{
fs::File,
path::{Path, PathBuf},
};
use super::rank::{LockRank, LockRankSet};
use crate::FastHashSet;
pub type RankData = Option<HeldLock>;
pub struct Mutex<T> {
inner: parking_lot::Mutex<T>,
rank: LockRank,
}
pub struct MutexGuard<'a, T> {
inner: parking_lot::MutexGuard<'a, T>,
_state: LockStateGuard,
}
impl<T> Mutex<T> {
pub fn new(rank: LockRank, value: T) -> Mutex<T> {
Mutex {
inner: parking_lot::Mutex::new(value),
rank,
}
}
#[track_caller]
pub fn lock(&self) -> MutexGuard<'_, T> {
let saved = acquire(self.rank, Location::caller());
MutexGuard {
inner: self.inner.lock(),
_state: LockStateGuard { saved },
}
}
pub fn into_inner(self) -> T {
self.inner.into_inner()
}
}
impl<'a, T> core::ops::Deref for MutexGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<'a, T> core::ops::DerefMut for MutexGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.deref_mut()
}
}
impl<T: core::fmt::Debug> core::fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.inner.fmt(f)
}
}
pub struct RwLock<T> {
inner: parking_lot::RwLock<T>,
rank: LockRank,
}
pub struct RwLockReadGuard<'a, T> {
inner: parking_lot::RwLockReadGuard<'a, T>,
_state: LockStateGuard,
}
pub struct RwLockWriteGuard<'a, T> {
inner: parking_lot::RwLockWriteGuard<'a, T>,
_state: LockStateGuard,
}
impl<T> RwLock<T> {
pub fn new(rank: LockRank, value: T) -> RwLock<T> {
RwLock {
inner: parking_lot::RwLock::new(value),
rank,
}
}
#[track_caller]
pub fn read(&self) -> RwLockReadGuard<'_, T> {
let saved = acquire(self.rank, Location::caller());
RwLockReadGuard {
inner: self.inner.read(),
_state: LockStateGuard { saved },
}
}
#[track_caller]
pub fn write(&self) -> RwLockWriteGuard<'_, T> {
let saved = acquire(self.rank, Location::caller());
RwLockWriteGuard {
inner: self.inner.write(),
_state: LockStateGuard { saved },
}
}
pub unsafe fn force_unlock_read(&self, data: RankData) {
release(data);
unsafe { self.inner.force_unlock_read() };
}
}
impl<'a, T> RwLockReadGuard<'a, T> {
pub fn forget(this: Self) -> RankData {
core::mem::forget(this.inner);
this._state.saved
}
}
impl<'a, T> RwLockWriteGuard<'a, T> {
pub fn downgrade(this: Self) -> RwLockReadGuard<'a, T> {
RwLockReadGuard {
inner: parking_lot::RwLockWriteGuard::downgrade(this.inner),
_state: this._state,
}
}
}
impl<T: core::fmt::Debug> core::fmt::Debug for RwLock<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.inner.fmt(f)
}
}
impl<'a, T> core::ops::Deref for RwLockReadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<'a, T> core::ops::Deref for RwLockWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner.deref()
}
}
impl<'a, T> core::ops::DerefMut for RwLockWriteGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.deref_mut()
}
}
struct LockStateGuard {
saved: Option<HeldLock>,
}
impl Drop for LockStateGuard {
fn drop(&mut self) {
release(self.saved)
}
}
fn acquire(new_rank: LockRank, location: &'static Location<'static>) -> Option<HeldLock> {
LOCK_STATE.with_borrow_mut(|state| match *state {
ThreadState::Disabled => None,
ThreadState::Initial => {
let Ok(dir) = std::env::var("WGPU_CORE_LOCK_OBSERVE_DIR") else {
*state = ThreadState::Disabled;
return None;
};
let mut log = ObservationLog::create(dir)
.expect("Failed to open lock observation file (does the dir exist?)");
for rank in LockRankSet::all().iter() {
log.write_rank(rank);
}
*state = ThreadState::Enabled {
held_lock: Some(HeldLock {
rank: new_rank,
location,
}),
log,
};
None
}
ThreadState::Enabled {
ref mut held_lock,
ref mut log,
} => {
if let Some(ref held_lock) = held_lock {
log.write_acquisition(held_lock, new_rank, location);
}
held_lock.replace(HeldLock {
rank: new_rank,
location,
})
}
})
}
fn release(saved: Option<HeldLock>) {
LOCK_STATE.with_borrow_mut(|state| {
if let ThreadState::Enabled {
ref mut held_lock, ..
} = *state
{
*held_lock = saved;
}
});
}
std::thread_local! {
static LOCK_STATE: RefCell<ThreadState> = const { RefCell::new(ThreadState::Initial) };
}
enum ThreadState {
Initial,
Disabled,
Enabled {
held_lock: Option<HeldLock>,
log: ObservationLog,
},
}
#[derive(Debug, Copy, Clone)]
pub struct HeldLock {
rank: LockRank,
location: &'static Location<'static>,
}
struct ObservationLog {
log_file: File,
locations_seen: FastHashSet<*const Location<'static>>,
buffer: String,
}
impl ObservationLog {
fn create(dir: impl AsRef<Path>) -> Result<Self, std::io::Error> {
let mut path = PathBuf::from(dir.as_ref());
path.push(format!(
"locks-{}.{:?}.ron",
std::process::id(),
std::thread::current().id()
));
let log_file = File::create(&path)?;
Ok(ObservationLog {
log_file,
locations_seen: FastHashSet::default(),
buffer: String::new(),
})
}
fn write_acquisition(
&mut self,
older_lock: &HeldLock,
new_rank: LockRank,
new_location: &'static Location<'static>,
) {
self.write_location(older_lock.location);
self.write_location(new_location);
self.write_action(&Action::Acquisition {
older_rank: older_lock.rank.bit.number(),
older_location: addr(older_lock.location),
newer_rank: new_rank.bit.number(),
newer_location: addr(new_location),
});
}
fn write_location(&mut self, location: &'static Location<'static>) {
if self.locations_seen.insert(location) {
self.write_action(&Action::Location {
address: addr(location),
file: location.file(),
line: location.line(),
column: location.column(),
});
}
}
fn write_rank(&mut self, rank: LockRankSet) {
self.write_action(&Action::Rank {
bit: rank.number(),
member_name: rank.member_name(),
const_name: rank.const_name(),
});
}
fn write_action(&mut self, action: &Action) {
use std::io::Write;
self.buffer.clear();
ron::ser::to_writer(&mut self.buffer, &action)
.expect("error serializing `lock::observing::Action`");
self.buffer.push('\n');
self.log_file
.write_all(self.buffer.as_bytes())
.expect("error writing `lock::observing::Action`");
}
}
#[derive(serde::Serialize)]
enum Action {
Location {
address: usize,
file: &'static str,
line: u32,
column: u32,
},
Rank {
bit: u32,
member_name: &'static str,
const_name: &'static str,
},
Acquisition {
older_rank: u32,
older_location: usize,
newer_rank: u32,
newer_location: usize,
},
}
impl LockRankSet {
fn number(self) -> u32 {
self.bits().trailing_zeros()
}
}
fn addr<T>(t: &T) -> usize {
core::ptr::from_ref(t) as usize
}