exec_rs/
sync.rs

1use std::{
2    hash::Hash,
3    sync::{
4        atomic::{AtomicUsize, Ordering},
5        Arc,
6    },
7};
8
9use crate::{Invoker, ModeWrapper};
10
11/// Task executor that can synchronise tasks by value of a key provided when submitting a task.
12///
13/// For example, if i32 is used as key type, then tasks submitted with keys 3 and 5 may run concurrently
14/// but several tasks submitted with key 7 are synchronised by a mutex mapped to the key.
15///
16/// Manages a concurrent hash map that maps [`ReferenceCountedMutex`](struct.ReferenceCountedMutex.html)
17/// elements to the used keys. The [`ReferenceCountedMutex`](struct.ReferenceCountedMutex.html) struct
18/// holds a mutex used for synchronisation and removes itself from the map automatically if not used by
19/// any thread anymore by managing an atomic reference counter. If the counter is decremented from 1 to
20/// 0 the element is removed from the map and the counter cannot be incremented back up again. If the counter
21/// reached 0 future increments fail and a new `ReferenceCountedMutex` is created instead. When creating
22/// a new `ReferenceCountedMutex` and inserting it to the map fails because another thread has already
23/// created an element for the same key, the current thread tries to use the found existing element instead
24/// as long as its reference counter is valid (greater than 0), else it retries creating the element.
25///
26/// The type of the key used for synchronisation must be able to be used as a key for the map and thus
27/// must implement `Sync + Send + Clone + Hash + Ord` and have a static lifetime.
28pub struct MutexSync<K>
29where
30    K: 'static + Sync + Send + Clone + Hash + Ord,
31{
32    mutex_map: flurry::HashMap<K, ReferenceCountedMutex>,
33}
34
35impl<K> Default for MutexSync<K>
36where
37    K: 'static + Sync + Send + Clone + Hash + Ord,
38{
39    fn default() -> Self {
40        MutexSync {
41            mutex_map: flurry::HashMap::new(),
42        }
43    }
44}
45
46impl<K> MutexSync<K>
47where
48    K: 'static + Sync + Send + Clone + Hash + Ord,
49{
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Submits a task for execution using the provided key for synchronisation. Uses the mutex of the
55    /// [`ReferenceCountedMutex`](struct.ReferenceCountedMutex.html) mapped to the key to synchronise
56    /// execution of the task. The `ReferenceCountedMutex` removes itself from the map automatically
57    /// as soon as no thread is using it anymore by managing an atomic reference counter.
58    ///
59    /// If the mutex map does not already contain an entry for the provided key, meaning no task is
60    /// currently running with a mutex mapped to the same key, the current thread attempts to insert
61    /// a new [`ReferenceCountedMutex`](struct.ReferenceCountedMutex.html), with an initial reference
62    /// count of 1, and if it succeeds acquires the mutex and runs the task, else it tries to use the
63    /// found existing `ReferenceCountedMutex` if its reference counter is valid (greater than 0) or
64    /// else retries creating the `ReferenceCountedMutex`.
65    ///
66    /// If the mutex map already contains a `ReferenceCountedMutex` mapped to the same key, meaning there
67    /// are threads running using the same key, the current thread attempts to increment the reference
68    /// counter on the found `ReferenceCountedMutex` and if it succeeds it waits to acquire the mutex and
69    /// then executes the task, if it fails, because the counter has been decremented to 0 because another
70    /// thread is in the process of removing the mutex from the map, it tries to create a new
71    /// `ReferenceCountedMutex`, same as above.
72    pub fn evaluate<R, F: FnOnce() -> R>(&self, key: K, task: F) -> R {
73        let mutex_map = self.mutex_map.pin();
74
75        let rc_mutex = if let Some(mutex) = mutex_map.get(&key) {
76            if mutex.increment_rc() > 0 {
77                mutex
78            } else {
79                Self::create_mutex(&key, &mutex_map)
80            }
81        } else {
82            Self::create_mutex(&key, &mutex_map)
83        };
84
85        let _guard = rc_mutex.lock(&key, &mutex_map);
86        task()
87    }
88
89    #[inline]
90    fn create_mutex<'a>(
91        key: &K,
92        map_ref: &'a flurry::HashMapRef<'a, K, ReferenceCountedMutex>,
93    ) -> &'a ReferenceCountedMutex {
94        let mut mutex = ReferenceCountedMutex::new();
95        loop {
96            match map_ref.try_insert(key.clone(), mutex) {
97                Ok(mutex_ref) => break mutex_ref,
98                Err(insert_err) => {
99                    let curr = insert_err.current;
100                    if curr.increment_rc() > 0 {
101                        break curr;
102                    } else {
103                        mutex = insert_err.not_inserted;
104                    }
105                }
106            }
107        }
108    }
109}
110
111/// Struct that holds the mutex used for synchronisation and manages removing itself from the
112/// containing map once no longer referenced by any threads. Removes itself from the map when
113/// decrementing the counter from 1 to 0 and makes sure that the counter cannot be incremented
114/// back up once reaching 0 in case a thread finds a ReferenceCountedMutex that is in the
115/// process of being removed from the map.
116pub struct ReferenceCountedMutex {
117    mutex: parking_lot::Mutex<()>,
118    rc: AtomicUsize,
119}
120
121impl ReferenceCountedMutex {
122    /// Create a new ReferenceCountedMutex with an initial reference count of 1.
123    fn new() -> Self {
124        ReferenceCountedMutex {
125            mutex: parking_lot::Mutex::new(()),
126            rc: AtomicUsize::new(1),
127        }
128    }
129
130    /// Requires the lock for the underlying mutex and returns a `ReferenceCountedMutexGuard`
131    /// that additionally calls `decrement_rc` when dropped.
132    fn lock<'a, K>(
133        &'a self,
134        key: &'a K,
135        map: &'a flurry::HashMapRef<'a, K, ReferenceCountedMutex>,
136    ) -> ReferenceCountedMutexGuard<'a, K>
137    where
138        K: 'static + Sync + Send + Clone + Hash + Ord,
139    {
140        let _mutex_guard = self.mutex.lock();
141
142        ReferenceCountedMutexGuard {
143            map,
144            key,
145            mutex: self,
146            _mutex_guard,
147        }
148    }
149
150    /// Attempts to increment the reference counter, failing to do so if it has reached 0 already.
151    /// Callers can check whether the increment succeed by checking whether the witnessed value is 0.
152    fn increment_rc(&self) -> usize {
153        let curr = self.rc.load(Ordering::Relaxed);
154
155        // disallow incrementing once it reached 0
156        if curr == 0 {
157            return curr;
158        }
159
160        let mut expected = curr;
161
162        loop {
163            match self.rc.compare_exchange_weak(
164                expected,
165                expected + 1,
166                Ordering::Relaxed,
167                Ordering::Relaxed,
168            ) {
169                Ok(witnessed) => break witnessed,
170                Err(witnessed) if witnessed == 0 => break witnessed,
171                Err(witnessed) => expected = witnessed,
172            }
173        }
174    }
175
176    /// Decrements the reference counter, removing the entry from the map if the previous value was 0.
177    /// This is protected against race conditions since ReferenceCountedMutex elements cannot be used
178    /// anymore once the reference counter has reached 0, so even if some other thread might still
179    /// find this ReferenceCountedMutex in the map after this thread has decremented the rc to 0 but
180    /// before this thread removed the element from the map, the other thread will fail to increment
181    /// the reference counter and thus has to create a new ReferenceCountedMutex element.
182    fn decrement_rc<K>(&self, key: &K, map_ref: &flurry::HashMapRef<K, ReferenceCountedMutex>)
183    where
184        K: 'static + Sync + Send + Clone + Hash + Ord,
185    {
186        let curr = self.rc.fetch_sub(1, Ordering::Relaxed);
187
188        if curr == 1 {
189            map_ref.remove(key);
190        }
191    }
192}
193
194/// Wrapper struct containing the actual mutex guard for the underlying Mutex that additionally calls
195/// `ReferenceCountedMutex::decrement_rc` when dropped.
196struct ReferenceCountedMutexGuard<'a, K>
197where
198    K: 'static + Sync + Send + Clone + Hash + Ord,
199{
200    map: &'a flurry::HashMapRef<'a, K, ReferenceCountedMutex>,
201    key: &'a K,
202    mutex: &'a ReferenceCountedMutex,
203    _mutex_guard: parking_lot::MutexGuard<'a, ()>,
204}
205
206impl<K> Drop for ReferenceCountedMutexGuard<'_, K>
207where
208    K: 'static + Sync + Send + Clone + Hash + Ord,
209{
210    fn drop(&mut self) {
211        self.mutex.decrement_rc(self.key, self.map);
212    }
213}
214
215/// Struct that implements the [`ModeWrapper`](trait.ModeWrapper.html) and [`Invoker`](trait.Invoker.html)
216/// traits for any type that borrows [`MutexSync`](struct.MutexSync.html) and a specific key. Enables using
217/// [`MutexSync`](struct.MutexSync.html) as a `ModeWrapper` or `Invoker`.
218pub struct MutexSyncExecutor<K, M>
219where
220    K: 'static + Sync + Send + Clone + Hash + Ord,
221    M: std::borrow::Borrow<MutexSync<K>> + 'static,
222{
223    key: K,
224    mutex_sync: M,
225}
226
227impl<T, K, M> ModeWrapper<'static, T> for MutexSyncExecutor<K, M>
228where
229    T: 'static,
230    K: 'static + Sync + Send + Clone + Hash + Ord,
231    M: std::borrow::Borrow<MutexSync<K>> + 'static,
232{
233    fn wrap<'f>(
234        self: Arc<Self>,
235        task: Box<(dyn FnOnce() -> T + 'f)>,
236    ) -> Box<(dyn FnOnce() -> T + 'f)> {
237        Box::new(move || self.mutex_sync.borrow().evaluate(self.key.clone(), task))
238    }
239}
240
241impl<K, M> Invoker for MutexSyncExecutor<K, M>
242where
243    K: 'static + Sync + Send + Clone + Hash + Ord,
244    M: std::borrow::Borrow<MutexSync<K>> + 'static,
245{
246    fn do_invoke<'f, T: 'f, F: FnOnce() -> T + 'f>(
247        &'f self,
248        mode: Option<&'f super::Mode<'f, T>>,
249        task: F,
250    ) -> T {
251        self.mutex_sync.borrow().evaluate(self.key.clone(), || {
252            if let Some(mode) = mode {
253                super::invoke(mode, task)
254            } else {
255                task()
256            }
257        })
258    }
259}
260
261#[cfg(test)]
262mod tests {
263
264    use crate::Invoker;
265
266    use super::{MutexSync, MutexSyncExecutor};
267    use std::sync::{
268        atomic::{AtomicBool, AtomicI32, Ordering},
269        Arc,
270    };
271
272    #[test]
273    fn it_works() {
274        let mutex_sync = Arc::new(MutexSync::<i32>::new());
275        let failed = Arc::new(AtomicBool::new(false));
276        let running_set = Arc::new(flurry::HashSet::<i32>::new());
277
278        let mut handles = Vec::with_capacity(5);
279
280        for _ in 0..5 {
281            let mutex_sync = mutex_sync.clone();
282            let failed = failed.clone();
283            let running_set = running_set.clone();
284
285            let handle = std::thread::spawn(move || {
286                for i in 0..15 {
287                    let mutex_sync = mutex_sync.clone();
288                    let failed = failed.clone();
289                    let running_set = running_set.clone();
290
291                    let mut handles = Vec::with_capacity(5);
292
293                    let handle = std::thread::spawn(move || {
294                        let running_set = running_set.pin();
295                        mutex_sync.evaluate(i, || {
296                            if running_set.contains(&i) {
297                                failed.store(true, Ordering::Relaxed);
298                            }
299
300                            running_set.insert(i);
301
302                            std::thread::sleep(std::time::Duration::from_secs(1));
303
304                            if !running_set.contains(&i) {
305                                failed.store(true, Ordering::Relaxed);
306                            }
307
308                            std::thread::sleep(std::time::Duration::from_secs(1));
309                            running_set.remove(&i);
310
311                            if running_set.contains(&i) {
312                                failed.store(true, Ordering::Relaxed);
313                            }
314                        })
315                    });
316
317                    handles.push(handle);
318
319                    for handle in handles {
320                        handle.join().unwrap();
321                    }
322                }
323            });
324
325            handles.push(handle);
326        }
327
328        for handle in handles {
329            handle.join().unwrap();
330        }
331
332        assert_eq!(failed.load(Ordering::Relaxed), false);
333    }
334
335    #[test]
336    fn test_concurrent_different_key() {
337        let running = Arc::new(AtomicBool::new(false));
338        let failed = Arc::new(AtomicBool::new(false));
339
340        let mutex_sync = Arc::new(MutexSync::<i32>::new());
341
342        let mut handles = Vec::with_capacity(2);
343
344        let mutex_sync1 = mutex_sync.clone();
345        let running1 = running.clone();
346        let handle1 = std::thread::spawn(move || {
347            mutex_sync1.evaluate(1, move || {
348                running1.store(true, Ordering::Relaxed);
349                std::thread::sleep(std::time::Duration::from_secs(5));
350                running1.store(false, Ordering::Relaxed);
351            });
352        });
353        handles.push(handle1);
354
355        let mutex_sync2 = mutex_sync.clone();
356        let running2 = running.clone();
357        let failed2 = failed.clone();
358        let handle2 = std::thread::spawn(move || {
359            mutex_sync2.evaluate(2, move || {
360                std::thread::sleep(std::time::Duration::from_secs(3));
361
362                if !running2.load(Ordering::Relaxed) {
363                    failed2.store(true, Ordering::Relaxed);
364                }
365            });
366        });
367        handles.push(handle2);
368
369        for handle in handles {
370            handle.join().unwrap();
371        }
372
373        assert_eq!(failed.load(Ordering::Relaxed), false);
374    }
375
376    #[test]
377    fn test_mutex_sync_executor() {
378        let mutex_sync = Arc::new(MutexSync::<i32>::new());
379        let failed = Arc::new(AtomicBool::new(false));
380        let running_set = Arc::new(flurry::HashSet::<i32>::new());
381        let multiplier_map = Arc::new(flurry::HashMap::<i32, AtomicI32>::new());
382
383        {
384            let map = multiplier_map.pin();
385            for i in 0..5 {
386                map.insert(i, AtomicI32::new(0));
387            }
388        }
389
390        let mutex_sync_executor = MutexSyncExecutor {
391            key: 1,
392            mutex_sync: MutexSync::<i32>::new(),
393        };
394
395        assert_eq!(mutex_sync_executor.invoke(|| 4), 4);
396
397        let mut handles = Vec::with_capacity(25);
398
399        for _ in 0..5 {
400            for i in 0..5 {
401                let failed = failed.clone();
402                let failed2 = failed.clone();
403                let running_set = running_set.clone();
404                let multiplier_map = multiplier_map.clone();
405
406                let executor = MutexSyncExecutor {
407                    key: i,
408                    mutex_sync: mutex_sync.clone(),
409                };
410
411                let handle = std::thread::spawn(move || {
412                    let running_set = running_set.pin();
413                    executor.invoke(move || {
414                        if running_set.contains(&i) {
415                            failed.store(true, Ordering::Relaxed);
416                        }
417
418                        running_set.insert(i);
419
420                        std::thread::sleep(std::time::Duration::from_secs(1));
421
422                        if !running_set.contains(&i) {
423                            failed.store(true, Ordering::Relaxed);
424                        }
425
426                        std::thread::sleep(std::time::Duration::from_secs(1));
427                        running_set.remove(&i);
428
429                        if running_set.contains(&i) {
430                            failed.store(true, Ordering::Relaxed);
431                        }
432                    });
433
434                    let mode = crate::Mode::<i32>::new().with(executor);
435                    let result = crate::invoke(&mode, move || {
436                        let multiplier_map = multiplier_map.pin();
437                        let multiplier = multiplier_map.get(&i).unwrap();
438                        multiplier.store(2, Ordering::Relaxed);
439                        std::thread::sleep(std::time::Duration::from_secs(1));
440                        let result = multiplier.load(Ordering::Relaxed) * 4;
441                        multiplier.store(0, Ordering::Relaxed);
442                        result
443                    });
444
445                    if result != 8 {
446                        failed2.store(true, Ordering::Relaxed);
447                    }
448                });
449
450                handles.push(handle);
451            }
452        }
453
454        for handle in handles {
455            handle.join().unwrap();
456        }
457
458        assert_eq!(failed.load(Ordering::Relaxed), false);
459    }
460
461    #[test]
462    fn test_remove_mutex_on_panic() {
463        let mutex_sync = Arc::new(MutexSync::<i32>::new());
464
465        let m = mutex_sync.clone();
466        let handle = std::thread::spawn(move || {
467            m.evaluate(1, || {
468                panic!("test panic");
469            });
470        });
471
472        let _ = handle.join();
473        assert!(mutex_sync.mutex_map.is_empty());
474    }
475}