rig/pipeline/
parallel.rs

1use futures::{join, try_join};
2
3use super::{Op, TryOp};
4
5pub struct Parallel<Op1, Op2> {
6    op1: Op1,
7    op2: Op2,
8}
9
10impl<Op1, Op2> Parallel<Op1, Op2> {
11    pub fn new(op1: Op1, op2: Op2) -> Self {
12        Self { op1, op2 }
13    }
14}
15
16impl<Op1, Op2> Op for Parallel<Op1, Op2>
17where
18    Op1: Op,
19    Op1::Input: Clone,
20    Op2: Op<Input = Op1::Input>,
21{
22    type Input = Op1::Input;
23    type Output = (Op1::Output, Op2::Output);
24
25    #[inline]
26    async fn call(&self, input: Self::Input) -> Self::Output {
27        join!(self.op1.call(input.clone()), self.op2.call(input))
28    }
29}
30
31impl<Op1, Op2> TryOp for Parallel<Op1, Op2>
32where
33    Op1: TryOp,
34    Op1::Input: Clone,
35    Op2: TryOp<Input = Op1::Input, Error = Op1::Error>,
36{
37    type Input = Op1::Input;
38    type Output = (Op1::Output, Op2::Output);
39    type Error = Op1::Error;
40
41    #[inline]
42    async fn try_call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
43        try_join!(self.op1.try_call(input.clone()), self.op2.try_call(input))
44    }
45}
46
47// See https://doc.rust-lang.org/src/core/future/join.rs.html#48
48#[macro_export]
49macro_rules! parallel_internal {
50    // Last recursive step
51    (
52        // Accumulate a token for each future that has been expanded: "_ _ _".
53        current_position: [
54            $($underscores:tt)*
55        ]
56        // Accumulate values and their positions in the tuple: `_0th ()   _1st ( _ ) …`.
57        values_and_positions: [
58            $($acc:tt)*
59        ]
60        // Munch one value.
61        munching: [
62            $current:tt
63        ]
64    ) => (
65        $crate::parallel_internal! {
66            current_position: [
67                $($underscores)*
68                _
69            ]
70            values_and_positions: [
71                $($acc)*
72                $current ( $($underscores)* + )
73            ]
74            munching: []
75        }
76    );
77
78    // Recursion step: map each value with its "position" (underscore count).
79    (
80        // Accumulate a token for each future that has been expanded: "_ _ _".
81        current_position: [
82            $($underscores:tt)*
83        ]
84        // Accumulate values and their positions in the tuple: `_0th ()   _1st ( _ ) …`.
85        values_and_positions: [
86            $($acc:tt)*
87        ]
88        // Munch one value.
89        munching: [
90            $current:tt
91            $($rest:tt)+
92        ]
93    ) => (
94        $crate::parallel_internal! {
95            current_position: [
96                $($underscores)*
97                _
98            ]
99            values_and_positions: [
100                $($acc)*
101                $current ( $($underscores)* )
102            ]
103            munching: [
104                $($rest)*
105            ]
106        }
107    );
108
109    // End of recursion: flatten the values.
110    (
111        current_position: [
112            $($max:tt)*
113        ]
114        values_and_positions: [
115            $(
116                $val:tt ( $($pos:tt)* )
117            )*
118        ]
119        munching: []
120    ) => ({
121        use $crate::pipeline::op::Op;
122
123        $crate::parallel_op!($($val),*)
124            .map(|output| {
125                ($(
126                    {
127                        let $crate::tuple_pattern!(x $($pos)*) = output;
128                        x
129                    }
130                ),+)
131            })
132    })
133}
134
135#[macro_export]
136macro_rules! parallel_op {
137    ($op1:tt, $op2:tt) => {
138        $crate::pipeline::parallel::Parallel::new($op1, $op2)
139    };
140    ($op1:tt $(, $ops:tt)*) => {
141        $crate::pipeline::parallel::Parallel::new(
142            $op1,
143            $crate::parallel_op!($($ops),*)
144        )
145    };
146}
147
148#[macro_export]
149macro_rules! tuple_pattern {
150    ($id:ident +) => {
151        $id
152    };
153    ($id:ident) => {
154        ($id, ..)
155    };
156    ($id:ident _ $($symbols:tt)*) => {
157        (_, $crate::tuple_pattern!($id $($symbols)*))
158    };
159}
160
161#[macro_export]
162macro_rules! parallel {
163    ($($es:expr),+ $(,)?) => {
164        $crate::parallel_internal! {
165            current_position: []
166            values_and_positions: []
167            munching: [
168                $($es)+
169            ]
170        }
171    };
172}
173
174// See https://doc.rust-lang.org/src/core/future/join.rs.html#48
175#[macro_export]
176macro_rules! try_parallel_internal {
177    // Last recursive step
178    (
179        // Accumulate a token for each future that has been expanded: "_ _ _".
180        current_position: [
181            $($underscores:tt)*
182        ]
183        // Accumulate values and their positions in the tuple: `_0th ()   _1st ( _ ) …`.
184        values_and_positions: [
185            $($acc:tt)*
186        ]
187        // Munch one value.
188        munching: [
189            $current:tt
190        ]
191    ) => (
192        $crate::try_parallel_internal! {
193            current_position: [
194                $($underscores)*
195                _
196            ]
197            values_and_positions: [
198                $($acc)*
199                $current ( $($underscores)* + )
200            ]
201            munching: []
202        }
203    );
204
205    // Recursion step: map each value with its "position" (underscore count).
206    (
207        // Accumulate a token for each future that has been expanded: "_ _ _".
208        current_position: [
209            $($underscores:tt)*
210        ]
211        // Accumulate values and their positions in the tuple: `_0th ()   _1st ( _ ) …`.
212        values_and_positions: [
213            $($acc:tt)*
214        ]
215        // Munch one value.
216        munching: [
217            $current:tt
218            $($rest:tt)+
219        ]
220    ) => (
221        $crate::try_parallel_internal! {
222            current_position: [
223                $($underscores)*
224                _
225            ]
226            values_and_positions: [
227                $($acc)*
228                $current ( $($underscores)* )
229            ]
230            munching: [
231                $($rest)*
232            ]
233        }
234    );
235
236    // End of recursion: flatten the values.
237    (
238        current_position: [
239            $($max:tt)*
240        ]
241        values_and_positions: [
242            $(
243                $val:tt ( $($pos:tt)* )
244            )*
245        ]
246        munching: []
247    ) => ({
248        use $crate::pipeline::try_op::TryOp;
249        $crate::parallel_op!($($val),*)
250            .map_ok(|output| {
251                ($(
252                    {
253                        let $crate::tuple_pattern!(x $($pos)*) = output;
254                        x
255                    }
256                ),+)
257            })
258    })
259}
260
261#[macro_export]
262macro_rules! try_parallel {
263    ($($es:expr),+ $(,)?) => {
264        $crate::try_parallel_internal! {
265            current_position: []
266            values_and_positions: []
267            munching: [
268                $($es)+
269            ]
270        }
271    };
272}
273
274pub use parallel;
275pub use parallel_internal;
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::pipeline::{
281        self,
282        op::{Sequential, map},
283        passthrough, then,
284    };
285
286    #[tokio::test]
287    async fn test_parallel() {
288        let op1 = map(|x: i32| x + 1);
289        let op2 = map(|x: i32| x * 3);
290        let pipeline = Parallel::new(op1, op2);
291
292        let result = pipeline.call(1).await;
293        assert_eq!(result, (2, 3));
294    }
295
296    #[tokio::test]
297    async fn test_parallel_nested() {
298        let op1 = map(|x: i32| x + 1);
299        let op2 = map(|x: i32| x * 3);
300        let op3 = map(|x: i32| format!("{x} is the number!"));
301        let op4 = map(|x: i32| x - 1);
302
303        let pipeline = Parallel::new(Parallel::new(Parallel::new(op1, op2), op3), op4);
304
305        let result = pipeline.call(1).await;
306        assert_eq!(result, (((2, 3), "1 is the number!".to_string()), 0));
307    }
308
309    #[tokio::test]
310    async fn test_parallel_nested_rev() {
311        let op1 = map(|x: i32| x + 1);
312        let op2 = map(|x: i32| x * 3);
313        let op3 = map(|x: i32| format!("{x} is the number!"));
314        let op4 = map(|x: i32| x == 1);
315
316        let pipeline = Parallel::new(op1, Parallel::new(op2, Parallel::new(op3, op4)));
317
318        let result = pipeline.call(1).await;
319        assert_eq!(result, (2, (3, ("1 is the number!".to_string(), true))));
320    }
321
322    #[tokio::test]
323    async fn test_sequential_and_parallel() {
324        let op1 = map(|x: i32| x + 1);
325        let op2 = map(|x: i32| x * 2);
326        let op3 = map(|x: i32| x * 3);
327        let op4 = map(|(x, y): (i32, i32)| x + y);
328
329        let pipeline = Sequential::new(Sequential::new(op1, Parallel::new(op2, op3)), op4);
330
331        let result = pipeline.call(1).await;
332        assert_eq!(result, 10);
333    }
334
335    #[tokio::test]
336    async fn test_parallel_chain_compile_check() {
337        let _ = pipeline::new().chain(
338            Parallel::new(
339                map(|x: i32| x + 1),
340                Parallel::new(
341                    map(|x: i32| x * 3),
342                    Parallel::new(
343                        map(|x: i32| format!("{x} is the number!")),
344                        map(|x: i32| x == 1),
345                    ),
346                ),
347            )
348            .map(|(r1, (r2, (r3, r4)))| (r1, r2, r3, r4)),
349        );
350    }
351
352    #[tokio::test]
353    async fn test_parallel_pass_through() {
354        let pipeline = then(|x| {
355            let op = Parallel::new(Parallel::new(passthrough(), passthrough()), passthrough());
356
357            async move {
358                let ((r1, r2), r3) = op.call(x).await;
359                (r1, r2, r3)
360            }
361        });
362
363        let result = pipeline.call(1).await;
364        assert_eq!(result, (1, 1, 1));
365    }
366
367    #[tokio::test]
368    async fn test_parallel_macro() {
369        let op2 = map(|x: i32| x * 2);
370
371        let pipeline = parallel!(
372            passthrough(),
373            op2,
374            map(|x: i32| format!("{x} is the number!")),
375            map(|x: i32| x == 1)
376        );
377
378        let result = pipeline.call(1).await;
379        assert_eq!(result, (1, 2, "1 is the number!".to_string(), true));
380    }
381
382    #[tokio::test]
383    async fn test_try_parallel_chain_compile_check() {
384        let chain = pipeline::new().chain(
385            Parallel::new(
386                map(|x: i32| Ok::<_, String>(x + 1)),
387                Parallel::new(
388                    map(|x: i32| Ok::<_, String>(x * 3)),
389                    Parallel::new(
390                        map(|x: i32| Err::<i32, _>(format!("{x} is the number!"))),
391                        map(|x: i32| Ok::<_, String>(x == 1)),
392                    ),
393                ),
394            )
395            .map_ok(|(r1, (r2, (r3, r4)))| (r1, r2, r3, r4)),
396        );
397
398        let response = chain.call(1).await;
399        assert_eq!(response, Err("1 is the number!".to_string()));
400    }
401
402    #[tokio::test]
403    async fn test_try_parallel_macro_ok() {
404        let op2 = map(|x: i32| Ok::<_, String>(x * 2));
405
406        let pipeline = try_parallel!(
407            map(|x: i32| Ok::<_, String>(x)),
408            op2,
409            map(|x: i32| Ok::<_, String>(format!("{x} is the number!"))),
410            map(|x: i32| Ok::<_, String>(x == 1))
411        );
412
413        let result = pipeline.try_call(1).await;
414        assert_eq!(result, Ok((1, 2, "1 is the number!".to_string(), true)));
415    }
416
417    #[tokio::test]
418    async fn test_try_parallel_macro_err() {
419        let op2 = map(|x: i32| Ok::<_, String>(x * 2));
420
421        let pipeline = try_parallel!(
422            map(|x: i32| Ok::<_, String>(x)),
423            op2,
424            map(|x: i32| Err::<i32, _>(format!("{x} is the number!"))),
425            map(|x: i32| Ok::<_, String>(x == 1))
426        );
427
428        let result = pipeline.try_call(1).await;
429        assert_eq!(result, Err("1 is the number!".to_string()));
430    }
431}