Skip to main content

gatel_core/hoops/
cache.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant};
4
5use bytes::Bytes;
6use http::header::{
7    AGE, CACHE_CONTROL, ETAG, IF_MODIFIED_SINCE, IF_NONE_MATCH, LAST_MODIFIED, SET_COOKIE,
8};
9use http::{HeaderMap, HeaderValue, Method, StatusCode};
10use salvo::{Depot, FlowCtrl, Request, Response, async_trait};
11use tracing::debug;
12
13use crate::cache_control::{build_vary_key, parse_cache_control};
14use crate::config::CacheConfig;
15
16/// A cached response entry.
17#[derive(Clone)]
18struct CacheEntry {
19    status: StatusCode,
20    headers: HeaderMap,
21    body: Bytes,
22    inserted_at: Instant,
23    max_age: Duration,
24    etag: Option<String>,
25    last_modified: Option<String>,
26}
27
28impl CacheEntry {
29    /// Check whether this entry is still fresh.
30    fn is_fresh(&self) -> bool {
31        self.inserted_at.elapsed() < self.max_age
32    }
33
34    /// Compute the Age header value in seconds.
35    fn age_secs(&self) -> u64 {
36        self.inserted_at.elapsed().as_secs()
37    }
38}
39
40/// Composite cache key: (host, method, path, vary_key).
41/// The vary_key is built from request header values named by the Vary response header.
42#[derive(Debug, Clone, Hash, PartialEq, Eq)]
43struct CacheKey {
44    host: String,
45    method: String,
46    path: String,
47    vary_key: String,
48}
49
50/// In-memory LRU response cache middleware.
51///
52/// Caches GET/HEAD responses with cacheable status codes (200, 301, 302, 304).
53/// Respects `Cache-Control` directives (`no-store`, `no-cache`, `max-age`, `s-maxage`).
54/// Includes `Vary` header values in the cache key.
55/// Supports conditional requests via `If-None-Match` (ETag) and `If-Modified-Since`.
56/// Skips caching responses with `Set-Cookie` headers.
57pub struct CacheHoop {
58    config: CacheConfig,
59    store: Mutex<CacheStore>,
60}
61
62struct CacheStore {
63    entries: HashMap<CacheKey, CacheEntry>,
64    /// Access order for LRU eviction — most recently accessed at the end.
65    access_order: Vec<CacheKey>,
66    max_entries: usize,
67}
68
69impl CacheStore {
70    fn new(max_entries: usize) -> Self {
71        Self {
72            entries: HashMap::new(),
73            access_order: Vec::new(),
74            max_entries,
75        }
76    }
77
78    /// Get an entry if it exists, updating LRU order.
79    fn get(&mut self, key: &CacheKey) -> Option<CacheEntry> {
80        let entry = self.entries.get(key)?;
81        if !entry.is_fresh() {
82            // Expired — remove it.
83            self.entries.remove(key);
84            self.access_order.retain(|k| k != key);
85            return None;
86        }
87        let entry = entry.clone();
88        // Move to end of access order (most recently used).
89        self.access_order.retain(|k| k != key);
90        self.access_order.push(key.clone());
91        Some(entry)
92    }
93
94    /// Insert an entry, evicting the least recently used if at capacity.
95    fn insert(&mut self, key: CacheKey, entry: CacheEntry) {
96        // If key already exists, remove old access order entry.
97        if self.entries.contains_key(&key) {
98            self.access_order.retain(|k| k != &key);
99        }
100
101        // Evict LRU entries if at capacity.
102        while self.entries.len() >= self.max_entries && !self.access_order.is_empty() {
103            let evicted = self.access_order.remove(0);
104            self.entries.remove(&evicted);
105            debug!(key = ?evicted.path, "evicted LRU cache entry");
106        }
107
108        self.access_order.push(key.clone());
109        self.entries.insert(key, entry);
110    }
111}
112
113impl CacheHoop {
114    pub fn new(config: &CacheConfig) -> Self {
115        debug!(
116            max_entries = config.max_entries,
117            max_entry_size = config.max_entry_size,
118            default_max_age = config.default_max_age.as_secs(),
119            "cache middleware initialized"
120        );
121        Self {
122            config: config.clone(),
123            store: Mutex::new(CacheStore::new(config.max_entries)),
124        }
125    }
126}
127
128#[async_trait]
129impl salvo::Handler for CacheHoop {
130    async fn handle(
131        &self,
132        req: &mut Request,
133        depot: &mut Depot,
134        res: &mut Response,
135        ctrl: &mut FlowCtrl,
136    ) {
137        let method = req.method().clone();
138
139        // Only cache GET and HEAD requests.
140        if method != Method::GET && method != Method::HEAD {
141            ctrl.call_next(req, depot, res).await;
142            return;
143        }
144
145        // Check request Cache-Control for no-store / no-cache.
146        let req_cache_control = parse_cache_control(
147            req.headers()
148                .get(CACHE_CONTROL)
149                .and_then(|v| v.to_str().ok())
150                .unwrap_or(""),
151        );
152        if req_cache_control.no_store {
153            ctrl.call_next(req, depot, res).await;
154            let _ = res.add_header("X-Cache", "BYPASS", true);
155            return;
156        }
157
158        let host = req
159            .headers()
160            .get(http::header::HOST)
161            .and_then(|v| v.to_str().ok())
162            .unwrap_or("")
163            .to_string();
164        let path = req
165            .uri()
166            .path_and_query()
167            .map(|pq| pq.to_string())
168            .unwrap_or_else(|| req.uri().path().to_string());
169
170        // Build a preliminary cache key (without vary — we check multiple vary keys).
171        let base_key = CacheKey {
172            host: host.clone(),
173            method: method.to_string(),
174            path: path.clone(),
175            vary_key: String::new(),
176        };
177
178        // Attempt cache lookup.
179        let cached = {
180            let mut store = self.store.lock().unwrap();
181            if let Some(entry) = store.get(&base_key) {
182                Some((base_key.clone(), entry))
183            } else {
184                let mut found = None;
185                let keys: Vec<CacheKey> = store
186                    .entries
187                    .keys()
188                    .filter(|k| k.host == host && k.method == method.as_str() && k.path == path)
189                    .cloned()
190                    .collect();
191                for key in keys {
192                    if let Some(entry) = store.get(&key) {
193                        let vary_key = build_vary_key(&entry.headers, req.headers());
194                        if key.vary_key == vary_key {
195                            found = Some((key, entry));
196                            break;
197                        }
198                    }
199                }
200                found
201            }
202        };
203
204        if let Some((_key, entry)) = cached {
205            if req_cache_control.no_cache {
206                debug!(path = path.as_str(), "no-cache directive, revalidating");
207            } else {
208                // Check conditional request: If-None-Match.
209                if let Some(inm) = req.headers().get(IF_NONE_MATCH)
210                    && let (Ok(inm_str), Some(etag)) = (inm.to_str(), &entry.etag)
211                    && inm_str.trim_matches('"') == etag.trim_matches('"')
212                {
213                    debug!(path = path.as_str(), "conditional cache hit (ETag), 304");
214                    res.status_code(StatusCode::NOT_MODIFIED);
215                    if let Ok(val) = etag.parse::<HeaderValue>() {
216                        res.headers_mut().insert(ETAG, val);
217                    }
218                    res.headers_mut()
219                        .insert(AGE, HeaderValue::from(entry.age_secs()));
220                    let _ = res.add_header("X-Cache", "HIT", true);
221                    ctrl.skip_rest();
222                    return;
223                }
224
225                // Check conditional request: If-Modified-Since.
226                if let Some(ims) = req.headers().get(IF_MODIFIED_SINCE)
227                    && let (Ok(ims_str), Some(lm)) = (ims.to_str(), &entry.last_modified)
228                    && ims_str == lm
229                {
230                    debug!(
231                        path = path.as_str(),
232                        "conditional cache hit (If-Modified-Since), 304"
233                    );
234                    res.status_code(StatusCode::NOT_MODIFIED);
235                    res.headers_mut()
236                        .insert(AGE, HeaderValue::from(entry.age_secs()));
237                    let _ = res.add_header("X-Cache", "HIT", true);
238                    ctrl.skip_rest();
239                    return;
240                }
241
242                debug!(
243                    path = path.as_str(),
244                    age = entry.age_secs(),
245                    "cache hit, serving cached response"
246                );
247
248                // Build response from cached entry.
249                res.status_code(entry.status);
250                *res.headers_mut() = entry.headers.clone();
251                res.headers_mut()
252                    .insert(AGE, HeaderValue::from(entry.age_secs()));
253                let _ = res.add_header("X-Cache", "HIT", true);
254                res.body(entry.body.to_vec());
255                ctrl.skip_rest();
256                return;
257            }
258        }
259
260        // Cache miss — call downstream.
261        ctrl.call_next(req, depot, res).await;
262
263        // Determine if the response is cacheable.
264        let status = res.status_code.unwrap_or(StatusCode::OK);
265        let cacheable_status = matches!(
266            status,
267            StatusCode::OK
268                | StatusCode::MOVED_PERMANENTLY
269                | StatusCode::FOUND
270                | StatusCode::NOT_MODIFIED
271        );
272
273        if !cacheable_status {
274            let _ = res.add_header("X-Cache", "BYPASS", true);
275            return;
276        }
277
278        // Skip if response has Set-Cookie.
279        if res.headers().contains_key(SET_COOKIE) {
280            debug!(path = path.as_str(), "response has Set-Cookie, not caching");
281            let _ = res.add_header("X-Cache", "BYPASS", true);
282            return;
283        }
284
285        // Parse response Cache-Control.
286        let resp_cache_control = parse_cache_control(
287            res.headers()
288                .get(CACHE_CONTROL)
289                .and_then(|v| v.to_str().ok())
290                .unwrap_or(""),
291        );
292
293        if resp_cache_control.no_store {
294            let _ = res.add_header("X-Cache", "BYPASS", true);
295            return;
296        }
297
298        // Determine max-age: s-maxage > max-age > default.
299        let max_age = resp_cache_control
300            .s_maxage
301            .or(resp_cache_control.max_age)
302            .unwrap_or(self.config.default_max_age);
303
304        if max_age.is_zero() {
305            let _ = res.add_header("X-Cache", "BYPASS", true);
306            return;
307        }
308
309        // Collect body to bytes for caching.
310        let body = res.take_body();
311        let body_bytes = match super::compress::collect_res_body_bytes(body).await {
312            Ok(b) => Bytes::from(b),
313            Err(_) => {
314                let _ = res.add_header("X-Cache", "BYPASS", true);
315                return;
316            }
317        };
318
319        // Check entry size limit.
320        if body_bytes.len() > self.config.max_entry_size {
321            debug!(
322                path = path.as_str(),
323                size = body_bytes.len(),
324                max = self.config.max_entry_size,
325                "response too large to cache"
326            );
327            let _ = res.add_header("X-Cache", "BYPASS", true);
328            res.body(body_bytes.to_vec());
329            return;
330        }
331
332        let etag = res
333            .headers()
334            .get(ETAG)
335            .and_then(|v| v.to_str().ok())
336            .map(|s| s.to_string());
337        let last_modified = res
338            .headers()
339            .get(LAST_MODIFIED)
340            .and_then(|v| v.to_str().ok())
341            .map(|s| s.to_string());
342
343        let vary_key = String::new();
344
345        let entry = CacheEntry {
346            status,
347            headers: res.headers().clone(),
348            body: body_bytes.clone(),
349            inserted_at: Instant::now(),
350            max_age,
351            etag,
352            last_modified,
353        };
354
355        let cache_key = CacheKey {
356            host,
357            method: method.to_string(),
358            path: path.clone(),
359            vary_key,
360        };
361
362        {
363            let mut store = self.store.lock().unwrap();
364            store.insert(cache_key, entry);
365        }
366
367        debug!(
368            path = path.as_str(),
369            max_age = max_age.as_secs(),
370            size = body_bytes.len(),
371            "cached response"
372        );
373
374        let _ = res.add_header("X-Cache", "MISS", true);
375        res.body(body_bytes.to_vec());
376    }
377}