Skip to main content

autumn_web/cache/
layer.rs

1//! Tower middleware that caches HTTP GET responses.
2//!
3//! Only caches `GET` requests that produce `200 OK` responses.
4//! Non-GET methods and non-200 responses pass through untouched.
5//!
6//! # Usage
7//!
8//! ```rust,ignore
9//! use autumn_web::prelude::*;
10//! use autumn_web::cache::{CacheResponseLayer, MokaCache};
11//!
12//! let store = MokaCache::builder()
13//!     .max_capacity(1000)
14//!     .ttl(std::time::Duration::from_secs(300))
15//!     .build();
16//!
17//! #[get("/users/{id}")]
18//! #[intercept(CacheResponseLayer::from_cache(store))]
19//! async fn get_user(Path(id): Path<i32>) -> Json<User> { ... }
20//! ```
21
22use std::convert::Infallible;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26use axum::body::Body;
27use axum::http::{Method, StatusCode};
28use http::Request;
29use http_body_util::BodyExt;
30use tower::{Layer, Service};
31
32use super::Cache;
33
34/// A cached HTTP response: status, headers, and body bytes.
35#[derive(Clone, serde::Deserialize, serde::Serialize)]
36struct CachedResponse {
37    status: u16,
38    headers: Vec<CachedHeader>,
39    body: Vec<u8>,
40}
41
42#[derive(Clone, serde::Deserialize, serde::Serialize)]
43struct CachedHeader {
44    name: String,
45    value: Vec<u8>,
46}
47
48fn cached_response_from_parts(
49    parts: &http::response::Parts,
50    body: &bytes::Bytes,
51) -> CachedResponse {
52    let headers = parts
53        .headers
54        .iter()
55        .map(|(name, value)| CachedHeader {
56            name: name.as_str().to_owned(),
57            value: value.as_bytes().to_vec(),
58        })
59        .collect();
60
61    CachedResponse {
62        status: parts.status.as_u16(),
63        headers,
64        body: body.to_vec(),
65    }
66}
67
68fn cached_response_into_response(cached: CachedResponse) -> Option<axum::response::Response> {
69    let status = StatusCode::from_u16(cached.status).ok()?;
70    let mut builder = axum::response::Response::builder().status(status);
71    let headers = builder.headers_mut()?;
72
73    for cached_header in cached.headers {
74        let name = http::HeaderName::from_bytes(cached_header.name.as_bytes()).ok()?;
75        let value = http::HeaderValue::from_bytes(&cached_header.value).ok()?;
76        headers.append(name, value);
77    }
78
79    builder.body(Body::from(cached.body)).ok()
80}
81
82/// Tower layer that caches HTTP GET responses.
83///
84/// Wrap around a handler via `#[intercept(CacheResponseLayer::from_cache(store))]`
85/// or construct manually and apply with `.layer()`.
86///
87/// Caching rules:
88/// - Only `GET` requests are cached.
89/// - Only `200 OK` responses are cached.
90/// - The cache key is the request URI path + query string.
91#[derive(Clone)]
92pub struct CacheResponseLayer {
93    store: Arc<dyn Cache>,
94}
95
96impl CacheResponseLayer {
97    /// Create a layer backed by the given cache store.
98    pub fn from_cache(store: impl Cache + 'static) -> Self {
99        Self {
100            store: Arc::new(store),
101        }
102    }
103
104    /// Create from an existing `Arc<dyn Cache>`.
105    pub fn from_shared(store: Arc<dyn Cache>) -> Self {
106        Self { store }
107    }
108
109    /// Create from the global cache registered in `AppState`.
110    ///
111    /// Returns `None` when no global cache has been registered (i.e. the app
112    /// is running with the default per-function Moka caches only).
113    #[must_use]
114    pub fn from_app(state: &crate::state::AppState) -> Option<Self> {
115        state.cache().map(Self::from_shared)
116    }
117}
118
119impl<S> Layer<S> for CacheResponseLayer {
120    type Service = CacheResponseService<S>;
121
122    fn layer(&self, inner: S) -> Self::Service {
123        CacheResponseService {
124            inner,
125            store: self.store.clone(),
126        }
127    }
128}
129
130/// The [`Service`] produced by [`CacheResponseLayer`].
131#[derive(Clone)]
132pub struct CacheResponseService<S> {
133    inner: S,
134    store: Arc<dyn Cache>,
135}
136
137impl<S> Service<Request<Body>> for CacheResponseService<S>
138where
139    S: Service<Request<Body>, Response = axum::response::Response, Error = Infallible>
140        + Clone
141        + Send
142        + 'static,
143    S::Future: Send,
144{
145    type Response = axum::response::Response;
146    type Error = Infallible;
147    type Future = std::pin::Pin<
148        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
149    >;
150
151    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
152        self.inner.poll_ready(cx)
153    }
154
155    fn call(&mut self, req: Request<Body>) -> Self::Future {
156        // Only cache GET requests
157        if req.method() != Method::GET {
158            return Box::pin(self.inner.call(req));
159        }
160
161        // ⚡ Bolt Optimization:
162        // Format the key into a stack-allocated buffer to avoid a heap allocation
163        // on every cache check. Fall back to allocating a String only if the URI
164        // is exceptionally long.
165        let mut buf = [0u8; 512];
166        let cache_key_str = {
167            let mut cursor = &mut buf[..];
168            if std::io::Write::write_fmt(&mut cursor, format_args!("http:{}", req.uri())).is_ok() {
169                let len = 512 - cursor.len();
170                std::str::from_utf8(&buf[..len]).unwrap_or_default()
171            } else {
172                ""
173            }
174        };
175
176        let store = self.store.clone();
177
178        let cache_hit = if cache_key_str.is_empty() {
179            // Fallback for very long URIs
180            super::get_cached::<CachedResponse>(store.as_ref(), &format!("http:{}", req.uri()))
181        } else {
182            super::get_cached::<CachedResponse>(store.as_ref(), cache_key_str)
183        };
184
185        // Check for a cache hit
186        if let Some(cached) = cache_hit
187            && let Some(resp) = cached_response_into_response(cached)
188        {
189            return Box::pin(async move { Ok(resp) });
190        }
191
192        // Cache miss — call the inner service
193        let mut inner = self.inner.clone();
194        let cache_key = if cache_key_str.is_empty() {
195            format!("http:{}", req.uri())
196        } else {
197            cache_key_str.to_owned()
198        };
199
200        Box::pin(async move {
201            let response = inner.call(req).await?;
202
203            // Only cache 200 OK responses
204            if response.status() != StatusCode::OK {
205                return Ok(response);
206            }
207
208            let (parts, body) = response.into_parts();
209
210            // Buffer the body
211            let Ok(collected) = body.collect().await else {
212                let resp = axum::response::Response::builder()
213                    .status(StatusCode::INTERNAL_SERVER_ERROR)
214                    .body(Body::empty())
215                    .expect("infallible response builder");
216                return Ok(resp);
217            };
218            let body_bytes = collected.to_bytes();
219
220            // Store in cache
221            let cached = cached_response_from_parts(&parts, &body_bytes);
222            super::insert_cached(store.as_ref(), &cache_key, cached, None);
223
224            // Reconstruct the response
225            let response = axum::response::Response::from_parts(parts, Body::from(body_bytes));
226            Ok(response)
227        })
228    }
229}
230
231#[cfg(all(test, feature = "cache-moka"))]
232mod tests {
233    use super::*;
234    use crate::cache::RawCacheBytes;
235    use std::collections::HashMap;
236    use std::sync::Mutex;
237    use std::sync::atomic::{AtomicUsize, Ordering};
238    use tower::{ServiceBuilder, ServiceExt};
239
240    #[derive(Default)]
241    struct RawOnlyCache {
242        entries: Mutex<HashMap<String, Vec<u8>>>,
243    }
244
245    impl Cache for RawOnlyCache {
246        fn get_value(&self, key: &str) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
247            self.entries
248                .lock()
249                .expect("raw cache lock poisoned")
250                .get(key)
251                .cloned()
252                .map(|bytes| Arc::new(RawCacheBytes(bytes)) as Arc<dyn std::any::Any + Send + Sync>)
253        }
254
255        fn insert_value(&self, _key: &str, _value: Arc<dyn std::any::Any + Send + Sync>) {}
256
257        fn insert_raw_bytes(&self, key: &str, bytes: Vec<u8>, _ttl: Option<std::time::Duration>) {
258            self.entries
259                .lock()
260                .expect("raw cache lock poisoned")
261                .insert(key.to_owned(), bytes);
262        }
263
264        fn invalidate(&self, key: &str) {
265            self.entries
266                .lock()
267                .expect("raw cache lock poisoned")
268                .remove(key);
269        }
270
271        fn clear(&self) {
272            self.entries
273                .lock()
274                .expect("raw cache lock poisoned")
275                .clear();
276        }
277    }
278
279    /// Build a test service that returns a fixed body and counts calls.
280    fn counting_service(
281        counter: Arc<AtomicUsize>,
282        body: &'static str,
283    ) -> impl Service<
284        Request<Body>,
285        Response = axum::response::Response,
286        Error = Infallible,
287        Future = impl std::future::Future<Output = Result<axum::response::Response, Infallible>> + Send,
288    > + Clone
289    + Send
290    + 'static {
291        let body = body.to_owned();
292        tower::service_fn(move |_req: Request<Body>| {
293            let counter = counter.clone();
294            let body = body.clone();
295            async move {
296                counter.fetch_add(1, Ordering::SeqCst);
297                Ok(axum::response::Response::builder()
298                    .status(StatusCode::OK)
299                    .body(Body::from(body))
300                    .expect("infallible response builder"))
301            }
302        })
303    }
304
305    #[tokio::test]
306    async fn caches_get_responses() {
307        let store = super::super::MokaCache::new(100, None);
308        let counter = Arc::new(AtomicUsize::new(0));
309
310        let mut svc = ServiceBuilder::new()
311            .layer(CacheResponseLayer::from_cache(store))
312            .service(counting_service(counter.clone(), "hello"));
313
314        // First request — cache miss
315        let req = Request::get("/test")
316            .body(Body::empty())
317            .expect("infallible response builder");
318        let resp = svc
319            .ready()
320            .await
321            .expect("infallible response builder")
322            .call(req)
323            .await
324            .expect("infallible response builder");
325        assert_eq!(resp.status(), StatusCode::OK);
326        let body = http_body_util::BodyExt::collect(resp.into_body())
327            .await
328            .expect("infallible response builder")
329            .to_bytes();
330        assert_eq!(body.as_ref(), b"hello");
331        assert_eq!(counter.load(Ordering::SeqCst), 1);
332
333        // Second request — cache hit, inner service NOT called
334        let req = Request::get("/test")
335            .body(Body::empty())
336            .expect("infallible response builder");
337        let resp = svc
338            .ready()
339            .await
340            .expect("infallible response builder")
341            .call(req)
342            .await
343            .expect("infallible response builder");
344        assert_eq!(resp.status(), StatusCode::OK);
345        let body = http_body_util::BodyExt::collect(resp.into_body())
346            .await
347            .expect("infallible response builder")
348            .to_bytes();
349        assert_eq!(body.as_ref(), b"hello");
350        assert_eq!(
351            counter.load(Ordering::SeqCst),
352            1,
353            "inner should not be called again"
354        );
355    }
356
357    #[tokio::test]
358    async fn caches_get_responses_with_raw_byte_backends() {
359        let store = Arc::new(RawOnlyCache::default());
360        let counter = Arc::new(AtomicUsize::new(0));
361
362        let inner = {
363            let counter = counter.clone();
364            tower::service_fn(move |_req: Request<Body>| {
365                let counter = counter.clone();
366                async move {
367                    counter.fetch_add(1, Ordering::SeqCst);
368                    Ok::<_, Infallible>(
369                        axum::response::Response::builder()
370                            .status(StatusCode::OK)
371                            .header("x-cache-test", "persisted")
372                            .body(Body::from("redis-like"))
373                            .expect("infallible response builder"),
374                    )
375                }
376            })
377        };
378
379        let mut svc = ServiceBuilder::new()
380            .layer(CacheResponseLayer::from_shared(store))
381            .service(inner);
382
383        let req = Request::get("/redis-backed")
384            .body(Body::empty())
385            .expect("infallible response builder");
386        let resp = svc
387            .ready()
388            .await
389            .expect("infallible response builder")
390            .call(req)
391            .await
392            .expect("infallible response builder");
393        assert_eq!(resp.status(), StatusCode::OK);
394
395        let req = Request::get("/redis-backed")
396            .body(Body::empty())
397            .expect("infallible response builder");
398        let resp = svc
399            .ready()
400            .await
401            .expect("infallible response builder")
402            .call(req)
403            .await
404            .expect("infallible response builder");
405
406        assert_eq!(resp.status(), StatusCode::OK);
407        assert_eq!(
408            resp.headers()
409                .get("x-cache-test")
410                .and_then(|v| v.to_str().ok()),
411            Some("persisted")
412        );
413        let body = http_body_util::BodyExt::collect(resp.into_body())
414            .await
415            .expect("infallible response builder")
416            .to_bytes();
417        assert_eq!(body.as_ref(), b"redis-like");
418        assert_eq!(
419            counter.load(Ordering::SeqCst),
420            1,
421            "raw-byte backends should cache HTTP responses"
422        );
423    }
424
425    #[tokio::test]
426    async fn does_not_cache_post_requests() {
427        let store = super::super::MokaCache::new(100, None);
428        let counter = Arc::new(AtomicUsize::new(0));
429
430        let mut svc = ServiceBuilder::new()
431            .layer(CacheResponseLayer::from_cache(store))
432            .service(counting_service(counter.clone(), "created"));
433
434        let req = Request::post("/items")
435            .body(Body::empty())
436            .expect("infallible response builder");
437        let _resp = svc
438            .ready()
439            .await
440            .expect("infallible response builder")
441            .call(req)
442            .await
443            .expect("infallible response builder");
444        assert_eq!(counter.load(Ordering::SeqCst), 1);
445
446        let req = Request::post("/items")
447            .body(Body::empty())
448            .expect("infallible response builder");
449        let _resp = svc
450            .ready()
451            .await
452            .expect("infallible response builder")
453            .call(req)
454            .await
455            .expect("infallible response builder");
456        assert_eq!(
457            counter.load(Ordering::SeqCst),
458            2,
459            "POST should not be cached"
460        );
461    }
462
463    #[tokio::test]
464    async fn does_not_cache_non_200_responses() {
465        let store = super::super::MokaCache::new(100, None);
466        let counter = Arc::new(AtomicUsize::new(0));
467
468        let svc_inner = {
469            let counter = counter.clone();
470            tower::service_fn(move |_req: Request<Body>| {
471                let counter = counter.clone();
472                async move {
473                    counter.fetch_add(1, Ordering::SeqCst);
474                    Ok::<_, Infallible>(
475                        axum::response::Response::builder()
476                            .status(StatusCode::NOT_FOUND)
477                            .body(Body::from("not found"))
478                            .expect("infallible response builder"),
479                    )
480                }
481            })
482        };
483
484        let mut svc = ServiceBuilder::new()
485            .layer(CacheResponseLayer::from_cache(store))
486            .service(svc_inner);
487
488        let req = Request::get("/missing")
489            .body(Body::empty())
490            .expect("infallible response builder");
491        let resp = svc
492            .ready()
493            .await
494            .expect("infallible response builder")
495            .call(req)
496            .await
497            .expect("infallible response builder");
498        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
499
500        let req = Request::get("/missing")
501            .body(Body::empty())
502            .expect("infallible response builder");
503        let resp = svc
504            .ready()
505            .await
506            .expect("infallible response builder")
507            .call(req)
508            .await
509            .expect("infallible response builder");
510        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
511        assert_eq!(
512            counter.load(Ordering::SeqCst),
513            2,
514            "404 should not be cached"
515        );
516    }
517
518    #[tokio::test]
519    async fn different_uris_cached_separately() {
520        let store = super::super::MokaCache::new(100, None);
521        let counter = Arc::new(AtomicUsize::new(0));
522
523        let mut svc = ServiceBuilder::new()
524            .layer(CacheResponseLayer::from_cache(store))
525            .service(counting_service(counter.clone(), "ok"));
526
527        let req = Request::get("/a")
528            .body(Body::empty())
529            .expect("infallible response builder");
530        let _resp = svc
531            .ready()
532            .await
533            .expect("infallible response builder")
534            .call(req)
535            .await
536            .expect("infallible response builder");
537        let req = Request::get("/b")
538            .body(Body::empty())
539            .expect("infallible response builder");
540        let _resp = svc
541            .ready()
542            .await
543            .expect("infallible response builder")
544            .call(req)
545            .await
546            .expect("infallible response builder");
547        assert_eq!(
548            counter.load(Ordering::SeqCst),
549            2,
550            "different URIs should miss"
551        );
552
553        // But repeating /a should hit
554        let req = Request::get("/a")
555            .body(Body::empty())
556            .expect("infallible response builder");
557        let _resp = svc
558            .ready()
559            .await
560            .expect("infallible response builder")
561            .call(req)
562            .await
563            .expect("infallible response builder");
564        assert_eq!(counter.load(Ordering::SeqCst), 2, "/a should be cached");
565    }
566
567    #[test]
568    fn from_shared_accepts_arc() {
569        let store = Arc::new(super::super::MokaCache::new(100, None));
570        // Just verify from_shared compiles and the layer can be used
571        let _layer = CacheResponseLayer::from_shared(store);
572    }
573
574    #[tokio::test]
575    async fn caches_get_responses_very_long_uri() {
576        let store = super::super::MokaCache::new(100, None);
577        let counter = Arc::new(AtomicUsize::new(0));
578
579        let mut svc = ServiceBuilder::new()
580            .layer(CacheResponseLayer::from_cache(store))
581            .service(counting_service(counter.clone(), "hello"));
582
583        let long_uri = format!("/test/{}", "a".repeat(1000));
584
585        let req1 = Request::get(&long_uri)
586            .body(Body::empty())
587            .expect("infallible response builder");
588
589        let resp1 = svc
590            .ready()
591            .await
592            .expect("infallible response builder")
593            .call(req1)
594            .await
595            .expect("infallible response builder");
596
597        assert_eq!(resp1.status(), StatusCode::OK);
598        assert_eq!(counter.load(Ordering::SeqCst), 1);
599
600        let req2 = Request::get(&long_uri)
601            .body(Body::empty())
602            .expect("infallible response builder");
603
604        let resp2 = svc
605            .ready()
606            .await
607            .expect("infallible response builder")
608            .call(req2)
609            .await
610            .expect("infallible response builder");
611
612        assert_eq!(resp2.status(), StatusCode::OK);
613        assert_eq!(
614            counter.load(Ordering::SeqCst),
615            1,
616            "Should be cached despite long URI"
617        );
618    }
619}