key_rwlock/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3#![warn(clippy::dbg_macro, clippy::use_debug)]
4#![warn(missing_docs, missing_debug_implementations, clippy::todo)]
5
6use std::{
7    collections::HashMap,
8    hash::Hash,
9    sync::{
10        atomic::{AtomicUsize, Ordering},
11        Arc,
12    },
13};
14
15use tokio::sync::{Mutex, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock, TryLockError};
16
17/// An async reader-writer lock, that locks based on a key, while allowing other
18/// keys to lock independently. Based on a [HashMap] of [RwLock]s.
19#[derive(Debug)]
20pub struct KeyRwLock<K> {
21    /// The inner map of locks for specific keys.
22    locks: Mutex<HashMap<K, Arc<RwLock<()>>>>,
23    /// Number of lock accesses.
24    accesses: AtomicUsize,
25}
26
27impl<K> Default for KeyRwLock<K> {
28    fn default() -> Self {
29        Self {
30            locks: Mutex::default(),
31            accesses: AtomicUsize::default(),
32        }
33    }
34}
35
36impl<K> KeyRwLock<K>
37where
38    K: Eq + Hash + Send + Clone,
39{
40    /// Create new instance of a [KeyRwLock]
41    #[must_use]
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    /// Lock this key with shared read access, returning a guard. Cleans up
47    /// locks every 1000 accesses.
48    pub async fn read(&self, key: K) -> OwnedRwLockReadGuard<()> {
49        let mut locks = self.locks.lock().await;
50
51        if self.accesses.fetch_add(1, Ordering::Relaxed) % 1000 == 0 {
52            Self::clean_up(&mut locks);
53        }
54
55        let lock = locks.entry(key).or_default().clone();
56        drop(locks);
57
58        lock.read_owned().await
59    }
60
61    /// Lock this key with exclusive write access, returning a guard. Cleans up
62    /// locks every 1000 accesses.
63    pub async fn write(&self, key: K) -> OwnedRwLockWriteGuard<()> {
64        let mut locks = self.locks.lock().await;
65
66        if self.accesses.fetch_add(1, Ordering::Relaxed) % 1000 == 0 {
67            Self::clean_up(&mut locks);
68        }
69
70        let lock = locks.entry(key).or_default().clone();
71        drop(locks);
72
73        lock.write_owned().await
74    }
75
76    /// Try lock this key with shared read access, returning immediately. Cleans
77    /// up locks every 1000 accesses.
78    pub async fn try_read(&self, key: K) -> Result<OwnedRwLockReadGuard<()>, TryLockError> {
79        let mut locks = self.locks.lock().await;
80
81        if self.accesses.fetch_add(1, Ordering::Relaxed) % 1000 == 0 {
82            Self::clean_up(&mut locks);
83        }
84
85        let lock = locks.entry(key).or_default().clone();
86        drop(locks);
87
88        lock.try_read_owned()
89    }
90
91    /// Try lock this key with exclusive write access, returning immediately.
92    /// Cleans up locks every 1000 accesses.
93    pub async fn try_write(&self, key: K) -> Result<OwnedRwLockWriteGuard<()>, TryLockError> {
94        let mut locks = self.locks.lock().await;
95
96        if self.accesses.fetch_add(1, Ordering::Relaxed) % 1000 == 0 {
97            Self::clean_up(&mut locks);
98        }
99
100        let lock = locks.entry(key).or_default().clone();
101        drop(locks);
102
103        lock.try_write_owned()
104    }
105
106    /// Clean up by removing locks that are not locked.
107    pub async fn clean(&self) {
108        let mut locks = self.locks.lock().await;
109        Self::clean_up(&mut locks);
110    }
111
112    /// Remove locks that are not locked currently.
113    fn clean_up(locks: &mut HashMap<K, Arc<RwLock<()>>>) {
114        let mut to_remove = Vec::new();
115        for (key, lock) in locks.iter() {
116            if lock.try_write().is_ok() {
117                to_remove.push(key.clone());
118            }
119        }
120        for key in to_remove {
121            locks.remove(&key);
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[tokio::test]
131    async fn test_basic_funcionality() {
132        let lock = KeyRwLock::new();
133
134        let _foo = lock.write("foo").await;
135        let _bar = lock.read("bar").await;
136
137        assert!(lock.try_read("foo").await.is_err());
138        assert!(lock.try_write("foo").await.is_err());
139
140        assert!(lock.try_read("bar").await.is_ok());
141        assert!(lock.try_write("bar").await.is_err());
142    }
143
144    #[tokio::test]
145    async fn test_clean_up() {
146        let lock = KeyRwLock::new();
147        let _foo_write = lock.write("foo_write").await;
148        let _bar_write = lock.write("bar_write").await;
149        let _foo_read = lock.read("foo_read").await;
150        let _bar_read = lock.read("bar_read").await;
151        assert_eq!(lock.locks.lock().await.len(), 4);
152        drop(_foo_read);
153        drop(_bar_write);
154        lock.clean().await;
155        assert_eq!(lock.locks.lock().await.len(), 2);
156    }
157}