1use std::{
287 convert::Infallible,
288 fmt::Debug,
289 future::Future,
290 hash::Hash,
291 pin::Pin,
292 sync::{Arc, Mutex},
293 task::{Context, Poll},
294};
295use tracing_futures::Instrument as _;
296
297#[cfg(feature = "axum07")]
298use axum_07 as axum;
299#[cfg(feature = "axum08")]
300use axum_08 as axum;
301
302use axum::body;
303use axum::{
304 body::{Body, Bytes},
305 http::{response::Parts, Request, StatusCode},
306 response::{IntoResponse, Response},
307};
308
309use cached::{Cached, CloneCached, TimedCache};
310use tower::{Layer, Service};
311use tracing::{debug, instrument};
312
313pub trait Keyer {
316 type Key;
317
318 fn get_key(&self, request: &Request<Body>) -> Self::Key;
319}
320
321impl<K, F> Keyer for F
322where
323 F: Fn(&Request<Body>) -> K + Send + Sync + 'static,
324{
325 type Key = K;
326
327 fn get_key(&self, request: &Request<Body>) -> Self::Key {
328 self(request)
329 }
330}
331
332pub struct BasicKeyer;
337
338pub type BasicKey = (http::Method, http::Uri);
339
340impl Keyer for BasicKeyer {
341 type Key = BasicKey;
342
343 fn get_key(&self, request: &Request<Body>) -> Self::Key {
344 (request.method().clone(), request.uri().clone())
345 }
346}
347
348#[derive(Clone, Debug)]
350pub struct CachedResponse {
351 parts: Parts,
352 body: Bytes,
353 timestamp: Option<std::time::Instant>,
354}
355
356impl IntoResponse for CachedResponse {
357 fn into_response(self) -> Response {
358 let mut response = Response::from_parts(self.parts, Body::from(self.body));
359 if let Some(timestamp) = self.timestamp {
360 let age = timestamp.elapsed().as_secs();
361 response
362 .headers_mut()
363 .insert("X-Cache-Age", age.to_string().parse().unwrap());
364 }
365 response
366 }
367}
368
369pub struct CacheLayer<C, K> {
373 cache: Arc<Mutex<C>>,
374 use_stale: bool,
375 limit: usize,
376 allow_invalidation: bool,
377 add_response_headers: bool,
378 keyer: Arc<K>,
379}
380
381impl<C, K> Clone for CacheLayer<C, K> {
382 fn clone(&self) -> Self {
383 Self {
384 cache: Arc::clone(&self.cache),
385 use_stale: self.use_stale,
386 limit: self.limit,
387 allow_invalidation: self.allow_invalidation,
388 add_response_headers: self.add_response_headers,
389 keyer: Arc::clone(&self.keyer),
390 }
391 }
392}
393
394impl<C, K> CacheLayer<C, K>
395where
396 C: Cached<K::Key, CachedResponse> + CloneCached<K::Key, CachedResponse>,
397 K: Keyer,
398 K::Key: Debug + Hash + Eq + Clone + Send + 'static,
399{
400 pub fn with_cache_and_keyer(cache: C, keyer: K) -> Self {
402 Self {
403 cache: Arc::new(Mutex::new(cache)),
404 use_stale: false,
405 limit: 128 * 1024 * 1024,
406 allow_invalidation: false,
407 add_response_headers: false,
408 keyer: Arc::new(keyer),
409 }
410 }
411
412 pub fn use_stale_on_failure(self) -> Self {
417 Self {
418 use_stale: true,
419 ..self
420 }
421 }
422
423 pub fn body_limit(self, new_limit: usize) -> Self {
425 Self {
426 limit: new_limit,
427 ..self
428 }
429 }
430
431 pub fn allow_invalidation(self) -> Self {
434 Self {
435 allow_invalidation: true,
436 ..self
437 }
438 }
439
440 pub fn add_response_headers(self) -> Self {
442 Self {
443 add_response_headers: true,
444 ..self
445 }
446 }
447}
448
449impl<C> CacheLayer<C, BasicKeyer>
450where
451 C: Cached<BasicKey, CachedResponse> + CloneCached<BasicKey, CachedResponse>,
452{
453 pub fn with(cache: C) -> Self {
455 Self {
456 cache: Arc::new(Mutex::new(cache)),
457 use_stale: false,
458 limit: 128 * 1024 * 1024,
459 allow_invalidation: false,
460 add_response_headers: false,
461 keyer: Arc::new(BasicKeyer),
462 }
463 }
464}
465
466impl CacheLayer<TimedCache<BasicKey, CachedResponse>, BasicKey> {
467 pub fn with_lifespan(
469 ttl_sec: u64,
470 ) -> CacheLayer<TimedCache<BasicKey, CachedResponse>, BasicKeyer> {
471 CacheLayer::with(TimedCache::with_lifespan(ttl_sec))
472 }
473}
474
475impl<K> CacheLayer<TimedCache<K::Key, CachedResponse>, K>
476where
477 K: Keyer,
478 K::Key: Debug + Hash + Eq + Clone + Send + 'static,
479{
480 pub fn with_lifespan_and_keyer(
482 ttl_sec: u64,
483 keyer: K,
484 ) -> CacheLayer<TimedCache<K::Key, CachedResponse>, K> {
485 CacheLayer::with_cache_and_keyer(TimedCache::with_lifespan(ttl_sec), keyer)
486 }
487}
488
489impl<S, C, K> Layer<S> for CacheLayer<C, K>
490where
491 K: Keyer,
492 K::Key: Debug + Hash + Eq + Clone + Send + 'static,
493{
494 type Service = CacheService<S, C, K>;
495
496 fn layer(&self, inner: S) -> Self::Service {
497 Self::Service {
498 inner,
499 cache: Arc::clone(&self.cache),
500 use_stale: self.use_stale,
501 limit: self.limit,
502 allow_invalidation: self.allow_invalidation,
503 add_response_headers: self.add_response_headers,
504 keyer: Arc::clone(&self.keyer),
505 }
506 }
507}
508
509pub struct CacheService<S, C, K> {
510 inner: S,
511 cache: Arc<Mutex<C>>,
512 use_stale: bool,
513 limit: usize,
514 allow_invalidation: bool,
515 add_response_headers: bool,
516 keyer: Arc<K>,
517}
518
519impl<S, C, K> Clone for CacheService<S, C, K>
520where
521 S: Clone,
522{
523 fn clone(&self) -> Self {
524 Self {
525 inner: self.inner.clone(),
526 cache: Arc::clone(&self.cache),
527 use_stale: self.use_stale,
528 limit: self.limit,
529 allow_invalidation: self.allow_invalidation,
530 add_response_headers: self.add_response_headers,
531 keyer: Arc::clone(&self.keyer),
532 }
533 }
534}
535
536impl<S, C, K> Service<Request<Body>> for CacheService<S, C, K>
537where
538 S: Service<Request<Body>, Response = Response, Error = Infallible> + Clone + Send,
539 S::Future: Send + 'static,
540 C: Cached<K::Key, CachedResponse> + CloneCached<K::Key, CachedResponse> + Send + 'static,
541 K: Keyer,
542 K::Key: Debug + Hash + Eq + Clone + Send + 'static,
543{
544 type Response = Response;
545 type Error = Infallible;
546 type Future = Pin<Box<dyn Future<Output = Result<Response, Infallible>> + Send + 'static>>;
547
548 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
549 self.inner.poll_ready(cx)
550 }
551
552 #[instrument(skip(self, request))]
553 fn call(&mut self, request: Request<Body>) -> Self::Future {
554 let mut inner = self.inner.clone();
555 let use_stale = self.use_stale;
556 let allow_invalidation = self.allow_invalidation;
557 let add_response_headers = self.add_response_headers;
558 let limit = self.limit;
559 let cache = Arc::clone(&self.cache);
560 let key = self.keyer.get_key(&request);
561
562 if allow_invalidation && request.headers().contains_key("X-Invalidate-Cache") {
564 cache.lock().unwrap().cache_remove(&key);
566 debug!("Cache invalidated manually for key {:?}", key);
567 }
568
569 let inner_fut = inner
570 .call(request)
571 .instrument(tracing::info_span!("inner_service"));
572 let (cached, evicted) = {
573 let mut guard = cache.lock().unwrap();
574 let (cached, evicted) = guard.cache_get_expired(&key);
575 if let (Some(stale), true) = (cached.as_ref(), evicted) {
576 debug!("Found stale value in cache, reinsterting and attempting refresh");
578 guard.cache_set(key.clone(), stale.clone());
579 }
580 (cached, evicted)
581 };
582
583 Box::pin(async move {
584 match (cached, evicted) {
585 (Some(value), false) => Ok(value.into_response()),
586 (Some(stale_value), true) => {
587 let response = inner_fut.await.unwrap();
588 if response.status().is_success() {
589 Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
590 } else if use_stale {
591 debug!("Returning stale value.");
592 Ok(stale_value.into_response())
593 } else {
594 debug!("Stale value in cache, evicting and returning failed response.");
595 cache.lock().unwrap().cache_remove(&key);
596 Ok(response)
597 }
598 }
599 (None, _) => {
600 let response = inner_fut.await.unwrap();
601 if response.status().is_success() {
602 Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
603 } else {
604 Ok(response)
605 }
606 }
607 }
608 })
609 }
610}
611
612#[instrument(skip(cache, response))]
613async fn update_cache<C, K>(
614 cache: &Arc<Mutex<C>>,
615 key: K,
616 response: Response,
617 limit: usize,
618 add_response_headers: bool,
619) -> Response
620where
621 C: Cached<K, CachedResponse> + CloneCached<K, CachedResponse>,
622 K: Debug + Hash + Eq + Clone + Send + 'static,
623{
624 let (parts, body) = response.into_parts();
625 let Ok(body) = body::to_bytes(body, limit).await else {
626 return (
627 StatusCode::INTERNAL_SERVER_ERROR,
628 format!("File too big, over {limit} bytes"),
629 )
630 .into_response();
631 };
632 let value = CachedResponse {
633 parts,
634 body,
635 timestamp: if add_response_headers {
636 Some(std::time::Instant::now())
637 } else {
638 None
639 },
640 };
641 {
642 cache.lock().unwrap().cache_set(key, value.clone());
643 }
644 value.into_response()
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use rand::Rng;
651 use std::sync::atomic::{AtomicIsize, Ordering};
652
653 #[cfg(feature = "axum07")]
654 use axum_07 as axum;
655 #[cfg(feature = "axum08")]
656 use axum_08 as axum;
657
658 use axum::{
659 extract::State,
660 http::{Request, StatusCode},
661 routing::get,
662 Router,
663 };
664
665 use tower::Service;
666
667 #[derive(Clone, Debug)]
668 struct Counter {
669 value: Arc<AtomicIsize>,
670 }
671
672 impl Counter {
673 fn new(init: isize) -> Self {
674 Self {
675 value: AtomicIsize::from(init).into(),
676 }
677 }
678
679 fn increment(&self) {
680 self.value.fetch_add(1, Ordering::Release);
681 }
682
683 fn read(&self) -> isize {
684 self.value.load(Ordering::Acquire)
685 }
686 }
687
688 #[tokio::test]
689 async fn should_use_cached_value() {
690 let handler = |State(cnt): State<Counter>| async move {
691 cnt.increment();
692 StatusCode::OK
693 };
694
695 let counter = Counter::new(0);
696 let cache = CacheLayer::with_lifespan(60).use_stale_on_failure();
697 let mut router = Router::new()
698 .route("/", get(handler).layer(cache))
699 .with_state(counter.clone());
700
701 for _ in 0..10 {
702 let status = router
703 .call(Request::get("/").body(Body::empty()).unwrap())
704 .await
705 .unwrap()
706 .status();
707 assert!(status.is_success(), "handler should return success");
708 }
709
710 assert_eq!(1, counter.read(), "handler should’ve been called only once");
711 }
712
713 #[tokio::test]
714 async fn should_not_cache_unsuccessful_responses() {
715 let handler = |State(cnt): State<Counter>| async move {
716 cnt.increment();
717 let responses = [
718 StatusCode::BAD_REQUEST,
719 StatusCode::INTERNAL_SERVER_ERROR,
720 StatusCode::NOT_FOUND,
721 ];
722 let mut rng = rand::rng();
723 responses[rng.random_range(0..responses.len())]
724 };
725
726 let counter = Counter::new(0);
727 let cache = CacheLayer::with_lifespan(60).use_stale_on_failure();
728 let mut router = Router::new()
729 .route("/", get(handler).layer(cache))
730 .with_state(counter.clone());
731
732 for _ in 0..10 {
733 let status = router
734 .call(Request::get("/").body(Body::empty()).unwrap())
735 .await
736 .unwrap()
737 .status();
738 assert!(!status.is_success(), "handler should never return success");
739 }
740
741 assert_eq!(
742 10,
743 counter.read(),
744 "handler should’ve been called for all requests"
745 );
746 }
747
748 #[tokio::test]
749 async fn should_use_last_correct_stale_value() {
750 let handler = |State(cnt): State<Counter>| async move {
751 let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
752 let responses = [
753 StatusCode::BAD_REQUEST,
754 StatusCode::INTERNAL_SERVER_ERROR,
755 StatusCode::NOT_FOUND,
756 ];
757 let mut rng = rand::rng();
758
759 if prev == 0 {
761 StatusCode::OK
762 } else {
763 responses[rng.random_range(0..responses.len())]
764 }
765 };
766
767 let counter = Counter::new(0);
768 let cache = CacheLayer::with_lifespan(1).use_stale_on_failure();
769 let mut router = Router::new()
770 .route("/", get(handler).layer(cache))
771 .with_state(counter);
772
773 let status = router
775 .call(Request::get("/").body(Body::empty()).unwrap())
776 .await
777 .unwrap()
778 .status();
779 assert!(status.is_success(), "handler should return success");
780
781 tokio::time::sleep(tokio::time::Duration::from_millis(1050)).await;
783
784 for _ in 1..10 {
785 let status = router
786 .call(Request::get("/").body(Body::empty()).unwrap())
787 .await
788 .unwrap()
789 .status();
790 assert!(
791 status.is_success(),
792 "cache should return stale successful value"
793 );
794 }
795 }
796
797 #[tokio::test]
798 async fn should_not_use_stale_values() {
799 let handler = |State(cnt): State<Counter>| async move {
800 let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
801 let responses = [
802 StatusCode::BAD_REQUEST,
803 StatusCode::INTERNAL_SERVER_ERROR,
804 StatusCode::NOT_FOUND,
805 ];
806 let mut rng = rand::rng();
807
808 if prev == 0 {
810 StatusCode::OK
811 } else {
812 responses[rng.random_range(0..responses.len())]
813 }
814 };
815
816 let counter = Counter::new(0);
817 let cache = CacheLayer::with_lifespan(1);
818 let mut router = Router::new()
819 .route("/", get(handler).layer(cache))
820 .with_state(counter.clone());
821
822 let status = router
824 .call(Request::get("/").body(Body::empty()).unwrap())
825 .await
826 .unwrap()
827 .status();
828 assert!(status.is_success(), "handler should return success");
829
830 tokio::time::sleep(tokio::time::Duration::from_millis(1050)).await;
832
833 for _ in 1..10 {
834 let status = router
835 .call(Request::get("/").body(Body::empty()).unwrap())
836 .await
837 .unwrap()
838 .status();
839 assert!(
840 !status.is_success(),
841 "cache should forward unsuccessful values"
842 );
843 }
844
845 assert_eq!(
846 10,
847 counter.read(),
848 "handler should’ve been called for all requests"
849 );
850 }
851
852 #[tokio::test]
853 async fn should_not_invalidate_cache_when_disabled() {
854 let handler = |State(cnt): State<Counter>| async move {
855 cnt.increment();
856 StatusCode::OK
857 };
858
859 let counter = Counter::new(0);
860 let cache = CacheLayer::with_lifespan(60);
861 let mut router = Router::new()
862 .route("/", get(handler).layer(cache))
863 .with_state(counter.clone());
864
865 let status = router
867 .call(Request::get("/").body(Body::empty()).unwrap())
868 .await
869 .unwrap()
870 .status();
871 assert!(status.is_success(), "handler should return success");
872
873 let status = router
875 .call(Request::get("/").body(Body::empty()).unwrap())
876 .await
877 .unwrap()
878 .status();
879 assert!(status.is_success(), "handler should return success");
880
881 let status = router
883 .call(
884 Request::get("/")
885 .header("X-Invalidate-Cache", "true")
886 .body(Body::empty())
887 .unwrap(),
888 )
889 .await
890 .unwrap()
891 .status();
892 assert!(status.is_success(), "handler should return success");
893
894 let status = router
896 .call(Request::get("/").body(Body::empty()).unwrap())
897 .await
898 .unwrap()
899 .status();
900 assert!(status.is_success(), "handler should return success");
901
902 assert_eq!(1, counter.read(), "handler should’ve been called only once");
903 }
904
905 #[tokio::test]
906 async fn should_invalidate_cache_when_enabled() {
907 let handler = |State(cnt): State<Counter>| async move {
908 cnt.increment();
909 StatusCode::OK
910 };
911
912 let counter = Counter::new(0);
913 let cache = CacheLayer::with_lifespan(60).allow_invalidation();
914 let mut router = Router::new()
915 .route("/", get(handler).layer(cache))
916 .with_state(counter.clone());
917
918 let status = router
920 .call(Request::get("/").body(Body::empty()).unwrap())
921 .await
922 .unwrap()
923 .status();
924 assert!(status.is_success(), "handler should return success");
925
926 let status = router
928 .call(Request::get("/").body(Body::empty()).unwrap())
929 .await
930 .unwrap()
931 .status();
932 assert!(status.is_success(), "handler should return success");
933
934 let status = router
936 .call(
937 Request::get("/")
938 .header("X-Invalidate-Cache", "true")
939 .body(Body::empty())
940 .unwrap(),
941 )
942 .await
943 .unwrap()
944 .status();
945 assert!(status.is_success(), "handler should return success");
946
947 let status = router
949 .call(Request::get("/").body(Body::empty()).unwrap())
950 .await
951 .unwrap()
952 .status();
953 assert!(status.is_success(), "handler should return success");
954
955 assert_eq!(2, counter.read(), "handler should’ve been called twice");
956 }
957
958 #[tokio::test]
959 async fn should_not_include_age_header_when_disabled() {
960 let handler = |State(cnt): State<Counter>| async move {
961 cnt.increment();
962 StatusCode::OK
963 };
964
965 let counter = Counter::new(0);
966 let cache = CacheLayer::with_lifespan(60);
967 let mut router = Router::new()
968 .route("/", get(handler).layer(cache))
969 .with_state(counter.clone());
970
971 let response = router
973 .call(Request::get("/").body(Body::empty()).unwrap())
974 .await
975 .unwrap();
976 assert!(
977 response.status().is_success(),
978 "handler should return success"
979 );
980
981 let response = router
983 .call(Request::get("/").body(Body::empty()).unwrap())
984 .await
985 .unwrap();
986 assert!(
987 response.status().is_success(),
988 "handler should return success"
989 );
990 assert!(
991 response.headers().get("X-Cache-Age").is_none(),
992 "Age header should not be present"
993 );
994
995 assert_eq!(1, counter.read(), "handler should’ve been called only once");
996 }
997
998 #[tokio::test]
999 async fn should_include_age_header_when_enabled() {
1000 let handler = |State(cnt): State<Counter>| async move {
1001 cnt.increment();
1002 StatusCode::OK
1003 };
1004
1005 let counter = Counter::new(0);
1006 let cache = CacheLayer::with_lifespan(60).add_response_headers();
1007 let mut router = Router::new()
1008 .route("/", get(handler).layer(cache))
1009 .with_state(counter.clone());
1010
1011 let response = router
1013 .call(Request::get("/").body(Body::empty()).unwrap())
1014 .await
1015 .unwrap();
1016 assert!(
1017 response.status().is_success(),
1018 "handler should return success"
1019 );
1020
1021 assert_eq!(
1023 response
1024 .headers()
1025 .get("X-Cache-Age")
1026 .and_then(|v| v.to_str().ok())
1027 .unwrap_or(""),
1028 "0",
1029 "Age header should be present and equal to 0"
1030 );
1031 tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1033 let response = router
1035 .call(Request::get("/").body(Body::empty()).unwrap())
1036 .await
1037 .unwrap();
1038
1039 assert_eq!(
1040 response
1041 .headers()
1042 .get("X-Cache-Age")
1043 .and_then(|v| v.to_str().ok())
1044 .unwrap_or(""),
1045 "2",
1046 "Age header should be present and equal to 2"
1047 );
1048
1049 assert_eq!(1, counter.read(), "handler should’ve been called only once");
1050 }
1051
1052 #[tokio::test]
1053 async fn should_cache_by_custom_keys() {
1054 let handler = |State(cnt): State<Counter>| async move {
1055 cnt.increment();
1056 StatusCode::OK
1057 };
1058
1059 let counter = Counter::new(0);
1060 let keyer = |request: &Request<Body>| {
1061 (
1062 request.method().clone(),
1063 request
1064 .headers()
1065 .get(axum::http::header::ACCEPT)
1066 .and_then(|c| c.to_str().ok())
1067 .unwrap_or("")
1068 .to_string(),
1069 request.uri().clone(),
1070 )
1071 };
1072 let cache = CacheLayer::with_lifespan_and_keyer(60, keyer).add_response_headers();
1073 let mut router = Router::new()
1074 .route("/", get(handler).layer(cache))
1075 .with_state(counter.clone());
1076
1077 let response = router
1079 .call(Request::get("/").body(Body::empty()).unwrap())
1080 .await
1081 .unwrap();
1082 assert!(
1083 response.status().is_success(),
1084 "handler should return success"
1085 );
1086
1087 assert_eq!(
1089 response
1090 .headers()
1091 .get("X-Cache-Age")
1092 .and_then(|v| v.to_str().ok())
1093 .unwrap_or(""),
1094 "0",
1095 "Age header should be present and equal to 0"
1096 );
1097
1098 tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1100 let response = router
1102 .call(Request::get("/").body(Body::empty()).unwrap())
1103 .await
1104 .unwrap();
1105
1106 assert_eq!(
1107 response
1108 .headers()
1109 .get("X-Cache-Age")
1110 .and_then(|v| v.to_str().ok())
1111 .unwrap_or(""),
1112 "2",
1113 "Age header should be present and equal to 2"
1114 );
1115
1116 let response = router
1118 .call(
1119 Request::get("/")
1120 .header(axum::http::header::ACCEPT, "application/json")
1121 .body(Body::empty())
1122 .unwrap(),
1123 )
1124 .await
1125 .unwrap();
1126
1127 assert_eq!(
1128 response
1129 .headers()
1130 .get("X-Cache-Age")
1131 .and_then(|v| v.to_str().ok())
1132 .unwrap_or(""),
1133 "0",
1134 "Age header should be present and equal to 0"
1135 );
1136
1137 tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1139 let response = router
1141 .call(
1142 Request::get("/")
1143 .header(axum::http::header::ACCEPT, "application/json")
1144 .body(Body::empty())
1145 .unwrap(),
1146 )
1147 .await
1148 .unwrap();
1149
1150 assert_eq!(
1151 response
1152 .headers()
1153 .get("X-Cache-Age")
1154 .and_then(|v| v.to_str().ok())
1155 .unwrap_or(""),
1156 "2",
1157 "Age header should be present and equal to 2"
1158 );
1159
1160 assert_eq!(
1161 2,
1162 counter.read(),
1163 "handler should’ve been called only twice"
1164 );
1165 }
1166}