axum_response_cache/
lib.rs

1//! This library provides [Axum middleware](`axum#middleware`) that caches HTTP responses to the
2//! incoming requests based on their HTTP method and path.
3//!
4//! The main struct is [`CacheLayer`]. It can be created with any cache that implements two traits
5//! from the [`cached`] crate: [`cached::Cached`] and [`cached::CloneCached`].
6//!
7//! The *current* version of [`CacheLayer`] is compatible only with services accepting
8//! Axum’s [`Request<Body>`](`http::Request<axum::body::Body>`) and returning
9//! [`axum::response::Response`], thus it is not compatible with non-Axum [`tower`] services.
10//!
11//! It’s possible to configure the layer to re-use an old expired response in case the wrapped
12//! service fails to produce a new successful response.
13//!
14//! Only successful responses are cached (responses with status codes outside of the `[200-299]`
15//! range are passed-through or ignored).
16//!
17//! The cache limits maximum size of the response’s body (128 MB by default).
18//!
19//! ## Examples
20//!
21//! To cache a response over a specific route, just wrap it in a [`CacheLayer`]:
22//!
23//! ```rust,no_run
24//! # use axum_08 as axum;
25//! use axum::{Router, extract::Path, routing::get};
26//! use axum_response_cache::CacheLayer;
27//!
28//! #[tokio::main]
29//! async fn main() {
30//!     let mut router = Router::new()
31//!         .route(
32//!             "/hello/{name}",
33//!             get(|Path(name): Path<String>| async move { format!("Hello, {name}!") })
34//!                 // this will cache responses with each `:name` for 60 seconds.
35//!                 .layer(CacheLayer::with_lifespan(60)),
36//!         );
37//!
38//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
39//!     axum::serve(listener, router).await.unwrap();
40//! }
41//! ```
42//!
43//! ### Reusing last successful response
44//!
45//! ```rust
46//! # use std::sync::atomic::{AtomicBool, Ordering};
47//! # use axum_08 as axum;
48//! use axum::{
49//!     body::Body,
50//!     extract::Path,
51//!     http::status::StatusCode,
52//!     http::Request,
53//!     Router,
54//!     routing::get,
55//! };
56//! use axum_response_cache::CacheLayer;
57//! use tower::Service as _;
58//!
59//! // a handler that returns 200 OK only the first time it’s called
60//! async fn handler(Path(name): Path<String>) -> (StatusCode, String) {
61//!     static FIRST_RUN: AtomicBool = AtomicBool::new(true);
62//!     let first_run = FIRST_RUN.swap(false, Ordering::AcqRel);
63//!
64//!     if first_run {
65//!         (StatusCode::OK, format!("Hello, {name}"))
66//!     } else {
67//!         (StatusCode::INTERNAL_SERVER_ERROR, String::from("Error!"))
68//!     }
69//! }
70//!
71//! # #[tokio::main]
72//! # async fn main() {
73//! let mut router = Router::new()
74//!     .route("/hello/{name}", get(handler))
75//!     .layer(CacheLayer::with_lifespan(60).use_stale_on_failure());
76//!
77//! // first request will fire handler and get the response
78//! let status1 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
79//!     .await
80//!     .unwrap()
81//!     .status();
82//! assert_eq!(StatusCode::OK, status1);
83//!
84//! // second request will reuse the last response since the handler now returns ISE
85//! let status2 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
86//!     .await
87//!     .unwrap()
88//!     .status();
89//! assert_eq!(StatusCode::OK, status2);
90//! # }
91//! ```
92//!
93//! ### Serving static files
94//! This middleware can be used to cache files served in memory to limit hard drive load on the
95//! server. To serve files you can use [`tower-http::services::ServeDir`](https://docs.rs/tower-http/latest/tower_http/services/struct.ServeDir.html) layer.
96//! ```rust,ignore
97//! let router = Router::new().nest_service("/", ServeDir::new("static/"));
98//! ```
99//!
100//! ### Limiting the body size
101//!
102//! ```rust
103//! # use axum_08 as axum;
104//! use axum::{
105//!     body::Body,
106//!     extract::Path,
107//!     http::status::StatusCode,
108//!     http::Request,
109//!     Router,
110//!     routing::get,
111//! };
112//! use axum_response_cache::CacheLayer;
113//! use tower::Service as _;
114//!
115//! // returns a short string, well below the limit
116//! async fn ok_handler() -> &'static str {
117//!     "ok"
118//! }
119//!
120//! async fn too_long_handler() -> &'static str {
121//!     "a response that is well beyond the limit of the cache!"
122//! }
123//!
124//! # #[tokio::main]
125//! # async fn main() {
126//! let mut router = Router::new()
127//!     .route("/ok", get(ok_handler))
128//!     .route("/too_long", get(too_long_handler))
129//!     // limit max cached body to only 16 bytes
130//!     .layer(CacheLayer::with_lifespan(60).body_limit(16));
131//!
132//! let status_ok = router.call(Request::get("/ok").body(Body::empty()).unwrap())
133//!     .await
134//!     .unwrap()
135//!     .status();
136//! assert_eq!(StatusCode::OK, status_ok);
137//!
138//! let status_too_long = router.call(Request::get("/too_long").body(Body::empty()).unwrap())
139//!     .await
140//!     .unwrap()
141//!     .status();
142//! assert_eq!(StatusCode::INTERNAL_SERVER_ERROR, status_too_long);
143//! # }
144//! ```
145//! ### Manual Cache Invalidation
146//! This middleware allows manual cache invalidation by setting the `X-Invalidate-Cache` header in the request. This can be useful when you know the underlying data has changed and you want to force a fresh pull of data.
147//!
148//! ```rust
149//! # use axum_08 as axum;
150//! use axum::{
151//!     body::Body,
152//!     extract::Path,
153//!     http::status::StatusCode,
154//!     http::Request,
155//!     Router,
156//!     routing::get,
157//! };
158//! use axum_response_cache::CacheLayer;
159//! use tower::Service as _;
160//!
161//! async fn handler(Path(name): Path<String>) -> (StatusCode, String) {
162//!     (StatusCode::OK, format!("Hello, {name}"))
163//! }
164//!
165//! # #[tokio::main]
166//! # async fn main() {
167//! let mut router = Router::new()
168//!     .route("/hello/{name}", get(handler))
169//!     .layer(CacheLayer::with_lifespan(60).allow_invalidation());
170//!
171//! // first request will fire handler and get the response
172//! let status1 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
173//!     .await
174//!     .unwrap()
175//!     .status();
176//! assert_eq!(StatusCode::OK, status1);
177//!
178//! // second request should return the cached response
179//! let status2 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
180//!     .await
181//!     .unwrap()
182//!     .status();
183//! assert_eq!(StatusCode::OK, status2);
184//!
185//! // third request with X-Invalidate-Cache header to invalidate the cache
186//! let status3 = router.call(
187//!     Request::get("/hello/foo")
188//!         .header("X-Invalidate-Cache", "true")
189//!         .body(Body::empty())
190//!         .unwrap(),
191//!     )
192//!     .await
193//!     .unwrap()
194//!     .status();
195//! assert_eq!(StatusCode::OK, status3);
196//!
197//! // fourth request to verify that the handler is called again
198//! let status4 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
199//!     .await
200//!     .unwrap()
201//!     .status();
202//! assert_eq!(StatusCode::OK, status4);
203//! # }
204//! ```
205//! Cache invalidation could be dangerous because it can allow a user to force the server to make a request to an external service or database. It is disabled by default, but can be enabled by calling the [`CacheLayer::allow_invalidation`] method.
206//!
207//! ## Using custom cache
208//! ```rust
209//! # use axum_08 as axum;
210//! use axum::{Router, routing::get};
211//! use axum_response_cache::CacheLayer;
212//! // let’s use TimedSizedCache here
213//! use cached::stores::TimedSizedCache;
214//! # use axum::{body::Body, http::Request};
215//! # use tower::ServiceExt;
216//!
217//! # #[tokio::main]
218//! # async fn main() {
219//! let router: Router = Router::new()
220//!     .route("/hello", get(|| async { "Hello, world!" }))
221//!     // cache maximum value of 50 responses for one minute
222//!     .layer(CacheLayer::with(TimedSizedCache::with_size_and_lifespan(50, 60)));
223//! # // force type inference to resolve the exact type of router
224//! #     let _ = router.oneshot(Request::get("/hello").body(Body::empty()).unwrap()).await;
225//! # }
226//! ```
227//!
228//! ## Using custom keyer
229//! It’s possible to customize the cache’s key to include eg. the `Accept` header (so that
230//! different types of responses are cached separately based on the header).
231//!
232//! ```rust
233//! # use axum_08 as axum;
234//! use axum::{Router, routing::get};
235//! use axum_response_cache::CacheLayer;
236//! # use axum::{body::Body, http::Request};
237//! # use tower::ServiceExt;
238//!
239//! # #[tokio::main]
240//! # async fn main() {
241//! // cache responses based on method, Accept header, and uri
242//! let keyer = |request: &Request<Body>| {
243//!     (
244//!         request.method().clone(),
245//!         request
246//!             .headers()
247//!             .get(axum::http::header::ACCEPT)
248//!             .and_then(|c| c.to_str().ok())
249//!             .unwrap_or("")
250//!             .to_string(),
251//!         request.uri().clone(),
252//!     )
253//! };
254//! let router: Router = Router::new()
255//!     .route("/hello", get(|| async { "Hello, world!" }))
256//!     .layer(CacheLayer::with_lifespan_and_keyer(60, keyer));
257//! # // force type inference to resolve the exact type of router
258//! #     let _ = router.oneshot(Request::get("/hello").body(Body::empty()).unwrap()).await;
259//! # }
260//! ```
261//!
262//! ## Use cases
263//! Caching responses in memory (eg. using [`cached::TimedCache`]) might be useful when the
264//! underlying service produces the responses by:
265//! 1. doing heavy computation,
266//! 2. requesting external service(s) that might not be fully reliable or performant,
267//! 3. serving static files from disk.
268//!
269//! In those cases, if the response to identical requests does not change often over time, it might
270//! be desirable to re-use the same responses from memory without re-calculating them – skipping requests to data
271//! bases, external services, reading from disk.
272//!
273//! ### Using Axum 0.7
274//!
275//! By default, this library uses Axum 0.8. However, you can configure it to use Axum 0.7 by enabling the appropriate feature flag in your `Cargo.toml`.
276//!
277//! To use Axum 0.7, add the following to your `Cargo.toml`:
278//!
279//! ```toml
280//! [dependencies]
281//! axum-response-cache = { version = "0.3", features = ["axum07"], default-features = false }
282//! ```
283//!
284//! This will disable the default Axum 0.8 feature and enable the Axum 0.7 feature instead.
285
286use std::{
287    convert::Infallible,
288    fmt::Debug,
289    future::Future,
290    hash::Hash,
291    pin::Pin,
292    sync::{Arc, Mutex},
293    task::{Context, Poll},
294};
295use tracing_futures::Instrument as _;
296
297#[cfg(feature = "axum07")]
298use axum_07 as axum;
299#[cfg(feature = "axum08")]
300use axum_08 as axum;
301
302use axum::body;
303use axum::{
304    body::{Body, Bytes},
305    http::{response::Parts, Request, StatusCode},
306    response::{IntoResponse, Response},
307};
308
309use cached::{Cached, CloneCached, TimedCache};
310use tower::{Layer, Service};
311use tracing::{debug, instrument};
312
313/// The trait for objects used to obtain cache keys. See [`BasicKeyer`] for default implementation
314/// returning `(http::Method, Uri)`.
315pub trait Keyer {
316    type Key;
317
318    fn get_key(&self, request: &Request<Body>) -> Self::Key;
319}
320
321impl<K, F> Keyer for F
322where
323    F: Fn(&Request<Body>) -> K + Send + Sync + 'static,
324{
325    type Key = K;
326
327    fn get_key(&self, request: &Request<Body>) -> Self::Key {
328        self(request)
329    }
330}
331
332/// The basic caching strategy for the responses.
333///
334/// The responses are cached according to the HTTP method ([`axum::http::Method`]) and path
335/// ([`axum::http::Uri`]) of the request they responded to.
336pub struct BasicKeyer;
337
338pub type BasicKey = (http::Method, http::Uri);
339
340impl Keyer for BasicKeyer {
341    type Key = BasicKey;
342
343    fn get_key(&self, request: &Request<Body>) -> Self::Key {
344        (request.method().clone(), request.uri().clone())
345    }
346}
347
348/// The struct preserving all the headers and body of the cached response.
349#[derive(Clone, Debug)]
350pub struct CachedResponse {
351    parts: Parts,
352    body: Bytes,
353    timestamp: Option<std::time::Instant>,
354}
355
356impl IntoResponse for CachedResponse {
357    fn into_response(self) -> Response {
358        let mut response = Response::from_parts(self.parts, Body::from(self.body));
359        if let Some(timestamp) = self.timestamp {
360            let age = timestamp.elapsed().as_secs();
361            response
362                .headers_mut()
363                .insert("X-Cache-Age", age.to_string().parse().unwrap());
364        }
365        response
366    }
367}
368
369/// The main struct of the library. The layer providing caching to the wrapped service.
370/// It is generic over the cache used (`C`) and a `Keyer` (`K`) used to obtain the key for cached
371/// responses.
372pub struct CacheLayer<C, K> {
373    cache: Arc<Mutex<C>>,
374    use_stale: bool,
375    limit: usize,
376    allow_invalidation: bool,
377    add_response_headers: bool,
378    keyer: Arc<K>,
379}
380
381impl<C, K> Clone for CacheLayer<C, K> {
382    fn clone(&self) -> Self {
383        Self {
384            cache: Arc::clone(&self.cache),
385            use_stale: self.use_stale,
386            limit: self.limit,
387            allow_invalidation: self.allow_invalidation,
388            add_response_headers: self.add_response_headers,
389            keyer: Arc::clone(&self.keyer),
390        }
391    }
392}
393
394impl<C, K> CacheLayer<C, K>
395where
396    C: Cached<K::Key, CachedResponse> + CloneCached<K::Key, CachedResponse>,
397    K: Keyer,
398    K::Key: Debug + Hash + Eq + Clone + Send + 'static,
399{
400    /// Create a new cache layer with a given cache and the default body size limit of 128 MB.
401    pub fn with_cache_and_keyer(cache: C, keyer: K) -> Self {
402        Self {
403            cache: Arc::new(Mutex::new(cache)),
404            use_stale: false,
405            limit: 128 * 1024 * 1024,
406            allow_invalidation: false,
407            add_response_headers: false,
408            keyer: Arc::new(keyer),
409        }
410    }
411
412    /// Switch the layer’s settings to preserve the last successful response even when it’s evicted
413    /// from the cache but the service failed to provide a new successful response (ie. eg. when
414    /// the underlying service responds with `404 NOT FOUND`, the cache will keep providing the last stale `200 OK`
415    /// response produced).
416    pub fn use_stale_on_failure(self) -> Self {
417        Self {
418            use_stale: true,
419            ..self
420        }
421    }
422
423    /// Change the maximum body size limit. If you want unlimited size, use [`usize::MAX`].
424    pub fn body_limit(self, new_limit: usize) -> Self {
425        Self {
426            limit: new_limit,
427            ..self
428        }
429    }
430
431    /// Allow manual cache invalidation by setting the `X-Invalidate-Cache` header in the request.
432    /// This will allow the cache to be invalidated for the given key.
433    pub fn allow_invalidation(self) -> Self {
434        Self {
435            allow_invalidation: true,
436            ..self
437        }
438    }
439
440    /// Allow the response headers to be included in the cached response.
441    pub fn add_response_headers(self) -> Self {
442        Self {
443            add_response_headers: true,
444            ..self
445        }
446    }
447}
448
449impl<C> CacheLayer<C, BasicKeyer>
450where
451    C: Cached<BasicKey, CachedResponse> + CloneCached<BasicKey, CachedResponse>,
452{
453    /// Create a new cache layer with a given cache and the default body size limit of 128 MB.
454    pub fn with(cache: C) -> Self {
455        Self {
456            cache: Arc::new(Mutex::new(cache)),
457            use_stale: false,
458            limit: 128 * 1024 * 1024,
459            allow_invalidation: false,
460            add_response_headers: false,
461            keyer: Arc::new(BasicKeyer),
462        }
463    }
464}
465
466impl CacheLayer<TimedCache<BasicKey, CachedResponse>, BasicKey> {
467    /// Create a new cache layer with the desired TTL in seconds
468    pub fn with_lifespan(
469        ttl_sec: u64,
470    ) -> CacheLayer<TimedCache<BasicKey, CachedResponse>, BasicKeyer> {
471        CacheLayer::with(TimedCache::with_lifespan(ttl_sec))
472    }
473}
474
475impl<K> CacheLayer<TimedCache<K::Key, CachedResponse>, K>
476where
477    K: Keyer,
478    K::Key: Debug + Hash + Eq + Clone + Send + 'static,
479{
480    /// Create a new cache layer with the desired TTL in seconds
481    pub fn with_lifespan_and_keyer(
482        ttl_sec: u64,
483        keyer: K,
484    ) -> CacheLayer<TimedCache<K::Key, CachedResponse>, K> {
485        CacheLayer::with_cache_and_keyer(TimedCache::with_lifespan(ttl_sec), keyer)
486    }
487}
488
489impl<S, C, K> Layer<S> for CacheLayer<C, K>
490where
491    K: Keyer,
492    K::Key: Debug + Hash + Eq + Clone + Send + 'static,
493{
494    type Service = CacheService<S, C, K>;
495
496    fn layer(&self, inner: S) -> Self::Service {
497        Self::Service {
498            inner,
499            cache: Arc::clone(&self.cache),
500            use_stale: self.use_stale,
501            limit: self.limit,
502            allow_invalidation: self.allow_invalidation,
503            add_response_headers: self.add_response_headers,
504            keyer: Arc::clone(&self.keyer),
505        }
506    }
507}
508
509pub struct CacheService<S, C, K> {
510    inner: S,
511    cache: Arc<Mutex<C>>,
512    use_stale: bool,
513    limit: usize,
514    allow_invalidation: bool,
515    add_response_headers: bool,
516    keyer: Arc<K>,
517}
518
519impl<S, C, K> Clone for CacheService<S, C, K>
520where
521    S: Clone,
522{
523    fn clone(&self) -> Self {
524        Self {
525            inner: self.inner.clone(),
526            cache: Arc::clone(&self.cache),
527            use_stale: self.use_stale,
528            limit: self.limit,
529            allow_invalidation: self.allow_invalidation,
530            add_response_headers: self.add_response_headers,
531            keyer: Arc::clone(&self.keyer),
532        }
533    }
534}
535
536impl<S, C, K> Service<Request<Body>> for CacheService<S, C, K>
537where
538    S: Service<Request<Body>, Response = Response, Error = Infallible> + Clone + Send,
539    S::Future: Send + 'static,
540    C: Cached<K::Key, CachedResponse> + CloneCached<K::Key, CachedResponse> + Send + 'static,
541    K: Keyer,
542    K::Key: Debug + Hash + Eq + Clone + Send + 'static,
543{
544    type Response = Response;
545    type Error = Infallible;
546    type Future = Pin<Box<dyn Future<Output = Result<Response, Infallible>> + Send + 'static>>;
547
548    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
549        self.inner.poll_ready(cx)
550    }
551
552    #[instrument(skip(self, request))]
553    fn call(&mut self, request: Request<Body>) -> Self::Future {
554        let mut inner = self.inner.clone();
555        let use_stale = self.use_stale;
556        let allow_invalidation = self.allow_invalidation;
557        let add_response_headers = self.add_response_headers;
558        let limit = self.limit;
559        let cache = Arc::clone(&self.cache);
560        let key = self.keyer.get_key(&request);
561
562        // Check for the custom header "X-Invalidate-Cache" if invalidation is allowed
563        if allow_invalidation && request.headers().contains_key("X-Invalidate-Cache") {
564            // Manually invalidate the cache for this key
565            cache.lock().unwrap().cache_remove(&key);
566            debug!("Cache invalidated manually for key {:?}", key);
567        }
568
569        let inner_fut = inner
570            .call(request)
571            .instrument(tracing::info_span!("inner_service"));
572        let (cached, evicted) = {
573            let mut guard = cache.lock().unwrap();
574            let (cached, evicted) = guard.cache_get_expired(&key);
575            if let (Some(stale), true) = (cached.as_ref(), evicted) {
576                // reinsert stale value immediately so that others don’t schedule their updating
577                debug!("Found stale value in cache, reinsterting and attempting refresh");
578                guard.cache_set(key.clone(), stale.clone());
579            }
580            (cached, evicted)
581        };
582
583        Box::pin(async move {
584            match (cached, evicted) {
585                (Some(value), false) => Ok(value.into_response()),
586                (Some(stale_value), true) => {
587                    let response = inner_fut.await.unwrap();
588                    if response.status().is_success() {
589                        Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
590                    } else if use_stale {
591                        debug!("Returning stale value.");
592                        Ok(stale_value.into_response())
593                    } else {
594                        debug!("Stale value in cache, evicting and returning failed response.");
595                        cache.lock().unwrap().cache_remove(&key);
596                        Ok(response)
597                    }
598                }
599                (None, _) => {
600                    let response = inner_fut.await.unwrap();
601                    if response.status().is_success() {
602                        Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
603                    } else {
604                        Ok(response)
605                    }
606                }
607            }
608        })
609    }
610}
611
612#[instrument(skip(cache, response))]
613async fn update_cache<C, K>(
614    cache: &Arc<Mutex<C>>,
615    key: K,
616    response: Response,
617    limit: usize,
618    add_response_headers: bool,
619) -> Response
620where
621    C: Cached<K, CachedResponse> + CloneCached<K, CachedResponse>,
622    K: Debug + Hash + Eq + Clone + Send + 'static,
623{
624    let (parts, body) = response.into_parts();
625    let Ok(body) = body::to_bytes(body, limit).await else {
626        return (
627            StatusCode::INTERNAL_SERVER_ERROR,
628            format!("File too big, over {limit} bytes"),
629        )
630            .into_response();
631    };
632    let value = CachedResponse {
633        parts,
634        body,
635        timestamp: if add_response_headers {
636            Some(std::time::Instant::now())
637        } else {
638            None
639        },
640    };
641    {
642        cache.lock().unwrap().cache_set(key, value.clone());
643    }
644    value.into_response()
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use rand::Rng;
651    use std::sync::atomic::{AtomicIsize, Ordering};
652
653    #[cfg(feature = "axum07")]
654    use axum_07 as axum;
655    #[cfg(feature = "axum08")]
656    use axum_08 as axum;
657
658    use axum::{
659        extract::State,
660        http::{Request, StatusCode},
661        routing::get,
662        Router,
663    };
664
665    use tower::Service;
666
667    #[derive(Clone, Debug)]
668    struct Counter {
669        value: Arc<AtomicIsize>,
670    }
671
672    impl Counter {
673        fn new(init: isize) -> Self {
674            Self {
675                value: AtomicIsize::from(init).into(),
676            }
677        }
678
679        fn increment(&self) {
680            self.value.fetch_add(1, Ordering::Release);
681        }
682
683        fn read(&self) -> isize {
684            self.value.load(Ordering::Acquire)
685        }
686    }
687
688    #[tokio::test]
689    async fn should_use_cached_value() {
690        let handler = |State(cnt): State<Counter>| async move {
691            cnt.increment();
692            StatusCode::OK
693        };
694
695        let counter = Counter::new(0);
696        let cache = CacheLayer::with_lifespan(60).use_stale_on_failure();
697        let mut router = Router::new()
698            .route("/", get(handler).layer(cache))
699            .with_state(counter.clone());
700
701        for _ in 0..10 {
702            let status = router
703                .call(Request::get("/").body(Body::empty()).unwrap())
704                .await
705                .unwrap()
706                .status();
707            assert!(status.is_success(), "handler should return success");
708        }
709
710        assert_eq!(1, counter.read(), "handler should’ve been called only once");
711    }
712
713    #[tokio::test]
714    async fn should_not_cache_unsuccessful_responses() {
715        let handler = |State(cnt): State<Counter>| async move {
716            cnt.increment();
717            let responses = [
718                StatusCode::BAD_REQUEST,
719                StatusCode::INTERNAL_SERVER_ERROR,
720                StatusCode::NOT_FOUND,
721            ];
722            let mut rng = rand::rng();
723            responses[rng.random_range(0..responses.len())]
724        };
725
726        let counter = Counter::new(0);
727        let cache = CacheLayer::with_lifespan(60).use_stale_on_failure();
728        let mut router = Router::new()
729            .route("/", get(handler).layer(cache))
730            .with_state(counter.clone());
731
732        for _ in 0..10 {
733            let status = router
734                .call(Request::get("/").body(Body::empty()).unwrap())
735                .await
736                .unwrap()
737                .status();
738            assert!(!status.is_success(), "handler should never return success");
739        }
740
741        assert_eq!(
742            10,
743            counter.read(),
744            "handler should’ve been called for all requests"
745        );
746    }
747
748    #[tokio::test]
749    async fn should_use_last_correct_stale_value() {
750        let handler = |State(cnt): State<Counter>| async move {
751            let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
752            let responses = [
753                StatusCode::BAD_REQUEST,
754                StatusCode::INTERNAL_SERVER_ERROR,
755                StatusCode::NOT_FOUND,
756            ];
757            let mut rng = rand::rng();
758
759            // first response successful, later failed
760            if prev == 0 {
761                StatusCode::OK
762            } else {
763                responses[rng.random_range(0..responses.len())]
764            }
765        };
766
767        let counter = Counter::new(0);
768        let cache = CacheLayer::with_lifespan(1).use_stale_on_failure();
769        let mut router = Router::new()
770            .route("/", get(handler).layer(cache))
771            .with_state(counter);
772
773        // feed the cache
774        let status = router
775            .call(Request::get("/").body(Body::empty()).unwrap())
776            .await
777            .unwrap()
778            .status();
779        assert!(status.is_success(), "handler should return success");
780
781        // wait over 1s for cache eviction
782        tokio::time::sleep(tokio::time::Duration::from_millis(1050)).await;
783
784        for _ in 1..10 {
785            let status = router
786                .call(Request::get("/").body(Body::empty()).unwrap())
787                .await
788                .unwrap()
789                .status();
790            assert!(
791                status.is_success(),
792                "cache should return stale successful value"
793            );
794        }
795    }
796
797    #[tokio::test]
798    async fn should_not_use_stale_values() {
799        let handler = |State(cnt): State<Counter>| async move {
800            let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
801            let responses = [
802                StatusCode::BAD_REQUEST,
803                StatusCode::INTERNAL_SERVER_ERROR,
804                StatusCode::NOT_FOUND,
805            ];
806            let mut rng = rand::rng();
807
808            // first response successful, later failed
809            if prev == 0 {
810                StatusCode::OK
811            } else {
812                responses[rng.random_range(0..responses.len())]
813            }
814        };
815
816        let counter = Counter::new(0);
817        let cache = CacheLayer::with_lifespan(1);
818        let mut router = Router::new()
819            .route("/", get(handler).layer(cache))
820            .with_state(counter.clone());
821
822        // feed the cache
823        let status = router
824            .call(Request::get("/").body(Body::empty()).unwrap())
825            .await
826            .unwrap()
827            .status();
828        assert!(status.is_success(), "handler should return success");
829
830        // wait over 1s for cache eviction
831        tokio::time::sleep(tokio::time::Duration::from_millis(1050)).await;
832
833        for _ in 1..10 {
834            let status = router
835                .call(Request::get("/").body(Body::empty()).unwrap())
836                .await
837                .unwrap()
838                .status();
839            assert!(
840                !status.is_success(),
841                "cache should forward unsuccessful values"
842            );
843        }
844
845        assert_eq!(
846            10,
847            counter.read(),
848            "handler should’ve been called for all requests"
849        );
850    }
851
852    #[tokio::test]
853    async fn should_not_invalidate_cache_when_disabled() {
854        let handler = |State(cnt): State<Counter>| async move {
855            cnt.increment();
856            StatusCode::OK
857        };
858
859        let counter = Counter::new(0);
860        let cache = CacheLayer::with_lifespan(60);
861        let mut router = Router::new()
862            .route("/", get(handler).layer(cache))
863            .with_state(counter.clone());
864
865        // First request to cache the response
866        let status = router
867            .call(Request::get("/").body(Body::empty()).unwrap())
868            .await
869            .unwrap()
870            .status();
871        assert!(status.is_success(), "handler should return success");
872
873        // Second request should return the cached response - no increment
874        let status = router
875            .call(Request::get("/").body(Body::empty()).unwrap())
876            .await
877            .unwrap()
878            .status();
879        assert!(status.is_success(), "handler should return success");
880
881        // Third request with X-Invalidate-Cache header should not invalidate the cache - no increment
882        let status = router
883            .call(
884                Request::get("/")
885                    .header("X-Invalidate-Cache", "true")
886                    .body(Body::empty())
887                    .unwrap(),
888            )
889            .await
890            .unwrap()
891            .status();
892        assert!(status.is_success(), "handler should return success");
893
894        // Fourth request should still return the cached response - no increment
895        let status = router
896            .call(Request::get("/").body(Body::empty()).unwrap())
897            .await
898            .unwrap()
899            .status();
900        assert!(status.is_success(), "handler should return success");
901
902        assert_eq!(1, counter.read(), "handler should’ve been called only once");
903    }
904
905    #[tokio::test]
906    async fn should_invalidate_cache_when_enabled() {
907        let handler = |State(cnt): State<Counter>| async move {
908            cnt.increment();
909            StatusCode::OK
910        };
911
912        let counter = Counter::new(0);
913        let cache = CacheLayer::with_lifespan(60).allow_invalidation();
914        let mut router = Router::new()
915            .route("/", get(handler).layer(cache))
916            .with_state(counter.clone());
917
918        // First request to cache the response
919        let status = router
920            .call(Request::get("/").body(Body::empty()).unwrap())
921            .await
922            .unwrap()
923            .status();
924        assert!(status.is_success(), "handler should return success");
925
926        // Second request should return the cached response - no increment
927        let status = router
928            .call(Request::get("/").body(Body::empty()).unwrap())
929            .await
930            .unwrap()
931            .status();
932        assert!(status.is_success(), "handler should return success");
933
934        // Third request with X-Invalidate-Cache header to invalidate the cache
935        let status = router
936            .call(
937                Request::get("/")
938                    .header("X-Invalidate-Cache", "true")
939                    .body(Body::empty())
940                    .unwrap(),
941            )
942            .await
943            .unwrap()
944            .status();
945        assert!(status.is_success(), "handler should return success");
946
947        // Fourth request to verify that the handler is called again
948        let status = router
949            .call(Request::get("/").body(Body::empty()).unwrap())
950            .await
951            .unwrap()
952            .status();
953        assert!(status.is_success(), "handler should return success");
954
955        assert_eq!(2, counter.read(), "handler should’ve been called twice");
956    }
957
958    #[tokio::test]
959    async fn should_not_include_age_header_when_disabled() {
960        let handler = |State(cnt): State<Counter>| async move {
961            cnt.increment();
962            StatusCode::OK
963        };
964
965        let counter = Counter::new(0);
966        let cache = CacheLayer::with_lifespan(60);
967        let mut router = Router::new()
968            .route("/", get(handler).layer(cache))
969            .with_state(counter.clone());
970
971        // First request to cache the response
972        let response = router
973            .call(Request::get("/").body(Body::empty()).unwrap())
974            .await
975            .unwrap();
976        assert!(
977            response.status().is_success(),
978            "handler should return success"
979        );
980
981        // Second request should return the cached response
982        let response = router
983            .call(Request::get("/").body(Body::empty()).unwrap())
984            .await
985            .unwrap();
986        assert!(
987            response.status().is_success(),
988            "handler should return success"
989        );
990        assert!(
991            response.headers().get("X-Cache-Age").is_none(),
992            "Age header should not be present"
993        );
994
995        assert_eq!(1, counter.read(), "handler should’ve been called only once");
996    }
997
998    #[tokio::test]
999    async fn should_include_age_header_when_enabled() {
1000        let handler = |State(cnt): State<Counter>| async move {
1001            cnt.increment();
1002            StatusCode::OK
1003        };
1004
1005        let counter = Counter::new(0);
1006        let cache = CacheLayer::with_lifespan(60).add_response_headers();
1007        let mut router = Router::new()
1008            .route("/", get(handler).layer(cache))
1009            .with_state(counter.clone());
1010
1011        // First request to cache the response
1012        let response = router
1013            .call(Request::get("/").body(Body::empty()).unwrap())
1014            .await
1015            .unwrap();
1016        assert!(
1017            response.status().is_success(),
1018            "handler should return success"
1019        );
1020
1021        // Age should be 0
1022        assert_eq!(
1023            response
1024                .headers()
1025                .get("X-Cache-Age")
1026                .and_then(|v| v.to_str().ok())
1027                .unwrap_or(""),
1028            "0",
1029            "Age header should be present and equal to 0"
1030        );
1031        // wait over 2s to age the cache
1032        tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1033        // Second request should return the cached response
1034        let response = router
1035            .call(Request::get("/").body(Body::empty()).unwrap())
1036            .await
1037            .unwrap();
1038
1039        assert_eq!(
1040            response
1041                .headers()
1042                .get("X-Cache-Age")
1043                .and_then(|v| v.to_str().ok())
1044                .unwrap_or(""),
1045            "2",
1046            "Age header should be present and equal to 2"
1047        );
1048
1049        assert_eq!(1, counter.read(), "handler should’ve been called only once");
1050    }
1051
1052    #[tokio::test]
1053    async fn should_cache_by_custom_keys() {
1054        let handler = |State(cnt): State<Counter>| async move {
1055            cnt.increment();
1056            StatusCode::OK
1057        };
1058
1059        let counter = Counter::new(0);
1060        let keyer = |request: &Request<Body>| {
1061            (
1062                request.method().clone(),
1063                request
1064                    .headers()
1065                    .get(axum::http::header::ACCEPT)
1066                    .and_then(|c| c.to_str().ok())
1067                    .unwrap_or("")
1068                    .to_string(),
1069                request.uri().clone(),
1070            )
1071        };
1072        let cache = CacheLayer::with_lifespan_and_keyer(60, keyer).add_response_headers();
1073        let mut router = Router::new()
1074            .route("/", get(handler).layer(cache))
1075            .with_state(counter.clone());
1076
1077        // First request to cache the response
1078        let response = router
1079            .call(Request::get("/").body(Body::empty()).unwrap())
1080            .await
1081            .unwrap();
1082        assert!(
1083            response.status().is_success(),
1084            "handler should return success"
1085        );
1086
1087        // Age should be 0
1088        assert_eq!(
1089            response
1090                .headers()
1091                .get("X-Cache-Age")
1092                .and_then(|v| v.to_str().ok())
1093                .unwrap_or(""),
1094            "0",
1095            "Age header should be present and equal to 0"
1096        );
1097
1098        // wait over 2s to age the cache
1099        tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1100        // Second request should return the cached response
1101        let response = router
1102            .call(Request::get("/").body(Body::empty()).unwrap())
1103            .await
1104            .unwrap();
1105
1106        assert_eq!(
1107            response
1108                .headers()
1109                .get("X-Cache-Age")
1110                .and_then(|v| v.to_str().ok())
1111                .unwrap_or(""),
1112            "2",
1113            "Age header should be present and equal to 2"
1114        );
1115
1116        // Request with a different accept header should return a new response
1117        let response = router
1118            .call(
1119                Request::get("/")
1120                    .header(axum::http::header::ACCEPT, "application/json")
1121                    .body(Body::empty())
1122                    .unwrap(),
1123            )
1124            .await
1125            .unwrap();
1126
1127        assert_eq!(
1128            response
1129                .headers()
1130                .get("X-Cache-Age")
1131                .and_then(|v| v.to_str().ok())
1132                .unwrap_or(""),
1133            "0",
1134            "Age header should be present and equal to 0"
1135        );
1136
1137        // wait over 2s to age the cache
1138        tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1139        // Second request should return the newly cached response
1140        let response = router
1141            .call(
1142                Request::get("/")
1143                    .header(axum::http::header::ACCEPT, "application/json")
1144                    .body(Body::empty())
1145                    .unwrap(),
1146            )
1147            .await
1148            .unwrap();
1149
1150        assert_eq!(
1151            response
1152                .headers()
1153                .get("X-Cache-Age")
1154                .and_then(|v| v.to_str().ok())
1155                .unwrap_or(""),
1156            "2",
1157            "Age header should be present and equal to 2"
1158        );
1159
1160        assert_eq!(
1161            2,
1162            counter.read(),
1163            "handler should’ve been called only twice"
1164        );
1165    }
1166}