use std::{
hash::{BuildHasher, BuildHasherDefault, DefaultHasher, Hash},
marker::PhantomData,
num::NonZeroUsize,
sync::{Mutex, MutexGuard},
};
use crate::batch::{KeyBatch, MAX_BATCH_KEYS};
pub struct StripedPoisonError;
pub struct StripedLockGuard<'l> {
_guard: MutexGuard<'l, ()>,
}
pub struct StripedBatchLockGuard<'l> {
_guards: [Option<MutexGuard<'l, ()>>; MAX_BATCH_KEYS],
}
pub struct StripedLock<K, H = BuildHasherDefault<DefaultHasher>>
where
K: Hash,
H: BuildHasher,
{
hasher_builder: H,
locks: Box<[Mutex<()>]>,
phantom: PhantomData<K>,
}
impl<K> StripedLock<K, BuildHasherDefault<DefaultHasher>>
where
K: Hash,
{
pub fn new(locks: NonZeroUsize) -> Self {
Self::with_hasher(BuildHasherDefault::default(), locks)
}
}
impl<K, H> StripedLock<K, H>
where
K: Hash,
H: BuildHasher,
{
pub fn with_hasher(hasher_builder: H, locks: NonZeroUsize) -> Self {
let locks = (0..locks.get())
.map(|_| Mutex::new(()))
.collect::<Vec<_>>()
.into_boxed_slice();
Self {
hasher_builder,
locks,
phantom: PhantomData::default(),
}
}
pub fn lock(&self, key: K) -> Result<StripedLockGuard, StripedPoisonError> {
fn inner(locks: &[Mutex<()>], key: u64) -> Result<StripedLockGuard, StripedPoisonError> {
let idx = (key % locks.len() as u64) as usize;
let lock = &locks[idx];
match lock.lock() {
Ok(guard) => Ok(StripedLockGuard { _guard: guard }),
Err(_) => Err(StripedPoisonError),
}
}
let hash = self.hasher_builder.hash_one(key);
inner(&self.locks, hash)
}
pub fn lock_batch<B>(&self, batch: B) -> Result<StripedBatchLockGuard, StripedPoisonError>
where
B: KeyBatch<K, H>,
{
fn inner<'l>(
locks: &'l [Mutex<()>],
batch: &mut [u64],
) -> Result<StripedBatchLockGuard<'l>, StripedPoisonError> {
const ARRAY_REPEAT_VALUE: Option<MutexGuard<()>> = None;
assert!(batch.len() > 0);
assert!(batch.len() <= MAX_BATCH_KEYS);
for key in batch.iter_mut() {
*key %= locks.len() as u64;
}
batch.sort_unstable();
let mut guards = [ARRAY_REPEAT_VALUE; MAX_BATCH_KEYS];
guards[0] = Some(
locks[batch[0] as usize]
.lock()
.map_err(|_| StripedPoisonError)?,
);
for i in 1..batch.len() {
if batch[i] != batch[i - 1] {
guards[i] = Some(
locks[batch[i] as usize]
.lock()
.map_err(|_| StripedPoisonError)?,
);
}
}
Ok(StripedBatchLockGuard { _guards: guards })
}
let (mut arr, filled) = batch.into_hash_array(&self.hasher_builder);
let batch = &mut arr[..filled];
inner(&self.locks, batch)
}
pub fn is_poisoned(&self, key: K) -> bool {
fn inner(locks: &[Mutex<()>], key: u64) -> bool {
let idx = (key % locks.len() as u64) as usize;
let lock = &locks[idx];
lock.is_poisoned()
}
let key = self.hasher_builder.hash_one(key);
inner(&self.locks, key)
}
pub fn clear_poison(&self, key: K) {
fn inner(locks: &[Mutex<()>], key: u64) {
let idx = (key % locks.len() as u64) as usize;
let lock = &locks[idx];
lock.clear_poison();
}
let key = self.hasher_builder.hash_one(key);
inner(&self.locks, key);
}
}