keyed_lock/
async.rs

1use parking_lot::Mutex as SyncMutex;
2use std::{collections::HashMap, hash::Hash, sync::Arc};
3use tokio::sync::{Mutex, OwnedMutexGuard};
4
5/// An RAII implementation of a scoped lock. When this structure is dropped
6/// (falls out of scope), the lock is released.
7pub struct Guard<'k, K: Eq + Hash + Clone + Send> {
8    key: K,
9    _guard: OwnedMutexGuard<()>,
10    keyed_lock: &'k KeyedLock<K>,
11}
12
13impl<'k, K: Eq + Hash + Clone + Send> Drop for Guard<'k, K> {
14    fn drop(&mut self) {
15        let mut registry = self.keyed_lock.0.lock();
16        // If the strong count is 2, it means only the registry and current guard
17        // hold a reference to the Arc. In this case, we can safely remove the
18        // key from the registry.
19        if let Some(arc_mutex) = registry.get(&self.key) {
20            if Arc::strong_count(arc_mutex) == 2 {
21                registry.remove(&self.key);
22            }
23        }
24    }
25}
26
27/// An RAII implementation of a scoped lock for an `Arc<KeyedLock>`. When this
28/// structure is dropped (falls out of scope), the lock is released.
29pub struct OwnedGuard<K: Eq + Hash + Clone + Send> {
30    key: K,
31    _guard: OwnedMutexGuard<()>,
32    keyed_lock: Arc<KeyedLock<K>>,
33}
34
35impl<K: Eq + Hash + Clone + Send> Drop for OwnedGuard<K> {
36    fn drop(&mut self) {
37        let mut registry = self.keyed_lock.0.lock();
38        // If the strong count is 2, it means only the registry and current guard
39        // hold a reference to the Arc. In this case, we can safely remove the
40        // key from the registry.
41        if let Some(arc_mutex) = registry.get(&self.key) {
42            if Arc::strong_count(arc_mutex) == 2 {
43                registry.remove(&self.key);
44            }
45        }
46    }
47}
48
49/// A lock that provides mutually exclusive access to a resource, where the
50/// resource is identified by a key.
51pub struct KeyedLock<K: Eq + Hash + Clone + Send>(SyncMutex<HashMap<K, Arc<Mutex<()>>>>);
52
53impl<K: Eq + Hash + Clone + Send> KeyedLock<K> {
54    /// Creates a new `KeyedLock`.
55    #[must_use]
56    pub fn new() -> Self {
57        Self(SyncMutex::new(HashMap::new()))
58    }
59
60    /// Acquires a lock for a given key.
61    ///
62    /// If the lock is already held by another task, this method will wait until
63    /// the lock is released.
64    ///
65    /// When the returned `Guard` is dropped, the lock is released.
66    pub async fn lock<'a>(&'a self, key: K) -> Guard<'a, K> {
67        let _guard = self.lock_inner(&key).await;
68        Guard {
69            key,
70            _guard,
71            keyed_lock: self,
72        }
73    }
74
75    /// Acquires a lock for a given key, returning an `OwnedGuard`.
76    ///
77    /// This method is for use with `Arc<KeyedLock>`. If the lock is already
78    /// held by another task, this method will wait until the lock is released.
79    ///
80    /// When the returned `OwnedGuard` is dropped, the lock is released.
81    pub async fn lock_owned(self: &Arc<Self>, key: K) -> OwnedGuard<K> {
82        let _guard = self.lock_inner(&key).await;
83        OwnedGuard {
84            key,
85            _guard,
86            keyed_lock: self.clone(),
87        }
88    }
89
90    /// Gets or creates a mutex for a key and locks it.
91    async fn lock_inner(&self, key: &K) -> OwnedMutexGuard<()> {
92        let key_lock = {
93            let mut registry = self.0.lock();
94            if let Some(notifies) = registry.get_mut(key) {
95                Arc::clone(notifies)
96            } else {
97                let new = Arc::new(Mutex::new(()));
98                registry.insert(key.clone(), new.clone());
99                new
100            }
101        };
102        key_lock.lock_owned().await
103    }
104
105    #[cfg(test)]
106    fn registry_len(&self) -> usize {
107        self.0.lock().len()
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use std::time::Duration;
115    use tokio::time::sleep;
116
117    #[tokio::test]
118    async fn test_lock_unlock() {
119        let keyed_lock = KeyedLock::new();
120        let guard = keyed_lock.lock(1).await;
121        drop(guard);
122    }
123
124    #[tokio::test]
125    async fn test_lock_contention() {
126        let keyed_lock = Arc::new(KeyedLock::new());
127        let keyed_lock_clone = Arc::clone(&keyed_lock);
128
129        let guard1 = keyed_lock.lock(1).await;
130
131        let task = tokio::spawn(async move {
132            keyed_lock_clone.lock(1).await;
133        });
134
135        sleep(Duration::from_millis(10)).await;
136        assert!(!task.is_finished());
137
138        drop(guard1);
139        sleep(Duration::from_millis(10)).await;
140        assert!(task.is_finished());
141    }
142
143    #[tokio::test]
144    async fn test_owned_lock_unlock() {
145        let keyed_lock = Arc::new(KeyedLock::new());
146        let guard = keyed_lock.lock_owned(1).await;
147        drop(guard);
148    }
149
150    #[tokio::test]
151    async fn test_registry_cleanup() {
152        let keyed_lock = KeyedLock::new();
153        assert_eq!(keyed_lock.registry_len(), 0);
154
155        let guard = keyed_lock.lock(1).await;
156        assert_eq!(keyed_lock.registry_len(), 1);
157        drop(guard);
158
159        assert_eq!(keyed_lock.registry_len(), 0);
160    }
161
162    #[tokio::test]
163    async fn test_multiple_keys() {
164        let keyed_lock = KeyedLock::new();
165        let guard1 = keyed_lock.lock(1).await;
166        let guard2 = keyed_lock.lock(2).await;
167
168        assert_eq!(keyed_lock.registry_len(), 2);
169
170        drop(guard1);
171        assert_eq!(keyed_lock.registry_len(), 1);
172
173        drop(guard2);
174        assert_eq!(keyed_lock.registry_len(), 0);
175    }
176}