use std::sync::Arc;
use tokio::sync::RwLock as ExtRwLock;
#[cfg(feature = "diagnostics")]
pub(super) type RwLock<T> = diagnostic::DiagnosticRwLock<T>;
#[cfg(not(feature = "diagnostics"))]
pub(super) type RwLock<T> = Arc<ExtRwLock<T>>;
pub(super) fn new<T>(value: T) -> RwLock<T> {
#[cfg(feature = "diagnostics")]
return diagnostic::DiagnosticRwLock::new(value);
#[cfg(not(feature = "diagnostics"))]
return Arc::new(ExtRwLock::new(value));
}
#[cfg(feature = "diagnostics")]
mod diagnostic {
use super::*;
use std::time::Duration;
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};
use tokio::time::timeout;
use tracing::warn;
pub(crate) struct DiagnosticRwLock<T> {
arc_lock: Arc<ExtRwLock<T>>,
timeout: Duration,
}
impl<T> DiagnosticRwLock<T> {
pub(crate) fn new(inner: T) -> Self {
Self {
arc_lock: Arc::new(ExtRwLock::new(inner)),
timeout: Duration::from_secs(1),
}
}
pub(crate) async fn read(&self) -> RwLockReadGuard<'_, T> {
loop {
match timeout(self.timeout, self.arc_lock.read()).await {
Ok(inner) => return inner,
Err(_) => {
warn!("Unable to acquire read in {:?}", self.timeout);
}
}
}
}
pub(crate) async fn write(&self) -> RwLockWriteGuard<'_, T> {
loop {
match timeout(self.timeout, self.arc_lock.write()).await {
Ok(inner) => return inner,
Err(_) => {
warn!("Unable to acquire write in {:?}", self.timeout);
}
}
}
}
}
impl<T> Clone for DiagnosticRwLock<T> {
fn clone(&self) -> Self {
Self {
arc_lock: self.arc_lock.clone(),
timeout: self.timeout,
}
}
}
#[cfg(test)]
mod diagnostic_tests {
use super::*;
#[tokio::test]
async fn test_timeout_warning() {
let lock = new(42);
let _write = lock.write().await;
let read =
tokio::time::timeout(Duration::from_millis(5), lock.read())
.await;
assert!(read.is_err(), "Read lock should have timed out");
}
}
}