async_map_reduce/
managers.rs

1use std::marker::PhantomData;
2use std::thread;
3
4use crate::entities::{CtxWrapper, MapReduceError, ResWrapper, ThreadError, Worker};
5use crate::traits::{SCFunMF, SCFunRF, Sss};
6
7const DEFAULT_CONTEXT: () = ();
8
9pub struct Manager<'a, Ctx> {
10    context: &'a Ctx,
11}
12
13pub struct MapManager<'a, Ctx, Req, Resp: Sss, MF: Fn(&Ctx, usize, &Req) -> Resp> {
14    phantom_req: PhantomData<Req>,
15    manager: Manager<'a, Ctx>,
16    fun: MF,
17}
18
19pub struct ReduceManager<
20    'a,
21    Ctx,
22    Req,
23    Resp: Sss,
24    MF: Fn(&Ctx, usize, &Req) -> Resp,
25    RF: Fn(&Ctx, Resp, Resp) -> Resp,
26> {
27    map_manager: MapManager<'a, Ctx, Req, Resp, MF>,
28    fun: RF,
29}
30
31pub struct ReduceRManager<
32    'a,
33    Ctx,
34    Req,
35    ResOk: Sss,
36    ResErr: Sss,
37    MF: Fn(&Ctx, usize, &Req) -> Result<ResOk, ResErr>,
38    RF: Fn(&Ctx, ResOk, ResOk) -> Result<ResOk, ResErr>,
39> {
40    map_manager: MapManager<'a, Ctx, Req, Result<ResOk, ResErr>, MF>,
41    fun: RF,
42}
43
44pub struct FullManager<
45    'a,
46    Ctx,
47    Req,
48    Resp: Sss,
49    MF: Fn(&Ctx, usize, &Req) -> Resp,
50    RF: Fn(&Ctx, Resp, Resp) -> Resp,
51> {
52    reduce_manager: ReduceManager<'a, Ctx, Req, Resp, MF, RF>,
53    default_value: Resp,
54}
55
56pub struct FullRManager<
57    'a,
58    Ctx,
59    Req,
60    ROk: Sss,
61    RErr: Sss,
62    MF: Fn(&Ctx, usize, &Req) -> Result<ROk, RErr>,
63    RF: Fn(&Ctx, ROk, ROk) -> Result<ROk, RErr>,
64> {
65    reduce_manager: ReduceRManager<'a, Ctx, Req, ROk, RErr, MF, RF>,
66    default_value: ROk,
67}
68
69pub fn manager() -> Manager<'static, ()> {
70    Manager {
71        context: &DEFAULT_CONTEXT,
72    }
73}
74
75impl<'a, Ctx> Manager<'a, Ctx> {
76    pub fn context<NewCtx>(self, context: &NewCtx) -> Manager<NewCtx> {
77        Manager { context }
78    }
79
80    pub fn map<Req, Resp: Sss, MF: Fn(&Ctx, usize, &Req) -> Resp>(
81        self,
82        fun: MF,
83    ) -> MapManager<'a, Ctx, Req, Resp, MF> {
84        MapManager {
85            manager: self,
86            fun,
87            phantom_req: PhantomData::default(),
88        }
89    }
90}
91
92impl<'a, Ctx, Req, Resp: Sss, MF: SCFunMF<Ctx, Req, Resp>> MapManager<'a, Ctx, Req, Resp, MF> {
93    pub fn reduce<RF: SCFunRF<Ctx, Resp>>(
94        self,
95        fun: RF,
96    ) -> ReduceManager<'a, Ctx, Req, Resp, MF, RF> {
97        ReduceManager {
98            map_manager: self,
99            fun,
100        }
101    }
102}
103
104impl<'a, Ctx, Req, ROk: Sss, RErr: Sss, MF: SCFunMF<Ctx, Req, Result<ROk, RErr>>>
105    MapManager<'a, Ctx, Req, Result<ROk, RErr>, MF>
106{
107    pub fn reduce_result<RF: 'static + Fn(&Ctx, ROk, ROk) -> Result<ROk, RErr>>(
108        self,
109        fun: RF,
110    ) -> ReduceRManager<'a, Ctx, Req, ROk, RErr, MF, RF> {
111        ReduceRManager {
112            map_manager: self,
113            fun,
114        }
115    }
116}
117
118impl<
119        'a,
120        Ctx,
121        Req,
122        Resp: Sss,
123        MF: SCFunMF<Ctx, Req, Resp>,
124        RF: 'static + Fn(&Ctx, Resp, Resp) -> Resp,
125    > ReduceManager<'a, Ctx, Req, Resp, MF, RF>
126{
127    pub fn default(self, default_value: Resp) -> FullManager<'a, Ctx, Req, Resp, MF, RF> {
128        FullManager {
129            reduce_manager: self,
130            default_value,
131        }
132    }
133}
134
135impl<
136        'a,
137        Ctx,
138        Req,
139        Resp: Sss + Default,
140        MF: SCFunMF<Ctx, Req, Resp>,
141        RF: 'static + Fn(&Ctx, Resp, Resp) -> Resp,
142    > ReduceManager<'a, Ctx, Req, Resp, MF, RF>
143{
144    pub fn run(self, chunks: &[Req]) -> Result<Resp, ThreadError> {
145        self.default(Resp::default()).run(chunks)
146    }
147}
148
149impl<
150        'a,
151        Ctx,
152        Req,
153        ROk: Sss,
154        RErr: Sss,
155        MF: SCFunMF<Ctx, Req, Result<ROk, RErr>>,
156        RF: 'static + Fn(&Ctx, ROk, ROk) -> Result<ROk, RErr>,
157    > ReduceRManager<'a, Ctx, Req, ROk, RErr, MF, RF>
158{
159    fn default(self, default_value: ROk) -> FullRManager<'a, Ctx, Req, ROk, RErr, MF, RF> {
160        FullRManager {
161            reduce_manager: self,
162            default_value,
163        }
164    }
165}
166
167impl<
168        'a,
169        Ctx,
170        Req,
171        ROk: Sss + Default,
172        RErr: Sss,
173        MF: SCFunMF<Ctx, Req, Result<ROk, RErr>>,
174        RF: 'static + Fn(&Ctx, ROk, ROk) -> Result<ROk, RErr>,
175    > ReduceRManager<'a, Ctx, Req, ROk, RErr, MF, RF>
176{
177    pub fn run(self, chunks: &[Req]) -> Result<ROk, MapReduceError<RErr>> {
178        self.default(ROk::default()).run(chunks)
179    }
180}
181
182fn make_workers<Ctx, Req, Resp: Sss, MF: Fn(&Ctx, usize, &Req) -> Resp>(
183    map_manager: &MapManager<Ctx, Req, Resp, MF>,
184    chunks: &[Req],
185) -> Vec<Worker<ResWrapper<Resp>>> {
186    let fun = CtxWrapper::new(&map_manager.fun);
187    let ctx = CtxWrapper::new(map_manager.manager.context);
188    let mut workers = Vec::new();
189    for (id, chunk) in chunks.iter().enumerate() {
190        let request = CtxWrapper::new(chunk);
191        let handler: thread::JoinHandle<ResWrapper<Resp>> = thread::spawn(move || {
192            let res = fun.get::<MF>()(ctx.get::<Ctx>(), id, request.get::<Req>());
193            ResWrapper::new(res)
194        });
195        let worker = Worker {
196            thread: Box::new(handler),
197        };
198        workers.push(worker);
199    }
200    workers
201}
202
203impl<
204        'a,
205        Ctx,
206        Req,
207        Resp: Sss,
208        MF: SCFunMF<Ctx, Req, Resp>,
209        RF: 'static + Fn(&Ctx, Resp, Resp) -> Resp,
210    > FullManager<'a, Ctx, Req, Resp, MF, RF>
211{
212    pub fn run(self, chunks: &[Req]) -> Result<Resp, ThreadError> {
213        let mut workers = make_workers(&self.reduce_manager.map_manager, chunks);
214        let mut result = self.default_value;
215        let ctx = self.reduce_manager.map_manager.manager.context;
216        let mut failed = None;
217        workers.reverse();
218        for _ in 0..workers.len() {
219            let worker = workers.pop().unwrap();
220            let data = match worker.thread.join() {
221                Ok(val) => val.get(),
222                Err(err) => {
223                    failed = Some(err);
224                    break;
225                }
226            };
227            result = (self.reduce_manager.fun)(ctx, result, data)
228        }
229        if let Some(err) = failed {
230            for worker in workers {
231                let _ = worker.thread.join();
232            }
233            Err(err)
234        } else {
235            Ok(result)
236        }
237    }
238}
239
240impl<
241        'a,
242        Ctx,
243        Req,
244        ROk: Sss,
245        RErr: Sss,
246        MF: 'static + Fn(&Ctx, usize, &Req) -> Result<ROk, RErr>,
247        RF: 'static + Fn(&Ctx, ROk, ROk) -> Result<ROk, RErr>,
248    > FullRManager<'a, Ctx, Req, ROk, RErr, MF, RF>
249{
250    fn do_map_reduce(
251        self,
252        workers: &mut Vec<Worker<ResWrapper<Result<ROk, RErr>>>>,
253    ) -> Result<ROk, MapReduceError<RErr>> {
254        let ctx = self.reduce_manager.map_manager.manager.context;
255        let mut result = self.default_value;
256        for _ in 0..workers.len() {
257            let worker = workers.pop().unwrap();
258            let worker_res_wrapper = match worker.thread.join() {
259                Ok(wrapper) => wrapper,
260                Err(err) => return Err(MapReduceError::ThreadFailed(err)),
261            };
262            let worker_res = match worker_res_wrapper.get() {
263                Ok(val) => val,
264                Err(err) => return Err(MapReduceError::Custom(err)),
265            };
266            match (self.reduce_manager.fun)(ctx, result, worker_res) {
267                Ok(val) => result = val,
268                Err(err) => return Err(MapReduceError::Custom(err)),
269            }
270        }
271        Ok(result)
272    }
273
274    pub fn run(self, chunks: &[Req]) -> Result<ROk, MapReduceError<RErr>> {
275        let mut workers = make_workers(&self.reduce_manager.map_manager, chunks);
276        workers.reverse();
277        let result = self.do_map_reduce(&mut workers);
278        for rest_worker in workers {
279            let _ = rest_worker.thread.join();
280        }
281        result
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use std::str::FromStr;
289    use std::time::Duration;
290
291    #[test]
292    fn test_manager() {
293        let res = manager()
294            .map(|_, _, n: &usize| n.to_string())
295            .reduce(|_, a: String, b: String| format!("{}{}", a, b))
296            .run(&[1, 2, 3])
297            .unwrap();
298        assert_eq!(res, "123");
299    }
300
301    #[test]
302    fn test_with_context() {
303        let var: usize = 5;
304        let res = manager()
305            .context(&var)
306            .map(|ctx: &usize, _, n: &usize| (n + ctx).to_string())
307            .reduce(|_, a: String, b: String| format!("{}{}", a, b))
308            .run(&[1, 2, 3])
309            .unwrap();
310        assert_eq!(res, "678");
311    }
312
313    #[test]
314    fn test_complex_types() {
315        let arg = Box::new("abc");
316        let args: Vec<Box<dyn Fn(&str) -> String>> = vec![
317            Box::new(|s: &str| format!("{}a", s)),
318            Box::new(|s: &str| format!("{}b", s)),
319            Box::new(|s: &str| format!("{}c", s)),
320        ];
321        let res = manager()
322            .context(&*arg)
323            .map(|ctx, _, f: &Box<dyn Fn(&str) -> String>| f(ctx))
324            .reduce(|_, a: String, b: String| format!("({}{})", a, b))
325            .run(&args)
326            .unwrap();
327        assert_eq!(res, "(((abca)abcb)abcc)")
328    }
329
330    #[test]
331    fn test_slices() {
332        let params: &[&[usize]] = &[&[1, 2, 3], &[4, 5]];
333        let res = manager()
334            .map(|_, _, s: &&[usize]| s.len())
335            .reduce(|_, a, b| a + b)
336            .run(params)
337            .unwrap();
338        assert_eq!(res, 5);
339    }
340
341    #[test]
342    fn test_result_ok() {
343        let res = manager()
344            .map(|_, _, s: &&str| u8::from_str(s))
345            .reduce_result(|_, a: u8, b: u8| Ok(a + b))
346            .run(&["1", "2", "3"]);
347        match res {
348            Ok(6) => (),
349            _ => panic!(),
350        }
351    }
352
353    #[test]
354    fn test_result_err_map() {
355        let res = manager()
356            .map(|_, _, s: &&str| u8::from_str(s))
357            .reduce_result(|_, a: u8, b: u8| Ok(a + b))
358            .run(&["1", "arr", "3"]);
359        let err = match res {
360            Err(MapReduceError::Custom(err)) => err.to_string(),
361            _ => panic!(),
362        };
363        assert_eq!(err, "invalid digit found in string");
364    }
365
366    #[test]
367    fn test_result_err_reduce() {
368        let res = manager()
369            .map(|_, _, s: &u8| Ok(*s))
370            .reduce_result(|_, _: u8, _: u8| Err(()))
371            .run(&[1, 2, 3]);
372        match res {
373            Err(MapReduceError::Custom(())) => (),
374            _ => panic!(),
375        }
376    }
377
378    #[test]
379    fn test_sleeps_and_errors() {
380        let res = manager()
381            .map(|_, _, s: &u8| {
382                thread::sleep(Duration::from_secs(*s as u64));
383                if *s <= 1 {
384                    Ok(*s)
385                } else {
386                    Err(())
387                }
388            })
389            .reduce_result(|_, a, b| Ok(a + b))
390            .run(&[1, 1, 2, 3, 3]);
391        match res {
392            Err(MapReduceError::Custom(())) => (),
393            _ => panic!(),
394        }
395    }
396
397    #[test]
398    fn test_thread_id() {
399        let res: String = manager()
400            .map(|_, thread, val: &char| format!("{}:{}", thread, val))
401            .reduce(|_, a, b| format!("[{}{}]", a, b))
402            .run(&['a', 'b', 'c'])
403            .unwrap();
404        assert_eq!(res, "[[[0:a]1:b]2:c]")
405    }
406
407    struct StructWrapper {
408        data: usize,
409    }
410
411    fn func(context: &str, data: &[StructWrapper]) -> usize {
412        let res = manager()
413            .context(&context)
414            .map(|_, _, val: &&[StructWrapper]| {
415                let mut sum = 0;
416                for s in *val {
417                    sum += s.data
418                }
419                sum
420            })
421            .reduce(|_, a, b| a + b)
422            .run(&[&data[0..2], &data[2..4]]);
423        match res {
424            Ok(val) => val,
425            _ => panic!(),
426        }
427    }
428
429    fn func_res(context: &str, data: &[StructWrapper]) -> Result<Vec<String>, String> {
430        manager()
431            .context(&context)
432            .map(|ctx, _, val: &&[StructWrapper]| {
433                let mut sum = 0;
434                for s in *val {
435                    sum += s.data;
436                }
437                Ok(vec![format!("{}: {}", ctx, sum)])
438            })
439            .reduce_result(|_, mut a, mut b| {
440                a.append(&mut b);
441                let res: Result<Vec<String>, String> = Ok(a);
442                res
443            })
444            .run(&[&data[0..2], &data[2..4]])
445            .map_err(|_| "err".to_string())
446    }
447
448    #[test]
449    fn test_map_per_chunks() {
450        let wrapped = vec![
451            StructWrapper { data: 1 },
452            StructWrapper { data: 2 },
453            StructWrapper { data: 3 },
454            StructWrapper { data: 4 },
455        ];
456        assert_eq!(func("hello", &wrapped), 10);
457        assert_eq!(wrapped[0].data, 1);
458    }
459
460    #[test]
461    fn test_map_result_per_chunks() {
462        let ctx = "hi";
463        let wrapped = vec![
464            StructWrapper { data: 1 },
465            StructWrapper { data: 2 },
466            StructWrapper { data: 3 },
467            StructWrapper { data: 4 },
468        ];
469        assert_eq!(
470            func_res(ctx, &wrapped),
471            Ok(vec!["hi: 3".to_string(), "hi: 7".to_string()])
472        );
473        assert_eq!(wrapped[0].data, 1);
474        assert_eq!(ctx, "hi");
475    }
476}