keyed_lock/
sync.rs

1use parking_lot::{ArcMutexGuard, Mutex, RawMutex};
2use std::{collections::HashMap, hash::Hash, sync::Arc};
3
4/// An RAII implementation of a scoped lock. When this structure is dropped
5/// (falls out of scope), the lock is released.
6pub struct Guard<'k, K: Eq + Hash + Clone> {
7    key: K,
8    _guard: ArcMutexGuard<RawMutex, ()>,
9    keyed_lock: &'k KeyedLock<K>,
10}
11
12impl<'k, K: Eq + Hash + Clone> Drop for Guard<'k, K> {
13    fn drop(&mut self) {
14        let mut registry = self.keyed_lock.0.lock();
15        // If the strong count is 2, it means only the registry and current guard
16        // hold a reference to the Arc. In this case, we can safely remove the
17        // key from the registry.
18        if let Some(arc_mutex) = registry.get(&self.key) {
19            if Arc::strong_count(arc_mutex) == 2 {
20                registry.remove(&self.key);
21            }
22        }
23    }
24}
25
26/// An RAII implementation of a scoped lock for an `Arc<KeyedLock>`. When this
27/// structure is dropped (falls out of scope), the lock is released.
28pub struct OwnedGuard<K: Eq + Hash + Clone> {
29    key: K,
30    _guard: ArcMutexGuard<RawMutex, ()>,
31    keyed_lock: Arc<KeyedLock<K>>,
32}
33
34impl<K: Eq + Hash + Clone> Drop for OwnedGuard<K> {
35    fn drop(&mut self) {
36        let mut registry = self.keyed_lock.0.lock();
37        // If the strong count is 2, it means only the registry and current guard
38        // hold a reference to the Arc. In this case, we can safely remove the
39        // key from the registry.
40        if let Some(arc_mutex) = registry.get(&self.key) {
41            if Arc::strong_count(arc_mutex) == 2 {
42                registry.remove(&self.key);
43            }
44        }
45    }
46}
47
48/// A lock that provides mutually exclusive access to a resource, where the
49/// resource is identified by a key.
50pub struct KeyedLock<K: Eq + Hash + Clone>(Mutex<HashMap<K, Arc<Mutex<()>>>>);
51
52impl<K: Eq + Hash + Clone> KeyedLock<K> {
53    /// Creates a new `KeyedLock`.
54    #[must_use]
55    pub fn new() -> Self {
56        Self(Mutex::new(HashMap::new()))
57    }
58
59    /// Acquires a lock for a given key.
60    ///
61    /// If the lock is already held by another task, this method will wait until
62    /// the lock is released.
63    ///
64    /// When the returned `Guard` is dropped, the lock is released.
65    pub fn lock(&self, key: K) -> Guard<'_, K> {
66        let _guard = self.lock_inner(&key);
67        Guard {
68            key,
69            _guard,
70            keyed_lock: self,
71        }
72    }
73
74    /// Acquires a lock for a given key, returning an `OwnedGuard`.
75    ///
76    /// This method is for use with `Arc<KeyedLock>`. If the lock is already
77    /// held by another task, this method will wait until the lock is released.
78    ///
79    /// When the returned `OwnedGuard` is dropped, the lock is released.
80    pub fn lock_owned(self: &Arc<Self>, key: K) -> OwnedGuard<K> {
81        let _guard = self.lock_inner(&key);
82        OwnedGuard {
83            key,
84            _guard,
85            keyed_lock: self.clone(),
86        }
87    }
88
89    /// Gets or creates a mutex for a key and locks it.
90    fn lock_inner(&self, key: &K) -> ArcMutexGuard<RawMutex, ()> {
91        let key_lock = {
92            let mut registry = self.0.lock();
93            if let Some(notifies) = registry.get_mut(key) {
94                Arc::clone(notifies)
95            } else {
96                let new = Arc::new(Mutex::new(()));
97                registry.insert(key.clone(), new.clone());
98                new
99            }
100        };
101        key_lock.lock_arc()
102    }
103
104    #[cfg(test)]
105    fn registry_len(&self) -> usize {
106        self.0.lock().len()
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use std::sync::Arc;
114    use std::thread;
115    use std::time::{Duration, Instant};
116
117    #[test]
118    fn test_basic_lock() {
119        let keyed_lock = KeyedLock::new();
120        let _guard = keyed_lock.lock(1);
121        // The lock is held here.
122        // When _guard goes out of scope, the lock is released.
123    }
124
125    #[test]
126    fn test_concurrent_access() {
127        let keyed_lock = Arc::new(KeyedLock::new());
128        let mut handles = vec![];
129
130        for _ in 0..10 {
131            let keyed_lock_clone = Arc::clone(&keyed_lock);
132            let handle = thread::spawn(move || {
133                let _guard = keyed_lock_clone.lock(1);
134                // Do some work while holding the lock.
135            });
136            handles.push(handle);
137        }
138
139        for handle in handles {
140            handle.join().unwrap();
141        }
142    }
143
144    #[test]
145    fn test_lock_is_released() {
146        let keyed_lock = KeyedLock::new();
147        let guard = keyed_lock.lock(1);
148        drop(guard);
149        // The lock should be released now.
150        let _guard2 = keyed_lock.lock(1);
151    }
152
153    #[test]
154    fn test_lock_reuse() {
155        let keyed_lock = KeyedLock::new();
156        let guard1 = keyed_lock.lock(1);
157        drop(guard1);
158        let guard2 = keyed_lock.lock(1);
159        drop(guard2);
160    }
161
162    #[test]
163    fn test_locks_different_keys() {
164        let keyed_lock = KeyedLock::new();
165        let _guard1 = keyed_lock.lock(1);
166        let _guard2 = keyed_lock.lock(2);
167        // Locks for different keys should not block each other.
168    }
169
170    #[test]
171    fn test_multiple_keys_concurrently() {
172        let keyed_lock = Arc::new(KeyedLock::new());
173        let mut handles = vec![];
174
175        for i in 0..10 {
176            let keyed_lock_clone = Arc::clone(&keyed_lock);
177            let handle = thread::spawn(move || {
178                let _guard = keyed_lock_clone.lock(i);
179                // Do some work while holding the lock.
180                thread::sleep(Duration::from_millis(10));
181            });
182            handles.push(handle);
183        }
184
185        for handle in handles {
186            handle.join().unwrap();
187        }
188    }
189
190    #[test]
191    fn test_non_reentrant_lock() {
192        let keyed_lock = Arc::new(KeyedLock::new());
193        let keyed_lock_clone = Arc::clone(&keyed_lock);
194
195        // Acquire the lock in the main thread
196        let _guard = keyed_lock.lock(1);
197
198        // Try to acquire the same lock in another thread
199        let handle = thread::spawn(move || {
200            let now = Instant::now();
201            let _guard = keyed_lock_clone.lock(1);
202            assert!(now.elapsed() >= Duration::from_secs(3));
203            // This part should not be reached if the lock is non-reentrant
204        });
205
206        std::thread::sleep(Duration::from_secs(4));
207        drop(_guard);
208
209        handle.join().unwrap();
210    }
211
212    #[test]
213    fn test_registry_cleanup() {
214        let keyed_lock = KeyedLock::new();
215        assert_eq!(keyed_lock.registry_len(), 0);
216
217        // Lock a key, registry should have one entry.
218        let guard = keyed_lock.lock(1);
219        assert_eq!(keyed_lock.registry_len(), 1);
220
221        // Drop the guard, registry should be empty.
222        drop(guard);
223        assert_eq!(keyed_lock.registry_len(), 0);
224    }
225
226    #[test]
227    fn test_registry_cleanup_concurrent() {
228        let keyed_lock = Arc::new(KeyedLock::new());
229        assert_eq!(keyed_lock.registry_len(), 0);
230
231        let guard1 = keyed_lock.lock(1);
232        assert_eq!(keyed_lock.registry_len(), 1);
233
234        let keyed_lock_clone = Arc::clone(&keyed_lock);
235        let handle = thread::spawn(move || {
236            // This will block until guard1 is dropped.
237            let guard2 = keyed_lock_clone.lock(1);
238            // The registry should still contain the key.
239            assert_eq!(keyed_lock_clone.registry_len(), 1);
240            drop(guard2);
241        });
242
243        // Before dropping guard1, another thread is waiting. Registry should have 1 entry.
244        // The strong count of the Arc is 3 (registry, guard1, handle's closure).
245        assert_eq!(keyed_lock.registry_len(), 1);
246        drop(guard1);
247
248        handle.join().unwrap();
249
250        // After all guards are dropped, the registry should be empty.
251        assert_eq!(keyed_lock.registry_len(), 0);
252    }
253
254    #[test]
255    fn test_registry_cleanup_arc() {
256        let keyed_lock = Arc::new(KeyedLock::new());
257        assert_eq!(keyed_lock.registry_len(), 0);
258
259        // Lock a key, registry should have one entry.
260        let guard = keyed_lock.lock_owned(1);
261        assert_eq!(keyed_lock.registry_len(), 1);
262
263        // Drop the guard, registry should be empty.
264        drop(guard);
265        assert_eq!(keyed_lock.registry_len(), 0);
266    }
267
268    #[test]
269    fn test_lock_arc_concurrently() {
270        let keyed_lock = Arc::new(KeyedLock::new());
271        let mut handles = vec![];
272
273        for i in 0..10 {
274            let keyed_lock_clone = Arc::clone(&keyed_lock);
275            let handle = thread::spawn(move || {
276                let _guard = keyed_lock_clone.lock_owned(i);
277                // Do some work while holding the lock.
278                thread::sleep(Duration::from_millis(10));
279            });
280            handles.push(handle);
281        }
282
283        for handle in handles {
284            handle.join().unwrap();
285        }
286    }
287
288    #[cfg(feature = "send_guard")]
289    #[test]
290    fn test_non_reentrant_lock_arc() {
291        let keyed_lock = Arc::new(KeyedLock::new());
292
293        // Acquire the lock in the main thread
294        let _guard = keyed_lock.lock_owned(1);
295
296        // Try to acquire the same lock in another thread
297        let handle = thread::spawn(move || {
298            std::thread::sleep(Duration::from_secs(4));
299            drop(_guard);
300        });
301
302        let now = Instant::now();
303        let _guard = keyed_lock.lock(1);
304        assert!(now.elapsed() >= Duration::from_secs(4));
305
306        handle.join().unwrap();
307    }
308
309    #[test]
310    fn test_basic_lock_arc() {
311        let keyed_lock = Arc::new(KeyedLock::new());
312        let _guard = keyed_lock.lock_owned(1);
313        // The lock is held here.
314        // When _guard goes out of scope, the lock is released.
315    }
316
317    #[test]
318    fn test_lock_is_released_arc() {
319        let keyed_lock = Arc::new(KeyedLock::new());
320        let guard = keyed_lock.lock_owned(1);
321        drop(guard);
322        // The lock should be released now.
323        let _guard2 = keyed_lock.lock_owned(1);
324    }
325
326    #[test]
327    fn test_lock_reuse_arc() {
328        let keyed_lock = Arc::new(KeyedLock::new());
329        let guard1 = keyed_lock.lock_owned(1);
330        drop(guard1);
331        let guard2 = keyed_lock.lock_owned(1);
332        drop(guard2);
333    }
334
335    #[test]
336    fn test_locks_different_keys_arc() {
337        let keyed_lock = Arc::new(KeyedLock::new());
338        let _guard1 = keyed_lock.lock_owned(1);
339        let _guard2 = keyed_lock.lock_owned(2);
340        // Locks for different keys should not block each other.
341    }
342}