use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use dashmap::DashMap;
pub struct PrefetchTracker {
inner: DashMap<String, AtomicU32, ahash::RandomState>,
}
impl PrefetchTracker {
pub fn new() -> Arc<Self> {
Arc::new(Self {
inner: DashMap::with_hasher(ahash::RandomState::default()),
})
}
pub fn increment(&self, domain: &str) {
if let Some(v) = self.inner.get(domain) {
v.fetch_add(1, Ordering::Relaxed);
} else {
self.inner
.entry(domain.to_owned())
.or_insert_with(|| AtomicU32::new(0))
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn take_hot(&self, threshold: u32) -> Vec<String> {
let hot: Vec<String> = self.inner
.iter()
.filter(|e| e.value().load(Ordering::Relaxed) >= threshold)
.map(|e| e.key().clone())
.collect();
self.inner.clear();
hot
}
#[cfg(test)]
pub fn count_for(&self, domain: &str) -> u32 {
self.inner
.get(domain)
.map(|v| v.load(Ordering::Relaxed))
.unwrap_or(0)
}
#[cfg(test)]
pub fn len(&self) -> usize {
self.inner.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn increment_and_count() {
let t = PrefetchTracker::new();
t.increment("example.com");
t.increment("example.com");
t.increment("other.net");
assert_eq!(t.count_for("example.com"), 2);
assert_eq!(t.count_for("other.net"), 1);
}
#[test]
fn take_hot_filters_by_threshold() {
let t = PrefetchTracker::new();
for _ in 0..5 { t.increment("hot.com"); }
t.increment("cold.com");
let hot = t.take_hot(5);
assert_eq!(hot, vec!["hot.com"]);
assert_eq!(t.count_for("hot.com"), 0);
assert_eq!(t.count_for("cold.com"), 0);
}
#[test]
fn take_hot_resets_all_counters() {
let t = PrefetchTracker::new();
t.increment("a.com");
t.increment("b.com");
let _ = t.take_hot(1);
assert_eq!(t.len(), 0);
}
}