reqwest_middleware_cache/
lib.rs1#![forbid(unsafe_code, future_incompatible)]
27#![deny(
28 missing_docs,
29 missing_debug_implementations,
30 missing_copy_implementations,
31 nonstandard_style,
32 unused_qualifications,
33 rustdoc::missing_doc_code_examples
34)]
35use std::time::SystemTime;
36
37use anyhow::{anyhow, Result};
38use http::{header::CACHE_CONTROL, HeaderValue, Method};
39use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
40use reqwest::{Request, Response};
41use reqwest_middleware::{Error, Middleware, Next};
42use task_local_extensions::Extensions;
43
44pub mod managers;
46
47#[async_trait::async_trait]
49pub trait CacheManager {
50 async fn get(&self, req: &Request) -> Result<Option<(Response, CachePolicy)>>;
52 async fn put(&self, req: &Request, res: Response, policy: CachePolicy) -> Result<Response>;
54 async fn delete(&self, req: &Request) -> Result<()>;
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum CacheMode {
62 Default,
70 NoStore,
72 Reload,
75 NoCache,
78 ForceCache,
82 OnlyIfCached,
88}
89
90#[derive(Debug, Clone)]
92pub struct Cache<T: CacheManager + Send + Sync + 'static> {
93 pub mode: CacheMode,
95 pub cache_manager: T,
97}
98
99impl<T: CacheManager + Send + Sync + 'static> Cache<T> {
100 pub async fn run<'a>(
102 &'a self,
103 mut req: Request,
104 next: Next<'a>,
105 extensions: &mut Extensions,
106 ) -> Result<Response> {
107 let is_cacheable = (req.method() == Method::GET || req.method() == Method::HEAD)
108 && self.mode != CacheMode::NoStore
109 && self.mode != CacheMode::Reload;
110
111 if !is_cacheable {
112 return self.remote_fetch(req, next, extensions).await;
113 }
114
115 if let Some(store) = self.cache_manager.get(&req).await? {
116 let (mut res, policy) = store;
117 if let Some(warning_code) = get_warning_code(&res) {
118 #[allow(clippy::manual_range_contains)]
129 if warning_code >= 100 && warning_code < 200 {
130 res.headers_mut().remove(reqwest::header::WARNING);
131 }
132 }
133
134 match self.mode {
135 CacheMode::Default => Ok(self
136 .conditional_fetch(req, res, policy, next, extensions)
137 .await?),
138 CacheMode::NoCache => {
139 req.headers_mut()
140 .insert(CACHE_CONTROL, HeaderValue::from_str("no-cache")?);
141 Ok(self
142 .conditional_fetch(req, res, policy, next, extensions)
143 .await?)
144 }
145 CacheMode::ForceCache | CacheMode::OnlyIfCached => {
146 add_warning(&mut res, req.url(), 112, "Disconnected operation");
151 Ok(res)
152 }
153 _ => Ok(self.remote_fetch(req, next, extensions).await?),
154 }
155 } else {
156 match self.mode {
157 CacheMode::OnlyIfCached => {
158 let err_res = http::Response::builder()
160 .status(http::StatusCode::GATEWAY_TIMEOUT)
161 .body("")?;
162 Ok(err_res.into())
163 }
164 _ => Ok(self.remote_fetch(req, next, extensions).await?),
165 }
166 }
167 }
168
169 async fn conditional_fetch<'a>(
170 &self,
171 mut req: Request,
172 mut cached_res: Response,
173 mut policy: CachePolicy,
174 next: Next<'_>,
175 extensions: &mut Extensions,
176 ) -> Result<Response> {
177 let before_req = policy.before_request(&req, SystemTime::now());
178 match before_req {
179 BeforeRequest::Fresh(parts) => {
180 update_response_headers(parts, &mut cached_res);
181 return Ok(cached_res);
182 }
183 BeforeRequest::Stale {
184 request: parts,
185 matches,
186 } => {
187 if matches {
188 update_request_headers(parts, &mut req);
189 }
190 }
191 }
192 let copied_req = req.try_clone().ok_or_else(|| {
193 Error::Middleware(anyhow!(
194 "Request object is not cloneable. Are you passing a streaming body?".to_string()
195 ))
196 })?;
197 match self.remote_fetch(req, next, extensions).await {
198 Ok(cond_res) => {
199 if cond_res.status().is_server_error() && must_revalidate(&cached_res) {
200 add_warning(
206 &mut cached_res,
207 copied_req.url(),
208 111,
209 "Revalidation failed",
210 );
211 Ok(cached_res)
212 } else if cond_res.status() == http::StatusCode::NOT_MODIFIED {
213 let mut res = http::Response::builder()
214 .status(cond_res.status())
215 .body(cached_res.text().await?)?;
216 for (key, value) in cond_res.headers() {
217 res.headers_mut().append(key, value.clone());
218 }
219 let mut converted = Response::from(res);
220 let after_res =
221 policy.after_response(&copied_req, &cond_res, SystemTime::now());
222 match after_res {
223 AfterResponse::Modified(new_policy, parts) => {
224 policy = new_policy;
225 update_response_headers(parts, &mut converted);
226 }
227 AfterResponse::NotModified(new_policy, parts) => {
228 policy = new_policy;
229 update_response_headers(parts, &mut converted);
230 }
231 }
232 let res = self
233 .cache_manager
234 .put(&copied_req, converted, policy)
235 .await?;
236 Ok(res)
237 } else {
238 Ok(cached_res)
239 }
240 }
241 Err(e) => {
242 if must_revalidate(&cached_res) {
243 Err(e)
244 } else {
245 add_warning(
251 &mut cached_res,
252 copied_req.url(),
253 111,
254 "Revalidation failed",
255 );
256 add_warning(
263 &mut cached_res,
264 copied_req.url(),
265 199,
266 format!("Miscellaneous Warning {}", e).as_str(),
267 );
268 Ok(cached_res)
269 }
270 }
271 }
272 }
273
274 async fn remote_fetch<'a>(
275 &'a self,
276 req: Request,
277 next: Next<'a>,
278 extensions: &mut Extensions,
279 ) -> Result<Response> {
280 let copied_req = req.try_clone().ok_or_else(|| {
281 Error::Middleware(anyhow!(
282 "Request object is not clonable. Are you passing a streaming body?".to_string()
283 ))
284 })?;
285 let res = next.run(req, extensions).await?;
286 let is_method_get_head =
287 copied_req.method() == Method::GET || copied_req.method() == Method::HEAD;
288 let policy = CachePolicy::new(&copied_req, &res);
289 let is_cacheable = self.mode != CacheMode::NoStore
290 && is_method_get_head
291 && res.status() == http::StatusCode::OK
292 && policy.is_storable();
293 if is_cacheable {
294 Ok(self.cache_manager.put(&copied_req, res, policy).await?)
295 } else if !is_method_get_head {
296 self.cache_manager.delete(&copied_req).await?;
297 Ok(res)
298 } else {
299 Ok(res)
300 }
301 }
302}
303
304fn must_revalidate(res: &Response) -> bool {
305 if let Some(val) = res.headers().get(CACHE_CONTROL.as_str()) {
306 val.to_str()
307 .expect("Unable to convert header value to string")
308 .to_lowercase()
309 .contains("must-revalidate")
310 } else {
311 false
312 }
313}
314
315fn get_warning_code(res: &Response) -> Option<usize> {
316 res.headers().get(reqwest::header::WARNING).and_then(|hdr| {
317 hdr.to_str()
318 .expect("Unable to convert warning to string")
319 .chars()
320 .take(3)
321 .collect::<String>()
322 .parse()
323 .ok()
324 })
325}
326
327fn update_request_headers(parts: http::request::Parts, req: &mut Request) {
328 let headers = parts.headers;
329 for header in headers.iter() {
330 req.headers_mut().insert(header.0.clone(), header.1.clone());
331 }
332}
333
334fn update_response_headers(parts: http::response::Parts, res: &mut Response) {
335 for header in parts.headers.iter() {
336 res.headers_mut().insert(header.0.clone(), header.1.clone());
337 }
338}
339
340fn add_warning(res: &mut Response, uri: &reqwest::Url, code: usize, message: &str) {
341 let val = HeaderValue::from_str(
352 format!(
353 "{} {} {:?} \"{}\"",
354 code,
355 uri.host().expect("Invalid URL"),
356 message,
357 httpdate::fmt_http_date(SystemTime::now())
358 )
359 .as_str(),
360 )
361 .expect("Failed to generate warning string");
362 res.headers_mut().append(reqwest::header::WARNING, val);
363}
364
365#[async_trait::async_trait]
366impl<T: CacheManager + 'static + Send + Sync> Middleware for Cache<T> {
367 async fn handle(
368 &self,
369 req: Request,
370 extensions: &mut Extensions,
371 next: Next<'_>,
372 ) -> reqwest_middleware::Result<Response> {
373 let res = self.run(req, next, extensions).await?;
374 Ok(res)
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use anyhow::Result;
382 use http::{HeaderValue, Response};
383 use std::str::FromStr;
384
385 #[tokio::test]
386 async fn can_get_warning_code() -> Result<()> {
387 let url = reqwest::Url::from_str("https://example.com")?;
388 let mut res = reqwest::Response::from(Response::new(""));
389 add_warning(&mut res, &url, 111, "Revalidation failed");
390 let code = get_warning_code(&res).unwrap();
391 assert_eq!(code, 111);
392 Ok(())
393 }
394
395 #[tokio::test]
396 async fn can_check_revalidate() -> Result<()> {
397 let mut res = Response::new("");
398 res.headers_mut().append(
399 "Cache-Control",
400 HeaderValue::from_str("max-age=1733992, must-revalidate")?,
401 );
402 let check = must_revalidate(&res.into());
403 assert!(check, "{}", true);
404 Ok(())
405 }
406}