use std::future::Future;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use tokio::task::AbortHandle;
use tokio::task::JoinError;
use super::task::MaskFutureAsSend;
use super::task::MaskResultAsSend;
pub struct JoinSet<T> {
joinset: tokio::task::JoinSet<MaskResultAsSend<T>>,
waker: Option<Waker>,
}
impl<T> Default for JoinSet<T> {
fn default() -> Self {
Self {
joinset: Default::default(),
waker: None,
}
}
}
impl<T: 'static> JoinSet<T> {
#[track_caller]
pub fn spawn<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T>,
F: 'static,
T: 'static,
{
let handle = self.joinset.spawn(unsafe { MaskFutureAsSend::new(task) });
if let Some(waker) = self.waker.take() {
waker.wake();
}
handle
}
pub fn len(&self) -> usize {
self.joinset.len()
}
pub fn is_empty(&self) -> bool {
self.joinset.is_empty()
}
pub fn poll_join_next(
&mut self,
cx: &mut Context,
) -> Poll<Result<T, JoinError>> {
match self.joinset.poll_join_next(cx) {
Poll::Ready(Some(res)) => Poll::Ready(res.map(|res| res.into_inner())),
Poll::Ready(None) => {
self.waker = Some(cx.waker().clone());
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
self
.joinset
.join_next()
.await
.map(|result| result.map(|res| res.into_inner()))
}
pub fn abort_all(&mut self) {
self.joinset.abort_all();
}
pub fn detach_all(&mut self) {
self.joinset.detach_all();
}
}