1use std::error::Error;
4use std::fmt;
5use std::future::Future;
6use std::sync::Arc;
7
8use tokio::task::{JoinError, JoinSet};
9
10pub const DEFAULT_MAX_CONCURRENCY: usize = 16;
12
13pub const MAX_CONCURRENCY: usize = 10_000;
18
19#[derive(Debug)]
26#[non_exhaustive]
27pub enum TaskGroupError<E> {
28 ZeroConcurrency,
30 ExcessiveConcurrency {
32 max_concurrency: usize,
34 upper_bound: usize,
36 },
37 TaskFailed {
39 index: usize,
41 error: E,
43 },
44 TaskJoinFailed {
46 index: Option<usize>,
53 source: JoinError,
55 },
56}
57
58impl<E> fmt::Display for TaskGroupError<E>
59where
60 E: fmt::Display,
61{
62 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
63 match self {
64 Self::ZeroConcurrency => {
65 formatter.write_str("max_concurrency must be greater than zero")
66 }
67 Self::ExcessiveConcurrency {
68 max_concurrency,
69 upper_bound,
70 } => write!(
71 formatter,
72 "max_concurrency must be less than or equal to {upper_bound}, got {max_concurrency}"
73 ),
74 Self::TaskFailed { index, error } => {
75 write!(formatter, "task {index} failed: {error}")
76 }
77 Self::TaskJoinFailed {
78 index: Some(index),
79 source,
80 } => write!(formatter, "task {index} failed to join: {source}"),
81 Self::TaskJoinFailed {
82 index: None,
83 source,
84 } => write!(formatter, "task failed to join: {source}"),
85 }
86 }
87}
88
89impl<E> Error for TaskGroupError<E>
90where
91 E: Error + 'static,
92{
93 fn source(&self) -> Option<&(dyn Error + 'static)> {
94 match self {
95 Self::TaskFailed { error, .. } => Some(error),
96 Self::TaskJoinFailed { source, .. } => Some(source),
97 Self::ZeroConcurrency | Self::ExcessiveConcurrency { .. } => None,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Eq, PartialEq)]
104#[non_exhaustive]
105pub struct TaskSuccess<T> {
106 pub index: usize,
108 pub value: T,
110}
111
112#[derive(Debug, Clone, Eq, PartialEq)]
114#[non_exhaustive]
115pub struct TaskFailure<E> {
116 pub index: usize,
118 pub error: E,
120}
121
122#[derive(Debug, Clone, Eq, PartialEq)]
128#[non_exhaustive]
129pub struct TaskGroupReport<T, E> {
130 pub successes: Vec<TaskSuccess<T>>,
132 pub failures: Vec<TaskFailure<E>>,
134}
135
136impl<T, E> TaskGroupReport<T, E> {
137 #[must_use]
139 pub fn is_success(&self) -> bool {
140 self.failures.is_empty()
141 }
142
143 #[must_use]
145 pub fn len(&self) -> usize {
146 self.successes.len() + self.failures.len()
147 }
148
149 #[must_use]
151 pub fn is_empty(&self) -> bool {
152 self.successes.is_empty() && self.failures.is_empty()
153 }
154}
155
156enum TaskOutcome<T, E> {
157 Success { index: usize, value: T },
158 Failure { index: usize, error: E },
159}
160
161pub async fn try_map_bounded<I, F, Fut, T, E>(
195 items: I,
196 max_concurrency: usize,
197 operation: F,
198) -> Result<Vec<T>, TaskGroupError<E>>
199where
200 I: IntoIterator,
201 I::Item: Send + 'static,
202 F: Fn(I::Item) -> Fut + Send + Sync + 'static,
203 Fut: Future<Output = Result<T, E>> + Send + 'static,
204 T: Send + 'static,
205 E: Send + 'static,
206{
207 validate_max_concurrency(max_concurrency)?;
208
209 let mut tasks = JoinSet::new();
210 let mut indexed_items = items.into_iter().enumerate();
211 let operation = Arc::new(operation);
212 let mut results = Vec::new();
213
214 fill_tasks(
215 &mut tasks,
216 &mut indexed_items,
217 max_concurrency,
218 &operation,
219 &mut results,
220 );
221
222 while let Some(result) = tasks.join_next().await {
223 match result {
224 Ok(TaskOutcome::Success { index, value }) => {
225 results[index] = Some(value);
226 fill_tasks(
227 &mut tasks,
228 &mut indexed_items,
229 max_concurrency,
230 &operation,
231 &mut results,
232 );
233 }
234 Ok(TaskOutcome::Failure { index, error }) => {
235 shutdown_tasks(&mut tasks).await;
236 return Err(TaskGroupError::TaskFailed { index, error });
237 }
238 Err(source) => {
239 shutdown_tasks(&mut tasks).await;
240 return Err(TaskGroupError::TaskJoinFailed {
241 index: None,
242 source,
243 });
244 }
245 }
246 }
247
248 Ok(results.into_iter().flatten().collect())
249}
250
251pub async fn map_bounded_collect<I, F, Fut, T, E>(
287 items: I,
288 max_concurrency: usize,
289 operation: F,
290) -> Result<TaskGroupReport<T, E>, TaskGroupError<E>>
291where
292 I: IntoIterator,
293 I::Item: Send + 'static,
294 F: Fn(I::Item) -> Fut + Send + Sync + 'static,
295 Fut: Future<Output = Result<T, E>> + Send + 'static,
296 T: Send + 'static,
297 E: Send + 'static,
298{
299 validate_max_concurrency(max_concurrency)?;
300
301 let mut tasks = JoinSet::new();
302 let mut indexed_items = items.into_iter().enumerate();
303 let operation = Arc::new(operation);
304 let mut successes = Vec::new();
305 let mut failures = Vec::new();
306 let mut slots = Vec::new();
307
308 fill_tasks(
309 &mut tasks,
310 &mut indexed_items,
311 max_concurrency,
312 &operation,
313 &mut slots,
314 );
315
316 while let Some(result) = tasks.join_next().await {
317 match result {
318 Ok(TaskOutcome::Success { index, value }) => {
319 successes.push(TaskSuccess { index, value });
320 }
321 Ok(TaskOutcome::Failure { index, error }) => {
322 failures.push(TaskFailure { index, error });
323 }
324 Err(source) => {
325 shutdown_tasks(&mut tasks).await;
326 return Err(TaskGroupError::TaskJoinFailed {
327 index: None,
328 source,
329 });
330 }
331 }
332
333 fill_tasks(
334 &mut tasks,
335 &mut indexed_items,
336 max_concurrency,
337 &operation,
338 &mut slots,
339 );
340 }
341
342 successes.sort_by_key(|success| success.index);
343 failures.sort_by_key(|failure| failure.index);
344
345 Ok(TaskGroupReport {
346 successes,
347 failures,
348 })
349}
350
351fn validate_max_concurrency<E>(max_concurrency: usize) -> Result<(), TaskGroupError<E>> {
352 if max_concurrency == 0 {
353 return Err(TaskGroupError::ZeroConcurrency);
354 }
355 if max_concurrency > MAX_CONCURRENCY {
356 return Err(TaskGroupError::ExcessiveConcurrency {
357 max_concurrency,
358 upper_bound: MAX_CONCURRENCY,
359 });
360 }
361 Ok(())
362}
363
364fn fill_tasks<I, F, Fut, T, E>(
365 tasks: &mut JoinSet<TaskOutcome<T, E>>,
366 indexed_items: &mut std::iter::Enumerate<I>,
367 max_concurrency: usize,
368 operation: &Arc<F>,
369 slots: &mut Vec<Option<T>>,
370) where
371 I: Iterator,
372 I::Item: Send + 'static,
373 F: Fn(I::Item) -> Fut + Send + Sync + 'static,
374 Fut: Future<Output = Result<T, E>> + Send + 'static,
375 T: Send + 'static,
376 E: Send + 'static,
377{
378 while tasks.len() < max_concurrency {
379 let Some((index, item)) = indexed_items.next() else {
380 break;
381 };
382
383 while slots.len() <= index {
384 slots.push(None);
385 }
386
387 let operation = Arc::clone(operation);
388 tasks.spawn(async move {
389 match operation(item).await {
390 Ok(value) => TaskOutcome::Success { index, value },
391 Err(error) => TaskOutcome::Failure { index, error },
392 }
393 });
394 }
395}
396
397async fn shutdown_tasks<T, E>(tasks: &mut JoinSet<TaskOutcome<T, E>>)
398where
399 T: Send + 'static,
400 E: Send + 'static,
401{
402 tasks.abort_all();
403 while tasks.join_next().await.is_some() {}
404}
405
406#[cfg(test)]
407mod tests {
408 use std::future::pending;
409 use std::sync::Arc;
410 use std::sync::atomic::{AtomicUsize, Ordering};
411 use std::{error, fmt};
412
413 use tokio::sync::Notify;
414 use tokio::task::yield_now;
415 use tokio::time::{Duration, sleep, timeout};
416
417 use super::*;
418
419 struct DropCounter {
420 counter: Arc<AtomicUsize>,
421 }
422
423 impl Drop for DropCounter {
424 fn drop(&mut self) {
425 self.counter.fetch_add(1, Ordering::SeqCst);
426 }
427 }
428
429 #[derive(Debug, Eq, PartialEq)]
430 struct StaticError(&'static str);
431
432 impl fmt::Display for StaticError {
433 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
434 formatter.write_str(self.0)
435 }
436 }
437
438 impl error::Error for StaticError {}
439
440 #[test]
441 fn task_group_error_formats_validation_failures() {
442 let zero = TaskGroupError::<StaticError>::ZeroConcurrency;
443 let excessive = TaskGroupError::<StaticError>::ExcessiveConcurrency {
444 max_concurrency: MAX_CONCURRENCY + 1,
445 upper_bound: MAX_CONCURRENCY,
446 };
447
448 assert_eq!(
449 zero.to_string(),
450 "max_concurrency must be greater than zero"
451 );
452 assert_eq!(
453 excessive.to_string(),
454 format!(
455 "max_concurrency must be less than or equal to {}, got {}",
456 MAX_CONCURRENCY,
457 MAX_CONCURRENCY + 1
458 )
459 );
460 assert!(zero.source().is_none());
461 assert!(excessive.source().is_none());
462 }
463
464 #[test]
465 fn task_group_error_preserves_operation_error_source() {
466 let error = TaskGroupError::TaskFailed {
467 index: 3,
468 error: StaticError("operation failed"),
469 };
470
471 assert_eq!(error.to_string(), "task 3 failed: operation failed");
472 assert_eq!(
473 error.source().map(ToString::to_string),
474 Some("operation failed".to_owned())
475 );
476 }
477
478 #[tokio::test]
479 async fn try_map_bounded_preserves_input_order() {
480 let values = try_map_bounded([3, 1, 2], 2, |value| async move {
481 sleep(Duration::from_millis((4 - value) * 10)).await;
482 Ok::<_, &'static str>(value * 10)
483 })
484 .await
485 .unwrap();
486
487 assert_eq!(values, vec![30, 10, 20]);
488 }
489
490 #[tokio::test]
491 async fn try_map_bounded_respects_concurrency_bound() {
492 let current = Arc::new(AtomicUsize::new(0));
493 let peak = Arc::new(AtomicUsize::new(0));
494
495 let values = try_map_bounded(0..10, 3, {
496 let current = Arc::clone(¤t);
497 let peak = Arc::clone(&peak);
498 move |value| {
499 let current = Arc::clone(¤t);
500 let peak = Arc::clone(&peak);
501 async move {
502 let active = current.fetch_add(1, Ordering::SeqCst) + 1;
503 peak.fetch_max(active, Ordering::SeqCst);
504 sleep(Duration::from_millis(5)).await;
505 current.fetch_sub(1, Ordering::SeqCst);
506 Ok::<_, &'static str>(value)
507 }
508 }
509 })
510 .await
511 .unwrap();
512
513 assert_eq!(values, (0..10).collect::<Vec<_>>());
514 assert!(peak.load(Ordering::SeqCst) <= 3);
515 }
516
517 #[tokio::test]
518 async fn try_map_bounded_aborts_and_drains_siblings_on_first_error() {
519 let started = Arc::new(Notify::new());
520 let dropped = Arc::new(AtomicUsize::new(0));
521
522 let actual = try_map_bounded(0..2, 2, {
523 let started = Arc::clone(&started);
524 let dropped = Arc::clone(&dropped);
525 move |value| {
526 let started = Arc::clone(&started);
527 let dropped = Arc::clone(&dropped);
528 async move {
529 if value == 0 {
530 started.notified().await;
531 return Err("boom");
532 }
533
534 let _guard = DropCounter { counter: dropped };
535 started.notify_one();
536 pending::<()>().await;
537 Ok::<_, &'static str>(value)
538 }
539 }
540 })
541 .await;
542
543 assert!(matches!(
544 actual,
545 Err(TaskGroupError::TaskFailed {
546 index: 0,
547 error: "boom"
548 })
549 ));
550 assert_eq!(dropped.load(Ordering::SeqCst), 1);
551 }
552
553 #[tokio::test]
554 async fn map_bounded_collect_records_all_operation_results() {
555 let report = map_bounded_collect(0..5, 2, |value| async move {
556 if value % 2 == 0 {
557 Ok(value * 10)
558 } else {
559 Err(value)
560 }
561 })
562 .await
563 .unwrap();
564
565 assert!(!report.is_success());
566 assert_eq!(report.len(), 5);
567 assert_eq!(
568 report.successes,
569 vec![
570 TaskSuccess { index: 0, value: 0 },
571 TaskSuccess {
572 index: 2,
573 value: 20
574 },
575 TaskSuccess {
576 index: 4,
577 value: 40
578 },
579 ]
580 );
581 assert_eq!(
582 report.failures,
583 vec![
584 TaskFailure { index: 1, error: 1 },
585 TaskFailure { index: 3, error: 3 },
586 ]
587 );
588 }
589
590 #[tokio::test]
591 async fn map_bounded_collect_reports_empty_input() {
592 let report = map_bounded_collect(Vec::<i32>::new(), 4, |value| async move {
593 Ok::<_, StaticError>(value)
594 })
595 .await
596 .unwrap();
597
598 assert!(report.is_success());
599 assert_eq!(report.len(), 0);
600 assert!(report.is_empty());
601 }
602
603 #[tokio::test]
604 async fn map_bounded_collect_rejects_invalid_concurrency() {
605 let zero =
606 map_bounded_collect([1], 0, |value| async move { Ok::<_, StaticError>(value) }).await;
607 let excessive = map_bounded_collect([1], MAX_CONCURRENCY + 1, |value| async move {
608 Ok::<_, StaticError>(value)
609 })
610 .await;
611
612 assert!(matches!(zero, Err(TaskGroupError::ZeroConcurrency)));
613 assert!(matches!(
614 excessive,
615 Err(TaskGroupError::ExcessiveConcurrency {
616 max_concurrency,
617 upper_bound: MAX_CONCURRENCY
618 }) if max_concurrency == MAX_CONCURRENCY + 1
619 ));
620 }
621
622 #[tokio::test]
623 async fn rejects_zero_concurrency() {
624 let actual =
625 try_map_bounded([1], 0, |value| async move { Ok::<_, &'static str>(value) }).await;
626
627 assert!(matches!(actual, Err(TaskGroupError::ZeroConcurrency)));
628 }
629
630 #[tokio::test]
631 async fn rejects_excessive_concurrency() {
632 let actual = try_map_bounded([1], MAX_CONCURRENCY + 1, |value| async move {
633 Ok::<_, &'static str>(value)
634 })
635 .await;
636
637 assert!(matches!(
638 actual,
639 Err(TaskGroupError::ExcessiveConcurrency {
640 max_concurrency,
641 upper_bound: MAX_CONCURRENCY
642 }) if max_concurrency == MAX_CONCURRENCY + 1
643 ));
644 }
645
646 #[tokio::test]
647 async fn reports_join_failure_and_drains_remaining_tasks() {
648 let sibling_started = Arc::new(Notify::new());
649 let dropped = Arc::new(AtomicUsize::new(0));
650
651 let actual = try_map_bounded(0..2, 2, {
652 let sibling_started = Arc::clone(&sibling_started);
653 let dropped = Arc::clone(&dropped);
654 move |value| {
655 let sibling_started = Arc::clone(&sibling_started);
656 let dropped = Arc::clone(&dropped);
657 async move {
658 if value == 0 {
659 sibling_started.notified().await;
660 panic!("task panic");
661 }
662
663 let _guard = DropCounter { counter: dropped };
664 sibling_started.notify_one();
665 pending::<()>().await;
666 Ok::<_, &'static str>(value)
667 }
668 }
669 })
670 .await;
671
672 assert!(matches!(
673 actual,
674 Err(TaskGroupError::TaskJoinFailed {
675 index: None,
676 source,
677 }) if source.is_panic()
678 ));
679 assert_eq!(dropped.load(Ordering::SeqCst), 1);
680 }
681
682 #[tokio::test]
683 async fn join_failure_formats_and_exposes_source() {
684 let actual = try_map_bounded([1], 1, |_| async move {
685 panic!("task panic");
686 #[allow(unreachable_code)]
687 Ok::<_, StaticError>(())
688 })
689 .await;
690
691 let Err(TaskGroupError::TaskJoinFailed {
692 index: None,
693 source,
694 }) = actual
695 else {
696 panic!("expected join failure");
697 };
698 let error = TaskGroupError::<StaticError>::TaskJoinFailed {
699 index: Some(7),
700 source,
701 };
702
703 assert!(error.to_string().starts_with("task 7 failed to join:"));
704 assert!(error.source().is_some());
705 }
706
707 #[tokio::test]
708 async fn map_bounded_collect_join_failure_drains_remaining_tasks() {
709 let sibling_started = Arc::new(Notify::new());
710 let dropped = Arc::new(AtomicUsize::new(0));
711
712 let actual = map_bounded_collect(0..2, 2, {
713 let sibling_started = Arc::clone(&sibling_started);
714 let dropped = Arc::clone(&dropped);
715 move |value| {
716 let sibling_started = Arc::clone(&sibling_started);
717 let dropped = Arc::clone(&dropped);
718 async move {
719 if value == 0 {
720 sibling_started.notified().await;
721 panic!("task panic");
722 }
723
724 let _guard = DropCounter { counter: dropped };
725 sibling_started.notify_one();
726 pending::<()>().await;
727 Ok::<_, StaticError>(value)
728 }
729 }
730 })
731 .await;
732
733 assert!(matches!(
734 actual,
735 Err(TaskGroupError::TaskJoinFailed {
736 index: None,
737 source,
738 }) if source.is_panic()
739 ));
740 assert_eq!(dropped.load(Ordering::SeqCst), 1);
741 }
742
743 #[tokio::test]
744 async fn dropping_try_map_bounded_future_aborts_started_tasks() {
745 let started = Arc::new(AtomicUsize::new(0));
746 let dropped = Arc::new(AtomicUsize::new(0));
747
748 let task = tokio::spawn(try_map_bounded(0..4, 4, {
749 let started = Arc::clone(&started);
750 let dropped = Arc::clone(&dropped);
751 move |value| {
752 let started = Arc::clone(&started);
753 let dropped = Arc::clone(&dropped);
754 async move {
755 let _guard = DropCounter { counter: dropped };
756 started.fetch_add(1, Ordering::SeqCst);
757 pending::<()>().await;
758 Ok::<_, StaticError>(value)
759 }
760 }
761 }));
762
763 while started.load(Ordering::SeqCst) < 4 {
764 yield_now().await;
765 }
766
767 task.abort();
768 assert!(task.await.unwrap_err().is_cancelled());
769 timeout(Duration::from_secs(1), async {
770 while dropped.load(Ordering::SeqCst) < 4 {
771 yield_now().await;
772 }
773 })
774 .await
775 .unwrap();
776
777 assert_eq!(dropped.load(Ordering::SeqCst), 4);
778 }
779
780 #[tokio::test]
781 async fn dropping_map_bounded_collect_future_aborts_started_tasks() {
782 let started = Arc::new(AtomicUsize::new(0));
783 let dropped = Arc::new(AtomicUsize::new(0));
784
785 let task = tokio::spawn(map_bounded_collect(0..4, 4, {
786 let started = Arc::clone(&started);
787 let dropped = Arc::clone(&dropped);
788 move |value| {
789 let started = Arc::clone(&started);
790 let dropped = Arc::clone(&dropped);
791 async move {
792 let _guard = DropCounter { counter: dropped };
793 started.fetch_add(1, Ordering::SeqCst);
794 pending::<()>().await;
795 Ok::<_, StaticError>(value)
796 }
797 }
798 }));
799
800 while started.load(Ordering::SeqCst) < 4 {
801 yield_now().await;
802 }
803
804 task.abort();
805 assert!(task.await.unwrap_err().is_cancelled());
806 timeout(Duration::from_secs(1), async {
807 while dropped.load(Ordering::SeqCst) < 4 {
808 yield_now().await;
809 }
810 })
811 .await
812 .unwrap();
813
814 assert_eq!(dropped.load(Ordering::SeqCst), 4);
815 }
816}