pumps/
concurrency_base.rs1#[macro_export]
2macro_rules! concurrency_base {
3 (
4 input_receiver = $input_receiver:expr;
5 concurrency = $concurrency_var_name:expr;
6
7 on_input($input_var_name:ident, $in_progress_var_name:ident) => $input:block,
8 on_progress($output_var_name:ident, $output_sender_var_name:ident) => $output:block) => {
9 { let ($output_sender_var_name, output_receiver) =
10 tokio::sync::mpsc::channel(1);
11
12 let join_handle = tokio::spawn(async move {
13 let mut $in_progress_var_name = $crate::concurrency::FuturesContainer::new($concurrency_var_name.preserve_order);
14
15 loop {
16 let in_progress_len = $in_progress_var_name.len();
17 tokio::select! {
18 biased;
19
20 Some($input_var_name) = $input_receiver.recv(), if in_progress_len < $concurrency_var_name.concurrency => {
21 $input
22 },
23 Some($output_var_name) = $in_progress_var_name.next(), if in_progress_len > 0 => {
24 $output
25 },
26 else => break
27 }
28 }
29 });
30
31 (output_receiver, join_handle)}
32 };
33}
34
35#[cfg(test)]
36mod test {
37 use tokio::sync::mpsc;
38
39 use crate::{
40 concurrency::Concurrency,
41 test_utils::{FutureTimings, TestValue},
42 };
43
44 #[tokio::test]
45 async fn serial() {
46 let concurrency = Concurrency::serial();
47 let (input_sender, mut input_receiver) = mpsc::channel(100);
48
49 let timings = FutureTimings::new();
50
51 let map_fn = timings.get_tracked_fn(|value| value.id);
52
53 let (mut output_receiver, _join_handle) = concurrency_base! {
54 input_receiver = input_receiver;
55 concurrency = concurrency;
56
57 on_input(input, in_progress) => {
58 let f = map_fn(input);
59 in_progress.push_back(f);
60 },
61 on_progress(output, output_sender) => {
62 if let Err(_e) = output_sender.send(output).await {
63 break;
64 }
65 }
66 };
67
68 input_sender.send(TestValue::new(1, 30)).await.unwrap();
70 input_sender.send(TestValue::new(2, 20)).await.unwrap();
71 input_sender.send(TestValue::new(3, 10)).await.unwrap();
72
73 assert_eq!(output_receiver.recv().await, Some(1));
74 assert_eq!(output_receiver.recv().await, Some(2));
75 assert_eq!(output_receiver.recv().await, Some(3));
76
77 assert!(timings.run_after(3, 2).await);
78 assert!(timings.run_after(2, 1).await);
79
80 drop(input_sender);
81
82 assert_eq!(output_receiver.recv().await, None);
83 }
84
85 #[tokio::test]
86 async fn concurrency_2_unordered() {
87 let concurrency = Concurrency::concurrent_unordered(2);
88 let (input_sender, mut input_receiver) = mpsc::channel(100);
89
90 let timings = FutureTimings::new();
91
92 let map_fn = timings.get_tracked_fn(|value| value.id);
93
94 let (mut output_receiver, _join_handle) = concurrency_base! {
95 input_receiver = input_receiver;
96 concurrency = concurrency;
97
98 on_input(input, in_progress) => {
99 let f = map_fn(input);
100 in_progress.push_back(f);
101 },
102 on_progress(output, output_sender) => {
103 if let Err(_e) = output_sender.send(output).await {
104 break;
105 }
106 }
107 };
108
109 input_sender.send(TestValue::new(1, 20)).await.unwrap();
111 input_sender.send(TestValue::new(2, 10)).await.unwrap();
112 input_sender.send(TestValue::new(3, 15)).await.unwrap();
113
114 assert_eq!(output_receiver.recv().await, Some(2));
115 assert_eq!(output_receiver.recv().await, Some(1));
116 assert_eq!(output_receiver.recv().await, Some(3));
117
118 assert!(timings.run_in_parallel(1, 2).await);
119 assert!(timings.run_in_parallel(1, 3).await);
120 assert!(timings.run_after(3, 2).await);
121
122 drop(input_sender);
123
124 assert_eq!(output_receiver.recv().await, None);
125 }
126
127 #[tokio::test]
128 async fn concurrency_2_ordered() {
129 let concurrency = Concurrency::concurrent_ordered(2);
130 let (input_sender, mut input_receiver) = mpsc::channel(100);
131
132 let timings = FutureTimings::new();
133
134 let map_fn = timings.get_tracked_fn(|value| value.id);
135
136 let (mut output_receiver, _join_handle) = concurrency_base! {
137 input_receiver = input_receiver;
138 concurrency = concurrency;
139
140 on_input(input, in_progress) => {
141 let f = map_fn(input);
142 in_progress.push_back(f);
143 },
144 on_progress(output, output_sender) => {
145 if let Err(_e) = output_sender.send(output).await {
146 break;
147 }
148 }
149 };
150
151 input_sender.send(TestValue::new(1, 20)).await.unwrap();
153 input_sender.send(TestValue::new(2, 10)).await.unwrap();
154 input_sender.send(TestValue::new(3, 15)).await.unwrap();
155
156 assert_eq!(output_receiver.recv().await, Some(1));
157 assert_eq!(output_receiver.recv().await, Some(2));
158 assert_eq!(output_receiver.recv().await, Some(3));
159
160 assert!(timings.run_in_parallel(1, 2).await);
161 assert!(timings.run_after(3, 2).await);
162
163 drop(input_sender);
164
165 assert_eq!(output_receiver.recv().await, None);
166 }
167
168 #[tokio::test]
169 async fn concurrency_2_ordered_stops_without_consumer() {
170 let concurrency = Concurrency::concurrent_ordered(2);
171 let (input_sender, mut input_receiver) = mpsc::channel(100);
172
173 let timings = FutureTimings::new();
174
175 let map_fn = timings.get_tracked_fn(|value| value.id);
176
177 let (_output_receiver, _join_handle) = concurrency_base! {
178 input_receiver = input_receiver;
179 concurrency = concurrency;
180
181 on_input(input, in_progress) => {
182 let f = map_fn(input);
183 in_progress.push_back(f);
184 },
185 on_progress(output, output_sender) => {
186 if let Err(_e) = output_sender.send(output).await {
187 break;
188 }
189 }
190 };
191
192 input_sender.send(TestValue::new(1, 10)).await.unwrap();
193 input_sender.send(TestValue::new(2, 10)).await.unwrap();
194 input_sender.send(TestValue::new(3, 10)).await.unwrap();
195 input_sender.send(TestValue::new(4, 10)).await.unwrap();
196
197 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
198
199 assert!(timings.is_completed(1).await);
201 assert!(timings.is_completed(2).await);
202 assert!(!timings.is_completed(3).await);
203 assert!(!timings.is_completed(4).await);
204 }
205}