use parking_lot::{Condvar, Mutex};
use std::collections::HashMap;
use std::sync::OnceLock;
pub type ConnId = u64;
pub struct AdvisoryLocks {
state: Mutex<HashMap<i64, ConnId>>,
cv: Condvar,
}
impl AdvisoryLocks {
fn new() -> Self {
Self {
state: Mutex::new(HashMap::new()),
cv: Condvar::new(),
}
}
pub fn try_acquire(&self, key: i64, conn: ConnId) -> bool {
let mut map = self.state.lock();
match map.get(&key).copied() {
Some(owner) if owner == conn => true,
Some(_) => false,
None => {
map.insert(key, conn);
true
}
}
}
pub fn acquire(&self, key: i64, conn: ConnId) {
let mut map = self.state.lock();
loop {
match map.get(&key).copied() {
Some(owner) if owner == conn => return,
Some(_) => self.cv.wait(&mut map),
None => {
map.insert(key, conn);
return;
}
}
}
}
pub fn release(&self, key: i64, conn: ConnId) -> bool {
let mut map = self.state.lock();
match map.get(&key).copied() {
Some(owner) if owner == conn => {
map.remove(&key);
self.cv.notify_all();
true
}
_ => false,
}
}
pub fn release_all(&self, conn: ConnId) -> usize {
let mut map = self.state.lock();
let before = map.len();
map.retain(|_, owner| *owner != conn);
let dropped = before - map.len();
if dropped > 0 {
self.cv.notify_all();
}
dropped
}
#[cfg(test)]
pub fn is_held(&self, key: i64) -> bool {
self.state.lock().contains_key(&key)
}
}
static GLOBAL: OnceLock<AdvisoryLocks> = OnceLock::new();
pub fn global() -> &'static AdvisoryLocks {
GLOBAL.get_or_init(AdvisoryLocks::new)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn try_acquire_and_release() {
let locks = AdvisoryLocks::new();
assert!(locks.try_acquire(1, 100));
assert!(!locks.try_acquire(1, 200), "other conn cannot steal");
assert!(locks.try_acquire(1, 100), "same conn is reentrant");
assert!(locks.release(1, 100));
assert!(!locks.is_held(1));
}
#[test]
fn release_all_drops_only_owned() {
let locks = AdvisoryLocks::new();
assert!(locks.try_acquire(1, 100));
assert!(locks.try_acquire(2, 100));
assert!(locks.try_acquire(3, 200));
assert_eq!(locks.release_all(100), 2);
assert!(!locks.is_held(1));
assert!(!locks.is_held(2));
assert!(locks.is_held(3));
}
#[test]
fn release_mismatch_returns_false() {
let locks = AdvisoryLocks::new();
assert!(locks.try_acquire(5, 100));
assert!(!locks.release(5, 999));
assert!(locks.is_held(5));
}
}