async_inspect/sync/
rwlock.rs

1//! Tracked `RwLock` implementation
2//!
3//! A drop-in replacement for `tokio::sync::RwLock` that automatically tracks
4//! contention and integrates with async-inspect's deadlock detection.
5
6use crate::deadlock::{DeadlockDetector, ResourceId, ResourceInfo, ResourceKind};
7use crate::inspector::Inspector;
8use crate::instrument::current_task_id;
9use crate::sync::{LockMetrics, MetricsTracker, WaitTimer};
10
11use std::fmt;
12use std::ops::{Deref, DerefMut};
13use std::sync::Arc;
14use tokio::sync::RwLock as TokioRwLock;
15
16/// A tracked read-write lock that automatically records contention metrics
17/// and integrates with deadlock detection.
18///
19/// This is a drop-in replacement for `tokio::sync::RwLock` with additional
20/// observability features. It tracks read and write operations separately.
21///
22/// # Example
23///
24/// ```rust,no_run
25/// use async_inspect::sync::RwLock;
26///
27/// #[tokio::main]
28/// async fn main() {
29///     let lock = RwLock::new(vec![1, 2, 3], "shared_data");
30///
31///     // Multiple readers can access simultaneously
32///     let lock = std::sync::Arc::new(lock);
33///     let mut handles = vec![];
34///
35///     for i in 0..5 {
36///         let l = lock.clone();
37///         handles.push(tokio::spawn(async move {
38///             let guard = l.read().await;
39///             println!("Reader {}: {:?}", i, *guard);
40///         }));
41///     }
42///
43///     // Writer has exclusive access
44///     {
45///         let mut guard = lock.write().await;
46///         guard.push(4);
47///     }
48///
49///     for h in handles {
50///         h.await.unwrap();
51///     }
52///
53///     // Check contention metrics
54///     let (read_metrics, write_metrics) = lock.metrics();
55///     println!("Read acquisitions: {}", read_metrics.acquisitions);
56///     println!("Write acquisitions: {}", write_metrics.acquisitions);
57/// }
58/// ```
59pub struct RwLock<T> {
60    /// The underlying Tokio `RwLock`
61    inner: TokioRwLock<T>,
62    /// Name for debugging/display
63    name: String,
64    /// Resource ID for deadlock detection
65    resource_id: ResourceId,
66    /// Contention metrics for read operations
67    read_metrics: Arc<MetricsTracker>,
68    /// Contention metrics for write operations
69    write_metrics: Arc<MetricsTracker>,
70}
71
72impl<T> RwLock<T> {
73    /// Create a new tracked `RwLock` with a name for identification.
74    ///
75    /// # Arguments
76    ///
77    /// * `value` - The initial value to protect
78    /// * `name` - A descriptive name for debugging and metrics
79    ///
80    /// # Example
81    ///
82    /// ```rust,no_run
83    /// use async_inspect::sync::RwLock;
84    /// use std::collections::HashMap;
85    ///
86    /// let lock = RwLock::new(HashMap::<String, i32>::new(), "user_cache");
87    /// ```
88    pub fn new(value: T, name: impl Into<String>) -> Self {
89        let name = name.into();
90        let resource_info = ResourceInfo::new(ResourceKind::RwLock, name.clone());
91        let resource_id = resource_info.id;
92
93        // Register with deadlock detector
94        let detector = Inspector::global().deadlock_detector();
95        let _ = detector.register_resource(resource_info);
96
97        Self {
98            inner: TokioRwLock::new(value),
99            name,
100            resource_id,
101            read_metrics: Arc::new(MetricsTracker::new()),
102            write_metrics: Arc::new(MetricsTracker::new()),
103        }
104    }
105
106    /// Acquire a read lock, blocking until it's available.
107    ///
108    /// Multiple readers can hold the lock simultaneously, but writers
109    /// must wait for all readers to release.
110    ///
111    /// # Example
112    ///
113    /// ```rust,no_run
114    /// use async_inspect::sync::RwLock;
115    ///
116    /// # async fn example() {
117    /// let lock = RwLock::new(42, "my_value");
118    /// let guard = lock.read().await;
119    /// println!("Value: {}", *guard);
120    /// # }
121    /// ```
122    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
123        let detector = Inspector::global().deadlock_detector();
124        let task_id = current_task_id();
125
126        // Record that we're waiting for this resource
127        if let Some(tid) = task_id {
128            detector.wait_for(tid, self.resource_id);
129        }
130
131        let timer = WaitTimer::start();
132
133        // Actually acquire the lock
134        let guard = self.inner.read().await;
135
136        // Record metrics
137        let wait_time = timer.elapsed_if_contended();
138        self.read_metrics.record_acquisition(wait_time);
139
140        // Record successful acquisition
141        if let Some(tid) = task_id {
142            detector.acquire(tid, self.resource_id);
143        }
144
145        RwLockReadGuard {
146            guard,
147            resource_id: self.resource_id,
148            task_id,
149            detector: detector.clone(),
150        }
151    }
152
153    /// Acquire a write lock, blocking until it's available.
154    ///
155    /// Writers have exclusive access - no other readers or writers
156    /// can hold the lock simultaneously.
157    ///
158    /// # Example
159    ///
160    /// ```rust,no_run
161    /// use async_inspect::sync::RwLock;
162    ///
163    /// # async fn example() {
164    /// let lock = RwLock::new(42, "my_value");
165    /// let mut guard = lock.write().await;
166    /// *guard = 100;
167    /// # }
168    /// ```
169    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
170        let detector = Inspector::global().deadlock_detector();
171        let task_id = current_task_id();
172
173        // Record that we're waiting for this resource
174        if let Some(tid) = task_id {
175            detector.wait_for(tid, self.resource_id);
176        }
177
178        let timer = WaitTimer::start();
179
180        // Actually acquire the lock
181        let guard = self.inner.write().await;
182
183        // Record metrics
184        let wait_time = timer.elapsed_if_contended();
185        self.write_metrics.record_acquisition(wait_time);
186
187        // Record successful acquisition
188        if let Some(tid) = task_id {
189            detector.acquire(tid, self.resource_id);
190        }
191
192        RwLockWriteGuard {
193            guard,
194            resource_id: self.resource_id,
195            task_id,
196            detector: detector.clone(),
197        }
198    }
199
200    /// Try to acquire a read lock immediately.
201    ///
202    /// Returns `None` if a writer is holding the lock.
203    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
204        let detector = Inspector::global().deadlock_detector();
205        let task_id = current_task_id();
206
207        match self.inner.try_read() {
208            Ok(guard) => {
209                self.read_metrics.record_acquisition(None);
210
211                if let Some(tid) = task_id {
212                    detector.acquire(tid, self.resource_id);
213                }
214
215                Some(RwLockReadGuard {
216                    guard,
217                    resource_id: self.resource_id,
218                    task_id,
219                    detector: detector.clone(),
220                })
221            }
222            Err(_) => None,
223        }
224    }
225
226    /// Try to acquire a write lock immediately.
227    ///
228    /// Returns `None` if any readers or writers are holding the lock.
229    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
230        let detector = Inspector::global().deadlock_detector();
231        let task_id = current_task_id();
232
233        match self.inner.try_write() {
234            Ok(guard) => {
235                self.write_metrics.record_acquisition(None);
236
237                if let Some(tid) = task_id {
238                    detector.acquire(tid, self.resource_id);
239                }
240
241                Some(RwLockWriteGuard {
242                    guard,
243                    resource_id: self.resource_id,
244                    task_id,
245                    detector: detector.clone(),
246                })
247            }
248            Err(_) => None,
249        }
250    }
251
252    /// Get the current contention metrics for this `RwLock`.
253    ///
254    /// Returns a tuple of (`read_metrics`, `write_metrics`).
255    ///
256    /// # Example
257    ///
258    /// ```rust,no_run
259    /// use async_inspect::sync::RwLock;
260    ///
261    /// # async fn example() {
262    /// let lock = RwLock::new(42, "my_value");
263    /// // ... some operations ...
264    /// let (read_metrics, write_metrics) = lock.metrics();
265    /// println!("Read acquisitions: {}", read_metrics.acquisitions);
266    /// println!("Write acquisitions: {}", write_metrics.acquisitions);
267    /// # }
268    /// ```
269    #[must_use]
270    pub fn metrics(&self) -> (LockMetrics, LockMetrics) {
271        (
272            self.read_metrics.get_metrics(),
273            self.write_metrics.get_metrics(),
274        )
275    }
276
277    /// Get only the read metrics.
278    #[must_use]
279    pub fn read_metrics(&self) -> LockMetrics {
280        self.read_metrics.get_metrics()
281    }
282
283    /// Get only the write metrics.
284    #[must_use]
285    pub fn write_metrics(&self) -> LockMetrics {
286        self.write_metrics.get_metrics()
287    }
288
289    /// Reset all contention metrics.
290    pub fn reset_metrics(&self) {
291        self.read_metrics.reset();
292        self.write_metrics.reset();
293    }
294
295    /// Get the name of this `RwLock`.
296    #[must_use]
297    pub fn name(&self) -> &str {
298        &self.name
299    }
300
301    /// Get the resource ID for deadlock detection.
302    #[must_use]
303    pub fn resource_id(&self) -> ResourceId {
304        self.resource_id
305    }
306
307    /// Consume the `RwLock` and return the inner value.
308    pub fn into_inner(self) -> T {
309        self.inner.into_inner()
310    }
311
312    /// Get a mutable reference to the inner value without locking.
313    ///
314    /// This is safe because we have exclusive access via `&mut self`.
315    pub fn get_mut(&mut self) -> &mut T {
316        self.inner.get_mut()
317    }
318}
319
320impl<T: fmt::Debug> fmt::Debug for RwLock<T> {
321    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322        let (read_metrics, write_metrics) = self.metrics();
323        f.debug_struct("RwLock")
324            .field("name", &self.name)
325            .field("resource_id", &self.resource_id)
326            .field("read_acquisitions", &read_metrics.acquisitions)
327            .field("write_acquisitions", &write_metrics.acquisitions)
328            .finish()
329    }
330}
331
332/// RAII guard for a tracked `RwLock` read lock.
333///
334/// When this guard is dropped, the lock is released and the deadlock
335/// detector is notified.
336pub struct RwLockReadGuard<'a, T> {
337    guard: tokio::sync::RwLockReadGuard<'a, T>,
338    resource_id: ResourceId,
339    task_id: Option<crate::task::TaskId>,
340    detector: DeadlockDetector,
341}
342
343impl<T> Deref for RwLockReadGuard<'_, T> {
344    type Target = T;
345
346    fn deref(&self) -> &Self::Target {
347        &self.guard
348    }
349}
350
351impl<T> Drop for RwLockReadGuard<'_, T> {
352    fn drop(&mut self) {
353        if let Some(tid) = self.task_id {
354            self.detector.release(tid, self.resource_id);
355        }
356    }
357}
358
359impl<T: fmt::Debug> fmt::Debug for RwLockReadGuard<'_, T> {
360    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361        f.debug_struct("RwLockReadGuard")
362            .field("value", &*self.guard)
363            .field("resource_id", &self.resource_id)
364            .finish()
365    }
366}
367
368/// RAII guard for a tracked `RwLock` write lock.
369///
370/// When this guard is dropped, the lock is released and the deadlock
371/// detector is notified.
372pub struct RwLockWriteGuard<'a, T> {
373    guard: tokio::sync::RwLockWriteGuard<'a, T>,
374    resource_id: ResourceId,
375    task_id: Option<crate::task::TaskId>,
376    detector: DeadlockDetector,
377}
378
379impl<T> Deref for RwLockWriteGuard<'_, T> {
380    type Target = T;
381
382    fn deref(&self) -> &Self::Target {
383        &self.guard
384    }
385}
386
387impl<T> DerefMut for RwLockWriteGuard<'_, T> {
388    fn deref_mut(&mut self) -> &mut Self::Target {
389        &mut self.guard
390    }
391}
392
393impl<T> Drop for RwLockWriteGuard<'_, T> {
394    fn drop(&mut self) {
395        if let Some(tid) = self.task_id {
396            self.detector.release(tid, self.resource_id);
397        }
398    }
399}
400
401impl<T: fmt::Debug> fmt::Debug for RwLockWriteGuard<'_, T> {
402    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403        f.debug_struct("RwLockWriteGuard")
404            .field("value", &*self.guard)
405            .field("resource_id", &self.resource_id)
406            .finish()
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[tokio::test]
415    async fn test_basic_read_write() {
416        let lock = RwLock::new(42, "test_lock");
417
418        // Read
419        {
420            let guard = lock.read().await;
421            assert_eq!(*guard, 42);
422        }
423
424        // Write
425        {
426            let mut guard = lock.write().await;
427            *guard = 100;
428        }
429
430        // Read again
431        let guard = lock.read().await;
432        assert_eq!(*guard, 100);
433
434        let (read_metrics, write_metrics) = lock.metrics();
435        assert_eq!(read_metrics.acquisitions, 2);
436        assert_eq!(write_metrics.acquisitions, 1);
437    }
438
439    #[tokio::test]
440    async fn test_concurrent_readers() {
441        use std::sync::Arc;
442
443        let lock = Arc::new(RwLock::new(vec![1, 2, 3], "shared_vec"));
444        let mut handles = vec![];
445
446        // Spawn multiple readers
447        for _ in 0..5 {
448            let l = lock.clone();
449            handles.push(tokio::spawn(async move {
450                let guard = l.read().await;
451                assert_eq!(guard.len(), 3);
452            }));
453        }
454
455        for h in handles {
456            h.await.unwrap();
457        }
458
459        let read_metrics = lock.read_metrics();
460        assert_eq!(read_metrics.acquisitions, 5);
461    }
462
463    #[tokio::test]
464    async fn test_try_read_write() {
465        let lock = RwLock::new(42, "test_lock");
466
467        // Should succeed when no locks held
468        let guard = lock.try_read();
469        assert!(guard.is_some());
470        drop(guard);
471
472        // Try write should succeed when no locks held
473        let guard = lock.try_write();
474        assert!(guard.is_some());
475
476        // Try read should fail when write lock is held
477        let guard2 = lock.try_read();
478        assert!(guard2.is_none());
479
480        drop(guard);
481
482        // Should succeed now
483        let guard3 = lock.try_read();
484        assert!(guard3.is_some());
485    }
486
487    #[tokio::test]
488    async fn test_into_inner() {
489        let lock = RwLock::new(vec![1, 2, 3], "vec_lock");
490        let inner = lock.into_inner();
491        assert_eq!(inner, vec![1, 2, 3]);
492    }
493
494    #[tokio::test]
495    async fn test_get_mut() {
496        let mut lock = RwLock::new(42, "mut_lock");
497        *lock.get_mut() = 100;
498        let guard = lock.read().await;
499        assert_eq!(*guard, 100);
500    }
501}