use std::{future::Future, mem, pin::Pin};
use futures::{stream, Stream, StreamExt};
use tokio::{
runtime::Handle,
select, spawn,
task::{spawn_blocking, spawn_local, JoinError, JoinHandle, LocalSet},
};
#[allow(unused_imports)] use tokio::task::JoinSet;
#[derive(Debug, Default)]
pub struct TwoJoinSet<T> {
first: Option<JoinHandle<T>>,
second: Option<JoinHandle<T>>,
}
impl<T> TwoJoinSet<T> {
pub const fn new() -> Self {
Self {
first: None,
second: None,
}
}
pub fn len(&self) -> usize {
match (&self.first, &self.second) {
(Some(_), Some(_)) => 2,
(Some(_), None) | (None, Some(_)) => 1,
(None, None) => 0,
}
}
pub fn is_empty(&self) -> bool {
matches!((&self.first, &self.second), (None, None))
}
pub fn first(&self) -> Option<&JoinHandle<T>> {
self.first.as_ref()
}
pub fn second(&self) -> Option<&JoinHandle<T>> {
self.second.as_ref()
}
pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
match (&mut self.first, &mut self.second) {
(Some(first), Some(second)) => {
let (first, second) = (Pin::new(first), Pin::new(second));
let (result, first_finished) = select! {
biased;
r = first => (r, true),
r = second => (r, false),
};
match first_finished {
true => self.first = self.second.take(),
false => self.second = None,
}
Some(result)
}
(handle @ Some(_), None) | (None, handle @ Some(_)) => {
Some(handle.take().unwrap().await)
}
(None, None) => None,
}
}
pub async fn shutdown(&mut self) {
self.abort_all();
for maybe in self.iter_raw_mut() {
if let Some(handle) = maybe.take() {
_ = handle.await;
}
}
}
#[inline]
pub fn abort_all(&mut self) {
self.iter().for_each(|handle| handle.abort())
}
pub fn detach_all(&mut self) {
self.iter_raw_mut().for_each(|maybe| _ = maybe.take())
}
pub fn try_join_both(&mut self) -> impl Stream<Item = Result<T, JoinError>> {
let mut maybe_handles = [None, None];
for (index, maybe) in self.iter_raw_mut().enumerate() {
match maybe {
Some(handle) if handle.is_finished() => maybe_handles[index] = mem::take(maybe),
_ => {}
}
}
self.fix_order();
let finished_handles = maybe_handles.into_iter().flatten();
stream::iter(finished_handles).then(|handle| handle)
}
#[inline]
fn fix_order(&mut self) {
if let (None, Some(_)) = (&mut self.first, &mut self.second) {
self.first = self.second.take();
}
}
#[inline]
pub fn push_handle(&mut self, join_handle: JoinHandle<T>) {
match self.first {
Some(_) => {
if let Some(handle) = self.second.replace(join_handle) {
handle.abort();
}
}
None => self.first = Some(join_handle),
}
}
pub fn iter(&self) -> impl DoubleEndedIterator<Item = &JoinHandle<T>> {
[&self.first, &self.second]
.into_iter()
.filter_map(|maybe| maybe.as_ref())
}
#[inline]
fn iter_raw_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Option<JoinHandle<T>>> {
[&mut self.first, &mut self.second].into_iter()
}
}
impl<T> TwoJoinSet<T>
where
T: Send + 'static,
{
pub fn spawn<F>(&mut self, task: F) -> impl Stream<Item = Result<T, JoinError>>
where
F: Future<Output = T> + Send + 'static,
{
let finished_results = self.try_join_both();
self.push_handle(spawn(task));
finished_results
}
pub fn spawn_on<F>(
&mut self,
task: F,
handle: &Handle,
) -> impl Stream<Item = Result<T, JoinError>>
where
F: Future<Output = T> + Send + 'static,
{
let finished_results = self.try_join_both();
self.push_handle(handle.spawn(task));
finished_results
}
pub fn spawn_blocking<F>(&mut self, f: F) -> impl Stream<Item = Result<T, JoinError>>
where
F: FnOnce() -> T + Send + 'static,
{
let finished_results = self.try_join_both();
self.push_handle(spawn_blocking(f));
finished_results
}
pub fn spawn_blocking_on<F>(
&mut self,
f: F,
handle: &Handle,
) -> impl Stream<Item = Result<T, JoinError>>
where
F: FnOnce() -> T + Send + 'static,
{
let finished_results = self.try_join_both();
self.push_handle(handle.spawn_blocking(f));
finished_results
}
}
impl<T> TwoJoinSet<T>
where
T: 'static,
{
pub fn spawn_local<F>(&mut self, task: F) -> impl Stream<Item = Result<T, JoinError>>
where
F: Future<Output = T> + 'static,
{
let finished_results = self.try_join_both();
self.push_handle(spawn_local(task));
finished_results
}
pub fn spawn_local_on<F>(
&mut self,
task: F,
local_set: &LocalSet,
) -> impl Stream<Item = Result<T, JoinError>>
where
F: Future<Output = T> + 'static,
{
let finished_results = self.try_join_both();
self.push_handle(local_set.spawn_local(task));
finished_results
}
}
impl<T> Drop for TwoJoinSet<T> {
fn drop(&mut self) {
self.abort_all();
}
}
#[cfg(test)]
mod tests;