1use std::{
2 hash::Hash,
3 sync::{
4 atomic::{AtomicUsize, Ordering},
5 Arc,
6 },
7};
8
9use crate::{Invoker, ModeWrapper};
10
11pub struct MutexSync<K>
29where
30 K: 'static + Sync + Send + Clone + Hash + Ord,
31{
32 mutex_map: flurry::HashMap<K, ReferenceCountedMutex>,
33}
34
35impl<K> Default for MutexSync<K>
36where
37 K: 'static + Sync + Send + Clone + Hash + Ord,
38{
39 fn default() -> Self {
40 MutexSync {
41 mutex_map: flurry::HashMap::new(),
42 }
43 }
44}
45
46impl<K> MutexSync<K>
47where
48 K: 'static + Sync + Send + Clone + Hash + Ord,
49{
50 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn evaluate<R, F: FnOnce() -> R>(&self, key: K, task: F) -> R {
73 let mutex_map = self.mutex_map.pin();
74
75 let rc_mutex = if let Some(mutex) = mutex_map.get(&key) {
76 if mutex.increment_rc() > 0 {
77 mutex
78 } else {
79 Self::create_mutex(&key, &mutex_map)
80 }
81 } else {
82 Self::create_mutex(&key, &mutex_map)
83 };
84
85 let _guard = rc_mutex.lock(&key, &mutex_map);
86 task()
87 }
88
89 #[inline]
90 fn create_mutex<'a>(
91 key: &K,
92 map_ref: &'a flurry::HashMapRef<'a, K, ReferenceCountedMutex>,
93 ) -> &'a ReferenceCountedMutex {
94 let mut mutex = ReferenceCountedMutex::new();
95 loop {
96 match map_ref.try_insert(key.clone(), mutex) {
97 Ok(mutex_ref) => break mutex_ref,
98 Err(insert_err) => {
99 let curr = insert_err.current;
100 if curr.increment_rc() > 0 {
101 break curr;
102 } else {
103 mutex = insert_err.not_inserted;
104 }
105 }
106 }
107 }
108 }
109}
110
111pub struct ReferenceCountedMutex {
117 mutex: parking_lot::Mutex<()>,
118 rc: AtomicUsize,
119}
120
121impl ReferenceCountedMutex {
122 fn new() -> Self {
124 ReferenceCountedMutex {
125 mutex: parking_lot::Mutex::new(()),
126 rc: AtomicUsize::new(1),
127 }
128 }
129
130 fn lock<'a, K>(
133 &'a self,
134 key: &'a K,
135 map: &'a flurry::HashMapRef<'a, K, ReferenceCountedMutex>,
136 ) -> ReferenceCountedMutexGuard<'a, K>
137 where
138 K: 'static + Sync + Send + Clone + Hash + Ord,
139 {
140 let _mutex_guard = self.mutex.lock();
141
142 ReferenceCountedMutexGuard {
143 map,
144 key,
145 mutex: self,
146 _mutex_guard,
147 }
148 }
149
150 fn increment_rc(&self) -> usize {
153 let curr = self.rc.load(Ordering::Relaxed);
154
155 if curr == 0 {
157 return curr;
158 }
159
160 let mut expected = curr;
161
162 loop {
163 match self.rc.compare_exchange_weak(
164 expected,
165 expected + 1,
166 Ordering::Relaxed,
167 Ordering::Relaxed,
168 ) {
169 Ok(witnessed) => break witnessed,
170 Err(witnessed) if witnessed == 0 => break witnessed,
171 Err(witnessed) => expected = witnessed,
172 }
173 }
174 }
175
176 fn decrement_rc<K>(&self, key: &K, map_ref: &flurry::HashMapRef<K, ReferenceCountedMutex>)
183 where
184 K: 'static + Sync + Send + Clone + Hash + Ord,
185 {
186 let curr = self.rc.fetch_sub(1, Ordering::Relaxed);
187
188 if curr == 1 {
189 map_ref.remove(key);
190 }
191 }
192}
193
194struct ReferenceCountedMutexGuard<'a, K>
197where
198 K: 'static + Sync + Send + Clone + Hash + Ord,
199{
200 map: &'a flurry::HashMapRef<'a, K, ReferenceCountedMutex>,
201 key: &'a K,
202 mutex: &'a ReferenceCountedMutex,
203 _mutex_guard: parking_lot::MutexGuard<'a, ()>,
204}
205
206impl<K> Drop for ReferenceCountedMutexGuard<'_, K>
207where
208 K: 'static + Sync + Send + Clone + Hash + Ord,
209{
210 fn drop(&mut self) {
211 self.mutex.decrement_rc(self.key, self.map);
212 }
213}
214
215pub struct MutexSyncExecutor<K, M>
219where
220 K: 'static + Sync + Send + Clone + Hash + Ord,
221 M: std::borrow::Borrow<MutexSync<K>> + 'static,
222{
223 key: K,
224 mutex_sync: M,
225}
226
227impl<T, K, M> ModeWrapper<'static, T> for MutexSyncExecutor<K, M>
228where
229 T: 'static,
230 K: 'static + Sync + Send + Clone + Hash + Ord,
231 M: std::borrow::Borrow<MutexSync<K>> + 'static,
232{
233 fn wrap<'f>(
234 self: Arc<Self>,
235 task: Box<(dyn FnOnce() -> T + 'f)>,
236 ) -> Box<(dyn FnOnce() -> T + 'f)> {
237 Box::new(move || self.mutex_sync.borrow().evaluate(self.key.clone(), task))
238 }
239}
240
241impl<K, M> Invoker for MutexSyncExecutor<K, M>
242where
243 K: 'static + Sync + Send + Clone + Hash + Ord,
244 M: std::borrow::Borrow<MutexSync<K>> + 'static,
245{
246 fn do_invoke<'f, T: 'f, F: FnOnce() -> T + 'f>(
247 &'f self,
248 mode: Option<&'f super::Mode<'f, T>>,
249 task: F,
250 ) -> T {
251 self.mutex_sync.borrow().evaluate(self.key.clone(), || {
252 if let Some(mode) = mode {
253 super::invoke(mode, task)
254 } else {
255 task()
256 }
257 })
258 }
259}
260
261#[cfg(test)]
262mod tests {
263
264 use crate::Invoker;
265
266 use super::{MutexSync, MutexSyncExecutor};
267 use std::sync::{
268 atomic::{AtomicBool, AtomicI32, Ordering},
269 Arc,
270 };
271
272 #[test]
273 fn it_works() {
274 let mutex_sync = Arc::new(MutexSync::<i32>::new());
275 let failed = Arc::new(AtomicBool::new(false));
276 let running_set = Arc::new(flurry::HashSet::<i32>::new());
277
278 let mut handles = Vec::with_capacity(5);
279
280 for _ in 0..5 {
281 let mutex_sync = mutex_sync.clone();
282 let failed = failed.clone();
283 let running_set = running_set.clone();
284
285 let handle = std::thread::spawn(move || {
286 for i in 0..15 {
287 let mutex_sync = mutex_sync.clone();
288 let failed = failed.clone();
289 let running_set = running_set.clone();
290
291 let mut handles = Vec::with_capacity(5);
292
293 let handle = std::thread::spawn(move || {
294 let running_set = running_set.pin();
295 mutex_sync.evaluate(i, || {
296 if running_set.contains(&i) {
297 failed.store(true, Ordering::Relaxed);
298 }
299
300 running_set.insert(i);
301
302 std::thread::sleep(std::time::Duration::from_secs(1));
303
304 if !running_set.contains(&i) {
305 failed.store(true, Ordering::Relaxed);
306 }
307
308 std::thread::sleep(std::time::Duration::from_secs(1));
309 running_set.remove(&i);
310
311 if running_set.contains(&i) {
312 failed.store(true, Ordering::Relaxed);
313 }
314 })
315 });
316
317 handles.push(handle);
318
319 for handle in handles {
320 handle.join().unwrap();
321 }
322 }
323 });
324
325 handles.push(handle);
326 }
327
328 for handle in handles {
329 handle.join().unwrap();
330 }
331
332 assert_eq!(failed.load(Ordering::Relaxed), false);
333 }
334
335 #[test]
336 fn test_concurrent_different_key() {
337 let running = Arc::new(AtomicBool::new(false));
338 let failed = Arc::new(AtomicBool::new(false));
339
340 let mutex_sync = Arc::new(MutexSync::<i32>::new());
341
342 let mut handles = Vec::with_capacity(2);
343
344 let mutex_sync1 = mutex_sync.clone();
345 let running1 = running.clone();
346 let handle1 = std::thread::spawn(move || {
347 mutex_sync1.evaluate(1, move || {
348 running1.store(true, Ordering::Relaxed);
349 std::thread::sleep(std::time::Duration::from_secs(5));
350 running1.store(false, Ordering::Relaxed);
351 });
352 });
353 handles.push(handle1);
354
355 let mutex_sync2 = mutex_sync.clone();
356 let running2 = running.clone();
357 let failed2 = failed.clone();
358 let handle2 = std::thread::spawn(move || {
359 mutex_sync2.evaluate(2, move || {
360 std::thread::sleep(std::time::Duration::from_secs(3));
361
362 if !running2.load(Ordering::Relaxed) {
363 failed2.store(true, Ordering::Relaxed);
364 }
365 });
366 });
367 handles.push(handle2);
368
369 for handle in handles {
370 handle.join().unwrap();
371 }
372
373 assert_eq!(failed.load(Ordering::Relaxed), false);
374 }
375
376 #[test]
377 fn test_mutex_sync_executor() {
378 let mutex_sync = Arc::new(MutexSync::<i32>::new());
379 let failed = Arc::new(AtomicBool::new(false));
380 let running_set = Arc::new(flurry::HashSet::<i32>::new());
381 let multiplier_map = Arc::new(flurry::HashMap::<i32, AtomicI32>::new());
382
383 {
384 let map = multiplier_map.pin();
385 for i in 0..5 {
386 map.insert(i, AtomicI32::new(0));
387 }
388 }
389
390 let mutex_sync_executor = MutexSyncExecutor {
391 key: 1,
392 mutex_sync: MutexSync::<i32>::new(),
393 };
394
395 assert_eq!(mutex_sync_executor.invoke(|| 4), 4);
396
397 let mut handles = Vec::with_capacity(25);
398
399 for _ in 0..5 {
400 for i in 0..5 {
401 let failed = failed.clone();
402 let failed2 = failed.clone();
403 let running_set = running_set.clone();
404 let multiplier_map = multiplier_map.clone();
405
406 let executor = MutexSyncExecutor {
407 key: i,
408 mutex_sync: mutex_sync.clone(),
409 };
410
411 let handle = std::thread::spawn(move || {
412 let running_set = running_set.pin();
413 executor.invoke(move || {
414 if running_set.contains(&i) {
415 failed.store(true, Ordering::Relaxed);
416 }
417
418 running_set.insert(i);
419
420 std::thread::sleep(std::time::Duration::from_secs(1));
421
422 if !running_set.contains(&i) {
423 failed.store(true, Ordering::Relaxed);
424 }
425
426 std::thread::sleep(std::time::Duration::from_secs(1));
427 running_set.remove(&i);
428
429 if running_set.contains(&i) {
430 failed.store(true, Ordering::Relaxed);
431 }
432 });
433
434 let mode = crate::Mode::<i32>::new().with(executor);
435 let result = crate::invoke(&mode, move || {
436 let multiplier_map = multiplier_map.pin();
437 let multiplier = multiplier_map.get(&i).unwrap();
438 multiplier.store(2, Ordering::Relaxed);
439 std::thread::sleep(std::time::Duration::from_secs(1));
440 let result = multiplier.load(Ordering::Relaxed) * 4;
441 multiplier.store(0, Ordering::Relaxed);
442 result
443 });
444
445 if result != 8 {
446 failed2.store(true, Ordering::Relaxed);
447 }
448 });
449
450 handles.push(handle);
451 }
452 }
453
454 for handle in handles {
455 handle.join().unwrap();
456 }
457
458 assert_eq!(failed.load(Ordering::Relaxed), false);
459 }
460
461 #[test]
462 fn test_remove_mutex_on_panic() {
463 let mutex_sync = Arc::new(MutexSync::<i32>::new());
464
465 let m = mutex_sync.clone();
466 let handle = std::thread::spawn(move || {
467 m.evaluate(1, || {
468 panic!("test panic");
469 });
470 });
471
472 let _ = handle.join();
473 assert!(mutex_sync.mutex_map.is_empty());
474 }
475}