use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, OwnedMutexGuard};
use super::PvDatabase;
#[derive(Default)]
pub(crate) struct RecordLockRegistry {
gates: std::sync::Mutex<HashMap<String, Arc<Mutex<()>>>>,
}
impl RecordLockRegistry {
fn gate_for(&self, record: &str) -> Arc<Mutex<()>> {
let mut map = self
.gates
.lock()
.expect("record-lock registry mutex poisoned");
map.entry(record.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}
}
pub struct RecordWriteGuard {
_guard: OwnedMutexGuard<()>,
}
#[must_use = "the locked epoch ends as soon as the guard is dropped"]
pub struct ManyRecordWriteGuard {
_guards: Vec<OwnedMutexGuard<()>>,
}
impl PvDatabase {
pub async fn lock_record(&self, record: &str) -> RecordWriteGuard {
let canonical = self
.resolve_alias(record)
.await
.unwrap_or_else(|| record.to_string());
let gate = self.inner.record_locks.gate_for(&canonical);
RecordWriteGuard {
_guard: gate.lock_owned().await,
}
}
pub async fn lock_records<I, S>(&self, records: I) -> ManyRecordWriteGuard
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let mut names: Vec<String> = Vec::new();
for record in records {
let record = record.as_ref();
names.push(
self.resolve_alias(record)
.await
.unwrap_or_else(|| record.to_string()),
);
}
names.sort_unstable();
names.dedup();
let mut guards = Vec::with_capacity(names.len());
for name in &names {
let gate = self.inner.record_locks.gate_for(name);
guards.push(gate.lock_owned().await);
}
ManyRecordWriteGuard { _guards: guards }
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
#[tokio::test]
async fn lock_record_excludes_same_record() {
let db = PvDatabase::new();
let order = Arc::new(AtomicUsize::new(0));
let g = db.lock_record("ai:1").await;
let db2 = db.clone();
let order2 = order.clone();
let h = tokio::spawn(async move {
let _g2 = db2.lock_record("ai:1").await;
order2.fetch_add(10, Ordering::SeqCst);
});
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(order.load(Ordering::SeqCst), 0);
order.fetch_add(1, Ordering::SeqCst);
drop(g);
h.await.unwrap();
assert_eq!(order.load(Ordering::SeqCst), 11);
}
#[tokio::test]
async fn lock_records_excludes_single_member_write() {
let db = PvDatabase::new();
let many = db.lock_records(["g:a", "g:b", "g:c"]).await;
let db2 = db.clone();
let acquired = Arc::new(AtomicUsize::new(0));
let acquired2 = acquired.clone();
let h = tokio::spawn(async move {
let _g = db2.lock_record("g:b").await;
acquired2.store(1, Ordering::SeqCst);
});
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(
acquired.load(Ordering::SeqCst),
0,
"single-member write must block while ManyRecordWriteGuard is held"
);
drop(many);
h.await.unwrap();
assert_eq!(acquired.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn lock_records_overlapping_sets_no_deadlock() {
let db = PvDatabase::new();
let db_a = db.clone();
let ta = tokio::spawn(async move {
for _ in 0..50 {
let _g = db_a.lock_records(["x", "y", "z"]).await;
tokio::task::yield_now().await;
}
});
let db_b = db.clone();
let tb = tokio::spawn(async move {
for _ in 0..50 {
let _g = db_b.lock_records(["z", "y", "x"]).await;
tokio::task::yield_now().await;
}
});
tokio::time::timeout(Duration::from_secs(5), async {
ta.await.unwrap();
tb.await.unwrap();
})
.await
.expect("overlapping lock_records sets must not deadlock");
}
#[tokio::test]
async fn overlapping_epochs_are_mutually_exclusive_and_deadlock_free() {
let db = PvDatabase::new();
let a = vec!["RECA".to_string(), "RECB".to_string()];
let b = vec!["RECB".to_string(), "RECC".to_string()];
let guard_a = db.lock_records(&a).await;
let db2 = db.clone();
let b2 = b.clone();
let handle = tokio::spawn(async move { db2.lock_records(&b2).await });
tokio::task::yield_now().await;
assert!(!handle.is_finished(), "epoch B must block on shared RECB");
drop(guard_a);
let _guard_b = handle.await.expect("epoch B task");
}
#[tokio::test]
async fn disjoint_epochs_do_not_block_each_other() {
let db = PvDatabase::new();
let _g1 = db.lock_records(&["X1".to_string()]).await;
let _g2 = db.lock_records(&["X2".to_string()]).await;
}
}