use std::{
ops::Deref,
sync::{Arc, Mutex},
};
use async_channel::{Receiver, Sender};
use futures_lite::prelude::*;
use slab::Slab;
pub struct Nursery {
nhandle: NurseryHandle,
recv_error: Receiver<anyhow::Error>,
recv_term: Receiver<()>,
}
impl Deref for Nursery {
type Target = NurseryHandle;
fn deref(&self) -> &Self::Target {
&self.nhandle
}
}
impl Default for Nursery {
fn default() -> Self {
Self::new()
}
}
impl Nursery {
pub fn new() -> Self {
let (send_error, recv_error) = async_channel::unbounded();
let (send_term, recv_term) = async_channel::unbounded();
Self {
nhandle: NurseryHandle {
task_holder: Arc::new(Mutex::new(Slab::default())),
send_error,
send_term,
},
recv_error,
recv_term,
}
}
pub fn handle(&self) -> NurseryHandle {
self.nhandle.clone()
}
pub async fn wait(self) -> anyhow::Result<()> {
let a = async {
if self.recv_error.sender_count() > 1 {
let next_error = self
.recv_error
.recv()
.await
.expect("recv_error should never fail");
Err(next_error)
} else {
Ok(())
}
};
let b = async {
while self.recv_error.sender_count() > 1 {
self.recv_term.recv().await?;
}
Ok(())
};
a.or(b).await
}
pub fn wait_sync(self) -> anyhow::Result<()> {
futures_lite::future::block_on(self.wait())
}
}
#[derive(Clone)]
pub struct NurseryHandle {
task_holder: Arc<Mutex<Slab<async_executor::Task<()>>>>,
send_error: Sender<anyhow::Error>,
send_term: Sender<()>,
}
impl NurseryHandle {
pub fn spawn<F: Future<Output = anyhow::Result<()>> + Send + 'static>(
&self,
mut on_error: OnError,
task_gen: impl FnOnce(NurseryHandle) -> F + Send + 'static,
) {
let send_error = self.send_error.clone();
let this = self.clone();
let (send_tid, recv_tid) = async_oneshot::oneshot();
let task_holder = self.task_holder.clone();
let task = crate::spawn(async move {
scopeguard::defer!({
let _ = this.send_term.try_send(());
});
let send_error = send_error.clone();
let result = task_gen(this.clone()).await;
match result {
Ok(()) => (),
Err(err) => {
while let OnError::Custom(f) = on_error {
on_error = f(&err)
}
match on_error {
OnError::Ignore => (),
OnError::Propagate => {
let _ = send_error.send(err).await;
}
_ => unreachable!(),
}
if let Ok(tid) = recv_tid.await {
drop(task_holder.lock().unwrap().remove(tid));
};
}
}
});
let task_id = self.task_holder.lock().unwrap().insert(task);
if send_tid.send(task_id).is_err() {
drop(self.task_holder.lock().unwrap().remove(task_id));
}
}
}
pub enum OnError {
Ignore,
Propagate,
Custom(Box<dyn FnOnce(&anyhow::Error) -> OnError + Send + Sync>),
}
impl OnError {
pub fn custom(f: impl FnOnce(&anyhow::Error) -> OnError + 'static + Send + Sync) -> Self {
Self::Custom(Box::new(f))
}
pub fn ignore_with(f: impl FnOnce(&anyhow::Error) + 'static + Send + Sync) -> Self {
Self::Custom(Box::new(move |e| {
f(e);
Self::Ignore
}))
}
pub fn propagate_with(f: impl FnOnce(&anyhow::Error) + 'static + Send + Sync) -> Self {
Self::Custom(Box::new(move |e| {
f(e);
Self::Propagate
}))
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::AtomicUsize;
use super::*;
#[test]
fn nursery_simple() {
let nursery = Nursery::new();
let counter = Arc::new(AtomicUsize::new(0));
nursery.spawn(OnError::Ignore, {
let counter = counter;
move |nursery| async move {
eprintln!("hello world");
nursery.spawn(
OnError::propagate_with(|e| eprintln!("error: {}", e)),
|_| async {
eprintln!("attempt");
anyhow::bail!("oh no");
},
);
drop(nursery);
drop(counter);
Ok(())
}
});
assert!(nursery.wait_sync().is_err())
}
}