actix_request_reply_cache/
lib.rs

1#![warn(missing_docs)]
2//! # Actix Request-Reply Cache
3//!
4//! A Redis-backed caching middleware for Actix Web that enables response caching.
5//!
6//! This library implements efficient HTTP response caching using Redis as a backend store,
7//! with functionality for fine-grained cache control through predicates that can examine
8//! request context to determine cacheability.
9//!
10//! ## Features
11//!
12//! - Redis-backed HTTP response caching
13//! - Configurable TTL (time-to-live) for cached responses
14//! - Customizable cache key prefix
15//! - Maximum cacheable response size configuration
16//! - Flexible cache control through predicate functions
17//! - Respects standard HTTP cache control headers
18//!
19//! ## Example
20//!
21//! ```rust
22//! use actix_web::{web, App, HttpServer};
23//! use actix_request_reply_cache::RedisCacheMiddlewareBuilder;
24//!
25//! #[actix_web::main]
26//! async fn main() -> std::io::Result<()> {
27//!     // Create the cache middleware
28//!     let cache = RedisCacheMiddlewareBuilder::new("redis://127.0.0.1:6379")
29//!         .ttl(60)  // Cache for 60 seconds
30//!         .cache_if(|ctx| {
31//!             // Only cache GET requests without Authorization header
32//!             ctx.method == "GET" && !ctx.headers.contains_key("Authorization")
33//!         })
34//!         .build()
35//!         .await;
36//!         
37//!     HttpServer::new(move || {
38//!         App::new()
39//!             .wrap(cache.clone())
40//!             .service(web::resource("/").to(|| async { "Hello world!" }))
41//!     })
42//!     .bind(("127.0.0.1", 8080))?
43//!     .run()
44//!     .await
45//! }
46//! ```
47use actix_web::{
48    body::{BoxBody, EitherBody, MessageBody},
49    dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
50    http::header::HeaderMap,
51    web::{Bytes, BytesMut},
52    Error, HttpMessage,
53};
54use futures::{
55    future::{ready, LocalBoxFuture, Ready},
56    StreamExt,
57};
58use redis::{aio::MultiplexedConnection, AsyncCommands};
59use serde::{Deserialize, Serialize};
60use sha2::{Digest, Sha256};
61use std::sync::Arc;
62
63/// Context used to determine if a request/response should be cached.
64///
65/// This struct contains information about the current request that can be
66/// examined by cache predicate functions to make caching decisions.
67pub struct CacheDecisionContext<'a> {
68    /// The HTTP method of the request (e.g., "GET", "POST")
69    pub method: &'a str,
70    /// The request path
71    pub path: &'a str,
72    /// The query string from the request URL
73    pub query_string: &'a str,
74    /// HTTP headers from the request
75    pub headers: &'a HeaderMap,
76    /// The request body as a byte slice
77    pub body: &'a [u8],
78}
79
80/// Function type for cache decision predicates.
81///
82/// This type represents functions that take a `CacheDecisionContext` and return
83/// a boolean indicating whether the response should be cached.
84type CachePredicate = Arc<dyn Fn(&CacheDecisionContext) -> bool + Send + Sync>;
85
86/// Redis-backed caching middleware for Actix Web.
87///
88/// This middleware intercepts responses, caches them in Redis, and serves
89/// cached responses for subsequent matching requests when available.
90pub struct RedisCacheMiddleware {
91    redis_conn: Option<MultiplexedConnection>,
92    redis_url: String,
93    ttl: u64,
94    max_cacheable_size: usize,
95    cache_prefix: String,
96    cache_if: CachePredicate,
97}
98
99/// Builder for configuring and creating the `RedisCacheMiddleware`.
100///
101/// Provides a fluent interface for configuring cache parameters such as TTL,
102/// maximum cacheable size, cache key prefix, and cache decision predicates.
103pub struct RedisCacheMiddlewareBuilder {
104    redis_url: String,
105    ttl: u64,
106    max_cacheable_size: usize,
107    cache_prefix: String,
108    cache_if: CachePredicate,
109}
110
111impl RedisCacheMiddlewareBuilder {
112    /// Creates a new builder with the given Redis URL.
113    ///
114    /// # Arguments
115    ///
116    /// * `redis_url` - The Redis connection URL (e.g., "redis://127.0.0.1:6379")
117    ///
118    /// # Returns
119    ///
120    /// A new `RedisCacheMiddlewareBuilder` with default settings:
121    /// - TTL: 3600 seconds (1 hour)
122    /// - Max cacheable size: 1MB
123    /// - Cache prefix: "cache:"
124    /// - Cache predicate: cache all responses
125    pub fn new(redis_url: impl Into<String>) -> Self {
126        Self {
127            redis_url: redis_url.into(),
128            ttl: 3600,                       // 1 hour default
129            max_cacheable_size: 1024 * 1024, // 1MB default
130            cache_prefix: "cache:".to_string(),
131            cache_if: Arc::new(|_| true), // Default: cache everything
132        }
133    }
134
135    /// Sets the TTL (time-to-live) for cached responses in seconds.
136    ///
137    /// # Arguments
138    ///
139    /// * `seconds` - The number of seconds a response should remain in the cache
140    ///
141    /// # Returns
142    ///
143    /// Self for method chaining
144    pub fn ttl(mut self, seconds: u64) -> Self {
145        self.ttl = seconds;
146        self
147    }
148
149    /// Sets the maximum size of responses that can be cached, in bytes.
150    ///
151    /// Responses larger than this size will not be cached.
152    ///
153    /// # Arguments
154    ///
155    /// * `bytes` - The maximum cacheable response size in bytes
156    ///
157    /// # Returns
158    ///
159    /// Self for method chaining
160    pub fn max_cacheable_size(mut self, bytes: usize) -> Self {
161        self.max_cacheable_size = bytes;
162        self
163    }
164
165    /// Sets the prefix used for Redis cache keys.
166    ///
167    /// # Arguments
168    ///
169    /// * `prefix` - The string prefix to use for all cache keys
170    ///
171    /// # Returns
172    ///
173    /// Self for method chaining
174    pub fn cache_prefix(mut self, prefix: impl Into<String>) -> Self {
175        self.cache_prefix = prefix.into();
176        self
177    }
178
179    /// Set a predicate function to determine if a response should be cached
180    ///
181    /// Example:
182    /// ```
183    /// builder.cache_if(|ctx| {
184    ///     // Only cache GET requests
185    ///     if ctx.method != "GET" {
186    ///         return false;
187    ///     }
188    ///     
189    ///     // Don't cache if Authorization header is present
190    ///     if ctx.headers.contains_key("Authorization") {
191    ///         return false;
192    ///     }
193    ///     
194    ///     // Don't cache responses to paths that start with /admin
195    ///     if ctx.path.starts_with("/admin") {
196    ///         return false;
197    ///     }
198    ///
199    ///     // Don't cache for a specific route if its body contains some field
200    ///     if ctx.path.starts_with("/api/users") && ctx.method == "POST" {
201    ///        if let Ok(user_json) = serde_json::from_slice::<serde_json::Value>(ctx.body) {
202    ///            // Check properties in the JSON to make caching decisions
203    ///            return user_json.get("role").and_then(|r| r.as_str()) != Some("admin");
204    ///        }
205    ///    }
206    ///     true
207    /// })
208    /// ```
209    pub fn cache_if<F>(mut self, predicate: F) -> Self
210    where
211        F: Fn(&CacheDecisionContext) -> bool + Send + Sync + 'static,
212    {
213        self.cache_if = Arc::new(predicate);
214        self
215    }
216
217    /// Builds and returns the configured `RedisCacheMiddleware`.
218    ///
219    /// # Returns
220    ///
221    /// A new `RedisCacheMiddleware` instance configured with the settings from this builder.
222    pub fn build(self) -> RedisCacheMiddleware {
223        RedisCacheMiddleware {
224            redis_conn: None,
225            redis_url: self.redis_url,
226            ttl: self.ttl,
227            max_cacheable_size: self.max_cacheable_size,
228            cache_prefix: self.cache_prefix,
229            cache_if: self.cache_if,
230        }
231    }
232}
233
234impl RedisCacheMiddleware {
235    /// Creates a new `RedisCacheMiddleware` with default settings.
236    ///
237    /// This is a convenience method that uses the builder with default settings.
238    ///
239    /// # Arguments
240    ///
241    /// * `redis_url` - The Redis connection URL
242    ///
243    /// # Returns
244    ///
245    /// A new `RedisCacheMiddleware` instance with default settings.
246    pub fn new(redis_url: &str) -> Self {
247        RedisCacheMiddlewareBuilder::new(redis_url).build()
248    }
249}
250
251/// Service implementation for the Redis cache middleware.
252///
253/// This struct is created by the `RedisCacheMiddleware` and handles
254/// the actual interception of requests and responses for caching.
255pub struct RedisCacheMiddlewareService<S> {
256    service: S,
257    redis_conn: Option<MultiplexedConnection>,
258    redis_url: String,
259    ttl: u64,
260    max_cacheable_size: usize,
261    cache_prefix: String,
262    cache_if: CachePredicate,
263}
264
265#[derive(Serialize, Deserialize)]
266struct CachedResponse {
267    status: u16,
268    headers: Vec<(String, String)>,
269    body: Vec<u8>,
270}
271
272impl<S, B> Transform<S, ServiceRequest> for RedisCacheMiddleware
273where
274    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
275    S::Future: 'static,
276    B: 'static + Clone + MessageBody,
277{
278    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
279    type Error = Error;
280    type Transform = RedisCacheMiddlewareService<S>;
281    type InitError = ();
282    type Future = Ready<Result<Self::Transform, Self::InitError>>;
283
284    fn new_transform(&self, service: S) -> Self::Future {
285        ready(Ok(RedisCacheMiddlewareService {
286            service,
287            redis_conn: self.redis_conn.clone(),
288            redis_url: self.redis_url.clone(),
289            ttl: self.ttl,
290            max_cacheable_size: self.max_cacheable_size,
291            cache_prefix: self.cache_prefix.clone(),
292            cache_if: self.cache_if.clone(),
293        }))
294    }
295}
296
297impl<S, B> Service<ServiceRequest> for RedisCacheMiddlewareService<S>
298where
299    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
300    S::Future: 'static,
301    B: actix_web::body::MessageBody + 'static + Clone,
302{
303    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
304    type Error = Error;
305    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
306
307    forward_ready!(service);
308
309    fn call(&self, mut req: ServiceRequest) -> Self::Future {
310        if let Some(cache_control) = req.headers().get("Cache-Control") {
311            if let Ok(cache_control_str) = cache_control.to_str() {
312                if cache_control_str.contains("no-cache") || cache_control_str.contains("no-store")
313                {
314                    let fut = self.service.call(req);
315                    return Box::pin(async move {
316                        let res = fut.await?;
317                        Ok(res.map_body(|_, b| EitherBody::left(b)))
318                    });
319                }
320            }
321        }
322
323        let redis_url = self.redis_url.clone();
324        let mut redis_conn = self.redis_conn.clone();
325        let expiration = self.ttl;
326        let max_cacheable_size = self.max_cacheable_size;
327        let cache_prefix = self.cache_prefix.clone();
328        let service = self.service.clone();
329        let cache_if = self.cache_if.clone();
330
331        Box::pin(async move {
332            let body_bytes = req
333                .take_payload()
334                .fold(BytesMut::new(), move |mut body, chunk| async {
335                    if let Ok(chunk) = chunk {
336                        body.extend_from_slice(&chunk);
337                    }
338                    body
339                })
340                .await;
341
342            let cache_ctx = CacheDecisionContext {
343                method: req.method().as_str(),
344                path: req.path(),
345                query_string: req.query_string(),
346                headers: req.headers(),
347                body: &body_bytes,
348            };
349
350            let should_cache = cache_if(&cache_ctx);
351
352            req.set_payload(Payload::from(Bytes::from(body_bytes.clone())));
353
354            let base_key = if body_bytes.is_empty() {
355                format!(
356                    "{}:{}:{}",
357                    req.method().as_str(),
358                    req.path(),
359                    req.query_string()
360                )
361            } else {
362                let body_hash = hex::encode(Sha256::digest(&body_bytes));
363                format!(
364                    "{}:{}:{}:{}",
365                    req.method().as_str(),
366                    req.path(),
367                    req.query_string(),
368                    body_hash
369                )
370            };
371
372            let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
373            let cache_key = format!("{}{}", cache_prefix, hashed_key);
374
375            // Only try to get from cache if we're considering caching for this request
376            let cached_result: Option<String> = if should_cache {
377                if redis_conn.is_none() {
378                    let client = redis::Client::open(redis_url.as_str())
379                        .expect("Failed to connect to Redis");
380
381                    let conn = client
382                        .get_multiplexed_async_connection()
383                        .await
384                        .expect("Failed to get Redis connection");
385
386                    redis_conn = Some(conn);
387                }
388
389                let conn = redis_conn.as_mut().unwrap();
390                conn.get(&cache_key).await.unwrap_or(None)
391            } else {
392                None
393            };
394
395            if let Some(cached_data) = cached_result {
396                log::debug!("Cache hit for {}", cache_key);
397
398                // Deserialize cached response
399                match serde_json::from_str::<CachedResponse>(&cached_data) {
400                    Ok(cached_response) => {
401                        let mut response = actix_web::HttpResponse::build(
402                            actix_web::http::StatusCode::from_u16(cached_response.status)
403                                .unwrap_or(actix_web::http::StatusCode::OK),
404                        );
405
406                        for (name, value) in cached_response.headers {
407                            response.insert_header((name, value));
408                        }
409
410                        response.insert_header(("X-Cache", "HIT"));
411
412                        let resp = response.body(cached_response.body);
413                        return Ok(req
414                            .into_response(resp)
415                            .map_body(|_, b| EitherBody::right(BoxBody::new(b))));
416                    }
417                    Err(e) => {
418                        log::error!("Failed to deserialize cached response: {}", e);
419                    }
420                }
421            }
422
423            log::debug!("Cache miss for {}", cache_key);
424
425            let service_result = service.call(req).await?;
426
427            // Only store in cache if we're considering caching and the response is successful
428            if should_cache && service_result.status().is_success() {
429                let res = service_result.response();
430
431                let status = res.status().as_u16();
432
433                let headers = res
434                    .headers()
435                    .iter()
436                    .filter(|(name, _)| {
437                        !["connection", "transfer-encoding", "content-length"]
438                            .contains(&name.as_str().to_lowercase().as_str())
439                    })
440                    .map(|(name, value)| {
441                        (
442                            name.to_string(),
443                            value.to_str().unwrap_or_default().to_string(),
444                        )
445                    })
446                    .collect::<Vec<_>>();
447
448                if let Ok(body) = res.body().clone().try_into_bytes() {
449                    if !body.is_empty() && body.len() <= max_cacheable_size {
450                        let cached_response = CachedResponse {
451                            status,
452                            headers,
453                            body: body.to_vec(),
454                        };
455
456                        if let Ok(serialized) = serde_json::to_string(&cached_response) {
457                            if redis_conn.is_none() {
458                                let client = redis::Client::open(redis_url.as_str())
459                                    .expect("Failed to connect to Redis");
460
461                                let conn = client
462                                    .get_multiplexed_async_connection()
463                                    .await
464                                    .expect("Failed to get Redis connection");
465
466                                redis_conn = Some(conn);
467                            }
468
469                            let conn = redis_conn.as_mut().unwrap();
470                            let _: Result<(), redis::RedisError> =
471                                conn.set_ex(cache_key, serialized, expiration).await;
472                        }
473                    }
474                }
475            }
476
477            Ok(service_result.map_body(|_, b| EitherBody::left(b)))
478        })
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use actix_web::{http::header, test::TestRequest};
486
487    #[actix_web::test]
488    async fn test_builder_default_values() {
489        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
490        assert_eq!(builder.ttl, 3600);
491        assert_eq!(builder.max_cacheable_size, 1024 * 1024);
492        assert_eq!(builder.cache_prefix, "cache:");
493        assert_eq!(builder.redis_url, "redis://localhost");
494    }
495
496    #[actix_web::test]
497    async fn test_builder_custom_values() {
498        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
499            .ttl(60)
500            .max_cacheable_size(512 * 1024)
501            .cache_prefix("custom:");
502
503        assert_eq!(builder.ttl, 60);
504        assert_eq!(builder.max_cacheable_size, 512 * 1024);
505        assert_eq!(builder.cache_prefix, "custom:");
506    }
507
508    #[actix_web::test]
509    async fn test_builder_custom_predicate() {
510        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
511            .cache_if(|ctx| ctx.method == "GET");
512
513        // Test the predicate
514        let get_ctx = CacheDecisionContext {
515            method: "GET",
516            path: "/test",
517            query_string: "",
518            headers: &header::HeaderMap::new(),
519            body: &[],
520        };
521
522        let post_ctx = CacheDecisionContext {
523            method: "POST",
524            path: "/test",
525            query_string: "",
526            headers: &header::HeaderMap::new(),
527            body: &[],
528        };
529
530        // The predicate should now only allow GET requests
531        assert!((builder.cache_if)(&get_ctx));
532        assert!(!(builder.cache_if)(&post_ctx));
533    }
534
535    #[actix_web::test]
536    async fn test_cache_key_generation() {
537        // Create a simple request
538        let req = TestRequest::get().uri("/test").to_srv_request();
539
540        // Extract the relevant parts for key generation
541        let method = req.method().as_str();
542        let path = req.path();
543        let query_string = req.query_string();
544
545        // Generate key manually as done in the middleware
546        let base_key = format!("{}:{}:{}", method, path, query_string);
547        let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
548        let cache_key = format!("test:{}", hashed_key);
549
550        // Now verify this matches what our middleware would generate
551        let expected_key = format!(
552            "test:{}",
553            hex::encode(Sha256::digest("GET:/test:".to_string().as_bytes()))
554        );
555
556        assert_eq!(cache_key, expected_key);
557    }
558
559    #[actix_web::test]
560    async fn test_cache_key_with_body() {
561        // Test case for when request has a body
562        let body_bytes = b"test body";
563        let body_hash = hex::encode(Sha256::digest(body_bytes));
564
565        // Generate key manually as done in the middleware
566        let base_key = format!("{}:{}:{}:{}", "POST", "/test", "", body_hash);
567        let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
568        let cache_key = format!("test:{}", hashed_key);
569
570        // Expected key when body is present
571        let expected_key = format!(
572            "test:{}",
573            hex::encode(Sha256::digest(
574                format!("POST:/test::{}", body_hash).as_bytes()
575            ))
576        );
577
578        assert_eq!(cache_key, expected_key);
579    }
580
581    #[actix_web::test]
582    async fn test_cacheable_methods() {
583        // Test different HTTP methods with default predicate
584        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
585        let default_predicate = builder.cache_if;
586
587        let methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"];
588
589        for method in methods {
590            let ctx = CacheDecisionContext {
591                method,
592                path: "/test",
593                query_string: "",
594                headers: &header::HeaderMap::new(),
595                body: &[],
596            };
597
598            // Default predicate should cache all methods
599            assert!(
600                (default_predicate)(&ctx),
601                "Method {} should be cacheable by default",
602                method
603            );
604        }
605
606        // Test with a custom predicate that only caches GET and HEAD
607        let custom_builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
608            .cache_if(|ctx| matches!(ctx.method, "GET" | "HEAD"));
609
610        for method in methods {
611            let ctx = CacheDecisionContext {
612                method,
613                path: "/test",
614                query_string: "",
615                headers: &header::HeaderMap::new(),
616                body: &[],
617            };
618
619            // Check if method should be cached according to our custom predicate
620            let should_cache = matches!(method, "GET" | "HEAD");
621            assert_eq!(
622                (custom_builder.cache_if)(&ctx),
623                should_cache,
624                "Method {} should be cacheable: {}",
625                method,
626                should_cache
627            );
628        }
629    }
630
631    #[actix_web::test]
632    async fn test_predicate_with_headers() {
633        // Test predicate behavior with different headers
634
635        // Create a predicate that doesn't cache requests with Authorization header
636        let predicate = |ctx: &CacheDecisionContext| !ctx.headers.contains_key("Authorization");
637
638        // Test with empty headers
639        let mut headers = header::HeaderMap::new();
640        let ctx_no_auth = CacheDecisionContext {
641            method: "GET",
642            path: "/test",
643            query_string: "",
644            headers: &headers,
645            body: &[],
646        };
647
648        assert!(
649            predicate(&ctx_no_auth),
650            "Request without Authorization should be cached"
651        );
652
653        // Test with Authorization header
654        headers.insert(
655            header::AUTHORIZATION,
656            header::HeaderValue::from_static("Bearer token"),
657        );
658
659        let ctx_with_auth = CacheDecisionContext {
660            method: "GET",
661            path: "/test",
662            query_string: "",
663            headers: &headers,
664            body: &[],
665        };
666
667        assert!(
668            !predicate(&ctx_with_auth),
669            "Request with Authorization should not be cached"
670        );
671    }
672
673    #[actix_web::test]
674    async fn test_predicate_with_path_patterns() {
675        // Test predicate behavior with different path patterns
676
677        // Create a predicate that doesn't cache admin paths
678        let predicate = |ctx: &CacheDecisionContext| {
679            !ctx.path.starts_with("/admin") && !ctx.path.contains("/private/")
680        };
681
682        // Test paths that should be cached
683        let cacheable_paths = ["/", "/api/users", "/public/resource", "/api/v1/data"];
684
685        for path in cacheable_paths {
686            let ctx = CacheDecisionContext {
687                method: "GET",
688                path,
689                query_string: "",
690                headers: &header::HeaderMap::new(),
691                body: &[],
692            };
693
694            assert!(predicate(&ctx), "Path {} should be cacheable", path);
695        }
696
697        // Test paths that should not be cached
698        let non_cacheable_paths = ["/admin", "/admin/users", "/users/private/profile"];
699
700        for path in non_cacheable_paths {
701            let ctx = CacheDecisionContext {
702                method: "GET",
703                path,
704                query_string: "",
705                headers: &header::HeaderMap::new(),
706                body: &[],
707            };
708
709            assert!(!predicate(&ctx), "Path {} should not be cacheable", path);
710        }
711    }
712
713    #[actix_web::test]
714    async fn test_cached_response_serialization() {
715        // Test that CachedResponse can be properly serialized and deserialized
716        let cached_response = CachedResponse {
717            status: 200,
718            headers: vec![
719                ("Content-Type".to_string(), "text/plain".to_string()),
720                ("X-Test".to_string(), "value".to_string()),
721            ],
722            body: b"test response".to_vec(),
723        };
724
725        // Serialize
726        let serialized = serde_json::to_string(&cached_response).unwrap();
727
728        // Deserialize
729        let deserialized: CachedResponse = serde_json::from_str(&serialized).unwrap();
730
731        // Verify fields match
732        assert_eq!(deserialized.status, 200);
733        assert_eq!(deserialized.headers.len(), 2);
734        assert_eq!(deserialized.headers[0].0, "Content-Type");
735        assert_eq!(deserialized.headers[0].1, "text/plain");
736        assert_eq!(deserialized.headers[1].0, "X-Test");
737        assert_eq!(deserialized.headers[1].1, "value");
738        assert_eq!(deserialized.body, b"test response");
739    }
740}