#![deny(missing_debug_implementations, nonstandard_style)]
#![warn(missing_docs, unreachable_pub)]
use async_std::channel::{self, Receiver, Sender};
use async_std::future::Future;
use async_std::stream::Stream;
use async_std::task::{self, JoinHandle};
use async_std::task::{Context, Poll};
use core::pin::Pin;
pub struct TaskGroup<E> {
sender: Sender<ChildHandle<E>>,
}
impl<E> std::fmt::Debug for TaskGroup<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TaskGroup").finish_non_exhaustive()
}
}
pub fn group<E, Fut, F>(f: F) -> GroupJoinHandle<E>
where
E: Send + 'static,
F: FnOnce(TaskGroup<E>) -> Fut,
Fut: Future<Output = Result<TaskGroup<E>, E>> + Send + 'static,
{
let (sender, receiver) = channel::unbounded();
let group = TaskGroup { sender };
let join_handle = GroupJoinHandle::new(receiver);
let fut = f(group.clone());
group.spawn(async move {
let _ = fut.await;
Ok(())
});
join_handle
}
impl<E> TaskGroup<E>
where
E: Send + 'static,
{
pub fn spawn<F>(&self, f: F)
where
F: Future<Output = Result<(), E>> + Send + 'static,
{
let join = task::spawn(f);
self.sender
.try_send(ChildHandle { handle: join })
.expect("Sending a task to the channel failed");
}
pub fn spawn_local<F>(&self, f: F)
where
F: Future<Output = Result<(), E>> + 'static,
{
let join = task::spawn_local(f);
self.sender
.try_send(ChildHandle { handle: join })
.expect("Sending a task to the channel failed");
}
pub fn build(&self) -> GroupBuilder<'_, E> {
GroupBuilder {
task_group: self,
builder: task::Builder::new(),
}
}
pub fn is_closed(&self) -> bool {
self.sender.is_closed()
}
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
}
}
}
#[derive(Debug)]
pub struct GroupBuilder<'a, E> {
task_group: &'a TaskGroup<E>,
builder: task::Builder,
}
impl<'a, E> GroupBuilder<'a, E>
where
E: Send + 'static,
{
pub fn name<A: AsRef<String>>(mut self, name: A) -> Self {
self.builder = self.builder.name(name.as_ref().to_owned());
self
}
pub fn spawn<F>(self, future: F)
where
F: Future<Output = Result<(), E>> + Send + 'static,
{
let handle = self.builder.spawn(future).unwrap();
self.task_group
.sender
.try_send(ChildHandle { handle })
.expect("Sending a task to the channel failed");
}
pub fn spawn_local<F>(self, future: F)
where
F: Future<Output = Result<(), E>> + 'static,
{
let handle = self.builder.local(future).unwrap();
self.task_group
.sender
.try_send(ChildHandle { handle })
.expect("Sending a task to the channel failed");
}
}
#[derive(Debug)]
struct ChildHandle<E> {
handle: JoinHandle<Result<(), E>>,
}
impl<E> ChildHandle<E> {
fn pin_join(self: Pin<&mut Self>) -> Pin<&mut JoinHandle<Result<(), E>>> {
unsafe { self.map_unchecked_mut(|s| &mut s.handle) }
}
}
impl<E> Drop for ChildHandle<E> {
fn drop(&mut self) {}
}
#[derive(Debug)]
pub struct GroupJoinHandle<E> {
channel: Option<Receiver<ChildHandle<E>>>,
children: Vec<Pin<Box<ChildHandle<E>>>>,
}
impl<E> GroupJoinHandle<E> {
fn new(channel: Receiver<ChildHandle<E>>) -> Self {
Self {
channel: Some(channel),
children: Vec::new(),
}
}
}
impl<E> Future for GroupJoinHandle<E> {
type Output = Result<(), E>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
let mut s = self.as_mut();
if let Some(mut channel) = s.channel.take() {
s.channel = loop {
match unsafe { Pin::new_unchecked(&mut channel) }.poll_next(ctx) {
Poll::Pending => {
break Some(channel);
}
Poll::Ready(Some(new_child)) => {
s.children.push(Box::pin(new_child));
}
Poll::Ready(None) => {
break None;
}
}
};
}
let mut err = None;
let mut child_ix = 0;
while s.children.get(child_ix).is_some() {
let child = s
.children
.get_mut(child_ix)
.expect("precondition: child exists at index");
match child.as_mut().pin_join().poll(ctx) {
Poll::Pending => child_ix += 1,
Poll::Ready(Ok(())) => {
let _ = s.children.swap_remove(child_ix);
}
Poll::Ready(Err(error)) => {
err = Some(error);
break;
}
}
}
if let Some(err) = err {
s.children.truncate(0);
s.channel.take();
Poll::Ready(Err(err))
} else if s.children.is_empty() {
if s.channel.is_none() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
} else {
Poll::Pending
}
}
}
#[cfg(test)]
mod test {
use super::*;
use anyhow::anyhow;
#[async_std::test]
async fn no_task() {
let handle = group(|group| async move { Ok::<_, ()>(group) });
assert!(handle.await.is_ok());
}
#[async_std::test]
async fn one_empty_task() {
let handle = group(|group| async move {
group.spawn(async move { Ok(()) });
Ok::<_, ()>(group)
});
assert!(handle.await.is_ok());
}
#[async_std::test]
async fn root_task_errors() {
let handle = group(|group| async move {
group.spawn(async { Err(anyhow!("idk!")) });
Ok(group)
});
let res = handle.await;
assert!(res.is_err());
assert_eq!(format!("{:?}", res), "Err(idk!)");
}
}