reqwest_middleware_cache/
lib.rs

1//! A caching middleware for Reqwest that follows HTTP caching rules.
2//! By default it uses [`cacache`](https://github.com/zkat/cacache-rs) as the backend cache manager.
3//!
4//! ## Example
5//!
6//! ```no_run
7//! use reqwest::Client;
8//! use reqwest_middleware::{ClientBuilder, Result};
9//! use reqwest_middleware_cache::{managers::CACacheManager, Cache, CacheMode};
10//!
11//! #[tokio::main]
12//! async fn main() -> Result<()> {
13//!     let client = ClientBuilder::new(Client::new())
14//!         .with(Cache {
15//!             mode: CacheMode::Default,
16//!             cache_manager: CACacheManager::default(),
17//!         })
18//!         .build();
19//!     client
20//!         .get("https://developer.mozilla.org/en-US/docs/Web/HTTP/Caching")
21//!         .send()
22//!         .await?;
23//!     Ok(())
24//! }
25//! ```
26#![forbid(unsafe_code, future_incompatible)]
27#![deny(
28    missing_docs,
29    missing_debug_implementations,
30    missing_copy_implementations,
31    nonstandard_style,
32    unused_qualifications,
33    rustdoc::missing_doc_code_examples
34)]
35use std::time::SystemTime;
36
37use anyhow::{anyhow, Result};
38use http::{header::CACHE_CONTROL, HeaderValue, Method};
39use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
40use reqwest::{Request, Response};
41use reqwest_middleware::{Error, Middleware, Next};
42use task_local_extensions::Extensions;
43
44/// Backend cache managers, cacache is the default.
45pub mod managers;
46
47/// A trait providing methods for storing, reading, and removing cache records.
48#[async_trait::async_trait]
49pub trait CacheManager {
50    /// Attempts to pull a cached reponse and related policy from cache.
51    async fn get(&self, req: &Request) -> Result<Option<(Response, CachePolicy)>>;
52    /// Attempts to cache a response and related policy.
53    async fn put(&self, req: &Request, res: Response, policy: CachePolicy) -> Result<Response>;
54    /// Attempts to remove a record from cache.
55    async fn delete(&self, req: &Request) -> Result<()>;
56}
57
58/// Similar to [make-fetch-happen cache options](https://github.com/npm/make-fetch-happen#--optscache).
59/// Passed in when the [`Cache`] struct is being built.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum CacheMode {
62    /// Will inspect the HTTP cache on the way to the network.
63    /// If there is a fresh response it will be used.
64    /// If there is a stale response a conditional request will be created,
65    /// and a normal request otherwise.
66    /// It then updates the HTTP cache with the response.
67    /// If the revalidation request fails (for example, on a 500 or if you're offline),
68    /// the stale response will be returned.
69    Default,
70    /// Behaves as if there is no HTTP cache at all.
71    NoStore,
72    /// Behaves as if there is no HTTP cache on the way to the network.
73    /// Ergo, it creates a normal request and updates the HTTP cache with the response.
74    Reload,
75    /// Creates a conditional request if there is a response in the HTTP cache
76    /// and a normal request otherwise. It then updates the HTTP cache with the response.
77    NoCache,
78    /// Uses any response in the HTTP cache matching the request,
79    /// not paying attention to staleness. If there was no response,
80    /// it creates a normal request and updates the HTTP cache with the response.
81    ForceCache,
82    /// Uses any response in the HTTP cache matching the request,
83    /// not paying attention to staleness. If there was no response,
84    /// it returns a network error. (Can only be used when request’s mode is "same-origin".
85    /// Any cached redirects will be followed assuming request’s redirect mode is "follow"
86    /// and the redirects do not violate request’s mode.)
87    OnlyIfCached,
88}
89
90/// Caches requests according to http spec
91#[derive(Debug, Clone)]
92pub struct Cache<T: CacheManager + Send + Sync + 'static> {
93    /// Determines the manager behavior
94    pub mode: CacheMode,
95    /// Manager instance that implements the CacheManager trait
96    pub cache_manager: T,
97}
98
99impl<T: CacheManager + Send + Sync + 'static> Cache<T> {
100    /// Called by the Reqwest middleware handle method when a request is made.
101    pub async fn run<'a>(
102        &'a self,
103        mut req: Request,
104        next: Next<'a>,
105        extensions: &mut Extensions,
106    ) -> Result<Response> {
107        let is_cacheable = (req.method() == Method::GET || req.method() == Method::HEAD)
108            && self.mode != CacheMode::NoStore
109            && self.mode != CacheMode::Reload;
110
111        if !is_cacheable {
112            return self.remote_fetch(req, next, extensions).await;
113        }
114
115        if let Some(store) = self.cache_manager.get(&req).await? {
116            let (mut res, policy) = store;
117            if let Some(warning_code) = get_warning_code(&res) {
118                // https://tools.ietf.org/html/rfc7234#section-4.3.4
119                //
120                // If a stored response is selected for update, the cache MUST:
121                //
122                // * delete any Warning header fields in the stored response with
123                //   warn-code 1xx (see Section 5.5);
124                //
125                // * retain any Warning header fields in the stored response with
126                //   warn-code 2xx;
127                //
128                #[allow(clippy::manual_range_contains)]
129                if warning_code >= 100 && warning_code < 200 {
130                    res.headers_mut().remove(reqwest::header::WARNING);
131                }
132            }
133
134            match self.mode {
135                CacheMode::Default => Ok(self
136                    .conditional_fetch(req, res, policy, next, extensions)
137                    .await?),
138                CacheMode::NoCache => {
139                    req.headers_mut()
140                        .insert(CACHE_CONTROL, HeaderValue::from_str("no-cache")?);
141                    Ok(self
142                        .conditional_fetch(req, res, policy, next, extensions)
143                        .await?)
144                }
145                CacheMode::ForceCache | CacheMode::OnlyIfCached => {
146                    //   112 Disconnected operation
147                    // SHOULD be included if the cache is intentionally disconnected from
148                    // the rest of the network for a period of time.
149                    // (https://tools.ietf.org/html/rfc2616#section-14.46)
150                    add_warning(&mut res, req.url(), 112, "Disconnected operation");
151                    Ok(res)
152                }
153                _ => Ok(self.remote_fetch(req, next, extensions).await?),
154            }
155        } else {
156            match self.mode {
157                CacheMode::OnlyIfCached => {
158                    // ENOTCACHED
159                    let err_res = http::Response::builder()
160                        .status(http::StatusCode::GATEWAY_TIMEOUT)
161                        .body("")?;
162                    Ok(err_res.into())
163                }
164                _ => Ok(self.remote_fetch(req, next, extensions).await?),
165            }
166        }
167    }
168
169    async fn conditional_fetch<'a>(
170        &self,
171        mut req: Request,
172        mut cached_res: Response,
173        mut policy: CachePolicy,
174        next: Next<'_>,
175        extensions: &mut Extensions,
176    ) -> Result<Response> {
177        let before_req = policy.before_request(&req, SystemTime::now());
178        match before_req {
179            BeforeRequest::Fresh(parts) => {
180                update_response_headers(parts, &mut cached_res);
181                return Ok(cached_res);
182            }
183            BeforeRequest::Stale {
184                request: parts,
185                matches,
186            } => {
187                if matches {
188                    update_request_headers(parts, &mut req);
189                }
190            }
191        }
192        let copied_req = req.try_clone().ok_or_else(|| {
193            Error::Middleware(anyhow!(
194                "Request object is not cloneable. Are you passing a streaming body?".to_string()
195            ))
196        })?;
197        match self.remote_fetch(req, next, extensions).await {
198            Ok(cond_res) => {
199                if cond_res.status().is_server_error() && must_revalidate(&cached_res) {
200                    //   111 Revalidation failed
201                    //   MUST be included if a cache returns a stale response
202                    //   because an attempt to revalidate the response failed,
203                    //   due to an inability to reach the server.
204                    // (https://tools.ietf.org/html/rfc2616#section-14.46)
205                    add_warning(
206                        &mut cached_res,
207                        copied_req.url(),
208                        111,
209                        "Revalidation failed",
210                    );
211                    Ok(cached_res)
212                } else if cond_res.status() == http::StatusCode::NOT_MODIFIED {
213                    let mut res = http::Response::builder()
214                        .status(cond_res.status())
215                        .body(cached_res.text().await?)?;
216                    for (key, value) in cond_res.headers() {
217                        res.headers_mut().append(key, value.clone());
218                    }
219                    let mut converted = Response::from(res);
220                    let after_res =
221                        policy.after_response(&copied_req, &cond_res, SystemTime::now());
222                    match after_res {
223                        AfterResponse::Modified(new_policy, parts) => {
224                            policy = new_policy;
225                            update_response_headers(parts, &mut converted);
226                        }
227                        AfterResponse::NotModified(new_policy, parts) => {
228                            policy = new_policy;
229                            update_response_headers(parts, &mut converted);
230                        }
231                    }
232                    let res = self
233                        .cache_manager
234                        .put(&copied_req, converted, policy)
235                        .await?;
236                    Ok(res)
237                } else {
238                    Ok(cached_res)
239                }
240            }
241            Err(e) => {
242                if must_revalidate(&cached_res) {
243                    Err(e)
244                } else {
245                    //   111 Revalidation failed
246                    //   MUST be included if a cache returns a stale response
247                    //   because an attempt to revalidate the response failed,
248                    //   due to an inability to reach the server.
249                    // (https://tools.ietf.org/html/rfc2616#section-14.46)
250                    add_warning(
251                        &mut cached_res,
252                        copied_req.url(),
253                        111,
254                        "Revalidation failed",
255                    );
256                    //   199 Miscellaneous warning
257                    //   The warning text MAY include arbitrary information to
258                    //   be presented to a human user, or logged. A system
259                    //   receiving this warning MUST NOT take any automated
260                    //   action, besides presenting the warning to the user.
261                    // (https://tools.ietf.org/html/rfc2616#section-14.46)
262                    add_warning(
263                        &mut cached_res,
264                        copied_req.url(),
265                        199,
266                        format!("Miscellaneous Warning {}", e).as_str(),
267                    );
268                    Ok(cached_res)
269                }
270            }
271        }
272    }
273
274    async fn remote_fetch<'a>(
275        &'a self,
276        req: Request,
277        next: Next<'a>,
278        extensions: &mut Extensions,
279    ) -> Result<Response> {
280        let copied_req = req.try_clone().ok_or_else(|| {
281            Error::Middleware(anyhow!(
282                "Request object is not clonable. Are you passing a streaming body?".to_string()
283            ))
284        })?;
285        let res = next.run(req, extensions).await?;
286        let is_method_get_head =
287            copied_req.method() == Method::GET || copied_req.method() == Method::HEAD;
288        let policy = CachePolicy::new(&copied_req, &res);
289        let is_cacheable = self.mode != CacheMode::NoStore
290            && is_method_get_head
291            && res.status() == http::StatusCode::OK
292            && policy.is_storable();
293        if is_cacheable {
294            Ok(self.cache_manager.put(&copied_req, res, policy).await?)
295        } else if !is_method_get_head {
296            self.cache_manager.delete(&copied_req).await?;
297            Ok(res)
298        } else {
299            Ok(res)
300        }
301    }
302}
303
304fn must_revalidate(res: &Response) -> bool {
305    if let Some(val) = res.headers().get(CACHE_CONTROL.as_str()) {
306        val.to_str()
307            .expect("Unable to convert header value to string")
308            .to_lowercase()
309            .contains("must-revalidate")
310    } else {
311        false
312    }
313}
314
315fn get_warning_code(res: &Response) -> Option<usize> {
316    res.headers().get(reqwest::header::WARNING).and_then(|hdr| {
317        hdr.to_str()
318            .expect("Unable to convert warning to string")
319            .chars()
320            .take(3)
321            .collect::<String>()
322            .parse()
323            .ok()
324    })
325}
326
327fn update_request_headers(parts: http::request::Parts, req: &mut Request) {
328    let headers = parts.headers;
329    for header in headers.iter() {
330        req.headers_mut().insert(header.0.clone(), header.1.clone());
331    }
332}
333
334fn update_response_headers(parts: http::response::Parts, res: &mut Response) {
335    for header in parts.headers.iter() {
336        res.headers_mut().insert(header.0.clone(), header.1.clone());
337    }
338}
339
340fn add_warning(res: &mut Response, uri: &reqwest::Url, code: usize, message: &str) {
341    //   Warning    = "Warning" ":" 1#warning-value
342    // warning-value = warn-code SP warn-agent SP warn-text [SP warn-date]
343    // warn-code  = 3DIGIT
344    // warn-agent = ( host [ ":" port ] ) | pseudonym
345    //                 ; the name or pseudonym of the server adding
346    //                 ; the Warning header, for use in debugging
347    // warn-text  = quoted-string
348    // warn-date  = <"> HTTP-date <">
349    // (https://tools.ietf.org/html/rfc2616#section-14.46)
350    //
351    let val = HeaderValue::from_str(
352        format!(
353            "{} {} {:?} \"{}\"",
354            code,
355            uri.host().expect("Invalid URL"),
356            message,
357            httpdate::fmt_http_date(SystemTime::now())
358        )
359        .as_str(),
360    )
361    .expect("Failed to generate warning string");
362    res.headers_mut().append(reqwest::header::WARNING, val);
363}
364
365#[async_trait::async_trait]
366impl<T: CacheManager + 'static + Send + Sync> Middleware for Cache<T> {
367    async fn handle(
368        &self,
369        req: Request,
370        extensions: &mut Extensions,
371        next: Next<'_>,
372    ) -> reqwest_middleware::Result<Response> {
373        let res = self.run(req, next, extensions).await?;
374        Ok(res)
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use anyhow::Result;
382    use http::{HeaderValue, Response};
383    use std::str::FromStr;
384
385    #[tokio::test]
386    async fn can_get_warning_code() -> Result<()> {
387        let url = reqwest::Url::from_str("https://example.com")?;
388        let mut res = reqwest::Response::from(Response::new(""));
389        add_warning(&mut res, &url, 111, "Revalidation failed");
390        let code = get_warning_code(&res).unwrap();
391        assert_eq!(code, 111);
392        Ok(())
393    }
394
395    #[tokio::test]
396    async fn can_check_revalidate() -> Result<()> {
397        let mut res = Response::new("");
398        res.headers_mut().append(
399            "Cache-Control",
400            HeaderValue::from_str("max-age=1733992, must-revalidate")?,
401        );
402        let check = must_revalidate(&res.into());
403        assert!(check, "{}", true);
404        Ok(())
405    }
406}