use std::borrow::Borrow;
use std::collections::HashMap;
use std::hash::BuildHasher;
use std::hash::Hash;
use std::hash::RandomState;
use std::sync::Arc;
use crate::internal::Mutex;
use crate::once::OnceCell;
#[cfg(test)]
mod tests;
#[derive(Debug)]
pub struct Group<K, V, S = RandomState> {
map: Mutex<HashMap<K, Arc<OnceCell<V>>, S>>,
}
impl<K, V, S> Default for Group<K, V, S>
where
K: Eq + Hash + Clone,
V: Clone,
S: BuildHasher + Clone + Default,
{
fn default() -> Self {
Self::with_hasher(S::default())
}
}
impl<K, V> Group<K, V, RandomState>
where
K: Eq + Hash + Clone,
V: Clone,
{
pub fn new() -> Self {
Self {
map: Mutex::new(HashMap::new()),
}
}
}
impl<K, V, S> Group<K, V, S>
where
K: Eq + Hash + Clone,
V: Clone,
S: BuildHasher + Clone,
{
pub fn with_hasher(hasher: S) -> Self {
Self {
map: Mutex::new(HashMap::with_hasher(hasher)),
}
}
pub async fn work<F>(&self, key: K, func: F) -> V
where
F: AsyncFnOnce() -> V,
{
let cell = {
let mut map = self.map.lock();
map.entry(key.clone())
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone()
};
let res = cell
.get_or_init(async || {
let result = func().await;
let mut map = self.map.lock();
if let Some(existing) = map.get(&key) {
if Arc::ptr_eq(&cell, existing) {
map.remove(&key);
}
}
result
})
.await;
res.clone()
}
pub async fn try_work<E, F>(&self, key: K, func: F) -> Result<V, E>
where
F: AsyncFnOnce() -> Result<V, E>,
{
let cell = {
let mut map = self.map.lock();
map.entry(key.clone())
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone()
};
let res = cell
.get_or_try_init(async || {
let result = func().await?;
let mut map = self.map.lock();
if let Some(existing) = map.get(&key) {
if Arc::ptr_eq(&cell, existing) {
map.remove(&key);
}
}
Ok(result)
})
.await?;
Ok(res.clone())
}
pub fn forget<Q>(&self, key: &Q)
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let mut map = self.map.lock();
map.remove(key);
}
}