#[cfg(loom)]
use loom::sync::{Mutex, MutexGuard};
#[cfg(not(loom))]
use std::sync::{Mutex, MutexGuard};
use std::collections::HashMap;
use crate::{LockError, LockMode, ResourceId, TxnId};
const FIB_HASH: u64 = 0x9E37_79B9_7F4A_7C15;
#[derive(Clone, Copy)]
struct Holder {
txn: TxnId,
mode: LockMode,
}
struct LockEntry {
holders: Vec<Holder>,
}
impl LockEntry {
#[inline]
fn new() -> Self {
Self {
holders: Vec::new(),
}
}
}
struct ShardInner {
locks: HashMap<ResourceId, LockEntry>,
by_txn: HashMap<TxnId, Vec<ResourceId>>,
}
impl ShardInner {
fn new() -> Self {
Self {
locks: HashMap::new(),
by_txn: HashMap::new(),
}
}
}
struct Shard {
inner: Mutex<ShardInner>,
}
#[must_use = "a LockManager that is dropped immediately releases every lock it holds"]
pub struct LockManager {
shards: Box<[Shard]>,
bits: u32,
}
impl LockManager {
pub fn new() -> Self {
let parallelism = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let target = (parallelism.saturating_mul(4))
.next_power_of_two()
.clamp(16, 1024);
Self::with_shards(target)
}
pub fn with_shards(shards: usize) -> Self {
let n = shards.max(1).next_power_of_two();
let bits = n.trailing_zeros();
let mut v = Vec::with_capacity(n);
for _ in 0..n {
v.push(Shard {
inner: Mutex::new(ShardInner::new()),
});
}
Self {
shards: v.into_boxed_slice(),
bits,
}
}
#[inline]
#[must_use]
pub fn shards(&self) -> usize {
self.shards.len()
}
pub fn try_acquire(
&self,
txn: TxnId,
res: ResourceId,
mode: LockMode,
) -> Result<(), LockError> {
let mut guard = self.lock_shard(res);
let ShardInner { locks, by_txn } = &mut *guard;
let entry = locks.entry(res).or_insert_with(LockEntry::new);
if let Some(pos) = entry.holders.iter().position(|h| h.txn == txn) {
let current = entry.holders[pos].mode;
if current.covers(mode) {
return Ok(());
}
if entry.holders.len() == 1 {
entry.holders[pos].mode = mode;
return Ok(());
}
return Err(LockError::Conflict);
}
if entry.holders.iter().all(|h| h.mode.compatible_with(mode)) {
entry.holders.push(Holder { txn, mode });
by_txn.entry(txn).or_default().push(res);
Ok(())
} else {
Err(LockError::Conflict)
}
}
pub fn release(&self, txn: TxnId, res: ResourceId) -> Result<(), LockError> {
let mut guard = self.lock_shard(res);
let ShardInner { locks, by_txn } = &mut *guard;
let entry = match locks.get_mut(&res) {
Some(entry) => entry,
None => return Err(LockError::NotHeld),
};
let pos = match entry.holders.iter().position(|h| h.txn == txn) {
Some(pos) => pos,
None => return Err(LockError::NotHeld),
};
let _ = entry.holders.swap_remove(pos);
if entry.holders.is_empty() {
let _ = locks.remove(&res);
}
Self::forget_resource(by_txn, txn, res);
Ok(())
}
pub fn release_all(&self, txn: TxnId) -> usize {
let mut released = 0;
for shard in self.shards.iter() {
let mut guard = Self::lock(shard);
let ShardInner { locks, by_txn } = &mut *guard;
let Some(resources) = by_txn.remove(&txn) else {
continue;
};
for res in resources {
if let Some(entry) = locks.get_mut(&res) {
if let Some(pos) = entry.holders.iter().position(|h| h.txn == txn) {
let _ = entry.holders.swap_remove(pos);
released += 1;
if entry.holders.is_empty() {
let _ = locks.remove(&res);
}
}
}
}
}
released
}
#[must_use]
pub fn holder_count(&self, res: ResourceId) -> usize {
let guard = self.lock_shard(res);
guard.locks.get(&res).map_or(0, |e| e.holders.len())
}
#[must_use]
pub fn mode_held(&self, txn: TxnId, res: ResourceId) -> Option<LockMode> {
let guard = self.lock_shard(res);
guard
.locks
.get(&res)
.and_then(|e| e.holders.iter().find(|h| h.txn == txn))
.map(|h| h.mode)
}
#[inline]
fn forget_resource(by_txn: &mut HashMap<TxnId, Vec<ResourceId>>, txn: TxnId, res: ResourceId) {
if let Some(resources) = by_txn.get_mut(&txn) {
if let Some(pos) = resources.iter().position(|r| *r == res) {
let _ = resources.swap_remove(pos);
}
if resources.is_empty() {
let _ = by_txn.remove(&txn);
}
}
}
#[inline]
fn lock_shard(&self, res: ResourceId) -> MutexGuard<'_, ShardInner> {
Self::lock(&self.shards[self.shard_index(res)])
}
#[inline]
fn lock(shard: &Shard) -> MutexGuard<'_, ShardInner> {
match shard.inner.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
#[inline]
fn shard_index(&self, res: ResourceId) -> usize {
if self.bits == 0 {
return 0;
}
let hash = res.get().wrapping_mul(FIB_HASH);
(hash >> (u64::BITS - self.bits)) as usize
}
}
impl Default for LockManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(all(test, not(loom)))]
#[allow(clippy::unwrap_used)]
mod tests {
use super::{FIB_HASH, LockManager};
use crate::{LockError, LockMode, ResourceId, TxnId};
fn ids(t: u64, r: u64) -> (TxnId, ResourceId) {
(TxnId::new(t), ResourceId::new(r))
}
#[test]
fn test_shared_locks_coexist() {
let lm = LockManager::new();
let r = ResourceId::new(1);
lm.try_acquire(TxnId::new(1), r, LockMode::Shared).unwrap();
lm.try_acquire(TxnId::new(2), r, LockMode::Shared).unwrap();
lm.try_acquire(TxnId::new(3), r, LockMode::Shared).unwrap();
assert_eq!(lm.holder_count(r), 3);
}
#[test]
fn test_exclusive_excludes_shared() {
let lm = LockManager::new();
let (t1, r) = ids(1, 1);
lm.try_acquire(t1, r, LockMode::Exclusive).unwrap();
assert_eq!(
lm.try_acquire(TxnId::new(2), r, LockMode::Shared),
Err(LockError::Conflict)
);
}
#[test]
fn test_exclusive_excludes_exclusive() {
let lm = LockManager::new();
let (t1, r) = ids(1, 1);
lm.try_acquire(t1, r, LockMode::Exclusive).unwrap();
assert_eq!(
lm.try_acquire(TxnId::new(2), r, LockMode::Exclusive),
Err(LockError::Conflict)
);
}
#[test]
fn test_shared_blocks_other_exclusive() {
let lm = LockManager::new();
let (t1, r) = ids(1, 1);
lm.try_acquire(t1, r, LockMode::Shared).unwrap();
assert_eq!(
lm.try_acquire(TxnId::new(2), r, LockMode::Exclusive),
Err(LockError::Conflict)
);
}
#[test]
fn test_reacquire_same_mode_is_idempotent() {
let lm = LockManager::new();
let (t1, r) = ids(1, 1);
lm.try_acquire(t1, r, LockMode::Shared).unwrap();
lm.try_acquire(t1, r, LockMode::Shared).unwrap();
assert_eq!(lm.holder_count(r), 1);
}
#[test]
fn test_request_weaker_than_held_is_noop() {
let lm = LockManager::new();
let (t1, r) = ids(1, 1);
lm.try_acquire(t1, r, LockMode::Exclusive).unwrap();
lm.try_acquire(t1, r, LockMode::Shared).unwrap();
assert_eq!(lm.mode_held(t1, r), Some(LockMode::Exclusive));
assert_eq!(lm.holder_count(r), 1);
}
#[test]
fn test_upgrade_sole_holder_succeeds() {
let lm = LockManager::new();
let (t1, r) = ids(1, 1);
lm.try_acquire(t1, r, LockMode::Shared).unwrap();
lm.try_acquire(t1, r, LockMode::Exclusive).unwrap();
assert_eq!(lm.mode_held(t1, r), Some(LockMode::Exclusive));
assert_eq!(lm.holder_count(r), 1);
}
#[test]
fn test_upgrade_blocked_by_other_reader() {
let lm = LockManager::new();
let r = ResourceId::new(1);
lm.try_acquire(TxnId::new(1), r, LockMode::Shared).unwrap();
lm.try_acquire(TxnId::new(2), r, LockMode::Shared).unwrap();
assert_eq!(
lm.try_acquire(TxnId::new(1), r, LockMode::Exclusive),
Err(LockError::Conflict)
);
assert_eq!(lm.mode_held(TxnId::new(1), r), Some(LockMode::Shared));
}
#[test]
fn test_release_frees_resource_for_exclusive() {
let lm = LockManager::new();
let r = ResourceId::new(1);
lm.try_acquire(TxnId::new(1), r, LockMode::Shared).unwrap();
lm.try_acquire(TxnId::new(2), r, LockMode::Shared).unwrap();
lm.release(TxnId::new(1), r).unwrap();
assert!(
lm.try_acquire(TxnId::new(3), r, LockMode::Exclusive)
.is_err()
);
lm.release(TxnId::new(2), r).unwrap();
lm.try_acquire(TxnId::new(3), r, LockMode::Exclusive)
.unwrap();
}
#[test]
fn test_release_not_held_errors() {
let lm = LockManager::new();
let (t1, r) = ids(1, 1);
assert_eq!(lm.release(t1, r), Err(LockError::NotHeld));
lm.try_acquire(t1, r, LockMode::Shared).unwrap();
assert_eq!(lm.release(TxnId::new(9), r), Err(LockError::NotHeld));
}
#[test]
fn test_double_release_errors() {
let lm = LockManager::new();
let (t1, r) = ids(1, 1);
lm.try_acquire(t1, r, LockMode::Exclusive).unwrap();
lm.release(t1, r).unwrap();
assert_eq!(lm.release(t1, r), Err(LockError::NotHeld));
}
#[test]
fn test_release_all_drops_every_lock() {
let lm = LockManager::with_shards(8);
let t = TxnId::new(1);
for id in 0..50 {
lm.try_acquire(t, ResourceId::new(id), LockMode::Exclusive)
.unwrap();
}
assert_eq!(lm.release_all(t), 50);
for id in 0..50 {
assert_eq!(lm.holder_count(ResourceId::new(id)), 0);
}
assert_eq!(lm.release_all(t), 0);
}
#[test]
fn test_release_all_leaves_other_txns_alone() {
let lm = LockManager::new();
let r = ResourceId::new(1);
lm.try_acquire(TxnId::new(1), r, LockMode::Shared).unwrap();
lm.try_acquire(TxnId::new(2), r, LockMode::Shared).unwrap();
assert_eq!(lm.release_all(TxnId::new(1)), 1);
assert_eq!(lm.mode_held(TxnId::new(2), r), Some(LockMode::Shared));
assert_eq!(lm.holder_count(r), 1);
}
#[test]
fn test_resource_fully_released_can_be_taken_exclusively() {
let lm = LockManager::new();
let r = ResourceId::new(42);
lm.try_acquire(TxnId::new(1), r, LockMode::Exclusive)
.unwrap();
lm.release(TxnId::new(1), r).unwrap();
assert_eq!(lm.holder_count(r), 0);
lm.try_acquire(TxnId::new(2), r, LockMode::Exclusive)
.unwrap();
}
#[test]
fn test_with_shards_rounds_up_to_power_of_two() {
assert_eq!(LockManager::with_shards(1).shards(), 1);
assert_eq!(LockManager::with_shards(3).shards(), 4);
assert_eq!(LockManager::with_shards(5).shards(), 8);
assert_eq!(LockManager::with_shards(0).shards(), 1);
assert_eq!(LockManager::with_shards(64).shards(), 64);
}
#[test]
fn test_single_shard_routes_everything_to_index_zero() {
let lm = LockManager::with_shards(1);
for id in 0..1000 {
assert_eq!(lm.shard_index(ResourceId::new(id)), 0);
}
}
#[test]
fn test_shard_index_within_bounds() {
let lm = LockManager::with_shards(16);
for id in 0..10_000 {
assert!(lm.shard_index(ResourceId::new(id)) < 16);
}
}
#[test]
fn test_sequential_ids_spread_across_shards() {
let lm = LockManager::with_shards(16);
let mut seen = [false; 16];
for id in 0..256 {
seen[lm.shard_index(ResourceId::new(id))] = true;
}
assert!(seen.iter().all(|&hit| hit));
}
#[test]
fn test_locks_in_different_shards_are_independent() {
let lm = LockManager::with_shards(16);
let a = ResourceId::new(1);
let b = ResourceId::new(2);
lm.try_acquire(TxnId::new(1), a, LockMode::Exclusive)
.unwrap();
lm.try_acquire(TxnId::new(2), b, LockMode::Exclusive)
.unwrap();
assert_eq!(lm.holder_count(a), 1);
assert_eq!(lm.holder_count(b), 1);
}
#[test]
fn test_fib_hash_constant_is_odd() {
assert_eq!(FIB_HASH & 1, 1);
}
#[test]
fn test_concurrent_shared_acquire_release_is_consistent() {
use std::sync::Arc;
use std::thread;
let lm = Arc::new(LockManager::new());
let r = ResourceId::new(7);
let mut handles = Vec::new();
for t in 0..8u64 {
let lm = Arc::clone(&lm);
handles.push(thread::spawn(move || {
let txn = TxnId::new(t);
for _ in 0..1000 {
lm.try_acquire(txn, r, LockMode::Shared).unwrap();
lm.release(txn, r).unwrap();
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(lm.holder_count(r), 0);
}
#[test]
fn test_concurrent_exclusive_is_mutually_exclusive() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
let lm = Arc::new(LockManager::new());
let active = Arc::new(AtomicUsize::new(0));
let r = ResourceId::new(11);
let mut handles = Vec::new();
for t in 0..8u64 {
let lm = Arc::clone(&lm);
let active = Arc::clone(&active);
handles.push(thread::spawn(move || {
let txn = TxnId::new(t);
for _ in 0..2000 {
if lm.try_acquire(txn, r, LockMode::Exclusive).is_ok() {
let inside = active.fetch_add(1, Ordering::SeqCst);
assert_eq!(inside, 0);
active.fetch_sub(1, Ordering::SeqCst);
lm.release(txn, r).unwrap();
}
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(lm.holder_count(r), 0);
}
}