azoth_core/
lock_manager.rs1use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
2use xxhash_rust::xxh3::xxh3_64;
3
4pub struct LockManager {
9 stripes: Vec<Arc<RwLock<()>>>,
10 num_stripes: usize,
11}
12
13impl LockManager {
14 pub fn new(num_stripes: usize) -> Self {
19 assert!(num_stripes > 0, "num_stripes must be positive");
20 let stripes = (0..num_stripes)
21 .map(|_| Arc::new(RwLock::new(())))
22 .collect();
23
24 Self {
25 stripes,
26 num_stripes,
27 }
28 }
29
30 fn stripe_index(&self, key: &[u8]) -> usize {
32 let hash = xxh3_64(key);
33 (hash as usize) % self.num_stripes
34 }
35
36 pub fn read_lock(&self, key: &[u8]) -> RwLockReadGuard<'_, ()> {
43 let idx = self.stripe_index(key);
44 self.stripes[idx].read().expect("Lock poisoned")
45 }
46
47 pub fn write_lock(&self, key: &[u8]) -> RwLockWriteGuard<'_, ()> {
54 let idx = self.stripe_index(key);
55 self.stripes[idx].write().expect("Lock poisoned")
56 }
57
58 pub fn num_stripes(&self) -> usize {
60 self.num_stripes
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67 use std::thread;
68 use std::time::Duration;
69
70 #[test]
71 fn test_lock_manager_basic() {
72 let lm = LockManager::new(256);
73 assert_eq!(lm.num_stripes(), 256);
74
75 let _lock1 = lm.read_lock(b"key1");
77 let _lock2 = lm.read_lock(b"key2");
78 }
79
80 #[test]
81 fn test_stripe_distribution() {
82 let lm = LockManager::new(256);
83
84 let idx1 = lm.stripe_index(b"key1");
86 let idx2 = lm.stripe_index(b"key2");
87 let idx3 = lm.stripe_index(b"key3");
88
89 assert!(idx1 < 256);
90 assert!(idx2 < 256);
91 assert!(idx3 < 256);
92
93 assert_eq!(idx1, lm.stripe_index(b"key1"));
95 }
96
97 #[test]
98 fn test_concurrent_readers() {
99 let lm = Arc::new(LockManager::new(256));
100
101 let lm1 = lm.clone();
103 let lm2 = lm.clone();
104
105 let h1 = thread::spawn(move || {
106 let _lock = lm1.read_lock(b"same_key");
107 thread::sleep(Duration::from_millis(10));
108 });
109
110 let h2 = thread::spawn(move || {
111 let _lock = lm2.read_lock(b"same_key");
112 thread::sleep(Duration::from_millis(10));
113 });
114
115 h1.join().unwrap();
117 h2.join().unwrap();
118 }
119}