1use dashmap::{DashMap, Entry};
41use std::collections::{BTreeSet, LinkedList};
42use std::hash::Hash;
43use std::sync::atomic::{AtomicU32, Ordering};
44
45struct WaiterPtr(*const AtomicU32);
46
47impl WaiterPtr {
48 fn wake_up(self) {
49 let ptr = self.0;
50 let waiter = unsafe { &*ptr };
51 waiter.store(1, Ordering::Release);
52 atomic_wait::wake_one(ptr);
53 }
54}
55
56unsafe impl Sync for WaiterPtr {}
57unsafe impl Send for WaiterPtr {}
58
59pub struct LockManager<K: Eq + Hash + Clone> {
67 map: DashMap<K, LinkedList<WaiterPtr>>,
68}
69
70impl<K: Eq + Hash + Clone> LockManager<K> {
71 pub fn new() -> Self {
73 Self {
74 map: DashMap::new(),
75 }
76 }
77
78 pub fn with_capacity(capacity: usize) -> Self {
83 Self {
84 map: DashMap::with_capacity(capacity),
85 }
86 }
87
88 pub fn with_capacity_and_shard_amount(capacity: usize, shard_amount: usize) -> Self {
94 Self {
95 map: DashMap::with_capacity_and_shard_amount(capacity, shard_amount),
96 }
97 }
98
99 pub fn lock<'a, 'b>(&'a self, key: &'b K) -> LockGuard<'a, 'b, K> {
109 self.raw_lock(key);
110 LockGuard::<'a, 'b, K> { map: self, key }
111 }
112
113 pub fn batch_lock<'a, 'b>(&'a self, keys: &'b BTreeSet<K>) -> BatchLockGuard<'a, 'b, K> {
124 for key in keys {
125 self.raw_lock(key);
126 }
127 BatchLockGuard::<'a, 'b, K> { map: self, keys }
128 }
129
130 fn raw_lock(&self, key: &K) {
131 let waiter = AtomicU32::new(0);
132 match self.map.entry(key.clone()) {
133 Entry::Occupied(mut occupied_entry) => {
134 occupied_entry.get_mut().push_back(WaiterPtr(&waiter as _));
135 }
136 Entry::Vacant(vacant_entry) => {
137 vacant_entry.insert(Default::default());
138 waiter.store(1, Ordering::Release);
139 }
140 };
141 while waiter.load(Ordering::Acquire) == 0 {
142 atomic_wait::wait(&waiter, 0);
143 }
144 }
145
146 fn unlock(&self, key: &K) {
147 match self.map.entry(key.clone()) {
148 Entry::Occupied(mut occupied_entry) => match occupied_entry.get_mut().pop_front() {
149 Some(waiter) => {
150 waiter.wake_up();
151 }
152 None => {
153 occupied_entry.remove();
154 }
155 },
156 Entry::Vacant(_) => panic!("impossible: unlock a non-existent key!"),
157 }
158 }
159
160 fn batch_unlock(&self, keys: &BTreeSet<K>) {
161 for key in keys.iter().rev() {
162 self.unlock(key);
163 }
164 }
165}
166
167impl<K: Eq + Hash + Clone> Default for LockManager<K> {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173pub struct LockGuard<'a, 'b, K: Eq + Hash + Clone> {
177 map: &'a LockManager<K>,
178 key: &'b K,
179}
180
181impl<'a, 'b, K: Eq + Hash + Clone> Drop for LockGuard<'a, 'b, K> {
182 fn drop(&mut self) {
183 self.map.unlock(self.key);
184 }
185}
186
187pub struct BatchLockGuard<'a, 'b, K: Eq + Hash + Clone> {
192 map: &'a LockManager<K>,
193 keys: &'b BTreeSet<K>,
194}
195
196impl<'a, 'b, K: Eq + Hash + Clone> Drop for BatchLockGuard<'a, 'b, K> {
197 fn drop(&mut self) {
198 self.map.batch_unlock(self.keys);
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use std::sync::{atomic::AtomicUsize, Arc};
206
207 #[test]
208 fn test_lock_map_same_key() {
209 let lock_map = Arc::new(LockManager::<u32>::new());
210 let total = Arc::new(AtomicUsize::default());
211 let current = Arc::new(AtomicU32::default());
212 const N: usize = 1 << 12;
213 const M: usize = 8;
214
215 let threads = (0..M)
216 .map(|_| {
217 let lock_map = lock_map.clone();
218 let total = total.clone();
219 let current = current.clone();
220 std::thread::spawn(move || {
221 for _ in 0..N {
222 let _guard = lock_map.lock(&1);
223 let now = current.fetch_add(1, Ordering::AcqRel);
224 assert_eq!(now, 0);
225 total.fetch_add(1, Ordering::AcqRel);
226 let now = current.fetch_sub(1, Ordering::AcqRel);
227 assert_eq!(now, 1);
228 }
229 })
230 })
231 .collect::<Vec<_>>();
232 threads.into_iter().for_each(|t| t.join().unwrap());
233 assert_eq!(total.load(Ordering::Acquire), N * M);
234 }
235
236 #[test]
237 fn test_lock_map_random_key() {
238 let lock_map = Arc::new(LockManager::<u32>::with_capacity(128));
239 let total = Arc::new(AtomicUsize::default());
240 const N: usize = 1 << 20;
241 const M: usize = 8;
242
243 let threads = (0..M)
244 .map(|_| {
245 let lock_map = lock_map.clone();
246 let total = total.clone();
247 std::thread::spawn(move || {
248 for _ in 0..N {
249 let key = rand::random();
250 let _guard = lock_map.lock(&key);
251 total.fetch_add(1, Ordering::AcqRel);
252 }
253 })
254 })
255 .collect::<Vec<_>>();
256 threads.into_iter().for_each(|t| t.join().unwrap());
257 assert_eq!(total.load(Ordering::Acquire), N * M);
258 }
259
260 #[test]
261 fn test_batch_lock() {
262 let lock_map = Arc::new(LockManager::<usize>::with_capacity_and_shard_amount(
263 128, 16,
264 ));
265 let total = Arc::new(AtomicUsize::default());
266 let current = Arc::new(AtomicU32::default());
267 const N: usize = 1 << 12;
268 const M: usize = 8;
269
270 let threads = (0..M)
271 .map(|i| {
272 let lock_map = lock_map.clone();
273 let total = total.clone();
274 let current = current.clone();
275 let state = (0..M).filter(|v| *v != i).collect::<BTreeSet<_>>();
276 std::thread::spawn(move || {
277 for _ in 0..N {
278 let _guard = lock_map.batch_lock(&state);
279 let now = current.fetch_add(1, Ordering::AcqRel);
280 assert_eq!(now, 0);
281 total.fetch_add(1, Ordering::AcqRel);
282 let now = current.fetch_sub(1, Ordering::AcqRel);
283 assert_eq!(now, 1);
284 }
285 })
286 })
287 .collect::<Vec<_>>();
288 threads.into_iter().for_each(|t| t.join().unwrap());
289 assert_eq!(total.load(Ordering::Acquire), N * M);
290 }
291
292 #[test]
293 #[should_panic(expected = "impossible: unlock a non-existent key!")]
294 fn test_invalid_unlock() {
295 let lock_map = LockManager::<u32>::default();
296 let _lock_guard = LockGuard {
297 map: &lock_map,
298 key: &42,
299 };
300 }
301}