1use crate::deadlock::{DeadlockDetector, ResourceId, ResourceInfo, ResourceKind};
7use crate::inspector::Inspector;
8use crate::instrument::current_task_id;
9use crate::sync::{LockMetrics, MetricsTracker, WaitTimer};
10
11use std::fmt;
12use std::ops::{Deref, DerefMut};
13use std::sync::Arc;
14use tokio::sync::RwLock as TokioRwLock;
15
16pub struct RwLock<T> {
60 inner: TokioRwLock<T>,
62 name: String,
64 resource_id: ResourceId,
66 read_metrics: Arc<MetricsTracker>,
68 write_metrics: Arc<MetricsTracker>,
70}
71
72impl<T> RwLock<T> {
73 pub fn new(value: T, name: impl Into<String>) -> Self {
89 let name = name.into();
90 let resource_info = ResourceInfo::new(ResourceKind::RwLock, name.clone());
91 let resource_id = resource_info.id;
92
93 let detector = Inspector::global().deadlock_detector();
95 let _ = detector.register_resource(resource_info);
96
97 Self {
98 inner: TokioRwLock::new(value),
99 name,
100 resource_id,
101 read_metrics: Arc::new(MetricsTracker::new()),
102 write_metrics: Arc::new(MetricsTracker::new()),
103 }
104 }
105
106 pub async fn read(&self) -> RwLockReadGuard<'_, T> {
123 let detector = Inspector::global().deadlock_detector();
124 let task_id = current_task_id();
125
126 if let Some(tid) = task_id {
128 detector.wait_for(tid, self.resource_id);
129 }
130
131 let timer = WaitTimer::start();
132
133 let guard = self.inner.read().await;
135
136 let wait_time = timer.elapsed_if_contended();
138 self.read_metrics.record_acquisition(wait_time);
139
140 if let Some(tid) = task_id {
142 detector.acquire(tid, self.resource_id);
143 }
144
145 RwLockReadGuard {
146 guard,
147 resource_id: self.resource_id,
148 task_id,
149 detector: detector.clone(),
150 }
151 }
152
153 pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
170 let detector = Inspector::global().deadlock_detector();
171 let task_id = current_task_id();
172
173 if let Some(tid) = task_id {
175 detector.wait_for(tid, self.resource_id);
176 }
177
178 let timer = WaitTimer::start();
179
180 let guard = self.inner.write().await;
182
183 let wait_time = timer.elapsed_if_contended();
185 self.write_metrics.record_acquisition(wait_time);
186
187 if let Some(tid) = task_id {
189 detector.acquire(tid, self.resource_id);
190 }
191
192 RwLockWriteGuard {
193 guard,
194 resource_id: self.resource_id,
195 task_id,
196 detector: detector.clone(),
197 }
198 }
199
200 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
204 let detector = Inspector::global().deadlock_detector();
205 let task_id = current_task_id();
206
207 match self.inner.try_read() {
208 Ok(guard) => {
209 self.read_metrics.record_acquisition(None);
210
211 if let Some(tid) = task_id {
212 detector.acquire(tid, self.resource_id);
213 }
214
215 Some(RwLockReadGuard {
216 guard,
217 resource_id: self.resource_id,
218 task_id,
219 detector: detector.clone(),
220 })
221 }
222 Err(_) => None,
223 }
224 }
225
226 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
230 let detector = Inspector::global().deadlock_detector();
231 let task_id = current_task_id();
232
233 match self.inner.try_write() {
234 Ok(guard) => {
235 self.write_metrics.record_acquisition(None);
236
237 if let Some(tid) = task_id {
238 detector.acquire(tid, self.resource_id);
239 }
240
241 Some(RwLockWriteGuard {
242 guard,
243 resource_id: self.resource_id,
244 task_id,
245 detector: detector.clone(),
246 })
247 }
248 Err(_) => None,
249 }
250 }
251
252 #[must_use]
270 pub fn metrics(&self) -> (LockMetrics, LockMetrics) {
271 (
272 self.read_metrics.get_metrics(),
273 self.write_metrics.get_metrics(),
274 )
275 }
276
277 #[must_use]
279 pub fn read_metrics(&self) -> LockMetrics {
280 self.read_metrics.get_metrics()
281 }
282
283 #[must_use]
285 pub fn write_metrics(&self) -> LockMetrics {
286 self.write_metrics.get_metrics()
287 }
288
289 pub fn reset_metrics(&self) {
291 self.read_metrics.reset();
292 self.write_metrics.reset();
293 }
294
295 #[must_use]
297 pub fn name(&self) -> &str {
298 &self.name
299 }
300
301 #[must_use]
303 pub fn resource_id(&self) -> ResourceId {
304 self.resource_id
305 }
306
307 pub fn into_inner(self) -> T {
309 self.inner.into_inner()
310 }
311
312 pub fn get_mut(&mut self) -> &mut T {
316 self.inner.get_mut()
317 }
318}
319
320impl<T: fmt::Debug> fmt::Debug for RwLock<T> {
321 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322 let (read_metrics, write_metrics) = self.metrics();
323 f.debug_struct("RwLock")
324 .field("name", &self.name)
325 .field("resource_id", &self.resource_id)
326 .field("read_acquisitions", &read_metrics.acquisitions)
327 .field("write_acquisitions", &write_metrics.acquisitions)
328 .finish()
329 }
330}
331
332pub struct RwLockReadGuard<'a, T> {
337 guard: tokio::sync::RwLockReadGuard<'a, T>,
338 resource_id: ResourceId,
339 task_id: Option<crate::task::TaskId>,
340 detector: DeadlockDetector,
341}
342
343impl<T> Deref for RwLockReadGuard<'_, T> {
344 type Target = T;
345
346 fn deref(&self) -> &Self::Target {
347 &self.guard
348 }
349}
350
351impl<T> Drop for RwLockReadGuard<'_, T> {
352 fn drop(&mut self) {
353 if let Some(tid) = self.task_id {
354 self.detector.release(tid, self.resource_id);
355 }
356 }
357}
358
359impl<T: fmt::Debug> fmt::Debug for RwLockReadGuard<'_, T> {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 f.debug_struct("RwLockReadGuard")
362 .field("value", &*self.guard)
363 .field("resource_id", &self.resource_id)
364 .finish()
365 }
366}
367
368pub struct RwLockWriteGuard<'a, T> {
373 guard: tokio::sync::RwLockWriteGuard<'a, T>,
374 resource_id: ResourceId,
375 task_id: Option<crate::task::TaskId>,
376 detector: DeadlockDetector,
377}
378
379impl<T> Deref for RwLockWriteGuard<'_, T> {
380 type Target = T;
381
382 fn deref(&self) -> &Self::Target {
383 &self.guard
384 }
385}
386
387impl<T> DerefMut for RwLockWriteGuard<'_, T> {
388 fn deref_mut(&mut self) -> &mut Self::Target {
389 &mut self.guard
390 }
391}
392
393impl<T> Drop for RwLockWriteGuard<'_, T> {
394 fn drop(&mut self) {
395 if let Some(tid) = self.task_id {
396 self.detector.release(tid, self.resource_id);
397 }
398 }
399}
400
401impl<T: fmt::Debug> fmt::Debug for RwLockWriteGuard<'_, T> {
402 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403 f.debug_struct("RwLockWriteGuard")
404 .field("value", &*self.guard)
405 .field("resource_id", &self.resource_id)
406 .finish()
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[tokio::test]
415 async fn test_basic_read_write() {
416 let lock = RwLock::new(42, "test_lock");
417
418 {
420 let guard = lock.read().await;
421 assert_eq!(*guard, 42);
422 }
423
424 {
426 let mut guard = lock.write().await;
427 *guard = 100;
428 }
429
430 let guard = lock.read().await;
432 assert_eq!(*guard, 100);
433
434 let (read_metrics, write_metrics) = lock.metrics();
435 assert_eq!(read_metrics.acquisitions, 2);
436 assert_eq!(write_metrics.acquisitions, 1);
437 }
438
439 #[tokio::test]
440 async fn test_concurrent_readers() {
441 use std::sync::Arc;
442
443 let lock = Arc::new(RwLock::new(vec![1, 2, 3], "shared_vec"));
444 let mut handles = vec![];
445
446 for _ in 0..5 {
448 let l = lock.clone();
449 handles.push(tokio::spawn(async move {
450 let guard = l.read().await;
451 assert_eq!(guard.len(), 3);
452 }));
453 }
454
455 for h in handles {
456 h.await.unwrap();
457 }
458
459 let read_metrics = lock.read_metrics();
460 assert_eq!(read_metrics.acquisitions, 5);
461 }
462
463 #[tokio::test]
464 async fn test_try_read_write() {
465 let lock = RwLock::new(42, "test_lock");
466
467 let guard = lock.try_read();
469 assert!(guard.is_some());
470 drop(guard);
471
472 let guard = lock.try_write();
474 assert!(guard.is_some());
475
476 let guard2 = lock.try_read();
478 assert!(guard2.is_none());
479
480 drop(guard);
481
482 let guard3 = lock.try_read();
484 assert!(guard3.is_some());
485 }
486
487 #[tokio::test]
488 async fn test_into_inner() {
489 let lock = RwLock::new(vec![1, 2, 3], "vec_lock");
490 let inner = lock.into_inner();
491 assert_eq!(inner, vec![1, 2, 3]);
492 }
493
494 #[tokio::test]
495 async fn test_get_mut() {
496 let mut lock = RwLock::new(42, "mut_lock");
497 *lock.get_mut() = 100;
498 let guard = lock.read().await;
499 assert_eq!(*guard, 100);
500 }
501}