use arc_swap::ArcSwapOption;
use dashmap::DashMap;
use derive_more::Debug;
use futures_delay_queue::{delay_queue, DelayHandle, DelayQueue, Receiver};
use futures_intrusive::buffer::GrowingHeapBuf;
use std::hash::Hash;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::time::Duration;
#[derive(Debug)]
pub struct ValueKeyPair<V> {
#[debug(skip)]
pub value: V,
pub key: ArcSwapOption<DelayHandle>,
}
impl<V> ValueKeyPair<V> {
pub fn new(value: V) -> Self {
Self {
value,
key: ArcSwapOption::default(),
}
}
}
#[derive(Debug)]
pub struct TimedMap<K: Ord + 'static, V> {
sender: DelayQueue<K, GrowingHeapBuf<K>>,
reciever: Receiver<K>,
pub inner: DashMap<K, ValueKeyPair<V>>,
disable_expiration: AtomicBool,
}
impl<K: Ord + 'static + Send + Hash, V> Default for TimedMap<K, V> {
fn default() -> Self {
let (sender, reciever) = delay_queue();
Self {
inner: DashMap::default(),
sender,
reciever,
disable_expiration: AtomicBool::default(),
}
}
}
impl<K: Ord, V> TimedMap<K, V> {
pub fn toggle_expiration(&self) {
let previous_state = self.disable_expiration.load(Ordering::Acquire);
let _ = self.disable_expiration.compare_exchange(
previous_state,
!previous_state,
Ordering::AcqRel,
Ordering::SeqCst,
);
}
}
impl<K: Ord + Clone + Send + 'static + Sync + Hash, V: Send + 'static + Sync> TimedMap<K, V> {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn insert_constant(&self, key: K, value: V) {
let pair = ValueKeyPair::new(value);
self.inner.insert(key, pair);
}
pub async fn insert_expirable(&self, key: K, value: V, timeout: Duration) {
let pair = ValueKeyPair::new(value);
if !self.disable_expiration.load(Ordering::Acquire) {
if let Some(old_handle) = self
.inner
.get(&key)
.and_then(|entry| entry.value().key.swap(None).and_then(Arc::into_inner))
{
let _ = old_handle.cancel().await;
}
let next_handle = self.sender.insert(key.clone(), timeout);
pair.key.store(Some(next_handle.into()));
}
self.inner.insert(key, pair);
}
pub fn len_expired(&self) -> usize {
self.inner
.iter()
.filter(|entry| entry.value().key.load().is_some())
.count()
}
pub fn remove(&self, key: &K) {
let _ = self.inner.remove(key);
}
pub async fn update_expiration_status(
&self,
key: &K,
duration: Duration,
) -> Option<Arc<DelayHandle>> {
let found = self.inner.get(key)?;
let previous_handle = found.value().key.swap(None).and_then(Arc::into_inner)?;
let next_handle = Arc::new(previous_handle.reset(duration).await.ok()?);
found.value().key.store(Some(next_handle.clone()));
drop(found);
Some(next_handle)
}
pub fn expires_entries(&self) -> bool {
!self.disable_expiration.load(Ordering::Acquire)
}
pub fn clear(&self) {
self.inner.clear();
self.reciever.close();
}
pub async fn purge_expired(&self) {
use tokio_util::time::FutureExt;
if !self.expires_entries() {
return;
}
let timeout = Duration::from_millis(1);
while let Ok(Some(expired)) = self.reciever.receive().timeout(timeout).await {
self.inner.remove(&expired);
}
}
}
#[cfg(test)]
mod tests {
use super::TimedMap;
use std::sync::Arc;
use tokio::time::{sleep, Duration};
#[tokio::test(flavor = "multi_thread")]
async fn test_purge_removes_expired() {
let map: Arc<TimedMap<u64, u64>> = Arc::new(TimedMap::new());
map.insert_expirable(1, 100, Duration::from_millis(50))
.await;
sleep(Duration::from_millis(80)).await;
map.purge_expired().await;
assert!(!map.inner.contains_key(&1));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_concurrent_inserts_and_purge() {
let map = Arc::new(TimedMap::new());
let mut handles = Vec::new();
for i in 0..50u64 {
let m = Arc::clone(&map);
handles.push(tokio::spawn(async move {
for j in 0..10u64 {
let k = i * 100 + j;
m.insert_expirable(k, k, Duration::from_millis(30)).await;
}
}));
}
for h in handles {
let _ = h.await;
}
sleep(Duration::from_millis(60)).await;
map.purge_expired().await;
assert_eq!(map.len_expired(), 0);
}
}