1#![warn(missing_docs)]
2use actix_web::{
48 body::{BoxBody, EitherBody, MessageBody},
49 dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
50 http::header::HeaderMap,
51 web::{Bytes, BytesMut},
52 Error, HttpMessage,
53};
54use futures::{
55 future::{ready, LocalBoxFuture, Ready},
56 StreamExt,
57};
58use redis::{aio::MultiplexedConnection, AsyncCommands};
59use serde::{Deserialize, Serialize};
60use sha2::{Digest, Sha256};
61use std::rc::Rc;
62use std::sync::Arc;
63
64pub struct CacheDecisionContext<'a> {
69 pub method: &'a str,
71 pub path: &'a str,
73 pub query_string: &'a str,
75 pub headers: &'a HeaderMap,
77 pub body: &'a [u8],
79}
80
81type CachePredicate = Arc<dyn Fn(&CacheDecisionContext) -> bool + Send + Sync>;
86
87pub struct RedisCacheMiddleware {
92 redis_conn: Option<MultiplexedConnection>,
93 redis_url: String,
94 ttl: u64,
95 max_cacheable_size: usize,
96 cache_prefix: String,
97 cache_if: CachePredicate,
98}
99
100pub struct RedisCacheMiddlewareBuilder {
105 redis_url: String,
106 ttl: u64,
107 max_cacheable_size: usize,
108 cache_prefix: String,
109 cache_if: CachePredicate,
110}
111
112impl RedisCacheMiddlewareBuilder {
113 pub fn new(redis_url: impl Into<String>) -> Self {
127 Self {
128 redis_url: redis_url.into(),
129 ttl: 3600, max_cacheable_size: 1024 * 1024, cache_prefix: "cache:".to_string(),
132 cache_if: Arc::new(|_| true), }
134 }
135
136 pub fn ttl(mut self, seconds: u64) -> Self {
146 self.ttl = seconds;
147 self
148 }
149
150 pub fn max_cacheable_size(mut self, bytes: usize) -> Self {
162 self.max_cacheable_size = bytes;
163 self
164 }
165
166 pub fn cache_prefix(mut self, prefix: impl Into<String>) -> Self {
176 self.cache_prefix = prefix.into();
177 self
178 }
179
180 pub fn cache_if<F>(mut self, predicate: F) -> Self
211 where
212 F: Fn(&CacheDecisionContext) -> bool + Send + Sync + 'static,
213 {
214 self.cache_if = Arc::new(predicate);
215 self
216 }
217
218 pub fn build(self) -> RedisCacheMiddleware {
224 RedisCacheMiddleware {
225 redis_conn: None,
226 redis_url: self.redis_url,
227 ttl: self.ttl,
228 max_cacheable_size: self.max_cacheable_size,
229 cache_prefix: self.cache_prefix,
230 cache_if: self.cache_if,
231 }
232 }
233}
234
235impl RedisCacheMiddleware {
236 pub fn new(redis_url: &str) -> Self {
248 RedisCacheMiddlewareBuilder::new(redis_url).build()
249 }
250}
251
252pub struct RedisCacheMiddlewareService<S> {
257 service: Rc<S>,
258 redis_conn: Option<MultiplexedConnection>,
259 redis_url: String,
260 ttl: u64,
261 max_cacheable_size: usize,
262 cache_prefix: String,
263 cache_if: CachePredicate,
264}
265
266#[derive(Serialize, Deserialize)]
267struct CachedResponse {
268 status: u16,
269 headers: Vec<(String, String)>,
270 body: Vec<u8>,
271}
272
273impl<S, B> Transform<S, ServiceRequest> for RedisCacheMiddleware
274where
275 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
276 S::Future: 'static,
277 B: 'static + Clone + MessageBody,
278{
279 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
280 type Error = Error;
281 type Transform = RedisCacheMiddlewareService<S>;
282 type InitError = ();
283 type Future = Ready<Result<Self::Transform, Self::InitError>>;
284
285 fn new_transform(&self, service: S) -> Self::Future {
287 ready(Ok(RedisCacheMiddlewareService {
288 service: Rc::new(service),
289 redis_conn: self.redis_conn.clone(),
290 redis_url: self.redis_url.clone(),
291 ttl: self.ttl,
292 max_cacheable_size: self.max_cacheable_size,
293 cache_prefix: self.cache_prefix.clone(),
294 cache_if: self.cache_if.clone(),
295 }))
296 }
297}
298
299impl<S, B> Service<ServiceRequest> for RedisCacheMiddlewareService<S>
300where
301 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
302 S::Future: 'static,
303 B: actix_web::body::MessageBody + 'static + Clone,
304{
305 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
306 type Error = Error;
307 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
308
309 forward_ready!(service);
310
311 fn call(&self, mut req: ServiceRequest) -> Self::Future {
312 if let Some(cache_control) = req.headers().get("Cache-Control") {
313 if let Ok(cache_control_str) = cache_control.to_str() {
314 if cache_control_str.contains("no-cache") || cache_control_str.contains("no-store")
315 {
316 let fut = self.service.call(req);
317 return Box::pin(async move {
318 let res = fut.await?;
319 Ok(res.map_body(|_, b| EitherBody::left(b)))
320 });
321 }
322 }
323 }
324
325 let redis_url = self.redis_url.clone();
326 let mut redis_conn = self.redis_conn.clone();
327 let expiration = self.ttl;
328 let max_cacheable_size = self.max_cacheable_size;
329 let cache_prefix = self.cache_prefix.clone();
330 let service = Rc::clone(&self.service);
331 let cache_if = self.cache_if.clone();
332
333 Box::pin(async move {
334 let body_bytes = req
335 .take_payload()
336 .fold(BytesMut::new(), move |mut body, chunk| async {
337 if let Ok(chunk) = chunk {
338 body.extend_from_slice(&chunk);
339 }
340 body
341 })
342 .await;
343
344 let cache_ctx = CacheDecisionContext {
345 method: req.method().as_str(),
346 path: req.path(),
347 query_string: req.query_string(),
348 headers: req.headers(),
349 body: &body_bytes,
350 };
351
352 let should_cache = cache_if(&cache_ctx);
353
354 req.set_payload(Payload::from(Bytes::from(body_bytes.clone())));
355
356 let base_key = if body_bytes.is_empty() {
357 format!(
358 "{}:{}:{}",
359 req.method().as_str(),
360 req.path(),
361 req.query_string()
362 )
363 } else {
364 let body_hash = hex::encode(Sha256::digest(&body_bytes));
365 format!(
366 "{}:{}:{}:{}",
367 req.method().as_str(),
368 req.path(),
369 req.query_string(),
370 body_hash
371 )
372 };
373
374 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
375 let cache_key = format!("{}{}", cache_prefix, hashed_key);
376
377 let cached_result: Option<String> = if should_cache {
379 if redis_conn.is_none() {
380 let client = redis::Client::open(redis_url.as_str())
381 .expect("Failed to connect to Redis");
382
383 let conn = client
384 .get_multiplexed_async_connection()
385 .await
386 .expect("Failed to get Redis connection");
387
388 redis_conn = Some(conn);
389 }
390
391 let conn = redis_conn.as_mut().unwrap();
392 conn.get(&cache_key).await.unwrap_or(None)
393 } else {
394 None
395 };
396
397 if let Some(cached_data) = cached_result {
398 log::debug!("Cache hit for {}", cache_key);
399
400 match serde_json::from_str::<CachedResponse>(&cached_data) {
402 Ok(cached_response) => {
403 let mut response = actix_web::HttpResponse::build(
404 actix_web::http::StatusCode::from_u16(cached_response.status)
405 .unwrap_or(actix_web::http::StatusCode::OK),
406 );
407
408 for (name, value) in cached_response.headers {
409 response.insert_header((name, value));
410 }
411
412 response.insert_header(("X-Cache", "HIT"));
413
414 let resp = response.body(cached_response.body);
415 return Ok(req
416 .into_response(resp)
417 .map_body(|_, b| EitherBody::right(BoxBody::new(b))));
418 }
419 Err(e) => {
420 log::error!("Failed to deserialize cached response: {}", e);
421 }
422 }
423 }
424
425 log::debug!("Cache miss for {}", cache_key);
426
427 let service_result = service.call(req).await?;
428
429 if should_cache && service_result.status().is_success() {
431 let res = service_result.response();
432
433 let status = res.status().as_u16();
434
435 let headers = res
436 .headers()
437 .iter()
438 .filter(|(name, _)| {
439 !["connection", "transfer-encoding", "content-length"]
440 .contains(&name.as_str().to_lowercase().as_str())
441 })
442 .map(|(name, value)| {
443 (
444 name.to_string(),
445 value.to_str().unwrap_or_default().to_string(),
446 )
447 })
448 .collect::<Vec<_>>();
449
450 if let Ok(body) = res.body().clone().try_into_bytes() {
451 if !body.is_empty() && body.len() <= max_cacheable_size {
452 let cached_response = CachedResponse {
453 status,
454 headers,
455 body: body.to_vec(),
456 };
457
458 if let Ok(serialized) = serde_json::to_string(&cached_response) {
459 if redis_conn.is_none() {
460 let client = redis::Client::open(redis_url.as_str())
461 .expect("Failed to connect to Redis");
462
463 let conn = client
464 .get_multiplexed_async_connection()
465 .await
466 .expect("Failed to get Redis connection");
467
468 redis_conn = Some(conn);
469 }
470
471 let conn = redis_conn.as_mut().unwrap();
472 let _: Result<(), redis::RedisError> =
473 conn.set_ex(cache_key, serialized, expiration).await;
474 }
475 }
476 }
477 }
478
479 Ok(service_result.map_body(|_, b| EitherBody::left(b)))
480 })
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use actix_web::{http::header, test::TestRequest};
488
489 #[actix_web::test]
490 async fn test_builder_default_values() {
491 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
492 assert_eq!(builder.ttl, 3600);
493 assert_eq!(builder.max_cacheable_size, 1024 * 1024);
494 assert_eq!(builder.cache_prefix, "cache:");
495 assert_eq!(builder.redis_url, "redis://localhost");
496 }
497
498 #[actix_web::test]
499 async fn test_builder_custom_values() {
500 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
501 .ttl(60)
502 .max_cacheable_size(512 * 1024)
503 .cache_prefix("custom:");
504
505 assert_eq!(builder.ttl, 60);
506 assert_eq!(builder.max_cacheable_size, 512 * 1024);
507 assert_eq!(builder.cache_prefix, "custom:");
508 }
509
510 #[actix_web::test]
511 async fn test_builder_custom_predicate() {
512 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
513 .cache_if(|ctx| ctx.method == "GET");
514
515 let get_ctx = CacheDecisionContext {
517 method: "GET",
518 path: "/test",
519 query_string: "",
520 headers: &header::HeaderMap::new(),
521 body: &[],
522 };
523
524 let post_ctx = CacheDecisionContext {
525 method: "POST",
526 path: "/test",
527 query_string: "",
528 headers: &header::HeaderMap::new(),
529 body: &[],
530 };
531
532 assert!((builder.cache_if)(&get_ctx));
534 assert!(!(builder.cache_if)(&post_ctx));
535 }
536
537 #[actix_web::test]
538 async fn test_cache_key_generation() {
539 let req = TestRequest::get().uri("/test").to_srv_request();
541
542 let method = req.method().as_str();
544 let path = req.path();
545 let query_string = req.query_string();
546
547 let base_key = format!("{}:{}:{}", method, path, query_string);
549 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
550 let cache_key = format!("test:{}", hashed_key);
551
552 let expected_key = format!(
554 "test:{}",
555 hex::encode(Sha256::digest("GET:/test:".to_string().as_bytes()))
556 );
557
558 assert_eq!(cache_key, expected_key);
559 }
560
561 #[actix_web::test]
562 async fn test_cache_key_with_body() {
563 let body_bytes = b"test body";
565 let body_hash = hex::encode(Sha256::digest(body_bytes));
566
567 let base_key = format!("{}:{}:{}:{}", "POST", "/test", "", body_hash);
569 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
570 let cache_key = format!("test:{}", hashed_key);
571
572 let expected_key = format!(
574 "test:{}",
575 hex::encode(Sha256::digest(
576 format!("POST:/test::{}", body_hash).as_bytes()
577 ))
578 );
579
580 assert_eq!(cache_key, expected_key);
581 }
582
583 #[actix_web::test]
584 async fn test_cacheable_methods() {
585 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
587 let default_predicate = builder.cache_if;
588
589 let methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"];
590
591 for method in methods {
592 let ctx = CacheDecisionContext {
593 method,
594 path: "/test",
595 query_string: "",
596 headers: &header::HeaderMap::new(),
597 body: &[],
598 };
599
600 assert!(
602 (default_predicate)(&ctx),
603 "Method {} should be cacheable by default",
604 method
605 );
606 }
607
608 let custom_builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
610 .cache_if(|ctx| matches!(ctx.method, "GET" | "HEAD"));
611
612 for method in methods {
613 let ctx = CacheDecisionContext {
614 method,
615 path: "/test",
616 query_string: "",
617 headers: &header::HeaderMap::new(),
618 body: &[],
619 };
620
621 let should_cache = matches!(method, "GET" | "HEAD");
623 assert_eq!(
624 (custom_builder.cache_if)(&ctx),
625 should_cache,
626 "Method {} should be cacheable: {}",
627 method,
628 should_cache
629 );
630 }
631 }
632
633 #[actix_web::test]
634 async fn test_predicate_with_headers() {
635 let predicate = |ctx: &CacheDecisionContext| !ctx.headers.contains_key("Authorization");
639
640 let mut headers = header::HeaderMap::new();
642 let ctx_no_auth = CacheDecisionContext {
643 method: "GET",
644 path: "/test",
645 query_string: "",
646 headers: &headers,
647 body: &[],
648 };
649
650 assert!(
651 predicate(&ctx_no_auth),
652 "Request without Authorization should be cached"
653 );
654
655 headers.insert(
657 header::AUTHORIZATION,
658 header::HeaderValue::from_static("Bearer token"),
659 );
660
661 let ctx_with_auth = CacheDecisionContext {
662 method: "GET",
663 path: "/test",
664 query_string: "",
665 headers: &headers,
666 body: &[],
667 };
668
669 assert!(
670 !predicate(&ctx_with_auth),
671 "Request with Authorization should not be cached"
672 );
673 }
674
675 #[actix_web::test]
676 async fn test_predicate_with_path_patterns() {
677 let predicate = |ctx: &CacheDecisionContext| {
681 !ctx.path.starts_with("/admin") && !ctx.path.contains("/private/")
682 };
683
684 let cacheable_paths = ["/", "/api/users", "/public/resource", "/api/v1/data"];
686
687 for path in cacheable_paths {
688 let ctx = CacheDecisionContext {
689 method: "GET",
690 path,
691 query_string: "",
692 headers: &header::HeaderMap::new(),
693 body: &[],
694 };
695
696 assert!(predicate(&ctx), "Path {} should be cacheable", path);
697 }
698
699 let non_cacheable_paths = ["/admin", "/admin/users", "/users/private/profile"];
701
702 for path in non_cacheable_paths {
703 let ctx = CacheDecisionContext {
704 method: "GET",
705 path,
706 query_string: "",
707 headers: &header::HeaderMap::new(),
708 body: &[],
709 };
710
711 assert!(!predicate(&ctx), "Path {} should not be cacheable", path);
712 }
713 }
714
715 #[actix_web::test]
716 async fn test_cached_response_serialization() {
717 let cached_response = CachedResponse {
719 status: 200,
720 headers: vec![
721 ("Content-Type".to_string(), "text/plain".to_string()),
722 ("X-Test".to_string(), "value".to_string()),
723 ],
724 body: b"test response".to_vec(),
725 };
726
727 let serialized = serde_json::to_string(&cached_response).unwrap();
729
730 let deserialized: CachedResponse = serde_json::from_str(&serialized).unwrap();
732
733 assert_eq!(deserialized.status, 200);
735 assert_eq!(deserialized.headers.len(), 2);
736 assert_eq!(deserialized.headers[0].0, "Content-Type");
737 assert_eq!(deserialized.headers[0].1, "text/plain");
738 assert_eq!(deserialized.headers[1].0, "X-Test");
739 assert_eq!(deserialized.headers[1].1, "value");
740 assert_eq!(deserialized.body, b"test response");
741 }
742}