http_cache_reqwest/
lib.rs

1#![forbid(unsafe_code, future_incompatible)]
2#![deny(
3    missing_docs,
4    missing_debug_implementations,
5    missing_copy_implementations,
6    nonstandard_style,
7    unused_qualifications,
8    unused_import_braces,
9    unused_extern_crates,
10    trivial_casts,
11    trivial_numeric_casts
12)]
13#![allow(clippy::doc_lazy_continuation)]
14#![cfg_attr(docsrs, feature(doc_cfg))]
15//! The reqwest middleware implementation for http-cache.
16//! ```no_run
17//! use reqwest::Client;
18//! use reqwest_middleware::{ClientBuilder, Result};
19//! use http_cache_reqwest::{Cache, CacheMode, CACacheManager, HttpCache, HttpCacheOptions};
20//!
21//! #[tokio::main]
22//! async fn main() -> Result<()> {
23//!     let client = ClientBuilder::new(Client::new())
24//!         .with(Cache(HttpCache {
25//!             mode: CacheMode::Default,
26//!             manager: CACacheManager::default(),
27//!             options: HttpCacheOptions::default(),
28//!         }))
29//!         .build();
30//!     client
31//!         .get("https://developer.mozilla.org/en-US/docs/Web/HTTP/Caching")
32//!         .send()
33//!         .await?;
34//!     Ok(())
35//! }
36//! ```
37//!
38//! ## Overriding the cache mode
39//!
40//! The cache mode can be overridden on a per-request basis by making use of the
41//! `reqwest-middleware` extensions system.
42//!
43//! ```no_run
44//! client.get("https://developer.mozilla.org/en-US/docs/Web/HTTP/Caching")
45//!     .with_extension(CacheMode::OnlyIfCached)
46//!     .send()
47//!     .await?;
48//! ```
49mod error;
50
51use anyhow::anyhow;
52
53pub use error::BadRequest;
54
55use std::{
56    collections::HashMap,
57    convert::{TryFrom, TryInto},
58    str::FromStr,
59    time::SystemTime,
60};
61
62pub use http::request::Parts;
63use http::{
64    header::{HeaderName, CACHE_CONTROL},
65    Extensions, HeaderValue, Method,
66};
67use http_cache::{
68    BoxError, HitOrMiss, Middleware, Result, XCACHE, XCACHELOOKUP,
69};
70use http_cache_semantics::CachePolicy;
71use reqwest::{Request, Response, ResponseBuilderExt};
72use reqwest_middleware::{Error, Next};
73use url::Url;
74
75pub use http_cache::{
76    CacheManager, CacheMode, CacheOptions, HttpCache, HttpCacheOptions,
77    HttpResponse,
78};
79
80#[cfg(feature = "manager-cacache")]
81#[cfg_attr(docsrs, doc(cfg(feature = "manager-cacache")))]
82pub use http_cache::CACacheManager;
83
84#[cfg(feature = "manager-moka")]
85#[cfg_attr(docsrs, doc(cfg(feature = "manager-moka")))]
86pub use http_cache::{MokaCache, MokaCacheBuilder, MokaManager};
87
88/// Wrapper for [`HttpCache`]
89#[derive(Debug)]
90pub struct Cache<T: CacheManager>(pub HttpCache<T>);
91
92/// Implements ['Middleware'] for reqwest
93pub(crate) struct ReqwestMiddleware<'a> {
94    pub req: Request,
95    pub next: Next<'a>,
96    pub extensions: &'a mut Extensions,
97}
98
99fn clone_req(request: &Request) -> std::result::Result<Request, Error> {
100    match request.try_clone() {
101        Some(r) => Ok(r),
102        None => Err(Error::Middleware(anyhow!(BadRequest))),
103    }
104}
105
106#[async_trait::async_trait]
107impl Middleware for ReqwestMiddleware<'_> {
108    fn overridden_cache_mode(&self) -> Option<CacheMode> {
109        self.extensions.get().cloned()
110    }
111    fn is_method_get_head(&self) -> bool {
112        self.req.method() == Method::GET || self.req.method() == Method::HEAD
113    }
114    fn policy(&self, response: &HttpResponse) -> Result<CachePolicy> {
115        Ok(CachePolicy::new(&self.parts()?, &response.parts()?))
116    }
117    fn policy_with_options(
118        &self,
119        response: &HttpResponse,
120        options: CacheOptions,
121    ) -> Result<CachePolicy> {
122        Ok(CachePolicy::new_options(
123            &self.parts()?,
124            &response.parts()?,
125            SystemTime::now(),
126            options,
127        ))
128    }
129    fn update_headers(&mut self, parts: &Parts) -> Result<()> {
130        for header in parts.headers.iter() {
131            self.req.headers_mut().insert(header.0.clone(), header.1.clone());
132        }
133        Ok(())
134    }
135    fn force_no_cache(&mut self) -> Result<()> {
136        self.req
137            .headers_mut()
138            .insert(CACHE_CONTROL, HeaderValue::from_str("no-cache")?);
139        Ok(())
140    }
141    fn parts(&self) -> Result<Parts> {
142        let copied_req = clone_req(&self.req)?;
143        let converted = match http::Request::try_from(copied_req) {
144            Ok(r) => r,
145            Err(e) => return Err(Box::new(e)),
146        };
147        Ok(converted.into_parts().0)
148    }
149    fn url(&self) -> Result<Url> {
150        Ok(self.req.url().clone())
151    }
152    fn method(&self) -> Result<String> {
153        Ok(self.req.method().as_ref().to_string())
154    }
155    async fn remote_fetch(&mut self) -> Result<HttpResponse> {
156        let copied_req = clone_req(&self.req)?;
157        let res = match self.next.clone().run(copied_req, self.extensions).await
158        {
159            Ok(r) => r,
160            Err(e) => return Err(Box::new(e)),
161        };
162        let mut headers = HashMap::new();
163        for header in res.headers() {
164            headers.insert(
165                header.0.as_str().to_owned(),
166                header.1.to_str()?.to_owned(),
167            );
168        }
169        let url = res.url().clone();
170        let status = res.status().into();
171        let version = res.version();
172        let body: Vec<u8> = match res.bytes().await {
173            Ok(b) => b,
174            Err(e) => return Err(Box::new(e)),
175        }
176        .to_vec();
177        Ok(HttpResponse {
178            body,
179            headers,
180            status,
181            url,
182            version: version.try_into()?,
183        })
184    }
185}
186
187// Converts an [`HttpResponse`] to a reqwest [`Response`]
188fn convert_response(response: HttpResponse) -> anyhow::Result<Response> {
189    let mut ret_res = http::Response::builder()
190        .status(response.status)
191        .url(response.url)
192        .version(response.version.into())
193        .body(response.body)?;
194    for header in response.headers {
195        ret_res.headers_mut().insert(
196            HeaderName::from_str(header.0.clone().as_str())?,
197            HeaderValue::from_str(header.1.clone().as_str())?,
198        );
199    }
200    Ok(Response::from(ret_res))
201}
202
203fn bad_header(e: reqwest::header::InvalidHeaderValue) -> Error {
204    Error::Middleware(anyhow!(e))
205}
206
207fn from_box_error(e: BoxError) -> Error {
208    Error::Middleware(anyhow!(e))
209}
210
211#[async_trait::async_trait]
212impl<T: CacheManager> reqwest_middleware::Middleware for Cache<T> {
213    async fn handle(
214        &self,
215        req: Request,
216        extensions: &mut Extensions,
217        next: Next<'_>,
218    ) -> std::result::Result<Response, Error> {
219        let mut middleware = ReqwestMiddleware { req, next, extensions };
220        if self
221            .0
222            .can_cache_request(&middleware)
223            .map_err(|e| Error::Middleware(anyhow!(e)))?
224        {
225            let res = self.0.run(middleware).await.map_err(from_box_error)?;
226            let converted = convert_response(res)?;
227            Ok(converted)
228        } else {
229            self.0
230                .run_no_cache(&mut middleware)
231                .await
232                .map_err(from_box_error)?;
233            let mut res = middleware
234                .next
235                .run(middleware.req, middleware.extensions)
236                .await?;
237
238            let miss =
239                HeaderValue::from_str(HitOrMiss::MISS.to_string().as_ref())
240                    .map_err(bad_header)?;
241            res.headers_mut().insert(XCACHE, miss.clone());
242            res.headers_mut().insert(XCACHELOOKUP, miss);
243            Ok(res)
244        }
245    }
246}
247
248#[cfg(test)]
249mod test;