1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3#![warn(clippy::dbg_macro, clippy::use_debug)]
4#![warn(missing_docs, missing_debug_implementations, clippy::todo)]
5
6use std::{
7 collections::HashMap,
8 hash::Hash,
9 sync::{
10 atomic::{AtomicUsize, Ordering},
11 Arc,
12 },
13};
14
15use tokio::sync::{Mutex, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock, TryLockError};
16
17#[derive(Debug)]
20pub struct KeyRwLock<K> {
21 locks: Mutex<HashMap<K, Arc<RwLock<()>>>>,
23 accesses: AtomicUsize,
25}
26
27impl<K> Default for KeyRwLock<K> {
28 fn default() -> Self {
29 Self {
30 locks: Mutex::default(),
31 accesses: AtomicUsize::default(),
32 }
33 }
34}
35
36impl<K> KeyRwLock<K>
37where
38 K: Eq + Hash + Send + Clone,
39{
40 #[must_use]
42 pub fn new() -> Self {
43 Self::default()
44 }
45
46 pub async fn read(&self, key: K) -> OwnedRwLockReadGuard<()> {
49 let mut locks = self.locks.lock().await;
50
51 if self.accesses.fetch_add(1, Ordering::Relaxed) % 1000 == 0 {
52 Self::clean_up(&mut locks);
53 }
54
55 let lock = locks.entry(key).or_default().clone();
56 drop(locks);
57
58 lock.read_owned().await
59 }
60
61 pub async fn write(&self, key: K) -> OwnedRwLockWriteGuard<()> {
64 let mut locks = self.locks.lock().await;
65
66 if self.accesses.fetch_add(1, Ordering::Relaxed) % 1000 == 0 {
67 Self::clean_up(&mut locks);
68 }
69
70 let lock = locks.entry(key).or_default().clone();
71 drop(locks);
72
73 lock.write_owned().await
74 }
75
76 pub async fn try_read(&self, key: K) -> Result<OwnedRwLockReadGuard<()>, TryLockError> {
79 let mut locks = self.locks.lock().await;
80
81 if self.accesses.fetch_add(1, Ordering::Relaxed) % 1000 == 0 {
82 Self::clean_up(&mut locks);
83 }
84
85 let lock = locks.entry(key).or_default().clone();
86 drop(locks);
87
88 lock.try_read_owned()
89 }
90
91 pub async fn try_write(&self, key: K) -> Result<OwnedRwLockWriteGuard<()>, TryLockError> {
94 let mut locks = self.locks.lock().await;
95
96 if self.accesses.fetch_add(1, Ordering::Relaxed) % 1000 == 0 {
97 Self::clean_up(&mut locks);
98 }
99
100 let lock = locks.entry(key).or_default().clone();
101 drop(locks);
102
103 lock.try_write_owned()
104 }
105
106 pub async fn clean(&self) {
108 let mut locks = self.locks.lock().await;
109 Self::clean_up(&mut locks);
110 }
111
112 fn clean_up(locks: &mut HashMap<K, Arc<RwLock<()>>>) {
114 let mut to_remove = Vec::new();
115 for (key, lock) in locks.iter() {
116 if lock.try_write().is_ok() {
117 to_remove.push(key.clone());
118 }
119 }
120 for key in to_remove {
121 locks.remove(&key);
122 }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[tokio::test]
131 async fn test_basic_funcionality() {
132 let lock = KeyRwLock::new();
133
134 let _foo = lock.write("foo").await;
135 let _bar = lock.read("bar").await;
136
137 assert!(lock.try_read("foo").await.is_err());
138 assert!(lock.try_write("foo").await.is_err());
139
140 assert!(lock.try_read("bar").await.is_ok());
141 assert!(lock.try_write("bar").await.is_err());
142 }
143
144 #[tokio::test]
145 async fn test_clean_up() {
146 let lock = KeyRwLock::new();
147 let _foo_write = lock.write("foo_write").await;
148 let _bar_write = lock.write("bar_write").await;
149 let _foo_read = lock.read("foo_read").await;
150 let _bar_read = lock.read("bar_read").await;
151 assert_eq!(lock.locks.lock().await.len(), 4);
152 drop(_foo_read);
153 drop(_bar_write);
154 lock.clean().await;
155 assert_eq!(lock.locks.lock().await.len(), 2);
156 }
157}