use futures::prelude::*;
use std::{
collections::HashMap,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
pub use tokio::time::Instant;
use tokio_util::time::delay_queue::{self, DelayQueue};
const DEFAULT_DELAY: u64 = 30;
#[derive(Debug)]
pub struct HashMapDelay<K, V>
where
K: std::cmp::Eq + std::hash::Hash + std::clone::Clone,
{
entries: HashMap<K, MapEntry<V>>,
expirations: DelayQueue<K>,
default_entry_timeout: Duration,
}
#[derive(Debug)]
struct MapEntry<V> {
key: delay_queue::Key,
value: V,
}
impl<K, V> Default for HashMapDelay<K, V>
where
K: std::cmp::Eq + std::hash::Hash + std::clone::Clone,
{
fn default() -> Self {
HashMapDelay::new(Duration::from_secs(DEFAULT_DELAY))
}
}
impl<K, V> HashMapDelay<K, V>
where
K: std::cmp::Eq + std::hash::Hash + std::clone::Clone,
{
pub fn new(default_entry_timeout: Duration) -> Self {
HashMapDelay {
entries: HashMap::new(),
expirations: DelayQueue::new(),
default_entry_timeout,
}
}
pub fn with_capacity(default_entry_timeout: Duration, capacity: usize) -> Self {
HashMapDelay {
entries: HashMap::with_capacity(capacity),
expirations: DelayQueue::with_capacity(capacity),
default_entry_timeout,
}
}
pub fn insert(&mut self, key: K, value: V) {
self.insert_at(key, value, self.default_entry_timeout);
}
pub fn insert_at(&mut self, key: K, value: V, entry_duration: Duration) {
if let Some(entry) = self.entries.get(&key) {
self.expirations.reset(&entry.key, entry_duration);
} else {
let delay_key = self.expirations.insert(key.clone(), entry_duration);
let entry = MapEntry {
key: delay_key,
value,
};
self.entries.insert(key, entry);
}
}
pub fn get(&self, key: &K) -> Option<&V> {
self.entries.get(key).map(|entry| &entry.value)
}
pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
self.entries.get_mut(key).map(|entry| &mut entry.value)
}
pub fn contains_key(&self, key: &K) -> bool {
self.entries.contains_key(key)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn update_timeout(&mut self, key: &K, timeout: Duration) -> bool {
if let Some(entry) = self.entries.get(key) {
self.expirations.reset(&entry.key, timeout);
true
} else {
false
}
}
pub fn remove(&mut self, key: &K) -> Option<V> {
let entry = self.entries.remove(key)?;
self.expirations.remove(&entry.key);
Some(entry.value)
}
pub fn retain<F: FnMut(&K, &V) -> bool>(&mut self, mut f: F) {
let expiration = &mut self.expirations;
self.entries.retain(|key, entry| {
let result = f(key, &entry.value);
if !result {
expiration.remove(&entry.key);
}
result
})
}
pub fn clear(&mut self) {
self.entries.clear();
self.expirations.clear();
}
pub fn keys(&self) -> impl Iterator<Item = &K> {
self.entries.keys()
}
pub fn values(&self) -> impl Iterator<Item = &V> {
self.entries.values().map(|entry| &entry.value)
}
pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
self.entries.iter().map(|(k, entry)| (k, &entry.value))
}
pub fn shrink_to_fit(&mut self) {
self.entries.shrink_to_fit();
}
pub fn shrink_to(&mut self, capacity: usize) {
self.entries.shrink_to(capacity);
}
pub fn deadline(&self, key: &K) -> Option<Instant> {
self.entries
.get(key)
.map(|map_entry| self.expirations.deadline(&map_entry.key))
}
pub fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<(K, V), String>>> {
match self.expirations.poll_expired(cx) {
Poll::Ready(Some(key)) => match self.entries.remove(key.get_ref()) {
Some(entry) => Poll::Ready(Some(Ok((key.into_inner(), entry.value)))),
None => Poll::Ready(Some(Err("Value no longer exists in expirations".into()))),
},
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl<K, V> Stream for HashMapDelay<K, V>
where
K: std::cmp::Eq + std::hash::Hash + std::clone::Clone + Unpin,
V: Unpin,
{
type Item = Result<(K, V), String>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
HashMapDelay::poll_expired(self.get_mut(), cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn should_not_panic() {
let key = 2u8;
let value = 0;
let mut map = HashMapDelay::default();
map.insert(key, value);
map.update_timeout(&key, Duration::from_secs(100));
let fut = |cx: &mut Context<'_>| {
let _ = map.poll_next_unpin(cx);
let _ = map.poll_next_unpin(cx);
Poll::Ready(())
};
future::poll_fn(fut).await;
map.insert(key, value);
map.update_timeout(&key, Duration::from_secs(100));
}
#[tokio::test]
async fn basic_insert() {
let key = 2u8;
let value = 10;
let mut map = HashMapDelay::default();
map.insert(key, value);
assert!(map.contains_key(&key));
map.remove(&key);
assert!(!map.contains_key(&key));
map.insert_at(key, value, Duration::from_millis(50));
assert!(map.contains_key(&key));
match tokio::time::timeout_at(
tokio::time::Instant::now() + Duration::from_millis(100),
map.next(),
)
.await
{
Err(_) => panic!("Entry did not expire"),
Ok(Some(Ok((k, v)))) => {
assert_eq!(v, value);
assert_eq!(k, key);
}
Ok(Some(_)) => panic!("Polling the map failed"),
Ok(None) => panic!("Entry did not exist, stream terminated"),
}
assert_eq!(map.len(), 0);
}
#[tokio::test]
async fn insert_clear() {
let key = 2u8;
let value = 10;
let mut map = HashMapDelay::default();
map.insert(key, value);
assert!(map.contains_key(&key));
map.clear();
assert!(!map.contains_key(&key));
assert_eq!(map.expirations.len(), 0);
}
}