Skip to main content

component_map/
async_fallible.rs

1use crate::{ComponentManager, Keyed, WithArgs};
2use futures::future::join_all;
3
4impl<Key, Args, Comp, FnInit> ComponentManager<Key, Args, Comp, FnInit> {
5    pub async fn try_init_async<Error>(
6        args: impl IntoIterator<Item = (Key, Args)>,
7        init: FnInit,
8    ) -> Result<Self, Error>
9    where
10        Key: Eq + std::hash::Hash,
11        FnInit: AsyncFn(&Args) -> Result<Comp, Error> + Clone,
12    {
13        let components_fut = args.into_iter().map(|(key, args)| {
14            let init = init.clone();
15            async move {
16                let result = (init)(&args)
17                    .await
18                    .map(|component| WithArgs { component, args });
19
20                (key, result)
21            }
22        });
23
24        let map = join_all(components_fut)
25            .await
26            .into_iter()
27            .map(|(key, result)| result.map(|component| (key, component)))
28            .collect::<Result<_, _>>()?;
29
30        Ok(Self { map, init })
31    }
32
33    pub async fn try_reinit_all_async<Error>(
34        &mut self,
35    ) -> impl Iterator<Item = Keyed<&Key, Result<Comp, Error>>>
36    where
37        Key: Clone,
38        FnInit: AsyncFn(&Args) -> Result<Comp, Error> + Clone,
39    {
40        let next_components_fut = self
41            .map
42            .values()
43            .map(|component| (self.init)(&component.args));
44
45        let next_components = join_all(next_components_fut).await;
46
47        self.map
48            .iter_mut()
49            .zip(next_components)
50            .map(|((key, prev), result)| {
51                let result = result.map(|next| std::mem::replace(&mut prev.component, next));
52
53                Keyed::new(key, result)
54            })
55    }
56
57    pub async fn try_reinit_async<Error>(
58        &mut self,
59        keys: impl IntoIterator<Item = Key>,
60    ) -> impl Iterator<Item = Keyed<Key, Option<Result<Comp, Error>>>>
61    where
62        Key: Eq + std::hash::Hash + Clone,
63        FnInit: AsyncFn(&Args) -> Result<Comp, Error> + Clone,
64    {
65        let next_components_fut = keys.into_iter().map(|key| {
66            let init = self.init.clone();
67            let args = self.map.get(&key).map(|component| &component.args);
68            async move {
69                let result = match args {
70                    Some(args) => Some((init)(args).await),
71                    None => None,
72                };
73                Keyed::new(key, result)
74            }
75        });
76
77        let results = join_all(next_components_fut).await;
78
79        results.into_iter().map(|Keyed { key, value: result }| {
80            let prev = result
81                .map(|result| {
82                    result.map(|next| {
83                        self.map
84                            .get_mut(&key)
85                            .map(|component| std::mem::replace(&mut component.component, next))
86                    })
87                })
88                .transpose()
89                .map(Option::flatten);
90
91            Keyed::new(key, prev.transpose())
92        })
93    }
94
95    pub async fn try_update_async<Error>(
96        &mut self,
97        updates: impl IntoIterator<Item = (Key, Args)>,
98    ) -> impl Iterator<Item = Keyed<Key, Option<Result<WithArgs<Args, Comp>, Error>>>>
99    where
100        Key: Clone + Eq + std::hash::Hash,
101        FnInit: AsyncFn(&Args) -> Result<Comp, Error> + Clone,
102    {
103        let updated_components_fut = updates.into_iter().map(|(key, args)| {
104            let init = self.init.clone();
105            async move {
106                let result = (init)(&args)
107                    .await
108                    .map(|component| WithArgs { component, args });
109
110                (key, result)
111            }
112        });
113
114        join_all(updated_components_fut)
115            .await
116            .into_iter()
117            .map(|(key, result)| {
118                let result = result.map(|component| self.map.insert(key.clone(), component));
119
120                Keyed::new(key, result.transpose())
121            })
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use std::sync::{Arc, Mutex};
129
130    #[derive(Debug, Clone, PartialEq, Eq)]
131    struct Counter(usize);
132
133    #[derive(Debug, Clone, PartialEq, Eq)]
134    struct FailArgs {
135        value: usize,
136        should_fail: bool,
137    }
138
139    #[derive(Debug, PartialEq, Eq)]
140    struct TestError(String);
141
142    #[tokio::test]
143    async fn test_try_init_async_success() {
144        let init = |args: &FailArgs| {
145            let value = args.value;
146            let should_fail = args.should_fail;
147            async move {
148                if should_fail {
149                    Err(TestError("Failed".to_string()))
150                } else {
151                    Ok(Counter(value))
152                }
153            }
154        };
155
156        let result = ComponentManager::try_init_async(
157            [
158                (
159                    "key1",
160                    FailArgs {
161                        value: 1,
162                        should_fail: false,
163                    },
164                ),
165                (
166                    "key2",
167                    FailArgs {
168                        value: 2,
169                        should_fail: false,
170                    },
171                ),
172            ],
173            init,
174        )
175        .await;
176
177        assert!(result.is_ok());
178        let manager = result.unwrap();
179        assert_eq!(manager.components().len(), 2);
180        assert_eq!(
181            manager.components().get("key1").unwrap().component,
182            Counter(1)
183        );
184        assert_eq!(
185            manager.components().get("key2").unwrap().component,
186            Counter(2)
187        );
188    }
189
190    #[tokio::test]
191    async fn test_try_init_async_failure() {
192        let init = |args: &FailArgs| {
193            let value = args.value;
194            let should_fail = args.should_fail;
195            async move {
196                if should_fail {
197                    Err(TestError("Failed".to_string()))
198                } else {
199                    Ok(Counter(value))
200                }
201            }
202        };
203
204        let result = ComponentManager::try_init_async(
205            [
206                (
207                    "key1",
208                    FailArgs {
209                        value: 1,
210                        should_fail: false,
211                    },
212                ),
213                (
214                    "key2",
215                    FailArgs {
216                        value: 2,
217                        should_fail: true,
218                    },
219                ),
220            ],
221            init,
222        )
223        .await;
224
225        assert!(result.is_err());
226        assert_eq!(result.err().unwrap(), TestError("Failed".to_string()));
227    }
228
229    #[tokio::test]
230    async fn test_try_init_async_empty() {
231        let init = |args: &FailArgs| {
232            let value = args.value;
233            let should_fail = args.should_fail;
234            async move {
235                if should_fail {
236                    Err(TestError("Failed".to_string()))
237                } else {
238                    Ok(Counter(value))
239                }
240            }
241        };
242
243        let result: Result<ComponentManager<&str, FailArgs, Counter, _>, TestError> =
244            ComponentManager::try_init_async([], init).await;
245
246        assert!(result.is_ok());
247        assert_eq!(result.unwrap().components().len(), 0);
248    }
249
250    #[tokio::test]
251    async fn test_try_reinit_all_async_success() {
252        let init = |args: &FailArgs| {
253            let value = args.value;
254            let should_fail = args.should_fail;
255            async move {
256                if should_fail {
257                    Err(TestError("Failed".to_string()))
258                } else {
259                    Ok(Counter(value * 2))
260                }
261            }
262        };
263
264        let mut manager = ComponentManager::try_init_async(
265            [
266                (
267                    "key1",
268                    FailArgs {
269                        value: 1,
270                        should_fail: false,
271                    },
272                ),
273                (
274                    "key2",
275                    FailArgs {
276                        value: 2,
277                        should_fail: false,
278                    },
279                ),
280            ],
281            init,
282        )
283        .await
284        .unwrap();
285
286        let results: Vec<_> = manager.try_reinit_all_async().await.collect();
287
288        assert_eq!(results.len(), 2);
289        assert!(results.iter().all(|r| r.value.is_ok()));
290
291        assert_eq!(
292            manager.components().get("key1").unwrap().component,
293            Counter(2)
294        );
295        assert_eq!(
296            manager.components().get("key2").unwrap().component,
297            Counter(4)
298        );
299    }
300
301    #[tokio::test]
302    async fn test_try_reinit_all_async_with_failure() {
303        let call_count = Arc::new(Mutex::new(0));
304        let call_count_clone = call_count.clone();
305
306        let init = move |args: &FailArgs| {
307            let call_count = call_count_clone.clone();
308            let value = args.value;
309            let should_fail = args.should_fail;
310            async move {
311                let count = *call_count.lock().unwrap();
312                *call_count.lock().unwrap() += 1;
313
314                if count >= 2 && should_fail {
315                    Err(TestError("Failed on reinit".to_string()))
316                } else {
317                    Ok(Counter(value * 2))
318                }
319            }
320        };
321
322        let mut manager = ComponentManager::try_init_async(
323            [
324                (
325                    "key1",
326                    FailArgs {
327                        value: 1,
328                        should_fail: false,
329                    },
330                ),
331                (
332                    "key2",
333                    FailArgs {
334                        value: 2,
335                        should_fail: true,
336                    },
337                ),
338            ],
339            init,
340        )
341        .await
342        .unwrap();
343
344        let results: Vec<_> = manager.try_reinit_all_async().await.collect();
345
346        assert_eq!(results.len(), 2);
347        let failures: Vec<_> = results.iter().filter(|r| r.value.is_err()).collect();
348        assert_eq!(failures.len(), 1);
349        let successes: Vec<_> = results.iter().filter(|r| r.value.is_ok()).collect();
350        assert_eq!(successes.len(), 1);
351    }
352
353    #[tokio::test]
354    async fn test_try_reinit_all_async_empty() {
355        let init = |args: &FailArgs| {
356            let value = args.value;
357            let should_fail = args.should_fail;
358            async move {
359                if should_fail {
360                    Err(TestError("Failed".to_string()))
361                } else {
362                    Ok(Counter(value))
363                }
364            }
365        };
366
367        let mut manager: ComponentManager<&str, FailArgs, Counter, _> =
368            ComponentManager::try_init_async([], init).await.unwrap();
369
370        let results: Vec<_> = manager.try_reinit_all_async().await.collect();
371        assert_eq!(results.len(), 0);
372    }
373
374    #[tokio::test]
375    async fn test_try_reinit_async_success() {
376        let init = |args: &FailArgs| {
377            let value = args.value;
378            let should_fail = args.should_fail;
379            async move {
380                if should_fail {
381                    Err(TestError("Failed".to_string()))
382                } else {
383                    Ok(Counter(value * 3))
384                }
385            }
386        };
387
388        let mut manager = ComponentManager::try_init_async(
389            [
390                (
391                    "key1",
392                    FailArgs {
393                        value: 1,
394                        should_fail: false,
395                    },
396                ),
397                (
398                    "key2",
399                    FailArgs {
400                        value: 2,
401                        should_fail: false,
402                    },
403                ),
404            ],
405            init,
406        )
407        .await
408        .unwrap();
409
410        let results: Vec<_> = manager.try_reinit_async(["key1"]).await.collect();
411
412        assert_eq!(results.len(), 1);
413        assert!(results[0].value.as_ref().unwrap().is_ok());
414        assert_eq!(
415            manager.components().get("key1").unwrap().component,
416            Counter(3)
417        );
418        assert_eq!(
419            manager.components().get("key2").unwrap().component,
420            Counter(6)
421        );
422    }
423
424    #[tokio::test]
425    async fn test_try_reinit_async_nonexistent_key() {
426        let init = |args: &FailArgs| {
427            let value = args.value;
428            let should_fail = args.should_fail;
429            async move {
430                if should_fail {
431                    Err(TestError("Failed".to_string()))
432                } else {
433                    Ok(Counter(value))
434                }
435            }
436        };
437
438        let mut manager = ComponentManager::try_init_async(
439            [(
440                "key1",
441                FailArgs {
442                    value: 1,
443                    should_fail: false,
444                },
445            )],
446            init,
447        )
448        .await
449        .unwrap();
450
451        let results: Vec<_> = manager.try_reinit_async(["nonexistent"]).await.collect();
452
453        assert_eq!(results.len(), 1);
454        assert_eq!(results[0].key, "nonexistent");
455        assert!(results[0].value.is_none());
456    }
457
458    #[tokio::test]
459    async fn test_try_update_async_new_key_success() {
460        let init = |args: &FailArgs| {
461            let value = args.value;
462            let should_fail = args.should_fail;
463            async move {
464                if should_fail {
465                    Err(TestError("Failed".to_string()))
466                } else {
467                    Ok(Counter(value))
468                }
469            }
470        };
471
472        let mut manager = ComponentManager::try_init_async(
473            [(
474                "key1",
475                FailArgs {
476                    value: 1,
477                    should_fail: false,
478                },
479            )],
480            init,
481        )
482        .await
483        .unwrap();
484
485        let results: Vec<_> = manager
486            .try_update_async([(
487                "key2",
488                FailArgs {
489                    value: 20,
490                    should_fail: false,
491                },
492            )])
493            .await
494            .collect();
495
496        assert_eq!(results.len(), 1);
497        assert!(results[0].value.is_none());
498        assert_eq!(manager.components().len(), 2);
499        assert_eq!(
500            manager.components().get("key2").unwrap().component,
501            Counter(20)
502        );
503    }
504
505    #[tokio::test]
506    async fn test_try_update_async_failure() {
507        let init = |args: &FailArgs| {
508            let value = args.value;
509            let should_fail = args.should_fail;
510            async move {
511                if should_fail {
512                    Err(TestError("Failed".to_string()))
513                } else {
514                    Ok(Counter(value))
515                }
516            }
517        };
518
519        let mut manager = ComponentManager::try_init_async(
520            [(
521                "key1",
522                FailArgs {
523                    value: 1,
524                    should_fail: false,
525                },
526            )],
527            init,
528        )
529        .await
530        .unwrap();
531
532        let results: Vec<_> = manager
533            .try_update_async([(
534                "key2",
535                FailArgs {
536                    value: 20,
537                    should_fail: true,
538                },
539            )])
540            .await
541            .collect();
542
543        assert_eq!(results.len(), 1);
544        assert!(results[0].value.is_some());
545        assert!(results[0].value.as_ref().unwrap().is_err());
546
547        // Should not insert on error
548        assert_eq!(manager.components().len(), 1);
549        assert!(manager.components().get("key2").is_none());
550    }
551
552    #[tokio::test]
553    async fn test_try_update_async_multiple_mixed() {
554        let init = |args: &FailArgs| {
555            let value = args.value;
556            let should_fail = args.should_fail;
557            async move {
558                if should_fail {
559                    Err(TestError("Failed".to_string()))
560                } else {
561                    Ok(Counter(value))
562                }
563            }
564        };
565
566        let mut manager = ComponentManager::try_init_async(
567            [(
568                "key1",
569                FailArgs {
570                    value: 1,
571                    should_fail: false,
572                },
573            )],
574            init,
575        )
576        .await
577        .unwrap();
578
579        let results: Vec<_> = manager
580            .try_update_async([
581                (
582                    "key2",
583                    FailArgs {
584                        value: 20,
585                        should_fail: false,
586                    },
587                ),
588                (
589                    "key3",
590                    FailArgs {
591                        value: 30,
592                        should_fail: true,
593                    },
594                ),
595                (
596                    "key4",
597                    FailArgs {
598                        value: 40,
599                        should_fail: false,
600                    },
601                ),
602            ])
603            .await
604            .collect();
605
606        assert_eq!(results.len(), 3);
607
608        // Check that only successful updates were inserted
609        assert_eq!(manager.components().len(), 3); // key1, key2, key4
610        assert!(manager.components().get("key2").is_some());
611        assert!(manager.components().get("key3").is_none());
612        assert!(manager.components().get("key4").is_some());
613    }
614}