1#![warn(missing_docs)]
2use actix_web::{
48 body::{BodySize, BoxBody, EitherBody, MessageBody},
49 dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
50 http::header::HeaderMap,
51 web::{Bytes, BytesMut},
52 Error, HttpMessage,
53};
54use futures::{
55 future::{ready, LocalBoxFuture, Ready},
56 StreamExt,
57};
58use pin_project_lite::pin_project;
59use redis::{aio::MultiplexedConnection, AsyncCommands};
60use serde::{Deserialize, Serialize};
61use sha2::{Digest, Sha256};
62use std::{future::Future, marker::PhantomData, pin::Pin, rc::Rc};
63use std::{
64 sync::Arc,
65 task::{Context, Poll},
66};
67
68pub struct CacheDecisionContext<'a> {
73 pub method: &'a str,
75 pub path: &'a str,
77 pub query_string: &'a str,
79 pub headers: &'a HeaderMap,
81 pub body: &'a [u8],
83}
84
85type CachePredicate = Arc<dyn Fn(&CacheDecisionContext) -> bool + Send + Sync>;
90
91pub struct RedisCacheMiddleware {
96 redis_conn: Option<MultiplexedConnection>,
97 redis_url: String,
98 ttl: u64,
99 max_cacheable_size: usize,
100 cache_prefix: String,
101 cache_if: CachePredicate,
102}
103
104pub struct RedisCacheMiddlewareBuilder {
109 redis_url: String,
110 ttl: u64,
111 max_cacheable_size: usize,
112 cache_prefix: String,
113 cache_if: CachePredicate,
114}
115
116impl RedisCacheMiddlewareBuilder {
117 pub fn new(redis_url: impl Into<String>) -> Self {
131 Self {
132 redis_url: redis_url.into(),
133 ttl: 3600, max_cacheable_size: 1024 * 1024, cache_prefix: "cache:".to_string(),
136 cache_if: Arc::new(|_| true), }
138 }
139
140 pub fn ttl(mut self, seconds: u64) -> Self {
150 self.ttl = seconds;
151 self
152 }
153
154 pub fn max_cacheable_size(mut self, bytes: usize) -> Self {
166 self.max_cacheable_size = bytes;
167 self
168 }
169
170 pub fn cache_prefix(mut self, prefix: impl Into<String>) -> Self {
180 self.cache_prefix = prefix.into();
181 self
182 }
183
184 pub fn cache_if<F>(mut self, predicate: F) -> Self
215 where
216 F: Fn(&CacheDecisionContext) -> bool + Send + Sync + 'static,
217 {
218 self.cache_if = Arc::new(predicate);
219 self
220 }
221
222 pub fn build(self) -> RedisCacheMiddleware {
228 RedisCacheMiddleware {
229 redis_conn: None,
230 redis_url: self.redis_url,
231 ttl: self.ttl,
232 max_cacheable_size: self.max_cacheable_size,
233 cache_prefix: self.cache_prefix,
234 cache_if: self.cache_if,
235 }
236 }
237}
238
239impl RedisCacheMiddleware {
240 pub fn new(redis_url: &str) -> Self {
252 RedisCacheMiddlewareBuilder::new(redis_url).build()
253 }
254}
255
256pub struct RedisCacheMiddlewareService<S> {
261 service: Rc<S>,
262 redis_conn: Option<MultiplexedConnection>,
263 redis_url: String,
264 ttl: u64,
265 max_cacheable_size: usize,
266 cache_prefix: String,
267 cache_if: CachePredicate,
268}
269
270#[derive(Serialize, Deserialize)]
271struct CachedResponse {
272 status: u16,
273 headers: Vec<(String, String)>,
274 body: Vec<u8>,
275}
276
277impl<S, B> Transform<S, ServiceRequest> for RedisCacheMiddleware
278where
279 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
280 S::Future: 'static,
281 B: 'static + MessageBody,
282{
283 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
284 type Error = Error;
285 type Transform = RedisCacheMiddlewareService<S>;
286 type InitError = ();
287 type Future = Ready<Result<Self::Transform, Self::InitError>>;
288
289 fn new_transform(&self, service: S) -> Self::Future {
291 ready(Ok(RedisCacheMiddlewareService {
292 service: Rc::new(service),
293 redis_conn: self.redis_conn.clone(),
294 redis_url: self.redis_url.clone(),
295 ttl: self.ttl,
296 max_cacheable_size: self.max_cacheable_size,
297 cache_prefix: self.cache_prefix.clone(),
298 cache_if: self.cache_if.clone(),
299 }))
300 }
301}
302
303pin_project! {
305 struct CacheResponseFuture<S, B>
306 where
307 B: MessageBody,
308 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
309 {
310 #[pin]
311 fut: S::Future,
312 should_cache: bool,
313 cache_key: String,
314 redis_conn: Option<MultiplexedConnection>,
315 redis_url: String,
316 ttl: u64,
317 max_cacheable_size: usize,
318 _marker: PhantomData<B>,
319 }
320}
321
322impl<S, B> Future for CacheResponseFuture<S, B>
324where
325 B: MessageBody + 'static,
326 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
327{
328 type Output = Result<ServiceResponse<EitherBody<B, BoxBody>>, Error>;
329
330 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
331 let this = self.project();
332
333 let res = futures_util::ready!(this.fut.poll(cx))?;
334
335 let status = res.status();
336 let headers = res.headers().clone();
337 let should_cache = *this.should_cache && status.is_success();
338
339 if !should_cache {
340 return Poll::Ready(Ok(res.map_body(|_, b| EitherBody::left(b))));
341 }
342
343 let cache_key = this.cache_key.clone();
344 let redis_url = this.redis_url.clone();
345 let redis_conn = this.redis_conn.clone();
346 let ttl = *this.ttl;
347 let max_size = *this.max_cacheable_size;
348
349 let res = res.map_body(move |_, body| {
350 let filtered_headers = headers
351 .iter()
352 .filter(|(name, _)| {
353 !["connection", "transfer-encoding", "content-length"]
354 .contains(&name.as_str().to_lowercase().as_str())
355 })
356 .map(|(name, value)| {
357 (
358 name.to_string(),
359 value.to_str().unwrap_or_default().to_string(),
360 )
361 })
362 .collect::<Vec<_>>();
363
364 EitherBody::right(BoxBody::new(CacheableBody {
365 body: body.boxed(),
366 status: status.as_u16(),
367 headers: filtered_headers,
368 body_accum: BytesMut::new(),
369 cache_key,
370 redis_conn,
371 redis_url,
372 ttl,
373 max_size,
374 }))
375 });
376
377 Poll::Ready(Ok(res))
378 }
379}
380
381pin_project! {
383 struct CacheableBody {
384 #[pin]
385 body: BoxBody,
386 status: u16,
387 headers: Vec<(String, String)>,
388 body_accum: BytesMut,
389 cache_key: String,
390 redis_conn: Option<MultiplexedConnection>,
391 redis_url: String,
392 ttl: u64,
393 max_size: usize,
394 }
395
396 impl PinnedDrop for CacheableBody {
397 fn drop(this: Pin<&mut Self>) {
398 let this = this.project();
399
400 let body_bytes = this.body_accum.clone().freeze();
401 let status = *this.status;
402 let headers = this.headers.clone();
403 let cache_key = this.cache_key.clone();
404 let mut redis_conn = this.redis_conn.take();
405 let redis_url = this.redis_url.clone();
406 let ttl = *this.ttl;
407 let max_size = *this.max_size;
408
409 if !body_bytes.is_empty() && body_bytes.len() <= max_size {
410 actix_web::rt::spawn(async move {
411 let cached_response = CachedResponse {
412 status,
413 headers,
414 body: body_bytes.to_vec(),
415 };
416
417 if let Ok(serialized) = serde_json::to_string(&cached_response) {
418 if redis_conn.is_none() {
419 let client = redis::Client::open(redis_url.as_str())
420 .expect("Failed to connect to Redis");
421
422 let conn = client
423 .get_multiplexed_async_connection()
424 .await
425 .expect("Failed to get Redis connection");
426
427 redis_conn = Some(conn);
428 }
429
430 if let Some(conn) = redis_conn.as_mut() {
431 let _: Result<(), redis::RedisError> =
432 conn.set_ex(cache_key, serialized, ttl).await;
433 }
434 }
435 });
436 }
437 }
438 }
439}
440
441impl MessageBody for CacheableBody {
442 type Error = <BoxBody as MessageBody>::Error;
443
444 fn size(&self) -> BodySize {
445 self.body.size()
446 }
447
448 fn poll_next(
449 self: Pin<&mut Self>,
450 cx: &mut Context<'_>,
451 ) -> Poll<Option<Result<Bytes, Self::Error>>> {
452 let this = self.project();
453
454 match this.body.poll_next(cx) {
456 Poll::Ready(Some(Ok(chunk))) => {
457 this.body_accum.extend_from_slice(&chunk);
458 Poll::Ready(Some(Ok(chunk)))
459 }
460 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
461 Poll::Ready(None) => Poll::Ready(None),
462 Poll::Pending => Poll::Pending,
463 }
464 }
465}
466
467impl<S, B> Service<ServiceRequest> for RedisCacheMiddlewareService<S>
468where
469 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
470 S::Future: 'static,
471 B: MessageBody + 'static,
472{
473 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
474 type Error = Error;
475 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
476
477 forward_ready!(service);
478
479 fn call(&self, mut req: ServiceRequest) -> Self::Future {
480 if let Some(cache_control) = req.headers().get("Cache-Control") {
482 if let Ok(cache_control_str) = cache_control.to_str() {
483 if cache_control_str.contains("no-cache") || cache_control_str.contains("no-store")
484 {
485 let fut = self.service.call(req);
486 return Box::pin(async move {
487 let res = fut.await?;
488 Ok(res.map_body(|_, b| EitherBody::left(b)))
489 });
490 }
491 }
492 }
493
494 let redis_url = self.redis_url.clone();
495 let mut redis_conn = self.redis_conn.clone();
496 let expiration = self.ttl;
497 let max_cacheable_size = self.max_cacheable_size;
498 let cache_prefix = self.cache_prefix.clone();
499 let service = Rc::clone(&self.service);
500 let cache_if = self.cache_if.clone();
501
502 Box::pin(async move {
503 let body_bytes = req
504 .take_payload()
505 .fold(BytesMut::new(), move |mut body, chunk| async {
506 if let Ok(chunk) = chunk {
507 body.extend_from_slice(&chunk);
508 }
509 body
510 })
511 .await;
512
513 let cache_ctx = CacheDecisionContext {
514 method: req.method().as_str(),
515 path: req.path(),
516 query_string: req.query_string(),
517 headers: req.headers(),
518 body: &body_bytes,
519 };
520
521 let should_cache = cache_if(&cache_ctx);
522
523 req.set_payload(Payload::from(Bytes::from(body_bytes.clone())));
524
525 let base_key = if body_bytes.is_empty() {
527 format!(
528 "{}:{}:{}",
529 req.method().as_str(),
530 req.path(),
531 req.query_string()
532 )
533 } else {
534 let body_hash = hex::encode(Sha256::digest(&body_bytes));
535 format!(
536 "{}:{}:{}:{}",
537 req.method().as_str(),
538 req.path(),
539 req.query_string(),
540 body_hash
541 )
542 };
543
544 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
545 let cache_key = format!("{}{}", cache_prefix, hashed_key);
546
547 let cached_result: Option<String> = if should_cache {
548 if redis_conn.is_none() {
549 let client = redis::Client::open(redis_url.as_str())
550 .expect("Failed to connect to Redis");
551
552 let conn = client
553 .get_multiplexed_async_connection()
554 .await
555 .expect("Failed to get Redis connection");
556
557 redis_conn = Some(conn);
558 }
559
560 let conn = redis_conn.as_mut().unwrap();
561 conn.get(&cache_key).await.unwrap_or(None)
562 } else {
563 None
564 };
565
566 if let Some(cached_data) = cached_result {
567 log::debug!("Cache hit for {}", cache_key);
568
569 match serde_json::from_str::<CachedResponse>(&cached_data) {
570 Ok(cached_response) => {
571 let mut response = actix_web::HttpResponse::build(
572 actix_web::http::StatusCode::from_u16(cached_response.status)
573 .unwrap_or(actix_web::http::StatusCode::OK),
574 );
575
576 for (name, value) in cached_response.headers {
577 response.insert_header((name, value));
578 }
579
580 response.insert_header(("X-Cache", "HIT"));
581
582 let resp = response.body(cached_response.body);
583 return Ok(req
584 .into_response(resp)
585 .map_body(|_, b| EitherBody::right(BoxBody::new(b))));
586 }
587 Err(e) => {
588 log::error!("Failed to deserialize cached response: {}", e);
589 }
590 }
591 }
592
593 log::debug!("Cache miss for {}", cache_key);
594 let future = CacheResponseFuture::<S, B> {
595 fut: service.call(req),
596 should_cache,
597 cache_key,
598 redis_conn,
599 redis_url,
600 ttl: expiration,
601 max_cacheable_size,
602 _marker: PhantomData,
603 };
604
605 future.await
606 })
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613 use actix_web::{http::header, test::TestRequest};
614
615 #[actix_web::test]
616 async fn test_builder_default_values() {
617 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
618 assert_eq!(builder.ttl, 3600);
619 assert_eq!(builder.max_cacheable_size, 1024 * 1024);
620 assert_eq!(builder.cache_prefix, "cache:");
621 assert_eq!(builder.redis_url, "redis://localhost");
622 }
623
624 #[actix_web::test]
625 async fn test_builder_custom_values() {
626 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
627 .ttl(60)
628 .max_cacheable_size(512 * 1024)
629 .cache_prefix("custom:");
630
631 assert_eq!(builder.ttl, 60);
632 assert_eq!(builder.max_cacheable_size, 512 * 1024);
633 assert_eq!(builder.cache_prefix, "custom:");
634 }
635
636 #[actix_web::test]
637 async fn test_builder_custom_predicate() {
638 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
639 .cache_if(|ctx| ctx.method == "GET");
640
641 let get_ctx = CacheDecisionContext {
643 method: "GET",
644 path: "/test",
645 query_string: "",
646 headers: &header::HeaderMap::new(),
647 body: &[],
648 };
649
650 let post_ctx = CacheDecisionContext {
651 method: "POST",
652 path: "/test",
653 query_string: "",
654 headers: &header::HeaderMap::new(),
655 body: &[],
656 };
657
658 assert!((builder.cache_if)(&get_ctx));
660 assert!(!(builder.cache_if)(&post_ctx));
661 }
662
663 #[actix_web::test]
664 async fn test_cache_key_generation() {
665 let req = TestRequest::get().uri("/test").to_srv_request();
667
668 let method = req.method().as_str();
670 let path = req.path();
671 let query_string = req.query_string();
672
673 let base_key = format!("{}:{}:{}", method, path, query_string);
675 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
676 let cache_key = format!("test:{}", hashed_key);
677
678 let expected_key = format!(
680 "test:{}",
681 hex::encode(Sha256::digest("GET:/test:".to_string().as_bytes()))
682 );
683
684 assert_eq!(cache_key, expected_key);
685 }
686
687 #[actix_web::test]
688 async fn test_cache_key_with_body() {
689 let body_bytes = b"test body";
691 let body_hash = hex::encode(Sha256::digest(body_bytes));
692
693 let base_key = format!("{}:{}:{}:{}", "POST", "/test", "", body_hash);
695 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
696 let cache_key = format!("test:{}", hashed_key);
697
698 let expected_key = format!(
700 "test:{}",
701 hex::encode(Sha256::digest(
702 format!("POST:/test::{}", body_hash).as_bytes()
703 ))
704 );
705
706 assert_eq!(cache_key, expected_key);
707 }
708
709 #[actix_web::test]
710 async fn test_cacheable_methods() {
711 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
713 let default_predicate = builder.cache_if;
714
715 let methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"];
716
717 for method in methods {
718 let ctx = CacheDecisionContext {
719 method,
720 path: "/test",
721 query_string: "",
722 headers: &header::HeaderMap::new(),
723 body: &[],
724 };
725
726 assert!(
728 (default_predicate)(&ctx),
729 "Method {} should be cacheable by default",
730 method
731 );
732 }
733
734 let custom_builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
736 .cache_if(|ctx| matches!(ctx.method, "GET" | "HEAD"));
737
738 for method in methods {
739 let ctx = CacheDecisionContext {
740 method,
741 path: "/test",
742 query_string: "",
743 headers: &header::HeaderMap::new(),
744 body: &[],
745 };
746
747 let should_cache = matches!(method, "GET" | "HEAD");
749 assert_eq!(
750 (custom_builder.cache_if)(&ctx),
751 should_cache,
752 "Method {} should be cacheable: {}",
753 method,
754 should_cache
755 );
756 }
757 }
758
759 #[actix_web::test]
760 async fn test_predicate_with_headers() {
761 let predicate = |ctx: &CacheDecisionContext| !ctx.headers.contains_key("Authorization");
765
766 let mut headers = header::HeaderMap::new();
768 let ctx_no_auth = CacheDecisionContext {
769 method: "GET",
770 path: "/test",
771 query_string: "",
772 headers: &headers,
773 body: &[],
774 };
775
776 assert!(
777 predicate(&ctx_no_auth),
778 "Request without Authorization should be cached"
779 );
780
781 headers.insert(
783 header::AUTHORIZATION,
784 header::HeaderValue::from_static("Bearer token"),
785 );
786
787 let ctx_with_auth = CacheDecisionContext {
788 method: "GET",
789 path: "/test",
790 query_string: "",
791 headers: &headers,
792 body: &[],
793 };
794
795 assert!(
796 !predicate(&ctx_with_auth),
797 "Request with Authorization should not be cached"
798 );
799 }
800
801 #[actix_web::test]
802 async fn test_predicate_with_path_patterns() {
803 let predicate = |ctx: &CacheDecisionContext| {
807 !ctx.path.starts_with("/admin") && !ctx.path.contains("/private/")
808 };
809
810 let cacheable_paths = ["/", "/api/users", "/public/resource", "/api/v1/data"];
812
813 for path in cacheable_paths {
814 let ctx = CacheDecisionContext {
815 method: "GET",
816 path,
817 query_string: "",
818 headers: &header::HeaderMap::new(),
819 body: &[],
820 };
821
822 assert!(predicate(&ctx), "Path {} should be cacheable", path);
823 }
824
825 let non_cacheable_paths = ["/admin", "/admin/users", "/users/private/profile"];
827
828 for path in non_cacheable_paths {
829 let ctx = CacheDecisionContext {
830 method: "GET",
831 path,
832 query_string: "",
833 headers: &header::HeaderMap::new(),
834 body: &[],
835 };
836
837 assert!(!predicate(&ctx), "Path {} should not be cacheable", path);
838 }
839 }
840
841 #[actix_web::test]
842 async fn test_cached_response_serialization() {
843 let cached_response = CachedResponse {
845 status: 200,
846 headers: vec![
847 ("Content-Type".to_string(), "text/plain".to_string()),
848 ("X-Test".to_string(), "value".to_string()),
849 ],
850 body: b"test response".to_vec(),
851 };
852
853 let serialized = serde_json::to_string(&cached_response).unwrap();
855
856 let deserialized: CachedResponse = serde_json::from_str(&serialized).unwrap();
858
859 assert_eq!(deserialized.status, 200);
861 assert_eq!(deserialized.headers.len(), 2);
862 assert_eq!(deserialized.headers[0].0, "Content-Type");
863 assert_eq!(deserialized.headers[0].1, "text/plain");
864 assert_eq!(deserialized.headers[1].0, "X-Test");
865 assert_eq!(deserialized.headers[1].1, "value");
866 assert_eq!(deserialized.body, b"test response");
867 }
868}