batch_lock/
lib.rs

1//! A thread-safe key-based locking mechanism for managing concurrent access to resources.
2//!
3//! This crate provides a flexible lock manager that allows fine-grained locking based on keys.
4//! It supports both single-key and batch locking operations with RAII-style lock guards.
5//!
6//! # Features
7//!
8//! - Thread-safe key-based locking
9//! - Support for both single key and batch locking operations
10//! - RAII-style lock guards
11//! - Configurable capacity and sharding
12//!
13//! # Examples
14//!
15//! Single key locking:
16//! ```
17//! use batch_lock::LockManager;
18//!
19//! let lock_manager = LockManager::<String>::new();
20//! let key = "resource_1".to_string();
21//! let guard = lock_manager.lock(&key);
22//! // Critical section - exclusive access guaranteed
23//! // Guard automatically releases lock when dropped
24//! ```
25//!
26//! Batch locking:
27//! ```
28//! use batch_lock::LockManager;
29//! use std::collections::BTreeSet;
30//!
31//! let lock_manager = LockManager::<String>::new();
32//! let mut keys = BTreeSet::new();
33//! keys.insert("resource_1".to_string());
34//! keys.insert("resource_2".to_string());
35//!
36//! let guard = lock_manager.batch_lock(&keys);
37//! // Critical section - exclusive access to all keys guaranteed
38//! // Guard automatically releases all locks when dropped
39//! ```
40use dashmap::{DashMap, Entry};
41use std::collections::{BTreeSet, LinkedList};
42use std::hash::Hash;
43use std::sync::atomic::{AtomicU32, Ordering};
44
45struct WaiterPtr(*const AtomicU32);
46
47impl WaiterPtr {
48    fn wake_up(self) {
49        let ptr = self.0;
50        let waiter = unsafe { &*ptr };
51        waiter.store(1, Ordering::Release);
52        atomic_wait::wake_one(ptr);
53    }
54}
55
56unsafe impl Sync for WaiterPtr {}
57unsafe impl Send for WaiterPtr {}
58
59/// A thread-safe lock manager that provides key-based locking capabilities.
60///
61/// The `LockManager` allows concurrent access control based on keys of type `K`.
62/// It supports both single-key and batch locking operations.
63///
64/// Type parameter:
65/// - `K`: The key type that must implement `Eq + Hash + Clone` traits
66pub struct LockManager<K: Eq + Hash + Clone> {
67    map: DashMap<K, LinkedList<WaiterPtr>>,
68}
69
70impl<K: Eq + Hash + Clone> LockManager<K> {
71    /// Creates a new `LockManager` instance with default capacity.
72    pub fn new() -> Self {
73        Self {
74            map: DashMap::new(),
75        }
76    }
77
78    /// Creates a new `LockManager` with the specified capacity.
79    ///
80    /// # Arguments
81    /// * `capacity` - The initial capacity for the internal map
82    pub fn with_capacity(capacity: usize) -> Self {
83        Self {
84            map: DashMap::with_capacity(capacity),
85        }
86    }
87
88    /// Creates a new `LockManager` with specified capacity and shard amount.
89    ///
90    /// # Arguments
91    /// * `capacity` - The initial capacity for the internal map
92    /// * `shard_amount` - The number of shards to use for internal concurrency
93    pub fn with_capacity_and_shard_amount(capacity: usize, shard_amount: usize) -> Self {
94        Self {
95            map: DashMap::with_capacity_and_shard_amount(capacity, shard_amount),
96        }
97    }
98
99    /// Acquires a lock for a single key.
100    ///
101    /// This method will block until the lock can be acquired.
102    ///
103    /// # Arguments
104    /// * `key` - The key to lock
105    ///
106    /// # Returns
107    /// Returns a `LockGuard` that will automatically release the lock when dropped
108    pub fn lock<'a, 'b>(&'a self, key: &'b K) -> LockGuard<'a, 'b, K> {
109        self.raw_lock(key);
110        LockGuard::<'a, 'b, K> { map: self, key }
111    }
112
113    /// Acquires locks for multiple keys atomically.
114    ///
115    /// This method will block until all locks can be acquired. The locks are acquired
116    /// in a consistent order to prevent deadlocks.
117    ///
118    /// # Arguments
119    /// * `keys` - A `BTreeSet` containing the keys to lock
120    ///
121    /// # Returns
122    /// Returns a `BatchLockGuard` that will automatically release all locks when dropped
123    pub fn batch_lock<'a, 'b>(&'a self, keys: &'b BTreeSet<K>) -> BatchLockGuard<'a, 'b, K> {
124        for key in keys {
125            self.raw_lock(key);
126        }
127        BatchLockGuard::<'a, 'b, K> { map: self, keys }
128    }
129
130    fn raw_lock(&self, key: &K) {
131        let waiter = AtomicU32::new(0);
132        match self.map.entry(key.clone()) {
133            Entry::Occupied(mut occupied_entry) => {
134                occupied_entry.get_mut().push_back(WaiterPtr(&waiter as _));
135            }
136            Entry::Vacant(vacant_entry) => {
137                vacant_entry.insert(Default::default());
138                waiter.store(1, Ordering::Release);
139            }
140        };
141        while waiter.load(Ordering::Acquire) == 0 {
142            atomic_wait::wait(&waiter, 0);
143        }
144    }
145
146    fn unlock(&self, key: &K) {
147        match self.map.entry(key.clone()) {
148            Entry::Occupied(mut occupied_entry) => match occupied_entry.get_mut().pop_front() {
149                Some(waiter) => {
150                    waiter.wake_up();
151                }
152                None => {
153                    occupied_entry.remove();
154                }
155            },
156            Entry::Vacant(_) => panic!("impossible: unlock a non-existent key!"),
157        }
158    }
159
160    fn batch_unlock(&self, keys: &BTreeSet<K>) {
161        for key in keys.iter().rev() {
162            self.unlock(key);
163        }
164    }
165}
166
167impl<K: Eq + Hash + Clone> Default for LockManager<K> {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173/// RAII guard for a single locked key.
174///
175/// When this guard is dropped, the lock will be automatically released.
176pub struct LockGuard<'a, 'b, K: Eq + Hash + Clone> {
177    map: &'a LockManager<K>,
178    key: &'b K,
179}
180
181impl<'a, 'b, K: Eq + Hash + Clone> Drop for LockGuard<'a, 'b, K> {
182    fn drop(&mut self) {
183        self.map.unlock(self.key);
184    }
185}
186
187/// RAII guard for multiple locked keys.
188///
189/// When this guard is dropped, all locks will be automatically released
190/// in the reverse order they were acquired.
191pub struct BatchLockGuard<'a, 'b, K: Eq + Hash + Clone> {
192    map: &'a LockManager<K>,
193    keys: &'b BTreeSet<K>,
194}
195
196impl<'a, 'b, K: Eq + Hash + Clone> Drop for BatchLockGuard<'a, 'b, K> {
197    fn drop(&mut self) {
198        self.map.batch_unlock(self.keys);
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use std::sync::{atomic::AtomicUsize, Arc};
206
207    #[test]
208    fn test_lock_map_same_key() {
209        let lock_map = Arc::new(LockManager::<u32>::new());
210        let total = Arc::new(AtomicUsize::default());
211        let current = Arc::new(AtomicU32::default());
212        const N: usize = 1 << 12;
213        const M: usize = 8;
214
215        let threads = (0..M)
216            .map(|_| {
217                let lock_map = lock_map.clone();
218                let total = total.clone();
219                let current = current.clone();
220                std::thread::spawn(move || {
221                    for _ in 0..N {
222                        let _guard = lock_map.lock(&1);
223                        let now = current.fetch_add(1, Ordering::AcqRel);
224                        assert_eq!(now, 0);
225                        total.fetch_add(1, Ordering::AcqRel);
226                        let now = current.fetch_sub(1, Ordering::AcqRel);
227                        assert_eq!(now, 1);
228                    }
229                })
230            })
231            .collect::<Vec<_>>();
232        threads.into_iter().for_each(|t| t.join().unwrap());
233        assert_eq!(total.load(Ordering::Acquire), N * M);
234    }
235
236    #[test]
237    fn test_lock_map_random_key() {
238        let lock_map = Arc::new(LockManager::<u32>::with_capacity(128));
239        let total = Arc::new(AtomicUsize::default());
240        const N: usize = 1 << 20;
241        const M: usize = 8;
242
243        let threads = (0..M)
244            .map(|_| {
245                let lock_map = lock_map.clone();
246                let total = total.clone();
247                std::thread::spawn(move || {
248                    for _ in 0..N {
249                        let key = rand::random();
250                        let _guard = lock_map.lock(&key);
251                        total.fetch_add(1, Ordering::AcqRel);
252                    }
253                })
254            })
255            .collect::<Vec<_>>();
256        threads.into_iter().for_each(|t| t.join().unwrap());
257        assert_eq!(total.load(Ordering::Acquire), N * M);
258    }
259
260    #[test]
261    fn test_batch_lock() {
262        let lock_map = Arc::new(LockManager::<usize>::with_capacity_and_shard_amount(
263            128, 16,
264        ));
265        let total = Arc::new(AtomicUsize::default());
266        let current = Arc::new(AtomicU32::default());
267        const N: usize = 1 << 12;
268        const M: usize = 8;
269
270        let threads = (0..M)
271            .map(|i| {
272                let lock_map = lock_map.clone();
273                let total = total.clone();
274                let current = current.clone();
275                let state = (0..M).filter(|v| *v != i).collect::<BTreeSet<_>>();
276                std::thread::spawn(move || {
277                    for _ in 0..N {
278                        let _guard = lock_map.batch_lock(&state);
279                        let now = current.fetch_add(1, Ordering::AcqRel);
280                        assert_eq!(now, 0);
281                        total.fetch_add(1, Ordering::AcqRel);
282                        let now = current.fetch_sub(1, Ordering::AcqRel);
283                        assert_eq!(now, 1);
284                    }
285                })
286            })
287            .collect::<Vec<_>>();
288        threads.into_iter().for_each(|t| t.join().unwrap());
289        assert_eq!(total.load(Ordering::Acquire), N * M);
290    }
291
292    #[test]
293    #[should_panic(expected = "impossible: unlock a non-existent key!")]
294    fn test_invalid_unlock() {
295        let lock_map = LockManager::<u32>::default();
296        let _lock_guard = LockGuard {
297            map: &lock_map,
298            key: &42,
299        };
300    }
301}