use std::collections::HashMap;
use std::future::Future;
use std::hash::Hash;
use std::sync::{Arc, Mutex};
use event_listener::Event;
struct CallState<V> {
dups: usize,
result: Option<V>,
}
struct Call<V> {
event: Event,
state: Mutex<CallState<V>>,
}
struct GroupInner<K, V> {
calls: Mutex<HashMap<K, Arc<Call<V>>>>,
}
pub struct Group<K, V> {
inner: Arc<GroupInner<K, V>>,
}
impl<K, V> Clone for Group<K, V> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<K, V> Default for Group<K, V> {
fn default() -> Self {
Self::new()
}
}
enum Role<V> {
Leader(Arc<Call<V>>),
Follower(Arc<Call<V>>),
}
struct LeaderGuard<'a, K: Eq + Hash, V> {
inner: &'a GroupInner<K, V>,
key: K,
call: &'a Arc<Call<V>>,
done: bool,
}
impl<'a, K: Eq + Hash, V> Drop for LeaderGuard<'a, K, V> {
fn drop(&mut self) {
if !self.done {
self.inner.calls.lock().unwrap().remove(&self.key);
self.call.event.notify(usize::MAX);
}
}
}
impl<K, V> Group<K, V> {
pub fn new() -> Self {
Self {
inner: Arc::new(GroupInner {
calls: Mutex::new(HashMap::new()),
}),
}
}
}
impl<K, V> Group<K, V>
where
K: Eq + Hash + Clone,
V: Clone,
{
pub async fn do_<F>(&self, key: K, fut: F) -> (V, bool)
where
F: Future<Output = V>,
{
let mut fut = Some(fut);
loop {
let role = {
let mut calls = self.inner.calls.lock().unwrap();
match calls.get(&key) {
Some(call) => {
call.state.lock().unwrap().dups += 1;
Role::Follower(Arc::clone(call))
}
None => {
let call = Arc::new(Call {
event: Event::new(),
state: Mutex::new(CallState {
dups: 0,
result: None,
}),
});
calls.insert(key.clone(), Arc::clone(&call));
Role::Leader(call)
}
}
};
match role {
Role::Follower(call) => {
let listener = call.event.listen();
if let Some(v) = call.state.lock().unwrap().result.clone() {
return (v, true);
}
listener.await;
if let Some(v) = call.state.lock().unwrap().result.clone() {
return (v, true);
}
continue;
}
Role::Leader(call) => {
let mut guard = LeaderGuard {
inner: &self.inner,
key: key.clone(),
call: &call,
done: false,
};
let value = fut
.take()
.expect("leader future already consumed")
.await;
let shared = {
let mut calls = self.inner.calls.lock().unwrap();
calls.remove(&key);
let mut state = call.state.lock().unwrap();
state.result = Some(value.clone());
state.dups > 0
};
guard.done = true;
call.event.notify(usize::MAX);
return (value, shared);
}
}
}
}
pub fn do_chan<F>(&self, key: K, fut: F) -> crate::chan::Receiver<(V, bool)>
where
F: Future<Output = V> + Send + 'static,
K: Send + 'static,
V: Send + 'static,
{
let (tx, rx) = crate::chan::bounded::<(V, bool)>(1);
let group = self.clone();
crate::spawn(async move {
let result = group.do_(key, fut).await;
let _ = tx.send(result).await;
});
rx
}
pub fn forget(&self, key: &K) {
self.inner.calls.lock().unwrap().remove(key);
}
}