const DEFAULT_DELAY: u64 = 30;
use futures::prelude::*;
use std::{
collections::HashMap,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
pub use tokio::time::Instant;
use tokio_util::time::delay_queue::{self, DelayQueue};
#[derive(Debug)]
pub struct HashSetDelay<K>
where
K: std::cmp::Eq + std::hash::Hash + std::clone::Clone + Unpin,
{
entries: HashMap<K, delay_queue::Key>,
expirations: DelayQueue<K>,
default_entry_timeout: Duration,
}
impl<K> Default for HashSetDelay<K>
where
K: std::cmp::Eq + std::hash::Hash + std::clone::Clone + Unpin,
{
fn default() -> Self {
HashSetDelay::new(Duration::from_secs(DEFAULT_DELAY))
}
}
impl<K> HashSetDelay<K>
where
K: std::cmp::Eq + std::hash::Hash + std::clone::Clone + Unpin,
{
pub fn new(default_entry_timeout: Duration) -> Self {
HashSetDelay {
entries: HashMap::new(),
expirations: DelayQueue::new(),
default_entry_timeout,
}
}
pub fn with_capacity(default_entry_timeout: Duration, capacity: usize) -> Self {
HashSetDelay {
entries: HashMap::with_capacity(capacity),
expirations: DelayQueue::with_capacity(capacity),
default_entry_timeout,
}
}
pub fn insert(&mut self, key: K) {
self.insert_at(key, self.default_entry_timeout);
}
pub fn insert_at(&mut self, key: K, entry_duration: Duration) {
if let Some(entry) = self.entries.get(&key) {
self.expirations.reset(entry, entry_duration);
} else {
let delay_key = self.expirations.insert(key.clone(), entry_duration);
self.entries.insert(key, delay_key);
}
}
pub fn update_timeout(&mut self, key: &K, timeout: Duration) -> bool {
if let Some(delay_key) = self.entries.get_mut(key) {
self.expirations.reset(delay_key, timeout);
true
} else {
false
}
}
pub fn contains_key(&self, key: &K) -> bool {
self.entries.contains_key(key)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn remove(&mut self, key: &K) -> bool {
if let Some(delay_key) = self.entries.remove(key) {
self.expirations.remove(&delay_key);
true
} else {
false
}
}
#[allow(dead_code)]
pub fn clear(&mut self) {
self.entries.clear();
self.expirations.clear();
}
pub fn iter(&self) -> impl Iterator<Item = &K> {
self.entries.keys()
}
pub fn shrink_to_fit(&mut self) {
self.entries.shrink_to_fit();
}
pub fn shrink_to(&mut self, capacity: usize) {
self.entries.shrink_to(capacity);
}
pub fn deadline(&self, key: &K) -> Option<Instant> {
self.entries
.get(key)
.map(|queue_key| self.expirations.deadline(queue_key))
}
pub fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<K, String>>> {
match self.expirations.poll_expired(cx) {
Poll::Ready(Some(key)) => match self.entries.remove(key.get_ref()) {
Some(_delay_key) => Poll::Ready(Some(Ok(key.into_inner()))),
None => Poll::Ready(Some(Err("Value no longer exists in expirations".into()))),
},
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl<K> Stream for HashSetDelay<K>
where
K: std::cmp::Eq + std::hash::Hash + std::clone::Clone + Unpin,
{
type Item = Result<K, String>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
HashSetDelay::poll_expired(self.get_mut(), cx)
}
}