actix_request_reply_cache/
lib.rs

1#![warn(missing_docs)]
2//! # Actix Request-Reply Cache
3//!
4//! A Redis-backed caching middleware for Actix Web that enables response caching.
5//!
6//! This library implements efficient HTTP response caching using Redis as a backend store,
7//! with functionality for fine-grained cache control through predicates that can examine
8//! request context to determine cacheability.
9//!
10//! ## Features
11//!
12//! - Redis-backed HTTP response caching
13//! - Configurable TTL (time-to-live) for cached responses
14//! - Customizable cache key prefix
15//! - Maximum cacheable response size configuration
16//! - Flexible cache control through predicate functions
17//! - Respects standard HTTP cache control headers
18//!
19//! ## Example
20//!
21//! ```rust
22//! use actix_web::{web, App, HttpServer};
23//! use actix_request_reply_cache::RedisCacheMiddlewareBuilder;
24//!
25//! #[actix_web::main]
26//! async fn main() -> std::io::Result<()> {
27//!     // Create the cache middleware
28//!     let cache = RedisCacheMiddlewareBuilder::new("redis://127.0.0.1:6379")
29//!         .ttl(60)  // Cache for 60 seconds
30//!         .cache_if(|ctx| {
31//!             // Only cache GET requests without Authorization header
32//!             ctx.method == "GET" && !ctx.headers.contains_key("Authorization")
33//!         })
34//!         .build()
35//!         .await;
36//!         
37//!     HttpServer::new(move || {
38//!         App::new()
39//!             .wrap(cache.clone())
40//!             .service(web::resource("/").to(|| async { "Hello world!" }))
41//!     })
42//!     .bind(("127.0.0.1", 8080))?
43//!     .run()
44//!     .await
45//! }
46//! ```
47use actix_web::{
48    body::{BoxBody, EitherBody, MessageBody},
49    dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
50    http::header::HeaderMap,
51    web::{Bytes, BytesMut},
52    Error, HttpMessage,
53};
54use futures::{
55    future::{ready, LocalBoxFuture, Ready},
56    StreamExt,
57};
58use redis::{aio::MultiplexedConnection, AsyncCommands};
59use serde::{Deserialize, Serialize};
60use sha2::{Digest, Sha256};
61use std::sync::Arc;
62
63/// Context used to determine if a request/response should be cached.
64///
65/// This struct contains information about the current request that can be
66/// examined by cache predicate functions to make caching decisions.
67pub struct CacheDecisionContext<'a> {
68    /// The HTTP method of the request (e.g., "GET", "POST")
69    pub method: &'a str,
70    /// The request path
71    pub path: &'a str,
72    /// The query string from the request URL
73    pub query_string: &'a str,
74    /// HTTP headers from the request
75    pub headers: &'a HeaderMap,
76    /// The request body as a byte slice
77    pub body: &'a [u8],
78}
79
80/// Function type for cache decision predicates.
81///
82/// This type represents functions that take a `CacheDecisionContext` and return
83/// a boolean indicating whether the response should be cached.
84type CachePredicate = Arc<dyn Fn(&CacheDecisionContext) -> bool + Send + Sync>;
85
86/// Redis-backed caching middleware for Actix Web.
87///
88/// This middleware intercepts responses, caches them in Redis, and serves
89/// cached responses for subsequent matching requests when available.
90pub struct RedisCacheMiddleware {
91    redis_conn: MultiplexedConnection,
92    ttl: u64,
93    max_cacheable_size: usize,
94    cache_prefix: String,
95    cache_if: CachePredicate,
96}
97
98/// Builder for configuring and creating the `RedisCacheMiddleware`.
99///
100/// Provides a fluent interface for configuring cache parameters such as TTL,
101/// maximum cacheable size, cache key prefix, and cache decision predicates.
102pub struct RedisCacheMiddlewareBuilder {
103    redis_url: String,
104    ttl: u64,
105    max_cacheable_size: usize,
106    cache_prefix: String,
107    cache_if: CachePredicate,
108}
109
110impl RedisCacheMiddlewareBuilder {
111    /// Creates a new builder with the given Redis URL.
112    ///
113    /// # Arguments
114    ///
115    /// * `redis_url` - The Redis connection URL (e.g., "redis://127.0.0.1:6379")
116    ///
117    /// # Returns
118    ///
119    /// A new `RedisCacheMiddlewareBuilder` with default settings:
120    /// - TTL: 3600 seconds (1 hour)
121    /// - Max cacheable size: 1MB
122    /// - Cache prefix: "cache:"
123    /// - Cache predicate: cache all responses
124    pub fn new(redis_url: impl Into<String>) -> Self {
125        Self {
126            redis_url: redis_url.into(),
127            ttl: 3600,                       // 1 hour default
128            max_cacheable_size: 1024 * 1024, // 1MB default
129            cache_prefix: "cache:".to_string(),
130            cache_if: Arc::new(|_| true), // Default: cache everything
131        }
132    }
133
134    /// Sets the TTL (time-to-live) for cached responses in seconds.
135    ///
136    /// # Arguments
137    ///
138    /// * `seconds` - The number of seconds a response should remain in the cache
139    ///
140    /// # Returns
141    ///
142    /// Self for method chaining
143    pub fn ttl(mut self, seconds: u64) -> Self {
144        self.ttl = seconds;
145        self
146    }
147
148    /// Sets the maximum size of responses that can be cached, in bytes.
149    ///
150    /// Responses larger than this size will not be cached.
151    ///
152    /// # Arguments
153    ///
154    /// * `bytes` - The maximum cacheable response size in bytes
155    ///
156    /// # Returns
157    ///
158    /// Self for method chaining
159    pub fn max_cacheable_size(mut self, bytes: usize) -> Self {
160        self.max_cacheable_size = bytes;
161        self
162    }
163
164    /// Sets the prefix used for Redis cache keys.
165    ///
166    /// # Arguments
167    ///
168    /// * `prefix` - The string prefix to use for all cache keys
169    ///
170    /// # Returns
171    ///
172    /// Self for method chaining
173    pub fn cache_prefix(mut self, prefix: impl Into<String>) -> Self {
174        self.cache_prefix = prefix.into();
175        self
176    }
177
178    /// Set a predicate function to determine if a response should be cached
179    ///
180    /// Example:
181    /// ```
182    /// builder.cache_if(|ctx| {
183    ///     // Only cache GET requests
184    ///     if ctx.method != "GET" {
185    ///         return false;
186    ///     }
187    ///     
188    ///     // Don't cache if Authorization header is present
189    ///     if ctx.headers.contains_key("Authorization") {
190    ///         return false;
191    ///     }
192    ///     
193    ///     // Don't cache responses to paths that start with /admin
194    ///     if ctx.path.starts_with("/admin") {
195    ///         return false;
196    ///     }
197    ///
198    ///     // Don't cache for a specific route if its body contains some field
199    ///     if ctx.path.starts_with("/api/users") && ctx.method == "POST" {
200    ///        if let Ok(user_json) = serde_json::from_slice::<serde_json::Value>(ctx.body) {
201    ///            // Check properties in the JSON to make caching decisions
202    ///            return user_json.get("role").and_then(|r| r.as_str()) != Some("admin");
203    ///        }
204    ///    }
205    ///     true
206    /// })
207    /// ```
208    pub fn cache_if<F>(mut self, predicate: F) -> Self
209    where
210        F: Fn(&CacheDecisionContext) -> bool + Send + Sync + 'static,
211    {
212        self.cache_if = Arc::new(predicate);
213        self
214    }
215
216    /// Builds and returns the configured `RedisCacheMiddleware`.
217    ///
218    /// # Returns
219    ///
220    /// A new `RedisCacheMiddleware` instance configured with the settings from this builder.
221    ///
222    /// # Panics
223    ///
224    /// This method will panic if it cannot connect to Redis using the provided URL.
225    pub async fn build(self) -> RedisCacheMiddleware {
226        let client =
227            redis::Client::open(self.redis_url.as_str()).expect("Failed to connect to Redis");
228
229        let redis_conn = client
230            .get_multiplexed_async_connection()
231            .await
232            .expect("Failed to get Redis connection");
233
234        RedisCacheMiddleware {
235            redis_conn,
236            ttl: self.ttl,
237            max_cacheable_size: self.max_cacheable_size,
238            cache_prefix: self.cache_prefix,
239            cache_if: self.cache_if,
240        }
241    }
242}
243
244impl RedisCacheMiddleware {
245    /// Creates a new `RedisCacheMiddleware` with default settings.
246    ///
247    /// This is a convenience method that uses the builder with default settings.
248    ///
249    /// # Arguments
250    ///
251    /// * `redis_url` - The Redis connection URL
252    ///
253    /// # Returns
254    ///
255    /// A new `RedisCacheMiddleware` instance with default settings.
256    pub async fn new(redis_url: &str) -> Self {
257        RedisCacheMiddlewareBuilder::new(redis_url).build().await
258    }
259}
260
261/// Service implementation for the Redis cache middleware.
262///
263/// This struct is created by the `RedisCacheMiddleware` and handles
264/// the actual interception of requests and responses for caching.
265pub struct RedisCacheMiddlewareService<S> {
266    service: S,
267    redis_conn: MultiplexedConnection,
268    ttl: u64,
269    max_cacheable_size: usize,
270    cache_prefix: String,
271    cache_if: CachePredicate,
272}
273
274#[derive(Serialize, Deserialize)]
275struct CachedResponse {
276    status: u16,
277    headers: Vec<(String, String)>,
278    body: Vec<u8>,
279}
280
281impl<S, B> Transform<S, ServiceRequest> for RedisCacheMiddleware
282where
283    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
284    S::Future: 'static,
285    B: 'static + Clone + MessageBody,
286{
287    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
288    type Error = Error;
289    type Transform = RedisCacheMiddlewareService<S>;
290    type InitError = ();
291    type Future = Ready<Result<Self::Transform, Self::InitError>>;
292
293    fn new_transform(&self, service: S) -> Self::Future {
294        ready(Ok(RedisCacheMiddlewareService {
295            service,
296            redis_conn: self.redis_conn.clone(),
297            ttl: self.ttl,
298            max_cacheable_size: self.max_cacheable_size,
299            cache_prefix: self.cache_prefix.clone(),
300            cache_if: self.cache_if.clone(),
301        }))
302    }
303}
304
305impl<S, B> Service<ServiceRequest> for RedisCacheMiddlewareService<S>
306where
307    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static + Clone,
308    S::Future: 'static,
309    B: actix_web::body::MessageBody + 'static + Clone,
310{
311    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
312    type Error = Error;
313    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
314
315    forward_ready!(service);
316
317    fn call(&self, mut req: ServiceRequest) -> Self::Future {
318        if let Some(cache_control) = req.headers().get("Cache-Control") {
319            if let Ok(cache_control_str) = cache_control.to_str() {
320                if cache_control_str.contains("no-cache") || cache_control_str.contains("no-store")
321                {
322                    let fut = self.service.call(req);
323                    return Box::pin(async move {
324                        let res = fut.await?;
325                        Ok(res.map_body(|_, b| EitherBody::left(b)))
326                    });
327                }
328            }
329        }
330
331        let mut redis_conn = self.redis_conn.clone();
332        let expiration = self.ttl;
333        let max_cacheable_size = self.max_cacheable_size;
334        let cache_prefix = self.cache_prefix.clone();
335        let service = self.service.clone();
336        let cache_if = self.cache_if.clone();
337
338        Box::pin(async move {
339            let body_bytes = req
340                .take_payload()
341                .fold(BytesMut::new(), move |mut body, chunk| async {
342                    if let Ok(chunk) = chunk {
343                        body.extend_from_slice(&chunk);
344                    }
345                    body
346                })
347                .await;
348
349            let cache_ctx = CacheDecisionContext {
350                method: req.method().as_str(),
351                path: req.path(),
352                query_string: req.query_string(),
353                headers: req.headers(),
354                body: &body_bytes,
355            };
356
357            let should_cache = cache_if(&cache_ctx);
358
359            req.set_payload(Payload::from(Bytes::from(body_bytes.clone())));
360
361            let base_key = if body_bytes.is_empty() {
362                format!(
363                    "{}:{}:{}",
364                    req.method().as_str(),
365                    req.path(),
366                    req.query_string()
367                )
368            } else {
369                let body_hash = hex::encode(Sha256::digest(&body_bytes));
370                format!(
371                    "{}:{}:{}:{}",
372                    req.method().as_str(),
373                    req.path(),
374                    req.query_string(),
375                    body_hash
376                )
377            };
378
379            let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
380            let cache_key = format!("{}{}", cache_prefix, hashed_key);
381
382            // Only try to get from cache if we're considering caching for this request
383            let cached_result: Option<String> = if should_cache {
384                redis_conn.get(&cache_key).await.unwrap_or(None)
385            } else {
386                None
387            };
388
389            if let Some(cached_data) = cached_result {
390                log::debug!("Cache hit for {}", cache_key);
391
392                // Deserialize cached response
393                match serde_json::from_str::<CachedResponse>(&cached_data) {
394                    Ok(cached_response) => {
395                        let mut response = actix_web::HttpResponse::build(
396                            actix_web::http::StatusCode::from_u16(cached_response.status)
397                                .unwrap_or(actix_web::http::StatusCode::OK),
398                        );
399
400                        for (name, value) in cached_response.headers {
401                            response.insert_header((name, value));
402                        }
403
404                        response.insert_header(("X-Cache", "HIT"));
405
406                        let resp = response.body(cached_response.body);
407                        return Ok(req
408                            .into_response(resp)
409                            .map_body(|_, b| EitherBody::right(BoxBody::new(b))));
410                    }
411                    Err(e) => {
412                        log::error!("Failed to deserialize cached response: {}", e);
413                    }
414                }
415            }
416
417            log::debug!("Cache miss for {}", cache_key);
418
419            let service_result = service.call(req).await?;
420
421            // Only store in cache if we're considering caching and the response is successful
422            if should_cache && service_result.status().is_success() {
423                let res = service_result.response();
424
425                let status = res.status().as_u16();
426
427                let headers = res
428                    .headers()
429                    .iter()
430                    .filter(|(name, _)| {
431                        !["connection", "transfer-encoding", "content-length"]
432                            .contains(&name.as_str().to_lowercase().as_str())
433                    })
434                    .map(|(name, value)| {
435                        (
436                            name.to_string(),
437                            value.to_str().unwrap_or_default().to_string(),
438                        )
439                    })
440                    .collect::<Vec<_>>();
441
442                if let Ok(body) = res.body().clone().try_into_bytes() {
443                    if !body.is_empty() && body.len() <= max_cacheable_size {
444                        let cached_response = CachedResponse {
445                            status,
446                            headers,
447                            body: body.to_vec(),
448                        };
449
450                        if let Ok(serialized) = serde_json::to_string(&cached_response) {
451                            let _: Result<(), redis::RedisError> =
452                                redis_conn.set_ex(cache_key, serialized, expiration).await;
453                        }
454                    }
455                }
456            }
457
458            Ok(service_result.map_body(|_, b| EitherBody::left(b)))
459        })
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use actix_web::{http::header, test::TestRequest};
467
468    #[actix_web::test]
469    async fn test_builder_default_values() {
470        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
471        assert_eq!(builder.ttl, 3600);
472        assert_eq!(builder.max_cacheable_size, 1024 * 1024);
473        assert_eq!(builder.cache_prefix, "cache:");
474        assert_eq!(builder.redis_url, "redis://localhost");
475    }
476
477    #[actix_web::test]
478    async fn test_builder_custom_values() {
479        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
480            .ttl(60)
481            .max_cacheable_size(512 * 1024)
482            .cache_prefix("custom:");
483
484        assert_eq!(builder.ttl, 60);
485        assert_eq!(builder.max_cacheable_size, 512 * 1024);
486        assert_eq!(builder.cache_prefix, "custom:");
487    }
488
489    #[actix_web::test]
490    async fn test_builder_custom_predicate() {
491        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
492            .cache_if(|ctx| ctx.method == "GET");
493
494        // Test the predicate
495        let get_ctx = CacheDecisionContext {
496            method: "GET",
497            path: "/test",
498            query_string: "",
499            headers: &header::HeaderMap::new(),
500            body: &[],
501        };
502
503        let post_ctx = CacheDecisionContext {
504            method: "POST",
505            path: "/test",
506            query_string: "",
507            headers: &header::HeaderMap::new(),
508            body: &[],
509        };
510
511        // The predicate should now only allow GET requests
512        assert!((builder.cache_if)(&get_ctx));
513        assert!(!(builder.cache_if)(&post_ctx));
514    }
515
516    #[actix_web::test]
517    async fn test_cache_key_generation() {
518        // Create a simple request
519        let req = TestRequest::get().uri("/test").to_srv_request();
520
521        // Extract the relevant parts for key generation
522        let method = req.method().as_str();
523        let path = req.path();
524        let query_string = req.query_string();
525
526        // Generate key manually as done in the middleware
527        let base_key = format!("{}:{}:{}", method, path, query_string);
528        let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
529        let cache_key = format!("test:{}", hashed_key);
530
531        // Now verify this matches what our middleware would generate
532        let expected_key = format!(
533            "test:{}",
534            hex::encode(Sha256::digest("GET:/test:".to_string().as_bytes()))
535        );
536
537        assert_eq!(cache_key, expected_key);
538    }
539
540    #[actix_web::test]
541    async fn test_cache_key_with_body() {
542        // Test case for when request has a body
543        let body_bytes = b"test body";
544        let body_hash = hex::encode(Sha256::digest(body_bytes));
545
546        // Generate key manually as done in the middleware
547        let base_key = format!("{}:{}:{}:{}", "POST", "/test", "", body_hash);
548        let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
549        let cache_key = format!("test:{}", hashed_key);
550
551        // Expected key when body is present
552        let expected_key = format!(
553            "test:{}",
554            hex::encode(Sha256::digest(
555                format!("POST:/test::{}", body_hash).as_bytes()
556            ))
557        );
558
559        assert_eq!(cache_key, expected_key);
560    }
561
562    #[actix_web::test]
563    async fn test_cacheable_methods() {
564        // Test different HTTP methods with default predicate
565        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
566        let default_predicate = builder.cache_if;
567
568        let methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"];
569
570        for method in methods {
571            let ctx = CacheDecisionContext {
572                method,
573                path: "/test",
574                query_string: "",
575                headers: &header::HeaderMap::new(),
576                body: &[],
577            };
578
579            // Default predicate should cache all methods
580            assert!(
581                (default_predicate)(&ctx),
582                "Method {} should be cacheable by default",
583                method
584            );
585        }
586
587        // Test with a custom predicate that only caches GET and HEAD
588        let custom_builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
589            .cache_if(|ctx| matches!(ctx.method, "GET" | "HEAD"));
590
591        for method in methods {
592            let ctx = CacheDecisionContext {
593                method,
594                path: "/test",
595                query_string: "",
596                headers: &header::HeaderMap::new(),
597                body: &[],
598            };
599
600            // Check if method should be cached according to our custom predicate
601            let should_cache = matches!(method, "GET" | "HEAD");
602            assert_eq!(
603                (custom_builder.cache_if)(&ctx),
604                should_cache,
605                "Method {} should be cacheable: {}",
606                method,
607                should_cache
608            );
609        }
610    }
611
612    #[actix_web::test]
613    async fn test_predicate_with_headers() {
614        // Test predicate behavior with different headers
615
616        // Create a predicate that doesn't cache requests with Authorization header
617        let predicate = |ctx: &CacheDecisionContext| !ctx.headers.contains_key("Authorization");
618
619        // Test with empty headers
620        let mut headers = header::HeaderMap::new();
621        let ctx_no_auth = CacheDecisionContext {
622            method: "GET",
623            path: "/test",
624            query_string: "",
625            headers: &headers,
626            body: &[],
627        };
628
629        assert!(
630            predicate(&ctx_no_auth),
631            "Request without Authorization should be cached"
632        );
633
634        // Test with Authorization header
635        headers.insert(
636            header::AUTHORIZATION,
637            header::HeaderValue::from_static("Bearer token"),
638        );
639
640        let ctx_with_auth = CacheDecisionContext {
641            method: "GET",
642            path: "/test",
643            query_string: "",
644            headers: &headers,
645            body: &[],
646        };
647
648        assert!(
649            !predicate(&ctx_with_auth),
650            "Request with Authorization should not be cached"
651        );
652    }
653
654    #[actix_web::test]
655    async fn test_predicate_with_path_patterns() {
656        // Test predicate behavior with different path patterns
657
658        // Create a predicate that doesn't cache admin paths
659        let predicate = |ctx: &CacheDecisionContext| {
660            !ctx.path.starts_with("/admin") && !ctx.path.contains("/private/")
661        };
662
663        // Test paths that should be cached
664        let cacheable_paths = ["/", "/api/users", "/public/resource", "/api/v1/data"];
665
666        for path in cacheable_paths {
667            let ctx = CacheDecisionContext {
668                method: "GET",
669                path,
670                query_string: "",
671                headers: &header::HeaderMap::new(),
672                body: &[],
673            };
674
675            assert!(predicate(&ctx), "Path {} should be cacheable", path);
676        }
677
678        // Test paths that should not be cached
679        let non_cacheable_paths = ["/admin", "/admin/users", "/users/private/profile"];
680
681        for path in non_cacheable_paths {
682            let ctx = CacheDecisionContext {
683                method: "GET",
684                path,
685                query_string: "",
686                headers: &header::HeaderMap::new(),
687                body: &[],
688            };
689
690            assert!(!predicate(&ctx), "Path {} should not be cacheable", path);
691        }
692    }
693
694    #[actix_web::test]
695    async fn test_cached_response_serialization() {
696        // Test that CachedResponse can be properly serialized and deserialized
697        let cached_response = CachedResponse {
698            status: 200,
699            headers: vec![
700                ("Content-Type".to_string(), "text/plain".to_string()),
701                ("X-Test".to_string(), "value".to_string()),
702            ],
703            body: b"test response".to_vec(),
704        };
705
706        // Serialize
707        let serialized = serde_json::to_string(&cached_response).unwrap();
708
709        // Deserialize
710        let deserialized: CachedResponse = serde_json::from_str(&serialized).unwrap();
711
712        // Verify fields match
713        assert_eq!(deserialized.status, 200);
714        assert_eq!(deserialized.headers.len(), 2);
715        assert_eq!(deserialized.headers[0].0, "Content-Type");
716        assert_eq!(deserialized.headers[0].1, "text/plain");
717        assert_eq!(deserialized.headers[1].0, "X-Test");
718        assert_eq!(deserialized.headers[1].1, "value");
719        assert_eq!(deserialized.body, b"test response");
720    }
721}