use crate::cache::Cache;
use std::collections::HashMap;
use std::future::Future;
use std::hash::Hash;
use std::num::NonZeroUsize;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Weak;
use thiserror::Error;
use tokio::sync::broadcast;
use tokio::sync::Mutex;
pub type DeduplicateFuture<V> = Pin<Box<dyn Future<Output = Option<V>> + Send>>;
type WaitMap<K, V> = Arc<Mutex<HashMap<K, Weak<broadcast::Sender<Option<V>>>>>>;
const DEFAULT_CACHE_CAPACITY: usize = 512;
#[derive(Debug, Error)]
pub enum DeduplicateError {
#[error("Delegated get failed")]
Failed,
#[error("Cache not enabled")]
NoCache,
}
#[derive(Clone)]
pub struct Deduplicate<G, K, V>
where
G: Fn(K) -> DeduplicateFuture<V>,
K: Clone + Send + Eq + Hash,
V: Clone + Send,
{
delegate: G,
storage: Option<Cache<K, V>>,
wait_map: WaitMap<K, V>,
}
impl<G, K, V> Deduplicate<G, K, V>
where
G: Fn(K) -> DeduplicateFuture<V>,
K: Clone + Send + Eq + Hash + 'static,
V: Clone + Send + 'static,
{
pub fn new(delegate: G) -> Self {
Self::with_capacity(delegate, DEFAULT_CACHE_CAPACITY)
}
pub fn with_capacity(delegate: G, capacity: usize) -> Self {
let storage = if capacity > 0 {
let val = unsafe { NonZeroUsize::new_unchecked(capacity) };
Some(Cache::new(val))
} else {
None
};
Self {
delegate,
wait_map: Arc::new(Mutex::new(HashMap::new())),
storage,
}
}
pub fn clear(&self) {
if let Some(storage) = &self.storage {
storage.clear();
}
}
pub fn count(&self) -> usize {
match &self.storage {
Some(s) => s.count(),
None => 0,
}
}
#[allow(clippy::await_holding_lock)]
pub async fn get(&self, key: K) -> Result<Option<V>, DeduplicateError> {
let mut locked_wait_map = self.wait_map.lock().await;
match locked_wait_map.get(&key) {
Some(weak) => {
if let Some(strong) = weak.upgrade() {
let mut receiver = strong.subscribe();
drop(strong);
drop(locked_wait_map);
receiver.recv().await.map_err(|_| DeduplicateError::Failed)
} else {
let _ = locked_wait_map.remove(&key);
Err(DeduplicateError::Failed)
}
}
None => {
let (sender, mut receiver) = broadcast::channel(1);
let sender = Arc::new(sender);
locked_wait_map.insert(key.clone(), Arc::downgrade(&sender));
drop(locked_wait_map);
if let Some(storage) = &self.storage {
if let Some(value) = storage.get(&key) {
let mut locked_wait_map = self.wait_map.lock().await;
let _ = locked_wait_map.remove(&key);
let _ = sender.send(Some(value.clone()));
return Ok(Some(value));
}
}
let fut = (self.delegate)(key.clone());
let k = key.clone();
let wait_map = self.wait_map.clone();
tokio::spawn(async move {
let value = fut.await;
let mut locked_wait_map = wait_map.lock().await;
let _ = locked_wait_map.remove(&k);
let _ = sender.send(value);
});
let result = receiver.recv().await.map_err(|_| DeduplicateError::Failed);
let mut locked_wait_map = self.wait_map.lock().await;
let _ = locked_wait_map.remove(&key);
let res = result?;
if let Some(storage) = &self.storage {
if let Some(v) = &res {
storage.insert(key, v.clone());
}
}
Ok(res)
}
}
}
pub fn insert(&self, key: K, value: V) -> Result<(), DeduplicateError> {
if let Some(storage) = &self.storage {
storage.insert(key, value);
Ok(())
} else {
Err(DeduplicateError::NoCache)
}
}
pub fn set_delegate(&mut self, delegate: G) {
self.clear();
self.delegate = delegate;
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use std::time::Instant;
fn get(_key: usize) -> DeduplicateFuture<String> {
let fut = async {
let num = rand::thread_rng().gen_range(1000..2000);
tokio::time::sleep(tokio::time::Duration::from_millis(num)).await;
if num % 2 == 0 {
panic!("BAD NUMBER");
}
Some("test".to_string())
};
Box::pin(fut)
}
async fn test_harness<G>(deduplicate: Deduplicate<G, usize, String>)
where
G: Fn(usize) -> DeduplicateFuture<String>,
{
let no_panic_get = |_x: usize| async {
let num = rand::thread_rng().gen_range(1000..2000);
tokio::time::sleep(tokio::time::Duration::from_millis(num)).await;
Some("test".to_string())
};
let deduplicate = Arc::new(deduplicate);
for i in 1..6 {
let mut dedup_hdls = vec![];
let mut slower_hdls = vec![];
let start = Instant::now();
for _i in 0..100 {
let my_deduplicate = deduplicate.clone();
dedup_hdls.push(async move {
let is_ok = my_deduplicate.get(5).await.is_ok();
(Instant::now(), is_ok)
});
slower_hdls.push(async move {
let is_ok = (no_panic_get)(5).await.is_some();
(Instant::now(), is_ok)
});
}
let mut dedup_result: Vec<(Instant, bool)> = futures::future::join_all(dedup_hdls)
.await
.into_iter()
.collect();
dedup_result.sort();
let mut slower_result: Vec<(Instant, bool)> = futures::future::join_all(slower_hdls)
.await
.into_iter()
.collect();
slower_result.sort();
let dedup_range = dedup_result.last().unwrap().0 - dedup_result.first().unwrap().0;
let slower_range = slower_result.last().unwrap().0 - slower_result.first().unwrap().0;
println!("iteration: {}", i);
println!("dedup_range: {:?}", dedup_range);
println!("slower_range: {:?}", slower_range);
assert!(dedup_range <= slower_range);
let dedup_passed = dedup_result
.iter()
.fold(0, |acc, x| if x.1 { acc + 1 } else { acc });
let slower_passed = slower_result
.iter()
.fold(0, |acc, x| if x.1 { acc + 1 } else { acc });
assert!(dedup_passed == 0 || dedup_passed == 100);
assert_eq!(slower_passed, 100);
assert!(dedup_passed <= slower_passed);
println!("dedup passed: {:?}", dedup_passed);
println!("slower passed: {:?}", slower_passed);
println!("elapsed: {:?}\n", Instant::now() - start);
}
}
#[tokio::test]
async fn it_deduplicates_correctly_with_cache() {
let no_panic_get = |_x: usize| -> DeduplicateFuture<String> {
let fut = async {
let num = rand::thread_rng().gen_range(1000..2000);
tokio::time::sleep(tokio::time::Duration::from_millis(num)).await;
Some("test".to_string())
};
Box::pin(fut)
};
test_harness(Deduplicate::new(no_panic_get)).await
}
#[tokio::test]
async fn it_deduplicates_correctly_without_cache() {
test_harness(Deduplicate::with_capacity(get, 0)).await
}
}