axum_response_cache/
lib.rs

1//! This library provides [Axum middleware](`axum#middleware`) that caches HTTP responses to the
2//! incoming requests based on their HTTP method and path.
3//!
4//! The main struct is [`CacheLayer`]. It can be created with any cache that implements two traits
5//! from the [`cached`] crate: [`cached::Cached`] and [`cached::CloneCached`].
6//!
7//! The *current* version of [`CacheLayer`] is compatible only with services accepting
8//! Axum’s [`Request<Body>`](`http::Request<axum::body::Body>`) and returning
9//! [`axum::response::Response`], thus it is not compatible with non-Axum [`tower`] services.
10//!
11//! It’s possible to configure the layer to re-use an old expired response in case the wrapped
12//! service fails to produce a new successful response.
13//!
14//! Only successful responses are cached (responses with status codes outside of the `[200-299]`
15//! range are passed-through or ignored).
16//!
17//! The cache limits maximum size of the response’s body (128 MB by default).
18//!
19//! ## Examples
20//!
21//! To cache a response over a specific route, just wrap it in a [`CacheLayer`]:
22//!
23//! ```rust,no_run
24//! # use axum_08 as axum;
25//! use axum::{Router, extract::Path, routing::get};
26//! use axum_response_cache::CacheLayer;
27//!
28//! #[tokio::main]
29//! async fn main() {
30//!     let mut router = Router::new()
31//!         .route(
32//!             "/hello/{name}",
33//!             get(|Path(name): Path<String>| async move { format!("Hello, {name}!") })
34//!                 // this will cache responses with each `:name` for 60 seconds.
35//!                 .layer(CacheLayer::with_lifespan(60)),
36//!         );
37//!
38//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
39//!     axum::serve(listener, router).await.unwrap();
40//! }
41//! ```
42//!
43//! ### Reusing last successful response
44//!
45//! ```rust
46//! # use std::sync::atomic::{AtomicBool, Ordering};
47//! # use axum_08 as axum;
48//! use axum::{
49//!     body::Body,
50//!     extract::Path,
51//!     http::status::StatusCode,
52//!     http::Request,
53//!     Router,
54//!     routing::get,
55//! };
56//! use axum_response_cache::CacheLayer;
57//! use tower::Service as _;
58//!
59//! // a handler that returns 200 OK only the first time it’s called
60//! async fn handler(Path(name): Path<String>) -> (StatusCode, String) {
61//!     static FIRST_RUN: AtomicBool = AtomicBool::new(true);
62//!     let first_run = FIRST_RUN.swap(false, Ordering::AcqRel);
63//!
64//!     if first_run {
65//!         (StatusCode::OK, format!("Hello, {name}"))
66//!     } else {
67//!         (StatusCode::INTERNAL_SERVER_ERROR, String::from("Error!"))
68//!     }
69//! }
70//!
71//! # #[tokio::main]
72//! # async fn main() {
73//! let mut router = Router::new()
74//!     .route("/hello/{name}", get(handler))
75//!     .layer(CacheLayer::with_lifespan(60).use_stale_on_failure());
76//!
77//! // first request will fire handler and get the response
78//! let status1 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
79//!     .await
80//!     .unwrap()
81//!     .status();
82//! assert_eq!(StatusCode::OK, status1);
83//!
84//! // second request will reuse the last response since the handler now returns ISE
85//! let status2 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
86//!     .await
87//!     .unwrap()
88//!     .status();
89//! assert_eq!(StatusCode::OK, status2);
90//! # }
91//! ```
92//!
93//! ### Serving static files
94//! This middleware can be used to cache files served in memory to limit hard drive load on the
95//! server. To serve files you can use [`tower-http::services::ServeDir`](https://docs.rs/tower-http/latest/tower_http/services/struct.ServeDir.html) layer.
96//! ```rust,ignore
97//! let router = Router::new().nest_service("/", ServeDir::new("static/"));
98//! ```
99//!
100//! ### Limiting the body size
101//!
102//! ```rust
103//! # use axum_08 as axum;
104//! use axum::{
105//!     body::Body,
106//!     extract::Path,
107//!     http::status::StatusCode,
108//!     http::Request,
109//!     Router,
110//!     routing::get,
111//! };
112//! use axum_response_cache::CacheLayer;
113//! use tower::Service as _;
114//!
115//! // returns a short string, well below the limit
116//! async fn ok_handler() -> &'static str {
117//!     "ok"
118//! }
119//!
120//! async fn too_long_handler() -> &'static str {
121//!     "a response that is well beyond the limit of the cache!"
122//! }
123//!
124//! # #[tokio::main]
125//! # async fn main() {
126//! let mut router = Router::new()
127//!     .route("/ok", get(ok_handler))
128//!     .route("/too_long", get(too_long_handler))
129//!     // limit max cached body to only 16 bytes
130//!     .layer(CacheLayer::with_lifespan(60).body_limit(16));
131//!
132//! let status_ok = router.call(Request::get("/ok").body(Body::empty()).unwrap())
133//!     .await
134//!     .unwrap()
135//!     .status();
136//! assert_eq!(StatusCode::OK, status_ok);
137//!
138//! let status_too_long = router.call(Request::get("/too_long").body(Body::empty()).unwrap())
139//!     .await
140//!     .unwrap()
141//!     .status();
142//! assert_eq!(StatusCode::INTERNAL_SERVER_ERROR, status_too_long);
143//! # }
144//! ```
145//! ### Manual Cache Invalidation
146//! This middleware allows manual cache invalidation by setting the `X-Invalidate-Cache` header in the request. This can be useful when you know the underlying data has changed and you want to force a fresh pull of data.
147//!
148//! ```rust
149//! # use axum_08 as axum;
150//! use axum::{
151//!     body::Body,
152//!     extract::Path,
153//!     http::status::StatusCode,
154//!     http::Request,
155//!     Router,
156//!     routing::get,
157//! };
158//! use axum_response_cache::CacheLayer;
159//! use tower::Service as _;
160//!
161//! async fn handler(Path(name): Path<String>) -> (StatusCode, String) {
162//!     (StatusCode::OK, format!("Hello, {name}"))
163//! }
164//!
165//! # #[tokio::main]
166//! # async fn main() {
167//! let mut router = Router::new()
168//!     .route("/hello/{name}", get(handler))
169//!     .layer(CacheLayer::with_lifespan(60).allow_invalidation());
170//!
171//! // first request will fire handler and get the response
172//! let status1 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
173//!     .await
174//!     .unwrap()
175//!     .status();
176//! assert_eq!(StatusCode::OK, status1);
177//!
178//! // second request should return the cached response
179//! let status2 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
180//!     .await
181//!     .unwrap()
182//!     .status();
183//! assert_eq!(StatusCode::OK, status2);
184//!
185//! // third request with X-Invalidate-Cache header to invalidate the cache
186//! let status3 = router.call(
187//!     Request::get("/hello/foo")
188//!         .header("X-Invalidate-Cache", "true")
189//!         .body(Body::empty())
190//!         .unwrap(),
191//!     )
192//!     .await
193//!     .unwrap()
194//!     .status();
195//! assert_eq!(StatusCode::OK, status3);
196//!
197//! // fourth request to verify that the handler is called again
198//! let status4 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
199//!     .await
200//!     .unwrap()
201//!     .status();
202//! assert_eq!(StatusCode::OK, status4);
203//! # }
204//! ```
205//! Cache invalidation could be dangerous because it can allow a user to force the server to make a request to an external service or database. It is disabled by default, but can be enabled by calling the [`CacheLayer::allow_invalidation`] method.
206//!
207//! ## Using custom cache
208//!
209//! ```rust
210//! # use axum_08 as axum;
211//! use axum::{Router, routing::get};
212//! use axum_response_cache::CacheLayer;
213//! // let’s use TimedSizedCache here
214//! use cached::stores::TimedSizedCache;
215//! # use axum::{body::Body, http::Request};
216//! # use tower::ServiceExt;
217//!
218//! # #[tokio::main]
219//! # async fn main() {
220//! let router: Router = Router::new()
221//!     .route("/hello", get(|| async { "Hello, world!" }))
222//!     // cache maximum value of 50 responses for one minute
223//!     .layer(CacheLayer::with(TimedSizedCache::with_size_and_lifespan(50, 60)));
224//! # // force type inference to resolve the exact type of router
225//! #     let _ = router.oneshot(Request::get("/hello").body(Body::empty()).unwrap()).await;
226//! # }
227//! ```
228//!
229//! ## Use cases
230//! Caching responses in memory (eg. using [`cached::TimedCache`]) might be useful when the
231//! underlying service produces the responses by:
232//! 1. doing heavy computation,
233//! 2. requesting external service(s) that might not be fully reliable or performant,
234//! 3. serving static files from disk.
235//!
236//! In those cases, if the response to identical requests does not change often over time, it might
237//! be desirable to re-use the same responses from memory without re-calculating them – skipping requests to data
238//! bases, external services, reading from disk.
239//!
240//! ### Using Axum 0.7
241//!
242//! By default, this library uses Axum 0.8. However, you can configure it to use Axum 0.7 by enabling the appropriate feature flag in your `Cargo.toml`.
243//!
244//! To use Axum 0.7, add the following to your `Cargo.toml`:
245//!
246//! ```toml
247//! [dependencies]
248//! axum-response-cache = { version = "0.1.2", features = ["axum07"], default-features = false }
249//! ```
250//!
251//! This will disable the default Axum 0.8 feature and enable the Axum 0.7 feature instead.
252
253use std::{
254    convert::Infallible,
255    future::Future,
256    pin::Pin,
257    sync::{Arc, Mutex},
258    task::{Context, Poll},
259};
260use tracing_futures::Instrument as _;
261
262#[cfg(feature = "axum07")]
263use axum_07 as axum;
264#[cfg(feature = "axum08")]
265use axum_08 as axum;
266
267use axum::body;
268use axum::{
269    body::{Body, Bytes},
270    http::{response::Parts, Request, StatusCode},
271    response::{IntoResponse, Response},
272};
273
274use cached::{Cached, CloneCached, TimedCache};
275use tower::{Layer, Service};
276use tracing::{debug, instrument};
277
278/// The caching key for the responses.
279///
280/// The responses are cached according to the HTTP method [`axum::http::Method`]) and path
281/// ([`axum::http::Uri`]) of the request they responded to.
282type Key = (http::Method, http::Uri);
283
284/// The struct preserving all the headers and body of the cached response.
285#[derive(Clone, Debug)]
286pub struct CachedResponse {
287    parts: Parts,
288    body: Bytes,
289    timestamp: Option<std::time::Instant>,
290}
291
292impl IntoResponse for CachedResponse {
293    fn into_response(self) -> Response {
294        let mut response = Response::from_parts(self.parts, Body::from(self.body));
295        if let Some(timestamp) = self.timestamp {
296            let age = timestamp.elapsed().as_secs();
297            response
298                .headers_mut()
299                .insert("X-Cache-Age", age.to_string().parse().unwrap());
300        }
301        response
302    }
303}
304
305/// The main struct of the library. The layer providing caching to the wrapped service.
306#[derive(Clone)]
307pub struct CacheLayer<C> {
308    cache: Arc<Mutex<C>>,
309    use_stale: bool,
310    limit: usize,
311    allow_invalidation: bool,
312    add_response_headers: bool,
313}
314
315impl<C> CacheLayer<C>
316where
317    C: Cached<Key, CachedResponse> + CloneCached<Key, CachedResponse>,
318{
319    /// Create a new cache layer with a given cache and the default body size limit of 128 MB.
320    pub fn with(cache: C) -> Self {
321        Self {
322            cache: Arc::new(Mutex::new(cache)),
323            use_stale: false,
324            limit: 128 * 1024 * 1024,
325            allow_invalidation: false,
326            add_response_headers: false,
327        }
328    }
329
330    /// Switch the layer’s settings to preserve the last successful response even when it’s evicted
331    /// from the cache but the service failed to provide a new successful response (ie. eg. when
332    /// the underlying service responds with `404 NOT FOUND`, the cache will keep providing the last stale `200 OK`
333    /// response produced).
334    pub fn use_stale_on_failure(self) -> Self {
335        Self {
336            use_stale: true,
337            ..self
338        }
339    }
340
341    /// Change the maximum body size limit. If you want unlimited size, use [`usize::MAX`].
342    pub fn body_limit(self, new_limit: usize) -> Self {
343        Self {
344            limit: new_limit,
345            ..self
346        }
347    }
348
349    /// Allow manual cache invalidation by setting the `X-Invalidate-Cache` header in the request.
350    /// This will allow the cache to be invalidated for the given key.
351    pub fn allow_invalidation(self) -> Self {
352        Self {
353            allow_invalidation: true,
354            ..self
355        }
356    }
357
358    /// Allow the response headers to be included in the cached response.
359    pub fn add_response_headers(self) -> Self {
360        Self {
361            add_response_headers: true,
362            ..self
363        }
364    }
365}
366
367impl CacheLayer<TimedCache<Key, CachedResponse>> {
368    /// Create a new cache layer with the desired TTL in seconds
369    pub fn with_lifespan(ttl_sec: u64) -> CacheLayer<TimedCache<Key, CachedResponse>> {
370        CacheLayer::with(TimedCache::with_lifespan(ttl_sec))
371    }
372}
373
374impl<S, C> Layer<S> for CacheLayer<C> {
375    type Service = CacheService<S, C>;
376
377    fn layer(&self, inner: S) -> Self::Service {
378        Self::Service {
379            inner,
380            cache: Arc::clone(&self.cache),
381            use_stale: self.use_stale,
382            limit: self.limit,
383            allow_invalidation: self.allow_invalidation,
384            add_response_headers: self.add_response_headers,
385        }
386    }
387}
388
389#[derive(Clone)]
390pub struct CacheService<S, C> {
391    inner: S,
392    cache: Arc<Mutex<C>>,
393    use_stale: bool,
394    limit: usize,
395    allow_invalidation: bool,
396    add_response_headers: bool,
397}
398
399impl<S, C> Service<Request<Body>> for CacheService<S, C>
400where
401    S: Service<Request<Body>, Response = Response, Error = Infallible> + Clone + Send,
402    S::Future: Send + 'static,
403    C: Cached<Key, CachedResponse> + CloneCached<Key, CachedResponse> + Send + 'static,
404{
405    type Response = Response;
406    type Error = Infallible;
407    type Future = Pin<Box<dyn Future<Output = Result<Response, Infallible>> + Send + 'static>>;
408
409    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
410        self.inner.poll_ready(cx)
411    }
412
413    #[instrument(skip(self, request))]
414    fn call(&mut self, request: Request<Body>) -> Self::Future {
415        let mut inner = self.inner.clone();
416        let use_stale = self.use_stale;
417        let allow_invalidation = self.allow_invalidation;
418        let add_response_headers = self.add_response_headers;
419        let limit = self.limit;
420        let cache = Arc::clone(&self.cache);
421        let key = (request.method().clone(), request.uri().clone());
422
423        // Check for the custom header "X-Invalidate-Cache" if invalidation is allowed
424        if allow_invalidation && request.headers().contains_key("X-Invalidate-Cache") {
425            // Manually invalidate the cache for this key
426            cache.lock().unwrap().cache_remove(&key);
427            debug!("Cache invalidated manually for key {:?}", key);
428        }
429
430        let inner_fut = inner
431            .call(request)
432            .instrument(tracing::info_span!("inner_service"));
433        let (cached, evicted) = {
434            let mut guard = cache.lock().unwrap();
435            let (cached, evicted) = guard.cache_get_expired(&key);
436            if let (Some(stale), true) = (cached.as_ref(), evicted) {
437                // reinsert stale value immediately so that others don’t schedule their updating
438                debug!("Found stale value in cache, reinsterting and attempting refresh");
439                guard.cache_set(key.clone(), stale.clone());
440            }
441            (cached, evicted)
442        };
443
444        Box::pin(async move {
445            match (cached, evicted) {
446                (Some(value), false) => Ok(value.into_response()),
447                (Some(stale_value), true) => {
448                    let response = inner_fut.await.unwrap();
449                    if response.status().is_success() {
450                        Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
451                    } else if use_stale {
452                        debug!("Returning stale value.");
453                        Ok(stale_value.into_response())
454                    } else {
455                        debug!("Stale value in cache, evicting and returning failed response.");
456                        cache.lock().unwrap().cache_remove(&key);
457                        Ok(response)
458                    }
459                }
460                (None, _) => {
461                    let response = inner_fut.await.unwrap();
462                    if response.status().is_success() {
463                        Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
464                    } else {
465                        Ok(response)
466                    }
467                }
468            }
469        })
470    }
471}
472
473#[instrument(skip(cache, response))]
474async fn update_cache<C: Cached<Key, CachedResponse> + CloneCached<Key, CachedResponse>>(
475    cache: &Arc<Mutex<C>>,
476    key: Key,
477    response: Response,
478    limit: usize,
479    add_response_headers: bool,
480) -> Response {
481    let (parts, body) = response.into_parts();
482    let Ok(body) = body::to_bytes(body, limit).await else {
483        return (
484            StatusCode::INTERNAL_SERVER_ERROR,
485            format!("File too big, over {limit} bytes"),
486        )
487            .into_response();
488    };
489    let value = CachedResponse {
490        parts,
491        body,
492        timestamp: if add_response_headers {
493            Some(std::time::Instant::now())
494        } else {
495            None
496        },
497    };
498    {
499        cache.lock().unwrap().cache_set(key, value.clone());
500    }
501    value.into_response()
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507    use rand::Rng;
508    use std::sync::atomic::{AtomicIsize, Ordering};
509
510    #[cfg(feature = "axum07")]
511    use axum_07 as axum;
512    #[cfg(feature = "axum08")]
513    use axum_08 as axum;
514
515    use axum::{
516        extract::State,
517        http::{Request, StatusCode},
518        routing::get,
519        Router,
520    };
521
522    use tower::Service;
523
524    #[derive(Clone, Debug)]
525    struct Counter {
526        value: Arc<AtomicIsize>,
527    }
528
529    impl Counter {
530        fn new(init: isize) -> Self {
531            Self {
532                value: AtomicIsize::from(init).into(),
533            }
534        }
535
536        fn increment(&self) {
537            self.value.fetch_add(1, Ordering::Release);
538        }
539
540        fn read(&self) -> isize {
541            self.value.load(Ordering::Acquire)
542        }
543    }
544
545    #[tokio::test]
546    async fn should_use_cached_value() {
547        let handler = |State(cnt): State<Counter>| async move {
548            cnt.increment();
549            StatusCode::OK
550        };
551
552        let counter = Counter::new(0);
553        let cache = CacheLayer::with_lifespan(60).use_stale_on_failure();
554        let mut router = Router::new()
555            .route("/", get(handler).layer(cache))
556            .with_state(counter.clone());
557
558        for _ in 0..10 {
559            let status = router
560                .call(Request::get("/").body(Body::empty()).unwrap())
561                .await
562                .unwrap()
563                .status();
564            assert!(status.is_success(), "handler should return success");
565        }
566
567        assert_eq!(1, counter.read(), "handler should’ve been called only once");
568    }
569
570    #[tokio::test]
571    async fn should_not_cache_unsuccessful_responses() {
572        let handler = |State(cnt): State<Counter>| async move {
573            cnt.increment();
574            let responses = [
575                StatusCode::BAD_REQUEST,
576                StatusCode::INTERNAL_SERVER_ERROR,
577                StatusCode::NOT_FOUND,
578            ];
579            let mut rng = rand::thread_rng();
580            responses[rng.gen_range(0..responses.len())]
581        };
582
583        let counter = Counter::new(0);
584        let cache = CacheLayer::with_lifespan(60).use_stale_on_failure();
585        let mut router = Router::new()
586            .route("/", get(handler).layer(cache))
587            .with_state(counter.clone());
588
589        for _ in 0..10 {
590            let status = router
591                .call(Request::get("/").body(Body::empty()).unwrap())
592                .await
593                .unwrap()
594                .status();
595            assert!(!status.is_success(), "handler should never return success");
596        }
597
598        assert_eq!(
599            10,
600            counter.read(),
601            "handler should’ve been called for all requests"
602        );
603    }
604
605    #[tokio::test]
606    async fn should_use_last_correct_stale_value() {
607        let handler = |State(cnt): State<Counter>| async move {
608            let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
609            let responses = [
610                StatusCode::BAD_REQUEST,
611                StatusCode::INTERNAL_SERVER_ERROR,
612                StatusCode::NOT_FOUND,
613            ];
614            let mut rng = rand::thread_rng();
615
616            // first response successful, later failed
617            if prev == 0 {
618                StatusCode::OK
619            } else {
620                responses[rng.gen_range(0..responses.len())]
621            }
622        };
623
624        let counter = Counter::new(0);
625        let cache = CacheLayer::with_lifespan(1).use_stale_on_failure();
626        let mut router = Router::new()
627            .route("/", get(handler).layer(cache))
628            .with_state(counter);
629
630        // feed the cache
631        let status = router
632            .call(Request::get("/").body(Body::empty()).unwrap())
633            .await
634            .unwrap()
635            .status();
636        assert!(status.is_success(), "handler should return success");
637
638        // wait over 1s for cache eviction
639        tokio::time::sleep(tokio::time::Duration::from_millis(1050)).await;
640
641        for _ in 1..10 {
642            let status = router
643                .call(Request::get("/").body(Body::empty()).unwrap())
644                .await
645                .unwrap()
646                .status();
647            assert!(
648                status.is_success(),
649                "cache should return stale successful value"
650            );
651        }
652    }
653
654    #[tokio::test]
655    async fn should_not_use_stale_values() {
656        let handler = |State(cnt): State<Counter>| async move {
657            let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
658            let responses = [
659                StatusCode::BAD_REQUEST,
660                StatusCode::INTERNAL_SERVER_ERROR,
661                StatusCode::NOT_FOUND,
662            ];
663            let mut rng = rand::thread_rng();
664
665            // first response successful, later failed
666            if prev == 0 {
667                StatusCode::OK
668            } else {
669                responses[rng.gen_range(0..responses.len())]
670            }
671        };
672
673        let counter = Counter::new(0);
674        let cache = CacheLayer::with_lifespan(1);
675        let mut router = Router::new()
676            .route("/", get(handler).layer(cache))
677            .with_state(counter.clone());
678
679        // feed the cache
680        let status = router
681            .call(Request::get("/").body(Body::empty()).unwrap())
682            .await
683            .unwrap()
684            .status();
685        assert!(status.is_success(), "handler should return success");
686
687        // wait over 1s for cache eviction
688        tokio::time::sleep(tokio::time::Duration::from_millis(1050)).await;
689
690        for _ in 1..10 {
691            let status = router
692                .call(Request::get("/").body(Body::empty()).unwrap())
693                .await
694                .unwrap()
695                .status();
696            assert!(
697                !status.is_success(),
698                "cache should forward unsuccessful values"
699            );
700        }
701
702        assert_eq!(
703            10,
704            counter.read(),
705            "handler should’ve been called for all requests"
706        );
707    }
708
709    #[tokio::test]
710    async fn should_not_invalidate_cache_when_disabled() {
711        let handler = |State(cnt): State<Counter>| async move {
712            cnt.increment();
713            StatusCode::OK
714        };
715
716        let counter = Counter::new(0);
717        let cache = CacheLayer::with_lifespan(60);
718        let mut router = Router::new()
719            .route("/", get(handler).layer(cache))
720            .with_state(counter.clone());
721
722        // First request to cache the response
723        let status = router
724            .call(Request::get("/").body(Body::empty()).unwrap())
725            .await
726            .unwrap()
727            .status();
728        assert!(status.is_success(), "handler should return success");
729
730        // Second request should return the cached response - no increment
731        let status = router
732            .call(Request::get("/").body(Body::empty()).unwrap())
733            .await
734            .unwrap()
735            .status();
736        assert!(status.is_success(), "handler should return success");
737
738        // Third request with X-Invalidate-Cache header should not invalidate the cache - no increment
739        let status = router
740            .call(
741                Request::get("/")
742                    .header("X-Invalidate-Cache", "true")
743                    .body(Body::empty())
744                    .unwrap(),
745            )
746            .await
747            .unwrap()
748            .status();
749        assert!(status.is_success(), "handler should return success");
750
751        // Fourth request should still return the cached response - no increment
752        let status = router
753            .call(Request::get("/").body(Body::empty()).unwrap())
754            .await
755            .unwrap()
756            .status();
757        assert!(status.is_success(), "handler should return success");
758
759        assert_eq!(1, counter.read(), "handler should’ve been called only once");
760    }
761
762    #[tokio::test]
763    async fn should_invalidate_cache_when_enabled() {
764        let handler = |State(cnt): State<Counter>| async move {
765            cnt.increment();
766            StatusCode::OK
767        };
768
769        let counter = Counter::new(0);
770        let cache = CacheLayer::with_lifespan(60).allow_invalidation();
771        let mut router = Router::new()
772            .route("/", get(handler).layer(cache))
773            .with_state(counter.clone());
774
775        // First request to cache the response
776        let status = router
777            .call(Request::get("/").body(Body::empty()).unwrap())
778            .await
779            .unwrap()
780            .status();
781        assert!(status.is_success(), "handler should return success");
782
783        // Second request should return the cached response - no increment
784        let status = router
785            .call(Request::get("/").body(Body::empty()).unwrap())
786            .await
787            .unwrap()
788            .status();
789        assert!(status.is_success(), "handler should return success");
790
791        // Third request with X-Invalidate-Cache header to invalidate the cache
792        let status = router
793            .call(
794                Request::get("/")
795                    .header("X-Invalidate-Cache", "true")
796                    .body(Body::empty())
797                    .unwrap(),
798            )
799            .await
800            .unwrap()
801            .status();
802        assert!(status.is_success(), "handler should return success");
803
804        // Fourth request to verify that the handler is called again
805        let status = router
806            .call(Request::get("/").body(Body::empty()).unwrap())
807            .await
808            .unwrap()
809            .status();
810        assert!(status.is_success(), "handler should return success");
811
812        assert_eq!(2, counter.read(), "handler should’ve been called twice");
813    }
814
815    #[tokio::test]
816    async fn should_not_include_age_header_when_disabled() {
817        let handler = |State(cnt): State<Counter>| async move {
818            cnt.increment();
819            StatusCode::OK
820        };
821
822        let counter = Counter::new(0);
823        let cache = CacheLayer::with_lifespan(60);
824        let mut router = Router::new()
825            .route("/", get(handler).layer(cache))
826            .with_state(counter.clone());
827
828        // First request to cache the response
829        let response = router
830            .call(Request::get("/").body(Body::empty()).unwrap())
831            .await
832            .unwrap();
833        assert!(
834            response.status().is_success(),
835            "handler should return success"
836        );
837
838        // Second request should return the cached response
839        let response = router
840            .call(Request::get("/").body(Body::empty()).unwrap())
841            .await
842            .unwrap();
843        assert!(
844            response.status().is_success(),
845            "handler should return success"
846        );
847        assert!(
848            response.headers().get("X-Cache-Age").is_none(),
849            "Age header should not be present"
850        );
851
852        assert_eq!(1, counter.read(), "handler should’ve been called only once");
853    }
854
855    #[tokio::test]
856    async fn should_include_age_header_when_enabled() {
857        let handler = |State(cnt): State<Counter>| async move {
858            cnt.increment();
859            StatusCode::OK
860        };
861
862        let counter = Counter::new(0);
863        let cache = CacheLayer::with_lifespan(60).add_response_headers();
864        let mut router = Router::new()
865            .route("/", get(handler).layer(cache))
866            .with_state(counter.clone());
867
868        // First request to cache the response
869        let response = router
870            .call(Request::get("/").body(Body::empty()).unwrap())
871            .await
872            .unwrap();
873        assert!(
874            response.status().is_success(),
875            "handler should return success"
876        );
877
878        // Age should be 0
879        assert_eq!(
880            response
881                .headers()
882                .get("X-Cache-Age")
883                .and_then(|v| v.to_str().ok())
884                .unwrap_or(""),
885            "0",
886            "Age header should be present and equal to 0"
887        );
888        // wait over 2s to age the cache
889        tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
890        // Second request should return the cached response
891        let response = router
892            .call(Request::get("/").body(Body::empty()).unwrap())
893            .await
894            .unwrap();
895
896        assert_eq!(
897            response
898                .headers()
899                .get("X-Cache-Age")
900                .and_then(|v| v.to_str().ok())
901                .unwrap_or(""),
902            "2",
903            "Age header should be present and equal to 2"
904        );
905
906        assert_eq!(1, counter.read(), "handler should’ve been called only once");
907    }
908}