1use std::fmt::Debug;
2
3use crossbeam_utils::sync::WaitGroup;
4use hashbrown::HashMap;
5use parking_lot::Mutex;
6
7#[derive(Clone, Debug)]
9struct Call<T>
10where
11 T: Clone + Debug,
12{
13 wg: WaitGroup,
14 res: Option<T>,
15}
16
17impl<T> Call<T>
18where
19 T: Clone + Debug,
20{
21 fn new() -> Call<T> {
22 Call {
23 wg: WaitGroup::new(),
24 res: None,
25 }
26 }
27}
28
29#[derive(Default)]
32pub struct Group<T>
33where
34 T: Clone + Debug,
35{
36 m: Mutex<HashMap<String, Box<Call<T>>>>,
37}
38
39impl<T> Group<T>
40where
41 T: Clone + Debug,
42{
43 pub fn new() -> Group<T> {
45 Group {
46 m: Mutex::new(HashMap::new()),
47 }
48 }
49
50 pub fn work<F>(&self, key: &str, func: F) -> T
54 where
55 F: Fn() -> T,
56 {
57 let mut m = self.m.lock();
58
59 if let Some(c) = m.get(key) {
60 let c = c.clone();
61 drop(m);
62 c.wg.wait();
63 return c.res.unwrap();
64 }
65
66 let c = Call::new();
67 let wg = c.wg.clone();
68 let job = m.entry(key.to_owned()).or_insert(Box::new(c));
69 job.res = Some(func());
70 drop(m);
71 drop(wg);
72
73 let mut m = self.m.lock();
74 let c = m.remove(key).unwrap();
75 drop(m);
76
77 c.res.unwrap()
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::Group;
84
85 const RES: usize = 7;
86
87 #[test]
88 fn test_simple() {
89 let g = Group::new();
90 let res = g.work("key", || RES);
91 assert_eq!(res, RES);
92 }
93
94 #[test]
95 fn test_multiple_threads() {
96 use std::time::Duration;
97
98 use crossbeam_utils::thread;
99
100 fn expensive_fn() -> usize {
101 std::thread::sleep(Duration::new(0, 500));
102 RES
103 }
104
105 let g = Group::new();
106 thread::scope(|s| {
107 for _ in 0..10 {
108 s.spawn(|_| {
109 let res = g.work("key", expensive_fn);
110 assert_eq!(res, RES);
111 });
112 }
113 })
114 .unwrap();
115 }
116}