1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
use std::future::Future;
use tokio::task::JoinSet;

pub mod prelude {
  pub use super::{IntoJoinSet as _, JoinSetFromIter as _};
}

pub trait JoinSetFromIter: Iterator {
  fn join_set<T>(self) -> JoinSet<T>
  where
    Self: Sized,
    Self::Item: Future<Output = T> + Send + 'static,
    T: Send + 'static,
  {
    self.collect()
  }

  fn join_set_by<T, F, M>(self, f: M) -> JoinSet<T>
  where
    Self: Sized,
    Self::Item: Send + 'static,
    F: Future<Output = T> + Send + 'static,
    T: Send + 'static,
    M: FnMut(Self::Item) -> F,
  {
    self.map(f).join_set()
  }
}

impl<T> JoinSetFromIter for T where T: Iterator + ?Sized {}

pub trait IntoJoinSet<F, T>: IntoIterator
where
  Self: Sized,
  <Self as IntoIterator>::Item: Future<Output = T> + Send + 'static,
  F: Future<Output = T> + Send + 'static,
  T: Send + 'static,
{
  fn into_join_set(self) -> JoinSet<T> {
    self.into_iter().join_set()
  }

  fn into_join_set_by<M>(self, f: M) -> JoinSet<T>
  where
    M: FnMut(Self::Item) -> F,
  {
    self.into_iter().join_set_by(f)
  }
}

impl<F, T> IntoJoinSet<F, T> for Vec<F>
where
  F: Future<Output = T> + Send + 'static,
  T: Send + 'static,
{
}

#[cfg(test)]
mod tests {
  use super::*;
  use std::future;

  #[tokio::test]
  async fn join_set_by() {
    let mut set = (0..10).into_iter().join_set_by(future::ready);

    assert!(set.len() == 10);

    while let Some(result) = set.join_next().await {
      result.unwrap();
    }
  }
}