use crate::error::{AzothError, Result};
use parking_lot::{Mutex, MutexGuard};
use std::collections::BTreeSet;
use std::time::Duration;
use xxhash_rust::xxh3::xxh3_64;
pub const DEFAULT_LOCK_TIMEOUT_MS: u64 = 5000;
pub struct LockManager {
stripes: Vec<Mutex<()>>,
num_stripes: usize,
default_timeout: Duration,
}
pub struct MultiLockGuard<'a> {
_guards: Vec<MutexGuard<'a, ()>>,
}
impl LockManager {
pub fn new(num_stripes: usize, default_timeout: Duration) -> Self {
let num_stripes = if num_stripes == 0 {
tracing::warn!(
"LockManager created with num_stripes=0, defaulting to 1. \
This is a configuration error — set ARCANA_KEY_LOCK_STRIPES > 0."
);
1
} else {
num_stripes
};
let stripes = (0..num_stripes).map(|_| Mutex::new(())).collect();
Self {
stripes,
num_stripes,
default_timeout,
}
}
pub fn with_stripes(num_stripes: usize) -> Self {
Self::new(num_stripes, Duration::from_millis(DEFAULT_LOCK_TIMEOUT_MS))
}
fn stripe_index(&self, key: &[u8]) -> usize {
let hash = xxh3_64(key);
(hash as usize) % self.num_stripes
}
pub fn acquire_keys<K: AsRef<[u8]>>(&self, keys: &[K]) -> Result<MultiLockGuard<'_>> {
self.acquire_keys_with_timeout(keys, self.default_timeout)
}
pub fn acquire_keys_with_timeout<K: AsRef<[u8]>>(
&self,
keys: &[K],
timeout: Duration,
) -> Result<MultiLockGuard<'_>> {
if keys.is_empty() {
return Ok(MultiLockGuard {
_guards: Vec::new(),
});
}
let stripe_indices: BTreeSet<usize> =
keys.iter().map(|k| self.stripe_index(k.as_ref())).collect();
let mut guards = Vec::with_capacity(stripe_indices.len());
for stripe_idx in stripe_indices {
match self.stripes[stripe_idx].try_lock_for(timeout) {
Some(guard) => guards.push(guard),
None => {
return Err(AzothError::LockTimeout {
timeout_ms: timeout.as_millis() as u64,
});
}
}
}
Ok(MultiLockGuard { _guards: guards })
}
pub fn lock(&self, key: &[u8]) -> Result<MutexGuard<'_, ()>> {
let idx = self.stripe_index(key);
self.stripes[idx]
.try_lock_for(self.default_timeout)
.ok_or(AzothError::LockTimeout {
timeout_ms: self.default_timeout.as_millis() as u64,
})
}
pub fn num_stripes(&self) -> usize {
self.num_stripes
}
pub fn default_timeout(&self) -> Duration {
self.default_timeout
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_lock_manager_basic() {
let lm = LockManager::with_stripes(256);
assert_eq!(lm.num_stripes(), 256);
let _guard = lm.acquire_keys(&[b"key1", b"key2"]).unwrap();
}
#[test]
fn test_stripe_distribution() {
let lm = LockManager::with_stripes(256);
let idx1 = lm.stripe_index(b"key1");
let idx2 = lm.stripe_index(b"key2");
let idx3 = lm.stripe_index(b"key3");
assert!(idx1 < 256);
assert!(idx2 < 256);
assert!(idx3 < 256);
assert_eq!(idx1, lm.stripe_index(b"key1"));
}
#[test]
fn test_empty_keys() {
let lm = LockManager::with_stripes(256);
let _guard = lm.acquire_keys::<&[u8]>(&[]).unwrap();
}
#[test]
fn test_duplicate_keys_deduplicated() {
let lm = LockManager::with_stripes(256);
let _guard = lm.acquire_keys(&[b"key1", b"key1", b"key1"]).unwrap();
}
#[test]
fn test_sorted_acquisition_prevents_deadlock() {
let lm = Arc::new(LockManager::with_stripes(256));
let lm1 = lm.clone();
let lm2 = lm.clone();
let h1 = thread::spawn(move || {
for _ in 0..100 {
let _guard = lm1.acquire_keys(&[b"key_a", b"key_b"]).unwrap();
thread::sleep(Duration::from_micros(10));
}
});
let h2 = thread::spawn(move || {
for _ in 0..100 {
let _guard = lm2.acquire_keys(&[b"key_b", b"key_a"]).unwrap();
thread::sleep(Duration::from_micros(10));
}
});
h1.join().unwrap();
h2.join().unwrap();
}
#[test]
fn test_timeout_works() {
let lm = Arc::new(LockManager::new(1, Duration::from_millis(50)));
let _guard = lm.acquire_keys(&[b"any_key"]).unwrap();
let lm2 = lm.clone();
let handle = thread::spawn(move || {
matches!(
lm2.acquire_keys(&[b"another_key"]),
Err(AzothError::LockTimeout { .. })
)
});
let timed_out = handle.join().unwrap();
assert!(timed_out, "Should have timed out");
}
#[test]
fn test_concurrent_different_stripes() {
let lm = Arc::new(LockManager::with_stripes(256));
let handles: Vec<_> = (0..10)
.map(|i| {
let lm = lm.clone();
thread::spawn(move || {
let key = format!("unique_key_{}", i);
let _guard = lm.acquire_keys(&[key.as_bytes()]).unwrap();
thread::sleep(Duration::from_millis(10));
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
}
}