Skip to main content

sabi/tokio/
async_group.rs

1// Copyright (C) 2024-2026 Takayuki Sato. All Rights Reserved.
2// This program is free software under MIT License.
3// See the file LICENSE in this distribution for more details.
4
5use super::AsyncGroup;
6
7use futures::future;
8use std::future::Future;
9use std::sync::Arc;
10
11impl AsyncGroup {
12    #[allow(clippy::new_without_default)]
13    pub fn new() -> Self {
14        Self {
15            tasks: Vec::new(),
16            names: Vec::new(),
17            _name: "".into(),
18        }
19    }
20
21    /// Adds a future to the AsyncGroup to be executed concurrently.
22    ///
23    /// The provided future will be polled along with others in this group.
24    ///
25    /// # Arguments
26    ///
27    /// * `future` - The future to add. It must implement `Future<Output = errs::Result<()>>`,
28    ///              `Send`, and have a `'static` lifetime.
29    #[allow(clippy::doc_overindented_list_items)]
30    pub fn add<Fut>(&mut self, future: Fut)
31    where
32        Fut: Future<Output = errs::Result<()>> + Send + 'static,
33    {
34        self.tasks.push(Box::pin(future));
35        self.names.push(self._name.clone());
36    }
37
38    pub(crate) async fn join_and_collect_errors_async(
39        self,
40        errors: &mut Vec<(Arc<str>, errs::Err)>,
41    ) {
42        if self.tasks.is_empty() {
43            return;
44        }
45
46        let result_all = future::join_all(self.tasks).await;
47        for (i, result) in result_all.into_iter().enumerate() {
48            if let Err(err) = result {
49                errors.push((self.names[i].clone(), err));
50            }
51        }
52    }
53
54    pub(crate) async fn join_and_ignore_errors_async(self) {
55        let _ = future::join_all(self.tasks).await;
56    }
57
58    #[inline]
59    pub async fn join_async(self) -> Vec<(Arc<str>, errs::Err)> {
60        let mut vec = Vec::new();
61        self.join_and_collect_errors_async(&mut vec).await;
62        vec
63    }
64}
65
66#[cfg(test)]
67mod tests_of_async_group {
68    use super::*;
69    use std::sync::Arc;
70    use tokio::sync::Mutex;
71    use tokio::time;
72
73    const BASE_LINE: u32 = line!();
74
75    #[derive(Debug, PartialEq)]
76    enum Reasons {
77        BadString(String),
78    }
79
80    struct MyStruct {
81        string: Arc<Mutex<String>>,
82        fail: bool,
83    }
84
85    impl MyStruct {
86        fn new(s: String, fail: bool) -> Self {
87            Self {
88                string: Arc::new(Mutex::new(s)),
89                fail,
90            }
91        }
92
93        fn process(&self, ag: &mut AsyncGroup) {
94            let s_clone = self.string.clone();
95            let fail = self.fail;
96            ag.add(async move {
97                // The `.await` must be executed outside the Mutex lock.
98                let _ = time::sleep(time::Duration::from_millis(100)).await;
99
100                {
101                    let mut s = s_clone.lock().await;
102                    if fail {
103                        return Err(errs::Err::new(Reasons::BadString((*s).to_string())));
104                    }
105                    *s = s.to_uppercase();
106                }
107
108                Ok(())
109            });
110        }
111
112        fn process_multiple(&self, ag: &mut AsyncGroup) {
113            let s_clone = self.string.clone();
114            let fail = self.fail;
115            ag.add(async move {
116                // The `.await` must be executed outside the Mutex lock.
117                let _ = time::sleep(time::Duration::from_millis(100)).await;
118
119                {
120                    let mut s = s_clone.lock().await;
121                    if fail {
122                        return Err(errs::Err::new(Reasons::BadString((*s).to_string())));
123                    }
124                    *s = s.to_uppercase();
125                }
126
127                Ok(())
128            });
129
130            let s_clone = self.string.clone();
131            let fail = self.fail;
132            ag.add(async move {
133                // The `.await` must be executed outside the Mutex lock.
134                let _ = time::sleep(time::Duration::from_millis(100)).await;
135
136                {
137                    let mut s = s_clone.lock().await;
138                    if fail {
139                        return Err(errs::Err::new(Reasons::BadString((*s).to_string())));
140                    }
141                    *s = s.to_uppercase();
142                }
143
144                Ok(())
145            });
146        }
147    }
148
149    mod tests_of_join_and_collect_errors {
150        use super::*;
151
152        #[tokio::test]
153        async fn zero() {
154            let ag = AsyncGroup::new();
155
156            let mut err_vec = Vec::new();
157            ag.join_and_collect_errors_async(&mut err_vec).await;
158
159            assert!(err_vec.is_empty());
160        }
161
162        #[tokio::test]
163        async fn single_ok() {
164            let mut ag = AsyncGroup::new();
165
166            let struct_a = MyStruct::new("a".to_string(), false);
167            assert_eq!(*struct_a.string.lock().await, "a");
168
169            ag._name = "foo".into();
170            struct_a.process(&mut ag);
171
172            let mut errors = Vec::new();
173            ag.join_and_collect_errors_async(&mut errors).await;
174
175            assert!(errors.is_empty());
176            assert_eq!(*struct_a.string.lock().await, "A");
177        }
178
179        #[tokio::test]
180        async fn single_fail() {
181            let mut ag = AsyncGroup::new();
182
183            let struct_a = MyStruct::new("a".to_string(), true);
184            assert_eq!(*struct_a.string.lock().await, "a");
185
186            ag._name = "foo".into();
187            struct_a.process(&mut ag);
188
189            let mut errors = Vec::new();
190            ag.join_and_collect_errors_async(&mut errors).await;
191
192            assert_eq!(errors.len(), 1);
193            assert_eq!(*struct_a.string.lock().await, "a");
194
195            assert_eq!(errors[0].0, "foo".into());
196            #[cfg(unix)]
197            assert_eq!(
198                format!("{:?}", errors[0].1),
199                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"a\"), file = src/tokio/async_group.rs, line = ".to_string() + &(BASE_LINE + 30).to_string() + " }"
200            );
201            #[cfg(windows)]
202            assert_eq!(
203                format!("{:?}", errors[0].1),
204                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"a\"), file = src\\tokio\\async_group.rs, line = ".to_string() + &(BASE_LINE + 30).to_string() + " }"
205            );
206        }
207
208        #[tokio::test]
209        async fn multiple_ok() {
210            let mut ag = AsyncGroup::new();
211
212            let struct_a = MyStruct::new("a".to_string(), false);
213            assert_eq!(*struct_a.string.lock().await, "a".to_string());
214
215            let struct_b = MyStruct::new("b".to_string(), false);
216            assert_eq!(*struct_b.string.lock().await, "b".to_string());
217
218            let struct_c = MyStruct::new("c".to_string(), false);
219            assert_eq!(*struct_c.string.lock().await, "c".to_string());
220
221            ag._name = "foo".into();
222            struct_a.process(&mut ag);
223
224            ag._name = "bar".into();
225            struct_b.process(&mut ag);
226
227            ag._name = "baz".into();
228            struct_c.process(&mut ag);
229
230            let mut err_vec = Vec::new();
231            ag.join_and_collect_errors_async(&mut err_vec).await;
232
233            assert_eq!(err_vec.len(), 0);
234
235            assert_eq!(*struct_a.string.lock().await, "A");
236            assert_eq!(*struct_b.string.lock().await, "B");
237            assert_eq!(*struct_c.string.lock().await, "C");
238        }
239
240        #[tokio::test]
241        async fn multiple_processes_and_single_fail() {
242            let mut ag = AsyncGroup::new();
243
244            let struct_a = MyStruct::new("a".to_string(), false);
245            assert_eq!(*struct_a.string.lock().await, "a");
246
247            let struct_b = MyStruct::new("b".to_string(), true);
248            assert_eq!(*struct_b.string.lock().await, "b");
249
250            let struct_c = MyStruct::new("c".to_string(), false);
251            assert_eq!(*struct_c.string.lock().await, "c");
252
253            ag._name = "foo".into();
254            struct_a.process(&mut ag);
255
256            ag._name = "bar".into();
257            struct_b.process(&mut ag);
258
259            ag._name = "baz".into();
260            struct_c.process(&mut ag);
261
262            let mut err_vec = Vec::new();
263            ag.join_and_collect_errors_async(&mut err_vec).await;
264
265            assert_eq!(err_vec.len(), 1);
266            assert_eq!(err_vec[0].0, "bar".into());
267            #[cfg(unix)]
268            assert_eq!(
269                format!("{:?}", err_vec[0].1),
270                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"b\"), file = src/tokio/async_group.rs, line = ".to_string() + &(BASE_LINE + 30).to_string() + " }",
271            );
272            #[cfg(windows)]
273            assert_eq!(
274                format!("{:?}", err_vec[0].1),
275                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"b\"), file = src\\tokio\\async_group.rs, line = ".to_string() + &(BASE_LINE + 30).to_string() + " }",
276            );
277
278            assert_eq!(*struct_a.string.lock().await, "A");
279            assert_eq!(*struct_b.string.lock().await, "b");
280            assert_eq!(*struct_c.string.lock().await, "C");
281        }
282
283        #[tokio::test]
284        async fn multiple_fail() {
285            let mut ag = AsyncGroup::new();
286
287            let struct_a = MyStruct::new("a".to_string(), true);
288            assert_eq!(*struct_a.string.lock().await, "a");
289
290            let struct_b = MyStruct::new("b".to_string(), true);
291            assert_eq!(*struct_b.string.lock().await, "b");
292
293            let struct_c = MyStruct::new("c".to_string(), true);
294            assert_eq!(*struct_c.string.lock().await, "c");
295
296            ag._name = "foo".into();
297            struct_a.process(&mut ag);
298
299            ag._name = "bar".into();
300            struct_b.process(&mut ag);
301
302            ag._name = "baz".into();
303            struct_c.process(&mut ag);
304
305            let mut err_vec = Vec::new();
306            ag.join_and_collect_errors_async(&mut err_vec).await;
307
308            assert_eq!(err_vec.len(), 3);
309
310            assert_eq!(err_vec[0].0, "foo".into());
311            assert_eq!(err_vec[1].0, "bar".into());
312            assert_eq!(err_vec[2].0, "baz".into());
313
314            #[cfg(unix)]
315            assert_eq!(
316                format!("{:?}", err_vec[0].1),
317                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"a\"), file = src/tokio/async_group.rs, line = ".to_string() + &(BASE_LINE + 30).to_string() + " }",
318            );
319            #[cfg(windows)]
320            assert_eq!(
321                format!("{:?}", err_vec[0].1),
322                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"a\"), file = src\\tokio\\async_group.rs, line = ".to_string() + &(BASE_LINE + 30).to_string() + " }",
323            );
324            #[cfg(unix)]
325            assert_eq!(
326                format!("{:?}", err_vec[1].1),
327                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"b\"), file = src/tokio/async_group.rs, line = ".to_string() + &(BASE_LINE + 30).to_string() + " }",
328            );
329            #[cfg(windows)]
330            assert_eq!(
331                format!("{:?}", err_vec[1].1),
332                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"b\"), file = src\\tokio\\async_group.rs, line = ".to_string() + &(BASE_LINE + 30).to_string() + " }",
333            );
334            #[cfg(unix)]
335            assert_eq!(
336                format!("{:?}", err_vec[2].1),
337                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"c\"), file = src/tokio/async_group.rs, line = ".to_string() + &(BASE_LINE + 30).to_string() + " }",
338            );
339            #[cfg(windows)]
340            assert_eq!(
341                format!("{:?}", err_vec[2].1),
342                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"c\"), file = src\\tokio\\async_group.rs, line = ".to_string() + &(BASE_LINE + 30).to_string() + " }",
343            );
344
345            assert_eq!(*struct_a.string.lock().await, "a");
346            assert_eq!(*struct_b.string.lock().await, "b");
347            assert_eq!(*struct_c.string.lock().await, "c");
348        }
349
350        #[tokio::test]
351        async fn data_src_execute_multiple_setup_handles() {
352            let mut ag = AsyncGroup::new();
353
354            let struct_d = MyStruct::new("d".to_string(), false);
355            assert_eq!(*struct_d.string.lock().await, "d");
356
357            ag._name = "foo".into();
358            struct_d.process(&mut ag);
359
360            let mut err_vec = Vec::new();
361            ag.join_and_collect_errors_async(&mut err_vec).await;
362
363            assert_eq!(err_vec.len(), 0);
364
365            assert_eq!(*struct_d.string.lock().await, "D");
366        }
367
368        #[tokio::test]
369        async fn collect_all_errors_if_data_src_executes_multiple_setup_handles() {
370            let mut ag = AsyncGroup::new();
371
372            let struct_d = MyStruct::new("d".to_string(), true);
373            assert_eq!(*struct_d.string.lock().await, "d");
374
375            ag._name = "foo".into();
376            struct_d.process_multiple(&mut ag);
377
378            let mut err_vec = Vec::new();
379            ag.join_and_collect_errors_async(&mut err_vec).await;
380
381            assert_eq!(err_vec.len(), 2);
382
383            assert_eq!(err_vec[0].0, "foo".into());
384            assert_eq!(err_vec[1].0, "foo".into());
385
386            #[cfg(unix)]
387            assert_eq!(
388                format!("{:?}", err_vec[0].1),
389                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"d\"), file = src/tokio/async_group.rs, line = ".to_string() + &(BASE_LINE + 49).to_string() + " }",
390            );
391            #[cfg(windows)]
392            assert_eq!(
393                format!("{:?}", err_vec[0].1),
394                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"d\"), file = src\\tokio\\async_group.rs, line = ".to_string() + &(BASE_LINE + 49).to_string() + " }"
395            );
396
397            #[cfg(unix)]
398            assert_eq!(
399                format!("{:?}", err_vec[1].1),
400                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"d\"), file = src/tokio/async_group.rs, line = ".to_string() + &(BASE_LINE + 66).to_string() + " }",
401            );
402            #[cfg(windows)]
403            assert_eq!(
404                format!("{:?}", err_vec[1].1),
405                "errs::Err { reason = sabi::tokio::async_group::tests_of_async_group::Reasons BadString(\"d\"), file = src\\tokio\\async_group.rs, line = ".to_string() + &(BASE_LINE + 66).to_string() + " }"
406            );
407
408            assert_eq!(*struct_d.string.lock().await, "d");
409        }
410    }
411
412    mod tests_join_and_ignore_errors_async {
413        use super::*;
414
415        #[tokio::test]
416        async fn zero() {
417            let ag = AsyncGroup::new();
418
419            ag.join_and_ignore_errors_async().await;
420        }
421
422        #[tokio::test]
423        async fn single_ok() {
424            let mut ag = AsyncGroup::new();
425
426            let struct_a = MyStruct::new("a".to_string(), false);
427            assert_eq!(*struct_a.string.lock().await, "a".to_string());
428
429            ag._name = "foo".into();
430            struct_a.process(&mut ag);
431
432            ag.join_and_ignore_errors_async().await;
433            assert_eq!(*struct_a.string.lock().await, "A".to_string());
434        }
435
436        #[tokio::test]
437        async fn single_fail() {
438            let mut ag = AsyncGroup::new();
439
440            let struct_a = MyStruct::new("a".to_string(), true);
441            assert_eq!(*struct_a.string.lock().await, "a".to_string());
442
443            ag._name = "foo".into();
444            struct_a.process(&mut ag);
445
446            ag.join_and_ignore_errors_async().await;
447            assert_eq!(*struct_a.string.lock().await, "a".to_string());
448        }
449
450        #[tokio::test]
451        async fn multiple_ok() {
452            let mut ag = AsyncGroup::new();
453
454            let struct_a = MyStruct::new("a".to_string(), false);
455            assert_eq!(*struct_a.string.lock().await, "a".to_string());
456
457            let struct_b = MyStruct::new("b".to_string(), false);
458            assert_eq!(*struct_b.string.lock().await, "b".to_string());
459
460            let struct_c = MyStruct::new("c".to_string(), false);
461            assert_eq!(*struct_c.string.lock().await, "c".to_string());
462
463            ag._name = "foo".into();
464            struct_a.process(&mut ag);
465
466            ag._name = "bar".into();
467            struct_b.process(&mut ag);
468
469            ag._name = "baz".into();
470            struct_c.process(&mut ag);
471
472            ag.join_and_ignore_errors_async().await;
473
474            assert_eq!(*struct_a.string.lock().await, "A");
475            assert_eq!(*struct_b.string.lock().await, "B");
476            assert_eq!(*struct_c.string.lock().await, "C");
477        }
478
479        #[tokio::test]
480        async fn multiple_processes_and_single_fail() {
481            let mut ag = AsyncGroup::new();
482
483            let struct_a = MyStruct::new("a".to_string(), false);
484            assert_eq!(*struct_a.string.lock().await, "a".to_string());
485
486            let struct_b = MyStruct::new("b".to_string(), true);
487            assert_eq!(*struct_b.string.lock().await, "b".to_string());
488
489            let struct_c = MyStruct::new("c".to_string(), false);
490            assert_eq!(*struct_c.string.lock().await, "c".to_string());
491
492            ag._name = "foo".into();
493            struct_a.process(&mut ag);
494
495            ag._name = "bar".into();
496            struct_b.process(&mut ag);
497
498            ag._name = "baz".into();
499            struct_c.process(&mut ag);
500
501            ag.join_and_ignore_errors_async().await;
502
503            assert_eq!(*struct_a.string.lock().await, "A");
504            assert_eq!(*struct_b.string.lock().await, "b");
505            assert_eq!(*struct_c.string.lock().await, "C");
506        }
507
508        #[tokio::test]
509        async fn multiple_fail() {
510            let mut ag = AsyncGroup::new();
511
512            let struct_a = MyStruct::new("a".to_string(), true);
513            assert_eq!(*struct_a.string.lock().await, "a".to_string());
514
515            let struct_b = MyStruct::new("b".to_string(), true);
516            assert_eq!(*struct_b.string.lock().await, "b".to_string());
517
518            let struct_c = MyStruct::new("c".to_string(), true);
519            assert_eq!(*struct_c.string.lock().await, "c".to_string());
520
521            ag._name = "foo".into();
522            struct_a.process(&mut ag);
523
524            ag._name = "bar".into();
525            struct_b.process(&mut ag);
526
527            ag._name = "baz".into();
528            struct_c.process(&mut ag);
529
530            ag.join_and_ignore_errors_async().await;
531
532            assert_eq!(*struct_a.string.lock().await, "a");
533            assert_eq!(*struct_b.string.lock().await, "b");
534            assert_eq!(*struct_c.string.lock().await, "c");
535        }
536    }
537}