1use crate::{ComponentMap, Keyed, WithArgs};
2use futures::future::join_all;
3
4impl<Key, Args, Comp, FnInit> ComponentMap<Key, Args, Comp, FnInit> {
5 pub async fn try_init_async<Error>(
6 args: impl IntoIterator<Item = (Key, Args)>,
7 init: FnInit,
8 ) -> Result<Self, Error>
9 where
10 Key: Eq + std::hash::Hash,
11 FnInit: AsyncFn(&Args) -> Result<Comp, Error> + Clone,
12 {
13 let components_fut = args.into_iter().map(|(key, args)| {
14 let init = init.clone();
15 async move {
16 let result = (init)(&args)
17 .await
18 .map(|component| WithArgs { component, args });
19
20 (key, result)
21 }
22 });
23
24 let map = join_all(components_fut)
25 .await
26 .into_iter()
27 .map(|(key, result)| result.map(|component| (key, component)))
28 .collect::<Result<_, _>>()?;
29
30 Ok(Self { map, init })
31 }
32
33 pub async fn try_reinit_all_async<Error>(
34 &mut self,
35 ) -> impl Iterator<Item = Keyed<&Key, Result<Comp, Error>>>
36 where
37 Key: Clone,
38 FnInit: AsyncFn(&Args) -> Result<Comp, Error> + Clone,
39 {
40 let next_components_fut = self
41 .map
42 .values()
43 .map(|component| (self.init)(&component.args));
44
45 let next_components = join_all(next_components_fut).await;
46
47 self.map
48 .iter_mut()
49 .zip(next_components)
50 .map(|((key, prev), result)| {
51 let result = result.map(|next| std::mem::replace(&mut prev.component, next));
52
53 Keyed::new(key, result)
54 })
55 }
56
57 pub async fn try_reinit_async<Error>(
58 &mut self,
59 keys: impl IntoIterator<Item = Key>,
60 ) -> impl Iterator<Item = Keyed<Key, Option<Result<Comp, Error>>>>
61 where
62 Key: Eq + std::hash::Hash + Clone,
63 FnInit: AsyncFn(&Args) -> Result<Comp, Error> + Clone,
64 {
65 let next_components_fut = keys.into_iter().map(|key| {
66 let init = self.init.clone();
67 let args = self.map.get(&key).map(|component| &component.args);
68 async move {
69 let result = match args {
70 Some(args) => Some((init)(args).await),
71 None => None,
72 };
73 Keyed::new(key, result)
74 }
75 });
76
77 let results = join_all(next_components_fut).await;
78
79 results.into_iter().map(|Keyed { key, value: result }| {
80 let prev = result
81 .map(|result| {
82 result.map(|next| {
83 self.map
84 .get_mut(&key)
85 .map(|component| std::mem::replace(&mut component.component, next))
86 })
87 })
88 .transpose()
89 .map(Option::flatten);
90
91 Keyed::new(key, prev.transpose())
92 })
93 }
94
95 pub async fn try_update_async<Error>(
96 &mut self,
97 updates: impl IntoIterator<Item = (Key, Args)>,
98 ) -> impl Iterator<Item = Keyed<Key, Option<Result<WithArgs<Args, Comp>, Error>>>>
99 where
100 Key: Clone + Eq + std::hash::Hash,
101 FnInit: AsyncFn(&Args) -> Result<Comp, Error> + Clone,
102 {
103 let updated_components_fut = updates.into_iter().map(|(key, args)| {
104 let init = self.init.clone();
105 async move {
106 let result = (init)(&args)
107 .await
108 .map(|component| WithArgs { component, args });
109
110 (key, result)
111 }
112 });
113
114 join_all(updated_components_fut)
115 .await
116 .into_iter()
117 .map(|(key, result)| {
118 let result = result.map(|component| self.map.insert(key.clone(), component));
119
120 Keyed::new(key, result.transpose())
121 })
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use std::sync::{Arc, Mutex};
129
130 #[derive(Debug, Clone, PartialEq, Eq)]
131 struct Counter(usize);
132
133 #[derive(Debug, Clone, PartialEq, Eq)]
134 struct FailArgs {
135 value: usize,
136 should_fail: bool,
137 }
138
139 #[derive(Debug, PartialEq, Eq)]
140 struct TestError(String);
141
142 #[tokio::test]
143 async fn test_try_init_async_success() {
144 let init = |args: &FailArgs| {
145 let value = args.value;
146 let should_fail = args.should_fail;
147 async move {
148 if should_fail {
149 Err(TestError("Failed".to_string()))
150 } else {
151 Ok(Counter(value))
152 }
153 }
154 };
155
156 let result = ComponentMap::try_init_async(
157 [
158 (
159 "key1",
160 FailArgs {
161 value: 1,
162 should_fail: false,
163 },
164 ),
165 (
166 "key2",
167 FailArgs {
168 value: 2,
169 should_fail: false,
170 },
171 ),
172 ],
173 init,
174 )
175 .await;
176
177 assert!(result.is_ok());
178 let manager = result.unwrap();
179 assert_eq!(manager.components().len(), 2);
180 assert_eq!(
181 manager.components().get("key1").unwrap().component,
182 Counter(1)
183 );
184 assert_eq!(
185 manager.components().get("key2").unwrap().component,
186 Counter(2)
187 );
188 }
189
190 #[tokio::test]
191 async fn test_try_init_async_failure() {
192 let init = |args: &FailArgs| {
193 let value = args.value;
194 let should_fail = args.should_fail;
195 async move {
196 if should_fail {
197 Err(TestError("Failed".to_string()))
198 } else {
199 Ok(Counter(value))
200 }
201 }
202 };
203
204 let result = ComponentMap::try_init_async(
205 [
206 (
207 "key1",
208 FailArgs {
209 value: 1,
210 should_fail: false,
211 },
212 ),
213 (
214 "key2",
215 FailArgs {
216 value: 2,
217 should_fail: true,
218 },
219 ),
220 ],
221 init,
222 )
223 .await;
224
225 assert!(result.is_err());
226 assert_eq!(result.err().unwrap(), TestError("Failed".to_string()));
227 }
228
229 #[tokio::test]
230 async fn test_try_init_async_empty() {
231 let init = |args: &FailArgs| {
232 let value = args.value;
233 let should_fail = args.should_fail;
234 async move {
235 if should_fail {
236 Err(TestError("Failed".to_string()))
237 } else {
238 Ok(Counter(value))
239 }
240 }
241 };
242
243 let result: Result<ComponentMap<&str, FailArgs, Counter, _>, TestError> =
244 ComponentMap::try_init_async([], init).await;
245
246 assert!(result.is_ok());
247 assert_eq!(result.unwrap().components().len(), 0);
248 }
249
250 #[tokio::test]
251 async fn test_try_reinit_all_async_success() {
252 let init = |args: &FailArgs| {
253 let value = args.value;
254 let should_fail = args.should_fail;
255 async move {
256 if should_fail {
257 Err(TestError("Failed".to_string()))
258 } else {
259 Ok(Counter(value * 2))
260 }
261 }
262 };
263
264 let mut manager = ComponentMap::try_init_async(
265 [
266 (
267 "key1",
268 FailArgs {
269 value: 1,
270 should_fail: false,
271 },
272 ),
273 (
274 "key2",
275 FailArgs {
276 value: 2,
277 should_fail: false,
278 },
279 ),
280 ],
281 init,
282 )
283 .await
284 .unwrap();
285
286 let results: Vec<_> = manager.try_reinit_all_async().await.collect();
287
288 assert_eq!(results.len(), 2);
289 assert!(results.iter().all(|r| r.value.is_ok()));
290
291 assert_eq!(
292 manager.components().get("key1").unwrap().component,
293 Counter(2)
294 );
295 assert_eq!(
296 manager.components().get("key2").unwrap().component,
297 Counter(4)
298 );
299 }
300
301 #[tokio::test]
302 async fn test_try_reinit_all_async_with_failure() {
303 let call_count = Arc::new(Mutex::new(0));
304 let call_count_clone = call_count.clone();
305
306 let init = move |args: &FailArgs| {
307 let call_count = call_count_clone.clone();
308 let value = args.value;
309 let should_fail = args.should_fail;
310 async move {
311 let count = *call_count.lock().unwrap();
312 *call_count.lock().unwrap() += 1;
313
314 if count >= 2 && should_fail {
315 Err(TestError("Failed on reinit".to_string()))
316 } else {
317 Ok(Counter(value * 2))
318 }
319 }
320 };
321
322 let mut manager = ComponentMap::try_init_async(
323 [
324 (
325 "key1",
326 FailArgs {
327 value: 1,
328 should_fail: false,
329 },
330 ),
331 (
332 "key2",
333 FailArgs {
334 value: 2,
335 should_fail: true,
336 },
337 ),
338 ],
339 init,
340 )
341 .await
342 .unwrap();
343
344 let results: Vec<_> = manager.try_reinit_all_async().await.collect();
345
346 assert_eq!(results.len(), 2);
347 let failures: Vec<_> = results.iter().filter(|r| r.value.is_err()).collect();
348 assert_eq!(failures.len(), 1);
349 let successes: Vec<_> = results.iter().filter(|r| r.value.is_ok()).collect();
350 assert_eq!(successes.len(), 1);
351 }
352
353 #[tokio::test]
354 async fn test_try_reinit_all_async_empty() {
355 let init = |args: &FailArgs| {
356 let value = args.value;
357 let should_fail = args.should_fail;
358 async move {
359 if should_fail {
360 Err(TestError("Failed".to_string()))
361 } else {
362 Ok(Counter(value))
363 }
364 }
365 };
366
367 let mut manager: ComponentMap<&str, FailArgs, Counter, _> =
368 ComponentMap::try_init_async([], init).await.unwrap();
369
370 let results: Vec<_> = manager.try_reinit_all_async().await.collect();
371 assert_eq!(results.len(), 0);
372 }
373
374 #[tokio::test]
375 async fn test_try_reinit_async_success() {
376 let init = |args: &FailArgs| {
377 let value = args.value;
378 let should_fail = args.should_fail;
379 async move {
380 if should_fail {
381 Err(TestError("Failed".to_string()))
382 } else {
383 Ok(Counter(value * 3))
384 }
385 }
386 };
387
388 let mut manager = ComponentMap::try_init_async(
389 [
390 (
391 "key1",
392 FailArgs {
393 value: 1,
394 should_fail: false,
395 },
396 ),
397 (
398 "key2",
399 FailArgs {
400 value: 2,
401 should_fail: false,
402 },
403 ),
404 ],
405 init,
406 )
407 .await
408 .unwrap();
409
410 let results: Vec<_> = manager.try_reinit_async(["key1"]).await.collect();
411
412 assert_eq!(results.len(), 1);
413 assert!(results[0].value.as_ref().unwrap().is_ok());
414 assert_eq!(
415 manager.components().get("key1").unwrap().component,
416 Counter(3)
417 );
418 assert_eq!(
419 manager.components().get("key2").unwrap().component,
420 Counter(6)
421 );
422 }
423
424 #[tokio::test]
425 async fn test_try_reinit_async_nonexistent_key() {
426 let init = |args: &FailArgs| {
427 let value = args.value;
428 let should_fail = args.should_fail;
429 async move {
430 if should_fail {
431 Err(TestError("Failed".to_string()))
432 } else {
433 Ok(Counter(value))
434 }
435 }
436 };
437
438 let mut manager = ComponentMap::try_init_async(
439 [(
440 "key1",
441 FailArgs {
442 value: 1,
443 should_fail: false,
444 },
445 )],
446 init,
447 )
448 .await
449 .unwrap();
450
451 let results: Vec<_> = manager.try_reinit_async(["nonexistent"]).await.collect();
452
453 assert_eq!(results.len(), 1);
454 assert_eq!(results[0].key, "nonexistent");
455 assert!(results[0].value.is_none());
456 }
457
458 #[tokio::test]
459 async fn test_try_update_async_new_key_success() {
460 let init = |args: &FailArgs| {
461 let value = args.value;
462 let should_fail = args.should_fail;
463 async move {
464 if should_fail {
465 Err(TestError("Failed".to_string()))
466 } else {
467 Ok(Counter(value))
468 }
469 }
470 };
471
472 let mut manager = ComponentMap::try_init_async(
473 [(
474 "key1",
475 FailArgs {
476 value: 1,
477 should_fail: false,
478 },
479 )],
480 init,
481 )
482 .await
483 .unwrap();
484
485 let results: Vec<_> = manager
486 .try_update_async([(
487 "key2",
488 FailArgs {
489 value: 20,
490 should_fail: false,
491 },
492 )])
493 .await
494 .collect();
495
496 assert_eq!(results.len(), 1);
497 assert!(results[0].value.is_none());
498 assert_eq!(manager.components().len(), 2);
499 assert_eq!(
500 manager.components().get("key2").unwrap().component,
501 Counter(20)
502 );
503 }
504
505 #[tokio::test]
506 async fn test_try_update_async_failure() {
507 let init = |args: &FailArgs| {
508 let value = args.value;
509 let should_fail = args.should_fail;
510 async move {
511 if should_fail {
512 Err(TestError("Failed".to_string()))
513 } else {
514 Ok(Counter(value))
515 }
516 }
517 };
518
519 let mut manager = ComponentMap::try_init_async(
520 [(
521 "key1",
522 FailArgs {
523 value: 1,
524 should_fail: false,
525 },
526 )],
527 init,
528 )
529 .await
530 .unwrap();
531
532 let results: Vec<_> = manager
533 .try_update_async([(
534 "key2",
535 FailArgs {
536 value: 20,
537 should_fail: true,
538 },
539 )])
540 .await
541 .collect();
542
543 assert_eq!(results.len(), 1);
544 assert!(results[0].value.is_some());
545 assert!(results[0].value.as_ref().unwrap().is_err());
546
547 assert_eq!(manager.components().len(), 1);
549 assert!(manager.components().get("key2").is_none());
550 }
551
552 #[tokio::test]
553 async fn test_try_update_async_multiple_mixed() {
554 let init = |args: &FailArgs| {
555 let value = args.value;
556 let should_fail = args.should_fail;
557 async move {
558 if should_fail {
559 Err(TestError("Failed".to_string()))
560 } else {
561 Ok(Counter(value))
562 }
563 }
564 };
565
566 let mut manager = ComponentMap::try_init_async(
567 [(
568 "key1",
569 FailArgs {
570 value: 1,
571 should_fail: false,
572 },
573 )],
574 init,
575 )
576 .await
577 .unwrap();
578
579 let results: Vec<_> = manager
580 .try_update_async([
581 (
582 "key2",
583 FailArgs {
584 value: 20,
585 should_fail: false,
586 },
587 ),
588 (
589 "key3",
590 FailArgs {
591 value: 30,
592 should_fail: true,
593 },
594 ),
595 (
596 "key4",
597 FailArgs {
598 value: 40,
599 should_fail: false,
600 },
601 ),
602 ])
603 .await
604 .collect();
605
606 assert_eq!(results.len(), 3);
607
608 assert_eq!(manager.components().len(), 3); assert!(manager.components().get("key2").is_some());
611 assert!(manager.components().get("key3").is_none());
612 assert!(manager.components().get("key4").is_some());
613 }
614}