pumps/
concurrency_base.rs

1#[macro_export]
2macro_rules! concurrency_base {
3    (
4        input_receiver = $input_receiver:expr;
5        concurrency = $concurrency_var_name:expr;
6
7        on_input($input_var_name:ident, $in_progress_var_name:ident) => $input:block,
8        on_progress($output_var_name:ident, $output_sender_var_name:ident) => $output:block) => {
9           { let ($output_sender_var_name, output_receiver) =
10                tokio::sync::mpsc::channel(1);
11
12            let join_handle = tokio::spawn(async move {
13                let mut $in_progress_var_name = $crate::concurrency::FuturesContainer::new($concurrency_var_name.preserve_order);
14
15                loop {
16                    let in_progress_len = $in_progress_var_name.len();
17                    tokio::select! {
18                        biased;
19
20                        Some($input_var_name) = $input_receiver.recv(), if in_progress_len < $concurrency_var_name.concurrency => {
21                            $input
22                        },
23                        Some($output_var_name) = $in_progress_var_name.next(), if in_progress_len > 0 => {
24                            $output
25                        },
26                        else => break
27                    }
28                }
29            });
30
31            (output_receiver, join_handle)}
32    };
33}
34
35#[cfg(test)]
36mod test {
37    use tokio::sync::mpsc;
38
39    use crate::{
40        concurrency::Concurrency,
41        test_utils::{FutureTimings, TestValue},
42    };
43
44    #[tokio::test]
45    async fn serial() {
46        let concurrency = Concurrency::serial();
47        let (input_sender, mut input_receiver) = mpsc::channel(100);
48
49        let timings = FutureTimings::new();
50
51        let map_fn = timings.get_tracked_fn(|value| value.id);
52
53        let (mut output_receiver, _join_handle) = concurrency_base! {
54            input_receiver = input_receiver;
55            concurrency = concurrency;
56
57            on_input(input, in_progress) => {
58                let f = map_fn(input);
59                in_progress.push_back(f);
60            },
61            on_progress(output, output_sender) => {
62                if let Err(_e) = output_sender.send(output).await {
63                    break;
64                }
65            }
66        };
67
68        // values are sent with decreasing duration, but will be executed in order
69        input_sender.send(TestValue::new(1, 30)).await.unwrap();
70        input_sender.send(TestValue::new(2, 20)).await.unwrap();
71        input_sender.send(TestValue::new(3, 10)).await.unwrap();
72
73        assert_eq!(output_receiver.recv().await, Some(1));
74        assert_eq!(output_receiver.recv().await, Some(2));
75        assert_eq!(output_receiver.recv().await, Some(3));
76
77        assert!(timings.run_after(3, 2).await);
78        assert!(timings.run_after(2, 1).await);
79
80        drop(input_sender);
81
82        assert_eq!(output_receiver.recv().await, None);
83    }
84
85    #[tokio::test]
86    async fn concurrency_2_unordered() {
87        let concurrency = Concurrency::concurrent_unordered(2);
88        let (input_sender, mut input_receiver) = mpsc::channel(100);
89
90        let timings = FutureTimings::new();
91
92        let map_fn = timings.get_tracked_fn(|value| value.id);
93
94        let (mut output_receiver, _join_handle) = concurrency_base! {
95            input_receiver = input_receiver;
96            concurrency = concurrency;
97
98            on_input(input, in_progress) => {
99                let f = map_fn(input);
100                in_progress.push_back(f);
101            },
102            on_progress(output, output_sender) => {
103                if let Err(_e) = output_sender.send(output).await {
104                    break;
105                }
106            }
107        };
108
109        // (2) finishes first, (1) and (3) are executed concurrently
110        input_sender.send(TestValue::new(1, 20)).await.unwrap();
111        input_sender.send(TestValue::new(2, 10)).await.unwrap();
112        input_sender.send(TestValue::new(3, 15)).await.unwrap();
113
114        assert_eq!(output_receiver.recv().await, Some(2));
115        assert_eq!(output_receiver.recv().await, Some(1));
116        assert_eq!(output_receiver.recv().await, Some(3));
117
118        assert!(timings.run_in_parallel(1, 2).await);
119        assert!(timings.run_in_parallel(1, 3).await);
120        assert!(timings.run_after(3, 2).await);
121
122        drop(input_sender);
123
124        assert_eq!(output_receiver.recv().await, None);
125    }
126
127    #[tokio::test]
128    async fn concurrency_2_ordered() {
129        let concurrency = Concurrency::concurrent_ordered(2);
130        let (input_sender, mut input_receiver) = mpsc::channel(100);
131
132        let timings = FutureTimings::new();
133
134        let map_fn = timings.get_tracked_fn(|value| value.id);
135
136        let (mut output_receiver, _join_handle) = concurrency_base! {
137            input_receiver = input_receiver;
138            concurrency = concurrency;
139
140            on_input(input, in_progress) => {
141                let f = map_fn(input);
142                in_progress.push_back(f);
143            },
144            on_progress(output, output_sender) => {
145                if let Err(_e) = output_sender.send(output).await {
146                    break;
147                }
148            }
149        };
150
151        // (2) finishes first, but (2) and (3) are executed concurrently to keep order
152        input_sender.send(TestValue::new(1, 20)).await.unwrap();
153        input_sender.send(TestValue::new(2, 10)).await.unwrap();
154        input_sender.send(TestValue::new(3, 15)).await.unwrap();
155
156        assert_eq!(output_receiver.recv().await, Some(1));
157        assert_eq!(output_receiver.recv().await, Some(2));
158        assert_eq!(output_receiver.recv().await, Some(3));
159
160        assert!(timings.run_in_parallel(1, 2).await);
161        assert!(timings.run_after(3, 2).await);
162
163        drop(input_sender);
164
165        assert_eq!(output_receiver.recv().await, None);
166    }
167
168    #[tokio::test]
169    async fn concurrency_2_ordered_stops_without_consumer() {
170        let concurrency = Concurrency::concurrent_ordered(2);
171        let (input_sender, mut input_receiver) = mpsc::channel(100);
172
173        let timings = FutureTimings::new();
174
175        let map_fn = timings.get_tracked_fn(|value| value.id);
176
177        let (_output_receiver, _join_handle) = concurrency_base! {
178            input_receiver = input_receiver;
179            concurrency = concurrency;
180
181            on_input(input, in_progress) => {
182                let f = map_fn(input);
183                in_progress.push_back(f);
184            },
185            on_progress(output, output_sender) => {
186                if let Err(_e) = output_sender.send(output).await {
187                    break;
188                }
189            }
190        };
191
192        input_sender.send(TestValue::new(1, 10)).await.unwrap();
193        input_sender.send(TestValue::new(2, 10)).await.unwrap();
194        input_sender.send(TestValue::new(3, 10)).await.unwrap();
195        input_sender.send(TestValue::new(4, 10)).await.unwrap();
196
197        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
198
199        // 3, 4 did not run, as there is no more space in the output channel
200        assert!(timings.is_completed(1).await);
201        assert!(timings.is_completed(2).await);
202        assert!(!timings.is_completed(3).await);
203        assert!(!timings.is_completed(4).await);
204    }
205}