1use parking_lot::{ArcMutexGuard, Mutex, RawMutex};
2use std::{collections::HashMap, hash::Hash, sync::Arc};
3
4pub struct Guard<'k, K: Eq + Hash + Clone> {
7 key: K,
8 _guard: ArcMutexGuard<RawMutex, ()>,
9 keyed_lock: &'k KeyedLock<K>,
10}
11
12impl<'k, K: Eq + Hash + Clone> Drop for Guard<'k, K> {
13 fn drop(&mut self) {
14 let mut registry = self.keyed_lock.0.lock();
15 if let Some(arc_mutex) = registry.get(&self.key) {
19 if Arc::strong_count(arc_mutex) == 2 {
20 registry.remove(&self.key);
21 }
22 }
23 }
24}
25
26pub struct OwnedGuard<K: Eq + Hash + Clone> {
29 key: K,
30 _guard: ArcMutexGuard<RawMutex, ()>,
31 keyed_lock: Arc<KeyedLock<K>>,
32}
33
34impl<K: Eq + Hash + Clone> Drop for OwnedGuard<K> {
35 fn drop(&mut self) {
36 let mut registry = self.keyed_lock.0.lock();
37 if let Some(arc_mutex) = registry.get(&self.key) {
41 if Arc::strong_count(arc_mutex) == 2 {
42 registry.remove(&self.key);
43 }
44 }
45 }
46}
47
48pub struct KeyedLock<K: Eq + Hash + Clone>(Mutex<HashMap<K, Arc<Mutex<()>>>>);
51
52impl<K: Eq + Hash + Clone> KeyedLock<K> {
53 #[must_use]
55 pub fn new() -> Self {
56 Self(Mutex::new(HashMap::new()))
57 }
58
59 pub fn lock(&self, key: K) -> Guard<'_, K> {
66 let _guard = self.lock_inner(&key);
67 Guard {
68 key,
69 _guard,
70 keyed_lock: self,
71 }
72 }
73
74 pub fn lock_owned(self: &Arc<Self>, key: K) -> OwnedGuard<K> {
81 let _guard = self.lock_inner(&key);
82 OwnedGuard {
83 key,
84 _guard,
85 keyed_lock: self.clone(),
86 }
87 }
88
89 fn lock_inner(&self, key: &K) -> ArcMutexGuard<RawMutex, ()> {
91 let key_lock = {
92 let mut registry = self.0.lock();
93 if let Some(notifies) = registry.get_mut(key) {
94 Arc::clone(notifies)
95 } else {
96 let new = Arc::new(Mutex::new(()));
97 registry.insert(key.clone(), new.clone());
98 new
99 }
100 };
101 key_lock.lock_arc()
102 }
103
104 #[cfg(test)]
105 fn registry_len(&self) -> usize {
106 self.0.lock().len()
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use std::sync::Arc;
114 use std::thread;
115 use std::time::{Duration, Instant};
116
117 #[test]
118 fn test_basic_lock() {
119 let keyed_lock = KeyedLock::new();
120 let _guard = keyed_lock.lock(1);
121 }
124
125 #[test]
126 fn test_concurrent_access() {
127 let keyed_lock = Arc::new(KeyedLock::new());
128 let mut handles = vec![];
129
130 for _ in 0..10 {
131 let keyed_lock_clone = Arc::clone(&keyed_lock);
132 let handle = thread::spawn(move || {
133 let _guard = keyed_lock_clone.lock(1);
134 });
136 handles.push(handle);
137 }
138
139 for handle in handles {
140 handle.join().unwrap();
141 }
142 }
143
144 #[test]
145 fn test_lock_is_released() {
146 let keyed_lock = KeyedLock::new();
147 let guard = keyed_lock.lock(1);
148 drop(guard);
149 let _guard2 = keyed_lock.lock(1);
151 }
152
153 #[test]
154 fn test_lock_reuse() {
155 let keyed_lock = KeyedLock::new();
156 let guard1 = keyed_lock.lock(1);
157 drop(guard1);
158 let guard2 = keyed_lock.lock(1);
159 drop(guard2);
160 }
161
162 #[test]
163 fn test_locks_different_keys() {
164 let keyed_lock = KeyedLock::new();
165 let _guard1 = keyed_lock.lock(1);
166 let _guard2 = keyed_lock.lock(2);
167 }
169
170 #[test]
171 fn test_multiple_keys_concurrently() {
172 let keyed_lock = Arc::new(KeyedLock::new());
173 let mut handles = vec![];
174
175 for i in 0..10 {
176 let keyed_lock_clone = Arc::clone(&keyed_lock);
177 let handle = thread::spawn(move || {
178 let _guard = keyed_lock_clone.lock(i);
179 thread::sleep(Duration::from_millis(10));
181 });
182 handles.push(handle);
183 }
184
185 for handle in handles {
186 handle.join().unwrap();
187 }
188 }
189
190 #[test]
191 fn test_non_reentrant_lock() {
192 let keyed_lock = Arc::new(KeyedLock::new());
193 let keyed_lock_clone = Arc::clone(&keyed_lock);
194
195 let _guard = keyed_lock.lock(1);
197
198 let handle = thread::spawn(move || {
200 let now = Instant::now();
201 let _guard = keyed_lock_clone.lock(1);
202 assert!(now.elapsed() >= Duration::from_secs(3));
203 });
205
206 std::thread::sleep(Duration::from_secs(4));
207 drop(_guard);
208
209 handle.join().unwrap();
210 }
211
212 #[test]
213 fn test_registry_cleanup() {
214 let keyed_lock = KeyedLock::new();
215 assert_eq!(keyed_lock.registry_len(), 0);
216
217 let guard = keyed_lock.lock(1);
219 assert_eq!(keyed_lock.registry_len(), 1);
220
221 drop(guard);
223 assert_eq!(keyed_lock.registry_len(), 0);
224 }
225
226 #[test]
227 fn test_registry_cleanup_concurrent() {
228 let keyed_lock = Arc::new(KeyedLock::new());
229 assert_eq!(keyed_lock.registry_len(), 0);
230
231 let guard1 = keyed_lock.lock(1);
232 assert_eq!(keyed_lock.registry_len(), 1);
233
234 let keyed_lock_clone = Arc::clone(&keyed_lock);
235 let handle = thread::spawn(move || {
236 let guard2 = keyed_lock_clone.lock(1);
238 assert_eq!(keyed_lock_clone.registry_len(), 1);
240 drop(guard2);
241 });
242
243 assert_eq!(keyed_lock.registry_len(), 1);
246 drop(guard1);
247
248 handle.join().unwrap();
249
250 assert_eq!(keyed_lock.registry_len(), 0);
252 }
253
254 #[test]
255 fn test_registry_cleanup_arc() {
256 let keyed_lock = Arc::new(KeyedLock::new());
257 assert_eq!(keyed_lock.registry_len(), 0);
258
259 let guard = keyed_lock.lock_owned(1);
261 assert_eq!(keyed_lock.registry_len(), 1);
262
263 drop(guard);
265 assert_eq!(keyed_lock.registry_len(), 0);
266 }
267
268 #[test]
269 fn test_lock_arc_concurrently() {
270 let keyed_lock = Arc::new(KeyedLock::new());
271 let mut handles = vec![];
272
273 for i in 0..10 {
274 let keyed_lock_clone = Arc::clone(&keyed_lock);
275 let handle = thread::spawn(move || {
276 let _guard = keyed_lock_clone.lock_owned(i);
277 thread::sleep(Duration::from_millis(10));
279 });
280 handles.push(handle);
281 }
282
283 for handle in handles {
284 handle.join().unwrap();
285 }
286 }
287
288 #[cfg(feature = "send_guard")]
289 #[test]
290 fn test_non_reentrant_lock_arc() {
291 let keyed_lock = Arc::new(KeyedLock::new());
292
293 let _guard = keyed_lock.lock_owned(1);
295
296 let handle = thread::spawn(move || {
298 std::thread::sleep(Duration::from_secs(4));
299 drop(_guard);
300 });
301
302 let now = Instant::now();
303 let _guard = keyed_lock.lock(1);
304 assert!(now.elapsed() >= Duration::from_secs(4));
305
306 handle.join().unwrap();
307 }
308
309 #[test]
310 fn test_basic_lock_arc() {
311 let keyed_lock = Arc::new(KeyedLock::new());
312 let _guard = keyed_lock.lock_owned(1);
313 }
316
317 #[test]
318 fn test_lock_is_released_arc() {
319 let keyed_lock = Arc::new(KeyedLock::new());
320 let guard = keyed_lock.lock_owned(1);
321 drop(guard);
322 let _guard2 = keyed_lock.lock_owned(1);
324 }
325
326 #[test]
327 fn test_lock_reuse_arc() {
328 let keyed_lock = Arc::new(KeyedLock::new());
329 let guard1 = keyed_lock.lock_owned(1);
330 drop(guard1);
331 let guard2 = keyed_lock.lock_owned(1);
332 drop(guard2);
333 }
334
335 #[test]
336 fn test_locks_different_keys_arc() {
337 let keyed_lock = Arc::new(KeyedLock::new());
338 let _guard1 = keyed_lock.lock_owned(1);
339 let _guard2 = keyed_lock.lock_owned(2);
340 }
342}