use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use axum::body::{to_bytes, Body};
use axum::http::HeaderValue;
use axum::response::Response;
use dashmap::DashMap;
use futures::future::BoxFuture;
use once_cell::sync::Lazy;
use crate::web::context::RequestContext;
use crate::web::interceptors::{Interceptor, NextHandler};
#[derive(Clone)]
struct CachedEntry {
bytes: bytes::Bytes,
content_type: Option<HeaderValue>,
status: u16,
inserted_at: Instant,
ttl: Duration,
}
impl CachedEntry {
#[inline]
fn is_fresh(&self) -> bool {
self.inserted_at.elapsed() < self.ttl
}
}
static STORE: Lazy<DashMap<String, CachedEntry>> = Lazy::new(DashMap::new);
static HITS: AtomicU64 = AtomicU64::new(0);
static MISSES: AtomicU64 = AtomicU64::new(0);
static REJECTED: AtomicU64 = AtomicU64::new(0);
static CAPACITY: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(10_000);
#[doc(hidden)]
pub fn set_capacity(max_entries: usize) {
CAPACITY.store(max_entries, Ordering::Relaxed);
}
pub(crate) fn sweep_expired() {
STORE.retain(|_, e| e.is_fresh());
}
pub(crate) fn spawn_sweeper(interval: Duration) {
tokio::spawn(async move {
let mut tick = tokio::time::interval(interval);
tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tick.tick().await;
sweep_expired();
}
});
}
#[derive(Debug, Clone, Copy, serde::Serialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: u64,
pub rejected: u64,
}
#[inline]
pub fn stats() -> CacheStats {
CacheStats {
hits: HITS.load(Ordering::Relaxed),
misses: MISSES.load(Ordering::Relaxed),
entries: STORE.len() as u64,
rejected: REJECTED.load(Ordering::Relaxed),
}
}
pub fn clear() {
STORE.clear();
}
#[inline]
fn key_for(ctx: &RequestContext) -> String {
if let Some(spec) = ctx.route_spec() {
if !spec.cache_key.is_empty() {
return spec.cache_key.to_owned();
}
}
let q = ctx.query_string().unwrap_or("");
if q.is_empty() {
format!("{} {}", ctx.method(), ctx.path())
} else {
format!("{} {}?{}", ctx.method(), ctx.path(), q)
}
}
pub struct CacheInterceptor;
impl Interceptor for CacheInterceptor {
fn around(
&'static self,
ctx: RequestContext,
next: NextHandler,
) -> BoxFuture<'static, Response> {
Box::pin(async move {
let ttl_secs = ctx.route_spec().map(|s| s.cache_ttl_secs).unwrap_or(0);
if ttl_secs == 0 {
return next.run(ctx).await;
}
let key = key_for(&ctx);
if let Some(entry) = STORE.get(&key) {
if entry.is_fresh() {
HITS.fetch_add(1, Ordering::Relaxed);
let mut resp = Response::builder()
.status(entry.status)
.header("x-cache", "HIT")
.body(Body::from(entry.bytes.clone()))
.expect("cache hit response builds");
if let Some(ct) = entry.content_type.clone() {
resp.headers_mut().insert("content-type", ct);
}
return resp;
} else {
drop(entry);
STORE.remove(&key);
}
}
MISSES.fetch_add(1, Ordering::Relaxed);
let resp = next.run(ctx).await;
if !resp.status().is_success() {
return resp;
}
let (mut parts, body) = resp.into_parts();
const MAX_CACHE_BODY: usize = 8 * 1024 * 1024;
let bytes = to_bytes(body, MAX_CACHE_BODY).await.unwrap_or_default();
let content_type = parts.headers.get("content-type").cloned();
if STORE.len() < CAPACITY.load(Ordering::Relaxed) || STORE.contains_key(&key) {
STORE.insert(
key,
CachedEntry {
bytes: bytes.clone(),
content_type,
status: parts.status.as_u16(),
inserted_at: Instant::now(),
ttl: Duration::from_secs(ttl_secs),
},
);
} else {
REJECTED.fetch_add(1, Ordering::Relaxed);
}
parts
.headers
.insert("x-cache", HeaderValue::from_static("MISS"));
Response::from_parts(parts, Body::from(bytes))
})
}
}