1use parking_lot::Mutex as SyncMutex;
2use std::{collections::HashMap, hash::Hash, sync::Arc};
3use tokio::sync::{Mutex, OwnedMutexGuard};
4
5pub struct Guard<'k, K: Eq + Hash + Clone + Send> {
8 key: K,
9 _guard: OwnedMutexGuard<()>,
10 keyed_lock: &'k KeyedLock<K>,
11}
12
13impl<'k, K: Eq + Hash + Clone + Send> Drop for Guard<'k, K> {
14 fn drop(&mut self) {
15 let mut registry = self.keyed_lock.0.lock();
16 if let Some(arc_mutex) = registry.get(&self.key) {
20 if Arc::strong_count(arc_mutex) == 2 {
21 registry.remove(&self.key);
22 }
23 }
24 }
25}
26
27pub struct OwnedGuard<K: Eq + Hash + Clone + Send> {
30 key: K,
31 _guard: OwnedMutexGuard<()>,
32 keyed_lock: Arc<KeyedLock<K>>,
33}
34
35impl<K: Eq + Hash + Clone + Send> Drop for OwnedGuard<K> {
36 fn drop(&mut self) {
37 let mut registry = self.keyed_lock.0.lock();
38 if let Some(arc_mutex) = registry.get(&self.key) {
42 if Arc::strong_count(arc_mutex) == 2 {
43 registry.remove(&self.key);
44 }
45 }
46 }
47}
48
49pub struct KeyedLock<K: Eq + Hash + Clone + Send>(SyncMutex<HashMap<K, Arc<Mutex<()>>>>);
52
53impl<K: Eq + Hash + Clone + Send> KeyedLock<K> {
54 #[must_use]
56 pub fn new() -> Self {
57 Self(SyncMutex::new(HashMap::new()))
58 }
59
60 pub async fn lock<'a>(&'a self, key: K) -> Guard<'a, K> {
67 let _guard = self.lock_inner(&key).await;
68 Guard {
69 key,
70 _guard,
71 keyed_lock: self,
72 }
73 }
74
75 pub async fn lock_owned(self: &Arc<Self>, key: K) -> OwnedGuard<K> {
82 let _guard = self.lock_inner(&key).await;
83 OwnedGuard {
84 key,
85 _guard,
86 keyed_lock: self.clone(),
87 }
88 }
89
90 async fn lock_inner(&self, key: &K) -> OwnedMutexGuard<()> {
92 let key_lock = {
93 let mut registry = self.0.lock();
94 if let Some(notifies) = registry.get_mut(key) {
95 Arc::clone(notifies)
96 } else {
97 let new = Arc::new(Mutex::new(()));
98 registry.insert(key.clone(), new.clone());
99 new
100 }
101 };
102 key_lock.lock_owned().await
103 }
104
105 #[cfg(test)]
106 fn registry_len(&self) -> usize {
107 self.0.lock().len()
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use std::time::Duration;
115 use tokio::time::sleep;
116
117 #[tokio::test]
118 async fn test_lock_unlock() {
119 let keyed_lock = KeyedLock::new();
120 let guard = keyed_lock.lock(1).await;
121 drop(guard);
122 }
123
124 #[tokio::test]
125 async fn test_lock_contention() {
126 let keyed_lock = Arc::new(KeyedLock::new());
127 let keyed_lock_clone = Arc::clone(&keyed_lock);
128
129 let guard1 = keyed_lock.lock(1).await;
130
131 let task = tokio::spawn(async move {
132 keyed_lock_clone.lock(1).await;
133 });
134
135 sleep(Duration::from_millis(10)).await;
136 assert!(!task.is_finished());
137
138 drop(guard1);
139 sleep(Duration::from_millis(10)).await;
140 assert!(task.is_finished());
141 }
142
143 #[tokio::test]
144 async fn test_owned_lock_unlock() {
145 let keyed_lock = Arc::new(KeyedLock::new());
146 let guard = keyed_lock.lock_owned(1).await;
147 drop(guard);
148 }
149
150 #[tokio::test]
151 async fn test_registry_cleanup() {
152 let keyed_lock = KeyedLock::new();
153 assert_eq!(keyed_lock.registry_len(), 0);
154
155 let guard = keyed_lock.lock(1).await;
156 assert_eq!(keyed_lock.registry_len(), 1);
157 drop(guard);
158
159 assert_eq!(keyed_lock.registry_len(), 0);
160 }
161
162 #[tokio::test]
163 async fn test_multiple_keys() {
164 let keyed_lock = KeyedLock::new();
165 let guard1 = keyed_lock.lock(1).await;
166 let guard2 = keyed_lock.lock(2).await;
167
168 assert_eq!(keyed_lock.registry_len(), 2);
169
170 drop(guard1);
171 assert_eq!(keyed_lock.registry_len(), 1);
172
173 drop(guard2);
174 assert_eq!(keyed_lock.registry_len(), 0);
175 }
176}