Skip to main content

component_map/
async_fallible.rs

1use crate::{ComponentMap, Keyed, WithArgs};
2use futures::future::join_all;
3
4impl<Key, Args, Comp, FnInit> ComponentMap<Key, Args, Comp, FnInit> {
5    pub async fn try_init_async<Error>(
6        args: impl IntoIterator<Item = (Key, Args)>,
7        init: FnInit,
8    ) -> Result<Self, Error>
9    where
10        Key: Eq + std::hash::Hash,
11        FnInit: AsyncFn(&Key, &Args) -> Result<Comp, Error> + Clone,
12    {
13        let components_fut = args.into_iter().map(|(key, args)| {
14            let init = init.clone();
15            async move {
16                let result = (init)(&key, &args)
17                    .await
18                    .map(|component| WithArgs { component, args });
19
20                (key, result)
21            }
22        });
23
24        let map = join_all(components_fut)
25            .await
26            .into_iter()
27            .map(|(key, result)| result.map(|component| (key, component)))
28            .collect::<Result<_, _>>()?;
29
30        Ok(Self { map: map, init })
31    }
32
33    pub async fn try_reinit_all_async<Error>(
34        &mut self,
35    ) -> impl Iterator<Item = Keyed<&Key, Result<Comp, Error>>>
36    where
37        FnInit: AsyncFn(&Key, &Args) -> Result<Comp, Error> + Clone,
38    {
39        let next_components_fut = self
40            .map
41            .iter()
42            .map(|(key, component)| (self.init)(key, &component.args));
43
44        let next_components = join_all(next_components_fut).await;
45
46        self.map
47            .iter_mut()
48            .zip(next_components)
49            .map(|((key, prev), result)| {
50                let result = result.map(|next| std::mem::replace(&mut prev.component, next));
51
52                Keyed::new(key, result)
53            })
54    }
55
56    pub async fn try_reinit_async<Error>(
57        &mut self,
58        keys: impl IntoIterator<Item = Key>,
59    ) -> impl Iterator<Item = Keyed<Key, Option<Result<Comp, Error>>>>
60    where
61        Key: Eq + std::hash::Hash + Clone,
62        FnInit: AsyncFn(&Key, &Args) -> Result<Comp, Error> + Clone,
63    {
64        let next_components_fut = keys.into_iter().map(|key| {
65            let init = self.init.clone();
66
67            let args = self.map.get(&key).map(|component| &component.args);
68
69            async move {
70                let result = match args {
71                    Some(args) => Some((init)(&key, args).await),
72                    None => None,
73                };
74                Keyed::new(key, result)
75            }
76        });
77
78        let results = join_all(next_components_fut).await;
79
80        results.into_iter().map(|Keyed { key, value: result }| {
81            let prev = result
82                .map(|result| {
83                    result.map(|next| {
84                        self.map
85                            .get_mut(&key)
86                            .map(|component| std::mem::replace(&mut component.component, next))
87                    })
88                })
89                .transpose()
90                .map(Option::flatten);
91
92            Keyed::new(key, prev.transpose())
93        })
94    }
95
96    pub async fn try_update_async<Error>(
97        &mut self,
98        updates: impl IntoIterator<Item = (Key, Args)>,
99    ) -> impl Iterator<Item = Keyed<Key, Option<Result<WithArgs<Args, Comp>, Error>>>>
100    where
101        Key: Clone + Eq + std::hash::Hash,
102        FnInit: AsyncFn(&Key, &Args) -> Result<Comp, Error> + Clone,
103    {
104        let updated_components_fut = updates.into_iter().map(|(key, args)| {
105            let init = self.init.clone();
106            async move {
107                let result = (init)(&key, &args)
108                    .await
109                    .map(|component| WithArgs { component, args });
110
111                (key, result)
112            }
113        });
114
115        join_all(updated_components_fut)
116            .await
117            .into_iter()
118            .map(|(key, result)| {
119                let result = result.map(|component| self.map.insert(key.clone(), component));
120
121                Keyed::new(key, result.transpose())
122            })
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use std::sync::{Arc, Mutex};
130
131    #[derive(Debug, Clone, PartialEq, Eq)]
132    struct Counter(usize);
133
134    #[derive(Debug, Clone, PartialEq, Eq)]
135    struct FailArgs {
136        value: usize,
137        should_fail: bool,
138    }
139
140    #[derive(Debug, PartialEq, Eq)]
141    struct TestError(String);
142
143    #[tokio::test]
144    async fn test_try_init_async_success() {
145        let init = |_key: &&str, args: &FailArgs| {
146            let value = args.value;
147            let should_fail = args.should_fail;
148            async move {
149                if should_fail {
150                    Err(TestError("Failed".to_string()))
151                } else {
152                    Ok(Counter(value))
153                }
154            }
155        };
156
157        let result = ComponentMap::try_init_async(
158            [
159                (
160                    "key1",
161                    FailArgs {
162                        value: 1,
163                        should_fail: false,
164                    },
165                ),
166                (
167                    "key2",
168                    FailArgs {
169                        value: 2,
170                        should_fail: false,
171                    },
172                ),
173            ],
174            init,
175        )
176        .await;
177
178        assert!(result.is_ok());
179        let manager = result.unwrap();
180        assert_eq!(manager.map.len(), 2);
181        assert_eq!(manager.map.get("key1").unwrap().component, Counter(1));
182        assert_eq!(manager.map.get("key2").unwrap().component, Counter(2));
183    }
184
185    #[tokio::test]
186    async fn test_try_init_async_failure() {
187        let init = |_key: &&str, args: &FailArgs| {
188            let value = args.value;
189            let should_fail = args.should_fail;
190            async move {
191                if should_fail {
192                    Err(TestError("Failed".to_string()))
193                } else {
194                    Ok(Counter(value))
195                }
196            }
197        };
198
199        let result = ComponentMap::try_init_async(
200            [
201                (
202                    "key1",
203                    FailArgs {
204                        value: 1,
205                        should_fail: false,
206                    },
207                ),
208                (
209                    "key2",
210                    FailArgs {
211                        value: 2,
212                        should_fail: true,
213                    },
214                ),
215            ],
216            init,
217        )
218        .await;
219
220        assert!(result.is_err());
221        assert_eq!(result.err().unwrap(), TestError("Failed".to_string()));
222    }
223
224    #[tokio::test]
225    async fn test_try_init_async_empty() {
226        let init = |_key: &&str, args: &FailArgs| {
227            let value = args.value;
228            let should_fail = args.should_fail;
229            async move {
230                if should_fail {
231                    Err(TestError("Failed".to_string()))
232                } else {
233                    Ok(Counter(value))
234                }
235            }
236        };
237
238        let result: Result<ComponentMap<&str, FailArgs, Counter, _>, TestError> =
239            ComponentMap::try_init_async([], init).await;
240
241        assert!(result.is_ok());
242        assert_eq!(result.unwrap().map.len(), 0);
243    }
244
245    #[tokio::test]
246    async fn test_try_reinit_all_async_success() {
247        let init = |_key: &&str, args: &FailArgs| {
248            let value = args.value;
249            let should_fail = args.should_fail;
250            async move {
251                if should_fail {
252                    Err(TestError("Failed".to_string()))
253                } else {
254                    Ok(Counter(value * 2))
255                }
256            }
257        };
258
259        let mut manager = ComponentMap::try_init_async(
260            [
261                (
262                    "key1",
263                    FailArgs {
264                        value: 1,
265                        should_fail: false,
266                    },
267                ),
268                (
269                    "key2",
270                    FailArgs {
271                        value: 2,
272                        should_fail: false,
273                    },
274                ),
275            ],
276            init,
277        )
278        .await
279        .unwrap();
280
281        let results: Vec<_> = manager.try_reinit_all_async().await.collect();
282
283        assert_eq!(results.len(), 2);
284        assert!(results.iter().all(|r| r.value.is_ok()));
285
286        assert_eq!(manager.map.get("key1").unwrap().component, Counter(2));
287        assert_eq!(manager.map.get("key2").unwrap().component, Counter(4));
288    }
289
290    #[tokio::test]
291    async fn test_try_reinit_all_async_with_failure() {
292        let call_count = Arc::new(Mutex::new(0));
293        let call_count_clone = call_count.clone();
294
295        let init = move |_key: &&str, args: &FailArgs| {
296            let call_count = call_count_clone.clone();
297            let value = args.value;
298            let should_fail = args.should_fail;
299            async move {
300                let count = *call_count.lock().unwrap();
301                *call_count.lock().unwrap() += 1;
302
303                if count >= 2 && should_fail {
304                    Err(TestError("Failed on reinit".to_string()))
305                } else {
306                    Ok(Counter(value * 2))
307                }
308            }
309        };
310
311        let mut manager = ComponentMap::try_init_async(
312            [
313                (
314                    "key1",
315                    FailArgs {
316                        value: 1,
317                        should_fail: false,
318                    },
319                ),
320                (
321                    "key2",
322                    FailArgs {
323                        value: 2,
324                        should_fail: true,
325                    },
326                ),
327            ],
328            init,
329        )
330        .await
331        .unwrap();
332
333        let results: Vec<_> = manager.try_reinit_all_async().await.collect();
334
335        assert_eq!(results.len(), 2);
336        let failures: Vec<_> = results.iter().filter(|r| r.value.is_err()).collect();
337        assert_eq!(failures.len(), 1);
338        let successes: Vec<_> = results.iter().filter(|r| r.value.is_ok()).collect();
339        assert_eq!(successes.len(), 1);
340    }
341
342    #[tokio::test]
343    async fn test_try_reinit_all_async_empty() {
344        let init = |_key: &&str, args: &FailArgs| {
345            let value = args.value;
346            let should_fail = args.should_fail;
347            async move {
348                if should_fail {
349                    Err(TestError("Failed".to_string()))
350                } else {
351                    Ok(Counter(value))
352                }
353            }
354        };
355
356        let mut manager: ComponentMap<&str, FailArgs, Counter, _> =
357            ComponentMap::try_init_async([], init).await.unwrap();
358
359        let results: Vec<_> = manager.try_reinit_all_async().await.collect();
360        assert_eq!(results.len(), 0);
361    }
362
363    #[tokio::test]
364    async fn test_try_reinit_async_success() {
365        let init = |_key: &&str, args: &FailArgs| {
366            let value = args.value;
367            let should_fail = args.should_fail;
368            async move {
369                if should_fail {
370                    Err(TestError("Failed".to_string()))
371                } else {
372                    Ok(Counter(value * 3))
373                }
374            }
375        };
376
377        let mut manager = ComponentMap::try_init_async(
378            [
379                (
380                    "key1",
381                    FailArgs {
382                        value: 1,
383                        should_fail: false,
384                    },
385                ),
386                (
387                    "key2",
388                    FailArgs {
389                        value: 2,
390                        should_fail: false,
391                    },
392                ),
393            ],
394            init,
395        )
396        .await
397        .unwrap();
398
399        let results: Vec<_> = manager.try_reinit_async(["key1"]).await.collect();
400
401        assert_eq!(results.len(), 1);
402        assert!(results[0].value.as_ref().unwrap().is_ok());
403        assert_eq!(manager.map.get("key1").unwrap().component, Counter(3));
404        assert_eq!(manager.map.get("key2").unwrap().component, Counter(6));
405    }
406
407    #[tokio::test]
408    async fn test_try_reinit_async_nonexistent_key() {
409        let init = |_key: &&str, args: &FailArgs| {
410            let value = args.value;
411            let should_fail = args.should_fail;
412            async move {
413                if should_fail {
414                    Err(TestError("Failed".to_string()))
415                } else {
416                    Ok(Counter(value))
417                }
418            }
419        };
420
421        let mut manager = ComponentMap::try_init_async(
422            [(
423                "key1",
424                FailArgs {
425                    value: 1,
426                    should_fail: false,
427                },
428            )],
429            init,
430        )
431        .await
432        .unwrap();
433
434        let results: Vec<_> = manager.try_reinit_async(["nonexistent"]).await.collect();
435
436        assert_eq!(results.len(), 1);
437        assert_eq!(results[0].key, "nonexistent");
438        assert!(results[0].value.is_none());
439    }
440
441    #[tokio::test]
442    async fn test_try_update_async_new_key_success() {
443        let init = |_key: &&str, args: &FailArgs| {
444            let value = args.value;
445            let should_fail = args.should_fail;
446            async move {
447                if should_fail {
448                    Err(TestError("Failed".to_string()))
449                } else {
450                    Ok(Counter(value))
451                }
452            }
453        };
454
455        let mut manager = ComponentMap::try_init_async(
456            [(
457                "key1",
458                FailArgs {
459                    value: 1,
460                    should_fail: false,
461                },
462            )],
463            init,
464        )
465        .await
466        .unwrap();
467
468        let results: Vec<_> = manager
469            .try_update_async([(
470                "key2",
471                FailArgs {
472                    value: 20,
473                    should_fail: false,
474                },
475            )])
476            .await
477            .collect();
478
479        assert_eq!(results.len(), 1);
480        assert!(results[0].value.is_none());
481        assert_eq!(manager.map.len(), 2);
482        assert_eq!(manager.map.get("key2").unwrap().component, Counter(20));
483    }
484
485    #[tokio::test]
486    async fn test_try_update_async_failure() {
487        let init = |_key: &&str, args: &FailArgs| {
488            let value = args.value;
489            let should_fail = args.should_fail;
490            async move {
491                if should_fail {
492                    Err(TestError("Failed".to_string()))
493                } else {
494                    Ok(Counter(value))
495                }
496            }
497        };
498
499        let mut manager = ComponentMap::try_init_async(
500            [(
501                "key1",
502                FailArgs {
503                    value: 1,
504                    should_fail: false,
505                },
506            )],
507            init,
508        )
509        .await
510        .unwrap();
511
512        let results: Vec<_> = manager
513            .try_update_async([(
514                "key2",
515                FailArgs {
516                    value: 20,
517                    should_fail: true,
518                },
519            )])
520            .await
521            .collect();
522
523        assert_eq!(results.len(), 1);
524        assert!(results[0].value.is_some());
525        assert!(results[0].value.as_ref().unwrap().is_err());
526
527        // Should not insert on error
528        assert_eq!(manager.map.len(), 1);
529        assert!(manager.map.get("key2").is_none());
530    }
531
532    #[tokio::test]
533    async fn test_try_update_async_multiple_mixed() {
534        let init = |_key: &&str, args: &FailArgs| {
535            let value = args.value;
536            let should_fail = args.should_fail;
537            async move {
538                if should_fail {
539                    Err(TestError("Failed".to_string()))
540                } else {
541                    Ok(Counter(value))
542                }
543            }
544        };
545
546        let mut manager = ComponentMap::try_init_async(
547            [(
548                "key1",
549                FailArgs {
550                    value: 1,
551                    should_fail: false,
552                },
553            )],
554            init,
555        )
556        .await
557        .unwrap();
558
559        let results: Vec<_> = manager
560            .try_update_async([
561                (
562                    "key2",
563                    FailArgs {
564                        value: 20,
565                        should_fail: false,
566                    },
567                ),
568                (
569                    "key3",
570                    FailArgs {
571                        value: 30,
572                        should_fail: true,
573                    },
574                ),
575                (
576                    "key4",
577                    FailArgs {
578                        value: 40,
579                        should_fail: false,
580                    },
581                ),
582            ])
583            .await
584            .collect();
585
586        assert_eq!(results.len(), 3);
587
588        // Check that only successful updates were inserted
589        assert_eq!(manager.map.len(), 3); // key1, key2, key4
590        assert!(manager.map.get("key2").is_some());
591        assert!(manager.map.get("key3").is_none());
592        assert!(manager.map.get("key4").is_some());
593    }
594}