1#![warn(missing_docs)]
2use actix_web::{
48 body::{BodySize, BoxBody, EitherBody, MessageBody},
49 dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
50 http::header::HeaderMap,
51 web::{Bytes, BytesMut},
52 Error, HttpMessage,
53};
54use futures::{
55 future::{ready, LocalBoxFuture, Ready},
56 StreamExt,
57};
58use pin_project_lite::pin_project;
59use redis::{aio::MultiplexedConnection, AsyncCommands};
60use serde::{Deserialize, Serialize};
61use sha2::{Digest, Sha256};
62use std::{future::Future, marker::PhantomData, pin::Pin, rc::Rc};
63use std::{
64 sync::Arc,
65 task::{Context, Poll},
66};
67
68pub struct CacheContext<'a> {
74 pub method: &'a str,
76 pub path: &'a str,
78 pub query_string: &'a str,
80 pub headers: &'a HeaderMap,
82 pub body: &'a serde_json::Value,
84}
85
86type CachePredicate = Arc<dyn Fn(&CacheContext) -> bool + Send + Sync>;
91
92type CacheKeyFn = Arc<dyn Fn(&CacheContext) -> String + Send + Sync>;
97
98pub struct RedisCacheMiddleware {
103 redis_conn: Option<MultiplexedConnection>,
104 redis_url: String,
105 ttl: u64,
106 max_cacheable_size: usize,
107 cache_prefix: String,
108 cache_if: CachePredicate,
109 cache_key_fn: Option<CacheKeyFn>,
110}
111
112pub struct RedisCacheMiddlewareBuilder {
117 redis_url: String,
118 ttl: u64,
119 max_cacheable_size: usize,
120 cache_prefix: String,
121 cache_if: CachePredicate,
122 cache_key_fn: Option<CacheKeyFn>,
123}
124
125impl RedisCacheMiddlewareBuilder {
126 pub fn new(redis_url: impl Into<String>) -> Self {
140 Self {
141 redis_url: redis_url.into(),
142 ttl: 3600, max_cacheable_size: 1024 * 1024, cache_prefix: "cache:".to_string(),
145 cache_if: Arc::new(|_| true), cache_key_fn: None, }
148 }
149
150 pub fn ttl(mut self, seconds: u64) -> Self {
160 self.ttl = seconds;
161 self
162 }
163
164 pub fn max_cacheable_size(mut self, bytes: usize) -> Self {
176 self.max_cacheable_size = bytes;
177 self
178 }
179
180 pub fn cache_prefix(mut self, prefix: impl Into<String>) -> Self {
190 self.cache_prefix = prefix.into();
191 self
192 }
193
194 pub fn cache_if<F>(mut self, predicate: F) -> Self
222 where
223 F: Fn(&CacheContext) -> bool + Send + Sync + 'static,
224 {
225 self.cache_if = Arc::new(predicate);
226 self
227 }
228
229 pub fn with_cache_key<F>(mut self, key_fn: F) -> Self
249 where
250 F: Fn(&CacheContext) -> String + Send + Sync + 'static,
251 {
252 self.cache_key_fn = Some(Arc::new(key_fn));
253 self
254 }
255
256 pub fn build(self) -> RedisCacheMiddleware {
262 RedisCacheMiddleware {
263 redis_conn: None,
264 redis_url: self.redis_url,
265 ttl: self.ttl,
266 max_cacheable_size: self.max_cacheable_size,
267 cache_prefix: self.cache_prefix,
268 cache_if: self.cache_if,
269 cache_key_fn: self.cache_key_fn,
270 }
271 }
272}
273
274impl RedisCacheMiddleware {
275 pub fn new(redis_url: &str) -> Self {
287 RedisCacheMiddlewareBuilder::new(redis_url).build()
288 }
289}
290
291pub struct RedisCacheMiddlewareService<S> {
296 service: Rc<S>,
297 redis_conn: Option<MultiplexedConnection>,
298 redis_url: String,
299 ttl: u64,
300 max_cacheable_size: usize,
301 cache_prefix: String,
302 cache_if: CachePredicate,
303 cache_key_fn: Option<CacheKeyFn>,
304}
305
306#[derive(Serialize, Deserialize)]
307struct CachedResponse {
308 status: u16,
309 headers: Vec<(String, String)>,
310 body: Vec<u8>,
311}
312
313impl<S, B> Transform<S, ServiceRequest> for RedisCacheMiddleware
314where
315 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
316 S::Future: 'static,
317 B: 'static + MessageBody,
318{
319 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
320 type Error = Error;
321 type Transform = RedisCacheMiddlewareService<S>;
322 type InitError = ();
323 type Future = Ready<Result<Self::Transform, Self::InitError>>;
324
325 fn new_transform(&self, service: S) -> Self::Future {
327 ready(Ok(RedisCacheMiddlewareService {
328 service: Rc::new(service),
329 redis_conn: self.redis_conn.clone(),
330 redis_url: self.redis_url.clone(),
331 ttl: self.ttl,
332 max_cacheable_size: self.max_cacheable_size,
333 cache_prefix: self.cache_prefix.clone(),
334 cache_if: self.cache_if.clone(),
335 cache_key_fn: self.cache_key_fn.clone(),
336 }))
337 }
338}
339
340pin_project! {
342 struct CacheResponseFuture<S, B>
343 where
344 B: MessageBody,
345 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
346 {
347 #[pin]
348 fut: S::Future,
349 should_cache: bool,
350 cache_key: String,
351 redis_conn: Option<MultiplexedConnection>,
352 redis_url: String,
353 ttl: u64,
354 max_cacheable_size: usize,
355 _marker: PhantomData<B>,
356 }
357}
358
359impl<S, B> Future for CacheResponseFuture<S, B>
361where
362 B: MessageBody + 'static,
363 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
364{
365 type Output = Result<ServiceResponse<EitherBody<B, BoxBody>>, Error>;
366
367 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
368 let this = self.project();
369
370 let res = futures_util::ready!(this.fut.poll(cx))?;
371
372 let status = res.status();
373 let headers = res.headers().clone();
374 let should_cache = *this.should_cache && status.is_success();
375
376 if !should_cache {
377 return Poll::Ready(Ok(res.map_body(|_, b| EitherBody::left(b))));
378 }
379
380 let cache_key = this.cache_key.clone();
381 let redis_url = this.redis_url.clone();
382 let redis_conn = this.redis_conn.clone();
383 let ttl = *this.ttl;
384 let max_size = *this.max_cacheable_size;
385
386 let res = res.map_body(move |_, body| {
387 let filtered_headers = headers
388 .iter()
389 .filter(|(name, _)| {
390 !["connection", "transfer-encoding", "content-length"]
391 .contains(&name.as_str().to_lowercase().as_str())
392 })
393 .map(|(name, value)| {
394 (
395 name.to_string(),
396 value.to_str().unwrap_or_default().to_string(),
397 )
398 })
399 .collect::<Vec<_>>();
400
401 EitherBody::right(BoxBody::new(CacheableBody {
402 body: body.boxed(),
403 status: status.as_u16(),
404 headers: filtered_headers,
405 body_accum: BytesMut::new(),
406 cache_key,
407 redis_conn,
408 redis_url,
409 ttl,
410 max_size,
411 }))
412 });
413
414 Poll::Ready(Ok(res))
415 }
416}
417
418pin_project! {
420 struct CacheableBody {
421 #[pin]
422 body: BoxBody,
423 status: u16,
424 headers: Vec<(String, String)>,
425 body_accum: BytesMut,
426 cache_key: String,
427 redis_conn: Option<MultiplexedConnection>,
428 redis_url: String,
429 ttl: u64,
430 max_size: usize,
431 }
432
433 impl PinnedDrop for CacheableBody {
434 fn drop(this: Pin<&mut Self>) {
435 let this = this.project();
436
437 let body_bytes = this.body_accum.clone().freeze();
438 let status = *this.status;
439 let headers = this.headers.clone();
440 let cache_key = this.cache_key.clone();
441 let mut redis_conn = this.redis_conn.take();
442 let redis_url = this.redis_url.clone();
443 let ttl = *this.ttl;
444 let max_size = *this.max_size;
445
446 if !body_bytes.is_empty() && body_bytes.len() <= max_size {
447 actix_web::rt::spawn(async move {
448 let cached_response = CachedResponse {
449 status,
450 headers,
451 body: body_bytes.to_vec(),
452 };
453
454 if let Ok(serialized) = rmp_serde::to_vec(&cached_response) {
455 if redis_conn.is_none() {
456 let client = redis::Client::open(redis_url.as_str())
457 .expect("Failed to connect to Redis");
458
459 let conn = client
460 .get_multiplexed_async_connection()
461 .await
462 .expect("Failed to get Redis connection");
463
464 redis_conn = Some(conn);
465 }
466
467 if let Some(conn) = redis_conn.as_mut() {
468 let _: Result<(), redis::RedisError> =
469 conn.set_ex(cache_key, serialized, ttl).await;
470 }
471 }
472 });
473 }
474 }
475 }
476}
477
478impl MessageBody for CacheableBody {
479 type Error = <BoxBody as MessageBody>::Error;
480
481 fn size(&self) -> BodySize {
482 self.body.size()
483 }
484
485 fn poll_next(
486 self: Pin<&mut Self>,
487 cx: &mut Context<'_>,
488 ) -> Poll<Option<Result<Bytes, Self::Error>>> {
489 let this = self.project();
490
491 match this.body.poll_next(cx) {
493 Poll::Ready(Some(Ok(chunk))) => {
494 this.body_accum.extend_from_slice(&chunk);
495 Poll::Ready(Some(Ok(chunk)))
496 }
497 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
498 Poll::Ready(None) => Poll::Ready(None),
499 Poll::Pending => Poll::Pending,
500 }
501 }
502}
503
504impl<S, B> Service<ServiceRequest> for RedisCacheMiddlewareService<S>
505where
506 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
507 S::Future: 'static,
508 B: MessageBody + 'static,
509{
510 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
511 type Error = Error;
512 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
513
514 forward_ready!(service);
515
516 fn call(&self, mut req: ServiceRequest) -> Self::Future {
517 if let Some(cache_control) = req.headers().get("Cache-Control") {
519 if let Ok(cache_control_str) = cache_control.to_str() {
520 if cache_control_str.contains("no-cache") || cache_control_str.contains("no-store")
521 {
522 let fut = self.service.call(req);
523 return Box::pin(async move {
524 let res = fut.await?;
525 Ok(res.map_body(|_, b| EitherBody::left(b)))
526 });
527 }
528 }
529 }
530
531 let redis_url = self.redis_url.clone();
532 let mut redis_conn = self.redis_conn.clone();
533 let expiration = self.ttl;
534 let max_cacheable_size = self.max_cacheable_size;
535 let cache_prefix = self.cache_prefix.clone();
536 let service = Rc::clone(&self.service);
537 let cache_if = self.cache_if.clone();
538 let cache_key_fn = self.cache_key_fn.clone();
539
540 Box::pin(async move {
541 let body_bytes = req
542 .take_payload()
543 .fold(BytesMut::new(), move |mut body, chunk| async {
544 if let Ok(chunk) = chunk {
545 body.extend_from_slice(&chunk);
546 }
547 body
548 })
549 .await;
550
551 let cache_ctx = CacheContext {
552 method: req.method().as_str(),
553 path: req.path(),
554 query_string: req.query_string(),
555 headers: req.headers(),
556 body: &serde_json::from_slice(&body_bytes).unwrap_or(serde_json::Value::Null),
557 };
558
559 let should_cache = cache_if(&cache_ctx);
560
561 let base_key = if let Some(key_fn) = &cache_key_fn {
563 key_fn(&cache_ctx)
564 } else if body_bytes.is_empty() {
565 format!(
566 "{}:{}:{}",
567 req.method().as_str(),
568 req.path(),
569 req.query_string()
570 )
571 } else {
572 let body_hash = hex::encode(Sha256::digest(&body_bytes));
573 format!(
574 "{}:{}:{}:{}",
575 req.method().as_str(),
576 req.path(),
577 req.query_string(),
578 body_hash
579 )
580 };
581
582 req.set_payload(Payload::from(Bytes::from(body_bytes.clone())));
583
584 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
585 let cache_key = format!("{}{}", cache_prefix, hashed_key);
586
587 let cached_result: Option<Vec<u8>> = if should_cache {
588 if redis_conn.is_none() {
589 let client = redis::Client::open(redis_url.as_str())
590 .expect("Failed to connect to Redis");
591
592 let conn = client
593 .get_multiplexed_async_connection()
594 .await
595 .expect("Failed to get Redis connection");
596
597 redis_conn = Some(conn);
598 }
599
600 let conn = redis_conn.as_mut().unwrap();
601 conn.get(&cache_key).await.unwrap_or(None)
602 } else {
603 None
604 };
605
606 if let Some(cached_data) = cached_result {
607 log::debug!("Cache hit for {}", cache_key);
608
609 match rmp_serde::from_slice::<CachedResponse>(&cached_data) {
610 Ok(cached_response) => {
611 let mut response = actix_web::HttpResponse::build(
612 actix_web::http::StatusCode::from_u16(cached_response.status)
613 .unwrap_or(actix_web::http::StatusCode::OK),
614 );
615
616 for (name, value) in cached_response.headers {
617 response.insert_header((name, value));
618 }
619
620 response.insert_header(("X-Cache", "HIT"));
621
622 let resp = response.body(cached_response.body);
623 return Ok(req
624 .into_response(resp)
625 .map_body(|_, b| EitherBody::right(BoxBody::new(b))));
626 }
627 Err(e) => {
628 log::error!("Failed to deserialize cached response: {}", e);
629 }
630 }
631 }
632
633 log::debug!("Cache miss for {}", cache_key);
634 let future = CacheResponseFuture::<S, B> {
635 fut: service.call(req),
636 should_cache,
637 cache_key,
638 redis_conn,
639 redis_url,
640 ttl: expiration,
641 max_cacheable_size,
642 _marker: PhantomData,
643 };
644
645 future.await
646 })
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use actix_web::{http::header, test::TestRequest};
654
655 #[actix_web::test]
656 async fn test_builder_default_values() {
657 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
658 assert_eq!(builder.ttl, 3600);
659 assert_eq!(builder.max_cacheable_size, 1024 * 1024);
660 assert_eq!(builder.cache_prefix, "cache:");
661 assert_eq!(builder.redis_url, "redis://localhost");
662 }
663
664 #[actix_web::test]
665 async fn test_builder_custom_values() {
666 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
667 .ttl(60)
668 .max_cacheable_size(512 * 1024)
669 .cache_prefix("custom:");
670
671 assert_eq!(builder.ttl, 60);
672 assert_eq!(builder.max_cacheable_size, 512 * 1024);
673 assert_eq!(builder.cache_prefix, "custom:");
674 }
675
676 #[actix_web::test]
677 async fn test_builder_custom_predicate() {
678 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
679 .cache_if(|ctx| ctx.method == "GET");
680
681 let get_ctx = CacheContext {
683 method: "GET",
684 path: "/test",
685 query_string: "",
686 headers: &header::HeaderMap::new(),
687 body: &serde_json::Value::Null,
688 };
689
690 let post_ctx = CacheContext {
691 method: "POST",
692 path: "/test",
693 query_string: "",
694 headers: &header::HeaderMap::new(),
695 body: &serde_json::Value::Null,
696 };
697
698 assert!((builder.cache_if)(&get_ctx));
700 assert!(!(builder.cache_if)(&post_ctx));
701 }
702
703 #[actix_web::test]
704 async fn test_cache_key_generation() {
705 let req = TestRequest::get().uri("/test").to_srv_request();
707
708 let method = req.method().as_str();
710 let path = req.path();
711 let query_string = req.query_string();
712
713 let base_key = format!("{}:{}:{}", method, path, query_string);
715 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
716 let cache_key = format!("test:{}", hashed_key);
717
718 let expected_key = format!(
720 "test:{}",
721 hex::encode(Sha256::digest("GET:/test:".to_string().as_bytes()))
722 );
723
724 assert_eq!(cache_key, expected_key);
725 }
726
727 #[actix_web::test]
728 async fn test_cache_key_with_body() {
729 let body_bytes = b"test body";
731 let body_hash = hex::encode(Sha256::digest(body_bytes));
732
733 let base_key = format!("{}:{}:{}:{}", "POST", "/test", "", body_hash);
735 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
736 let cache_key = format!("test:{}", hashed_key);
737
738 let expected_key = format!(
740 "test:{}",
741 hex::encode(Sha256::digest(
742 format!("POST:/test::{}", body_hash).as_bytes()
743 ))
744 );
745
746 assert_eq!(cache_key, expected_key);
747 }
748
749 #[actix_web::test]
750 async fn test_cacheable_methods() {
751 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
753 let default_predicate = builder.cache_if;
754
755 let methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"];
756
757 for method in methods {
758 let ctx = CacheContext {
759 method,
760 path: "/test",
761 query_string: "",
762 headers: &header::HeaderMap::new(),
763 body: &serde_json::Value::Null,
764 };
765
766 assert!(
768 (default_predicate)(&ctx),
769 "Method {} should be cacheable by default",
770 method
771 );
772 }
773
774 let custom_builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
776 .cache_if(|ctx| matches!(ctx.method, "GET" | "HEAD"));
777
778 for method in methods {
779 let ctx = CacheContext {
780 method,
781 path: "/test",
782 query_string: "",
783 headers: &header::HeaderMap::new(),
784 body: &serde_json::Value::Null,
785 };
786
787 let should_cache = matches!(method, "GET" | "HEAD");
789 assert_eq!(
790 (custom_builder.cache_if)(&ctx),
791 should_cache,
792 "Method {} should be cacheable: {}",
793 method,
794 should_cache
795 );
796 }
797 }
798
799 #[actix_web::test]
800 async fn test_predicate_with_headers() {
801 let predicate = |ctx: &CacheContext| !ctx.headers.contains_key("Authorization");
805
806 let mut headers = header::HeaderMap::new();
808 let ctx_no_auth = CacheContext {
809 method: "GET",
810 path: "/test",
811 query_string: "",
812 headers: &headers,
813 body: &serde_json::Value::Null,
814 };
815
816 assert!(
817 predicate(&ctx_no_auth),
818 "Request without Authorization should be cached"
819 );
820
821 headers.insert(
823 header::AUTHORIZATION,
824 header::HeaderValue::from_static("Bearer token"),
825 );
826
827 let ctx_with_auth = CacheContext {
828 method: "GET",
829 path: "/test",
830 query_string: "",
831 headers: &headers,
832 body: &serde_json::Value::Null,
833 };
834
835 assert!(
836 !predicate(&ctx_with_auth),
837 "Request with Authorization should not be cached"
838 );
839 }
840
841 #[actix_web::test]
842 async fn test_predicate_with_path_patterns() {
843 let predicate =
847 |ctx: &CacheContext| !ctx.path.starts_with("/admin") && !ctx.path.contains("/private/");
848
849 let cacheable_paths = ["/", "/api/users", "/public/resource", "/api/v1/data"];
851
852 for path in cacheable_paths {
853 let ctx = CacheContext {
854 method: "GET",
855 path,
856 query_string: "",
857 headers: &header::HeaderMap::new(),
858 body: &serde_json::Value::Null,
859 };
860
861 assert!(predicate(&ctx), "Path {} should be cacheable", path);
862 }
863
864 let non_cacheable_paths = ["/admin", "/admin/users", "/users/private/profile"];
866
867 for path in non_cacheable_paths {
868 let ctx = CacheContext {
869 method: "GET",
870 path,
871 query_string: "",
872 headers: &header::HeaderMap::new(),
873 body: &serde_json::Value::Null,
874 };
875
876 assert!(!predicate(&ctx), "Path {} should not be cacheable", path);
877 }
878 }
879
880 #[actix_web::test]
881 async fn test_cached_response_serialization() {
882 let cached_response = CachedResponse {
884 status: 200,
885 headers: vec![
886 ("Content-Type".to_string(), "text/plain".to_string()),
887 ("X-Test".to_string(), "value".to_string()),
888 ],
889 body: b"test response".to_vec(),
890 };
891
892 let serialized = rmp_serde::to_vec(&cached_response).unwrap();
894
895 let deserialized: CachedResponse = rmp_serde::from_slice(&serialized).unwrap();
897
898 assert_eq!(deserialized.status, 200);
900 assert_eq!(deserialized.headers.len(), 2);
901 assert_eq!(deserialized.headers[0].0, "Content-Type");
902 assert_eq!(deserialized.headers[0].1, "text/plain");
903 assert_eq!(deserialized.headers[1].0, "X-Test");
904 assert_eq!(deserialized.headers[1].1, "value");
905 assert_eq!(deserialized.body, b"test response");
906 }
907
908 #[actix_web::test]
909 async fn test_custom_cache_key() {
910 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
912 .with_cache_key(|ctx| format!("{}:{}", ctx.method, ctx.path));
913
914 let get_key = |method: &str, path: &str, query: &str, body: &[u8]| {
916 let headers = header::HeaderMap::new();
918 let body_json = serde_json::from_slice(body).unwrap_or(serde_json::Value::Null);
919 let ctx = CacheContext {
920 method,
921 path,
922 query_string: query,
923 headers: &headers,
924 body: &body_json,
925 };
926
927 let base_key = if let Some(key_fn) = &builder.cache_key_fn {
929 key_fn(&ctx)
930 } else {
931 format!("{}:{}:{}", method, path, query)
932 };
933
934 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
936 format!("{}:{}", builder.cache_prefix, hashed_key)
937 };
938
939 let key1 = get_key("GET", "/users", "", b"");
941 let key2 = get_key("GET", "/users", "page=1", b"");
942 let key3 = get_key("GET", "/users", "page=2", b"");
943
944 assert_eq!(key1, key2);
946 assert_eq!(key1, key3);
947
948 let key_get = get_key("GET", "/resource", "", b"");
950 let key_post = get_key("POST", "/resource", "", b"");
951
952 assert_ne!(key_get, key_post);
954 }
955}