1use std::sync::Arc;
10
11use async_trait::async_trait;
12use hitbox::CacheStatusExt;
13use hitbox::backend::CacheBackend;
14use hitbox::concurrency::{ConcurrencyManager, NoopConcurrencyManager};
15use hitbox::config::CacheConfig;
16use hitbox::fsm::CacheFuture;
17use hitbox_core::DisabledOffload;
18use hitbox_http::{
19 BufferedBody, CacheableHttpRequest, CacheableHttpResponse, DEFAULT_CACHE_STATUS_HEADER,
20};
21use http::Extensions;
22use http::header::HeaderName;
23use reqwest::{Request, Response};
24use reqwest_middleware::{Middleware, Next, Result};
25
26use crate::upstream::{ReqwestUpstream, buffered_body_to_reqwest};
27
28pub struct NotSet;
30
31pub struct CacheMiddleware<B, C, CM> {
39 backend: Arc<B>,
40 configuration: C,
41 concurrency_manager: CM,
42 cache_status_header: HeaderName,
44}
45
46impl<B, C, CM> CacheMiddleware<B, C, CM> {
47 pub fn new(
52 backend: Arc<B>,
53 configuration: C,
54 concurrency_manager: CM,
55 cache_status_header: HeaderName,
56 ) -> Self {
57 Self {
58 backend,
59 configuration,
60 concurrency_manager,
61 cache_status_header,
62 }
63 }
64}
65
66impl CacheMiddleware<NotSet, NotSet, NoopConcurrencyManager> {
67 pub fn builder() -> CacheMiddlewareBuilder<NotSet, NotSet, NoopConcurrencyManager> {
75 CacheMiddlewareBuilder::new()
76 }
77}
78
79impl<B, C, CM> Clone for CacheMiddleware<B, C, CM>
80where
81 C: Clone,
82 CM: Clone,
83{
84 fn clone(&self) -> Self {
85 Self {
86 backend: self.backend.clone(),
87 configuration: self.configuration.clone(),
88 concurrency_manager: self.concurrency_manager.clone(),
89 cache_status_header: self.cache_status_header.clone(),
90 }
91 }
92}
93
94#[async_trait]
95impl<B, C, CM> Middleware for CacheMiddleware<B, C, CM>
96where
97 B: CacheBackend + Send + Sync + 'static,
98 C: CacheConfig<CacheableHttpRequest<reqwest::Body>, CacheableHttpResponse<reqwest::Body>>
99 + Clone
100 + Send
101 + Sync
102 + 'static,
103 C::RequestPredicate: Clone + Send + Sync + 'static,
104 C::ResponsePredicate: Clone + Send + Sync + 'static,
105 C::Extractor: Clone + Send + Sync + 'static,
106 CM: ConcurrencyManager<Result<CacheableHttpResponse<reqwest::Body>>>
107 + Clone
108 + Send
109 + Sync
110 + 'static,
111{
112 async fn handle(
113 &self,
114 req: Request,
115 extensions: &mut Extensions,
116 next: Next<'_>,
117 ) -> Result<Response> {
118 let http_request: http::Request<reqwest::Body> = req
120 .try_into()
121 .map_err(|e: reqwest::Error| reqwest_middleware::Error::Reqwest(e))?;
122
123 let (parts, body) = http_request.into_parts();
125 let buffered_request = http::Request::from_parts(parts, BufferedBody::Passthrough(body));
126 let cacheable_req = CacheableHttpRequest::from_request(buffered_request);
127
128 let upstream = ReqwestUpstream::new(next.clone(), extensions.clone());
130
131 let cache_future: CacheFuture<
134 '_,
135 B,
136 CacheableHttpRequest<reqwest::Body>,
137 Result<CacheableHttpResponse<reqwest::Body>>,
138 ReqwestUpstream<'_>,
139 C::RequestPredicate,
140 C::ResponsePredicate,
141 C::Extractor,
142 CM,
143 DisabledOffload,
144 > = CacheFuture::new(
145 self.backend.clone(),
146 cacheable_req,
147 upstream,
148 self.configuration.request_predicates(),
149 self.configuration.response_predicates(),
150 self.configuration.extractors(),
151 Arc::new(self.configuration.policy().clone()),
152 DisabledOffload,
153 self.concurrency_manager.clone(),
154 );
155
156 let (response, cache_context) = cache_future.await;
158
159 let mut cacheable_response = response?;
161
162 cacheable_response.cache_status(cache_context.status, &self.cache_status_header);
164
165 let http_response = cacheable_response.into_response();
166 let (parts, buffered_body) = http_response.into_parts();
167
168 let body = buffered_body_to_reqwest(buffered_body);
170 let http_response = http::Response::from_parts(parts, body);
171
172 Ok(http_response.into())
174 }
175}
176
177pub struct CacheMiddlewareBuilder<B, C, CM> {
185 backend: B,
186 configuration: C,
187 concurrency_manager: CM,
188 cache_status_header: Option<HeaderName>,
189}
190
191impl<B, C, CM> CacheMiddlewareBuilder<B, C, CM> {
192 pub fn backend<NB>(self, backend: NB) -> CacheMiddlewareBuilder<Arc<NB>, C, CM>
194 where
195 NB: CacheBackend,
196 {
197 CacheMiddlewareBuilder {
198 backend: Arc::new(backend),
199 configuration: self.configuration,
200 concurrency_manager: self.concurrency_manager,
201 cache_status_header: self.cache_status_header,
202 }
203 }
204
205 pub fn config<NC>(self, configuration: NC) -> CacheMiddlewareBuilder<B, NC, CM> {
209 CacheMiddlewareBuilder {
210 backend: self.backend,
211 configuration,
212 concurrency_manager: self.concurrency_manager,
213 cache_status_header: self.cache_status_header,
214 }
215 }
216
217 pub fn concurrency_manager<NCM>(
221 self,
222 concurrency_manager: NCM,
223 ) -> CacheMiddlewareBuilder<B, C, NCM> {
224 CacheMiddlewareBuilder {
225 backend: self.backend,
226 configuration: self.configuration,
227 concurrency_manager,
228 cache_status_header: self.cache_status_header,
229 }
230 }
231
232 pub fn cache_status_header(self, header_name: HeaderName) -> Self {
239 CacheMiddlewareBuilder {
240 cache_status_header: Some(header_name),
241 ..self
242 }
243 }
244}
245
246impl<B, C, CM> CacheMiddlewareBuilder<Arc<B>, C, CM>
247where
248 B: CacheBackend,
249{
250 pub fn build(self) -> CacheMiddleware<B, C, CM> {
255 CacheMiddleware {
256 backend: self.backend,
257 configuration: self.configuration,
258 concurrency_manager: self.concurrency_manager,
259 cache_status_header: self
260 .cache_status_header
261 .unwrap_or(DEFAULT_CACHE_STATUS_HEADER),
262 }
263 }
264}
265
266impl CacheMiddlewareBuilder<NotSet, NotSet, NoopConcurrencyManager> {
267 pub fn new() -> Self {
269 Self {
270 backend: NotSet,
271 configuration: NotSet,
272 concurrency_manager: NoopConcurrencyManager,
273 cache_status_header: None,
274 }
275 }
276}
277
278impl Default for CacheMiddlewareBuilder<NotSet, NotSet, NoopConcurrencyManager> {
279 fn default() -> Self {
280 Self::new()
281 }
282}