axum_response_cache/lib.rs
1//! This library provides [Axum middleware](`axum#middleware`) that caches HTTP responses to the
2//! incoming requests based on their HTTP method and path.
3//!
4//! The main struct is [`CacheLayer`]. It can be created with any cache that implements two traits
5//! from the [`cached`] crate: [`cached::Cached`] and [`cached::CloneCached`].
6//!
7//! The *current* version of [`CacheLayer`] is compatible only with services accepting
8//! Axum’s [`Request<Body>`](`http::Request<axum::body::Body>`) and returning
9//! [`axum::response::Response`], thus it is not compatible with non-Axum [`tower`] services.
10//!
11//! It’s possible to configure the layer to re-use an old expired response in case the wrapped
12//! service fails to produce a new successful response.
13//!
14//! Only successful responses are cached (responses with status codes outside of the `[200-299]`
15//! range are passed-through or ignored).
16//!
17//! The cache limits maximum size of the response’s body (128 MB by default).
18//!
19//! ## Examples
20//!
21//! To cache a response over a specific route, just wrap it in a [`CacheLayer`]:
22//!
23//! ```rust,no_run
24//! # use axum_08 as axum;
25//! use axum::{Router, extract::Path, routing::get};
26//! use axum_response_cache::CacheLayer;
27//!
28//! #[tokio::main]
29//! async fn main() {
30//! let mut router = Router::new()
31//! .route(
32//! "/hello/{name}",
33//! get(|Path(name): Path<String>| async move { format!("Hello, {name}!") })
34//! // this will cache responses with each `:name` for 60 seconds.
35//! .layer(CacheLayer::with_lifespan(60)),
36//! );
37//!
38//! let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
39//! axum::serve(listener, router).await.unwrap();
40//! }
41//! ```
42//!
43//! ### Reusing last successful response
44//!
45//! ```rust
46//! # use std::sync::atomic::{AtomicBool, Ordering};
47//! # use axum_08 as axum;
48//! use axum::{
49//! body::Body,
50//! extract::Path,
51//! http::status::StatusCode,
52//! http::Request,
53//! Router,
54//! routing::get,
55//! };
56//! use axum_response_cache::CacheLayer;
57//! use tower::Service as _;
58//!
59//! // a handler that returns 200 OK only the first time it’s called
60//! async fn handler(Path(name): Path<String>) -> (StatusCode, String) {
61//! static FIRST_RUN: AtomicBool = AtomicBool::new(true);
62//! let first_run = FIRST_RUN.swap(false, Ordering::AcqRel);
63//!
64//! if first_run {
65//! (StatusCode::OK, format!("Hello, {name}"))
66//! } else {
67//! (StatusCode::INTERNAL_SERVER_ERROR, String::from("Error!"))
68//! }
69//! }
70//!
71//! # #[tokio::main]
72//! # async fn main() {
73//! let mut router = Router::new()
74//! .route("/hello/{name}", get(handler))
75//! .layer(CacheLayer::with_lifespan(60).use_stale_on_failure());
76//!
77//! // first request will fire handler and get the response
78//! let status1 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
79//! .await
80//! .unwrap()
81//! .status();
82//! assert_eq!(StatusCode::OK, status1);
83//!
84//! // second request will reuse the last response since the handler now returns ISE
85//! let status2 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
86//! .await
87//! .unwrap()
88//! .status();
89//! assert_eq!(StatusCode::OK, status2);
90//! # }
91//! ```
92//!
93//! ### Serving static files
94//! This middleware can be used to cache files served in memory to limit hard drive load on the
95//! server. To serve files you can use [`tower-http::services::ServeDir`](https://docs.rs/tower-http/latest/tower_http/services/struct.ServeDir.html) layer.
96//! ```rust,ignore
97//! let router = Router::new().nest_service("/", ServeDir::new("static/"));
98//! ```
99//!
100//! ### Limiting the body size
101//!
102//! ```rust
103//! # use axum_08 as axum;
104//! use axum::{
105//! body::Body,
106//! extract::Path,
107//! http::status::StatusCode,
108//! http::Request,
109//! Router,
110//! routing::get,
111//! };
112//! use axum_response_cache::CacheLayer;
113//! use tower::Service as _;
114//!
115//! // returns a short string, well below the limit
116//! async fn ok_handler() -> &'static str {
117//! "ok"
118//! }
119//!
120//! async fn too_long_handler() -> &'static str {
121//! "a response that is well beyond the limit of the cache!"
122//! }
123//!
124//! # #[tokio::main]
125//! # async fn main() {
126//! let mut router = Router::new()
127//! .route("/ok", get(ok_handler))
128//! .route("/too_long", get(too_long_handler))
129//! // limit max cached body to only 16 bytes
130//! .layer(CacheLayer::with_lifespan(60).body_limit(16));
131//!
132//! let status_ok = router.call(Request::get("/ok").body(Body::empty()).unwrap())
133//! .await
134//! .unwrap()
135//! .status();
136//! assert_eq!(StatusCode::OK, status_ok);
137//!
138//! let status_too_long = router.call(Request::get("/too_long").body(Body::empty()).unwrap())
139//! .await
140//! .unwrap()
141//! .status();
142//! assert_eq!(StatusCode::INTERNAL_SERVER_ERROR, status_too_long);
143//! # }
144//! ```
145//! ### Manual Cache Invalidation
146//! This middleware allows manual cache invalidation by setting the `X-Invalidate-Cache` header in the request. This can be useful when you know the underlying data has changed and you want to force a fresh pull of data.
147//!
148//! ```rust
149//! # use axum_08 as axum;
150//! use axum::{
151//! body::Body,
152//! extract::Path,
153//! http::status::StatusCode,
154//! http::Request,
155//! Router,
156//! routing::get,
157//! };
158//! use axum_response_cache::CacheLayer;
159//! use tower::Service as _;
160//!
161//! async fn handler(Path(name): Path<String>) -> (StatusCode, String) {
162//! (StatusCode::OK, format!("Hello, {name}"))
163//! }
164//!
165//! # #[tokio::main]
166//! # async fn main() {
167//! let mut router = Router::new()
168//! .route("/hello/{name}", get(handler))
169//! .layer(CacheLayer::with_lifespan(60).allow_invalidation());
170//!
171//! // first request will fire handler and get the response
172//! let status1 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
173//! .await
174//! .unwrap()
175//! .status();
176//! assert_eq!(StatusCode::OK, status1);
177//!
178//! // second request should return the cached response
179//! let status2 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
180//! .await
181//! .unwrap()
182//! .status();
183//! assert_eq!(StatusCode::OK, status2);
184//!
185//! // third request with X-Invalidate-Cache header to invalidate the cache
186//! let status3 = router.call(
187//! Request::get("/hello/foo")
188//! .header("X-Invalidate-Cache", "true")
189//! .body(Body::empty())
190//! .unwrap(),
191//! )
192//! .await
193//! .unwrap()
194//! .status();
195//! assert_eq!(StatusCode::OK, status3);
196//!
197//! // fourth request to verify that the handler is called again
198//! let status4 = router.call(Request::get("/hello/foo").body(Body::empty()).unwrap())
199//! .await
200//! .unwrap()
201//! .status();
202//! assert_eq!(StatusCode::OK, status4);
203//! # }
204//! ```
205//! Cache invalidation could be dangerous because it can allow a user to force the server to make a request to an external service or database. It is disabled by default, but can be enabled by calling the [`CacheLayer::allow_invalidation`] method.
206//!
207//! ## Using custom cache
208//!
209//! ```rust
210//! # use axum_08 as axum;
211//! use axum::{Router, routing::get};
212//! use axum_response_cache::CacheLayer;
213//! // let’s use TimedSizedCache here
214//! use cached::stores::TimedSizedCache;
215//! # use axum::{body::Body, http::Request};
216//! # use tower::ServiceExt;
217//!
218//! # #[tokio::main]
219//! # async fn main() {
220//! let router: Router = Router::new()
221//! .route("/hello", get(|| async { "Hello, world!" }))
222//! // cache maximum value of 50 responses for one minute
223//! .layer(CacheLayer::with(TimedSizedCache::with_size_and_lifespan(50, 60)));
224//! # // force type inference to resolve the exact type of router
225//! # let _ = router.oneshot(Request::get("/hello").body(Body::empty()).unwrap()).await;
226//! # }
227//! ```
228//!
229//! ## Use cases
230//! Caching responses in memory (eg. using [`cached::TimedCache`]) might be useful when the
231//! underlying service produces the responses by:
232//! 1. doing heavy computation,
233//! 2. requesting external service(s) that might not be fully reliable or performant,
234//! 3. serving static files from disk.
235//!
236//! In those cases, if the response to identical requests does not change often over time, it might
237//! be desirable to re-use the same responses from memory without re-calculating them – skipping requests to data
238//! bases, external services, reading from disk.
239//!
240//! ### Using Axum 0.7
241//!
242//! By default, this library uses Axum 0.8. However, you can configure it to use Axum 0.7 by enabling the appropriate feature flag in your `Cargo.toml`.
243//!
244//! To use Axum 0.7, add the following to your `Cargo.toml`:
245//!
246//! ```toml
247//! [dependencies]
248//! axum-response-cache = { version = "0.1.2", features = ["axum07"], default-features = false }
249//! ```
250//!
251//! This will disable the default Axum 0.8 feature and enable the Axum 0.7 feature instead.
252
253use std::{
254 convert::Infallible,
255 future::Future,
256 pin::Pin,
257 sync::{Arc, Mutex},
258 task::{Context, Poll},
259};
260use tracing_futures::Instrument as _;
261
262#[cfg(feature = "axum07")]
263use axum_07 as axum;
264#[cfg(feature = "axum08")]
265use axum_08 as axum;
266
267use axum::body;
268use axum::{
269 body::{Body, Bytes},
270 http::{response::Parts, Request, StatusCode},
271 response::{IntoResponse, Response},
272};
273
274use cached::{Cached, CloneCached, TimedCache};
275use tower::{Layer, Service};
276use tracing::{debug, instrument};
277
278/// The caching key for the responses.
279///
280/// The responses are cached according to the HTTP method [`axum::http::Method`]) and path
281/// ([`axum::http::Uri`]) of the request they responded to.
282type Key = (http::Method, http::Uri);
283
284/// The struct preserving all the headers and body of the cached response.
285#[derive(Clone, Debug)]
286pub struct CachedResponse {
287 parts: Parts,
288 body: Bytes,
289 timestamp: Option<std::time::Instant>,
290}
291
292impl IntoResponse for CachedResponse {
293 fn into_response(self) -> Response {
294 let mut response = Response::from_parts(self.parts, Body::from(self.body));
295 if let Some(timestamp) = self.timestamp {
296 let age = timestamp.elapsed().as_secs();
297 response
298 .headers_mut()
299 .insert("X-Cache-Age", age.to_string().parse().unwrap());
300 }
301 response
302 }
303}
304
305/// The main struct of the library. The layer providing caching to the wrapped service.
306#[derive(Clone)]
307pub struct CacheLayer<C> {
308 cache: Arc<Mutex<C>>,
309 use_stale: bool,
310 limit: usize,
311 allow_invalidation: bool,
312 add_response_headers: bool,
313}
314
315impl<C> CacheLayer<C>
316where
317 C: Cached<Key, CachedResponse> + CloneCached<Key, CachedResponse>,
318{
319 /// Create a new cache layer with a given cache and the default body size limit of 128 MB.
320 pub fn with(cache: C) -> Self {
321 Self {
322 cache: Arc::new(Mutex::new(cache)),
323 use_stale: false,
324 limit: 128 * 1024 * 1024,
325 allow_invalidation: false,
326 add_response_headers: false,
327 }
328 }
329
330 /// Switch the layer’s settings to preserve the last successful response even when it’s evicted
331 /// from the cache but the service failed to provide a new successful response (ie. eg. when
332 /// the underlying service responds with `404 NOT FOUND`, the cache will keep providing the last stale `200 OK`
333 /// response produced).
334 pub fn use_stale_on_failure(self) -> Self {
335 Self {
336 use_stale: true,
337 ..self
338 }
339 }
340
341 /// Change the maximum body size limit. If you want unlimited size, use [`usize::MAX`].
342 pub fn body_limit(self, new_limit: usize) -> Self {
343 Self {
344 limit: new_limit,
345 ..self
346 }
347 }
348
349 /// Allow manual cache invalidation by setting the `X-Invalidate-Cache` header in the request.
350 /// This will allow the cache to be invalidated for the given key.
351 pub fn allow_invalidation(self) -> Self {
352 Self {
353 allow_invalidation: true,
354 ..self
355 }
356 }
357
358 /// Allow the response headers to be included in the cached response.
359 pub fn add_response_headers(self) -> Self {
360 Self {
361 add_response_headers: true,
362 ..self
363 }
364 }
365}
366
367impl CacheLayer<TimedCache<Key, CachedResponse>> {
368 /// Create a new cache layer with the desired TTL in seconds
369 pub fn with_lifespan(ttl_sec: u64) -> CacheLayer<TimedCache<Key, CachedResponse>> {
370 CacheLayer::with(TimedCache::with_lifespan(ttl_sec))
371 }
372}
373
374impl<S, C> Layer<S> for CacheLayer<C> {
375 type Service = CacheService<S, C>;
376
377 fn layer(&self, inner: S) -> Self::Service {
378 Self::Service {
379 inner,
380 cache: Arc::clone(&self.cache),
381 use_stale: self.use_stale,
382 limit: self.limit,
383 allow_invalidation: self.allow_invalidation,
384 add_response_headers: self.add_response_headers,
385 }
386 }
387}
388
389#[derive(Clone)]
390pub struct CacheService<S, C> {
391 inner: S,
392 cache: Arc<Mutex<C>>,
393 use_stale: bool,
394 limit: usize,
395 allow_invalidation: bool,
396 add_response_headers: bool,
397}
398
399impl<S, C> Service<Request<Body>> for CacheService<S, C>
400where
401 S: Service<Request<Body>, Response = Response, Error = Infallible> + Clone + Send,
402 S::Future: Send + 'static,
403 C: Cached<Key, CachedResponse> + CloneCached<Key, CachedResponse> + Send + 'static,
404{
405 type Response = Response;
406 type Error = Infallible;
407 type Future = Pin<Box<dyn Future<Output = Result<Response, Infallible>> + Send + 'static>>;
408
409 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
410 self.inner.poll_ready(cx)
411 }
412
413 #[instrument(skip(self, request))]
414 fn call(&mut self, request: Request<Body>) -> Self::Future {
415 let mut inner = self.inner.clone();
416 let use_stale = self.use_stale;
417 let allow_invalidation = self.allow_invalidation;
418 let add_response_headers = self.add_response_headers;
419 let limit = self.limit;
420 let cache = Arc::clone(&self.cache);
421 let key = (request.method().clone(), request.uri().clone());
422
423 // Check for the custom header "X-Invalidate-Cache" if invalidation is allowed
424 if allow_invalidation && request.headers().contains_key("X-Invalidate-Cache") {
425 // Manually invalidate the cache for this key
426 cache.lock().unwrap().cache_remove(&key);
427 debug!("Cache invalidated manually for key {:?}", key);
428 }
429
430 let inner_fut = inner
431 .call(request)
432 .instrument(tracing::info_span!("inner_service"));
433 let (cached, evicted) = {
434 let mut guard = cache.lock().unwrap();
435 let (cached, evicted) = guard.cache_get_expired(&key);
436 if let (Some(stale), true) = (cached.as_ref(), evicted) {
437 // reinsert stale value immediately so that others don’t schedule their updating
438 debug!("Found stale value in cache, reinsterting and attempting refresh");
439 guard.cache_set(key.clone(), stale.clone());
440 }
441 (cached, evicted)
442 };
443
444 Box::pin(async move {
445 match (cached, evicted) {
446 (Some(value), false) => Ok(value.into_response()),
447 (Some(stale_value), true) => {
448 let response = inner_fut.await.unwrap();
449 if response.status().is_success() {
450 Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
451 } else if use_stale {
452 debug!("Returning stale value.");
453 Ok(stale_value.into_response())
454 } else {
455 debug!("Stale value in cache, evicting and returning failed response.");
456 cache.lock().unwrap().cache_remove(&key);
457 Ok(response)
458 }
459 }
460 (None, _) => {
461 let response = inner_fut.await.unwrap();
462 if response.status().is_success() {
463 Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
464 } else {
465 Ok(response)
466 }
467 }
468 }
469 })
470 }
471}
472
473#[instrument(skip(cache, response))]
474async fn update_cache<C: Cached<Key, CachedResponse> + CloneCached<Key, CachedResponse>>(
475 cache: &Arc<Mutex<C>>,
476 key: Key,
477 response: Response,
478 limit: usize,
479 add_response_headers: bool,
480) -> Response {
481 let (parts, body) = response.into_parts();
482 let Ok(body) = body::to_bytes(body, limit).await else {
483 return (
484 StatusCode::INTERNAL_SERVER_ERROR,
485 format!("File too big, over {limit} bytes"),
486 )
487 .into_response();
488 };
489 let value = CachedResponse {
490 parts,
491 body,
492 timestamp: if add_response_headers {
493 Some(std::time::Instant::now())
494 } else {
495 None
496 },
497 };
498 {
499 cache.lock().unwrap().cache_set(key, value.clone());
500 }
501 value.into_response()
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507 use rand::Rng;
508 use std::sync::atomic::{AtomicIsize, Ordering};
509
510 #[cfg(feature = "axum07")]
511 use axum_07 as axum;
512 #[cfg(feature = "axum08")]
513 use axum_08 as axum;
514
515 use axum::{
516 extract::State,
517 http::{Request, StatusCode},
518 routing::get,
519 Router,
520 };
521
522 use tower::Service;
523
524 #[derive(Clone, Debug)]
525 struct Counter {
526 value: Arc<AtomicIsize>,
527 }
528
529 impl Counter {
530 fn new(init: isize) -> Self {
531 Self {
532 value: AtomicIsize::from(init).into(),
533 }
534 }
535
536 fn increment(&self) {
537 self.value.fetch_add(1, Ordering::Release);
538 }
539
540 fn read(&self) -> isize {
541 self.value.load(Ordering::Acquire)
542 }
543 }
544
545 #[tokio::test]
546 async fn should_use_cached_value() {
547 let handler = |State(cnt): State<Counter>| async move {
548 cnt.increment();
549 StatusCode::OK
550 };
551
552 let counter = Counter::new(0);
553 let cache = CacheLayer::with_lifespan(60).use_stale_on_failure();
554 let mut router = Router::new()
555 .route("/", get(handler).layer(cache))
556 .with_state(counter.clone());
557
558 for _ in 0..10 {
559 let status = router
560 .call(Request::get("/").body(Body::empty()).unwrap())
561 .await
562 .unwrap()
563 .status();
564 assert!(status.is_success(), "handler should return success");
565 }
566
567 assert_eq!(1, counter.read(), "handler should’ve been called only once");
568 }
569
570 #[tokio::test]
571 async fn should_not_cache_unsuccessful_responses() {
572 let handler = |State(cnt): State<Counter>| async move {
573 cnt.increment();
574 let responses = [
575 StatusCode::BAD_REQUEST,
576 StatusCode::INTERNAL_SERVER_ERROR,
577 StatusCode::NOT_FOUND,
578 ];
579 let mut rng = rand::thread_rng();
580 responses[rng.gen_range(0..responses.len())]
581 };
582
583 let counter = Counter::new(0);
584 let cache = CacheLayer::with_lifespan(60).use_stale_on_failure();
585 let mut router = Router::new()
586 .route("/", get(handler).layer(cache))
587 .with_state(counter.clone());
588
589 for _ in 0..10 {
590 let status = router
591 .call(Request::get("/").body(Body::empty()).unwrap())
592 .await
593 .unwrap()
594 .status();
595 assert!(!status.is_success(), "handler should never return success");
596 }
597
598 assert_eq!(
599 10,
600 counter.read(),
601 "handler should’ve been called for all requests"
602 );
603 }
604
605 #[tokio::test]
606 async fn should_use_last_correct_stale_value() {
607 let handler = |State(cnt): State<Counter>| async move {
608 let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
609 let responses = [
610 StatusCode::BAD_REQUEST,
611 StatusCode::INTERNAL_SERVER_ERROR,
612 StatusCode::NOT_FOUND,
613 ];
614 let mut rng = rand::thread_rng();
615
616 // first response successful, later failed
617 if prev == 0 {
618 StatusCode::OK
619 } else {
620 responses[rng.gen_range(0..responses.len())]
621 }
622 };
623
624 let counter = Counter::new(0);
625 let cache = CacheLayer::with_lifespan(1).use_stale_on_failure();
626 let mut router = Router::new()
627 .route("/", get(handler).layer(cache))
628 .with_state(counter);
629
630 // feed the cache
631 let status = router
632 .call(Request::get("/").body(Body::empty()).unwrap())
633 .await
634 .unwrap()
635 .status();
636 assert!(status.is_success(), "handler should return success");
637
638 // wait over 1s for cache eviction
639 tokio::time::sleep(tokio::time::Duration::from_millis(1050)).await;
640
641 for _ in 1..10 {
642 let status = router
643 .call(Request::get("/").body(Body::empty()).unwrap())
644 .await
645 .unwrap()
646 .status();
647 assert!(
648 status.is_success(),
649 "cache should return stale successful value"
650 );
651 }
652 }
653
654 #[tokio::test]
655 async fn should_not_use_stale_values() {
656 let handler = |State(cnt): State<Counter>| async move {
657 let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
658 let responses = [
659 StatusCode::BAD_REQUEST,
660 StatusCode::INTERNAL_SERVER_ERROR,
661 StatusCode::NOT_FOUND,
662 ];
663 let mut rng = rand::thread_rng();
664
665 // first response successful, later failed
666 if prev == 0 {
667 StatusCode::OK
668 } else {
669 responses[rng.gen_range(0..responses.len())]
670 }
671 };
672
673 let counter = Counter::new(0);
674 let cache = CacheLayer::with_lifespan(1);
675 let mut router = Router::new()
676 .route("/", get(handler).layer(cache))
677 .with_state(counter.clone());
678
679 // feed the cache
680 let status = router
681 .call(Request::get("/").body(Body::empty()).unwrap())
682 .await
683 .unwrap()
684 .status();
685 assert!(status.is_success(), "handler should return success");
686
687 // wait over 1s for cache eviction
688 tokio::time::sleep(tokio::time::Duration::from_millis(1050)).await;
689
690 for _ in 1..10 {
691 let status = router
692 .call(Request::get("/").body(Body::empty()).unwrap())
693 .await
694 .unwrap()
695 .status();
696 assert!(
697 !status.is_success(),
698 "cache should forward unsuccessful values"
699 );
700 }
701
702 assert_eq!(
703 10,
704 counter.read(),
705 "handler should’ve been called for all requests"
706 );
707 }
708
709 #[tokio::test]
710 async fn should_not_invalidate_cache_when_disabled() {
711 let handler = |State(cnt): State<Counter>| async move {
712 cnt.increment();
713 StatusCode::OK
714 };
715
716 let counter = Counter::new(0);
717 let cache = CacheLayer::with_lifespan(60);
718 let mut router = Router::new()
719 .route("/", get(handler).layer(cache))
720 .with_state(counter.clone());
721
722 // First request to cache the response
723 let status = router
724 .call(Request::get("/").body(Body::empty()).unwrap())
725 .await
726 .unwrap()
727 .status();
728 assert!(status.is_success(), "handler should return success");
729
730 // Second request should return the cached response - no increment
731 let status = router
732 .call(Request::get("/").body(Body::empty()).unwrap())
733 .await
734 .unwrap()
735 .status();
736 assert!(status.is_success(), "handler should return success");
737
738 // Third request with X-Invalidate-Cache header should not invalidate the cache - no increment
739 let status = router
740 .call(
741 Request::get("/")
742 .header("X-Invalidate-Cache", "true")
743 .body(Body::empty())
744 .unwrap(),
745 )
746 .await
747 .unwrap()
748 .status();
749 assert!(status.is_success(), "handler should return success");
750
751 // Fourth request should still return the cached response - no increment
752 let status = router
753 .call(Request::get("/").body(Body::empty()).unwrap())
754 .await
755 .unwrap()
756 .status();
757 assert!(status.is_success(), "handler should return success");
758
759 assert_eq!(1, counter.read(), "handler should’ve been called only once");
760 }
761
762 #[tokio::test]
763 async fn should_invalidate_cache_when_enabled() {
764 let handler = |State(cnt): State<Counter>| async move {
765 cnt.increment();
766 StatusCode::OK
767 };
768
769 let counter = Counter::new(0);
770 let cache = CacheLayer::with_lifespan(60).allow_invalidation();
771 let mut router = Router::new()
772 .route("/", get(handler).layer(cache))
773 .with_state(counter.clone());
774
775 // First request to cache the response
776 let status = router
777 .call(Request::get("/").body(Body::empty()).unwrap())
778 .await
779 .unwrap()
780 .status();
781 assert!(status.is_success(), "handler should return success");
782
783 // Second request should return the cached response - no increment
784 let status = router
785 .call(Request::get("/").body(Body::empty()).unwrap())
786 .await
787 .unwrap()
788 .status();
789 assert!(status.is_success(), "handler should return success");
790
791 // Third request with X-Invalidate-Cache header to invalidate the cache
792 let status = router
793 .call(
794 Request::get("/")
795 .header("X-Invalidate-Cache", "true")
796 .body(Body::empty())
797 .unwrap(),
798 )
799 .await
800 .unwrap()
801 .status();
802 assert!(status.is_success(), "handler should return success");
803
804 // Fourth request to verify that the handler is called again
805 let status = router
806 .call(Request::get("/").body(Body::empty()).unwrap())
807 .await
808 .unwrap()
809 .status();
810 assert!(status.is_success(), "handler should return success");
811
812 assert_eq!(2, counter.read(), "handler should’ve been called twice");
813 }
814
815 #[tokio::test]
816 async fn should_not_include_age_header_when_disabled() {
817 let handler = |State(cnt): State<Counter>| async move {
818 cnt.increment();
819 StatusCode::OK
820 };
821
822 let counter = Counter::new(0);
823 let cache = CacheLayer::with_lifespan(60);
824 let mut router = Router::new()
825 .route("/", get(handler).layer(cache))
826 .with_state(counter.clone());
827
828 // First request to cache the response
829 let response = router
830 .call(Request::get("/").body(Body::empty()).unwrap())
831 .await
832 .unwrap();
833 assert!(
834 response.status().is_success(),
835 "handler should return success"
836 );
837
838 // Second request should return the cached response
839 let response = router
840 .call(Request::get("/").body(Body::empty()).unwrap())
841 .await
842 .unwrap();
843 assert!(
844 response.status().is_success(),
845 "handler should return success"
846 );
847 assert!(
848 response.headers().get("X-Cache-Age").is_none(),
849 "Age header should not be present"
850 );
851
852 assert_eq!(1, counter.read(), "handler should’ve been called only once");
853 }
854
855 #[tokio::test]
856 async fn should_include_age_header_when_enabled() {
857 let handler = |State(cnt): State<Counter>| async move {
858 cnt.increment();
859 StatusCode::OK
860 };
861
862 let counter = Counter::new(0);
863 let cache = CacheLayer::with_lifespan(60).add_response_headers();
864 let mut router = Router::new()
865 .route("/", get(handler).layer(cache))
866 .with_state(counter.clone());
867
868 // First request to cache the response
869 let response = router
870 .call(Request::get("/").body(Body::empty()).unwrap())
871 .await
872 .unwrap();
873 assert!(
874 response.status().is_success(),
875 "handler should return success"
876 );
877
878 // Age should be 0
879 assert_eq!(
880 response
881 .headers()
882 .get("X-Cache-Age")
883 .and_then(|v| v.to_str().ok())
884 .unwrap_or(""),
885 "0",
886 "Age header should be present and equal to 0"
887 );
888 // wait over 2s to age the cache
889 tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
890 // Second request should return the cached response
891 let response = router
892 .call(Request::get("/").body(Body::empty()).unwrap())
893 .await
894 .unwrap();
895
896 assert_eq!(
897 response
898 .headers()
899 .get("X-Cache-Age")
900 .and_then(|v| v.to_str().ok())
901 .unwrap_or(""),
902 "2",
903 "Age header should be present and equal to 2"
904 );
905
906 assert_eq!(1, counter.read(), "handler should’ve been called only once");
907 }
908}