#![forbid(unsafe_code, future_incompatible)]
#![deny(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
nonstandard_style,
unused_qualifications,
rustdoc::missing_doc_code_examples
)]
use std::time::SystemTime;
use anyhow::{anyhow, Result};
use http::{header::CACHE_CONTROL, HeaderValue, Method};
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
use reqwest::{Request, Response};
use reqwest_middleware::{Error, Middleware, Next};
use task_local_extensions::Extensions;
pub mod managers;
#[async_trait::async_trait]
pub trait CacheManager {
async fn get(&self, req: &Request) -> Result<Option<(Response, CachePolicy)>>;
async fn put(&self, req: &Request, res: Response, policy: CachePolicy) -> Result<Response>;
async fn delete(&self, req: &Request) -> Result<()>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheMode {
Default,
NoStore,
Reload,
NoCache,
ForceCache,
OnlyIfCached,
}
#[derive(Debug, Clone)]
pub struct Cache<T: CacheManager + Send + Sync + 'static> {
pub mode: CacheMode,
pub cache_manager: T,
}
impl<T: CacheManager + Send + Sync + 'static> Cache<T> {
pub async fn run<'a>(
&'a self,
mut req: Request,
next: Next<'a>,
extensions: &mut Extensions,
) -> Result<Response> {
let is_cacheable = (req.method() == Method::GET || req.method() == Method::HEAD)
&& self.mode != CacheMode::NoStore
&& self.mode != CacheMode::Reload;
if !is_cacheable {
return self.remote_fetch(req, next, extensions).await;
}
if let Some(store) = self.cache_manager.get(&req).await? {
let (mut res, policy) = store;
if let Some(warning_code) = get_warning_code(&res) {
#[allow(clippy::manual_range_contains)]
if warning_code >= 100 && warning_code < 200 {
res.headers_mut().remove(reqwest::header::WARNING);
}
}
match self.mode {
CacheMode::Default => Ok(self
.conditional_fetch(req, res, policy, next, extensions)
.await?),
CacheMode::NoCache => {
req.headers_mut()
.insert(CACHE_CONTROL, HeaderValue::from_str("no-cache")?);
Ok(self
.conditional_fetch(req, res, policy, next, extensions)
.await?)
}
CacheMode::ForceCache | CacheMode::OnlyIfCached => {
add_warning(&mut res, req.url(), 112, "Disconnected operation");
Ok(res)
}
_ => Ok(self.remote_fetch(req, next, extensions).await?),
}
} else {
match self.mode {
CacheMode::OnlyIfCached => {
let err_res = http::Response::builder()
.status(http::StatusCode::GATEWAY_TIMEOUT)
.body("")?;
Ok(err_res.into())
}
_ => Ok(self.remote_fetch(req, next, extensions).await?),
}
}
}
async fn conditional_fetch<'a>(
&self,
mut req: Request,
mut cached_res: Response,
mut policy: CachePolicy,
next: Next<'_>,
extensions: &mut Extensions,
) -> Result<Response> {
let before_req = policy.before_request(&req, SystemTime::now());
match before_req {
BeforeRequest::Fresh(parts) => {
update_response_headers(parts, &mut cached_res);
return Ok(cached_res);
}
BeforeRequest::Stale {
request: parts,
matches,
} => {
if matches {
update_request_headers(parts, &mut req);
}
}
}
let copied_req = req.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!(
"Request object is not cloneable. Are you passing a streaming body?".to_string()
))
})?;
match self.remote_fetch(req, next, extensions).await {
Ok(cond_res) => {
if cond_res.status().is_server_error() && must_revalidate(&cached_res) {
add_warning(
&mut cached_res,
copied_req.url(),
111,
"Revalidation failed",
);
Ok(cached_res)
} else if cond_res.status() == http::StatusCode::NOT_MODIFIED {
let mut res = http::Response::builder()
.status(cond_res.status())
.body(cached_res.text().await?)?;
for (key, value) in cond_res.headers() {
res.headers_mut().append(key, value.clone());
}
let mut converted = Response::from(res);
let after_res =
policy.after_response(&copied_req, &cond_res, SystemTime::now());
match after_res {
AfterResponse::Modified(new_policy, parts) => {
policy = new_policy;
update_response_headers(parts, &mut converted);
}
AfterResponse::NotModified(new_policy, parts) => {
policy = new_policy;
update_response_headers(parts, &mut converted);
}
}
let res = self
.cache_manager
.put(&copied_req, converted, policy)
.await?;
Ok(res)
} else {
Ok(cached_res)
}
}
Err(e) => {
if must_revalidate(&cached_res) {
Err(e)
} else {
add_warning(
&mut cached_res,
copied_req.url(),
111,
"Revalidation failed",
);
add_warning(
&mut cached_res,
copied_req.url(),
199,
format!("Miscellaneous Warning {}", e).as_str(),
);
Ok(cached_res)
}
}
}
}
async fn remote_fetch<'a>(
&'a self,
req: Request,
next: Next<'a>,
extensions: &mut Extensions,
) -> Result<Response> {
let copied_req = req.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!(
"Request object is not clonable. Are you passing a streaming body?".to_string()
))
})?;
let res = next.run(req, extensions).await?;
let is_method_get_head =
copied_req.method() == Method::GET || copied_req.method() == Method::HEAD;
let policy = CachePolicy::new(&copied_req, &res);
let is_cacheable = self.mode != CacheMode::NoStore
&& is_method_get_head
&& res.status() == http::StatusCode::OK
&& policy.is_storable();
if is_cacheable {
Ok(self.cache_manager.put(&copied_req, res, policy).await?)
} else if !is_method_get_head {
self.cache_manager.delete(&copied_req).await?;
Ok(res)
} else {
Ok(res)
}
}
}
fn must_revalidate(res: &Response) -> bool {
if let Some(val) = res.headers().get(CACHE_CONTROL.as_str()) {
val.to_str()
.expect("Unable to convert header value to string")
.to_lowercase()
.contains("must-revalidate")
} else {
false
}
}
fn get_warning_code(res: &Response) -> Option<usize> {
res.headers().get(reqwest::header::WARNING).and_then(|hdr| {
hdr.to_str()
.expect("Unable to convert warning to string")
.chars()
.take(3)
.collect::<String>()
.parse()
.ok()
})
}
fn update_request_headers(parts: http::request::Parts, req: &mut Request) {
let headers = parts.headers;
for header in headers.iter() {
req.headers_mut().insert(header.0.clone(), header.1.clone());
}
}
fn update_response_headers(parts: http::response::Parts, res: &mut Response) {
for header in parts.headers.iter() {
res.headers_mut().insert(header.0.clone(), header.1.clone());
}
}
fn add_warning(res: &mut Response, uri: &reqwest::Url, code: usize, message: &str) {
let val = HeaderValue::from_str(
format!(
"{} {} {:?} \"{}\"",
code,
uri.host().expect("Invalid URL"),
message,
httpdate::fmt_http_date(SystemTime::now())
)
.as_str(),
)
.expect("Failed to generate warning string");
res.headers_mut().append(reqwest::header::WARNING, val);
}
#[async_trait::async_trait]
impl<T: CacheManager + 'static + Send + Sync> Middleware for Cache<T> {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> reqwest_middleware::Result<Response> {
let res = self.run(req, next, extensions).await?;
Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use http::{HeaderValue, Response};
use std::str::FromStr;
#[tokio::test]
async fn can_get_warning_code() -> Result<()> {
let url = reqwest::Url::from_str("https://example.com")?;
let mut res = reqwest::Response::from(Response::new(""));
add_warning(&mut res, &url, 111, "Revalidation failed");
let code = get_warning_code(&res).unwrap();
assert_eq!(code, 111);
Ok(())
}
#[tokio::test]
async fn can_check_revalidate() -> Result<()> {
let mut res = Response::new("");
res.headers_mut().append(
"Cache-Control",
HeaderValue::from_str("max-age=1733992, must-revalidate")?,
);
let check = must_revalidate(&res.into());
assert!(check, "{}", true);
Ok(())
}
}