Skip to main content

hitbox_reqwest/
middleware.rs

1//! Cache middleware for reqwest-middleware.
2//!
3//! This module provides [`CacheMiddleware`] which implements the
4//! [`reqwest_middleware::Middleware`] trait to add caching capabilities
5//! to reqwest HTTP clients.
6//!
7//! See the [crate-level documentation](crate) for usage examples.
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use hitbox::CacheStatusExt;
13use hitbox::backend::CacheBackend;
14use hitbox::concurrency::{ConcurrencyManager, NoopConcurrencyManager};
15use hitbox::config::CacheConfig;
16use hitbox::fsm::CacheFuture;
17use hitbox_core::DisabledOffload;
18use hitbox_http::{
19    BufferedBody, CacheableHttpRequest, CacheableHttpResponse, DEFAULT_CACHE_STATUS_HEADER,
20};
21use http::Extensions;
22use http::header::HeaderName;
23use reqwest::{Request, Response};
24use reqwest_middleware::{Middleware, Next, Result};
25
26use crate::upstream::{ReqwestUpstream, buffered_body_to_reqwest};
27
28/// Marker type for unset builder fields.
29pub struct NotSet;
30
31/// Cache middleware for reqwest-middleware.
32///
33/// This middleware intercepts HTTP requests and responses, caching them
34/// according to the configured policy and predicates.
35///
36/// Use [`CacheMiddleware::builder()`] to construct an instance.
37/// See the [crate-level documentation](crate) for usage examples.
38pub struct CacheMiddleware<B, C, CM> {
39    backend: Arc<B>,
40    configuration: C,
41    concurrency_manager: CM,
42    /// Header name for cache status (HIT/MISS/STALE).
43    cache_status_header: HeaderName,
44}
45
46impl<B, C, CM> CacheMiddleware<B, C, CM> {
47    /// Creates a new cache middleware with explicit components.
48    ///
49    /// For most use cases, prefer [`CacheMiddleware::builder()`] which provides
50    /// a more ergonomic API with sensible defaults.
51    pub fn new(
52        backend: Arc<B>,
53        configuration: C,
54        concurrency_manager: CM,
55        cache_status_header: HeaderName,
56    ) -> Self {
57        Self {
58            backend,
59            configuration,
60            concurrency_manager,
61            cache_status_header,
62        }
63    }
64}
65
66impl CacheMiddleware<NotSet, NotSet, NoopConcurrencyManager> {
67    /// Creates a new builder for constructing cache middleware.
68    ///
69    /// Both [`backend()`](CacheMiddlewareBuilder::backend) and
70    /// [`config()`](CacheMiddlewareBuilder::config) must be called
71    /// before [`build()`](CacheMiddlewareBuilder::build).
72    ///
73    /// See the [crate-level documentation](crate) for usage examples.
74    pub fn builder() -> CacheMiddlewareBuilder<NotSet, NotSet, NoopConcurrencyManager> {
75        CacheMiddlewareBuilder::new()
76    }
77}
78
79impl<B, C, CM> Clone for CacheMiddleware<B, C, CM>
80where
81    C: Clone,
82    CM: Clone,
83{
84    fn clone(&self) -> Self {
85        Self {
86            backend: self.backend.clone(),
87            configuration: self.configuration.clone(),
88            concurrency_manager: self.concurrency_manager.clone(),
89            cache_status_header: self.cache_status_header.clone(),
90        }
91    }
92}
93
94#[async_trait]
95impl<B, C, CM> Middleware for CacheMiddleware<B, C, CM>
96where
97    B: CacheBackend + Send + Sync + 'static,
98    C: CacheConfig<CacheableHttpRequest<reqwest::Body>, CacheableHttpResponse<reqwest::Body>>
99        + Clone
100        + Send
101        + Sync
102        + 'static,
103    C::RequestPredicate: Clone + Send + Sync + 'static,
104    C::ResponsePredicate: Clone + Send + Sync + 'static,
105    C::Extractor: Clone + Send + Sync + 'static,
106    CM: ConcurrencyManager<Result<CacheableHttpResponse<reqwest::Body>>>
107        + Clone
108        + Send
109        + Sync
110        + 'static,
111{
112    async fn handle(
113        &self,
114        req: Request,
115        extensions: &mut Extensions,
116        next: Next<'_>,
117    ) -> Result<Response> {
118        // Convert reqwest::Request to http::Request<reqwest::Body>
119        let http_request: http::Request<reqwest::Body> = req
120            .try_into()
121            .map_err(|e: reqwest::Error| reqwest_middleware::Error::Reqwest(e))?;
122
123        // Wrap body with BufferedBody and create CacheableHttpRequest
124        let (parts, body) = http_request.into_parts();
125        let buffered_request = http::Request::from_parts(parts, BufferedBody::Passthrough(body));
126        let cacheable_req = CacheableHttpRequest::from_request(buffered_request);
127
128        // Create upstream wrapper
129        let upstream = ReqwestUpstream::new(next.clone(), extensions.clone());
130
131        // Create CacheFuture with DisabledOffload (no background revalidation)
132        // This allows us to use non-'static lifetimes
133        let cache_future: CacheFuture<
134            '_,
135            B,
136            CacheableHttpRequest<reqwest::Body>,
137            Result<CacheableHttpResponse<reqwest::Body>>,
138            ReqwestUpstream<'_>,
139            C::RequestPredicate,
140            C::ResponsePredicate,
141            C::Extractor,
142            CM,
143            DisabledOffload,
144        > = CacheFuture::new(
145            self.backend.clone(),
146            cacheable_req,
147            upstream,
148            self.configuration.request_predicates(),
149            self.configuration.response_predicates(),
150            self.configuration.extractors(),
151            Arc::new(self.configuration.policy().clone()),
152            DisabledOffload,
153            self.concurrency_manager.clone(),
154        );
155
156        // Execute cache future
157        let (response, cache_context) = cache_future.await;
158
159        // Convert CacheableHttpResponse back to reqwest::Response
160        let mut cacheable_response = response?;
161
162        // Add cache status header based on cache context
163        cacheable_response.cache_status(cache_context.status, &self.cache_status_header);
164
165        let http_response = cacheable_response.into_response();
166        let (parts, buffered_body) = http_response.into_parts();
167
168        // Convert BufferedBody back to reqwest::Body
169        let body = buffered_body_to_reqwest(buffered_body);
170        let http_response = http::Response::from_parts(parts, body);
171
172        // Convert to reqwest::Response
173        Ok(http_response.into())
174    }
175}
176
177/// Builder for constructing [`CacheMiddleware`] with a fluent API.
178///
179/// Obtained via [`CacheMiddleware::builder()`].
180/// Both [`backend()`](Self::backend) and [`config()`](Self::config)
181/// must be called before [`build()`](Self::build).
182///
183/// See the [crate-level documentation](crate) for usage examples.
184pub struct CacheMiddlewareBuilder<B, C, CM> {
185    backend: B,
186    configuration: C,
187    concurrency_manager: CM,
188    cache_status_header: Option<HeaderName>,
189}
190
191impl<B, C, CM> CacheMiddlewareBuilder<B, C, CM> {
192    /// Sets the cache backend.
193    pub fn backend<NB>(self, backend: NB) -> CacheMiddlewareBuilder<Arc<NB>, C, CM>
194    where
195        NB: CacheBackend,
196    {
197        CacheMiddlewareBuilder {
198            backend: Arc::new(backend),
199            configuration: self.configuration,
200            concurrency_manager: self.concurrency_manager,
201            cache_status_header: self.cache_status_header,
202        }
203    }
204
205    /// Sets the cache configuration.
206    ///
207    /// Use [`Config::builder()`](hitbox::Config::builder) to create a configuration.
208    pub fn config<NC>(self, configuration: NC) -> CacheMiddlewareBuilder<B, NC, CM> {
209        CacheMiddlewareBuilder {
210            backend: self.backend,
211            configuration,
212            concurrency_manager: self.concurrency_manager,
213            cache_status_header: self.cache_status_header,
214        }
215    }
216
217    /// Sets the concurrency manager for dogpile prevention.
218    ///
219    /// Defaults to [`NoopConcurrencyManager`](hitbox::concurrency::NoopConcurrencyManager) if not called.
220    pub fn concurrency_manager<NCM>(
221        self,
222        concurrency_manager: NCM,
223    ) -> CacheMiddlewareBuilder<B, C, NCM> {
224        CacheMiddlewareBuilder {
225            backend: self.backend,
226            configuration: self.configuration,
227            concurrency_manager,
228            cache_status_header: self.cache_status_header,
229        }
230    }
231
232    /// Sets the header name for cache status.
233    ///
234    /// The cache status header indicates whether a response was served from cache.
235    /// Possible values are `HIT`, `MISS`, or `STALE`.
236    ///
237    /// Defaults to `x-cache-status` if not set.
238    pub fn cache_status_header(self, header_name: HeaderName) -> Self {
239        CacheMiddlewareBuilder {
240            cache_status_header: Some(header_name),
241            ..self
242        }
243    }
244}
245
246impl<B, C, CM> CacheMiddlewareBuilder<Arc<B>, C, CM>
247where
248    B: CacheBackend,
249{
250    /// Builds the cache middleware.
251    ///
252    /// Both [`backend()`](Self::backend) and [`config()`](Self::config) must
253    /// be called before this method.
254    pub fn build(self) -> CacheMiddleware<B, C, CM> {
255        CacheMiddleware {
256            backend: self.backend,
257            configuration: self.configuration,
258            concurrency_manager: self.concurrency_manager,
259            cache_status_header: self
260                .cache_status_header
261                .unwrap_or(DEFAULT_CACHE_STATUS_HEADER),
262        }
263    }
264}
265
266impl CacheMiddlewareBuilder<NotSet, NotSet, NoopConcurrencyManager> {
267    /// Creates a new builder. Equivalent to [`CacheMiddleware::builder()`].
268    pub fn new() -> Self {
269        Self {
270            backend: NotSet,
271            configuration: NotSet,
272            concurrency_manager: NoopConcurrencyManager,
273            cache_status_header: None,
274        }
275    }
276}
277
278impl Default for CacheMiddlewareBuilder<NotSet, NotSet, NoopConcurrencyManager> {
279    fn default() -> Self {
280        Self::new()
281    }
282}