1use crate::{ComponentMap, Keyed, WithArgs};
2use futures::future::join_all;
3
4impl<Key, Args, Comp, FnInit> ComponentMap<Key, Args, Comp, FnInit> {
5 pub async fn try_init_async<Error>(
6 args: impl IntoIterator<Item = (Key, Args)>,
7 init: FnInit,
8 ) -> Result<Self, Error>
9 where
10 Key: Eq + std::hash::Hash,
11 FnInit: AsyncFn(&Key, &Args) -> Result<Comp, Error> + Clone,
12 {
13 let components_fut = args.into_iter().map(|(key, args)| {
14 let init = init.clone();
15 async move {
16 let result = (init)(&key, &args)
17 .await
18 .map(|component| WithArgs { component, args });
19
20 (key, result)
21 }
22 });
23
24 let map = join_all(components_fut)
25 .await
26 .into_iter()
27 .map(|(key, result)| result.map(|component| (key, component)))
28 .collect::<Result<_, _>>()?;
29
30 Ok(Self { map: map, init })
31 }
32
33 pub async fn try_reinit_all_async<Error>(
34 &mut self,
35 ) -> impl Iterator<Item = Keyed<&Key, Result<Comp, Error>>>
36 where
37 FnInit: AsyncFn(&Key, &Args) -> Result<Comp, Error> + Clone,
38 {
39 let next_components_fut = self
40 .map
41 .iter()
42 .map(|(key, component)| (self.init)(key, &component.args));
43
44 let next_components = join_all(next_components_fut).await;
45
46 self.map
47 .iter_mut()
48 .zip(next_components)
49 .map(|((key, prev), result)| {
50 let result = result.map(|next| std::mem::replace(&mut prev.component, next));
51
52 Keyed::new(key, result)
53 })
54 }
55
56 pub async fn try_reinit_async<Error>(
57 &mut self,
58 keys: impl IntoIterator<Item = Key>,
59 ) -> impl Iterator<Item = Keyed<Key, Option<Result<Comp, Error>>>>
60 where
61 Key: Eq + std::hash::Hash + Clone,
62 FnInit: AsyncFn(&Key, &Args) -> Result<Comp, Error> + Clone,
63 {
64 let next_components_fut = keys.into_iter().map(|key| {
65 let init = self.init.clone();
66
67 let args = self.map.get(&key).map(|component| &component.args);
68
69 async move {
70 let result = match args {
71 Some(args) => Some((init)(&key, args).await),
72 None => None,
73 };
74 Keyed::new(key, result)
75 }
76 });
77
78 let results = join_all(next_components_fut).await;
79
80 results.into_iter().map(|Keyed { key, value: result }| {
81 let prev = result
82 .map(|result| {
83 result.map(|next| {
84 self.map
85 .get_mut(&key)
86 .map(|component| std::mem::replace(&mut component.component, next))
87 })
88 })
89 .transpose()
90 .map(Option::flatten);
91
92 Keyed::new(key, prev.transpose())
93 })
94 }
95
96 pub async fn try_update_async<Error>(
97 &mut self,
98 updates: impl IntoIterator<Item = (Key, Args)>,
99 ) -> impl Iterator<Item = Keyed<Key, Option<Result<WithArgs<Args, Comp>, Error>>>>
100 where
101 Key: Clone + Eq + std::hash::Hash,
102 FnInit: AsyncFn(&Key, &Args) -> Result<Comp, Error> + Clone,
103 {
104 let updated_components_fut = updates.into_iter().map(|(key, args)| {
105 let init = self.init.clone();
106 async move {
107 let result = (init)(&key, &args)
108 .await
109 .map(|component| WithArgs { component, args });
110
111 (key, result)
112 }
113 });
114
115 join_all(updated_components_fut)
116 .await
117 .into_iter()
118 .map(|(key, result)| {
119 let result = result.map(|component| self.map.insert(key.clone(), component));
120
121 Keyed::new(key, result.transpose())
122 })
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use std::sync::{Arc, Mutex};
130
131 #[derive(Debug, Clone, PartialEq, Eq)]
132 struct Counter(usize);
133
134 #[derive(Debug, Clone, PartialEq, Eq)]
135 struct FailArgs {
136 value: usize,
137 should_fail: bool,
138 }
139
140 #[derive(Debug, PartialEq, Eq)]
141 struct TestError(String);
142
143 #[tokio::test]
144 async fn test_try_init_async_success() {
145 let init = |_key: &&str, args: &FailArgs| {
146 let value = args.value;
147 let should_fail = args.should_fail;
148 async move {
149 if should_fail {
150 Err(TestError("Failed".to_string()))
151 } else {
152 Ok(Counter(value))
153 }
154 }
155 };
156
157 let result = ComponentMap::try_init_async(
158 [
159 (
160 "key1",
161 FailArgs {
162 value: 1,
163 should_fail: false,
164 },
165 ),
166 (
167 "key2",
168 FailArgs {
169 value: 2,
170 should_fail: false,
171 },
172 ),
173 ],
174 init,
175 )
176 .await;
177
178 assert!(result.is_ok());
179 let manager = result.unwrap();
180 assert_eq!(manager.map.len(), 2);
181 assert_eq!(manager.map.get("key1").unwrap().component, Counter(1));
182 assert_eq!(manager.map.get("key2").unwrap().component, Counter(2));
183 }
184
185 #[tokio::test]
186 async fn test_try_init_async_failure() {
187 let init = |_key: &&str, args: &FailArgs| {
188 let value = args.value;
189 let should_fail = args.should_fail;
190 async move {
191 if should_fail {
192 Err(TestError("Failed".to_string()))
193 } else {
194 Ok(Counter(value))
195 }
196 }
197 };
198
199 let result = ComponentMap::try_init_async(
200 [
201 (
202 "key1",
203 FailArgs {
204 value: 1,
205 should_fail: false,
206 },
207 ),
208 (
209 "key2",
210 FailArgs {
211 value: 2,
212 should_fail: true,
213 },
214 ),
215 ],
216 init,
217 )
218 .await;
219
220 assert!(result.is_err());
221 assert_eq!(result.err().unwrap(), TestError("Failed".to_string()));
222 }
223
224 #[tokio::test]
225 async fn test_try_init_async_empty() {
226 let init = |_key: &&str, args: &FailArgs| {
227 let value = args.value;
228 let should_fail = args.should_fail;
229 async move {
230 if should_fail {
231 Err(TestError("Failed".to_string()))
232 } else {
233 Ok(Counter(value))
234 }
235 }
236 };
237
238 let result: Result<ComponentMap<&str, FailArgs, Counter, _>, TestError> =
239 ComponentMap::try_init_async([], init).await;
240
241 assert!(result.is_ok());
242 assert_eq!(result.unwrap().map.len(), 0);
243 }
244
245 #[tokio::test]
246 async fn test_try_reinit_all_async_success() {
247 let init = |_key: &&str, args: &FailArgs| {
248 let value = args.value;
249 let should_fail = args.should_fail;
250 async move {
251 if should_fail {
252 Err(TestError("Failed".to_string()))
253 } else {
254 Ok(Counter(value * 2))
255 }
256 }
257 };
258
259 let mut manager = ComponentMap::try_init_async(
260 [
261 (
262 "key1",
263 FailArgs {
264 value: 1,
265 should_fail: false,
266 },
267 ),
268 (
269 "key2",
270 FailArgs {
271 value: 2,
272 should_fail: false,
273 },
274 ),
275 ],
276 init,
277 )
278 .await
279 .unwrap();
280
281 let results: Vec<_> = manager.try_reinit_all_async().await.collect();
282
283 assert_eq!(results.len(), 2);
284 assert!(results.iter().all(|r| r.value.is_ok()));
285
286 assert_eq!(manager.map.get("key1").unwrap().component, Counter(2));
287 assert_eq!(manager.map.get("key2").unwrap().component, Counter(4));
288 }
289
290 #[tokio::test]
291 async fn test_try_reinit_all_async_with_failure() {
292 let call_count = Arc::new(Mutex::new(0));
293 let call_count_clone = call_count.clone();
294
295 let init = move |_key: &&str, args: &FailArgs| {
296 let call_count = call_count_clone.clone();
297 let value = args.value;
298 let should_fail = args.should_fail;
299 async move {
300 let count = *call_count.lock().unwrap();
301 *call_count.lock().unwrap() += 1;
302
303 if count >= 2 && should_fail {
304 Err(TestError("Failed on reinit".to_string()))
305 } else {
306 Ok(Counter(value * 2))
307 }
308 }
309 };
310
311 let mut manager = ComponentMap::try_init_async(
312 [
313 (
314 "key1",
315 FailArgs {
316 value: 1,
317 should_fail: false,
318 },
319 ),
320 (
321 "key2",
322 FailArgs {
323 value: 2,
324 should_fail: true,
325 },
326 ),
327 ],
328 init,
329 )
330 .await
331 .unwrap();
332
333 let results: Vec<_> = manager.try_reinit_all_async().await.collect();
334
335 assert_eq!(results.len(), 2);
336 let failures: Vec<_> = results.iter().filter(|r| r.value.is_err()).collect();
337 assert_eq!(failures.len(), 1);
338 let successes: Vec<_> = results.iter().filter(|r| r.value.is_ok()).collect();
339 assert_eq!(successes.len(), 1);
340 }
341
342 #[tokio::test]
343 async fn test_try_reinit_all_async_empty() {
344 let init = |_key: &&str, args: &FailArgs| {
345 let value = args.value;
346 let should_fail = args.should_fail;
347 async move {
348 if should_fail {
349 Err(TestError("Failed".to_string()))
350 } else {
351 Ok(Counter(value))
352 }
353 }
354 };
355
356 let mut manager: ComponentMap<&str, FailArgs, Counter, _> =
357 ComponentMap::try_init_async([], init).await.unwrap();
358
359 let results: Vec<_> = manager.try_reinit_all_async().await.collect();
360 assert_eq!(results.len(), 0);
361 }
362
363 #[tokio::test]
364 async fn test_try_reinit_async_success() {
365 let init = |_key: &&str, args: &FailArgs| {
366 let value = args.value;
367 let should_fail = args.should_fail;
368 async move {
369 if should_fail {
370 Err(TestError("Failed".to_string()))
371 } else {
372 Ok(Counter(value * 3))
373 }
374 }
375 };
376
377 let mut manager = ComponentMap::try_init_async(
378 [
379 (
380 "key1",
381 FailArgs {
382 value: 1,
383 should_fail: false,
384 },
385 ),
386 (
387 "key2",
388 FailArgs {
389 value: 2,
390 should_fail: false,
391 },
392 ),
393 ],
394 init,
395 )
396 .await
397 .unwrap();
398
399 let results: Vec<_> = manager.try_reinit_async(["key1"]).await.collect();
400
401 assert_eq!(results.len(), 1);
402 assert!(results[0].value.as_ref().unwrap().is_ok());
403 assert_eq!(manager.map.get("key1").unwrap().component, Counter(3));
404 assert_eq!(manager.map.get("key2").unwrap().component, Counter(6));
405 }
406
407 #[tokio::test]
408 async fn test_try_reinit_async_nonexistent_key() {
409 let init = |_key: &&str, args: &FailArgs| {
410 let value = args.value;
411 let should_fail = args.should_fail;
412 async move {
413 if should_fail {
414 Err(TestError("Failed".to_string()))
415 } else {
416 Ok(Counter(value))
417 }
418 }
419 };
420
421 let mut manager = ComponentMap::try_init_async(
422 [(
423 "key1",
424 FailArgs {
425 value: 1,
426 should_fail: false,
427 },
428 )],
429 init,
430 )
431 .await
432 .unwrap();
433
434 let results: Vec<_> = manager.try_reinit_async(["nonexistent"]).await.collect();
435
436 assert_eq!(results.len(), 1);
437 assert_eq!(results[0].key, "nonexistent");
438 assert!(results[0].value.is_none());
439 }
440
441 #[tokio::test]
442 async fn test_try_update_async_new_key_success() {
443 let init = |_key: &&str, args: &FailArgs| {
444 let value = args.value;
445 let should_fail = args.should_fail;
446 async move {
447 if should_fail {
448 Err(TestError("Failed".to_string()))
449 } else {
450 Ok(Counter(value))
451 }
452 }
453 };
454
455 let mut manager = ComponentMap::try_init_async(
456 [(
457 "key1",
458 FailArgs {
459 value: 1,
460 should_fail: false,
461 },
462 )],
463 init,
464 )
465 .await
466 .unwrap();
467
468 let results: Vec<_> = manager
469 .try_update_async([(
470 "key2",
471 FailArgs {
472 value: 20,
473 should_fail: false,
474 },
475 )])
476 .await
477 .collect();
478
479 assert_eq!(results.len(), 1);
480 assert!(results[0].value.is_none());
481 assert_eq!(manager.map.len(), 2);
482 assert_eq!(manager.map.get("key2").unwrap().component, Counter(20));
483 }
484
485 #[tokio::test]
486 async fn test_try_update_async_failure() {
487 let init = |_key: &&str, args: &FailArgs| {
488 let value = args.value;
489 let should_fail = args.should_fail;
490 async move {
491 if should_fail {
492 Err(TestError("Failed".to_string()))
493 } else {
494 Ok(Counter(value))
495 }
496 }
497 };
498
499 let mut manager = ComponentMap::try_init_async(
500 [(
501 "key1",
502 FailArgs {
503 value: 1,
504 should_fail: false,
505 },
506 )],
507 init,
508 )
509 .await
510 .unwrap();
511
512 let results: Vec<_> = manager
513 .try_update_async([(
514 "key2",
515 FailArgs {
516 value: 20,
517 should_fail: true,
518 },
519 )])
520 .await
521 .collect();
522
523 assert_eq!(results.len(), 1);
524 assert!(results[0].value.is_some());
525 assert!(results[0].value.as_ref().unwrap().is_err());
526
527 assert_eq!(manager.map.len(), 1);
529 assert!(manager.map.get("key2").is_none());
530 }
531
532 #[tokio::test]
533 async fn test_try_update_async_multiple_mixed() {
534 let init = |_key: &&str, args: &FailArgs| {
535 let value = args.value;
536 let should_fail = args.should_fail;
537 async move {
538 if should_fail {
539 Err(TestError("Failed".to_string()))
540 } else {
541 Ok(Counter(value))
542 }
543 }
544 };
545
546 let mut manager = ComponentMap::try_init_async(
547 [(
548 "key1",
549 FailArgs {
550 value: 1,
551 should_fail: false,
552 },
553 )],
554 init,
555 )
556 .await
557 .unwrap();
558
559 let results: Vec<_> = manager
560 .try_update_async([
561 (
562 "key2",
563 FailArgs {
564 value: 20,
565 should_fail: false,
566 },
567 ),
568 (
569 "key3",
570 FailArgs {
571 value: 30,
572 should_fail: true,
573 },
574 ),
575 (
576 "key4",
577 FailArgs {
578 value: 40,
579 should_fail: false,
580 },
581 ),
582 ])
583 .await
584 .collect();
585
586 assert_eq!(results.len(), 3);
587
588 assert_eq!(manager.map.len(), 3); assert!(manager.map.get("key2").is_some());
591 assert!(manager.map.get("key3").is_none());
592 assert!(manager.map.get("key4").is_some());
593 }
594}