use std::{
fmt,
future::Future,
sync::atomic::{AtomicUsize, Ordering},
sync::Arc,
task::{Context, Poll},
};
use tokio::{sync::Semaphore, task::JoinSet as TokioJoinSet};
use crate::tokio_exports::{AbortHandle, Handle, Id, JoinError, LocalSet};
pub struct JoinSet<T> {
num_inactive_tasks: Arc<AtomicUsize>,
active_semaphore: Arc<Semaphore>,
inner_join_set: TokioJoinSet<T>,
concurrency: usize,
}
impl<T> JoinSet<T> {
pub fn new(concurrency: usize) -> Self {
Self {
num_inactive_tasks: Arc::new(AtomicUsize::new(0)),
inner_join_set: TokioJoinSet::new(),
active_semaphore: Arc::new(Semaphore::new(concurrency)),
concurrency,
}
}
pub fn len(&self) -> usize {
self.inner_join_set.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn num_active(&self) -> usize {
self.concurrency - self.active_semaphore.available_permits()
}
pub fn num_queued(&self) -> usize {
self.num_inactive_tasks.load(Ordering::Acquire)
}
pub fn num_completed(&self) -> usize {
self.len() - self.num_active() - self.num_queued()
}
pub const MAX_CONCURRENCY: usize = Semaphore::MAX_PERMITS;
}
impl<T: 'static> JoinSet<T> {
fn wrap_task<F>(&self, task: F) -> impl Future<Output = T> + 'static
where
F: Future<Output = T> + 'static,
{
self.num_inactive_tasks.fetch_add(1, Ordering::Release);
let task_semaphore = self.active_semaphore.clone();
let task_inactive_count = self.num_inactive_tasks.clone();
async move {
let _permit = task_semaphore.acquire_owned().await.unwrap();
task_inactive_count.fetch_sub(1, Ordering::Release);
task.await
}
}
pub fn spawn<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T> + Send + 'static,
T: Send,
{
self.inner_join_set.spawn(self.wrap_task(task))
}
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
where
F: Future<Output = T> + Send + 'static,
T: Send,
{
self.inner_join_set.spawn_on(self.wrap_task(task), handle)
}
pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
where
F: Future<Output = T> + 'static,
{
self.inner_join_set.spawn_local(self.wrap_task(task))
}
pub fn spawn_local_on<F>(&mut self, task: F, local_set: &LocalSet) -> AbortHandle
where
F: Future<Output = T> + 'static,
{
self.inner_join_set
.spawn_local_on(self.wrap_task(task), local_set)
}
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
self.inner_join_set.join_next().await
}
pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
self.inner_join_set.join_next_with_id().await
}
pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
self.inner_join_set.try_join_next()
}
pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
self.inner_join_set.try_join_next_with_id()
}
pub async fn join_all(self) -> Vec<T> {
self.inner_join_set.join_all().await
}
pub async fn shutdown(&mut self) {
self.inner_join_set.shutdown().await;
}
pub fn abort_all(&mut self) {
self.inner_join_set.abort_all();
}
pub fn detach_all(&mut self) {
self.inner_join_set.detach_all();
}
pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, JoinError>>> {
self.inner_join_set.poll_join_next(cx)
}
pub fn poll_join_next_with_id(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(Id, T), JoinError>>> {
self.inner_join_set.poll_join_next_with_id(cx)
}
}
impl<T> Default for JoinSet<T> {
fn default() -> Self {
Self::new(8)
}
}
impl<T> fmt::Debug for JoinSet<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("JoinSet")
.field("len", &self.len())
.field("active_tasks", &self.num_active())
.field("queued_tasks", &self.num_queued())
.field("num_completed", &self.num_completed())
.field("concurrency", &self.concurrency)
.finish()
}
}