use crate::trace_utils::{trace_block, trace_future};
use std::future::Future;
use std::task::{Context, Poll};
use tokio::runtime::Handle;
use tokio::task::{AbortHandle, Id, JoinError, LocalSet};
#[derive(Debug)]
pub struct JoinSet<T> {
inner: tokio::task::JoinSet<T>,
}
impl<T> Default for JoinSet<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> JoinSet<T> {
pub fn new() -> Self {
Self {
inner: tokio::task::JoinSet::new(),
}
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl<T: 'static> JoinSet<T> {
pub fn spawn<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.inner.spawn(trace_future(task))
}
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.inner.spawn_on(trace_future(task), handle)
}
pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T>,
F: 'static,
{
self.inner.spawn_local(task)
}
pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet) -> AbortHandle
where
F: Future<Output = T>,
F: 'static,
{
self.inner.spawn_local_on(task, local_set)
}
pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send,
{
self.inner.spawn_blocking(trace_block(f))
}
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send,
{
self.inner.spawn_blocking_on(trace_block(f), handle)
}
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
self.inner.join_next().await
}
pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
self.inner.try_join_next()
}
pub fn abort_all(&mut self) {
self.inner.abort_all()
}
pub fn detach_all(&mut self) {
self.inner.detach_all()
}
pub fn poll_join_next(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<T, JoinError>>> {
self.inner.poll_join_next(cx)
}
pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
self.inner.join_next_with_id().await
}
pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
self.inner.try_join_next_with_id()
}
pub fn poll_join_next_with_id(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(Id, T), JoinError>>> {
self.inner.poll_join_next_with_id(cx)
}
pub async fn shutdown(&mut self) {
self.inner.shutdown().await
}
pub async fn join_all(self) -> Vec<T> {
self.inner.join_all().await
}
}