1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// Copyright © 2021 Alexandra Frydl
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

//! Wait for multiple fallible tasks concurrently.

use crate::prelude::*;
use crate::string::SharedString;
use crate::task;

/// The index of a [`Join`] task.
pub type Index = usize;

/// Concurrently waits for the results of multiple tasks that may return an
/// error.
#[derive(Deref, DerefMut)]
pub struct TryJoin<T, E> {
  #[deref]
  #[deref_mut]
  join: task::Join<Result<T, E>>,
}

impl<T, E> TryJoin<T, E>
where
  T: Send + 'static,
  E: From<task::Panic> + Send + 'static,
{
  /// Creates an empty join.
  pub fn new() -> Self {
    Self { join: task::Join::new() }
  }

  /// Adds a task to the join, returning its index.
  pub fn add(&mut self, task: impl task::Start<Result<T, E>>) -> Index {
    self.join.add(task)
  }

  /// Adds a named task to the join, returning its index.
  pub fn add_as(
    &mut self,
    name: impl Into<SharedString>,
    task: impl task::Start<Result<T, E>>,
  ) -> Index {
    self.join.add_as(name, task)
  }

  /// Waits for the next stopped task.
  ///
  /// If all tasks have stopped, this function returns `None`.
  pub async fn next(&mut self) -> Option<StoppedTask<T, E>> {
    let task = self.join.next().await?;

    Some(StoppedTask {
      index: task.index,
      name: task.name,
      result: task.result.map_err(E::from).and_then(|res| res),
    })
  }

  /// Waits for the next stopped task and returns its information as a
  /// [`Result`].
  ///
  /// If all tasks have stopped, this function returns `None`.
  pub async fn try_next(&mut self) -> Option<Result<FinishedTask<T>, FailedTask<E>>> {
    let task = self.next().await?;

    Some(match task.result {
      Ok(output) => Ok(FinishedTask { index: task.index, name: task.name, output }),
      Err(error) => Err(FailedTask { index: task.index, name: task.name, error }),
    })
  }

  /// Waits for all tasks to stop, dropping their results.
  pub async fn drain(&mut self) {
    self.join.drain().await
  }

  /// Waits for all tasks to stop, dropping their results, until a task fails.
  pub async fn try_drain(&mut self) -> Result<(), FailedTask<E>> {
    while self.try_next().await.transpose()?.is_some() {}

    Ok(())
  }
}

impl<T, E> Default for TryJoin<T, E>
where
  T: Send + 'static,
  E: From<Panic> + Send + 'static,
{
  fn default() -> Self {
    Self::new()
  }
}

/// Information about a stopped task.
#[derive(Debug)]
pub struct StoppedTask<T, E> {
  /// The index of the task.
  pub index: Index,
  /// The name of the task, if any.
  pub name: SharedString,
  /// The result of the task.
  pub result: Result<T, E>,
}

/// Information about a finished task.
#[derive(Debug)]
pub struct FinishedTask<T> {
  /// The index of the task.
  pub index: Index,
  /// The name of the task, if any.
  pub name: SharedString,
  /// The output of the task.
  pub output: T,
}

/// Information about a failed task.
#[derive(Debug)]
pub struct FailedTask<E> {
  /// The index of the task.
  pub index: Index,
  /// The name of the task, if any.
  pub name: SharedString,
  /// The error of the task.
  pub error: E,
}

impl<E> Display for FailedTask<E>
where
  E: Display,
{
  fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
    match self.name.as_str() {
      "" => write!(f, "Task #{} ", self.index)?,
      name => write!(f, "Task `{}`", name)?,
    }

    write!(f, "failed. {}", self.error)
  }
}

impl<E> Error for FailedTask<E> where E: Debug + Display {}