async_inspect/sync/
mutex.rs

1//! Tracked Mutex implementation
2//!
3//! A drop-in replacement for `tokio::sync::Mutex` that automatically tracks
4//! contention and integrates with async-inspect's deadlock detection.
5
6use crate::deadlock::{DeadlockDetector, ResourceId, ResourceInfo, ResourceKind};
7use crate::inspector::Inspector;
8use crate::instrument::current_task_id;
9use crate::sync::{LockMetrics, MetricsTracker, WaitTimer};
10
11use std::fmt;
12use std::ops::{Deref, DerefMut};
13use std::sync::Arc;
14use tokio::sync::Mutex as TokioMutex;
15
16/// A tracked mutex that automatically records contention metrics
17/// and integrates with deadlock detection.
18///
19/// This is a drop-in replacement for `tokio::sync::Mutex` with additional
20/// observability features.
21///
22/// # Example
23///
24/// ```rust,no_run
25/// use async_inspect::sync::Mutex;
26///
27/// #[tokio::main]
28/// async fn main() {
29///     let mutex = Mutex::new(0, "counter");
30///
31///     // Spawn multiple tasks that contend for the lock
32///     let mutex = std::sync::Arc::new(mutex);
33///     let mut handles = vec![];
34///
35///     for i in 0..10 {
36///         let m = mutex.clone();
37///         handles.push(tokio::spawn(async move {
38///             let mut guard = m.lock().await;
39///             *guard += 1;
40///         }));
41///     }
42///
43///     for h in handles {
44///         h.await.unwrap();
45///     }
46///
47///     // Check contention metrics
48///     let metrics = mutex.metrics();
49///     println!("Total acquisitions: {}", metrics.acquisitions);
50///     println!("Contentions: {}", metrics.contentions);
51///     println!("Contention rate: {:.1}%", metrics.contention_rate() * 100.0);
52/// }
53/// ```
54pub struct Mutex<T> {
55    /// The underlying Tokio mutex
56    inner: TokioMutex<T>,
57    /// Name for debugging/display
58    name: String,
59    /// Resource ID for deadlock detection
60    resource_id: ResourceId,
61    /// Contention metrics
62    metrics: Arc<MetricsTracker>,
63}
64
65impl<T> Mutex<T> {
66    /// Create a new tracked mutex with a name for identification.
67    ///
68    /// # Arguments
69    ///
70    /// * `value` - The initial value to protect
71    /// * `name` - A descriptive name for debugging and metrics
72    ///
73    /// # Example
74    ///
75    /// ```rust,no_run
76    /// use async_inspect::sync::Mutex;
77    ///
78    /// let mutex = Mutex::new(vec![1, 2, 3], "shared_vector");
79    /// ```
80    pub fn new(value: T, name: impl Into<String>) -> Self {
81        let name = name.into();
82        let resource_info = ResourceInfo::new(ResourceKind::Mutex, name.clone());
83        let resource_id = resource_info.id;
84
85        // Register with deadlock detector
86        let detector = Inspector::global().deadlock_detector();
87        let _ = detector.register_resource(resource_info);
88
89        Self {
90            inner: TokioMutex::new(value),
91            name,
92            resource_id,
93            metrics: Arc::new(MetricsTracker::new()),
94        }
95    }
96
97    /// Acquire the lock, blocking until it's available.
98    ///
99    /// This method automatically:
100    /// - Records wait time if there's contention
101    /// - Notifies the deadlock detector
102    /// - Tracks acquisition metrics
103    ///
104    /// # Example
105    ///
106    /// ```rust,no_run
107    /// use async_inspect::sync::Mutex;
108    ///
109    /// # async fn example() {
110    /// let mutex = Mutex::new(42, "my_value");
111    /// let guard = mutex.lock().await;
112    /// println!("Value: {}", *guard);
113    /// # }
114    /// ```
115    pub async fn lock(&self) -> MutexGuard<'_, T> {
116        let detector = Inspector::global().deadlock_detector();
117        let task_id = current_task_id();
118
119        // Record that we're waiting for this resource
120        if let Some(tid) = task_id {
121            detector.wait_for(tid, self.resource_id);
122        }
123
124        let timer = WaitTimer::start();
125
126        // Actually acquire the lock
127        let guard = self.inner.lock().await;
128
129        // Record metrics
130        let wait_time = timer.elapsed_if_contended();
131        self.metrics.record_acquisition(wait_time);
132
133        // Record successful acquisition
134        if let Some(tid) = task_id {
135            detector.acquire(tid, self.resource_id);
136        }
137
138        MutexGuard {
139            guard,
140            resource_id: self.resource_id,
141            task_id,
142            detector: detector.clone(),
143        }
144    }
145
146    /// Try to acquire the lock immediately.
147    ///
148    /// Returns `None` if the lock is already held.
149    ///
150    /// # Example
151    ///
152    /// ```rust,no_run
153    /// use async_inspect::sync::Mutex;
154    ///
155    /// # async fn example() {
156    /// let mutex = Mutex::new(42, "my_value");
157    /// let result = mutex.try_lock();
158    /// if let Some(guard) = result {
159    ///     println!("Got the lock: {}", *guard);
160    /// } else {
161    ///     println!("Lock is held by another task");
162    /// }
163    /// # }
164    /// ```
165    pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
166        let detector = Inspector::global().deadlock_detector();
167        let task_id = current_task_id();
168
169        match self.inner.try_lock() {
170            Ok(guard) => {
171                // Immediate acquisition - no contention
172                self.metrics.record_acquisition(None);
173
174                if let Some(tid) = task_id {
175                    detector.acquire(tid, self.resource_id);
176                }
177
178                Some(MutexGuard {
179                    guard,
180                    resource_id: self.resource_id,
181                    task_id,
182                    detector: detector.clone(),
183                })
184            }
185            Err(_) => None,
186        }
187    }
188
189    /// Get the current contention metrics for this mutex.
190    ///
191    /// # Example
192    ///
193    /// ```rust,no_run
194    /// use async_inspect::sync::Mutex;
195    ///
196    /// # async fn example() {
197    /// let mutex = Mutex::new(42, "my_value");
198    /// // ... some operations ...
199    /// let metrics = mutex.metrics();
200    /// println!("Acquisitions: {}", metrics.acquisitions);
201    /// println!("Contention rate: {:.1}%", metrics.contention_rate() * 100.0);
202    /// # }
203    /// ```
204    #[must_use]
205    pub fn metrics(&self) -> LockMetrics {
206        self.metrics.get_metrics()
207    }
208
209    /// Reset the contention metrics.
210    pub fn reset_metrics(&self) {
211        self.metrics.reset();
212    }
213
214    /// Get the name of this mutex.
215    #[must_use]
216    pub fn name(&self) -> &str {
217        &self.name
218    }
219
220    /// Get the resource ID for deadlock detection.
221    #[must_use]
222    pub fn resource_id(&self) -> ResourceId {
223        self.resource_id
224    }
225
226    /// Consume the mutex and return the inner value.
227    ///
228    /// # Panics
229    ///
230    /// This method will panic if the mutex is poisoned.
231    pub fn into_inner(self) -> T {
232        self.inner.into_inner()
233    }
234
235    /// Get a mutable reference to the inner value without locking.
236    ///
237    /// This is safe because we have exclusive access via `&mut self`.
238    pub fn get_mut(&mut self) -> &mut T {
239        self.inner.get_mut()
240    }
241}
242
243impl<T: fmt::Debug> fmt::Debug for Mutex<T> {
244    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245        let metrics = self.metrics();
246        f.debug_struct("Mutex")
247            .field("name", &self.name)
248            .field("resource_id", &self.resource_id)
249            .field("acquisitions", &metrics.acquisitions)
250            .field("contentions", &metrics.contentions)
251            .finish()
252    }
253}
254
255/// RAII guard for a tracked mutex.
256///
257/// When this guard is dropped, the lock is released and the deadlock
258/// detector is notified.
259pub struct MutexGuard<'a, T> {
260    guard: tokio::sync::MutexGuard<'a, T>,
261    resource_id: ResourceId,
262    task_id: Option<crate::task::TaskId>,
263    detector: DeadlockDetector,
264}
265
266impl<T> Deref for MutexGuard<'_, T> {
267    type Target = T;
268
269    fn deref(&self) -> &Self::Target {
270        &self.guard
271    }
272}
273
274impl<T> DerefMut for MutexGuard<'_, T> {
275    fn deref_mut(&mut self) -> &mut Self::Target {
276        &mut self.guard
277    }
278}
279
280impl<T> Drop for MutexGuard<'_, T> {
281    fn drop(&mut self) {
282        // Notify deadlock detector that we're releasing the lock
283        if let Some(tid) = self.task_id {
284            self.detector.release(tid, self.resource_id);
285        }
286    }
287}
288
289impl<T: fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
290    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291        f.debug_struct("MutexGuard")
292            .field("value", &*self.guard)
293            .field("resource_id", &self.resource_id)
294            .finish()
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[tokio::test]
303    async fn test_basic_lock_unlock() {
304        let mutex = Mutex::new(42, "test_mutex");
305
306        {
307            let mut guard = mutex.lock().await;
308            assert_eq!(*guard, 42);
309            *guard = 100;
310        }
311
312        let guard = mutex.lock().await;
313        assert_eq!(*guard, 100);
314
315        let metrics = mutex.metrics();
316        assert_eq!(metrics.acquisitions, 2);
317    }
318
319    #[tokio::test]
320    async fn test_try_lock() {
321        let mutex = Mutex::new(42, "test_mutex");
322
323        // Should succeed when unlocked
324        let guard = mutex.try_lock();
325        assert!(guard.is_some());
326
327        // Should fail when already locked
328        let guard2 = mutex.try_lock();
329        assert!(guard2.is_none());
330
331        // Drop the first guard
332        drop(guard);
333
334        // Should succeed again
335        let guard3 = mutex.try_lock();
336        assert!(guard3.is_some());
337    }
338
339    #[tokio::test]
340    async fn test_contention_metrics() {
341        use std::sync::Arc;
342        use tokio::time::{sleep, Duration};
343
344        let mutex = Arc::new(Mutex::new(0, "contended_mutex"));
345        let mut handles = vec![];
346
347        // Spawn tasks that will contend for the lock
348        for _ in 0..5 {
349            let m = mutex.clone();
350            handles.push(tokio::spawn(async move {
351                let mut guard = m.lock().await;
352                // Hold the lock briefly to cause contention
353                sleep(Duration::from_millis(10)).await;
354                *guard += 1;
355            }));
356        }
357
358        for h in handles {
359            h.await.unwrap();
360        }
361
362        let metrics = mutex.metrics();
363        assert_eq!(metrics.acquisitions, 5);
364        // At least some contention should have occurred
365        assert!(metrics.contentions > 0);
366    }
367
368    #[tokio::test]
369    async fn test_into_inner() {
370        let mutex = Mutex::new(vec![1, 2, 3], "vec_mutex");
371        let inner = mutex.into_inner();
372        assert_eq!(inner, vec![1, 2, 3]);
373    }
374
375    #[tokio::test]
376    async fn test_get_mut() {
377        let mut mutex = Mutex::new(42, "mut_mutex");
378        *mutex.get_mut() = 100;
379        let guard = mutex.lock().await;
380        assert_eq!(*guard, 100);
381    }
382}