use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CacheKey {
pub method: String,
pub url: String,
pub body_hash: u64,
}
impl CacheKey {
pub fn new(method: &str, url: &str, body: Option<&[u8]>) -> Self {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
body.unwrap_or(&[]).hash(&mut hasher);
Self {
method: method.to_ascii_uppercase(),
url: url.to_string(),
body_hash: hasher.finish(),
}
}
}
#[derive(Clone)]
struct CacheEntry {
value: serde_json::Value,
stored_at: Instant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheOutcome {
Hit,
Miss,
Expired,
}
#[derive(Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub expirations: u64,
}
impl std::fmt::Debug for CacheStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CacheStats")
.field("hits", &self.hits)
.field("misses", &self.misses)
.field("expirations", &self.expirations)
.finish()
}
}
#[derive(Clone)]
pub struct HttpResponseCache {
inner: Arc<HttpResponseCacheInner>,
}
struct HttpResponseCacheInner {
entries: RwLock<HashMap<CacheKey, CacheEntry>>,
stats: RwLock<CacheStats>,
ttl: Duration,
}
impl HttpResponseCache {
pub fn new(ttl: Duration) -> Self {
Self {
inner: Arc::new(HttpResponseCacheInner {
entries: RwLock::new(HashMap::new()),
stats: RwLock::new(CacheStats::default()),
ttl,
}),
}
}
pub fn is_disabled(&self) -> bool {
self.inner.ttl.is_zero()
}
pub fn ttl(&self) -> Duration {
self.inner.ttl
}
pub fn lookup(&self, key: &CacheKey) -> (CacheOutcome, Option<serde_json::Value>) {
if self.is_disabled() {
self.inner.stats.write().misses += 1;
return (CacheOutcome::Miss, None);
}
{
let map = self.inner.entries.read();
if let Some(entry) = map.get(key) {
if entry.stored_at.elapsed() <= self.inner.ttl {
self.inner.stats.write().hits += 1;
return (CacheOutcome::Hit, Some(entry.value.clone()));
}
} else {
self.inner.stats.write().misses += 1;
return (CacheOutcome::Miss, None);
}
}
let mut map = self.inner.entries.write();
if let Some(entry) = map.get(key) {
if entry.stored_at.elapsed() > self.inner.ttl {
map.remove(key);
self.inner.stats.write().expirations += 1;
return (CacheOutcome::Expired, None);
}
self.inner.stats.write().hits += 1;
return (CacheOutcome::Hit, Some(entry.value.clone()));
}
self.inner.stats.write().misses += 1;
(CacheOutcome::Miss, None)
}
pub fn insert(&self, key: CacheKey, value: serde_json::Value) {
if self.is_disabled() {
return;
}
self.inner.entries.write().insert(
key,
CacheEntry {
value,
stored_at: Instant::now(),
},
);
}
pub fn evict_expired(&self) -> usize {
if self.is_disabled() {
return 0;
}
let mut map = self.inner.entries.write();
let before = map.len();
let ttl = self.inner.ttl;
map.retain(|_, e| e.stored_at.elapsed() <= ttl);
let removed = before - map.len();
if removed > 0 {
self.inner.stats.write().expirations += removed as u64;
}
removed
}
pub fn stats(&self) -> (u64, u64, u64) {
let s = self.inner.stats.read();
(s.hits, s.misses, s.expirations)
}
pub fn len(&self) -> usize {
self.inner.entries.read().len()
}
pub fn is_empty(&self) -> bool {
self.inner.entries.read().is_empty()
}
}
impl std::fmt::Debug for HttpResponseCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let (h, m, x) = self.stats();
f.debug_struct("HttpResponseCache")
.field("ttl", &self.inner.ttl)
.field("len", &self.len())
.field("hits", &h)
.field("misses", &m)
.field("expirations", &x)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn disabled_cache_always_misses() {
let cache = HttpResponseCache::new(Duration::from_secs(0));
assert!(cache.is_disabled());
let key = CacheKey::new("GET", "https://x", None);
cache.insert(key.clone(), serde_json::json!("v"));
let (out, val) = cache.lookup(&key);
assert_eq!(out, CacheOutcome::Miss);
assert!(val.is_none());
}
#[test]
fn hit_then_miss_after_ttl() {
let cache = HttpResponseCache::new(Duration::from_millis(50));
let key = CacheKey::new("GET", "https://x", None);
cache.insert(key.clone(), serde_json::json!("v"));
let (out, val) = cache.lookup(&key);
assert_eq!(out, CacheOutcome::Hit);
assert_eq!(val, Some(serde_json::json!("v")));
std::thread::sleep(Duration::from_millis(80));
let (out, val) = cache.lookup(&key);
assert_eq!(out, CacheOutcome::Expired);
assert!(val.is_none());
}
#[test]
fn body_hash_separates_keys() {
let a = CacheKey::new("POST", "https://x", Some(b"a"));
let b = CacheKey::new("POST", "https://x", Some(b"b"));
assert_ne!(a, b);
let cache = HttpResponseCache::new(Duration::from_secs(60));
cache.insert(a.clone(), serde_json::json!(1));
let (out, _) = cache.lookup(&b);
assert_eq!(out, CacheOutcome::Miss);
}
#[test]
fn method_difference_separates_keys() {
let a = CacheKey::new("GET", "https://x", None);
let b = CacheKey::new("POST", "https://x", None);
assert_ne!(a, b);
}
#[test]
fn evict_expired_drops_old_entries() {
let cache = HttpResponseCache::new(Duration::from_millis(20));
for i in 0..5 {
cache.insert(
CacheKey::new("GET", &format!("https://x/{i}"), None),
serde_json::json!(i),
);
}
std::thread::sleep(Duration::from_millis(40));
let evicted = cache.evict_expired();
assert_eq!(evicted, 5);
assert_eq!(cache.len(), 0);
}
#[test]
fn stats_counters_increment() {
let cache = HttpResponseCache::new(Duration::from_secs(60));
let key = CacheKey::new("GET", "https://x", None);
let (_, _) = cache.lookup(&key);
cache.insert(key.clone(), serde_json::json!("v"));
let (_, _) = cache.lookup(&key);
let (_, _) = cache.lookup(&key);
let (hits, misses, _exp) = cache.stats();
assert_eq!(hits, 2);
assert_eq!(misses, 1);
}
}