use std::future::Future;
use futures::{stream::FuturesUnordered, StreamExt};
use tokio::task::{JoinError, JoinHandle};
#[derive(Debug, Default)]
pub struct LateJoinSet {
tasks: FuturesUnordered<JoinHandle<()>>,
}
impl LateJoinSet {
#[track_caller]
pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) {
self.insert(tokio::spawn(task));
}
pub fn insert(&self, task: JoinHandle<()>) {
self.tasks.push(task);
}
pub async fn join_next(&mut self) -> Option<Result<(), JoinError>> {
self.tasks.next().await
}
pub async fn join_all(&mut self) {
while self.join_next().await.is_some() {}
self.tasks.clear();
}
pub fn abort_all(&self) {
self.tasks.iter().for_each(JoinHandle::abort);
}
}
impl Drop for LateJoinSet {
fn drop(&mut self) {
self.abort_all();
self.tasks.clear();
}
}