1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
// Copyright © 2021 Alexandra Frydl
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

//! Wait for multiple tasks concurrently.

use crate::channel;
use crate::prelude::*;
use crate::string::SharedString;
use crate::task::{self, Task};
use fnv::FnvHashMap;

/// The index of a [`Join`] task.
pub type Index = usize;

/// Concurrently waits for the results of multiple tasks.
pub struct Join<T> {
  children: FnvHashMap<Index, Child>,
  next_index: Index,
  rx: channel::Receiver<Stopped<T>>,
  tx: channel::Sender<Stopped<T>>,
}

/// A task in a [`Join`].
struct Child {
  name: SharedString,
  _monitor: Task<()>,
}

/// A message sent from a task monitor.
struct Stopped<T> {
  index: usize,
  result: Result<T, Panic>,
}

impl<T> Join<T>
where
  T: Send + 'static,
{
  /// Creates an empty join.
  pub fn new() -> Self {
    let (tx, rx) = channel::unbounded();

    Self { children: default(), next_index: 0, rx, tx }
  }

  /// Adds a task to the join, returning its index.
  pub fn add(&mut self, task: impl task::Start<T>) -> Index {
    self.add_as("", task)
  }

  /// Adds a named task to the join, returning its index.
  pub fn add_as(&mut self, name: impl Into<SharedString>, task: impl task::Start<T>) -> Index {
    // Get next index.

    let index = self.next_index;

    self.next_index += 1;

    // Start the task.

    let task = task.start();

    // Start a task to monitor when this task stops and send its result on the
    // channel.

    let tx = self.tx.clone();

    let _monitor = task::start(async move {
      let result = task.join().await;

      tx.send(Stopped { index, result }).await.ok();
    });

    self.children.insert(index, Child { name: name.into(), _monitor });

    index
  }

  /// Waits for the next stopped task.
  ///
  /// If all tasks have stopped, this function returns `None`.
  pub async fn next(&mut self) -> Option<StoppedTask<Result<T, task::Panic>>> {
    if self.children.is_empty() {
      return None;
    }

    let Stopped { index, result } = self.rx.recv().await.ok()?;
    let child = self.children.remove(&index).expect("Received result from unknown child.");

    Some(StoppedTask { index, name: child.name, result })
  }

  /// Waits for the next stopped task and returns its information as a
  /// [`Result`].
  ///
  /// If all tasks have stopped, this function returns `None`.
  pub async fn try_next(&mut self) -> Option<Result<StoppedTask<T>, PanickedTask>> {
    if self.children.is_empty() {
      return None;
    }

    let Stopped { index, result } = self.rx.recv().await.ok()?;
    let child = self.children.remove(&index).expect("Received result from unknown child.");

    Some(match result {
      Ok(result) => Ok(StoppedTask { index, name: child.name, result }),
      Err(panic) => Err(PanickedTask { index, name: child.name, panic }),
    })
  }

  /// Waits for all tasks to stop, dropping their results.
  pub async fn drain(&mut self) {
    while self.next().await.is_some() {}
  }

  /// Waits for all tasks to stop, dropping their results, until a task panics.
  pub async fn try_drain(&mut self) -> Result<(), PanickedTask> {
    while self.try_next().await.transpose()?.is_some() {}

    Ok(())
  }
}

impl<T> Default for Join<T>
where
  T: Send + 'static,
{
  fn default() -> Self {
    Self::new()
  }
}

/// Information about a stopped task.
#[derive(Debug)]
pub struct StoppedTask<T> {
  /// The index of the task.
  pub index: Index,
  /// The name of the task, if any.
  pub name: SharedString,
  /// The result of the task.
  pub result: T,
}

/// Information about a stopped task.
#[derive(Debug, Error)]
pub struct PanickedTask {
  /// The index of the task.
  pub index: Index,
  /// The name of the task, if any.
  pub name: SharedString,
  /// The panic from the task.
  pub panic: task::Panic,
}

impl Display for PanickedTask {
  fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
    match self.name.as_str() {
      "" => write!(f, "Task #{} ", self.index)?,
      name => write!(f, "Task `{}`", name)?,
    }

    write!(f, "panicked")?;

    if let Some(value) = self.panic.value_str() {
      write!(f, " with `{}`", value)?;
    }

    write!(f, ".")
  }
}