#![cfg(feature = "scope")]
use tokio::task::JoinSet;
pub async fn scope<A, E>(
body: impl FnOnce(&mut JoinSet<Result<(), E>>) -> Result<A, E>,
) -> Result<A, E>
where
E: 'static,
{
let mut set = JoinSet::<Result<(), E>>::new();
let result = body(&mut set);
if result.is_ok() {
join_all(&mut set).await?;
} else {
set.shutdown().await;
}
result
}
async fn join_next<E>(set: &mut JoinSet<Result<(), E>>) -> Option<Result<(), E>>
where
E: 'static,
{
let next = set.join_next().await;
next.map(|outer| match outer {
Ok(inner) => inner, Err(_) => Ok(()), })
}
async fn join_all<E>(set: &mut JoinSet<Result<(), E>>) -> Result<(), E>
where
E: 'static,
{
loop {
match join_next(set).await {
Some(Ok(())) => (), Some(e) => {
set.shutdown().await;
break e;
}
None => break Ok(()), }
}
}
#[cfg(test)]
mod test {
use super::*;
use std::{sync::Arc, time::Duration};
use tokio::{sync::Mutex, time::sleep};
#[tokio::test]
async fn test_simple_scope() {
let task_load = 100;
let counter = scope::<_, ()>(|tasker| {
let counter = Arc::new(Mutex::new(0usize));
for _i in 0..task_load {
let c = counter.clone(); tasker.spawn(async move {
sleep(Duration::from_millis(100)).await;
*c.lock().await += 1;
Ok(())
});
}
Ok(counter)
})
.await
.unwrap();
assert!(*counter.lock().await == task_load)
}
}