1#![warn(missing_docs)]
2use actix_web::{
48 body::{BoxBody, EitherBody, MessageBody},
49 dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
50 http::header::HeaderMap,
51 web::{Bytes, BytesMut},
52 Error, HttpMessage,
53};
54use futures::{
55 future::{ready, LocalBoxFuture, Ready},
56 StreamExt,
57};
58use redis::{aio::MultiplexedConnection, AsyncCommands};
59use serde::{Deserialize, Serialize};
60use sha2::{Digest, Sha256};
61use std::sync::Arc;
62
63pub struct CacheDecisionContext<'a> {
68 pub method: &'a str,
70 pub path: &'a str,
72 pub query_string: &'a str,
74 pub headers: &'a HeaderMap,
76 pub body: &'a [u8],
78}
79
80type CachePredicate = Arc<dyn Fn(&CacheDecisionContext) -> bool + Send + Sync>;
85
86pub struct RedisCacheMiddleware {
91 redis_conn: MultiplexedConnection,
92 ttl: u64,
93 max_cacheable_size: usize,
94 cache_prefix: String,
95 cache_if: CachePredicate,
96}
97
98pub struct RedisCacheMiddlewareBuilder {
103 redis_url: String,
104 ttl: u64,
105 max_cacheable_size: usize,
106 cache_prefix: String,
107 cache_if: CachePredicate,
108}
109
110impl RedisCacheMiddlewareBuilder {
111 pub fn new(redis_url: impl Into<String>) -> Self {
125 Self {
126 redis_url: redis_url.into(),
127 ttl: 3600, max_cacheable_size: 1024 * 1024, cache_prefix: "cache:".to_string(),
130 cache_if: Arc::new(|_| true), }
132 }
133
134 pub fn ttl(mut self, seconds: u64) -> Self {
144 self.ttl = seconds;
145 self
146 }
147
148 pub fn max_cacheable_size(mut self, bytes: usize) -> Self {
160 self.max_cacheable_size = bytes;
161 self
162 }
163
164 pub fn cache_prefix(mut self, prefix: impl Into<String>) -> Self {
174 self.cache_prefix = prefix.into();
175 self
176 }
177
178 pub fn cache_if<F>(mut self, predicate: F) -> Self
209 where
210 F: Fn(&CacheDecisionContext) -> bool + Send + Sync + 'static,
211 {
212 self.cache_if = Arc::new(predicate);
213 self
214 }
215
216 pub async fn build(self) -> RedisCacheMiddleware {
226 let client =
227 redis::Client::open(self.redis_url.as_str()).expect("Failed to connect to Redis");
228
229 let redis_conn = client
230 .get_multiplexed_async_connection()
231 .await
232 .expect("Failed to get Redis connection");
233
234 RedisCacheMiddleware {
235 redis_conn,
236 ttl: self.ttl,
237 max_cacheable_size: self.max_cacheable_size,
238 cache_prefix: self.cache_prefix,
239 cache_if: self.cache_if,
240 }
241 }
242}
243
244impl RedisCacheMiddleware {
245 pub async fn new(redis_url: &str) -> Self {
257 RedisCacheMiddlewareBuilder::new(redis_url).build().await
258 }
259}
260
261pub struct RedisCacheMiddlewareService<S> {
266 service: S,
267 redis_conn: MultiplexedConnection,
268 ttl: u64,
269 max_cacheable_size: usize,
270 cache_prefix: String,
271 cache_if: CachePredicate,
272}
273
274#[derive(Serialize, Deserialize)]
275struct CachedResponse {
276 status: u16,
277 headers: Vec<(String, String)>,
278 body: Vec<u8>,
279}
280
281impl<S, B> Transform<S, ServiceRequest> for RedisCacheMiddleware
282where
283 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
284 S::Future: 'static,
285 B: 'static + Clone + MessageBody,
286{
287 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
288 type Error = Error;
289 type Transform = RedisCacheMiddlewareService<S>;
290 type InitError = ();
291 type Future = Ready<Result<Self::Transform, Self::InitError>>;
292
293 fn new_transform(&self, service: S) -> Self::Future {
294 ready(Ok(RedisCacheMiddlewareService {
295 service,
296 redis_conn: self.redis_conn.clone(),
297 ttl: self.ttl,
298 max_cacheable_size: self.max_cacheable_size,
299 cache_prefix: self.cache_prefix.clone(),
300 cache_if: self.cache_if.clone(),
301 }))
302 }
303}
304
305impl<S, B> Service<ServiceRequest> for RedisCacheMiddlewareService<S>
306where
307 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
308 S::Future: 'static,
309 B: actix_web::body::MessageBody + 'static + Clone,
310{
311 type Response = ServiceResponse<EitherBody<B, BoxBody>>;
312 type Error = Error;
313 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
314
315 forward_ready!(service);
316
317 fn call(&self, mut req: ServiceRequest) -> Self::Future {
318 if let Some(cache_control) = req.headers().get("Cache-Control") {
319 if let Ok(cache_control_str) = cache_control.to_str() {
320 if cache_control_str.contains("no-cache") || cache_control_str.contains("no-store")
321 {
322 let fut = self.service.call(req);
323 return Box::pin(async move {
324 let res = fut.await?;
325 Ok(res.map_body(|_, b| EitherBody::left(b)))
326 });
327 }
328 }
329 }
330
331 let mut redis_conn = self.redis_conn.clone();
332 let expiration = self.ttl;
333 let max_cacheable_size = self.max_cacheable_size;
334 let cache_prefix = self.cache_prefix.clone();
335 let service = self.service.clone();
336 let cache_if = self.cache_if.clone();
337
338 Box::pin(async move {
339 let body_bytes = req
340 .take_payload()
341 .fold(BytesMut::new(), move |mut body, chunk| async {
342 if let Ok(chunk) = chunk {
343 body.extend_from_slice(&chunk);
344 }
345 body
346 })
347 .await;
348
349 let cache_ctx = CacheDecisionContext {
350 method: req.method().as_str(),
351 path: req.path(),
352 query_string: req.query_string(),
353 headers: req.headers(),
354 body: &body_bytes,
355 };
356
357 let should_cache = cache_if(&cache_ctx);
358
359 req.set_payload(Payload::from(Bytes::from(body_bytes.clone())));
360
361 let base_key = if body_bytes.is_empty() {
362 format!(
363 "{}:{}:{}",
364 req.method().as_str(),
365 req.path(),
366 req.query_string()
367 )
368 } else {
369 let body_hash = hex::encode(Sha256::digest(&body_bytes));
370 format!(
371 "{}:{}:{}:{}",
372 req.method().as_str(),
373 req.path(),
374 req.query_string(),
375 body_hash
376 )
377 };
378
379 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
380 let cache_key = format!("{}{}", cache_prefix, hashed_key);
381
382 let cached_result: Option<String> = if should_cache {
384 redis_conn.get(&cache_key).await.unwrap_or(None)
385 } else {
386 None
387 };
388
389 if let Some(cached_data) = cached_result {
390 log::debug!("Cache hit for {}", cache_key);
391
392 match serde_json::from_str::<CachedResponse>(&cached_data) {
394 Ok(cached_response) => {
395 let mut response = actix_web::HttpResponse::build(
396 actix_web::http::StatusCode::from_u16(cached_response.status)
397 .unwrap_or(actix_web::http::StatusCode::OK),
398 );
399
400 for (name, value) in cached_response.headers {
401 response.insert_header((name, value));
402 }
403
404 response.insert_header(("X-Cache", "HIT"));
405
406 let resp = response.body(cached_response.body);
407 return Ok(req
408 .into_response(resp)
409 .map_body(|_, b| EitherBody::right(BoxBody::new(b))));
410 }
411 Err(e) => {
412 log::error!("Failed to deserialize cached response: {}", e);
413 }
414 }
415 }
416
417 log::debug!("Cache miss for {}", cache_key);
418
419 let service_result = service.call(req).await?;
420
421 if should_cache && service_result.status().is_success() {
423 let res = service_result.response();
424
425 let status = res.status().as_u16();
426
427 let headers = res
428 .headers()
429 .iter()
430 .filter(|(name, _)| {
431 !["connection", "transfer-encoding", "content-length"]
432 .contains(&name.as_str().to_lowercase().as_str())
433 })
434 .map(|(name, value)| {
435 (
436 name.to_string(),
437 value.to_str().unwrap_or_default().to_string(),
438 )
439 })
440 .collect::<Vec<_>>();
441
442 if let Ok(body) = res.body().clone().try_into_bytes() {
443 if !body.is_empty() && body.len() <= max_cacheable_size {
444 let cached_response = CachedResponse {
445 status,
446 headers,
447 body: body.to_vec(),
448 };
449
450 if let Ok(serialized) = serde_json::to_string(&cached_response) {
451 let _: Result<(), redis::RedisError> =
452 redis_conn.set_ex(cache_key, serialized, expiration).await;
453 }
454 }
455 }
456 }
457
458 Ok(service_result.map_body(|_, b| EitherBody::left(b)))
459 })
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use actix_web::{http::header, test::TestRequest};
467
468 #[actix_web::test]
469 async fn test_builder_default_values() {
470 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
471 assert_eq!(builder.ttl, 3600);
472 assert_eq!(builder.max_cacheable_size, 1024 * 1024);
473 assert_eq!(builder.cache_prefix, "cache:");
474 assert_eq!(builder.redis_url, "redis://localhost");
475 }
476
477 #[actix_web::test]
478 async fn test_builder_custom_values() {
479 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
480 .ttl(60)
481 .max_cacheable_size(512 * 1024)
482 .cache_prefix("custom:");
483
484 assert_eq!(builder.ttl, 60);
485 assert_eq!(builder.max_cacheable_size, 512 * 1024);
486 assert_eq!(builder.cache_prefix, "custom:");
487 }
488
489 #[actix_web::test]
490 async fn test_builder_custom_predicate() {
491 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
492 .cache_if(|ctx| ctx.method == "GET");
493
494 let get_ctx = CacheDecisionContext {
496 method: "GET",
497 path: "/test",
498 query_string: "",
499 headers: &header::HeaderMap::new(),
500 body: &[],
501 };
502
503 let post_ctx = CacheDecisionContext {
504 method: "POST",
505 path: "/test",
506 query_string: "",
507 headers: &header::HeaderMap::new(),
508 body: &[],
509 };
510
511 assert!((builder.cache_if)(&get_ctx));
513 assert!(!(builder.cache_if)(&post_ctx));
514 }
515
516 #[actix_web::test]
517 async fn test_cache_key_generation() {
518 let req = TestRequest::get().uri("/test").to_srv_request();
520
521 let method = req.method().as_str();
523 let path = req.path();
524 let query_string = req.query_string();
525
526 let base_key = format!("{}:{}:{}", method, path, query_string);
528 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
529 let cache_key = format!("test:{}", hashed_key);
530
531 let expected_key = format!(
533 "test:{}",
534 hex::encode(Sha256::digest("GET:/test:".to_string().as_bytes()))
535 );
536
537 assert_eq!(cache_key, expected_key);
538 }
539
540 #[actix_web::test]
541 async fn test_cache_key_with_body() {
542 let body_bytes = b"test body";
544 let body_hash = hex::encode(Sha256::digest(body_bytes));
545
546 let base_key = format!("{}:{}:{}:{}", "POST", "/test", "", body_hash);
548 let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
549 let cache_key = format!("test:{}", hashed_key);
550
551 let expected_key = format!(
553 "test:{}",
554 hex::encode(Sha256::digest(
555 format!("POST:/test::{}", body_hash).as_bytes()
556 ))
557 );
558
559 assert_eq!(cache_key, expected_key);
560 }
561
562 #[actix_web::test]
563 async fn test_cacheable_methods() {
564 let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
566 let default_predicate = builder.cache_if;
567
568 let methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"];
569
570 for method in methods {
571 let ctx = CacheDecisionContext {
572 method,
573 path: "/test",
574 query_string: "",
575 headers: &header::HeaderMap::new(),
576 body: &[],
577 };
578
579 assert!(
581 (default_predicate)(&ctx),
582 "Method {} should be cacheable by default",
583 method
584 );
585 }
586
587 let custom_builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
589 .cache_if(|ctx| matches!(ctx.method, "GET" | "HEAD"));
590
591 for method in methods {
592 let ctx = CacheDecisionContext {
593 method,
594 path: "/test",
595 query_string: "",
596 headers: &header::HeaderMap::new(),
597 body: &[],
598 };
599
600 let should_cache = matches!(method, "GET" | "HEAD");
602 assert_eq!(
603 (custom_builder.cache_if)(&ctx),
604 should_cache,
605 "Method {} should be cacheable: {}",
606 method,
607 should_cache
608 );
609 }
610 }
611
612 #[actix_web::test]
613 async fn test_predicate_with_headers() {
614 let predicate = |ctx: &CacheDecisionContext| !ctx.headers.contains_key("Authorization");
618
619 let mut headers = header::HeaderMap::new();
621 let ctx_no_auth = CacheDecisionContext {
622 method: "GET",
623 path: "/test",
624 query_string: "",
625 headers: &headers,
626 body: &[],
627 };
628
629 assert!(
630 predicate(&ctx_no_auth),
631 "Request without Authorization should be cached"
632 );
633
634 headers.insert(
636 header::AUTHORIZATION,
637 header::HeaderValue::from_static("Bearer token"),
638 );
639
640 let ctx_with_auth = CacheDecisionContext {
641 method: "GET",
642 path: "/test",
643 query_string: "",
644 headers: &headers,
645 body: &[],
646 };
647
648 assert!(
649 !predicate(&ctx_with_auth),
650 "Request with Authorization should not be cached"
651 );
652 }
653
654 #[actix_web::test]
655 async fn test_predicate_with_path_patterns() {
656 let predicate = |ctx: &CacheDecisionContext| {
660 !ctx.path.starts_with("/admin") && !ctx.path.contains("/private/")
661 };
662
663 let cacheable_paths = ["/", "/api/users", "/public/resource", "/api/v1/data"];
665
666 for path in cacheable_paths {
667 let ctx = CacheDecisionContext {
668 method: "GET",
669 path,
670 query_string: "",
671 headers: &header::HeaderMap::new(),
672 body: &[],
673 };
674
675 assert!(predicate(&ctx), "Path {} should be cacheable", path);
676 }
677
678 let non_cacheable_paths = ["/admin", "/admin/users", "/users/private/profile"];
680
681 for path in non_cacheable_paths {
682 let ctx = CacheDecisionContext {
683 method: "GET",
684 path,
685 query_string: "",
686 headers: &header::HeaderMap::new(),
687 body: &[],
688 };
689
690 assert!(!predicate(&ctx), "Path {} should not be cacheable", path);
691 }
692 }
693
694 #[actix_web::test]
695 async fn test_cached_response_serialization() {
696 let cached_response = CachedResponse {
698 status: 200,
699 headers: vec![
700 ("Content-Type".to_string(), "text/plain".to_string()),
701 ("X-Test".to_string(), "value".to_string()),
702 ],
703 body: b"test response".to_vec(),
704 };
705
706 let serialized = serde_json::to_string(&cached_response).unwrap();
708
709 let deserialized: CachedResponse = serde_json::from_str(&serialized).unwrap();
711
712 assert_eq!(deserialized.status, 200);
714 assert_eq!(deserialized.headers.len(), 2);
715 assert_eq!(deserialized.headers[0].0, "Content-Type");
716 assert_eq!(deserialized.headers[0].1, "text/plain");
717 assert_eq!(deserialized.headers[1].0, "X-Test");
718 assert_eq!(deserialized.headers[1].1, "value");
719 assert_eq!(deserialized.body, b"test response");
720 }
721}