use hashbrown::hash_map::RawEntryMut;
use hashbrown::HashMap;
use std::borrow::Borrow;
use std::collections::hash_map::RandomState;
use std::fmt;
use std::future::Future;
use std::hash::{BuildHasher, Hash, Hasher};
use std::marker::PhantomData;
use tokio::runtime::Handle;
use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
pub struct JoinMap<K, V, S = RandomState> {
tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
hashes_by_task: HashMap<Id, u64, S>,
tasks: JoinSet<V>,
}
#[derive(Debug)]
struct Key<K> {
key: K,
id: Id,
}
impl<K, V> JoinMap<K, V> {
#[inline]
#[must_use]
pub fn new() -> Self {
Self::with_hasher(RandomState::new())
}
#[inline]
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
JoinMap::with_capacity_and_hasher(capacity, Default::default())
}
}
impl<K, V, S: Clone> JoinMap<K, V, S> {
#[inline]
#[must_use]
pub fn with_hasher(hash_builder: S) -> Self {
Self::with_capacity_and_hasher(0, hash_builder)
}
#[inline]
#[must_use]
pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
Self {
tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
tasks: JoinSet::new(),
}
}
pub fn len(&self) -> usize {
let len = self.tasks_by_key.len();
debug_assert_eq!(len, self.hashes_by_task.len());
len
}
pub fn is_empty(&self) -> bool {
let empty = self.tasks_by_key.is_empty();
debug_assert_eq!(empty, self.hashes_by_task.is_empty());
empty
}
#[inline]
pub fn capacity(&self) -> usize {
let capacity = self.tasks_by_key.capacity();
debug_assert_eq!(capacity, self.hashes_by_task.capacity());
capacity
}
}
impl<K, V, S> JoinMap<K, V, S>
where
K: Hash + Eq,
V: 'static,
S: BuildHasher,
{
#[track_caller]
pub fn spawn<F>(&mut self, key: K, task: F)
where
F: Future<Output = V>,
F: Send + 'static,
V: Send,
{
let task = self.tasks.spawn(task);
self.insert(key, task)
}
#[track_caller]
pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
where
F: Future<Output = V>,
F: Send + 'static,
V: Send,
{
let task = self.tasks.spawn_on(task, handle);
self.insert(key, task);
}
#[track_caller]
pub fn spawn_blocking<F>(&mut self, key: K, f: F)
where
F: FnOnce() -> V,
F: Send + 'static,
V: Send,
{
let task = self.tasks.spawn_blocking(f);
self.insert(key, task)
}
#[track_caller]
pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
where
F: FnOnce() -> V,
F: Send + 'static,
V: Send,
{
let task = self.tasks.spawn_blocking_on(f, handle);
self.insert(key, task);
}
#[track_caller]
pub fn spawn_local<F>(&mut self, key: K, task: F)
where
F: Future<Output = V>,
F: 'static,
{
let task = self.tasks.spawn_local(task);
self.insert(key, task);
}
#[track_caller]
pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
where
F: Future<Output = V>,
F: 'static,
{
let task = self.tasks.spawn_local_on(task, local_set);
self.insert(key, task)
}
fn insert(&mut self, key: K, abort: AbortHandle) {
let hash = self.hash(&key);
let id = abort.id();
let map_key = Key { id, key };
let entry = self
.tasks_by_key
.raw_entry_mut()
.from_hash(hash, |k| k.key == map_key.key);
match entry {
RawEntryMut::Occupied(mut occ) => {
let Key { id: prev_id, .. } = occ.insert_key(map_key);
occ.insert(abort).abort();
let _prev_hash = self.hashes_by_task.remove(&prev_id);
debug_assert_eq!(Some(hash), _prev_hash);
}
RawEntryMut::Vacant(vac) => {
vac.insert(map_key, abort);
}
};
let _prev = self.hashes_by_task.insert(id, hash);
debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
}
pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
let (res, id) = match self.tasks.join_next_with_id().await {
Some(Ok((id, output))) => (Ok(output), id),
Some(Err(e)) => {
let id = e.id();
(Err(e), id)
}
None => return None,
};
let key = self.remove_by_id(id)?;
Some((key, res))
}
pub async fn shutdown(&mut self) {
self.abort_all();
while self.join_next().await.is_some() {}
}
pub fn abort<Q: ?Sized>(&mut self, key: &Q) -> bool
where
Q: Hash + Eq,
K: Borrow<Q>,
{
match self.get_by_key(key) {
Some((_, handle)) => {
handle.abort();
true
}
None => false,
}
}
pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
for (Key { ref key, .. }, task) in &self.tasks_by_key {
if predicate(key) {
task.abort();
}
}
}
pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
JoinMapKeys {
iter: self.tasks_by_key.keys(),
_value: PhantomData,
}
}
pub fn contains_key<Q: ?Sized>(&self, key: &Q) -> bool
where
Q: Hash + Eq,
K: Borrow<Q>,
{
self.get_by_key(key).is_some()
}
pub fn contains_task(&self, task: &Id) -> bool {
self.get_by_id(task).is_some()
}
#[inline]
pub fn reserve(&mut self, additional: usize) {
self.tasks_by_key.reserve(additional);
self.hashes_by_task.reserve(additional);
}
#[inline]
pub fn shrink_to_fit(&mut self) {
self.hashes_by_task.shrink_to_fit();
self.tasks_by_key.shrink_to_fit();
}
#[inline]
pub fn shrink_to(&mut self, min_capacity: usize) {
self.hashes_by_task.shrink_to(min_capacity);
self.tasks_by_key.shrink_to(min_capacity)
}
fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
where
Q: Hash + Eq,
K: Borrow<Q>,
{
let hash = self.hash(key);
self.tasks_by_key
.raw_entry()
.from_hash(hash, |k| k.key.borrow() == key)
}
fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
let hash = self.hashes_by_task.get(id)?;
self.tasks_by_key
.raw_entry()
.from_hash(*hash, |k| &k.id == id)
}
fn remove_by_id(&mut self, id: Id) -> Option<K> {
let hash = self.hashes_by_task.remove(&id)?;
let entry = self
.tasks_by_key
.raw_entry_mut()
.from_hash(hash, |k| k.id == id);
let (Key { id: _key_id, key }, handle) = match entry {
RawEntryMut::Occupied(entry) => entry.remove_entry(),
_ => return None,
};
debug_assert_eq!(_key_id, id);
debug_assert_eq!(id, handle.id());
self.hashes_by_task.remove(&id);
Some(key)
}
#[inline]
fn hash<Q: ?Sized>(&self, key: &Q) -> u64
where
Q: Hash,
{
let mut hasher = self.tasks_by_key.hasher().build_hasher();
key.hash(&mut hasher);
hasher.finish()
}
}
impl<K, V, S> JoinMap<K, V, S>
where
V: 'static,
{
pub fn abort_all(&mut self) {
self.tasks.abort_all()
}
pub fn detach_all(&mut self) {
self.tasks.detach_all();
self.tasks_by_key.clear();
self.hashes_by_task.clear();
}
}
impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map()
.entries(self.0.keys().map(|Key { key, id }| (key, id)))
.finish()
}
}
f.debug_struct("JoinMap")
.field("tasks", &KeySet(&self.tasks_by_key))
.finish()
}
}
impl<K, V> Default for JoinMap<K, V> {
fn default() -> Self {
Self::new()
}
}
impl<K: Hash> Hash for Key<K> {
#[inline]
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.key.hash(hasher);
}
}
impl<K: PartialEq> PartialEq for Key<K> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}
impl<K: Eq> Eq for Key<K> {}
#[derive(Debug, Clone)]
pub struct JoinMapKeys<'a, K, V> {
iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
_value: PhantomData<&'a V>,
}
impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
type Item = &'a K;
fn next(&mut self) -> Option<&'a K> {
self.iter.next().map(|key| &key.key)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
fn len(&self) -> usize {
self.iter.len()
}
}
impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}