singleflight/
lib.rs

1use std::fmt::Debug;
2
3use crossbeam_utils::sync::WaitGroup;
4use hashbrown::HashMap;
5use parking_lot::Mutex;
6
7// Call is an in-flight or completed call to work.
8#[derive(Clone, Debug)]
9struct Call<T>
10where
11    T: Clone + Debug,
12{
13    wg: WaitGroup,
14    res: Option<T>,
15}
16
17impl<T> Call<T>
18where
19    T: Clone + Debug,
20{
21    fn new() -> Call<T> {
22        Call {
23            wg: WaitGroup::new(),
24            res: None,
25        }
26    }
27}
28
29/// Group represents a class of work and creates a space in which units of work
30/// can be executed with duplicate suppression.
31#[derive(Default)]
32pub struct Group<T>
33where
34    T: Clone + Debug,
35{
36    m: Mutex<HashMap<String, Box<Call<T>>>>,
37}
38
39impl<T> Group<T>
40where
41    T: Clone + Debug,
42{
43    /// Create a new Group to do work with.
44    pub fn new() -> Group<T> {
45        Group {
46            m: Mutex::new(HashMap::new()),
47        }
48    }
49
50    /// Execute and return the value for a given function, making sure that only one
51    /// operation is in-flight at a given moment. If a duplicate call comes in, that caller will
52    /// wait until the original call completes and return the same value.
53    pub fn work<F>(&self, key: &str, func: F) -> T
54    where
55        F: Fn() -> T,
56    {
57        let mut m = self.m.lock();
58
59        if let Some(c) = m.get(key) {
60            let c = c.clone();
61            drop(m);
62            c.wg.wait();
63            return c.res.unwrap();
64        }
65
66        let c = Call::new();
67        let wg = c.wg.clone();
68        let job = m.entry(key.to_owned()).or_insert(Box::new(c));
69        job.res = Some(func());
70        drop(m);
71        drop(wg);
72
73        let mut m = self.m.lock();
74        let c = m.remove(key).unwrap();
75        drop(m);
76
77        c.res.unwrap()
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::Group;
84
85    const RES: usize = 7;
86
87    #[test]
88    fn test_simple() {
89        let g = Group::new();
90        let res = g.work("key", || RES);
91        assert_eq!(res, RES);
92    }
93
94    #[test]
95    fn test_multiple_threads() {
96        use std::time::Duration;
97
98        use crossbeam_utils::thread;
99
100        fn expensive_fn() -> usize {
101            std::thread::sleep(Duration::new(0, 500));
102            RES
103        }
104
105        let g = Group::new();
106        thread::scope(|s| {
107            for _ in 0..10 {
108                s.spawn(|_| {
109                    let res = g.work("key", expensive_fn);
110                    assert_eq!(res, RES);
111                });
112            }
113        })
114        .unwrap();
115    }
116}