use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::runtime::Handle;
#[cfg(tokio_unstable)]
use crate::task::Id;
use crate::task::{AbortHandle, JoinError, JoinHandle, LocalSet};
use crate::util::IdleNotifiedSet;
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub struct JoinSet<T> {
inner: IdleNotifiedSet<JoinHandle<T>>,
}
#[cfg(all(tokio_unstable, feature = "tracing"))]
#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "tracing"))))]
#[must_use = "builders do nothing unless used to spawn a task"]
pub struct Builder<'a, T> {
joinset: &'a mut JoinSet<T>,
builder: super::Builder<'a>,
}
impl<T> JoinSet<T> {
pub fn new() -> Self {
Self {
inner: IdleNotifiedSet::new(),
}
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl<T: 'static> JoinSet<T> {
#[cfg(all(tokio_unstable, feature = "tracing"))]
#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "tracing"))))]
pub fn build_task(&mut self) -> Builder<'_, T> {
Builder {
builder: super::Builder::new(),
joinset: self,
}
}
#[track_caller]
pub fn spawn<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.insert(crate::spawn(task))
}
#[track_caller]
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
self.insert(handle.spawn(task))
}
#[track_caller]
pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T>,
F: 'static,
{
self.insert(crate::task::spawn_local(task))
}
#[track_caller]
pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet) -> AbortHandle
where
F: Future<Output = T>,
F: 'static,
{
self.insert(local_set.spawn_local(task))
}
#[track_caller]
pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send,
{
self.insert(crate::runtime::spawn_blocking(f))
}
#[track_caller]
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send,
{
self.insert(handle.spawn_blocking(f))
}
fn insert(&mut self, jh: JoinHandle<T>) -> AbortHandle {
let abort = jh.abort_handle();
let mut entry = self.inner.insert_idle(jh);
entry.with_value_and_context(|jh, ctx| jh.set_join_waker(ctx.waker()));
abort
}
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
crate::future::poll_fn(|cx| self.poll_join_next(cx)).await
}
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
crate::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
}
pub async fn shutdown(&mut self) {
self.abort_all();
while self.join_next().await.is_some() {}
}
pub fn abort_all(&mut self) {
self.inner.for_each(|jh| jh.abort());
}
pub fn detach_all(&mut self) {
self.inner.drain(drop);
}
pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, JoinError>>> {
let mut entry = match self.inner.pop_notified(cx.waker()) {
Some(entry) => entry,
None => {
if self.is_empty() {
return Poll::Ready(None);
} else {
return Poll::Pending;
}
}
};
let res = entry.with_value_and_context(|jh, ctx| Pin::new(jh).poll(ctx));
if let Poll::Ready(res) = res {
let _entry = entry.remove();
Poll::Ready(Some(res))
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn poll_join_next_with_id(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(Id, T), JoinError>>> {
let mut entry = match self.inner.pop_notified(cx.waker()) {
Some(entry) => entry,
None => {
if self.is_empty() {
return Poll::Ready(None);
} else {
return Poll::Pending;
}
}
};
let res = entry.with_value_and_context(|jh, ctx| Pin::new(jh).poll(ctx));
if let Poll::Ready(res) = res {
let entry = entry.remove();
Poll::Ready(Some(res.map(|output| (entry.id(), output))))
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
impl<T> Drop for JoinSet<T> {
fn drop(&mut self) {
self.inner.drain(|join_handle| join_handle.abort());
}
}
impl<T> fmt::Debug for JoinSet<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("JoinSet").field("len", &self.len()).finish()
}
}
impl<T> Default for JoinSet<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(all(tokio_unstable, feature = "tracing"))]
#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "tracing"))))]
impl<'a, T: 'static> Builder<'a, T> {
pub fn name(self, name: &'a str) -> Self {
let builder = self.builder.name(name);
Self { builder, ..self }
}
#[track_caller]
pub fn spawn<F>(self, future: F) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
Ok(self.joinset.insert(self.builder.spawn(future)?))
}
#[track_caller]
pub fn spawn_on<F>(self, future: F, handle: &Handle) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: Send + 'static,
T: Send,
{
Ok(self.joinset.insert(self.builder.spawn_on(future, handle)?))
}
#[track_caller]
pub fn spawn_local<F>(self, future: F) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: 'static,
{
Ok(self.joinset.insert(self.builder.spawn_local(future)?))
}
#[track_caller]
pub fn spawn_local_on<F>(self, future: F, local_set: &LocalSet) -> std::io::Result<AbortHandle>
where
F: Future<Output = T>,
F: 'static,
{
Ok(self
.joinset
.insert(self.builder.spawn_local_on(future, local_set)?))
}
}
#[cfg(all(tokio_unstable, feature = "tracing"))]
#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "tracing"))))]
impl<'a, T> fmt::Debug for Builder<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("join_set::Builder")
.field("joinset", &self.joinset)
.field("builder", &self.builder)
.finish()
}
}