use std::cell::Cell;
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};
tokio::task_local! {
#[cfg(debug_assertions)]
pub(crate) static LOCK_LEVEL: Cell<usize>;
}
#[cfg(debug_assertions)]
pub(crate) struct LevelGuard {
prev: usize,
}
#[cfg(debug_assertions)]
impl LevelGuard {
pub(crate) fn acquire(l: usize) -> Option<Self> {
LOCK_LEVEL
.try_with(|cell| {
let prev = cell.get();
assert!(
l > prev,
"Lock ordering violation: tried to acquire level-{l} lock \
while level-{prev} is already held. \
See the LOCK ORDER comment in consumer/mod.rs."
);
cell.set(l);
LevelGuard { prev }
})
.ok()
}
}
#[cfg(debug_assertions)]
impl Drop for LevelGuard {
fn drop(&mut self) {
let _ = LOCK_LEVEL.try_with(|cell| cell.set(self.prev));
}
}
pub(crate) struct LeveledRwLock<const L: usize, T>(tokio::sync::RwLock<T>);
impl<const L: usize, T> LeveledRwLock<L, T> {
#[inline]
pub(crate) fn new(val: T) -> Self {
Self(tokio::sync::RwLock::new(val))
}
#[inline]
pub(crate) async fn read(&self) -> LeveledReadGuard<'_, L, T> {
#[cfg(debug_assertions)]
let _level_guard = LevelGuard::acquire(L);
let guard = self.0.read().await;
LeveledReadGuard {
guard,
#[cfg(debug_assertions)]
_level_guard,
}
}
#[inline]
pub(crate) async fn write(&self) -> LeveledWriteGuard<'_, L, T> {
#[cfg(debug_assertions)]
let _level_guard = LevelGuard::acquire(L);
let guard = self.0.write().await;
LeveledWriteGuard {
guard,
#[cfg(debug_assertions)]
_level_guard,
}
}
#[inline]
pub(crate) fn try_read(&self) -> Result<RwLockReadGuard<'_, T>, tokio::sync::TryLockError> {
self.0.try_read()
}
#[allow(dead_code)] #[inline]
pub(crate) fn try_write(&self) -> Result<RwLockWriteGuard<'_, T>, tokio::sync::TryLockError> {
self.0.try_write()
}
}
pub(crate) struct LeveledReadGuard<'a, const L: usize, T> {
guard: RwLockReadGuard<'a, T>,
#[cfg(debug_assertions)]
_level_guard: Option<LevelGuard>,
}
impl<const L: usize, T> std::ops::Deref for LeveledReadGuard<'_, L, T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
&self.guard
}
}
pub(crate) struct LeveledWriteGuard<'a, const L: usize, T> {
guard: RwLockWriteGuard<'a, T>,
#[cfg(debug_assertions)]
_level_guard: Option<LevelGuard>,
}
impl<const L: usize, T> std::ops::Deref for LeveledWriteGuard<'_, L, T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
&self.guard
}
}
impl<const L: usize, T> std::ops::DerefMut for LeveledWriteGuard<'_, L, T> {
#[inline]
fn deref_mut(&mut self) -> &mut T {
&mut self.guard
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_leveled_lock_new_and_basic_access() {
let lock = LeveledRwLock::<3, u32>::new(42);
let guard = lock.try_read().expect("try_read should succeed");
assert_eq!(*guard, 42);
}
#[test]
fn test_try_write_returns_guard() {
let lock = LeveledRwLock::<2, Vec<i32>>::new(vec![1, 2, 3]);
let mut guard = lock.try_write().expect("try_write should succeed");
guard.push(4);
drop(guard);
assert_eq!(*lock.try_read().unwrap(), vec![1, 2, 3, 4]);
}
#[tokio::test]
async fn test_async_read_and_write() {
let lock = LeveledRwLock::<1, String>::new("hello".to_string());
{
let r = lock.read().await;
assert_eq!(*r, "hello");
}
{
let mut w = lock.write().await;
w.push_str(" world");
}
let r = lock.read().await;
assert_eq!(*r, "hello world");
}
}