//! A caching middleware for Reqwest that follows HTTP caching rules.
//! By default it uses [`cacache`](https://github.com/zkat/cacache-rs) as the backend cache manager.
//!
//! ## Example
//!
//! ```no_run
//! use reqwest::Client;
//! use reqwest_middleware::{ClientBuilder, Result};
//! use reqwest_middleware_cache::{managers::CACacheManager, Cache, CacheMode};
//!
//! #[tokio::main]
//! async fn main() -> Result<()> {
//! let client = ClientBuilder::new(Client::new())
//! .with(Cache {
//! mode: CacheMode::Default,
//! cache_manager: CACacheManager::default(),
//! })
//! .build();
//! client
//! .get("https://developer.mozilla.org/en-US/docs/Web/HTTP/Caching")
//! .send()
//! .await?;
//! Ok(())
//! }
//! ```
#![forbid(unsafe_code, future_incompatible)]
#![deny(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
nonstandard_style,
unused_qualifications,
rustdoc::missing_doc_code_examples
)]
use std::time::SystemTime;
use anyhow::{anyhow, Result};
use http::{header::CACHE_CONTROL, HeaderValue, Method};
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
use reqwest::{Request, Response};
use reqwest_middleware::{Error, Middleware, Next};
use task_local_extensions::Extensions;
/// Backend cache managers, cacache is the default.
pub mod managers;
/// A trait providing methods for storing, reading, and removing cache records.
#[async_trait::async_trait]
pub trait CacheManager {
/// Attempts to pull a cached reponse and related policy from cache.
async fn get(&self, req: &Request) -> Result<Option<(Response, CachePolicy)>>;
/// Attempts to cache a response and related policy.
async fn put(&self, req: &Request, res: Response, policy: CachePolicy) -> Result<Response>;
/// Attempts to remove a record from cache.
async fn delete(&self, req: &Request) -> Result<()>;
}
/// Similar to [make-fetch-happen cache options](https://github.com/npm/make-fetch-happen#--optscache).
/// Passed in when the [`Cache`] struct is being built.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheMode {
/// Will inspect the HTTP cache on the way to the network.
/// If there is a fresh response it will be used.
/// If there is a stale response a conditional request will be created,
/// and a normal request otherwise.
/// It then updates the HTTP cache with the response.
/// If the revalidation request fails (for example, on a 500 or if you're offline),
/// the stale response will be returned.
Default,
/// Behaves as if there is no HTTP cache at all.
NoStore,
/// Behaves as if there is no HTTP cache on the way to the network.
/// Ergo, it creates a normal request and updates the HTTP cache with the response.
Reload,
/// Creates a conditional request if there is a response in the HTTP cache
/// and a normal request otherwise. It then updates the HTTP cache with the response.
NoCache,
/// Uses any response in the HTTP cache matching the request,
/// not paying attention to staleness. If there was no response,
/// it creates a normal request and updates the HTTP cache with the response.
ForceCache,
/// Uses any response in the HTTP cache matching the request,
/// not paying attention to staleness. If there was no response,
/// it returns a network error. (Can only be used when request’s mode is "same-origin".
/// Any cached redirects will be followed assuming request’s redirect mode is "follow"
/// and the redirects do not violate request’s mode.)
OnlyIfCached,
}
/// Caches requests according to http spec
#[derive(Debug, Clone)]
pub struct Cache<T: CacheManager + Send + Sync + 'static> {
/// Determines the manager behavior
pub mode: CacheMode,
/// Manager instance that implements the CacheManager trait
pub cache_manager: T,
}
impl<T: CacheManager + Send + Sync + 'static> Cache<T> {
/// Called by the Reqwest middleware handle method when a request is made.
pub async fn run<'a>(
&'a self,
mut req: Request,
next: Next<'a>,
extensions: &mut Extensions,
) -> Result<Response> {
let is_cacheable = (req.method() == Method::GET || req.method() == Method::HEAD)
&& self.mode != CacheMode::NoStore
&& self.mode != CacheMode::Reload;
if !is_cacheable {
return self.remote_fetch(req, next, extensions).await;
}
if let Some(store) = self.cache_manager.get(&req).await? {
let (mut res, policy) = store;
if let Some(warning_code) = get_warning_code(&res) {
// https://tools.ietf.org/html/rfc7234#section-4.3.4
//
// If a stored response is selected for update, the cache MUST:
//
// * delete any Warning header fields in the stored response with
// warn-code 1xx (see Section 5.5);
//
// * retain any Warning header fields in the stored response with
// warn-code 2xx;
//
#[allow(clippy::manual_range_contains)]
if warning_code >= 100 && warning_code < 200 {
res.headers_mut().remove(reqwest::header::WARNING);
}
}
match self.mode {
CacheMode::Default => Ok(self
.conditional_fetch(req, res, policy, next, extensions)
.await?),
CacheMode::NoCache => {
req.headers_mut()
.insert(CACHE_CONTROL, HeaderValue::from_str("no-cache")?);
Ok(self
.conditional_fetch(req, res, policy, next, extensions)
.await?)
}
CacheMode::ForceCache | CacheMode::OnlyIfCached => {
// 112 Disconnected operation
// SHOULD be included if the cache is intentionally disconnected from
// the rest of the network for a period of time.
// (https://tools.ietf.org/html/rfc2616#section-14.46)
add_warning(&mut res, req.url(), 112, "Disconnected operation");
Ok(res)
}
_ => Ok(self.remote_fetch(req, next, extensions).await?),
}
} else {
match self.mode {
CacheMode::OnlyIfCached => {
// ENOTCACHED
let err_res = http::Response::builder()
.status(http::StatusCode::GATEWAY_TIMEOUT)
.body("")?;
Ok(err_res.into())
}
_ => Ok(self.remote_fetch(req, next, extensions).await?),
}
}
}
async fn conditional_fetch<'a>(
&self,
mut req: Request,
mut cached_res: Response,
mut policy: CachePolicy,
next: Next<'_>,
extensions: &mut Extensions,
) -> Result<Response> {
let before_req = policy.before_request(&req, SystemTime::now());
match before_req {
BeforeRequest::Fresh(parts) => {
update_response_headers(parts, &mut cached_res);
return Ok(cached_res);
}
BeforeRequest::Stale {
request: parts,
matches,
} => {
if matches {
update_request_headers(parts, &mut req);
}
}
}
let copied_req = req.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!(
"Request object is not cloneable. Are you passing a streaming body?".to_string()
))
})?;
match self.remote_fetch(req, next, extensions).await {
Ok(cond_res) => {
if cond_res.status().is_server_error() && must_revalidate(&cached_res) {
// 111 Revalidation failed
// MUST be included if a cache returns a stale response
// because an attempt to revalidate the response failed,
// due to an inability to reach the server.
// (https://tools.ietf.org/html/rfc2616#section-14.46)
add_warning(
&mut cached_res,
copied_req.url(),
111,
"Revalidation failed",
);
Ok(cached_res)
} else if cond_res.status() == http::StatusCode::NOT_MODIFIED {
let mut res = http::Response::builder()
.status(cond_res.status())
.body(cached_res.text().await?)?;
for (key, value) in cond_res.headers() {
res.headers_mut().append(key, value.clone());
}
let mut converted = Response::from(res);
let after_res =
policy.after_response(&copied_req, &cond_res, SystemTime::now());
match after_res {
AfterResponse::Modified(new_policy, parts) => {
policy = new_policy;
update_response_headers(parts, &mut converted);
}
AfterResponse::NotModified(new_policy, parts) => {
policy = new_policy;
update_response_headers(parts, &mut converted);
}
}
let res = self
.cache_manager
.put(&copied_req, converted, policy)
.await?;
Ok(res)
} else {
Ok(cached_res)
}
}
Err(e) => {
if must_revalidate(&cached_res) {
Err(e)
} else {
// 111 Revalidation failed
// MUST be included if a cache returns a stale response
// because an attempt to revalidate the response failed,
// due to an inability to reach the server.
// (https://tools.ietf.org/html/rfc2616#section-14.46)
add_warning(
&mut cached_res,
copied_req.url(),
111,
"Revalidation failed",
);
// 199 Miscellaneous warning
// The warning text MAY include arbitrary information to
// be presented to a human user, or logged. A system
// receiving this warning MUST NOT take any automated
// action, besides presenting the warning to the user.
// (https://tools.ietf.org/html/rfc2616#section-14.46)
add_warning(
&mut cached_res,
copied_req.url(),
199,
format!("Miscellaneous Warning {}", e).as_str(),
);
Ok(cached_res)
}
}
}
}
async fn remote_fetch<'a>(
&'a self,
req: Request,
next: Next<'a>,
extensions: &mut Extensions,
) -> Result<Response> {
let copied_req = req.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!(
"Request object is not clonable. Are you passing a streaming body?".to_string()
))
})?;
let res = next.run(req, extensions).await?;
let is_method_get_head =
copied_req.method() == Method::GET || copied_req.method() == Method::HEAD;
let policy = CachePolicy::new(&copied_req, &res);
let is_cacheable = self.mode != CacheMode::NoStore
&& is_method_get_head
&& res.status() == http::StatusCode::OK
&& policy.is_storable();
if is_cacheable {
Ok(self.cache_manager.put(&copied_req, res, policy).await?)
} else if !is_method_get_head {
self.cache_manager.delete(&copied_req).await?;
Ok(res)
} else {
Ok(res)
}
}
}
fn must_revalidate(res: &Response) -> bool {
if let Some(val) = res.headers().get(CACHE_CONTROL.as_str()) {
val.to_str()
.expect("Unable to convert header value to string")
.to_lowercase()
.contains("must-revalidate")
} else {
false
}
}
fn get_warning_code(res: &Response) -> Option<usize> {
res.headers().get(reqwest::header::WARNING).and_then(|hdr| {
hdr.to_str()
.expect("Unable to convert warning to string")
.chars()
.take(3)
.collect::<String>()
.parse()
.ok()
})
}
fn update_request_headers(parts: http::request::Parts, req: &mut Request) {
let headers = parts.headers;
for header in headers.iter() {
req.headers_mut().insert(header.0.clone(), header.1.clone());
}
}
fn update_response_headers(parts: http::response::Parts, res: &mut Response) {
for header in parts.headers.iter() {
res.headers_mut().insert(header.0.clone(), header.1.clone());
}
}
fn add_warning(res: &mut Response, uri: &reqwest::Url, code: usize, message: &str) {
// Warning = "Warning" ":" 1#warning-value
// warning-value = warn-code SP warn-agent SP warn-text [SP warn-date]
// warn-code = 3DIGIT
// warn-agent = ( host [ ":" port ] ) | pseudonym
// ; the name or pseudonym of the server adding
// ; the Warning header, for use in debugging
// warn-text = quoted-string
// warn-date = <"> HTTP-date <">
// (https://tools.ietf.org/html/rfc2616#section-14.46)
//
let val = HeaderValue::from_str(
format!(
"{} {} {:?} \"{}\"",
code,
uri.host().expect("Invalid URL"),
message,
httpdate::fmt_http_date(SystemTime::now())
)
.as_str(),
)
.expect("Failed to generate warning string");
res.headers_mut().append(reqwest::header::WARNING, val);
}
#[async_trait::async_trait]
impl<T: CacheManager + 'static + Send + Sync> Middleware for Cache<T> {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
let res = self.run(req, next, extensions).await?;
Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use http::{HeaderValue, Response};
use std::str::FromStr;
#[tokio::test]
async fn can_get_warning_code() -> Result<()> {
let url = reqwest::Url::from_str("https://example.com")?;
let mut res = reqwest::Response::from(Response::new(""));
add_warning(&mut res, &url, 111, "Revalidation failed");
let code = get_warning_code(&res).unwrap();
assert_eq!(code, 111);
Ok(())
}
#[tokio::test]
async fn can_check_revalidate() -> Result<()> {
let mut res = Response::new("");
res.headers_mut().append(
"Cache-Control",
HeaderValue::from_str("max-age=1733992, must-revalidate")?,
);
let check = must_revalidate(&res.into());
assert!(check, "{}", true);
Ok(())
}
}