axum_response_cache/
lib.rs

1//! This library provides [Axum middleware](`axum#middleware`) that caches HTTP responses to the
2//! incoming requests based on their HTTP method and path.
3//!
4//! The main struct is [`CacheLayer`]. It can be created with any cache that implements two traits
5//! from the [`cached`] crate: [`cached::Cached`] and [`cached::CloneCached`].
6//!
7//! The *current* version of [`CacheLayer`] is compatible only with services accepting
8//! Axum’s [`Request<Body>`](`http::Request<axum::body::Body>`) and returning
9//! [`axum::response::Response`], thus it is not compatible with non-Axum [`tower`] services.
10//!
11//! It’s possible to configure the layer to re-use an old expired response in case the wrapped
12//! service fails to produce a new successful response.
13//!
14//! Only successful responses are cached (responses with status codes outside of the `[200-299]`
15//! range are passed-through or ignored).
16//!
17//! The cache limits maximum size of the response’s body (128 MB by default).
18//!
19//! ## Examples
20//!
21//! To cache a response over a specific route, just wrap it in a [`CacheLayer`]:
22//!
23//! ```rust,no_run
24//! # use axum_08 as axum;
25//! use std::time::Duration;
26//! use axum::{Router, extract::Path, routing::get};
27//! use axum_response_cache::CacheLayer;
28//!
29//! #[tokio::main]
30//! async fn main() {
31//!     let mut router = Router::new()
32//!         .route(
33//!             "/hello/{name}",
34//!             get(|Path(name): Path<String>| async move { format!("Hello, {name}!") })
35//!                 // this will cache responses with each `:name` for 60 seconds.
36//!                 .layer(CacheLayer::with_lifespan(Duration::from_secs(60))),
37//!         );
38//!
39//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
40//!     axum::serve(listener, router).await.unwrap();
41//! }
42//! ```
43//!
44//! ### Reusing last successful response
45//!
46//! ```rust
47//! # use std::sync::atomic::{AtomicBool, Ordering};
48//! # use axum_08 as axum;
49//! use std::time::Duration;
50//! use axum::{
51//!     body::Body,
52//!     extract::Path,
53//!     http::status::StatusCode,
54//!     http::Request,
55//!     Router,
56//!     routing::get,
57//! };
58//! use axum_response_cache::CacheLayer;
59//! use tower::Service as _;
60//!
61//! // a handler that returns 200 OK only the first time it’s called
62//! async fn handler(Path(name): Path<String>) -> (StatusCode, String) {
63//!     static FIRST_RUN: AtomicBool = AtomicBool::new(true);
64//!     let first_run = FIRST_RUN.swap(false, Ordering::AcqRel);
65//!
66//!     if first_run {
67//!         (StatusCode::OK, format!("Hello, {name}"))
68//!     } else {
69//!         (StatusCode::INTERNAL_SERVER_ERROR, String::from("Error!"))
70//!     }
71//! }
72//!
73//! # #[tokio::main]
74//! # async fn main() {
75//! let mut router = Router::new()
76//!     .route("/hello/{name}", get(handler))
77//!     .layer(CacheLayer::with_lifespan(Duration::from_secs(60)).use_stale_on_failure());
78//!
79//! // first request will fire handler and get the response
80//! let status1 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
81//!     .await
82//!     .unwrap()
83//!     .status();
84//! assert_eq!(StatusCode::OK, status1);
85//!
86//! // second request will reuse the last response since the handler now returns ISE
87//! let status2 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
88//!     .await
89//!     .unwrap()
90//!     .status();
91//! assert_eq!(StatusCode::OK, status2);
92//! # }
93//! ```
94//!
95//! ### Serving static files
96//! This middleware can be used to cache files served in memory to limit hard drive load on the
97//! server. To serve files you can use [`tower-http::services::ServeDir`](https://docs.rs/tower-http/latest/tower_http/services/struct.ServeDir.html) layer.
98//! ```rust,ignore
99//! let router = Router::new().nest_service("/", ServeDir::new("static/"));
100//! ```
101//!
102//! ### Limiting the body size
103//!
104//! ```rust
105//! # use axum_08 as axum;
106//! use std::time::Duration;
107//! use axum::{
108//!     body::Body,
109//!     extract::Path,
110//!     http::status::StatusCode,
111//!     http::Request,
112//!     Router,
113//!     routing::get,
114//! };
115//! use axum_response_cache::CacheLayer;
116//! use tower::Service as _;
117//!
118//! // returns a short string, well below the limit
119//! async fn ok_handler() -> &'static str {
120//!     "ok"
121//! }
122//!
123//! async fn too_long_handler() -> &'static str {
124//!     "a response that is well beyond the limit of the cache!"
125//! }
126//!
127//! # #[tokio::main]
128//! # async fn main() {
129//! let mut router = Router::new()
130//!     .route("/ok", get(ok_handler))
131//!     .route("/too_long", get(too_long_handler))
132//!     // limit max cached body to only 16 bytes
133//!     .layer(CacheLayer::with_lifespan(Duration::from_secs(60)).body_limit(16));
134//!
135//! let status_ok = router.call(Request::get("/ok").body(Body::empty()).unwrap())
136//!     .await
137//!     .unwrap()
138//!     .status();
139//! assert_eq!(StatusCode::OK, status_ok);
140//!
141//! let status_too_long = router.call(Request::get("/too_long").body(Body::empty()).unwrap())
142//!     .await
143//!     .unwrap()
144//!     .status();
145//! assert_eq!(StatusCode::INTERNAL_SERVER_ERROR, status_too_long);
146//! # }
147//! ```
148//! ### Manual Cache Invalidation
149//! This middleware allows manual cache invalidation by setting the `X-Invalidate-Cache` header in the request. This can be useful when you know the underlying data has changed and you want to force a fresh pull of data.
150//!
151//! ```rust
152//! # use axum_08 as axum;
153//! use std::time::Duration;
154//! use axum::{
155//!     body::Body,
156//!     extract::Path,
157//!     http::status::StatusCode,
158//!     http::Request,
159//!     Router,
160//!     routing::get,
161//! };
162//! use axum_response_cache::CacheLayer;
163//! use tower::Service as _;
164//!
165//! async fn handler(Path(name): Path<String>) -> (StatusCode, String) {
166//!     (StatusCode::OK, format!("Hello, {name}"))
167//! }
168//!
169//! # #[tokio::main]
170//! # async fn main() {
171//! let mut router = Router::new()
172//!     .route("/hello/{name}", get(handler))
173//!     .layer(CacheLayer::with_lifespan(Duration::from_secs(60)).allow_invalidation());
174//!
175//! // first request will fire handler and get the response
176//! let status1 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
177//!     .await
178//!     .unwrap()
179//!     .status();
180//! assert_eq!(StatusCode::OK, status1);
181//!
182//! // second request should return the cached response
183//! let status2 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
184//!     .await
185//!     .unwrap()
186//!     .status();
187//! assert_eq!(StatusCode::OK, status2);
188//!
189//! // third request with X-Invalidate-Cache header to invalidate the cache
190//! let status3 = router.call(
191//!     Request::get("/hello/foo")
192//!         .header("X-Invalidate-Cache", "true")
193//!         .body(Body::empty())
194//!         .unwrap(),
195//!     )
196//!     .await
197//!     .unwrap()
198//!     .status();
199//! assert_eq!(StatusCode::OK, status3);
200//!
201//! // fourth request to verify that the handler is called again
202//! let status4 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
203//!     .await
204//!     .unwrap()
205//!     .status();
206//! assert_eq!(StatusCode::OK, status4);
207//! # }
208//! ```
209//! Cache invalidation could be dangerous because it can allow a user to force the server to make a request to an external service or database. It is disabled by default, but can be enabled by calling the [`CacheLayer::allow_invalidation`] method.
210//!
211//! ## Using custom cache
212//! ```rust
213//! # use axum_08 as axum;
214//! use std::time::Duration;
215//! use axum::{Router, routing::get};
216//! use axum_response_cache::CacheLayer;
217//! // let’s use TimedSizedCache here
218//! use cached::stores::TimedSizedCache;
219//! # use axum::{body::Body, http::Request};
220//! # use tower::ServiceExt;
221//!
222//! # #[tokio::main]
223//! # async fn main() {
224//! let router: Router = Router::new()
225//!     .route("/hello", get(|| async { "Hello, world!" }))
226//!     // cache maximum value of 50 responses for one minute
227//!     .layer(CacheLayer::with(TimedSizedCache::with_size_and_lifespan(50, Duration::from_secs(60))));
228//! # // force type inference to resolve the exact type of router
229//! #     let _ = router.oneshot(Request::get("/hello").body(Body::empty()).unwrap()).await;
230//! # }
231//! ```
232//!
233//! ## Using custom keyer
234//! It’s possible to customize the cache’s key to include eg. the `Accept` header (so that
235//! different types of responses are cached separately based on the header).
236//!
237//! ```rust
238//! # use axum_08 as axum;
239//! use std::time::Duration;
240//! use axum::{Router, routing::get};
241//! use axum_response_cache::CacheLayer;
242//! # use axum::{body::Body, http::Request};
243//! # use tower::ServiceExt;
244//!
245//! # #[tokio::main]
246//! # async fn main() {
247//! // cache responses based on method, Accept header, and uri
248//! let keyer = |request: &Request<Body>| {
249//!     (
250//!         request.method().clone(),
251//!         request
252//!             .headers()
253//!             .get(axum::http::header::ACCEPT)
254//!             .and_then(|c| c.to_str().ok())
255//!             .unwrap_or("")
256//!             .to_string(),
257//!         request.uri().clone(),
258//!     )
259//! };
260//! let router: Router = Router::new()
261//!     .route("/hello", get(|| async { "Hello, world!" }))
262//!     .layer(CacheLayer::with_lifespan_and_keyer(Duration::from_secs(60), keyer));
263//! # // force type inference to resolve the exact type of router
264//! #     let _ = router.oneshot(Request::get("/hello").body(Body::empty()).unwrap()).await;
265//! # }
266//! ```
267//!
268//! ## Use cases
269//! Caching responses in memory (eg. using [`cached::TimedCache`]) might be useful when the
270//! underlying service produces the responses by:
271//! 1. doing heavy computation,
272//! 2. requesting external service(s) that might not be fully reliable or performant,
273//! 3. serving static files from disk.
274//!
275//! In those cases, if the response to identical requests does not change often over time, it might
276//! be desirable to re-use the same responses from memory without re-calculating them – skipping requests to data
277//! bases, external services, reading from disk.
278//!
279//! ### Using Axum 0.7
280//!
281//! By default, this library uses Axum 0.8. However, you can configure it to use Axum 0.7 by enabling the appropriate feature flag in your `Cargo.toml`.
282//!
283//! To use Axum 0.7, add the following to your `Cargo.toml`:
284//!
285//! ```toml
286//! [dependencies]
287//! axum-response-cache = { version = "0.3", features = ["axum07"], default-features = false }
288//! ```
289//!
290//! This will disable the default Axum 0.8 feature and enable the Axum 0.7 feature instead.
291
292use std::{
293    convert::Infallible,
294    fmt::Debug,
295    future::Future,
296    hash::Hash,
297    pin::Pin,
298    sync::{Arc, Mutex},
299    task::{Context, Poll},
300    time::Duration,
301};
302use tracing_futures::Instrument as _;
303
304#[cfg(feature = "axum07")]
305use axum_07 as axum;
306#[cfg(feature = "axum08")]
307use axum_08 as axum;
308
309use axum::body;
310use axum::{
311    body::{Body, Bytes},
312    http::{response::Parts, Request, StatusCode},
313    response::{IntoResponse, Response},
314};
315
316use cached::{Cached, CloneCached, TimedCache};
317use tower::{Layer, Service};
318use tracing::{debug, instrument};
319
320/// The trait for objects used to obtain cache keys. See [`BasicKeyer`] for default implementation
321/// returning `(http::Method, Uri)`.
322pub trait Keyer {
323    type Key;
324
325    fn get_key(&self, request: &Request<Body>) -> Self::Key;
326}
327
328impl<K, F> Keyer for F
329where
330    F: Fn(&Request<Body>) -> K + Send + Sync + 'static,
331{
332    type Key = K;
333
334    fn get_key(&self, request: &Request<Body>) -> Self::Key {
335        self(request)
336    }
337}
338
339/// The basic caching strategy for the responses.
340///
341/// The responses are cached according to the HTTP method ([`axum::http::Method`]) and path
342/// ([`axum::http::Uri`]) of the request they responded to.
343pub struct BasicKeyer;
344
345pub type BasicKey = (http::Method, http::Uri);
346
347impl Keyer for BasicKeyer {
348    type Key = BasicKey;
349
350    fn get_key(&self, request: &Request<Body>) -> Self::Key {
351        (request.method().clone(), request.uri().clone())
352    }
353}
354
355/// The struct preserving all the headers and body of the cached response.
356#[derive(Clone, Debug)]
357pub struct CachedResponse {
358    parts: Parts,
359    body: Bytes,
360    timestamp: Option<std::time::Instant>,
361}
362
363impl IntoResponse for CachedResponse {
364    fn into_response(self) -> Response {
365        let mut response = Response::from_parts(self.parts, Body::from(self.body));
366        if let Some(timestamp) = self.timestamp {
367            let age = timestamp.elapsed().as_secs();
368            response
369                .headers_mut()
370                .insert("X-Cache-Age", age.to_string().parse().unwrap());
371        }
372        response
373    }
374}
375
376/// The main struct of the library. The layer providing caching to the wrapped service.
377/// It is generic over the cache used (`C`) and a `Keyer` (`K`) used to obtain the key for cached
378/// responses.
379pub struct CacheLayer<C, K> {
380    cache: Arc<Mutex<C>>,
381    use_stale: bool,
382    limit: usize,
383    allow_invalidation: bool,
384    add_response_headers: bool,
385    keyer: Arc<K>,
386}
387
388impl<C, K> Clone for CacheLayer<C, K> {
389    fn clone(&self) -> Self {
390        Self {
391            cache: Arc::clone(&self.cache),
392            use_stale: self.use_stale,
393            limit: self.limit,
394            allow_invalidation: self.allow_invalidation,
395            add_response_headers: self.add_response_headers,
396            keyer: Arc::clone(&self.keyer),
397        }
398    }
399}
400
401impl<C, K> CacheLayer<C, K>
402where
403    C: Cached<K::Key, CachedResponse> + CloneCached<K::Key, CachedResponse>,
404    K: Keyer,
405    K::Key: Debug + Hash + Eq + Clone + Send + 'static,
406{
407    /// Create a new cache layer with a given cache and the default body size limit of 128 MB.
408    pub fn with_cache_and_keyer(cache: C, keyer: K) -> Self {
409        Self {
410            cache: Arc::new(Mutex::new(cache)),
411            use_stale: false,
412            limit: 128 * 1024 * 1024,
413            allow_invalidation: false,
414            add_response_headers: false,
415            keyer: Arc::new(keyer),
416        }
417    }
418
419    /// Switch the layer’s settings to preserve the last successful response even when it’s evicted
420    /// from the cache but the service failed to provide a new successful response (ie. eg. when
421    /// the underlying service responds with `404 NOT FOUND`, the cache will keep providing the last stale `200 OK`
422    /// response produced).
423    pub fn use_stale_on_failure(self) -> Self {
424        Self {
425            use_stale: true,
426            ..self
427        }
428    }
429
430    /// Change the maximum body size limit. If you want unlimited size, use [`usize::MAX`].
431    pub fn body_limit(self, new_limit: usize) -> Self {
432        Self {
433            limit: new_limit,
434            ..self
435        }
436    }
437
438    /// Allow manual cache invalidation by setting the `X-Invalidate-Cache` header in the request.
439    /// This will allow the cache to be invalidated for the given key.
440    pub fn allow_invalidation(self) -> Self {
441        Self {
442            allow_invalidation: true,
443            ..self
444        }
445    }
446
447    /// Allow the response headers to be included in the cached response.
448    pub fn add_response_headers(self) -> Self {
449        Self {
450            add_response_headers: true,
451            ..self
452        }
453    }
454}
455
456impl<C> CacheLayer<C, BasicKeyer>
457where
458    C: Cached<BasicKey, CachedResponse> + CloneCached<BasicKey, CachedResponse>,
459{
460    /// Create a new cache layer with a given cache and the default body size limit of 128 MB.
461    pub fn with(cache: C) -> Self {
462        Self {
463            cache: Arc::new(Mutex::new(cache)),
464            use_stale: false,
465            limit: 128 * 1024 * 1024,
466            allow_invalidation: false,
467            add_response_headers: false,
468            keyer: Arc::new(BasicKeyer),
469        }
470    }
471}
472
473impl CacheLayer<TimedCache<BasicKey, CachedResponse>, BasicKey> {
474    /// Create a new cache layer with the desired TTL
475    pub fn with_lifespan(
476        ttl: Duration,
477    ) -> CacheLayer<TimedCache<BasicKey, CachedResponse>, BasicKeyer> {
478        CacheLayer::with(TimedCache::with_lifespan(ttl))
479    }
480}
481
482impl<K> CacheLayer<TimedCache<K::Key, CachedResponse>, K>
483where
484    K: Keyer,
485    K::Key: Debug + Hash + Eq + Clone + Send + 'static,
486{
487    /// Create a new cache layer with the desired TTL
488    pub fn with_lifespan_and_keyer(
489        ttl: Duration,
490        keyer: K,
491    ) -> CacheLayer<TimedCache<K::Key, CachedResponse>, K> {
492        CacheLayer::with_cache_and_keyer(TimedCache::with_lifespan(ttl), keyer)
493    }
494}
495
496impl<S, C, K> Layer<S> for CacheLayer<C, K>
497where
498    K: Keyer,
499    K::Key: Debug + Hash + Eq + Clone + Send + 'static,
500{
501    type Service = CacheService<S, C, K>;
502
503    fn layer(&self, inner: S) -> Self::Service {
504        Self::Service {
505            inner,
506            cache: Arc::clone(&self.cache),
507            use_stale: self.use_stale,
508            limit: self.limit,
509            allow_invalidation: self.allow_invalidation,
510            add_response_headers: self.add_response_headers,
511            keyer: Arc::clone(&self.keyer),
512        }
513    }
514}
515
516pub struct CacheService<S, C, K> {
517    inner: S,
518    cache: Arc<Mutex<C>>,
519    use_stale: bool,
520    limit: usize,
521    allow_invalidation: bool,
522    add_response_headers: bool,
523    keyer: Arc<K>,
524}
525
526impl<S, C, K> Clone for CacheService<S, C, K>
527where
528    S: Clone,
529{
530    fn clone(&self) -> Self {
531        Self {
532            inner: self.inner.clone(),
533            cache: Arc::clone(&self.cache),
534            use_stale: self.use_stale,
535            limit: self.limit,
536            allow_invalidation: self.allow_invalidation,
537            add_response_headers: self.add_response_headers,
538            keyer: Arc::clone(&self.keyer),
539        }
540    }
541}
542
543impl<S, C, K> Service<Request<Body>> for CacheService<S, C, K>
544where
545    S: Service<Request<Body>, Response = Response, Error = Infallible> + Clone + Send,
546    S::Future: Send + 'static,
547    C: Cached<K::Key, CachedResponse> + CloneCached<K::Key, CachedResponse> + Send + 'static,
548    K: Keyer,
549    K::Key: Debug + Hash + Eq + Clone + Send + 'static,
550{
551    type Response = Response;
552    type Error = Infallible;
553    type Future = Pin<Box<dyn Future<Output = Result<Response, Infallible>> + Send + 'static>>;
554
555    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
556        self.inner.poll_ready(cx)
557    }
558
559    #[instrument(skip(self, request))]
560    fn call(&mut self, request: Request<Body>) -> Self::Future {
561        let mut inner = self.inner.clone();
562        let use_stale = self.use_stale;
563        let allow_invalidation = self.allow_invalidation;
564        let add_response_headers = self.add_response_headers;
565        let limit = self.limit;
566        let cache = Arc::clone(&self.cache);
567        let key = self.keyer.get_key(&request);
568
569        // Check for the custom header "X-Invalidate-Cache" if invalidation is allowed
570        if allow_invalidation && request.headers().contains_key("X-Invalidate-Cache") {
571            // Manually invalidate the cache for this key
572            cache.lock().unwrap().cache_remove(&key);
573            debug!("Cache invalidated manually for key {:?}", key);
574        }
575
576        let inner_fut = inner
577            .call(request)
578            .instrument(tracing::info_span!("inner_service"));
579        let (cached, evicted) = {
580            let mut guard = cache.lock().unwrap();
581            let (cached, evicted) = guard.cache_get_expired(&key);
582            if let (Some(stale), true) = (cached.as_ref(), evicted) {
583                // reinsert stale value immediately so that others don’t schedule their updating
584                debug!("Found stale value in cache, reinsterting and attempting refresh");
585                guard.cache_set(key.clone(), stale.clone());
586            }
587            (cached, evicted)
588        };
589
590        Box::pin(async move {
591            match (cached, evicted) {
592                (Some(value), false) => Ok(value.into_response()),
593                (Some(stale_value), true) => {
594                    let response = inner_fut.await.unwrap();
595                    if response.status().is_success() {
596                        Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
597                    } else if use_stale {
598                        debug!("Returning stale value.");
599                        Ok(stale_value.into_response())
600                    } else {
601                        debug!("Stale value in cache, evicting and returning failed response.");
602                        cache.lock().unwrap().cache_remove(&key);
603                        Ok(response)
604                    }
605                }
606                (None, _) => {
607                    let response = inner_fut.await.unwrap();
608                    if response.status().is_success() {
609                        Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
610                    } else {
611                        Ok(response)
612                    }
613                }
614            }
615        })
616    }
617}
618
619#[instrument(skip(cache, response))]
620async fn update_cache<C, K>(
621    cache: &Arc<Mutex<C>>,
622    key: K,
623    response: Response,
624    limit: usize,
625    add_response_headers: bool,
626) -> Response
627where
628    C: Cached<K, CachedResponse> + CloneCached<K, CachedResponse>,
629    K: Debug + Hash + Eq + Clone + Send + 'static,
630{
631    let (parts, body) = response.into_parts();
632    let Ok(body) = body::to_bytes(body, limit).await else {
633        return (
634            StatusCode::INTERNAL_SERVER_ERROR,
635            format!("File too big, over {limit} bytes"),
636        )
637            .into_response();
638    };
639    let value = CachedResponse {
640        parts,
641        body,
642        timestamp: if add_response_headers {
643            Some(std::time::Instant::now())
644        } else {
645            None
646        },
647    };
648    {
649        cache.lock().unwrap().cache_set(key, value.clone());
650    }
651    value.into_response()
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657    use rand::Rng;
658    use std::sync::atomic::{AtomicIsize, Ordering};
659
660    #[cfg(feature = "axum07")]
661    use axum_07 as axum;
662    #[cfg(feature = "axum08")]
663    use axum_08 as axum;
664
665    use axum::{
666        extract::State,
667        http::{Request, StatusCode},
668        routing::get,
669        Router,
670    };
671
672    use tower::Service;
673
674    #[derive(Clone, Debug)]
675    struct Counter {
676        value: Arc<AtomicIsize>,
677    }
678
679    impl Counter {
680        fn new(init: isize) -> Self {
681            Self {
682                value: AtomicIsize::from(init).into(),
683            }
684        }
685
686        fn increment(&self) {
687            self.value.fetch_add(1, Ordering::Release);
688        }
689
690        fn read(&self) -> isize {
691            self.value.load(Ordering::Acquire)
692        }
693    }
694
695    #[tokio::test]
696    async fn should_use_cached_value() {
697        let handler = |State(cnt): State<Counter>| async move {
698            cnt.increment();
699            StatusCode::OK
700        };
701
702        let counter = Counter::new(0);
703        let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).use_stale_on_failure();
704        let mut router = Router::new()
705            .route("/", get(handler).layer(cache))
706            .with_state(counter.clone());
707
708        for _ in 0..10 {
709            let status = router
710                .call(Request::get("/").body(Body::empty()).unwrap())
711                .await
712                .unwrap()
713                .status();
714            assert!(status.is_success(), "handler should return success");
715        }
716
717        assert_eq!(1, counter.read(), "handler should’ve been called only once");
718    }
719
720    #[tokio::test]
721    async fn should_not_cache_unsuccessful_responses() {
722        let handler = |State(cnt): State<Counter>| async move {
723            cnt.increment();
724            let responses = [
725                StatusCode::BAD_REQUEST,
726                StatusCode::INTERNAL_SERVER_ERROR,
727                StatusCode::NOT_FOUND,
728            ];
729            let mut rng = rand::rng();
730            responses[rng.random_range(0..responses.len())]
731        };
732
733        let counter = Counter::new(0);
734        let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).use_stale_on_failure();
735        let mut router = Router::new()
736            .route("/", get(handler).layer(cache))
737            .with_state(counter.clone());
738
739        for _ in 0..10 {
740            let status = router
741                .call(Request::get("/").body(Body::empty()).unwrap())
742                .await
743                .unwrap()
744                .status();
745            assert!(!status.is_success(), "handler should never return success");
746        }
747
748        assert_eq!(
749            10,
750            counter.read(),
751            "handler should’ve been called for all requests"
752        );
753    }
754
755    #[tokio::test]
756    async fn should_use_last_correct_stale_value() {
757        let handler = |State(cnt): State<Counter>| async move {
758            let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
759            let responses = [
760                StatusCode::BAD_REQUEST,
761                StatusCode::INTERNAL_SERVER_ERROR,
762                StatusCode::NOT_FOUND,
763            ];
764            let mut rng = rand::rng();
765
766            // first response successful, later failed
767            if prev == 0 {
768                StatusCode::OK
769            } else {
770                responses[rng.random_range(0..responses.len())]
771            }
772        };
773
774        let counter = Counter::new(0);
775        let cache = CacheLayer::with_lifespan(Duration::from_millis(100)).use_stale_on_failure();
776        let mut router = Router::new()
777            .route("/", get(handler).layer(cache))
778            .with_state(counter);
779
780        // feed the cache
781        let status = router
782            .call(Request::get("/").body(Body::empty()).unwrap())
783            .await
784            .unwrap()
785            .status();
786        assert!(status.is_success(), "handler should return success");
787
788        // wait over 100 ms for cache eviction
789        tokio::time::sleep(tokio::time::Duration::from_millis(105)).await;
790
791        for _ in 1..10 {
792            let status = router
793                .call(Request::get("/").body(Body::empty()).unwrap())
794                .await
795                .unwrap()
796                .status();
797            assert!(
798                status.is_success(),
799                "cache should return stale successful value"
800            );
801        }
802    }
803
804    #[tokio::test]
805    async fn should_not_use_stale_values() {
806        let handler = |State(cnt): State<Counter>| async move {
807            let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
808            let responses = [
809                StatusCode::BAD_REQUEST,
810                StatusCode::INTERNAL_SERVER_ERROR,
811                StatusCode::NOT_FOUND,
812            ];
813            let mut rng = rand::rng();
814
815            // first response successful, later failed
816            if prev == 0 {
817                StatusCode::OK
818            } else {
819                responses[rng.random_range(0..responses.len())]
820            }
821        };
822
823        let counter = Counter::new(0);
824        let cache = CacheLayer::with_lifespan(Duration::from_millis(100));
825        let mut router = Router::new()
826            .route("/", get(handler).layer(cache))
827            .with_state(counter.clone());
828
829        // feed the cache
830        let status = router
831            .call(Request::get("/").body(Body::empty()).unwrap())
832            .await
833            .unwrap()
834            .status();
835        assert!(status.is_success(), "handler should return success");
836
837        // wait over 100 ms for cache eviction
838        tokio::time::sleep(tokio::time::Duration::from_millis(105)).await;
839
840        for _ in 1..10 {
841            let status = router
842                .call(Request::get("/").body(Body::empty()).unwrap())
843                .await
844                .unwrap()
845                .status();
846            assert!(
847                !status.is_success(),
848                "cache should forward unsuccessful values"
849            );
850        }
851
852        assert_eq!(
853            10,
854            counter.read(),
855            "handler should’ve been called for all requests"
856        );
857    }
858
859    #[tokio::test]
860    async fn should_not_invalidate_cache_when_disabled() {
861        let handler = |State(cnt): State<Counter>| async move {
862            cnt.increment();
863            StatusCode::OK
864        };
865
866        let counter = Counter::new(0);
867        let cache = CacheLayer::with_lifespan(Duration::from_secs(60));
868        let mut router = Router::new()
869            .route("/", get(handler).layer(cache))
870            .with_state(counter.clone());
871
872        // First request to cache the response
873        let status = router
874            .call(Request::get("/").body(Body::empty()).unwrap())
875            .await
876            .unwrap()
877            .status();
878        assert!(status.is_success(), "handler should return success");
879
880        // Second request should return the cached response - no increment
881        let status = router
882            .call(Request::get("/").body(Body::empty()).unwrap())
883            .await
884            .unwrap()
885            .status();
886        assert!(status.is_success(), "handler should return success");
887
888        // Third request with X-Invalidate-Cache header should not invalidate the cache - no increment
889        let status = router
890            .call(
891                Request::get("/")
892                    .header("X-Invalidate-Cache", "true")
893                    .body(Body::empty())
894                    .unwrap(),
895            )
896            .await
897            .unwrap()
898            .status();
899        assert!(status.is_success(), "handler should return success");
900
901        // Fourth request should still return the cached response - no increment
902        let status = router
903            .call(Request::get("/").body(Body::empty()).unwrap())
904            .await
905            .unwrap()
906            .status();
907        assert!(status.is_success(), "handler should return success");
908
909        assert_eq!(1, counter.read(), "handler should’ve been called only once");
910    }
911
912    #[tokio::test]
913    async fn should_invalidate_cache_when_enabled() {
914        let handler = |State(cnt): State<Counter>| async move {
915            cnt.increment();
916            StatusCode::OK
917        };
918
919        let counter = Counter::new(0);
920        let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).allow_invalidation();
921        let mut router = Router::new()
922            .route("/", get(handler).layer(cache))
923            .with_state(counter.clone());
924
925        // First request to cache the response
926        let status = router
927            .call(Request::get("/").body(Body::empty()).unwrap())
928            .await
929            .unwrap()
930            .status();
931        assert!(status.is_success(), "handler should return success");
932
933        // Second request should return the cached response - no increment
934        let status = router
935            .call(Request::get("/").body(Body::empty()).unwrap())
936            .await
937            .unwrap()
938            .status();
939        assert!(status.is_success(), "handler should return success");
940
941        // Third request with X-Invalidate-Cache header to invalidate the cache
942        let status = router
943            .call(
944                Request::get("/")
945                    .header("X-Invalidate-Cache", "true")
946                    .body(Body::empty())
947                    .unwrap(),
948            )
949            .await
950            .unwrap()
951            .status();
952        assert!(status.is_success(), "handler should return success");
953
954        // Fourth request to verify that the handler is called again
955        let status = router
956            .call(Request::get("/").body(Body::empty()).unwrap())
957            .await
958            .unwrap()
959            .status();
960        assert!(status.is_success(), "handler should return success");
961
962        assert_eq!(2, counter.read(), "handler should’ve been called twice");
963    }
964
965    #[tokio::test]
966    async fn should_not_include_age_header_when_disabled() {
967        let handler = |State(cnt): State<Counter>| async move {
968            cnt.increment();
969            StatusCode::OK
970        };
971
972        let counter = Counter::new(0);
973        let cache = CacheLayer::with_lifespan(Duration::from_secs(60));
974        let mut router = Router::new()
975            .route("/", get(handler).layer(cache))
976            .with_state(counter.clone());
977
978        // First request to cache the response
979        let response = router
980            .call(Request::get("/").body(Body::empty()).unwrap())
981            .await
982            .unwrap();
983        assert!(
984            response.status().is_success(),
985            "handler should return success"
986        );
987
988        // Second request should return the cached response
989        let response = router
990            .call(Request::get("/").body(Body::empty()).unwrap())
991            .await
992            .unwrap();
993        assert!(
994            response.status().is_success(),
995            "handler should return success"
996        );
997        assert!(
998            response.headers().get("X-Cache-Age").is_none(),
999            "Age header should not be present"
1000        );
1001
1002        assert_eq!(1, counter.read(), "handler should’ve been called only once");
1003    }
1004
1005    #[tokio::test]
1006    async fn should_include_age_header_when_enabled() {
1007        let handler = |State(cnt): State<Counter>| async move {
1008            cnt.increment();
1009            StatusCode::OK
1010        };
1011
1012        let counter = Counter::new(0);
1013        let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).add_response_headers();
1014        let mut router = Router::new()
1015            .route("/", get(handler).layer(cache))
1016            .with_state(counter.clone());
1017
1018        // First request to cache the response
1019        let response = router
1020            .call(Request::get("/").body(Body::empty()).unwrap())
1021            .await
1022            .unwrap();
1023        assert!(
1024            response.status().is_success(),
1025            "handler should return success"
1026        );
1027
1028        // Age should be 0
1029        assert_eq!(
1030            response
1031                .headers()
1032                .get("X-Cache-Age")
1033                .and_then(|v| v.to_str().ok())
1034                .unwrap_or(""),
1035            "0",
1036            "Age header should be present and equal to 0"
1037        );
1038        // wait over 2s to age the cache
1039        tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1040        // Second request should return the cached response
1041        let response = router
1042            .call(Request::get("/").body(Body::empty()).unwrap())
1043            .await
1044            .unwrap();
1045
1046        assert_eq!(
1047            response
1048                .headers()
1049                .get("X-Cache-Age")
1050                .and_then(|v| v.to_str().ok())
1051                .unwrap_or(""),
1052            "2",
1053            "Age header should be present and equal to 2"
1054        );
1055
1056        assert_eq!(1, counter.read(), "handler should’ve been called only once");
1057    }
1058
1059    #[tokio::test]
1060    async fn should_cache_by_custom_keys() {
1061        let handler = |State(cnt): State<Counter>| async move {
1062            cnt.increment();
1063            StatusCode::OK
1064        };
1065
1066        let counter = Counter::new(0);
1067        let keyer = |request: &Request<Body>| {
1068            (
1069                request.method().clone(),
1070                request
1071                    .headers()
1072                    .get(axum::http::header::ACCEPT)
1073                    .and_then(|c| c.to_str().ok())
1074                    .unwrap_or("")
1075                    .to_string(),
1076                request.uri().clone(),
1077            )
1078        };
1079        let cache = CacheLayer::with_lifespan_and_keyer(Duration::from_secs(60), keyer)
1080            .add_response_headers();
1081        let mut router = Router::new()
1082            .route("/", get(handler).layer(cache))
1083            .with_state(counter.clone());
1084
1085        // First request to cache the response
1086        let response = router
1087            .call(Request::get("/").body(Body::empty()).unwrap())
1088            .await
1089            .unwrap();
1090        assert!(
1091            response.status().is_success(),
1092            "handler should return success"
1093        );
1094
1095        // Age should be 0
1096        assert_eq!(
1097            response
1098                .headers()
1099                .get("X-Cache-Age")
1100                .and_then(|v| v.to_str().ok())
1101                .unwrap_or(""),
1102            "0",
1103            "Age header should be present and equal to 0"
1104        );
1105
1106        // wait over 2s to age the cache
1107        tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1108        // Second request should return the cached response
1109        let response = router
1110            .call(Request::get("/").body(Body::empty()).unwrap())
1111            .await
1112            .unwrap();
1113
1114        assert_eq!(
1115            response
1116                .headers()
1117                .get("X-Cache-Age")
1118                .and_then(|v| v.to_str().ok())
1119                .unwrap_or(""),
1120            "2",
1121            "Age header should be present and equal to 2"
1122        );
1123
1124        // Request with a different accept header should return a new response
1125        let response = router
1126            .call(
1127                Request::get("/")
1128                    .header(axum::http::header::ACCEPT, "application/json")
1129                    .body(Body::empty())
1130                    .unwrap(),
1131            )
1132            .await
1133            .unwrap();
1134
1135        assert_eq!(
1136            response
1137                .headers()
1138                .get("X-Cache-Age")
1139                .and_then(|v| v.to_str().ok())
1140                .unwrap_or(""),
1141            "0",
1142            "Age header should be present and equal to 0"
1143        );
1144
1145        // wait over 2s to age the cache
1146        tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1147        // Second request should return the newly cached response
1148        let response = router
1149            .call(
1150                Request::get("/")
1151                    .header(axum::http::header::ACCEPT, "application/json")
1152                    .body(Body::empty())
1153                    .unwrap(),
1154            )
1155            .await
1156            .unwrap();
1157
1158        assert_eq!(
1159            response
1160                .headers()
1161                .get("X-Cache-Age")
1162                .and_then(|v| v.to_str().ok())
1163                .unwrap_or(""),
1164            "2",
1165            "Age header should be present and equal to 2"
1166        );
1167
1168        assert_eq!(
1169            2,
1170            counter.read(),
1171            "handler should’ve been called only twice"
1172        );
1173    }
1174}