1use std::sync::atomic::{AtomicU64, Ordering};
21use std::time::{Duration, Instant};
22
23use axum::body::{to_bytes, Body};
24use axum::http::HeaderValue;
25use axum::response::Response;
26use dashmap::DashMap;
27use futures::future::BoxFuture;
28use once_cell::sync::Lazy;
29
30use crate::web::context::RequestContext;
31use crate::web::interceptors::{Interceptor, NextHandler};
32
33#[derive(Clone)]
35struct CachedEntry {
36 bytes: bytes::Bytes,
37 content_type: Option<HeaderValue>,
38 status: u16,
39 inserted_at: Instant,
40 ttl: Duration,
41}
42
43impl CachedEntry {
44 #[inline]
45 fn is_fresh(&self) -> bool {
46 self.inserted_at.elapsed() < self.ttl
47 }
48}
49
50static STORE: Lazy<DashMap<String, CachedEntry>> = Lazy::new(DashMap::new);
53static HITS: AtomicU64 = AtomicU64::new(0);
54static MISSES: AtomicU64 = AtomicU64::new(0);
55
56#[derive(Debug, Clone, Copy, serde::Serialize)]
58pub struct CacheStats {
59 pub hits: u64,
60 pub misses: u64,
61 pub entries: u64,
62}
63
64#[inline]
65pub fn stats() -> CacheStats {
66 CacheStats {
67 hits: HITS.load(Ordering::Relaxed),
68 misses: MISSES.load(Ordering::Relaxed),
69 entries: STORE.len() as u64,
70 }
71}
72
73pub fn clear() {
75 STORE.clear();
76}
77
78#[inline]
82fn key_for(ctx: &RequestContext) -> String {
83 if let Some(spec) = ctx.route_spec() {
84 if !spec.cache_key.is_empty() {
85 return spec.cache_key.to_owned();
86 }
87 }
88 let q = ctx.query_string().unwrap_or("");
89 if q.is_empty() {
90 format!("{} {}", ctx.method(), ctx.path())
91 } else {
92 format!("{} {}?{}", ctx.method(), ctx.path(), q)
93 }
94}
95
96pub struct CacheInterceptor;
100
101impl Interceptor for CacheInterceptor {
102 fn around(
103 &'static self,
104 ctx: RequestContext,
105 next: NextHandler,
106 ) -> BoxFuture<'static, Response> {
107 Box::pin(async move {
108 let ttl_secs = ctx.route_spec().map(|s| s.cache_ttl_secs).unwrap_or(0);
109 if ttl_secs == 0 {
110 return next.run(ctx).await;
112 }
113 let key = key_for(&ctx);
114
115 if let Some(entry) = STORE.get(&key) {
117 if entry.is_fresh() {
118 HITS.fetch_add(1, Ordering::Relaxed);
119 let mut resp = Response::builder()
120 .status(entry.status)
121 .header("x-cache", "HIT")
122 .body(Body::from(entry.bytes.clone()))
123 .expect("cache hit response builds");
124 if let Some(ct) = entry.content_type.clone() {
125 resp.headers_mut().insert("content-type", ct);
126 }
127 return resp;
128 } else {
129 drop(entry);
132 STORE.remove(&key);
133 }
134 }
135
136 MISSES.fetch_add(1, Ordering::Relaxed);
137 let resp = next.run(ctx).await;
139 if !resp.status().is_success() {
140 return resp;
141 }
142
143 let (mut parts, body) = resp.into_parts();
144 const MAX_CACHE_BODY: usize = 8 * 1024 * 1024;
145 let bytes = to_bytes(body, MAX_CACHE_BODY).await.unwrap_or_default();
146 let content_type = parts.headers.get("content-type").cloned();
147
148 STORE.insert(
149 key,
150 CachedEntry {
151 bytes: bytes.clone(),
152 content_type,
153 status: parts.status.as_u16(),
154 inserted_at: Instant::now(),
155 ttl: Duration::from_secs(ttl_secs),
156 },
157 );
158 parts
159 .headers
160 .insert("x-cache", HeaderValue::from_static("MISS"));
161 Response::from_parts(parts, Body::from(bytes))
162 })
163 }
164}