1#![warn(missing_docs)]
2use actix_web::{
48 body::{BoxBody, EitherBody, MessageBody},
49 dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
50 http::header::HeaderMap,
51 web::{Bytes, BytesMut},
52 Error, HttpMessage,
53};
54use futures::{
55 future::{ready, LocalBoxFuture, Ready},
56 StreamExt,
57};
58use redis::{aio::MultiplexedConnection, AsyncCommands};
59use serde::{Deserialize, Serialize};
60use sha2::{Digest, Sha256};
61use std::sync::Arc;
62
63pub struct CacheDecisionContext<'a> {
68 pub method: &'a str,
70 pub path: &'a str,
72 pub query_string: &'a str,
74 pub headers: &'a HeaderMap,
76 pub body: &'a [u8],
78}
79
80type CachePredicate = Arc<dyn Fn(&CacheDecisionContext) -> bool + Send + Sync>;
85
86pub struct RedisCacheMiddleware {
91 redis_conn: Option<MultiplexedConnection>,
92 redis_url: String,
93 ttl: u64,
94 max_cacheable_size: usize,
95 cache_prefix: String,
96 cache_if: CachePredicate,
97}
98
99pub struct RedisCacheMiddlewareBuilder {
104 redis_url: String,
105 ttl: u64,
106 max_cacheable_size: usize,
107 cache_prefix: String,
108 cache_if: CachePredicate,
109}
110
111impl RedisCacheMiddlewareBuilder {
112 pub fn new(redis_url: impl Into<String>) -> Self {
126 Self {
127 redis_url: redis_url.into(),
128 ttl: 3600, max_cacheable_size: 1024 * 1024, cache_prefix: "cache:".to_string(),
131 cache_if: Arc::new(|_| true), }
133 }
134
135 pub fn ttl(mut self, seconds: u64) -> Self {
145 self.ttl = seconds;
146 self
147 }
148
149 pub fn max_cacheable_size(mut self, bytes: usize) -> Self {
161 self.max_cacheable_size = bytes;
162 self
163 }
164
165 pub fn cache_prefix(mut self, prefix: impl Into<String>) -> Self {
175 self.cache_prefix = prefix.into();
176 self
177 }
178
179 pub fn cache_if<F>(mut self, predicate: F) -> Self
210 where
211 F: Fn(&CacheDecisionContext) -> bool + Send + Sync + 'static,
212 {
213 self.cache_if = Arc::new(predicate);
214 self
215 }
216
217 pub fn build(self) -> RedisCacheMiddleware {
223 RedisCacheMiddleware {
224 redis_conn: None,
225 redis_url: self.redis_url,
226 ttl: self.ttl,
227 max_cacheable_size: self.max_cacheable_size,
228 cache_prefix: self.cache_prefix,
229 cache_if: self.cache_if,
230 }
231 }
232}
233
234impl RedisCacheMiddleware {
235 pub fn new(redis_url: &str) -> Self {
247 RedisCacheMiddlewareBuilder::new(redis_url).build()
248 }
249}
250
251pub struct RedisCacheMiddlewareService<S> {
256 service: S,
257 redis_conn: Option<MultiplexedConnection>,
258 redis_url: String,
259 ttl: u64,
260 max_cacheable_size: usize,
261 cache_prefix: String,
262 cache_if: CachePredicate,
263}
264
265#[derive(Serialize, Deserialize)]
266struct CachedResponse {
267 status: u16,
268 headers: Vec<(String, String)>,
269 body: Vec<u8>,
270}
271
272impl<S, B> Transform<S, ServiceRequest> for RedisCacheMiddleware
273where
274 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
275 S::Future: 'static,
276 B: 'static + Clone + MessageBody,
277{
278 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
279 type Error = Error;
280 type Transform = RedisCacheMiddlewareService<S>;
281 type InitError = ();
282 type Future = Ready<Result<Self::Transform, Self::InitError>>;
283
284 fn new_transform(&self, service: S) -> Self::Future {
285 ready(Ok(RedisCacheMiddlewareService {
286 service,
287 redis_conn: self.redis_conn.clone(),
288 redis_url: self.redis_url.clone(),
289 ttl: self.ttl,
290 max_cacheable_size: self.max_cacheable_size,
291 cache_prefix: self.cache_prefix.clone(),
292 cache_if: self.cache_if.clone(),
293 }))
294 }
295}
296
297impl<S, B> Service<ServiceRequest> for RedisCacheMiddlewareService<S>
298where
299 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
300 S::Future: 'static,
301 B: actix_web::body::MessageBody + 'static + Clone,
302{
303 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
304 type Error = Error;
305 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
306
307 forward_ready!(service);
308
309 fn call(&self, mut req: ServiceRequest) -> Self::Future {
310 if let Some(cache_control) = req.headers().get("Cache-Control") {
311 if let Ok(cache_control_str) = cache_control.to_str() {
312 if cache_control_str.contains("no-cache") || cache_control_str.contains("no-store")
313 {
314 let fut = self.service.call(req);
315 return Box::pin(async move {
316 let res = fut.await?;
317 Ok(res.map_body(|_, b| EitherBody::left(b)))
318 });
319 }
320 }
321 }
322
323 let redis_url = self.redis_url.clone();
324 let mut redis_conn = self.redis_conn.clone();
325 let expiration = self.ttl;
326 let max_cacheable_size = self.max_cacheable_size;
327 let cache_prefix = self.cache_prefix.clone();
328 let service = self.service.clone();
329 let cache_if = self.cache_if.clone();
330
331 Box::pin(async move {
332 let body_bytes = req
333 .take_payload()
334 .fold(BytesMut::new(), move |mut body, chunk| async {
335 if let Ok(chunk) = chunk {
336 body.extend_from_slice(&chunk);
337 }
338 body
339 })
340 .await;
341
342 let cache_ctx = CacheDecisionContext {
343 method: req.method().as_str(),
344 path: req.path(),
345 query_string: req.query_string(),
346 headers: req.headers(),
347 body: &body_bytes,
348 };
349
350 let should_cache = cache_if(&cache_ctx);
351
352 req.set_payload(Payload::from(Bytes::from(body_bytes.clone())));
353
354 let base_key = if body_bytes.is_empty() {
355 format!(
356 "{}:{}:{}",
357 req.method().as_str(),
358 req.path(),
359 req.query_string()
360 )
361 } else {
362 let body_hash = hex::encode(Sha256::digest(&body_bytes));
363 format!(
364 "{}:{}:{}:{}",
365 req.method().as_str(),
366 req.path(),
367 req.query_string(),
368 body_hash
369 )
370 };
371
372 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
373 let cache_key = format!("{}{}", cache_prefix, hashed_key);
374
375 let cached_result: Option<String> = if should_cache {
377 if redis_conn.is_none() {
378 let client = redis::Client::open(redis_url.as_str())
379 .expect("Failed to connect to Redis");
380
381 let conn = client
382 .get_multiplexed_async_connection()
383 .await
384 .expect("Failed to get Redis connection");
385
386 redis_conn = Some(conn);
387 }
388
389 let conn = redis_conn.as_mut().unwrap();
390 conn.get(&cache_key).await.unwrap_or(None)
391 } else {
392 None
393 };
394
395 if let Some(cached_data) = cached_result {
396 log::debug!("Cache hit for {}", cache_key);
397
398 match serde_json::from_str::<CachedResponse>(&cached_data) {
400 Ok(cached_response) => {
401 let mut response = actix_web::HttpResponse::build(
402 actix_web::http::StatusCode::from_u16(cached_response.status)
403 .unwrap_or(actix_web::http::StatusCode::OK),
404 );
405
406 for (name, value) in cached_response.headers {
407 response.insert_header((name, value));
408 }
409
410 response.insert_header(("X-Cache", "HIT"));
411
412 let resp = response.body(cached_response.body);
413 return Ok(req
414 .into_response(resp)
415 .map_body(|_, b| EitherBody::right(BoxBody::new(b))));
416 }
417 Err(e) => {
418 log::error!("Failed to deserialize cached response: {}", e);
419 }
420 }
421 }
422
423 log::debug!("Cache miss for {}", cache_key);
424
425 let service_result = service.call(req).await?;
426
427 if should_cache && service_result.status().is_success() {
429 let res = service_result.response();
430
431 let status = res.status().as_u16();
432
433 let headers = res
434 .headers()
435 .iter()
436 .filter(|(name, _)| {
437 !["connection", "transfer-encoding", "content-length"]
438 .contains(&name.as_str().to_lowercase().as_str())
439 })
440 .map(|(name, value)| {
441 (
442 name.to_string(),
443 value.to_str().unwrap_or_default().to_string(),
444 )
445 })
446 .collect::<Vec<_>>();
447
448 if let Ok(body) = res.body().clone().try_into_bytes() {
449 if !body.is_empty() && body.len() <= max_cacheable_size {
450 let cached_response = CachedResponse {
451 status,
452 headers,
453 body: body.to_vec(),
454 };
455
456 if let Ok(serialized) = serde_json::to_string(&cached_response) {
457 if redis_conn.is_none() {
458 let client = redis::Client::open(redis_url.as_str())
459 .expect("Failed to connect to Redis");
460
461 let conn = client
462 .get_multiplexed_async_connection()
463 .await
464 .expect("Failed to get Redis connection");
465
466 redis_conn = Some(conn);
467 }
468
469 let conn = redis_conn.as_mut().unwrap();
470 let _: Result<(), redis::RedisError> =
471 conn.set_ex(cache_key, serialized, expiration).await;
472 }
473 }
474 }
475 }
476
477 Ok(service_result.map_body(|_, b| EitherBody::left(b)))
478 })
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use actix_web::{http::header, test::TestRequest};
486
487 #[actix_web::test]
488 async fn test_builder_default_values() {
489 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
490 assert_eq!(builder.ttl, 3600);
491 assert_eq!(builder.max_cacheable_size, 1024 * 1024);
492 assert_eq!(builder.cache_prefix, "cache:");
493 assert_eq!(builder.redis_url, "redis://localhost");
494 }
495
496 #[actix_web::test]
497 async fn test_builder_custom_values() {
498 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
499 .ttl(60)
500 .max_cacheable_size(512 * 1024)
501 .cache_prefix("custom:");
502
503 assert_eq!(builder.ttl, 60);
504 assert_eq!(builder.max_cacheable_size, 512 * 1024);
505 assert_eq!(builder.cache_prefix, "custom:");
506 }
507
508 #[actix_web::test]
509 async fn test_builder_custom_predicate() {
510 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
511 .cache_if(|ctx| ctx.method == "GET");
512
513 let get_ctx = CacheDecisionContext {
515 method: "GET",
516 path: "/test",
517 query_string: "",
518 headers: &header::HeaderMap::new(),
519 body: &[],
520 };
521
522 let post_ctx = CacheDecisionContext {
523 method: "POST",
524 path: "/test",
525 query_string: "",
526 headers: &header::HeaderMap::new(),
527 body: &[],
528 };
529
530 assert!((builder.cache_if)(&get_ctx));
532 assert!(!(builder.cache_if)(&post_ctx));
533 }
534
535 #[actix_web::test]
536 async fn test_cache_key_generation() {
537 let req = TestRequest::get().uri("/test").to_srv_request();
539
540 let method = req.method().as_str();
542 let path = req.path();
543 let query_string = req.query_string();
544
545 let base_key = format!("{}:{}:{}", method, path, query_string);
547 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
548 let cache_key = format!("test:{}", hashed_key);
549
550 let expected_key = format!(
552 "test:{}",
553 hex::encode(Sha256::digest("GET:/test:".to_string().as_bytes()))
554 );
555
556 assert_eq!(cache_key, expected_key);
557 }
558
559 #[actix_web::test]
560 async fn test_cache_key_with_body() {
561 let body_bytes = b"test body";
563 let body_hash = hex::encode(Sha256::digest(body_bytes));
564
565 let base_key = format!("{}:{}:{}:{}", "POST", "/test", "", body_hash);
567 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
568 let cache_key = format!("test:{}", hashed_key);
569
570 let expected_key = format!(
572 "test:{}",
573 hex::encode(Sha256::digest(
574 format!("POST:/test::{}", body_hash).as_bytes()
575 ))
576 );
577
578 assert_eq!(cache_key, expected_key);
579 }
580
581 #[actix_web::test]
582 async fn test_cacheable_methods() {
583 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
585 let default_predicate = builder.cache_if;
586
587 let methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"];
588
589 for method in methods {
590 let ctx = CacheDecisionContext {
591 method,
592 path: "/test",
593 query_string: "",
594 headers: &header::HeaderMap::new(),
595 body: &[],
596 };
597
598 assert!(
600 (default_predicate)(&ctx),
601 "Method {} should be cacheable by default",
602 method
603 );
604 }
605
606 let custom_builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
608 .cache_if(|ctx| matches!(ctx.method, "GET" | "HEAD"));
609
610 for method in methods {
611 let ctx = CacheDecisionContext {
612 method,
613 path: "/test",
614 query_string: "",
615 headers: &header::HeaderMap::new(),
616 body: &[],
617 };
618
619 let should_cache = matches!(method, "GET" | "HEAD");
621 assert_eq!(
622 (custom_builder.cache_if)(&ctx),
623 should_cache,
624 "Method {} should be cacheable: {}",
625 method,
626 should_cache
627 );
628 }
629 }
630
631 #[actix_web::test]
632 async fn test_predicate_with_headers() {
633 let predicate = |ctx: &CacheDecisionContext| !ctx.headers.contains_key("Authorization");
637
638 let mut headers = header::HeaderMap::new();
640 let ctx_no_auth = CacheDecisionContext {
641 method: "GET",
642 path: "/test",
643 query_string: "",
644 headers: &headers,
645 body: &[],
646 };
647
648 assert!(
649 predicate(&ctx_no_auth),
650 "Request without Authorization should be cached"
651 );
652
653 headers.insert(
655 header::AUTHORIZATION,
656 header::HeaderValue::from_static("Bearer token"),
657 );
658
659 let ctx_with_auth = CacheDecisionContext {
660 method: "GET",
661 path: "/test",
662 query_string: "",
663 headers: &headers,
664 body: &[],
665 };
666
667 assert!(
668 !predicate(&ctx_with_auth),
669 "Request with Authorization should not be cached"
670 );
671 }
672
673 #[actix_web::test]
674 async fn test_predicate_with_path_patterns() {
675 let predicate = |ctx: &CacheDecisionContext| {
679 !ctx.path.starts_with("/admin") && !ctx.path.contains("/private/")
680 };
681
682 let cacheable_paths = ["/", "/api/users", "/public/resource", "/api/v1/data"];
684
685 for path in cacheable_paths {
686 let ctx = CacheDecisionContext {
687 method: "GET",
688 path,
689 query_string: "",
690 headers: &header::HeaderMap::new(),
691 body: &[],
692 };
693
694 assert!(predicate(&ctx), "Path {} should be cacheable", path);
695 }
696
697 let non_cacheable_paths = ["/admin", "/admin/users", "/users/private/profile"];
699
700 for path in non_cacheable_paths {
701 let ctx = CacheDecisionContext {
702 method: "GET",
703 path,
704 query_string: "",
705 headers: &header::HeaderMap::new(),
706 body: &[],
707 };
708
709 assert!(!predicate(&ctx), "Path {} should not be cacheable", path);
710 }
711 }
712
713 #[actix_web::test]
714 async fn test_cached_response_serialization() {
715 let cached_response = CachedResponse {
717 status: 200,
718 headers: vec![
719 ("Content-Type".to_string(), "text/plain".to_string()),
720 ("X-Test".to_string(), "value".to_string()),
721 ],
722 body: b"test response".to_vec(),
723 };
724
725 let serialized = serde_json::to_string(&cached_response).unwrap();
727
728 let deserialized: CachedResponse = serde_json::from_str(&serialized).unwrap();
730
731 assert_eq!(deserialized.status, 200);
733 assert_eq!(deserialized.headers.len(), 2);
734 assert_eq!(deserialized.headers[0].0, "Content-Type");
735 assert_eq!(deserialized.headers[0].1, "text/plain");
736 assert_eq!(deserialized.headers[1].0, "X-Test");
737 assert_eq!(deserialized.headers[1].1, "value");
738 assert_eq!(deserialized.body, b"test response");
739 }
740}