use futures::{Future, future::poll_fn};
use std::{
collections::HashSet,
task::{Context, Poll, ready},
};
use tokio::task::{JoinError, JoinSet};
#[derive(Debug, Default)]
pub struct JoinMap<K, V> {
keys: HashSet<K>,
joinset: JoinSet<(K, V)>,
}
impl<K, V> JoinMap<K, V> {
pub fn new() -> Self {
Self { keys: HashSet::new(), joinset: JoinSet::new() }
}
pub fn len(&self) -> usize {
self.joinset.len()
}
pub fn is_empty(&self) -> bool {
self.joinset.is_empty()
}
}
impl<K, V> JoinMap<K, V>
where
K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
V: 'static,
{
pub fn spawn<F>(&mut self, key: K, future: F)
where
F: Future<Output = (K, V)> + Send + 'static,
V: Send,
{
if self.keys.insert(key) {
self.joinset.spawn(future);
}
}
pub fn contains_key(&self, key: &K) -> bool {
self.keys.contains(key)
}
pub async fn join_next(&mut self) -> Option<Result<(K, V), JoinError>> {
poll_fn(|cx| self.poll_join_next(cx)).await
}
pub fn poll_join_next(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(K, V), JoinError>>> {
match ready!(self.joinset.poll_join_next(cx)) {
Some(Ok((key, value))) => {
self.keys.remove(&key);
Poll::Ready(Some(Ok((key, value))))
}
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
}
}
}