#![forbid(unsafe_code, future_incompatible)]
#![deny(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
nonstandard_style,
unused_qualifications,
rustdoc::missing_doc_code_examples
)]
mod error;
mod managers;
mod middleware;
pub use error::CacheError;
#[cfg(feature = "manager-cacache")]
pub use managers::cacache::CACacheManager;
use http::{header::CACHE_CONTROL, request, response, StatusCode};
use std::{collections::HashMap, str::FromStr, time::SystemTime};
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
use serde::{Deserialize, Serialize};
use url::Url;
pub type Result<T> = std::result::Result<T, CacheError>;
#[derive(Debug, Copy, Clone, Deserialize, Serialize)]
#[non_exhaustive]
pub enum HttpVersion {
#[serde(rename = "HTTP/0.9")]
Http09,
#[serde(rename = "HTTP/1.0")]
Http10,
#[serde(rename = "HTTP/1.1")]
Http11,
#[serde(rename = "HTTP/2.0")]
H2,
#[serde(rename = "HTTP/3.0")]
H3,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct HttpResponse {
pub body: Vec<u8>,
pub headers: HashMap<String, String>,
pub status: u16,
pub url: Url,
pub version: HttpVersion,
}
impl HttpResponse {
pub fn parts(&self) -> Result<response::Parts> {
let mut converted =
response::Builder::new().status(self.status).body(())?;
{
let headers = converted.headers_mut();
for header in self.headers.iter() {
headers.insert(
http::header::HeaderName::from_str(header.0.as_str())?,
http::HeaderValue::from_str(header.1.as_str())?,
);
}
}
Ok(converted.into_parts().0)
}
pub fn warning_code(&self) -> Option<usize> {
self.headers.get("Warning").and_then(|hdr| {
hdr.as_str().chars().take(3).collect::<String>().parse().ok()
})
}
pub fn add_warning(&mut self, url: Url, code: usize, message: &str) {
self.headers.insert(
"Warning".to_string(),
format!(
"{} {} {:?} \"{}\"",
code,
url.host().expect("Invalid URL"),
message,
httpdate::fmt_http_date(SystemTime::now())
),
);
}
pub fn remove_warning(&mut self) {
self.headers.remove("Warning");
}
pub fn update_headers(&mut self, parts: response::Parts) -> Result<()> {
for header in parts.headers.iter() {
self.headers.insert(
header.0.as_str().to_string(),
header.1.to_str()?.to_string(),
);
}
Ok(())
}
pub fn must_revalidate(&self) -> bool {
if let Some(val) = self.headers.get(CACHE_CONTROL.as_str()) {
val.as_str().to_lowercase().contains("must-revalidate")
} else {
false
}
}
}
#[async_trait::async_trait]
pub(crate) trait Middleware {
fn is_method_get_head(&self) -> bool;
fn policy(&self, response: &HttpResponse) -> Result<CachePolicy>;
fn update_headers(&mut self, parts: request::Parts) -> Result<()>;
fn set_no_cache(&mut self) -> Result<()>;
fn parts(&self) -> Result<request::Parts>;
fn url(&self) -> Result<&Url>;
fn method(&self) -> Result<String>;
async fn remote_fetch(&mut self) -> Result<HttpResponse>;
}
#[async_trait::async_trait]
pub trait CacheManager {
async fn get(
&self,
method: &str,
url: &Url,
) -> Result<Option<(HttpResponse, CachePolicy)>>;
async fn put(
&self,
method: &str,
url: &Url,
res: HttpResponse,
policy: CachePolicy,
) -> Result<HttpResponse>;
async fn delete(&self, method: &str, url: &Url) -> Result<()>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheMode {
Default,
NoStore,
Reload,
NoCache,
ForceCache,
OnlyIfCached,
}
#[derive(Debug, Clone)]
pub struct Cache<T: CacheManager + Send + Sync + 'static> {
pub mode: CacheMode,
pub cache_manager: T,
}
impl<T: CacheManager + Send + Sync + 'static> Cache<T> {
#[allow(dead_code)]
pub(crate) async fn run(
&self,
mut middleware: impl Middleware,
) -> Result<HttpResponse> {
let is_cacheable = middleware.is_method_get_head()
&& self.mode != CacheMode::NoStore
&& self.mode != CacheMode::Reload;
if !is_cacheable {
return middleware.remote_fetch().await;
}
if let Some(store) = self
.cache_manager
.get(&middleware.method()?, middleware.url()?)
.await?
{
let (mut res, policy) = store;
let res_url = res.url.clone();
if let Some(warning_code) = res.warning_code() {
#[allow(clippy::manual_range_contains)]
if warning_code >= 100 && warning_code < 200 {
res.remove_warning();
}
}
match self.mode {
CacheMode::Default => {
self.conditional_fetch(middleware, res, policy).await
}
CacheMode::NoCache => {
middleware.set_no_cache()?;
self.conditional_fetch(middleware, res, policy).await
}
CacheMode::ForceCache | CacheMode::OnlyIfCached => {
res.add_warning(res_url, 112, "Disconnected operation");
Ok(res)
}
_ => self.remote_fetch(&mut middleware).await,
}
} else {
match self.mode {
CacheMode::OnlyIfCached => {
return Ok(HttpResponse {
body: b"GatewayTimeout".to_vec(),
headers: Default::default(),
status: 504,
url: middleware.url()?.clone(),
version: HttpVersion::Http11,
});
}
_ => self.remote_fetch(&mut middleware).await,
}
}
}
#[allow(dead_code)]
async fn remote_fetch(
&self,
middleware: &mut impl Middleware,
) -> Result<HttpResponse> {
let res = middleware.remote_fetch().await?;
let policy = middleware.policy(&res)?;
let is_cacheable = middleware.is_method_get_head()
&& self.mode != CacheMode::NoStore
&& self.mode != CacheMode::Reload
&& res.status == 200
&& policy.is_storable();
if is_cacheable {
Ok(self
.cache_manager
.put(&middleware.method()?, middleware.url()?, res, policy)
.await?)
} else if !middleware.is_method_get_head() {
self.cache_manager
.delete(&middleware.method()?, middleware.url()?)
.await?;
Ok(res)
} else {
Ok(res)
}
}
#[allow(dead_code)]
async fn conditional_fetch(
&self,
mut middleware: impl Middleware,
mut cached_res: HttpResponse,
mut policy: CachePolicy,
) -> Result<HttpResponse> {
let before_req =
policy.before_request(&middleware.parts()?, SystemTime::now());
match before_req {
BeforeRequest::Fresh(parts) => {
cached_res.update_headers(parts)?;
return Ok(cached_res);
}
BeforeRequest::Stale { request: parts, matches } => {
if matches {
middleware.update_headers(parts)?;
}
}
}
let req_url = middleware.url()?.clone();
match middleware.remote_fetch().await {
Ok(cond_res) => {
let status = StatusCode::from_u16(cond_res.status)?;
if status.is_server_error() && cached_res.must_revalidate() {
cached_res.add_warning(
req_url.clone(),
111,
"Revalidation failed",
);
Ok(cached_res)
} else if cond_res.status == 304 {
let after_res = policy.after_response(
&middleware.parts()?,
&cond_res.parts()?,
SystemTime::now(),
);
match after_res {
AfterResponse::Modified(new_policy, parts)
| AfterResponse::NotModified(new_policy, parts) => {
policy = new_policy;
cached_res.update_headers(parts)?;
}
}
let res = self
.cache_manager
.put(
&middleware.method()?,
&req_url,
cached_res,
policy,
)
.await?;
Ok(res)
} else {
Ok(cached_res)
}
}
Err(e) => {
if cached_res.must_revalidate() {
Err(e)
} else {
cached_res.add_warning(req_url, 111, "Revalidation failed");
Ok(cached_res)
}
}
}
}
}