use crate::deadlock::{DeadlockDetector, ResourceId, ResourceInfo, ResourceKind};
use crate::inspector::Inspector;
use crate::instrument::current_task_id;
use crate::sync::{LockMetrics, MetricsTracker, WaitTimer};
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tokio::sync::RwLock as TokioRwLock;
pub struct RwLock<T> {
inner: TokioRwLock<T>,
name: String,
resource_id: ResourceId,
read_metrics: Arc<MetricsTracker>,
write_metrics: Arc<MetricsTracker>,
}
impl<T> RwLock<T> {
pub fn new(value: T, name: impl Into<String>) -> Self {
let name = name.into();
let resource_info = ResourceInfo::new(ResourceKind::RwLock, name.clone());
let resource_id = resource_info.id;
let detector = Inspector::global().deadlock_detector();
let _ = detector.register_resource(resource_info);
Self {
inner: TokioRwLock::new(value),
name,
resource_id,
read_metrics: Arc::new(MetricsTracker::new()),
write_metrics: Arc::new(MetricsTracker::new()),
}
}
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
let detector = Inspector::global().deadlock_detector();
let task_id = current_task_id();
if let Some(tid) = task_id {
detector.wait_for(tid, self.resource_id);
}
let timer = WaitTimer::start();
let guard = self.inner.read().await;
let wait_time = timer.elapsed_if_contended();
self.read_metrics.record_acquisition(wait_time);
if let Some(tid) = task_id {
detector.acquire(tid, self.resource_id);
}
RwLockReadGuard {
guard,
resource_id: self.resource_id,
task_id,
detector: detector.clone(),
}
}
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
let detector = Inspector::global().deadlock_detector();
let task_id = current_task_id();
if let Some(tid) = task_id {
detector.wait_for(tid, self.resource_id);
}
let timer = WaitTimer::start();
let guard = self.inner.write().await;
let wait_time = timer.elapsed_if_contended();
self.write_metrics.record_acquisition(wait_time);
if let Some(tid) = task_id {
detector.acquire(tid, self.resource_id);
}
RwLockWriteGuard {
guard,
resource_id: self.resource_id,
task_id,
detector: detector.clone(),
}
}
pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
let detector = Inspector::global().deadlock_detector();
let task_id = current_task_id();
match self.inner.try_read() {
Ok(guard) => {
self.read_metrics.record_acquisition(None);
if let Some(tid) = task_id {
detector.acquire(tid, self.resource_id);
}
Some(RwLockReadGuard {
guard,
resource_id: self.resource_id,
task_id,
detector: detector.clone(),
})
}
Err(_) => None,
}
}
pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
let detector = Inspector::global().deadlock_detector();
let task_id = current_task_id();
match self.inner.try_write() {
Ok(guard) => {
self.write_metrics.record_acquisition(None);
if let Some(tid) = task_id {
detector.acquire(tid, self.resource_id);
}
Some(RwLockWriteGuard {
guard,
resource_id: self.resource_id,
task_id,
detector: detector.clone(),
})
}
Err(_) => None,
}
}
#[must_use]
pub fn metrics(&self) -> (LockMetrics, LockMetrics) {
(
self.read_metrics.get_metrics(),
self.write_metrics.get_metrics(),
)
}
#[must_use]
pub fn read_metrics(&self) -> LockMetrics {
self.read_metrics.get_metrics()
}
#[must_use]
pub fn write_metrics(&self) -> LockMetrics {
self.write_metrics.get_metrics()
}
pub fn reset_metrics(&self) {
self.read_metrics.reset();
self.write_metrics.reset();
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn resource_id(&self) -> ResourceId {
self.resource_id
}
pub fn into_inner(self) -> T {
self.inner.into_inner()
}
pub fn get_mut(&mut self) -> &mut T {
self.inner.get_mut()
}
}
impl<T: fmt::Debug> fmt::Debug for RwLock<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (read_metrics, write_metrics) = self.metrics();
f.debug_struct("RwLock")
.field("name", &self.name)
.field("resource_id", &self.resource_id)
.field("read_acquisitions", &read_metrics.acquisitions)
.field("write_acquisitions", &write_metrics.acquisitions)
.finish()
}
}
pub struct RwLockReadGuard<'a, T> {
guard: tokio::sync::RwLockReadGuard<'a, T>,
resource_id: ResourceId,
task_id: Option<crate::task::TaskId>,
detector: DeadlockDetector,
}
impl<T> Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.guard
}
}
impl<T> Drop for RwLockReadGuard<'_, T> {
fn drop(&mut self) {
if let Some(tid) = self.task_id {
self.detector.release(tid, self.resource_id);
}
}
}
impl<T: fmt::Debug> fmt::Debug for RwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RwLockReadGuard")
.field("value", &*self.guard)
.field("resource_id", &self.resource_id)
.finish()
}
}
pub struct RwLockWriteGuard<'a, T> {
guard: tokio::sync::RwLockWriteGuard<'a, T>,
resource_id: ResourceId,
task_id: Option<crate::task::TaskId>,
detector: DeadlockDetector,
}
impl<T> Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.guard
}
}
impl<T> DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.guard
}
}
impl<T> Drop for RwLockWriteGuard<'_, T> {
fn drop(&mut self) {
if let Some(tid) = self.task_id {
self.detector.release(tid, self.resource_id);
}
}
}
impl<T: fmt::Debug> fmt::Debug for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RwLockWriteGuard")
.field("value", &*self.guard)
.field("resource_id", &self.resource_id)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic_read_write() {
let lock = RwLock::new(42, "test_lock");
{
let guard = lock.read().await;
assert_eq!(*guard, 42);
}
{
let mut guard = lock.write().await;
*guard = 100;
}
let guard = lock.read().await;
assert_eq!(*guard, 100);
let (read_metrics, write_metrics) = lock.metrics();
assert_eq!(read_metrics.acquisitions, 2);
assert_eq!(write_metrics.acquisitions, 1);
}
#[tokio::test]
async fn test_concurrent_readers() {
use std::sync::Arc;
let lock = Arc::new(RwLock::new(vec![1, 2, 3], "shared_vec"));
let mut handles = vec![];
for _ in 0..5 {
let l = lock.clone();
handles.push(tokio::spawn(async move {
let guard = l.read().await;
assert_eq!(guard.len(), 3);
}));
}
for h in handles {
h.await.unwrap();
}
let read_metrics = lock.read_metrics();
assert_eq!(read_metrics.acquisitions, 5);
}
#[tokio::test]
async fn test_try_read_write() {
let lock = RwLock::new(42, "test_lock");
let guard = lock.try_read();
assert!(guard.is_some());
drop(guard);
let guard = lock.try_write();
assert!(guard.is_some());
let guard2 = lock.try_read();
assert!(guard2.is_none());
drop(guard);
let guard3 = lock.try_read();
assert!(guard3.is_some());
}
#[tokio::test]
async fn test_into_inner() {
let lock = RwLock::new(vec![1, 2, 3], "vec_lock");
let inner = lock.into_inner();
assert_eq!(inner, vec![1, 2, 3]);
}
#[tokio::test]
async fn test_get_mut() {
let mut lock = RwLock::new(42, "mut_lock");
*lock.get_mut() = 100;
let guard = lock.read().await;
assert_eq!(*guard, 100);
}
}