actix_request_reply_cache/
lib.rs

1#![warn(missing_docs)]
2//! # Actix Request-Reply Cache
3//!
4//! A Redis-backed caching middleware for Actix Web that enables response caching.
5//!
6//! This library implements efficient HTTP response caching using Redis as a backend store,
7//! with functionality for fine-grained cache control through predicates that can examine
8//! request context to determine cacheability.
9//!
10//! ## Features
11//!
12//! - Redis-backed HTTP response caching
13//! - Configurable TTL (time-to-live) for cached responses
14//! - Customizable cache key prefix
15//! - Maximum cacheable response size configuration
16//! - Flexible cache control through predicate functions
17//! - Respects standard HTTP cache control headers
18//!
19//! ## Example
20//!
21//! ```rust
22//! use actix_web::{web, App, HttpServer};
23//! use actix_request_reply_cache::RedisCacheMiddlewareBuilder;
24//!
25//! #[actix_web::main]
26//! async fn main() -> std::io::Result<()> {
27//!     // Create the cache middleware
28//!     let cache = RedisCacheMiddlewareBuilder::new("redis://127.0.0.1:6379")
29//!         .ttl(60)  // Cache for 60 seconds
30//!         .cache_if(|ctx| {
31//!             // Only cache GET requests without Authorization header
32//!             ctx.method == "GET" && !ctx.headers.contains_key("Authorization")
33//!         })
34//!         .build()
35//!         .await;
36//!         
37//!     HttpServer::new(move || {
38//!         App::new()
39//!             .wrap(cache.clone())
40//!             .service(web::resource("/").to(|| async { "Hello world!" }))
41//!     })
42//!     .bind(("127.0.0.1", 8080))?
43//!     .run()
44//!     .await
45//! }
46//! ```
47use actix_web::{
48    body::{BoxBody, EitherBody, MessageBody},
49    dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
50    http::header::HeaderMap,
51    web::{Bytes, BytesMut},
52    Error, HttpMessage,
53};
54use futures::{
55    future::{ready, LocalBoxFuture, Ready},
56    StreamExt,
57};
58use redis::{aio::MultiplexedConnection, AsyncCommands};
59use serde::{Deserialize, Serialize};
60use sha2::{Digest, Sha256};
61use std::rc::Rc;
62use std::sync::Arc;
63
64/// Context used to determine if a request/response should be cached.
65///
66/// This struct contains information about the current request that can be
67/// examined by cache predicate functions to make caching decisions.
68pub struct CacheDecisionContext<'a> {
69    /// The HTTP method of the request (e.g., "GET", "POST")
70    pub method: &'a str,
71    /// The request path
72    pub path: &'a str,
73    /// The query string from the request URL
74    pub query_string: &'a str,
75    /// HTTP headers from the request
76    pub headers: &'a HeaderMap,
77    /// The request body as a byte slice
78    pub body: &'a [u8],
79}
80
81/// Function type for cache decision predicates.
82///
83/// This type represents functions that take a `CacheDecisionContext` and return
84/// a boolean indicating whether the response should be cached.
85type CachePredicate = Arc<dyn Fn(&CacheDecisionContext) -> bool + Send + Sync>;
86
87/// Redis-backed caching middleware for Actix Web.
88///
89/// This middleware intercepts responses, caches them in Redis, and serves
90/// cached responses for subsequent matching requests when available.
91pub struct RedisCacheMiddleware {
92    redis_conn: Option<MultiplexedConnection>,
93    redis_url: String,
94    ttl: u64,
95    max_cacheable_size: usize,
96    cache_prefix: String,
97    cache_if: CachePredicate,
98}
99
100/// Builder for configuring and creating the `RedisCacheMiddleware`.
101///
102/// Provides a fluent interface for configuring cache parameters such as TTL,
103/// maximum cacheable size, cache key prefix, and cache decision predicates.
104pub struct RedisCacheMiddlewareBuilder {
105    redis_url: String,
106    ttl: u64,
107    max_cacheable_size: usize,
108    cache_prefix: String,
109    cache_if: CachePredicate,
110}
111
112impl RedisCacheMiddlewareBuilder {
113    /// Creates a new builder with the given Redis URL.
114    ///
115    /// # Arguments
116    ///
117    /// * `redis_url` - The Redis connection URL (e.g., "redis://127.0.0.1:6379")
118    ///
119    /// # Returns
120    ///
121    /// A new `RedisCacheMiddlewareBuilder` with default settings:
122    /// - TTL: 3600 seconds (1 hour)
123    /// - Max cacheable size: 1MB
124    /// - Cache prefix: "cache:"
125    /// - Cache predicate: cache all responses
126    pub fn new(redis_url: impl Into<String>) -> Self {
127        Self {
128            redis_url: redis_url.into(),
129            ttl: 3600,                       // 1 hour default
130            max_cacheable_size: 1024 * 1024, // 1MB default
131            cache_prefix: "cache:".to_string(),
132            cache_if: Arc::new(|_| true), // Default: cache everything
133        }
134    }
135
136    /// Sets the TTL (time-to-live) for cached responses in seconds.
137    ///
138    /// # Arguments
139    ///
140    /// * `seconds` - The number of seconds a response should remain in the cache
141    ///
142    /// # Returns
143    ///
144    /// Self for method chaining
145    pub fn ttl(mut self, seconds: u64) -> Self {
146        self.ttl = seconds;
147        self
148    }
149
150    /// Sets the maximum size of responses that can be cached, in bytes.
151    ///
152    /// Responses larger than this size will not be cached.
153    ///
154    /// # Arguments
155    ///
156    /// * `bytes` - The maximum cacheable response size in bytes
157    ///
158    /// # Returns
159    ///
160    /// Self for method chaining
161    pub fn max_cacheable_size(mut self, bytes: usize) -> Self {
162        self.max_cacheable_size = bytes;
163        self
164    }
165
166    /// Sets the prefix used for Redis cache keys.
167    ///
168    /// # Arguments
169    ///
170    /// * `prefix` - The string prefix to use for all cache keys
171    ///
172    /// # Returns
173    ///
174    /// Self for method chaining
175    pub fn cache_prefix(mut self, prefix: impl Into<String>) -> Self {
176        self.cache_prefix = prefix.into();
177        self
178    }
179
180    /// Set a predicate function to determine if a response should be cached
181    ///
182    /// Example:
183    /// ```
184    /// builder.cache_if(|ctx| {
185    ///     // Only cache GET requests
186    ///     if ctx.method != "GET" {
187    ///         return false;
188    ///     }
189    ///     
190    ///     // Don't cache if Authorization header is present
191    ///     if ctx.headers.contains_key("Authorization") {
192    ///         return false;
193    ///     }
194    ///     
195    ///     // Don't cache responses to paths that start with /admin
196    ///     if ctx.path.starts_with("/admin") {
197    ///         return false;
198    ///     }
199    ///
200    ///     // Don't cache for a specific route if its body contains some field
201    ///     if ctx.path.starts_with("/api/users") && ctx.method == "POST" {
202    ///        if let Ok(user_json) = serde_json::from_slice::<serde_json::Value>(ctx.body) {
203    ///            // Check properties in the JSON to make caching decisions
204    ///            return user_json.get("role").and_then(|r| r.as_str()) != Some("admin");
205    ///        }
206    ///    }
207    ///     true
208    /// })
209    /// ```
210    pub fn cache_if<F>(mut self, predicate: F) -> Self
211    where
212        F: Fn(&CacheDecisionContext) -> bool + Send + Sync + 'static,
213    {
214        self.cache_if = Arc::new(predicate);
215        self
216    }
217
218    /// Builds and returns the configured `RedisCacheMiddleware`.
219    ///
220    /// # Returns
221    ///
222    /// A new `RedisCacheMiddleware` instance configured with the settings from this builder.
223    pub fn build(self) -> RedisCacheMiddleware {
224        RedisCacheMiddleware {
225            redis_conn: None,
226            redis_url: self.redis_url,
227            ttl: self.ttl,
228            max_cacheable_size: self.max_cacheable_size,
229            cache_prefix: self.cache_prefix,
230            cache_if: self.cache_if,
231        }
232    }
233}
234
235impl RedisCacheMiddleware {
236    /// Creates a new `RedisCacheMiddleware` with default settings.
237    ///
238    /// This is a convenience method that uses the builder with default settings.
239    ///
240    /// # Arguments
241    ///
242    /// * `redis_url` - The Redis connection URL
243    ///
244    /// # Returns
245    ///
246    /// A new `RedisCacheMiddleware` instance with default settings.
247    pub fn new(redis_url: &str) -> Self {
248        RedisCacheMiddlewareBuilder::new(redis_url).build()
249    }
250}
251
252/// Service implementation for the Redis cache middleware.
253///
254/// This struct is created by the `RedisCacheMiddleware` and handles
255/// the actual interception of requests and responses for caching.
256pub struct RedisCacheMiddlewareService<S> {
257    service: Rc<S>,
258    redis_conn: Option<MultiplexedConnection>,
259    redis_url: String,
260    ttl: u64,
261    max_cacheable_size: usize,
262    cache_prefix: String,
263    cache_if: CachePredicate,
264}
265
266#[derive(Serialize, Deserialize)]
267struct CachedResponse {
268    status: u16,
269    headers: Vec<(String, String)>,
270    body: Vec<u8>,
271}
272
273impl<S, B> Transform<S, ServiceRequest> for RedisCacheMiddleware
274where
275    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
276    S::Future: 'static,
277    B: 'static + Clone + MessageBody,
278{
279    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
280    type Error = Error;
281    type Transform = RedisCacheMiddlewareService<S>;
282    type InitError = ();
283    type Future = Ready<Result<Self::Transform, Self::InitError>>;
284
285    /// Creates a new transform of the input service.
286    fn new_transform(&self, service: S) -> Self::Future {
287        ready(Ok(RedisCacheMiddlewareService {
288            service: Rc::new(service),
289            redis_conn: self.redis_conn.clone(),
290            redis_url: self.redis_url.clone(),
291            ttl: self.ttl,
292            max_cacheable_size: self.max_cacheable_size,
293            cache_prefix: self.cache_prefix.clone(),
294            cache_if: self.cache_if.clone(),
295        }))
296    }
297}
298
299impl<S, B> Service<ServiceRequest> for RedisCacheMiddlewareService<S>
300where
301    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
302    S::Future: 'static,
303    B: actix_web::body::MessageBody + 'static + Clone,
304{
305    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
306    type Error = Error;
307    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
308
309    forward_ready!(service);
310
311    fn call(&self, mut req: ServiceRequest) -> Self::Future {
312        if let Some(cache_control) = req.headers().get("Cache-Control") {
313            if let Ok(cache_control_str) = cache_control.to_str() {
314                if cache_control_str.contains("no-cache") || cache_control_str.contains("no-store")
315                {
316                    let fut = self.service.call(req);
317                    return Box::pin(async move {
318                        let res = fut.await?;
319                        Ok(res.map_body(|_, b| EitherBody::left(b)))
320                    });
321                }
322            }
323        }
324
325        let redis_url = self.redis_url.clone();
326        let mut redis_conn = self.redis_conn.clone();
327        let expiration = self.ttl;
328        let max_cacheable_size = self.max_cacheable_size;
329        let cache_prefix = self.cache_prefix.clone();
330        let service = Rc::clone(&self.service);
331        let cache_if = self.cache_if.clone();
332
333        Box::pin(async move {
334            let body_bytes = req
335                .take_payload()
336                .fold(BytesMut::new(), move |mut body, chunk| async {
337                    if let Ok(chunk) = chunk {
338                        body.extend_from_slice(&chunk);
339                    }
340                    body
341                })
342                .await;
343
344            let cache_ctx = CacheDecisionContext {
345                method: req.method().as_str(),
346                path: req.path(),
347                query_string: req.query_string(),
348                headers: req.headers(),
349                body: &body_bytes,
350            };
351
352            let should_cache = cache_if(&cache_ctx);
353
354            req.set_payload(Payload::from(Bytes::from(body_bytes.clone())));
355
356            let base_key = if body_bytes.is_empty() {
357                format!(
358                    "{}:{}:{}",
359                    req.method().as_str(),
360                    req.path(),
361                    req.query_string()
362                )
363            } else {
364                let body_hash = hex::encode(Sha256::digest(&body_bytes));
365                format!(
366                    "{}:{}:{}:{}",
367                    req.method().as_str(),
368                    req.path(),
369                    req.query_string(),
370                    body_hash
371                )
372            };
373
374            let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
375            let cache_key = format!("{}{}", cache_prefix, hashed_key);
376
377            // Only try to get from cache if we're considering caching for this request
378            let cached_result: Option<String> = if should_cache {
379                if redis_conn.is_none() {
380                    let client = redis::Client::open(redis_url.as_str())
381                        .expect("Failed to connect to Redis");
382
383                    let conn = client
384                        .get_multiplexed_async_connection()
385                        .await
386                        .expect("Failed to get Redis connection");
387
388                    redis_conn = Some(conn);
389                }
390
391                let conn = redis_conn.as_mut().unwrap();
392                conn.get(&cache_key).await.unwrap_or(None)
393            } else {
394                None
395            };
396
397            if let Some(cached_data) = cached_result {
398                log::debug!("Cache hit for {}", cache_key);
399
400                // Deserialize cached response
401                match serde_json::from_str::<CachedResponse>(&cached_data) {
402                    Ok(cached_response) => {
403                        let mut response = actix_web::HttpResponse::build(
404                            actix_web::http::StatusCode::from_u16(cached_response.status)
405                                .unwrap_or(actix_web::http::StatusCode::OK),
406                        );
407
408                        for (name, value) in cached_response.headers {
409                            response.insert_header((name, value));
410                        }
411
412                        response.insert_header(("X-Cache", "HIT"));
413
414                        let resp = response.body(cached_response.body);
415                        return Ok(req
416                            .into_response(resp)
417                            .map_body(|_, b| EitherBody::right(BoxBody::new(b))));
418                    }
419                    Err(e) => {
420                        log::error!("Failed to deserialize cached response: {}", e);
421                    }
422                }
423            }
424
425            log::debug!("Cache miss for {}", cache_key);
426
427            let service_result = service.call(req).await?;
428
429            // Only store in cache if we're considering caching and the response is successful
430            if should_cache && service_result.status().is_success() {
431                let res = service_result.response();
432
433                let status = res.status().as_u16();
434
435                let headers = res
436                    .headers()
437                    .iter()
438                    .filter(|(name, _)| {
439                        !["connection", "transfer-encoding", "content-length"]
440                            .contains(&name.as_str().to_lowercase().as_str())
441                    })
442                    .map(|(name, value)| {
443                        (
444                            name.to_string(),
445                            value.to_str().unwrap_or_default().to_string(),
446                        )
447                    })
448                    .collect::<Vec<_>>();
449
450                if let Ok(body) = res.body().clone().try_into_bytes() {
451                    if !body.is_empty() && body.len() <= max_cacheable_size {
452                        let cached_response = CachedResponse {
453                            status,
454                            headers,
455                            body: body.to_vec(),
456                        };
457
458                        if let Ok(serialized) = serde_json::to_string(&cached_response) {
459                            if redis_conn.is_none() {
460                                let client = redis::Client::open(redis_url.as_str())
461                                    .expect("Failed to connect to Redis");
462
463                                let conn = client
464                                    .get_multiplexed_async_connection()
465                                    .await
466                                    .expect("Failed to get Redis connection");
467
468                                redis_conn = Some(conn);
469                            }
470
471                            let conn = redis_conn.as_mut().unwrap();
472                            let _: Result<(), redis::RedisError> =
473                                conn.set_ex(cache_key, serialized, expiration).await;
474                        }
475                    }
476                }
477            }
478
479            Ok(service_result.map_body(|_, b| EitherBody::left(b)))
480        })
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use actix_web::{http::header, test::TestRequest};
488
489    #[actix_web::test]
490    async fn test_builder_default_values() {
491        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
492        assert_eq!(builder.ttl, 3600);
493        assert_eq!(builder.max_cacheable_size, 1024 * 1024);
494        assert_eq!(builder.cache_prefix, "cache:");
495        assert_eq!(builder.redis_url, "redis://localhost");
496    }
497
498    #[actix_web::test]
499    async fn test_builder_custom_values() {
500        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
501            .ttl(60)
502            .max_cacheable_size(512 * 1024)
503            .cache_prefix("custom:");
504
505        assert_eq!(builder.ttl, 60);
506        assert_eq!(builder.max_cacheable_size, 512 * 1024);
507        assert_eq!(builder.cache_prefix, "custom:");
508    }
509
510    #[actix_web::test]
511    async fn test_builder_custom_predicate() {
512        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
513            .cache_if(|ctx| ctx.method == "GET");
514
515        // Test the predicate
516        let get_ctx = CacheDecisionContext {
517            method: "GET",
518            path: "/test",
519            query_string: "",
520            headers: &header::HeaderMap::new(),
521            body: &[],
522        };
523
524        let post_ctx = CacheDecisionContext {
525            method: "POST",
526            path: "/test",
527            query_string: "",
528            headers: &header::HeaderMap::new(),
529            body: &[],
530        };
531
532        // The predicate should now only allow GET requests
533        assert!((builder.cache_if)(&get_ctx));
534        assert!(!(builder.cache_if)(&post_ctx));
535    }
536
537    #[actix_web::test]
538    async fn test_cache_key_generation() {
539        // Create a simple request
540        let req = TestRequest::get().uri("/test").to_srv_request();
541
542        // Extract the relevant parts for key generation
543        let method = req.method().as_str();
544        let path = req.path();
545        let query_string = req.query_string();
546
547        // Generate key manually as done in the middleware
548        let base_key = format!("{}:{}:{}", method, path, query_string);
549        let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
550        let cache_key = format!("test:{}", hashed_key);
551
552        // Now verify this matches what our middleware would generate
553        let expected_key = format!(
554            "test:{}",
555            hex::encode(Sha256::digest("GET:/test:".to_string().as_bytes()))
556        );
557
558        assert_eq!(cache_key, expected_key);
559    }
560
561    #[actix_web::test]
562    async fn test_cache_key_with_body() {
563        // Test case for when request has a body
564        let body_bytes = b"test body";
565        let body_hash = hex::encode(Sha256::digest(body_bytes));
566
567        // Generate key manually as done in the middleware
568        let base_key = format!("{}:{}:{}:{}", "POST", "/test", "", body_hash);
569        let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
570        let cache_key = format!("test:{}", hashed_key);
571
572        // Expected key when body is present
573        let expected_key = format!(
574            "test:{}",
575            hex::encode(Sha256::digest(
576                format!("POST:/test::{}", body_hash).as_bytes()
577            ))
578        );
579
580        assert_eq!(cache_key, expected_key);
581    }
582
583    #[actix_web::test]
584    async fn test_cacheable_methods() {
585        // Test different HTTP methods with default predicate
586        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
587        let default_predicate = builder.cache_if;
588
589        let methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"];
590
591        for method in methods {
592            let ctx = CacheDecisionContext {
593                method,
594                path: "/test",
595                query_string: "",
596                headers: &header::HeaderMap::new(),
597                body: &[],
598            };
599
600            // Default predicate should cache all methods
601            assert!(
602                (default_predicate)(&ctx),
603                "Method {} should be cacheable by default",
604                method
605            );
606        }
607
608        // Test with a custom predicate that only caches GET and HEAD
609        let custom_builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
610            .cache_if(|ctx| matches!(ctx.method, "GET" | "HEAD"));
611
612        for method in methods {
613            let ctx = CacheDecisionContext {
614                method,
615                path: "/test",
616                query_string: "",
617                headers: &header::HeaderMap::new(),
618                body: &[],
619            };
620
621            // Check if method should be cached according to our custom predicate
622            let should_cache = matches!(method, "GET" | "HEAD");
623            assert_eq!(
624                (custom_builder.cache_if)(&ctx),
625                should_cache,
626                "Method {} should be cacheable: {}",
627                method,
628                should_cache
629            );
630        }
631    }
632
633    #[actix_web::test]
634    async fn test_predicate_with_headers() {
635        // Test predicate behavior with different headers
636
637        // Create a predicate that doesn't cache requests with Authorization header
638        let predicate = |ctx: &CacheDecisionContext| !ctx.headers.contains_key("Authorization");
639
640        // Test with empty headers
641        let mut headers = header::HeaderMap::new();
642        let ctx_no_auth = CacheDecisionContext {
643            method: "GET",
644            path: "/test",
645            query_string: "",
646            headers: &headers,
647            body: &[],
648        };
649
650        assert!(
651            predicate(&ctx_no_auth),
652            "Request without Authorization should be cached"
653        );
654
655        // Test with Authorization header
656        headers.insert(
657            header::AUTHORIZATION,
658            header::HeaderValue::from_static("Bearer token"),
659        );
660
661        let ctx_with_auth = CacheDecisionContext {
662            method: "GET",
663            path: "/test",
664            query_string: "",
665            headers: &headers,
666            body: &[],
667        };
668
669        assert!(
670            !predicate(&ctx_with_auth),
671            "Request with Authorization should not be cached"
672        );
673    }
674
675    #[actix_web::test]
676    async fn test_predicate_with_path_patterns() {
677        // Test predicate behavior with different path patterns
678
679        // Create a predicate that doesn't cache admin paths
680        let predicate = |ctx: &CacheDecisionContext| {
681            !ctx.path.starts_with("/admin") && !ctx.path.contains("/private/")
682        };
683
684        // Test paths that should be cached
685        let cacheable_paths = ["/", "/api/users", "/public/resource", "/api/v1/data"];
686
687        for path in cacheable_paths {
688            let ctx = CacheDecisionContext {
689                method: "GET",
690                path,
691                query_string: "",
692                headers: &header::HeaderMap::new(),
693                body: &[],
694            };
695
696            assert!(predicate(&ctx), "Path {} should be cacheable", path);
697        }
698
699        // Test paths that should not be cached
700        let non_cacheable_paths = ["/admin", "/admin/users", "/users/private/profile"];
701
702        for path in non_cacheable_paths {
703            let ctx = CacheDecisionContext {
704                method: "GET",
705                path,
706                query_string: "",
707                headers: &header::HeaderMap::new(),
708                body: &[],
709            };
710
711            assert!(!predicate(&ctx), "Path {} should not be cacheable", path);
712        }
713    }
714
715    #[actix_web::test]
716    async fn test_cached_response_serialization() {
717        // Test that CachedResponse can be properly serialized and deserialized
718        let cached_response = CachedResponse {
719            status: 200,
720            headers: vec![
721                ("Content-Type".to_string(), "text/plain".to_string()),
722                ("X-Test".to_string(), "value".to_string()),
723            ],
724            body: b"test response".to_vec(),
725        };
726
727        // Serialize
728        let serialized = serde_json::to_string(&cached_response).unwrap();
729
730        // Deserialize
731        let deserialized: CachedResponse = serde_json::from_str(&serialized).unwrap();
732
733        // Verify fields match
734        assert_eq!(deserialized.status, 200);
735        assert_eq!(deserialized.headers.len(), 2);
736        assert_eq!(deserialized.headers[0].0, "Content-Type");
737        assert_eq!(deserialized.headers[0].1, "text/plain");
738        assert_eq!(deserialized.headers[1].0, "X-Test");
739        assert_eq!(deserialized.headers[1].1, "value");
740        assert_eq!(deserialized.body, b"test response");
741    }
742}