1use std::{
293 convert::Infallible,
294 fmt::Debug,
295 future::Future,
296 hash::Hash,
297 pin::Pin,
298 sync::{Arc, Mutex},
299 task::{Context, Poll},
300 time::Duration,
301};
302use tracing_futures::Instrument as _;
303
304#[cfg(feature = "axum07")]
305use axum_07 as axum;
306#[cfg(feature = "axum08")]
307use axum_08 as axum;
308
309use axum::body;
310use axum::{
311 body::{Body, Bytes},
312 http::{response::Parts, Request, StatusCode},
313 response::{IntoResponse, Response},
314};
315
316use cached::{Cached, CloneCached, TimedCache};
317use tower::{Layer, Service};
318use tracing::{debug, instrument};
319
320pub trait Keyer {
323 type Key;
324
325 fn get_key(&self, request: &Request<Body>) -> Self::Key;
326}
327
328impl<K, F> Keyer for F
329where
330 F: Fn(&Request<Body>) -> K + Send + Sync + 'static,
331{
332 type Key = K;
333
334 fn get_key(&self, request: &Request<Body>) -> Self::Key {
335 self(request)
336 }
337}
338
339pub struct BasicKeyer;
344
345pub type BasicKey = (http::Method, http::Uri);
346
347impl Keyer for BasicKeyer {
348 type Key = BasicKey;
349
350 fn get_key(&self, request: &Request<Body>) -> Self::Key {
351 (request.method().clone(), request.uri().clone())
352 }
353}
354
355#[derive(Clone, Debug)]
357pub struct CachedResponse {
358 parts: Parts,
359 body: Bytes,
360 timestamp: Option<std::time::Instant>,
361}
362
363impl IntoResponse for CachedResponse {
364 fn into_response(self) -> Response {
365 let mut response = Response::from_parts(self.parts, Body::from(self.body));
366 if let Some(timestamp) = self.timestamp {
367 let age = timestamp.elapsed().as_secs();
368 response
369 .headers_mut()
370 .insert("X-Cache-Age", age.to_string().parse().unwrap());
371 }
372 response
373 }
374}
375
376pub struct CacheLayer<C, K> {
380 cache: Arc<Mutex<C>>,
381 use_stale: bool,
382 limit: usize,
383 allow_invalidation: bool,
384 add_response_headers: bool,
385 keyer: Arc<K>,
386}
387
388impl<C, K> Clone for CacheLayer<C, K> {
389 fn clone(&self) -> Self {
390 Self {
391 cache: Arc::clone(&self.cache),
392 use_stale: self.use_stale,
393 limit: self.limit,
394 allow_invalidation: self.allow_invalidation,
395 add_response_headers: self.add_response_headers,
396 keyer: Arc::clone(&self.keyer),
397 }
398 }
399}
400
401impl<C, K> CacheLayer<C, K>
402where
403 C: Cached<K::Key, CachedResponse> + CloneCached<K::Key, CachedResponse>,
404 K: Keyer,
405 K::Key: Debug + Hash + Eq + Clone + Send + 'static,
406{
407 pub fn with_cache_and_keyer(cache: C, keyer: K) -> Self {
409 Self {
410 cache: Arc::new(Mutex::new(cache)),
411 use_stale: false,
412 limit: 128 * 1024 * 1024,
413 allow_invalidation: false,
414 add_response_headers: false,
415 keyer: Arc::new(keyer),
416 }
417 }
418
419 pub fn use_stale_on_failure(self) -> Self {
424 Self {
425 use_stale: true,
426 ..self
427 }
428 }
429
430 pub fn body_limit(self, new_limit: usize) -> Self {
432 Self {
433 limit: new_limit,
434 ..self
435 }
436 }
437
438 pub fn allow_invalidation(self) -> Self {
441 Self {
442 allow_invalidation: true,
443 ..self
444 }
445 }
446
447 pub fn add_response_headers(self) -> Self {
449 Self {
450 add_response_headers: true,
451 ..self
452 }
453 }
454}
455
456impl<C> CacheLayer<C, BasicKeyer>
457where
458 C: Cached<BasicKey, CachedResponse> + CloneCached<BasicKey, CachedResponse>,
459{
460 pub fn with(cache: C) -> Self {
462 Self {
463 cache: Arc::new(Mutex::new(cache)),
464 use_stale: false,
465 limit: 128 * 1024 * 1024,
466 allow_invalidation: false,
467 add_response_headers: false,
468 keyer: Arc::new(BasicKeyer),
469 }
470 }
471}
472
473impl CacheLayer<TimedCache<BasicKey, CachedResponse>, BasicKey> {
474 pub fn with_lifespan(
476 ttl: Duration,
477 ) -> CacheLayer<TimedCache<BasicKey, CachedResponse>, BasicKeyer> {
478 CacheLayer::with(TimedCache::with_lifespan(ttl))
479 }
480}
481
482impl<K> CacheLayer<TimedCache<K::Key, CachedResponse>, K>
483where
484 K: Keyer,
485 K::Key: Debug + Hash + Eq + Clone + Send + 'static,
486{
487 pub fn with_lifespan_and_keyer(
489 ttl: Duration,
490 keyer: K,
491 ) -> CacheLayer<TimedCache<K::Key, CachedResponse>, K> {
492 CacheLayer::with_cache_and_keyer(TimedCache::with_lifespan(ttl), keyer)
493 }
494}
495
496impl<S, C, K> Layer<S> for CacheLayer<C, K>
497where
498 K: Keyer,
499 K::Key: Debug + Hash + Eq + Clone + Send + 'static,
500{
501 type Service = CacheService<S, C, K>;
502
503 fn layer(&self, inner: S) -> Self::Service {
504 Self::Service {
505 inner,
506 cache: Arc::clone(&self.cache),
507 use_stale: self.use_stale,
508 limit: self.limit,
509 allow_invalidation: self.allow_invalidation,
510 add_response_headers: self.add_response_headers,
511 keyer: Arc::clone(&self.keyer),
512 }
513 }
514}
515
516pub struct CacheService<S, C, K> {
517 inner: S,
518 cache: Arc<Mutex<C>>,
519 use_stale: bool,
520 limit: usize,
521 allow_invalidation: bool,
522 add_response_headers: bool,
523 keyer: Arc<K>,
524}
525
526impl<S, C, K> Clone for CacheService<S, C, K>
527where
528 S: Clone,
529{
530 fn clone(&self) -> Self {
531 Self {
532 inner: self.inner.clone(),
533 cache: Arc::clone(&self.cache),
534 use_stale: self.use_stale,
535 limit: self.limit,
536 allow_invalidation: self.allow_invalidation,
537 add_response_headers: self.add_response_headers,
538 keyer: Arc::clone(&self.keyer),
539 }
540 }
541}
542
543impl<S, C, K> Service<Request<Body>> for CacheService<S, C, K>
544where
545 S: Service<Request<Body>, Response = Response, Error = Infallible> + Clone + Send,
546 S::Future: Send + 'static,
547 C: Cached<K::Key, CachedResponse> + CloneCached<K::Key, CachedResponse> + Send + 'static,
548 K: Keyer,
549 K::Key: Debug + Hash + Eq + Clone + Send + 'static,
550{
551 type Response = Response;
552 type Error = Infallible;
553 type Future = Pin<Box<dyn Future<Output = Result<Response, Infallible>> + Send + 'static>>;
554
555 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
556 self.inner.poll_ready(cx)
557 }
558
559 #[instrument(skip(self, request))]
560 fn call(&mut self, request: Request<Body>) -> Self::Future {
561 let mut inner = self.inner.clone();
562 let use_stale = self.use_stale;
563 let allow_invalidation = self.allow_invalidation;
564 let add_response_headers = self.add_response_headers;
565 let limit = self.limit;
566 let cache = Arc::clone(&self.cache);
567 let key = self.keyer.get_key(&request);
568
569 if allow_invalidation && request.headers().contains_key("X-Invalidate-Cache") {
571 cache.lock().unwrap().cache_remove(&key);
573 debug!("Cache invalidated manually for key {:?}", key);
574 }
575
576 let inner_fut = inner
577 .call(request)
578 .instrument(tracing::info_span!("inner_service"));
579 let (cached, evicted) = {
580 let mut guard = cache.lock().unwrap();
581 let (cached, evicted) = guard.cache_get_expired(&key);
582 if let (Some(stale), true) = (cached.as_ref(), evicted) {
583 debug!("Found stale value in cache, reinsterting and attempting refresh");
585 guard.cache_set(key.clone(), stale.clone());
586 }
587 (cached, evicted)
588 };
589
590 Box::pin(async move {
591 match (cached, evicted) {
592 (Some(value), false) => Ok(value.into_response()),
593 (Some(stale_value), true) => {
594 let response = inner_fut.await.unwrap();
595 if response.status().is_success() {
596 Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
597 } else if use_stale {
598 debug!("Returning stale value.");
599 Ok(stale_value.into_response())
600 } else {
601 debug!("Stale value in cache, evicting and returning failed response.");
602 cache.lock().unwrap().cache_remove(&key);
603 Ok(response)
604 }
605 }
606 (None, _) => {
607 let response = inner_fut.await.unwrap();
608 if response.status().is_success() {
609 Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
610 } else {
611 Ok(response)
612 }
613 }
614 }
615 })
616 }
617}
618
619#[instrument(skip(cache, response))]
620async fn update_cache<C, K>(
621 cache: &Arc<Mutex<C>>,
622 key: K,
623 response: Response,
624 limit: usize,
625 add_response_headers: bool,
626) -> Response
627where
628 C: Cached<K, CachedResponse> + CloneCached<K, CachedResponse>,
629 K: Debug + Hash + Eq + Clone + Send + 'static,
630{
631 let (parts, body) = response.into_parts();
632 let Ok(body) = body::to_bytes(body, limit).await else {
633 return (
634 StatusCode::INTERNAL_SERVER_ERROR,
635 format!("File too big, over {limit} bytes"),
636 )
637 .into_response();
638 };
639 let value = CachedResponse {
640 parts,
641 body,
642 timestamp: if add_response_headers {
643 Some(std::time::Instant::now())
644 } else {
645 None
646 },
647 };
648 {
649 cache.lock().unwrap().cache_set(key, value.clone());
650 }
651 value.into_response()
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657 use rand::Rng;
658 use std::sync::atomic::{AtomicIsize, Ordering};
659
660 #[cfg(feature = "axum07")]
661 use axum_07 as axum;
662 #[cfg(feature = "axum08")]
663 use axum_08 as axum;
664
665 use axum::{
666 extract::State,
667 http::{Request, StatusCode},
668 routing::get,
669 Router,
670 };
671
672 use tower::Service;
673
674 #[derive(Clone, Debug)]
675 struct Counter {
676 value: Arc<AtomicIsize>,
677 }
678
679 impl Counter {
680 fn new(init: isize) -> Self {
681 Self {
682 value: AtomicIsize::from(init).into(),
683 }
684 }
685
686 fn increment(&self) {
687 self.value.fetch_add(1, Ordering::Release);
688 }
689
690 fn read(&self) -> isize {
691 self.value.load(Ordering::Acquire)
692 }
693 }
694
695 #[tokio::test]
696 async fn should_use_cached_value() {
697 let handler = |State(cnt): State<Counter>| async move {
698 cnt.increment();
699 StatusCode::OK
700 };
701
702 let counter = Counter::new(0);
703 let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).use_stale_on_failure();
704 let mut router = Router::new()
705 .route("/", get(handler).layer(cache))
706 .with_state(counter.clone());
707
708 for _ in 0..10 {
709 let status = router
710 .call(Request::get("/").body(Body::empty()).unwrap())
711 .await
712 .unwrap()
713 .status();
714 assert!(status.is_success(), "handler should return success");
715 }
716
717 assert_eq!(1, counter.read(), "handler should’ve been called only once");
718 }
719
720 #[tokio::test]
721 async fn should_not_cache_unsuccessful_responses() {
722 let handler = |State(cnt): State<Counter>| async move {
723 cnt.increment();
724 let responses = [
725 StatusCode::BAD_REQUEST,
726 StatusCode::INTERNAL_SERVER_ERROR,
727 StatusCode::NOT_FOUND,
728 ];
729 let mut rng = rand::rng();
730 responses[rng.random_range(0..responses.len())]
731 };
732
733 let counter = Counter::new(0);
734 let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).use_stale_on_failure();
735 let mut router = Router::new()
736 .route("/", get(handler).layer(cache))
737 .with_state(counter.clone());
738
739 for _ in 0..10 {
740 let status = router
741 .call(Request::get("/").body(Body::empty()).unwrap())
742 .await
743 .unwrap()
744 .status();
745 assert!(!status.is_success(), "handler should never return success");
746 }
747
748 assert_eq!(
749 10,
750 counter.read(),
751 "handler should’ve been called for all requests"
752 );
753 }
754
755 #[tokio::test]
756 async fn should_use_last_correct_stale_value() {
757 let handler = |State(cnt): State<Counter>| async move {
758 let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
759 let responses = [
760 StatusCode::BAD_REQUEST,
761 StatusCode::INTERNAL_SERVER_ERROR,
762 StatusCode::NOT_FOUND,
763 ];
764 let mut rng = rand::rng();
765
766 if prev == 0 {
768 StatusCode::OK
769 } else {
770 responses[rng.random_range(0..responses.len())]
771 }
772 };
773
774 let counter = Counter::new(0);
775 let cache = CacheLayer::with_lifespan(Duration::from_millis(100)).use_stale_on_failure();
776 let mut router = Router::new()
777 .route("/", get(handler).layer(cache))
778 .with_state(counter);
779
780 let status = router
782 .call(Request::get("/").body(Body::empty()).unwrap())
783 .await
784 .unwrap()
785 .status();
786 assert!(status.is_success(), "handler should return success");
787
788 tokio::time::sleep(tokio::time::Duration::from_millis(105)).await;
790
791 for _ in 1..10 {
792 let status = router
793 .call(Request::get("/").body(Body::empty()).unwrap())
794 .await
795 .unwrap()
796 .status();
797 assert!(
798 status.is_success(),
799 "cache should return stale successful value"
800 );
801 }
802 }
803
804 #[tokio::test]
805 async fn should_not_use_stale_values() {
806 let handler = |State(cnt): State<Counter>| async move {
807 let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
808 let responses = [
809 StatusCode::BAD_REQUEST,
810 StatusCode::INTERNAL_SERVER_ERROR,
811 StatusCode::NOT_FOUND,
812 ];
813 let mut rng = rand::rng();
814
815 if prev == 0 {
817 StatusCode::OK
818 } else {
819 responses[rng.random_range(0..responses.len())]
820 }
821 };
822
823 let counter = Counter::new(0);
824 let cache = CacheLayer::with_lifespan(Duration::from_millis(100));
825 let mut router = Router::new()
826 .route("/", get(handler).layer(cache))
827 .with_state(counter.clone());
828
829 let status = router
831 .call(Request::get("/").body(Body::empty()).unwrap())
832 .await
833 .unwrap()
834 .status();
835 assert!(status.is_success(), "handler should return success");
836
837 tokio::time::sleep(tokio::time::Duration::from_millis(105)).await;
839
840 for _ in 1..10 {
841 let status = router
842 .call(Request::get("/").body(Body::empty()).unwrap())
843 .await
844 .unwrap()
845 .status();
846 assert!(
847 !status.is_success(),
848 "cache should forward unsuccessful values"
849 );
850 }
851
852 assert_eq!(
853 10,
854 counter.read(),
855 "handler should’ve been called for all requests"
856 );
857 }
858
859 #[tokio::test]
860 async fn should_not_invalidate_cache_when_disabled() {
861 let handler = |State(cnt): State<Counter>| async move {
862 cnt.increment();
863 StatusCode::OK
864 };
865
866 let counter = Counter::new(0);
867 let cache = CacheLayer::with_lifespan(Duration::from_secs(60));
868 let mut router = Router::new()
869 .route("/", get(handler).layer(cache))
870 .with_state(counter.clone());
871
872 let status = router
874 .call(Request::get("/").body(Body::empty()).unwrap())
875 .await
876 .unwrap()
877 .status();
878 assert!(status.is_success(), "handler should return success");
879
880 let status = router
882 .call(Request::get("/").body(Body::empty()).unwrap())
883 .await
884 .unwrap()
885 .status();
886 assert!(status.is_success(), "handler should return success");
887
888 let status = router
890 .call(
891 Request::get("/")
892 .header("X-Invalidate-Cache", "true")
893 .body(Body::empty())
894 .unwrap(),
895 )
896 .await
897 .unwrap()
898 .status();
899 assert!(status.is_success(), "handler should return success");
900
901 let status = router
903 .call(Request::get("/").body(Body::empty()).unwrap())
904 .await
905 .unwrap()
906 .status();
907 assert!(status.is_success(), "handler should return success");
908
909 assert_eq!(1, counter.read(), "handler should’ve been called only once");
910 }
911
912 #[tokio::test]
913 async fn should_invalidate_cache_when_enabled() {
914 let handler = |State(cnt): State<Counter>| async move {
915 cnt.increment();
916 StatusCode::OK
917 };
918
919 let counter = Counter::new(0);
920 let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).allow_invalidation();
921 let mut router = Router::new()
922 .route("/", get(handler).layer(cache))
923 .with_state(counter.clone());
924
925 let status = router
927 .call(Request::get("/").body(Body::empty()).unwrap())
928 .await
929 .unwrap()
930 .status();
931 assert!(status.is_success(), "handler should return success");
932
933 let status = router
935 .call(Request::get("/").body(Body::empty()).unwrap())
936 .await
937 .unwrap()
938 .status();
939 assert!(status.is_success(), "handler should return success");
940
941 let status = router
943 .call(
944 Request::get("/")
945 .header("X-Invalidate-Cache", "true")
946 .body(Body::empty())
947 .unwrap(),
948 )
949 .await
950 .unwrap()
951 .status();
952 assert!(status.is_success(), "handler should return success");
953
954 let status = router
956 .call(Request::get("/").body(Body::empty()).unwrap())
957 .await
958 .unwrap()
959 .status();
960 assert!(status.is_success(), "handler should return success");
961
962 assert_eq!(2, counter.read(), "handler should’ve been called twice");
963 }
964
965 #[tokio::test]
966 async fn should_not_include_age_header_when_disabled() {
967 let handler = |State(cnt): State<Counter>| async move {
968 cnt.increment();
969 StatusCode::OK
970 };
971
972 let counter = Counter::new(0);
973 let cache = CacheLayer::with_lifespan(Duration::from_secs(60));
974 let mut router = Router::new()
975 .route("/", get(handler).layer(cache))
976 .with_state(counter.clone());
977
978 let response = router
980 .call(Request::get("/").body(Body::empty()).unwrap())
981 .await
982 .unwrap();
983 assert!(
984 response.status().is_success(),
985 "handler should return success"
986 );
987
988 let response = router
990 .call(Request::get("/").body(Body::empty()).unwrap())
991 .await
992 .unwrap();
993 assert!(
994 response.status().is_success(),
995 "handler should return success"
996 );
997 assert!(
998 response.headers().get("X-Cache-Age").is_none(),
999 "Age header should not be present"
1000 );
1001
1002 assert_eq!(1, counter.read(), "handler should’ve been called only once");
1003 }
1004
1005 #[tokio::test]
1006 async fn should_include_age_header_when_enabled() {
1007 let handler = |State(cnt): State<Counter>| async move {
1008 cnt.increment();
1009 StatusCode::OK
1010 };
1011
1012 let counter = Counter::new(0);
1013 let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).add_response_headers();
1014 let mut router = Router::new()
1015 .route("/", get(handler).layer(cache))
1016 .with_state(counter.clone());
1017
1018 let response = router
1020 .call(Request::get("/").body(Body::empty()).unwrap())
1021 .await
1022 .unwrap();
1023 assert!(
1024 response.status().is_success(),
1025 "handler should return success"
1026 );
1027
1028 assert_eq!(
1030 response
1031 .headers()
1032 .get("X-Cache-Age")
1033 .and_then(|v| v.to_str().ok())
1034 .unwrap_or(""),
1035 "0",
1036 "Age header should be present and equal to 0"
1037 );
1038 tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1040 let response = router
1042 .call(Request::get("/").body(Body::empty()).unwrap())
1043 .await
1044 .unwrap();
1045
1046 assert_eq!(
1047 response
1048 .headers()
1049 .get("X-Cache-Age")
1050 .and_then(|v| v.to_str().ok())
1051 .unwrap_or(""),
1052 "2",
1053 "Age header should be present and equal to 2"
1054 );
1055
1056 assert_eq!(1, counter.read(), "handler should’ve been called only once");
1057 }
1058
1059 #[tokio::test]
1060 async fn should_cache_by_custom_keys() {
1061 let handler = |State(cnt): State<Counter>| async move {
1062 cnt.increment();
1063 StatusCode::OK
1064 };
1065
1066 let counter = Counter::new(0);
1067 let keyer = |request: &Request<Body>| {
1068 (
1069 request.method().clone(),
1070 request
1071 .headers()
1072 .get(axum::http::header::ACCEPT)
1073 .and_then(|c| c.to_str().ok())
1074 .unwrap_or("")
1075 .to_string(),
1076 request.uri().clone(),
1077 )
1078 };
1079 let cache = CacheLayer::with_lifespan_and_keyer(Duration::from_secs(60), keyer)
1080 .add_response_headers();
1081 let mut router = Router::new()
1082 .route("/", get(handler).layer(cache))
1083 .with_state(counter.clone());
1084
1085 let response = router
1087 .call(Request::get("/").body(Body::empty()).unwrap())
1088 .await
1089 .unwrap();
1090 assert!(
1091 response.status().is_success(),
1092 "handler should return success"
1093 );
1094
1095 assert_eq!(
1097 response
1098 .headers()
1099 .get("X-Cache-Age")
1100 .and_then(|v| v.to_str().ok())
1101 .unwrap_or(""),
1102 "0",
1103 "Age header should be present and equal to 0"
1104 );
1105
1106 tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1108 let response = router
1110 .call(Request::get("/").body(Body::empty()).unwrap())
1111 .await
1112 .unwrap();
1113
1114 assert_eq!(
1115 response
1116 .headers()
1117 .get("X-Cache-Age")
1118 .and_then(|v| v.to_str().ok())
1119 .unwrap_or(""),
1120 "2",
1121 "Age header should be present and equal to 2"
1122 );
1123
1124 let response = router
1126 .call(
1127 Request::get("/")
1128 .header(axum::http::header::ACCEPT, "application/json")
1129 .body(Body::empty())
1130 .unwrap(),
1131 )
1132 .await
1133 .unwrap();
1134
1135 assert_eq!(
1136 response
1137 .headers()
1138 .get("X-Cache-Age")
1139 .and_then(|v| v.to_str().ok())
1140 .unwrap_or(""),
1141 "0",
1142 "Age header should be present and equal to 0"
1143 );
1144
1145 tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
1147 let response = router
1149 .call(
1150 Request::get("/")
1151 .header(axum::http::header::ACCEPT, "application/json")
1152 .body(Body::empty())
1153 .unwrap(),
1154 )
1155 .await
1156 .unwrap();
1157
1158 assert_eq!(
1159 response
1160 .headers()
1161 .get("X-Cache-Age")
1162 .and_then(|v| v.to_str().ok())
1163 .unwrap_or(""),
1164 "2",
1165 "Age header should be present and equal to 2"
1166 );
1167
1168 assert_eq!(
1169 2,
1170 counter.read(),
1171 "handler should’ve been called only twice"
1172 );
1173 }
1174}