1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5use bytes::Bytes;
6use http::header::{
7 AGE, CACHE_CONTROL, ETAG, IF_MODIFIED_SINCE, IF_NONE_MATCH, LAST_MODIFIED, SET_COOKIE,
8};
9use http::{HeaderMap, HeaderValue, Method, StatusCode};
10use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
11use tracing::debug;
12
13use crate::cache_control::{build_vary_key, parse_cache_control};
14use crate::config::CacheConfig;
15
16#[derive(Clone)]
18struct CacheEntry {
19 status: StatusCode,
20 headers: HeaderMap,
21 body: Bytes,
22 inserted_at: Instant,
23 max_age: Duration,
24 etag: Option<String>,
25 last_modified: Option<String>,
26}
27
28impl CacheEntry {
29 fn is_fresh(&self) -> bool {
31 self.inserted_at.elapsed() < self.max_age
32 }
33
34 fn age_secs(&self) -> u64 {
36 self.inserted_at.elapsed().as_secs()
37 }
38}
39
40#[derive(Debug, Clone, Hash, PartialEq, Eq)]
43struct CacheKey {
44 host: String,
45 method: String,
46 path: String,
47 vary_key: String,
48}
49
50pub struct CacheHoop {
58 config: CacheConfig,
59 store: Mutex<CacheStore>,
60}
61
62struct CacheStore {
63 entries: HashMap<CacheKey, CacheEntry>,
64 access_order: Vec<CacheKey>,
66 max_entries: usize,
67}
68
69impl CacheStore {
70 fn new(max_entries: usize) -> Self {
71 Self {
72 entries: HashMap::new(),
73 access_order: Vec::new(),
74 max_entries,
75 }
76 }
77
78 fn get(&mut self, key: &CacheKey) -> Option<CacheEntry> {
80 let entry = self.entries.get(key)?;
81 if !entry.is_fresh() {
82 self.entries.remove(key);
84 self.access_order.retain(|k| k != key);
85 return None;
86 }
87 let entry = entry.clone();
88 self.access_order.retain(|k| k != key);
90 self.access_order.push(key.clone());
91 Some(entry)
92 }
93
94 fn insert(&mut self, key: CacheKey, entry: CacheEntry) {
96 if self.entries.contains_key(&key) {
98 self.access_order.retain(|k| k != &key);
99 }
100
101 while self.entries.len() >= self.max_entries && !self.access_order.is_empty() {
103 let evicted = self.access_order.remove(0);
104 self.entries.remove(&evicted);
105 debug!(key = ?evicted.path, "evicted LRU cache entry");
106 }
107
108 self.access_order.push(key.clone());
109 self.entries.insert(key, entry);
110 }
111}
112
113impl CacheHoop {
114 pub fn new(config: &CacheConfig) -> Self {
115 debug!(
116 max_entries = config.max_entries,
117 max_entry_size = config.max_entry_size,
118 default_max_age = config.default_max_age.as_secs(),
119 "cache middleware initialized"
120 );
121 Self {
122 config: config.clone(),
123 store: Mutex::new(CacheStore::new(config.max_entries)),
124 }
125 }
126}
127
128#[async_trait]
129impl salvo::Handler for CacheHoop {
130 async fn handle(
131 &self,
132 req: &mut Request,
133 depot: &mut Depot,
134 res: &mut Response,
135 ctrl: &mut FlowCtrl,
136 ) {
137 let method = req.method().clone();
138
139 if method != Method::GET && method != Method::HEAD {
141 ctrl.call_next(req, depot, res).await;
142 return;
143 }
144
145 let req_cache_control = parse_cache_control(
147 req.headers()
148 .get(CACHE_CONTROL)
149 .and_then(|v| v.to_str().ok())
150 .unwrap_or(""),
151 );
152 if req_cache_control.no_store {
153 ctrl.call_next(req, depot, res).await;
154 let _ = res.add_header("X-Cache", "BYPASS", true);
155 return;
156 }
157
158 let host = req
159 .headers()
160 .get(http::header::HOST)
161 .and_then(|v| v.to_str().ok())
162 .unwrap_or("")
163 .to_string();
164 let path = req
165 .uri()
166 .path_and_query()
167 .map(|pq| pq.to_string())
168 .unwrap_or_else(|| req.uri().path().to_string());
169
170 let base_key = CacheKey {
172 host: host.clone(),
173 method: method.to_string(),
174 path: path.clone(),
175 vary_key: String::new(),
176 };
177
178 let cached = {
180 let mut store = self.store.lock().unwrap();
181 if let Some(entry) = store.get(&base_key) {
182 Some((base_key.clone(), entry))
183 } else {
184 let mut found = None;
185 let keys: Vec<CacheKey> = store
186 .entries
187 .keys()
188 .filter(|k| k.host == host && k.method == method.as_str() && k.path == path)
189 .cloned()
190 .collect();
191 for key in keys {
192 if let Some(entry) = store.get(&key) {
193 let vary_key = build_vary_key(&entry.headers, req.headers());
194 if key.vary_key == vary_key {
195 found = Some((key, entry));
196 break;
197 }
198 }
199 }
200 found
201 }
202 };
203
204 if let Some((_key, entry)) = cached {
205 if req_cache_control.no_cache {
206 debug!(path = path.as_str(), "no-cache directive, revalidating");
207 } else {
208 if let Some(inm) = req.headers().get(IF_NONE_MATCH)
210 && let (Ok(inm_str), Some(etag)) = (inm.to_str(), &entry.etag)
211 && inm_str.trim_matches('"') == etag.trim_matches('"')
212 {
213 debug!(path = path.as_str(), "conditional cache hit (ETag), 304");
214 res.status_code(StatusCode::NOT_MODIFIED);
215 if let Ok(val) = etag.parse::<HeaderValue>() {
216 res.headers_mut().insert(ETAG, val);
217 }
218 res.headers_mut()
219 .insert(AGE, HeaderValue::from(entry.age_secs()));
220 let _ = res.add_header("X-Cache", "HIT", true);
221 ctrl.skip_rest();
222 return;
223 }
224
225 if let Some(ims) = req.headers().get(IF_MODIFIED_SINCE)
227 && let (Ok(ims_str), Some(lm)) = (ims.to_str(), &entry.last_modified)
228 && ims_str == lm
229 {
230 debug!(
231 path = path.as_str(),
232 "conditional cache hit (If-Modified-Since), 304"
233 );
234 res.status_code(StatusCode::NOT_MODIFIED);
235 res.headers_mut()
236 .insert(AGE, HeaderValue::from(entry.age_secs()));
237 let _ = res.add_header("X-Cache", "HIT", true);
238 ctrl.skip_rest();
239 return;
240 }
241
242 debug!(
243 path = path.as_str(),
244 age = entry.age_secs(),
245 "cache hit, serving cached response"
246 );
247
248 res.status_code(entry.status);
250 *res.headers_mut() = entry.headers.clone();
251 res.headers_mut()
252 .insert(AGE, HeaderValue::from(entry.age_secs()));
253 let _ = res.add_header("X-Cache", "HIT", true);
254 res.body(entry.body.to_vec());
255 ctrl.skip_rest();
256 return;
257 }
258 }
259
260 ctrl.call_next(req, depot, res).await;
262
263 let status = res.status_code.unwrap_or(StatusCode::OK);
265 let cacheable_status = matches!(
266 status,
267 StatusCode::OK
268 | StatusCode::MOVED_PERMANENTLY
269 | StatusCode::FOUND
270 | StatusCode::NOT_MODIFIED
271 );
272
273 if !cacheable_status {
274 let _ = res.add_header("X-Cache", "BYPASS", true);
275 return;
276 }
277
278 if res.headers().contains_key(SET_COOKIE) {
280 debug!(path = path.as_str(), "response has Set-Cookie, not caching");
281 let _ = res.add_header("X-Cache", "BYPASS", true);
282 return;
283 }
284
285 let resp_cache_control = parse_cache_control(
287 res.headers()
288 .get(CACHE_CONTROL)
289 .and_then(|v| v.to_str().ok())
290 .unwrap_or(""),
291 );
292
293 if resp_cache_control.no_store {
294 let _ = res.add_header("X-Cache", "BYPASS", true);
295 return;
296 }
297
298 let max_age = resp_cache_control
300 .s_maxage
301 .or(resp_cache_control.max_age)
302 .unwrap_or(self.config.default_max_age);
303
304 if max_age.is_zero() {
305 let _ = res.add_header("X-Cache", "BYPASS", true);
306 return;
307 }
308
309 let body = res.take_body();
311 let body_bytes = match super::compress::collect_res_body_bytes(body).await {
312 Ok(b) => Bytes::from(b),
313 Err(_) => {
314 let _ = res.add_header("X-Cache", "BYPASS", true);
315 return;
316 }
317 };
318
319 if body_bytes.len() > self.config.max_entry_size {
321 debug!(
322 path = path.as_str(),
323 size = body_bytes.len(),
324 max = self.config.max_entry_size,
325 "response too large to cache"
326 );
327 let _ = res.add_header("X-Cache", "BYPASS", true);
328 res.body(body_bytes.to_vec());
329 return;
330 }
331
332 let etag = res
333 .headers()
334 .get(ETAG)
335 .and_then(|v| v.to_str().ok())
336 .map(|s| s.to_string());
337 let last_modified = res
338 .headers()
339 .get(LAST_MODIFIED)
340 .and_then(|v| v.to_str().ok())
341 .map(|s| s.to_string());
342
343 let vary_key = String::new();
344
345 let entry = CacheEntry {
346 status,
347 headers: res.headers().clone(),
348 body: body_bytes.clone(),
349 inserted_at: Instant::now(),
350 max_age,
351 etag,
352 last_modified,
353 };
354
355 let cache_key = CacheKey {
356 host,
357 method: method.to_string(),
358 path: path.clone(),
359 vary_key,
360 };
361
362 {
363 let mut store = self.store.lock().unwrap();
364 store.insert(cache_key, entry);
365 }
366
367 debug!(
368 path = path.as_str(),
369 max_age = max_age.as_secs(),
370 size = body_bytes.len(),
371 "cached response"
372 );
373
374 let _ = res.add_header("X-Cache", "MISS", true);
375 res.body(body_bytes.to_vec());
376 }
377}