use crate::deadlock::{DeadlockDetector, ResourceId, ResourceInfo, ResourceKind};
use crate::inspector::Inspector;
use crate::instrument::current_task_id;
use crate::sync::{LockMetrics, MetricsTracker, WaitTimer};
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tokio::sync::Mutex as TokioMutex;
pub struct Mutex<T> {
inner: TokioMutex<T>,
name: String,
resource_id: ResourceId,
metrics: Arc<MetricsTracker>,
}
impl<T> Mutex<T> {
pub fn new(value: T, name: impl Into<String>) -> Self {
let name = name.into();
let resource_info = ResourceInfo::new(ResourceKind::Mutex, name.clone());
let resource_id = resource_info.id;
let detector = Inspector::global().deadlock_detector();
let _ = detector.register_resource(resource_info);
Self {
inner: TokioMutex::new(value),
name,
resource_id,
metrics: Arc::new(MetricsTracker::new()),
}
}
pub async fn lock(&self) -> MutexGuard<'_, T> {
let detector = Inspector::global().deadlock_detector();
let task_id = current_task_id();
if let Some(tid) = task_id {
detector.wait_for(tid, self.resource_id);
}
let timer = WaitTimer::start();
let guard = self.inner.lock().await;
let wait_time = timer.elapsed_if_contended();
self.metrics.record_acquisition(wait_time);
if let Some(tid) = task_id {
detector.acquire(tid, self.resource_id);
}
MutexGuard {
guard,
resource_id: self.resource_id,
task_id,
detector: detector.clone(),
}
}
pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
let detector = Inspector::global().deadlock_detector();
let task_id = current_task_id();
match self.inner.try_lock() {
Ok(guard) => {
self.metrics.record_acquisition(None);
if let Some(tid) = task_id {
detector.acquire(tid, self.resource_id);
}
Some(MutexGuard {
guard,
resource_id: self.resource_id,
task_id,
detector: detector.clone(),
})
}
Err(_) => None,
}
}
#[must_use]
pub fn metrics(&self) -> LockMetrics {
self.metrics.get_metrics()
}
pub fn reset_metrics(&self) {
self.metrics.reset();
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn resource_id(&self) -> ResourceId {
self.resource_id
}
pub fn into_inner(self) -> T {
self.inner.into_inner()
}
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
}
impl<T: fmt::Debug> fmt::Debug for Mutex<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let metrics = self.metrics();
f.debug_struct("Mutex")
.field("name", &self.name)
.field("resource_id", &self.resource_id)
.field("acquisitions", &metrics.acquisitions)
.field("contentions", &metrics.contentions)
.finish()
}
}
pub struct MutexGuard<'a, T> {
guard: tokio::sync::MutexGuard<'a, T>,
resource_id: ResourceId,
task_id: Option<crate::task::TaskId>,
detector: DeadlockDetector,
}
impl<T> Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.guard
}
}
impl<T> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.guard
}
}
impl<T> Drop for MutexGuard<'_, T> {
fn drop(&mut self) {
if let Some(tid) = self.task_id {
self.detector.release(tid, self.resource_id);
}
}
}
impl<T: fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MutexGuard")
.field("value", &*self.guard)
.field("resource_id", &self.resource_id)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic_lock_unlock() {
let mutex = Mutex::new(42, "test_mutex");
{
let mut guard = mutex.lock().await;
assert_eq!(*guard, 42);
*guard = 100;
}
let guard = mutex.lock().await;
assert_eq!(*guard, 100);
let metrics = mutex.metrics();
assert_eq!(metrics.acquisitions, 2);
}
#[tokio::test]
async fn test_try_lock() {
let mutex = Mutex::new(42, "test_mutex");
let guard = mutex.try_lock();
assert!(guard.is_some());
let guard2 = mutex.try_lock();
assert!(guard2.is_none());
drop(guard);
let guard3 = mutex.try_lock();
assert!(guard3.is_some());
}
#[tokio::test]
async fn test_contention_metrics() {
use std::sync::Arc;
use tokio::time::{sleep, Duration};
let mutex = Arc::new(Mutex::new(0, "contended_mutex"));
let mut handles = vec![];
for _ in 0..5 {
let m = mutex.clone();
handles.push(tokio::spawn(async move {
let mut guard = m.lock().await;
sleep(Duration::from_millis(10)).await;
*guard += 1;
}));
}
for h in handles {
h.await.unwrap();
}
let metrics = mutex.metrics();
assert_eq!(metrics.acquisitions, 5);
assert!(metrics.contentions > 0);
}
#[tokio::test]
async fn test_into_inner() {
let mutex = Mutex::new(vec![1, 2, 3], "vec_mutex");
let inner = mutex.into_inner();
assert_eq!(inner, vec![1, 2, 3]);
}
#[tokio::test]
async fn test_get_mut() {
let mut mutex = Mutex::new(42, "mut_mutex");
*mutex.get_mut() = 100;
let guard = mutex.lock().await;
assert_eq!(*guard, 100);
}
}