actix_request_reply_cache/
lib.rs

1#![warn(missing_docs)]
2//! # Actix Request-Reply Cache
3//!
4//! A Redis-backed caching middleware for Actix Web that enables response caching.
5//!
6//! This library implements efficient HTTP response caching using Redis as a backend store,
7//! with functionality for fine-grained cache control through predicates that can examine
8//! request context to determine cacheability.
9//!
10//! ## Features
11//!
12//! - Redis-backed HTTP response caching
13//! - Configurable TTL (time-to-live) for cached responses
14//! - Customizable cache key prefix
15//! - Maximum cacheable response size configuration
16//! - Flexible cache control through predicate functions
17//! - Respects standard HTTP cache control headers
18//!
19//! ## Example
20//!
21//! ```rust
22//! use actix_web::{web, App, HttpServer};
23//! use actix_request_reply_cache::RedisCacheMiddlewareBuilder;
24//!
25//! #[actix_web::main]
26//! async fn main() -> std::io::Result<()> {
27//!     // Create the cache middleware
28//!     let cache = RedisCacheMiddlewareBuilder::new("redis://127.0.0.1:6379")
29//!         .ttl(60)  // Cache for 60 seconds
30//!         .cache_if(|ctx| {
31//!             // Only cache GET requests without Authorization header
32//!             ctx.method == "GET" && !ctx.headers.contains_key("Authorization")
33//!         })
34//!         .build()
35//!         .await;
36//!         
37//!     HttpServer::new(move || {
38//!         App::new()
39//!             .wrap(cache.clone())
40//!             .service(web::resource("/").to(|| async { "Hello world!" }))
41//!     })
42//!     .bind(("127.0.0.1", 8080))?
43//!     .run()
44//!     .await
45//! }
46//! ```
47use actix_web::{
48    body::{BodySize, BoxBody, EitherBody, MessageBody},
49    dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform},
50    http::header::HeaderMap,
51    web::{Bytes, BytesMut},
52    Error, HttpMessage,
53};
54use futures::{
55    future::{ready, LocalBoxFuture, Ready},
56    StreamExt,
57};
58use pin_project_lite::pin_project;
59use redis::{aio::MultiplexedConnection, AsyncCommands};
60use serde::{Deserialize, Serialize};
61use sha2::{Digest, Sha256};
62use std::{future::Future, marker::PhantomData, pin::Pin, rc::Rc};
63use std::{
64    sync::Arc,
65    task::{Context, Poll},
66};
67
68/// Context containing request information for cache operations.
69///
70/// This struct contains information about the current request that can be used for:
71/// - Making caching decisions through predicate functions
72/// - Generating custom cache keys
73pub struct CacheContext<'a> {
74    /// The HTTP method of the request (e.g., "GET", "POST")
75    pub method: &'a str,
76    /// The request path
77    pub path: &'a str,
78    /// The query string from the request URL
79    pub query_string: &'a str,
80    /// HTTP headers from the request
81    pub headers: &'a HeaderMap,
82    /// The request body as a byte slice
83    pub body: &'a serde_json::Value,
84}
85
86/// Function type for cache decision predicates.
87///
88/// This type represents functions that take a `CacheDecisionContext` and return
89/// a boolean indicating whether the response should be cached.
90type CachePredicate = Arc<dyn Fn(&CacheContext) -> bool + Send + Sync>;
91
92/// Function type for custom cache key generation.
93///
94/// This type represents functions that take a `CacheDecisionContext` and return
95/// a string to be used as the base for the cache key.
96type CacheKeyFn = Arc<dyn Fn(&CacheContext) -> String + Send + Sync>;
97
98/// Redis-backed caching middleware for Actix Web.
99///
100/// This middleware intercepts responses, caches them in Redis, and serves
101/// cached responses for subsequent matching requests when available.
102pub struct RedisCacheMiddleware {
103    redis_conn: Option<MultiplexedConnection>,
104    redis_url: String,
105    ttl: u64,
106    max_cacheable_size: usize,
107    cache_prefix: String,
108    cache_if: CachePredicate,
109    cache_key_fn: Option<CacheKeyFn>,
110}
111
112/// Builder for configuring and creating the `RedisCacheMiddleware`.
113///
114/// Provides a fluent interface for configuring cache parameters such as TTL,
115/// maximum cacheable size, cache key prefix, and cache decision predicates.
116pub struct RedisCacheMiddlewareBuilder {
117    redis_url: String,
118    ttl: u64,
119    max_cacheable_size: usize,
120    cache_prefix: String,
121    cache_if: CachePredicate,
122    cache_key_fn: Option<CacheKeyFn>,
123}
124
125impl RedisCacheMiddlewareBuilder {
126    /// Creates a new builder with the given Redis URL.
127    ///
128    /// # Arguments
129    ///
130    /// * `redis_url` - The Redis connection URL (e.g., "redis://127.0.0.1:6379")
131    ///
132    /// # Returns
133    ///
134    /// A new `RedisCacheMiddlewareBuilder` with default settings:
135    /// - TTL: 3600 seconds (1 hour)
136    /// - Max cacheable size: 1MB
137    /// - Cache prefix: "cache:"
138    /// - Cache predicate: cache all responses
139    pub fn new(redis_url: impl Into<String>) -> Self {
140        Self {
141            redis_url: redis_url.into(),
142            ttl: 3600,                       // 1 hour default
143            max_cacheable_size: 1024 * 1024, // 1MB default
144            cache_prefix: "cache:".to_string(),
145            cache_if: Arc::new(|_| true), // Default: cache everything
146            cache_key_fn: None,           // Default: use standard key generation
147        }
148    }
149
150    /// Sets the TTL (time-to-live) for cached responses in seconds.
151    ///
152    /// # Arguments
153    ///
154    /// * `seconds` - The number of seconds a response should remain in the cache
155    ///
156    /// # Returns
157    ///
158    /// Self for method chaining
159    pub fn ttl(mut self, seconds: u64) -> Self {
160        self.ttl = seconds;
161        self
162    }
163
164    /// Sets the maximum size of responses that can be cached, in bytes.
165    ///
166    /// Responses larger than this size will not be cached.
167    ///
168    /// # Arguments
169    ///
170    /// * `bytes` - The maximum cacheable response size in bytes
171    ///
172    /// # Returns
173    ///
174    /// Self for method chaining
175    pub fn max_cacheable_size(mut self, bytes: usize) -> Self {
176        self.max_cacheable_size = bytes;
177        self
178    }
179
180    /// Sets the prefix used for Redis cache keys.
181    ///
182    /// # Arguments
183    ///
184    /// * `prefix` - The string prefix to use for all cache keys
185    ///
186    /// # Returns
187    ///
188    /// Self for method chaining
189    pub fn cache_prefix(mut self, prefix: impl Into<String>) -> Self {
190        self.cache_prefix = prefix.into();
191        self
192    }
193
194    /// Set a predicate function to determine if a response should be cached
195    ///
196    /// Example:
197    /// ```
198    /// builder.cache_if(|ctx| {
199    ///     // Only cache GET requests
200    ///     if ctx.method != "GET" {
201    ///         return false;
202    ///     }
203    ///     
204    ///     // Don't cache if Authorization header is present
205    ///     if ctx.headers.contains_key("Authorization") {
206    ///         return false;
207    ///     }
208    ///     
209    ///     // Don't cache responses to paths that start with /admin
210    ///     if ctx.path.starts_with("/admin") {
211    ///         return false;
212    ///     }
213    ///
214    ///     // Don't cache for a specific route if its body contains some field
215    ///     if ctx.path.starts_with("/api/users") && ctx.method == "POST" {
216    ///         return ctx.body.get("role").and_then(|r| r.as_str()) != Some("admin");
217    ///     }
218    ///     true
219    /// })
220    /// ```
221    pub fn cache_if<F>(mut self, predicate: F) -> Self
222    where
223        F: Fn(&CacheContext) -> bool + Send + Sync + 'static,
224    {
225        self.cache_if = Arc::new(predicate);
226        self
227    }
228
229    /// Set a custom function to determine the cache key.
230    ///
231    /// By default, cache keys are based on HTTP method, path, query string, and
232    /// (if present) a hash of the request body. This method lets you specify a custom
233    /// function to generate the base key before hashing.
234    ///
235    /// Example:
236    /// ```
237    /// builder.with_cache_key(|ctx| {
238    ///     // Only use method and path for the cache key (ignore query params and body)
239    ///     format!("{}:{}", ctx.method, ctx.path)
240    ///     
241    ///     // Or include specific query parameters
242    ///     // let user_id = ctx.query_string.split('&')
243    ///     //     .find(|p| p.starts_with("user_id="))
244    ///     //     .unwrap_or("");
245    ///     // format!("{}:{}:{}", ctx.method, ctx.path, user_id)
246    /// })
247    /// ```
248    pub fn with_cache_key<F>(mut self, key_fn: F) -> Self
249    where
250        F: Fn(&CacheContext) -> String + Send + Sync + 'static,
251    {
252        self.cache_key_fn = Some(Arc::new(key_fn));
253        self
254    }
255
256    /// Builds and returns the configured `RedisCacheMiddleware`.
257    ///
258    /// # Returns
259    ///
260    /// A new `RedisCacheMiddleware` instance configured with the settings from this builder.
261    pub fn build(self) -> RedisCacheMiddleware {
262        RedisCacheMiddleware {
263            redis_conn: None,
264            redis_url: self.redis_url,
265            ttl: self.ttl,
266            max_cacheable_size: self.max_cacheable_size,
267            cache_prefix: self.cache_prefix,
268            cache_if: self.cache_if,
269            cache_key_fn: self.cache_key_fn,
270        }
271    }
272}
273
274impl RedisCacheMiddleware {
275    /// Creates a new `RedisCacheMiddleware` with default settings.
276    ///
277    /// This is a convenience method that uses the builder with default settings.
278    ///
279    /// # Arguments
280    ///
281    /// * `redis_url` - The Redis connection URL
282    ///
283    /// # Returns
284    ///
285    /// A new `RedisCacheMiddleware` instance with default settings.
286    pub fn new(redis_url: &str) -> Self {
287        RedisCacheMiddlewareBuilder::new(redis_url).build()
288    }
289}
290
291/// Service implementation for the Redis cache middleware.
292///
293/// This struct is created by the `RedisCacheMiddleware` and handles
294/// the actual interception of requests and responses for caching.
295pub struct RedisCacheMiddlewareService<S> {
296    service: Rc<S>,
297    redis_conn: Option<MultiplexedConnection>,
298    redis_url: String,
299    ttl: u64,
300    max_cacheable_size: usize,
301    cache_prefix: String,
302    cache_if: CachePredicate,
303    cache_key_fn: Option<CacheKeyFn>,
304}
305
306#[derive(Serialize, Deserialize)]
307struct CachedResponse {
308    status: u16,
309    headers: Vec<(String, String)>,
310    body: Vec<u8>,
311}
312
313impl<S, B> Transform<S, ServiceRequest> for RedisCacheMiddleware
314where
315    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
316    S::Future: 'static,
317    B: 'static + MessageBody,
318{
319    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
320    type Error = Error;
321    type Transform = RedisCacheMiddlewareService<S>;
322    type InitError = ();
323    type Future = Ready<Result<Self::Transform, Self::InitError>>;
324
325    /// Creates a new transform of the input service.
326    fn new_transform(&self, service: S) -> Self::Future {
327        ready(Ok(RedisCacheMiddlewareService {
328            service: Rc::new(service),
329            redis_conn: self.redis_conn.clone(),
330            redis_url: self.redis_url.clone(),
331            ttl: self.ttl,
332            max_cacheable_size: self.max_cacheable_size,
333            cache_prefix: self.cache_prefix.clone(),
334            cache_if: self.cache_if.clone(),
335            cache_key_fn: self.cache_key_fn.clone(),
336        }))
337    }
338}
339
340// Define the wrapper structure for your response future
341pin_project! {
342    struct CacheResponseFuture<S, B>
343    where
344        B: MessageBody,
345        S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
346    {
347        #[pin]
348        fut: S::Future,
349        should_cache: bool,
350        cache_key: String,
351        redis_conn: Option<MultiplexedConnection>,
352        redis_url: String,
353        ttl: u64,
354        max_cacheable_size: usize,
355        _marker: PhantomData<B>,
356    }
357}
358
359// Implement the Future trait for your response future
360impl<S, B> Future for CacheResponseFuture<S, B>
361where
362    B: MessageBody + 'static,
363    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
364{
365    type Output = Result<ServiceResponse<EitherBody<B, BoxBody>>, Error>;
366
367    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
368        let this = self.project();
369
370        let res = futures_util::ready!(this.fut.poll(cx))?;
371
372        let status = res.status();
373        let headers = res.headers().clone();
374        let should_cache = *this.should_cache && status.is_success();
375
376        if !should_cache {
377            return Poll::Ready(Ok(res.map_body(|_, b| EitherBody::left(b))));
378        }
379
380        let cache_key = this.cache_key.clone();
381        let redis_url = this.redis_url.clone();
382        let redis_conn = this.redis_conn.clone();
383        let ttl = *this.ttl;
384        let max_size = *this.max_cacheable_size;
385
386        let res = res.map_body(move |_, body| {
387            let filtered_headers = headers
388                .iter()
389                .filter(|(name, _)| {
390                    !["connection", "transfer-encoding", "content-length"]
391                        .contains(&name.as_str().to_lowercase().as_str())
392                })
393                .map(|(name, value)| {
394                    (
395                        name.to_string(),
396                        value.to_str().unwrap_or_default().to_string(),
397                    )
398                })
399                .collect::<Vec<_>>();
400
401            EitherBody::right(BoxBody::new(CacheableBody {
402                body: body.boxed(),
403                status: status.as_u16(),
404                headers: filtered_headers,
405                body_accum: BytesMut::new(),
406                cache_key,
407                redis_conn,
408                redis_url,
409                ttl,
410                max_size,
411            }))
412        });
413
414        Poll::Ready(Ok(res))
415    }
416}
417
418// Define the body wrapper that will accumulate data
419pin_project! {
420    struct CacheableBody {
421        #[pin]
422        body: BoxBody,
423        status: u16,
424        headers: Vec<(String, String)>,
425        body_accum: BytesMut,
426        cache_key: String,
427        redis_conn: Option<MultiplexedConnection>,
428        redis_url: String,
429        ttl: u64,
430        max_size: usize,
431    }
432
433    impl PinnedDrop for CacheableBody {
434        fn drop(this: Pin<&mut Self>) {
435            let this = this.project();
436
437            let body_bytes = this.body_accum.clone().freeze();
438            let status = *this.status;
439            let headers = this.headers.clone();
440            let cache_key = this.cache_key.clone();
441            let mut redis_conn = this.redis_conn.take();
442            let redis_url = this.redis_url.clone();
443            let ttl = *this.ttl;
444            let max_size = *this.max_size;
445
446            if !body_bytes.is_empty() && body_bytes.len() <= max_size {
447                actix_web::rt::spawn(async move {
448                    let cached_response = CachedResponse {
449                        status,
450                        headers,
451                        body: body_bytes.to_vec(),
452                    };
453
454                    if let Ok(serialized) = rmp_serde::to_vec(&cached_response) {
455                        if redis_conn.is_none() {
456                            let client = redis::Client::open(redis_url.as_str())
457                                .expect("Failed to connect to Redis");
458
459                            let conn = client
460                                .get_multiplexed_async_connection()
461                                .await
462                                .expect("Failed to get Redis connection");
463
464                            redis_conn = Some(conn);
465                        }
466
467                        if let Some(conn) = redis_conn.as_mut() {
468                            let _: Result<(), redis::RedisError> =
469                                conn.set_ex(cache_key, serialized, ttl).await;
470                        }
471                    }
472                });
473            }
474        }
475    }
476}
477
478impl MessageBody for CacheableBody {
479    type Error = <BoxBody as MessageBody>::Error;
480
481    fn size(&self) -> BodySize {
482        self.body.size()
483    }
484
485    fn poll_next(
486        self: Pin<&mut Self>,
487        cx: &mut Context<'_>,
488    ) -> Poll<Option<Result<Bytes, Self::Error>>> {
489        let this = self.project();
490
491        // Poll the inner body and accumulate data
492        match this.body.poll_next(cx) {
493            Poll::Ready(Some(Ok(chunk))) => {
494                this.body_accum.extend_from_slice(&chunk);
495                Poll::Ready(Some(Ok(chunk)))
496            }
497            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
498            Poll::Ready(None) => Poll::Ready(None),
499            Poll::Pending => Poll::Pending,
500        }
501    }
502}
503
504impl<S, B> Service<ServiceRequest> for RedisCacheMiddlewareService<S>
505where
506    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
507    S::Future: 'static,
508    B: MessageBody + 'static,
509{
510    type Response = ServiceResponse<EitherBody<B, BoxBody>>;
511    type Error = Error;
512    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
513
514    forward_ready!(service);
515
516    fn call(&self, mut req: ServiceRequest) -> Self::Future {
517        // Skip caching if Cache-Control says no-cache/no-store
518        if let Some(cache_control) = req.headers().get("Cache-Control") {
519            if let Ok(cache_control_str) = cache_control.to_str() {
520                if cache_control_str.contains("no-cache") || cache_control_str.contains("no-store")
521                {
522                    let fut = self.service.call(req);
523                    return Box::pin(async move {
524                        let res = fut.await?;
525                        Ok(res.map_body(|_, b| EitherBody::left(b)))
526                    });
527                }
528            }
529        }
530
531        let redis_url = self.redis_url.clone();
532        let mut redis_conn = self.redis_conn.clone();
533        let expiration = self.ttl;
534        let max_cacheable_size = self.max_cacheable_size;
535        let cache_prefix = self.cache_prefix.clone();
536        let service = Rc::clone(&self.service);
537        let cache_if = self.cache_if.clone();
538        let cache_key_fn = self.cache_key_fn.clone();
539
540        Box::pin(async move {
541            let body_bytes = req
542                .take_payload()
543                .fold(BytesMut::new(), move |mut body, chunk| async {
544                    if let Ok(chunk) = chunk {
545                        body.extend_from_slice(&chunk);
546                    }
547                    body
548                })
549                .await;
550
551            let cache_ctx = CacheContext {
552                method: req.method().as_str(),
553                path: req.path(),
554                query_string: req.query_string(),
555                headers: req.headers(),
556                body: &serde_json::from_slice(&body_bytes).unwrap_or(serde_json::Value::Null),
557            };
558
559            let should_cache = cache_if(&cache_ctx);
560
561            // Generate cache key using custom function if provided, otherwise use the default
562            let base_key = if let Some(key_fn) = &cache_key_fn {
563                key_fn(&cache_ctx)
564            } else if body_bytes.is_empty() {
565                format!(
566                    "{}:{}:{}",
567                    req.method().as_str(),
568                    req.path(),
569                    req.query_string()
570                )
571            } else {
572                let body_hash = hex::encode(Sha256::digest(&body_bytes));
573                format!(
574                    "{}:{}:{}:{}",
575                    req.method().as_str(),
576                    req.path(),
577                    req.query_string(),
578                    body_hash
579                )
580            };
581
582            req.set_payload(Payload::from(Bytes::from(body_bytes.clone())));
583
584            let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
585            let cache_key = format!("{}{}", cache_prefix, hashed_key);
586
587            let cached_result: Option<Vec<u8>> = if should_cache {
588                if redis_conn.is_none() {
589                    let client = redis::Client::open(redis_url.as_str())
590                        .expect("Failed to connect to Redis");
591
592                    let conn = client
593                        .get_multiplexed_async_connection()
594                        .await
595                        .expect("Failed to get Redis connection");
596
597                    redis_conn = Some(conn);
598                }
599
600                let conn = redis_conn.as_mut().unwrap();
601                conn.get(&cache_key).await.unwrap_or(None)
602            } else {
603                None
604            };
605
606            if let Some(cached_data) = cached_result {
607                log::debug!("Cache hit for {}", cache_key);
608
609                match rmp_serde::from_slice::<CachedResponse>(&cached_data) {
610                    Ok(cached_response) => {
611                        let mut response = actix_web::HttpResponse::build(
612                            actix_web::http::StatusCode::from_u16(cached_response.status)
613                                .unwrap_or(actix_web::http::StatusCode::OK),
614                        );
615
616                        for (name, value) in cached_response.headers {
617                            response.insert_header((name, value));
618                        }
619
620                        response.insert_header(("X-Cache", "HIT"));
621
622                        let resp = response.body(cached_response.body);
623                        return Ok(req
624                            .into_response(resp)
625                            .map_body(|_, b| EitherBody::right(BoxBody::new(b))));
626                    }
627                    Err(e) => {
628                        log::error!("Failed to deserialize cached response: {}", e);
629                    }
630                }
631            }
632
633            log::debug!("Cache miss for {}", cache_key);
634            let future = CacheResponseFuture::<S, B> {
635                fut: service.call(req),
636                should_cache,
637                cache_key,
638                redis_conn,
639                redis_url,
640                ttl: expiration,
641                max_cacheable_size,
642                _marker: PhantomData,
643            };
644
645            future.await
646        })
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653    use actix_web::{http::header, test::TestRequest};
654
655    #[actix_web::test]
656    async fn test_builder_default_values() {
657        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
658        assert_eq!(builder.ttl, 3600);
659        assert_eq!(builder.max_cacheable_size, 1024 * 1024);
660        assert_eq!(builder.cache_prefix, "cache:");
661        assert_eq!(builder.redis_url, "redis://localhost");
662    }
663
664    #[actix_web::test]
665    async fn test_builder_custom_values() {
666        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
667            .ttl(60)
668            .max_cacheable_size(512 * 1024)
669            .cache_prefix("custom:");
670
671        assert_eq!(builder.ttl, 60);
672        assert_eq!(builder.max_cacheable_size, 512 * 1024);
673        assert_eq!(builder.cache_prefix, "custom:");
674    }
675
676    #[actix_web::test]
677    async fn test_builder_custom_predicate() {
678        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
679            .cache_if(|ctx| ctx.method == "GET");
680
681        // Test the predicate
682        let get_ctx = CacheContext {
683            method: "GET",
684            path: "/test",
685            query_string: "",
686            headers: &header::HeaderMap::new(),
687            body: &serde_json::Value::Null,
688        };
689
690        let post_ctx = CacheContext {
691            method: "POST",
692            path: "/test",
693            query_string: "",
694            headers: &header::HeaderMap::new(),
695            body: &serde_json::Value::Null,
696        };
697
698        // The predicate should now only allow GET requests
699        assert!((builder.cache_if)(&get_ctx));
700        assert!(!(builder.cache_if)(&post_ctx));
701    }
702
703    #[actix_web::test]
704    async fn test_cache_key_generation() {
705        // Create a simple request
706        let req = TestRequest::get().uri("/test").to_srv_request();
707
708        // Extract the relevant parts for key generation
709        let method = req.method().as_str();
710        let path = req.path();
711        let query_string = req.query_string();
712
713        // Generate key manually as done in the middleware
714        let base_key = format!("{}:{}:{}", method, path, query_string);
715        let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
716        let cache_key = format!("test:{}", hashed_key);
717
718        // Now verify this matches what our middleware would generate
719        let expected_key = format!(
720            "test:{}",
721            hex::encode(Sha256::digest("GET:/test:".to_string().as_bytes()))
722        );
723
724        assert_eq!(cache_key, expected_key);
725    }
726
727    #[actix_web::test]
728    async fn test_cache_key_with_body() {
729        // Test case for when request has a body
730        let body_bytes = b"test body";
731        let body_hash = hex::encode(Sha256::digest(body_bytes));
732
733        // Generate key manually as done in the middleware
734        let base_key = format!("{}:{}:{}:{}", "POST", "/test", "", body_hash);
735        let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
736        let cache_key = format!("test:{}", hashed_key);
737
738        // Expected key when body is present
739        let expected_key = format!(
740            "test:{}",
741            hex::encode(Sha256::digest(
742                format!("POST:/test::{}", body_hash).as_bytes()
743            ))
744        );
745
746        assert_eq!(cache_key, expected_key);
747    }
748
749    #[actix_web::test]
750    async fn test_cacheable_methods() {
751        // Test different HTTP methods with default predicate
752        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost");
753        let default_predicate = builder.cache_if;
754
755        let methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"];
756
757        for method in methods {
758            let ctx = CacheContext {
759                method,
760                path: "/test",
761                query_string: "",
762                headers: &header::HeaderMap::new(),
763                body: &serde_json::Value::Null,
764            };
765
766            // Default predicate should cache all methods
767            assert!(
768                (default_predicate)(&ctx),
769                "Method {} should be cacheable by default",
770                method
771            );
772        }
773
774        // Test with a custom predicate that only caches GET and HEAD
775        let custom_builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
776            .cache_if(|ctx| matches!(ctx.method, "GET" | "HEAD"));
777
778        for method in methods {
779            let ctx = CacheContext {
780                method,
781                path: "/test",
782                query_string: "",
783                headers: &header::HeaderMap::new(),
784                body: &serde_json::Value::Null,
785            };
786
787            // Check if method should be cached according to our custom predicate
788            let should_cache = matches!(method, "GET" | "HEAD");
789            assert_eq!(
790                (custom_builder.cache_if)(&ctx),
791                should_cache,
792                "Method {} should be cacheable: {}",
793                method,
794                should_cache
795            );
796        }
797    }
798
799    #[actix_web::test]
800    async fn test_predicate_with_headers() {
801        // Test predicate behavior with different headers
802
803        // Create a predicate that doesn't cache requests with Authorization header
804        let predicate = |ctx: &CacheContext| !ctx.headers.contains_key("Authorization");
805
806        // Test with empty headers
807        let mut headers = header::HeaderMap::new();
808        let ctx_no_auth = CacheContext {
809            method: "GET",
810            path: "/test",
811            query_string: "",
812            headers: &headers,
813            body: &serde_json::Value::Null,
814        };
815
816        assert!(
817            predicate(&ctx_no_auth),
818            "Request without Authorization should be cached"
819        );
820
821        // Test with Authorization header
822        headers.insert(
823            header::AUTHORIZATION,
824            header::HeaderValue::from_static("Bearer token"),
825        );
826
827        let ctx_with_auth = CacheContext {
828            method: "GET",
829            path: "/test",
830            query_string: "",
831            headers: &headers,
832            body: &serde_json::Value::Null,
833        };
834
835        assert!(
836            !predicate(&ctx_with_auth),
837            "Request with Authorization should not be cached"
838        );
839    }
840
841    #[actix_web::test]
842    async fn test_predicate_with_path_patterns() {
843        // Test predicate behavior with different path patterns
844
845        // Create a predicate that doesn't cache admin paths
846        let predicate =
847            |ctx: &CacheContext| !ctx.path.starts_with("/admin") && !ctx.path.contains("/private/");
848
849        // Test paths that should be cached
850        let cacheable_paths = ["/", "/api/users", "/public/resource", "/api/v1/data"];
851
852        for path in cacheable_paths {
853            let ctx = CacheContext {
854                method: "GET",
855                path,
856                query_string: "",
857                headers: &header::HeaderMap::new(),
858                body: &serde_json::Value::Null,
859            };
860
861            assert!(predicate(&ctx), "Path {} should be cacheable", path);
862        }
863
864        // Test paths that should not be cached
865        let non_cacheable_paths = ["/admin", "/admin/users", "/users/private/profile"];
866
867        for path in non_cacheable_paths {
868            let ctx = CacheContext {
869                method: "GET",
870                path,
871                query_string: "",
872                headers: &header::HeaderMap::new(),
873                body: &serde_json::Value::Null,
874            };
875
876            assert!(!predicate(&ctx), "Path {} should not be cacheable", path);
877        }
878    }
879
880    #[actix_web::test]
881    async fn test_cached_response_serialization() {
882        // Test that CachedResponse can be properly serialized and deserialized
883        let cached_response = CachedResponse {
884            status: 200,
885            headers: vec![
886                ("Content-Type".to_string(), "text/plain".to_string()),
887                ("X-Test".to_string(), "value".to_string()),
888            ],
889            body: b"test response".to_vec(),
890        };
891
892        // Serialize
893        let serialized = rmp_serde::to_vec(&cached_response).unwrap();
894
895        // Deserialize
896        let deserialized: CachedResponse = rmp_serde::from_slice(&serialized).unwrap();
897
898        // Verify fields match
899        assert_eq!(deserialized.status, 200);
900        assert_eq!(deserialized.headers.len(), 2);
901        assert_eq!(deserialized.headers[0].0, "Content-Type");
902        assert_eq!(deserialized.headers[0].1, "text/plain");
903        assert_eq!(deserialized.headers[1].0, "X-Test");
904        assert_eq!(deserialized.headers[1].1, "value");
905        assert_eq!(deserialized.body, b"test response");
906    }
907
908    #[actix_web::test]
909    async fn test_custom_cache_key() {
910        // Create a builder with a custom cache key function that only uses method and path
911        let builder = RedisCacheMiddlewareBuilder::new("redis://localhost")
912            .with_cache_key(|ctx| format!("{}:{}", ctx.method, ctx.path));
913
914        // Create a function that extracts our cache key generation logic
915        let get_key = |method: &str, path: &str, query: &str, body: &[u8]| {
916            // Create a CacheDecisionContext
917            let headers = header::HeaderMap::new();
918            let body_json = serde_json::from_slice(body).unwrap_or(serde_json::Value::Null);
919            let ctx = CacheContext {
920                method,
921                path,
922                query_string: query,
923                headers: &headers,
924                body: &body_json,
925            };
926
927            // Get the base key using our cache key function
928            let base_key = if let Some(key_fn) = &builder.cache_key_fn {
929                key_fn(&ctx)
930            } else {
931                format!("{}:{}:{}", method, path, query)
932            };
933
934            // Hash it and apply prefix as done in the middleware
935            let hashed_key = hex::encode(Sha256::digest(base_key.as_bytes()));
936            format!("{}:{}", builder.cache_prefix, hashed_key)
937        };
938
939        // Test with different query strings that should now produce the same cache key
940        let key1 = get_key("GET", "/users", "", b"");
941        let key2 = get_key("GET", "/users", "page=1", b"");
942        let key3 = get_key("GET", "/users", "page=2", b"");
943
944        // Keys should be the same since our custom function ignores query string
945        assert_eq!(key1, key2);
946        assert_eq!(key1, key3);
947
948        // Test with different methods that should produce different cache keys
949        let key_get = get_key("GET", "/resource", "", b"");
950        let key_post = get_key("POST", "/resource", "", b"");
951
952        // Should be different keys
953        assert_ne!(key_get, key_post);
954    }
955}