use crate::sync::ThreadLocal;
use crate::sync::{Mutex, MutexGuard};
use std::backtrace::Backtrace;
use std::fmt::{Debug, Display};
use std::ops::{Deref, DerefMut};
use std::panic::Location;
use std::sync::Arc;
#[derive(Debug)]
pub struct LoroMutex<T> {
lock: Mutex<T>,
kind: u8,
currently_locked_in_this_thread: Arc<ThreadLocal<Mutex<LockInfo>>>,
}
#[derive(Debug, Copy, Clone, Default)]
struct LockInfo {
kind: u8,
caller_location: Option<&'static Location<'static>>,
}
impl Display for LockInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.caller_location {
Some(location) => write!(
f,
"LockInfo(kind: {}, location: {}:{}:{})",
self.kind,
location.file(),
location.line(),
location.column()
),
None => write!(f, "LockInfo(kind: {}, location: None)", self.kind),
}
}
}
#[derive(Debug)]
pub struct LoroLockGroup {
g: Arc<ThreadLocal<Mutex<LockInfo>>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LockKind {
None = 0,
Txn = 1,
OpLog = 2,
DocState = 3,
DiffCalculator = 4,
}
impl LoroLockGroup {
pub fn new() -> Self {
let g = Arc::new(ThreadLocal::new());
LoroLockGroup { g }
}
pub fn new_lock<T>(&self, value: T, kind: LockKind) -> LoroMutex<T> {
LoroMutex {
lock: Mutex::new(value),
currently_locked_in_this_thread: self.g.clone(),
kind: kind as u8,
}
}
}
impl Default for LoroLockGroup {
fn default() -> Self {
Self::new()
}
}
impl<T> LoroMutex<T> {
#[track_caller]
pub fn lock(&self) -> LoroMutexGuard<'_, T> {
let caller = Location::caller();
let v = self.currently_locked_in_this_thread.get_or_default();
let last = *v.lock();
let this = LockInfo {
kind: self.kind,
caller_location: Some(caller),
};
if last.kind >= self.kind {
panic!(
"Locking order violation. Current lock: {}, New lock: {}",
last, this
);
}
let guard = self.lock.lock_with_kind("LoroMutex");
*v.lock() = this;
LoroMutexGuard {
guard,
_inner: LoroMutexGuardInner {
inner: self,
this,
last,
},
}
}
pub fn is_locked(&self) -> bool {
self.lock.is_locked()
}
pub(crate) fn can_lock_in_this_thread(&self) -> bool {
let v = self.currently_locked_in_this_thread.get_or_default();
let last = *v.lock();
last.kind < self.kind
}
}
pub struct LoroMutexGuard<'a, T> {
guard: MutexGuard<'a, T>,
_inner: LoroMutexGuardInner<'a, T>,
}
struct LoroMutexGuardInner<'a, T> {
inner: &'a LoroMutex<T>,
this: LockInfo,
last: LockInfo,
}
impl<T> Deref for LoroMutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.guard
}
}
impl<T> DerefMut for LoroMutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.guard
}
}
impl<T: Debug> std::fmt::Debug for LoroMutexGuard<'_, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LoroMutex")
.field("data", &self.guard)
.finish()
}
}
impl<'a, T> LoroMutexGuard<'a, T> {
pub fn take_guard(self) -> MutexGuard<'a, T> {
self.guard
}
}
impl<T> Drop for LoroMutexGuardInner<'_, T> {
fn drop(&mut self) {
let cur = self.inner.currently_locked_in_this_thread.get_or_default();
let current_lock_info = *cur.lock();
if current_lock_info.kind != self.this.kind {
let bt = Backtrace::capture();
eprintln!("Locking release order violation callstack:\n{}", bt);
panic!(
"Locking release order violation. self.this: {}, self.last: {}, current: {}",
self.this, self.last, current_lock_info
);
}
*cur.lock() = self.last;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic(expected = "Locking order violation")]
fn test_locking_order_violation_shows_caller() {
let group = LoroLockGroup::new();
let mutex1 = group.new_lock(1, LockKind::DocState);
let mutex2 = group.new_lock(2, LockKind::Txn);
let _guard1 = mutex1.lock(); let _guard2 = mutex2.lock(); }
#[test]
fn test_locking_order_when_dropped_in_order() {
let group = LoroLockGroup::new();
let mutex1 = group.new_lock(1, LockKind::Txn);
let mutex2 = group.new_lock(2, LockKind::OpLog);
let mutex3 = group.new_lock(3, LockKind::DocState);
let _guard1 = mutex1.lock();
drop(_guard1);
let _guard2 = mutex2.lock();
drop(_guard2);
let _guard3 = mutex3.lock();
}
#[test]
#[should_panic]
fn test_locking_order_when_not_dropped_in_reverse_order() {
let group = LoroLockGroup::new();
let mutex1 = group.new_lock(1, LockKind::Txn);
let mutex2 = group.new_lock(2, LockKind::OpLog);
let _guard1 = mutex1.lock();
let _guard2 = mutex2.lock();
drop(_guard1);
drop(_guard2);
}
#[test]
fn test_dropping_should_restore_last_lock_info_0() {
let group = LoroLockGroup::new();
let mutex1 = group.new_lock(1, LockKind::Txn);
let mutex2 = group.new_lock(2, LockKind::OpLog);
let mutex3 = group.new_lock(3, LockKind::DocState);
let _guard1 = mutex1.lock();
let _guard3 = mutex3.lock();
drop(_guard3);
let _guard2 = mutex2.lock();
drop(_guard2);
}
#[test]
#[should_panic]
fn test_dropping_should_restore_last_lock_info_1() {
let group = LoroLockGroup::new();
let mutex1 = group.new_lock(1, LockKind::Txn);
let mutex2 = group.new_lock(2, LockKind::OpLog);
let mutex3 = group.new_lock(3, LockKind::DocState);
let _guard2 = mutex2.lock();
let _guard3 = mutex3.lock();
drop(_guard3);
let _guard1 = mutex1.lock();
}
#[test]
fn test_nested_locking_same_kind() {
let group = LoroLockGroup::new();
let mutex1 = group.new_lock(1, LockKind::Txn);
let mutex2 = group.new_lock(2, LockKind::Txn);
let guard1 = mutex1.lock();
drop(guard1);
let _guard2 = mutex2.lock(); }
#[test]
fn test_lock_kind_enum_values() {
assert_eq!(LockKind::None as u8, 0);
assert_eq!(LockKind::Txn as u8, 1);
assert_eq!(LockKind::OpLog as u8, 2);
assert_eq!(LockKind::DocState as u8, 3);
assert_eq!(LockKind::DiffCalculator as u8, 4);
}
#[test]
fn test_is_locked_functionality() {
let group = LoroLockGroup::new();
let mutex = group.new_lock(42, LockKind::Txn);
assert!(!mutex.is_locked());
let _guard = mutex.lock();
assert!(mutex.is_locked());
}
#[test]
#[should_panic(expected = "Locking order violation")]
fn test_panic_message_contains_location_info() {
let group = LoroLockGroup::new();
let mutex1 = group.new_lock(1, LockKind::DocState);
let mutex2 = group.new_lock(2, LockKind::Txn);
let _guard1 = mutex1.lock();
let _guard2 = mutex2.lock();
}
#[test]
#[should_panic(expected = "poisoned LoroMutex")]
fn test_poisoned_lock_panics_after_unwind() {
let group = LoroLockGroup::new();
let mutex = group.new_lock(42, LockKind::Txn);
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = mutex.lock();
panic!("poison the lock");
}));
let _ = mutex.lock();
}
}