#![warn(missing_docs)]
#![deny(unconditional_recursion)]
use chrono::prelude::*;
use http::HeaderMap;
use http::HeaderValue;
use http::Method;
use http::Request;
use http::Response;
use http::StatusCode;
use http::Uri;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::time::Duration;
use std::time::SystemTime;
const STATUS_CODE_CACHEABLE_BY_DEFAULT: &[u16] =
&[200, 203, 204, 206, 300, 301, 404, 405, 410, 414, 501];
const UNDERSTOOD_STATUSES: &[u16] = &[
200, 203, 204, 300, 301, 302, 303, 307, 308, 404, 405, 410, 414, 501,
];
const HOP_BY_HOP_HEADERS: &[&str] = &[
"date",
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
];
const EXCLUDED_FROM_REVALIDATION_UPDATE: &[&str] = &[
"content-length",
"content-encoding",
"transfer-encoding",
"content-range",
];
type CacheControl = HashMap<Box<str>, Option<Box<str>>>;
fn parse_cache_control<'a>(headers: impl IntoIterator<Item = &'a HeaderValue>) -> CacheControl {
let mut cc = CacheControl::new();
let mut is_valid = true;
for h in headers.into_iter().filter_map(|v| v.to_str().ok()) {
for part in h.split(',') {
if part.trim().is_empty() {
continue;
}
let mut kv = part.splitn(2, '=');
let k = kv.next().unwrap().trim();
if k.is_empty() {
continue;
}
let v = kv.next().map(str::trim);
match cc.entry(k.into()) {
Entry::Occupied(e) => {
if e.get().as_deref() != v {
is_valid = false;
}
}
Entry::Vacant(e) => {
e.insert(v.map(|v| v.trim_matches('"')).map(From::from));
}
}
}
}
if !is_valid {
cc.insert("must-revalidate".into(), None);
}
cc
}
fn format_cache_control(cc: &CacheControl) -> String {
let mut out = String::new();
for (k, v) in cc {
if !out.is_empty() {
out.push_str(", ");
}
out.push_str(k);
if let Some(v) = v {
out.push('=');
let needs_quote =
v.is_empty() || v.as_bytes().iter().any(|b| !b.is_ascii_alphanumeric());
if needs_quote {
out.push('"');
}
out.push_str(v);
if needs_quote {
out.push('"');
}
}
}
out
}
#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "with_serde", derive(serde_derive::Serialize, serde_derive::Deserialize))]
pub struct CacheOptions {
pub shared: bool,
pub cache_heuristic: f32,
pub immutable_min_time_to_live: Duration,
pub ignore_cargo_cult: bool,
}
impl Default for CacheOptions {
fn default() -> Self {
Self {
shared: true,
cache_heuristic: 0.1,
immutable_min_time_to_live: Duration::from_secs(24 * 3600),
ignore_cargo_cult: false,
}
}
}
#[derive(Debug)]
#[cfg_attr(feature = "with_serde", derive(serde_derive::Serialize, serde_derive::Deserialize))]
pub struct CachePolicy {
#[cfg_attr(feature = "with_serde", serde(with = "http_serde::header_map"))]
req: HeaderMap,
#[cfg_attr(feature = "with_serde", serde(with = "http_serde::header_map"))]
res: HeaderMap,
#[cfg_attr(feature = "with_serde", serde(with = "http_serde::uri"))]
uri: Uri,
#[cfg_attr(feature = "with_serde", serde(with = "http_serde::status_code"))]
status: StatusCode,
#[cfg_attr(feature = "with_serde", serde(with = "http_serde::method"))]
method: Method,
opts: CacheOptions,
res_cc: CacheControl,
req_cc: CacheControl,
response_time: SystemTime,
}
impl CachePolicy {
#[inline]
pub fn new<Req: RequestLike, Res: ResponseLike>(req: &Req, res: &Res) -> Self {
let uri = req.uri();
let status = res.status();
let method = req.method().clone();
let res = res.headers().clone();
let req = req.headers().clone();
Self::from_details(
uri,
method,
status,
req,
res,
SystemTime::now(),
Default::default(),
)
}
#[inline]
pub fn new_options<Req: RequestLike, Res: ResponseLike>(
req: &Req,
res: &Res,
response_time: SystemTime,
opts: CacheOptions,
) -> Self {
let uri = req.uri();
let status = res.status();
let method = req.method().clone();
let res = res.headers().clone();
let req = req.headers().clone();
Self::from_details(uri, method, status, req, res, response_time, opts)
}
fn from_details(
uri: Uri,
method: Method,
status: StatusCode,
req: HeaderMap,
mut res: HeaderMap,
response_time: SystemTime,
opts: CacheOptions,
) -> Self {
let mut res_cc = parse_cache_control(res.get_all("cache-control"));
let req_cc = parse_cache_control(req.get_all("cache-control"));
if opts.ignore_cargo_cult
&& res_cc.get("pre-check").is_some()
&& res_cc.get("post-check").is_some()
{
res_cc.remove("pre-check");
res_cc.remove("post-check");
res_cc.remove("no-cache");
res_cc.remove("no-store");
res_cc.remove("must-revalidate");
res.insert(
"cache-control",
HeaderValue::from_str(&format_cache_control(&res_cc)).unwrap(),
);
res.remove("expires");
res.remove("pragma");
}
if !res.contains_key("cache-control")
&& res
.get_str("pragma")
.map_or(false, |p| p.contains("no-cache"))
{
res_cc.insert("no-cache".into(), None);
}
Self {
status,
method,
res,
req,
res_cc,
req_cc,
opts,
uri,
response_time,
}
}
pub fn is_storable(&self) -> bool {
!self.req_cc.contains_key("no-store") &&
(Method::GET == self.method ||
Method::HEAD == self.method ||
(Method::POST == self.method && self.has_explicit_expiration())) &&
UNDERSTOOD_STATUSES.contains(&self.status.as_u16()) &&
!self.res_cc.contains_key("no-store") &&
(!self.opts.shared || !self.res_cc.contains_key("private")) &&
(!self.opts.shared ||
!self.req.contains_key("authorization") ||
self.allows_storing_authenticated()) &&
(self.res.contains_key("expires") ||
self.res_cc.contains_key("max-age") ||
(self.opts.shared && self.res_cc.contains_key("s-maxage")) ||
self.res_cc.contains_key("public") ||
STATUS_CODE_CACHEABLE_BY_DEFAULT.contains(&self.status.as_u16()))
}
fn has_explicit_expiration(&self) -> bool {
(self.opts.shared && self.res_cc.contains_key("s-maxage"))
|| self.res_cc.contains_key("max-age")
|| self.res.contains_key("expires")
}
pub fn before_request<Req: RequestLike>(&self, req: &Req, now: SystemTime) -> BeforeRequest {
let req_headers = req.headers();
let (exact_match, may_revalidate) = self.request_matches(req);
if exact_match && self.satisfies_without_revalidation(&req_headers, now) {
BeforeRequest::Fresh(self.cached_response(now))
} else if may_revalidate {
BeforeRequest::Stale {
request: self.revalidation_request(req),
matches: false,
}
} else {
BeforeRequest::Stale {
request: self.request_from_headers(req_headers.clone()),
matches: false,
}
}
}
fn satisfies_without_revalidation(&self, req_headers: &HeaderMap, now: SystemTime) -> bool {
let req_cc = parse_cache_control(req_headers.get_all("cache-control"));
if req_cc.contains_key("no-cache")
|| req_headers
.get_str("pragma")
.map_or(false, |v| v.contains("no-cache"))
{
return false;
}
if let Some(max_age) = req_cc
.get("max-age")
.and_then(|v| v.as_ref())
.and_then(|p| p.parse().ok())
{
if self.age(now) > Duration::from_secs(max_age) {
return false;
}
}
if let Some(min_fresh) = req_cc
.get("min-fresh")
.and_then(|v| v.as_ref())
.and_then(|p| p.parse().ok())
{
if self.time_to_live(now) < Duration::from_secs(min_fresh) {
return false;
}
}
if self.is_stale(now) {
let max_stale = req_cc.get("max-stale");
let has_max_stale = max_stale.is_some();
let max_stale = max_stale
.and_then(|m| m.as_ref())
.and_then(|s| s.parse().ok());
let allows_stale = !self.res_cc.contains_key("must-revalidate")
&& has_max_stale
&& max_stale.map_or(true, |val| {
Duration::from_secs(val) > self.age(now) - self.max_age()
});
if !allows_stale {
return false;
}
}
true
}
fn request_matches<Req: RequestLike>(&self, req: &Req) -> (bool, bool) {
let matches = req.is_same_uri(&self.uri) &&
(self.req.get("host") == req.headers().get("host")) &&
self.vary_matches(req);
let exact_match = matches && self.method == req.method();
(exact_match, exact_match || Method::HEAD == req.method())
}
fn allows_storing_authenticated(&self) -> bool {
self.res_cc.contains_key("must-revalidate")
|| self.res_cc.contains_key("public")
|| self.res_cc.contains_key("s-maxage")
}
fn vary_matches<Req: RequestLike>(&self, req: &Req) -> bool {
for name in get_all_comma(self.res.get_all("vary")) {
if name == "*" {
return false;
}
let name = name.trim().to_ascii_lowercase();
if req.headers().get(&name) != self.req.get(&name) {
return false;
}
}
true
}
fn copy_without_hop_by_hop_headers(in_headers: &HeaderMap) -> HeaderMap {
let mut headers = HeaderMap::with_capacity(in_headers.len());
for (h, v) in in_headers
.iter()
.filter(|(h, _)| !HOP_BY_HOP_HEADERS.contains(&h.as_str()))
{
headers.insert(h.to_owned(), v.to_owned());
}
for name in get_all_comma(in_headers.get_all("connection")) {
headers.remove(name);
}
let new_warnings = join(
get_all_comma(in_headers.get_all("warning")).filter(|warning| {
!warning.trim_start().starts_with('1')
}),
);
if new_warnings.is_empty() {
headers.remove("warning");
} else {
headers.insert("warning", HeaderValue::from_str(&new_warnings).unwrap());
}
headers
}
fn cached_response(&self, now: SystemTime) -> http::response::Parts {
let mut headers = Self::copy_without_hop_by_hop_headers(&self.res);
let age = self.age(now);
let day = Duration::from_secs(3600 * 24);
if age > day && !self.has_explicit_expiration() && self.max_age() > day {
headers.append(
"warning",
HeaderValue::from_static(r#"113 - "rfc7234 5.5.4""#),
);
}
let timestamp = now
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
let date = DateTime::<Utc>::from_utc(NaiveDateTime::from_timestamp(timestamp as _, 0), Utc);
headers.insert(
"age",
HeaderValue::from_str(&format!("{}", age.as_secs() as u32)).unwrap(),
);
headers.insert("date", HeaderValue::from_str(&date.to_rfc2822()).unwrap());
let mut parts = Response::builder()
.status(self.status)
.body(())
.unwrap()
.into_parts().0;
parts.headers = headers;
parts
}
fn raw_server_date(&self) -> SystemTime {
let date = self
.res
.get_str("date")
.and_then(|d| DateTime::parse_from_rfc2822(d).ok())
.and_then(|d| {
SystemTime::UNIX_EPOCH.checked_add(Duration::from_secs(d.timestamp() as _))
});
return date.unwrap_or(self.response_time)
}
pub fn age(&self, now: SystemTime) -> Duration {
let mut age = self.age_header_value();
if let Ok(resident_time) = now.duration_since(self.response_time) {
age += resident_time;
}
age
}
fn age_header_value(&self) -> Duration {
Duration::from_secs(
self.res
.get_str("age")
.and_then(|v| v.parse().ok())
.unwrap_or(0),
)
}
fn max_age(&self) -> Duration {
if !self.is_storable() || self.res_cc.contains_key("no-cache") {
return Duration::from_secs(0);
}
if self.opts.shared
&& (self.res.contains_key("set-cookie")
&& !self.res_cc.contains_key("public")
&& !self.res_cc.contains_key("immutable"))
{
return Duration::from_secs(0);
}
if self.res.get_str("vary").map(str::trim) == Some("*") {
return Duration::from_secs(0);
}
if self.opts.shared {
if self.res_cc.contains_key("proxy-revalidate") {
return Duration::from_secs(0);
}
if let Some(s_max) = self.res_cc.get("s-maxage").and_then(|v| v.as_ref()) {
return Duration::from_secs(s_max.parse().unwrap_or(0));
}
}
if let Some(max_age) = self.res_cc.get("max-age").and_then(|v| v.as_ref()) {
return Duration::from_secs(max_age.parse().unwrap_or(0));
}
let default_min_ttl = if self.res_cc.contains_key("immutable") {
self.opts.immutable_min_time_to_live
} else {
Duration::from_secs(0)
};
let server_date = self.raw_server_date();
if let Some(expires) = self.res.get_str("expires") {
return match DateTime::parse_from_rfc2822(expires) {
Err(_) => Duration::from_secs(0),
Ok(expires) => {
let expires = SystemTime::UNIX_EPOCH
+ Duration::from_secs(expires.timestamp().max(0) as _);
return default_min_ttl
.max(expires.duration_since(server_date).unwrap_or_default());
}
};
}
if let Some(last_modified) = self.res.get_str("last-modified") {
if let Ok(last_modified) = DateTime::parse_from_rfc2822(last_modified) {
let last_modified = SystemTime::UNIX_EPOCH
+ Duration::from_secs(last_modified.timestamp().max(0) as _);
if let Ok(diff) = server_date.duration_since(last_modified) {
let secs_left = diff.as_secs() as f64 * self.opts.cache_heuristic as f64;
return default_min_ttl.max(Duration::from_secs(secs_left as _));
}
}
}
default_min_ttl
}
pub fn time_to_live(&self, now: SystemTime) -> Duration {
self.max_age()
.checked_sub(self.age(now))
.unwrap_or_default()
}
pub fn is_stale(&self, now: SystemTime) -> bool {
self.max_age() <= self.age(now)
}
fn revalidation_request<Req: RequestLike>(&self, incoming_req: &Req) -> http::request::Parts {
let mut headers = Self::copy_without_hop_by_hop_headers(incoming_req.headers());
headers.remove("if-range");
if !self.is_storable() {
headers.remove("if-none-match");
headers.remove("if-modified-since");
return self.request_from_headers(headers);
}
if let Some(etag) = self.res.get_str("etag") {
let if_none = join(get_all_comma(headers.get_all("if-none-match")).chain(Some(etag)));
headers.insert("if-none-match", HeaderValue::from_str(&if_none).unwrap());
}
let forbids_weak_validators = self.method != Method::GET
|| headers.contains_key("accept-ranges")
|| headers.contains_key("if-match")
|| headers.contains_key("if-unmodified-since");
if forbids_weak_validators {
headers.remove("if-modified-since");
let etags = join(
get_all_comma(headers.get_all("if-none-match"))
.filter(|etag| !etag.trim_start().starts_with("W/")),
);
if etags.is_empty() {
headers.remove("if-none-match");
} else {
headers.insert("if-none-match", HeaderValue::from_str(&etags).unwrap());
}
} else if !headers.contains_key("if-modified-since") {
if let Some(last_modified) = self.res.get_str("last-modified") {
headers.insert(
"if-modified-since",
HeaderValue::from_str(&last_modified).unwrap(),
);
}
}
self.request_from_headers(headers)
}
fn request_from_headers(&self, headers: HeaderMap) -> http::request::Parts {
let mut parts = Request::builder()
.method(self.method.clone())
.uri(self.uri.clone())
.body(())
.unwrap()
.into_parts().0;
parts.headers = headers;
parts
}
pub fn after_response<Req: RequestLike, Res: ResponseLike>(
&self,
request: &Req,
response: &Res,
response_time: SystemTime,
) -> AfterResponse {
let response_headers = response.headers();
let mut response_status = response.status();
let old_etag = &self.res.get_str("etag").map(str::trim);
let old_last_modified = response_headers.get_str("last-modified").map(str::trim);
let new_etag = response_headers.get_str("etag").map(str::trim);
let new_last_modified = response_headers.get_str("last-modified").map(str::trim);
let mut matches = false;
if response.status() != StatusCode::NOT_MODIFIED {
matches = false;
} else if new_etag.map_or(false, |etag| !etag.starts_with("W/")) {
matches = old_etag.map(|e| e.trim_start_matches("W/")) == new_etag;
} else if let (Some(old), Some(new)) = (old_etag, new_etag) {
matches = old.trim_start_matches("W/") == new.trim_start_matches("W/");
} else if old_last_modified.is_some() {
matches = old_last_modified == new_last_modified;
} else {
if old_etag.is_none()
&& new_etag.is_none()
&& old_last_modified.is_none()
&& new_last_modified.is_none()
{
matches = true;
}
}
let new_response_headers = if matches {
let mut new_response_headers = HeaderMap::with_capacity(self.res.keys_len());
for (header, old_value) in &self.res {
let header = header.to_owned();
if let Some(new_value) = response_headers.get(&header) {
if !EXCLUDED_FROM_REVALIDATION_UPDATE.contains(&&header.as_str()) {
new_response_headers.insert(header, new_value.to_owned());
continue;
}
}
new_response_headers.insert(header, old_value.to_owned());
}
response_status = self.status;
new_response_headers
} else {
response_headers.clone()
};
let new_policy = CachePolicy::from_details(
request.uri(),
request.method().clone(),
response_status,
request.headers().clone(),
new_response_headers,
response_time,
self.opts,
);
let new_response = new_policy.cached_response(response_time);
if matches && response.status() == StatusCode::NOT_MODIFIED {
AfterResponse::NotModified(new_policy, new_response)
} else {
AfterResponse::Modified(new_policy, new_response)
}
}
}
pub enum AfterResponse {
NotModified(CachePolicy, http::response::Parts),
Modified(CachePolicy, http::response::Parts),
}
fn get_all_comma<'a>(
all: impl IntoIterator<Item = &'a HeaderValue>,
) -> impl Iterator<Item = &'a str> {
all.into_iter()
.filter_map(|v| v.to_str().ok())
.flat_map(|s| s.split(',').map(str::trim))
}
trait GetHeaderStr {
fn get_str(&self, k: &str) -> Option<&str>;
}
impl GetHeaderStr for HeaderMap {
#[inline]
fn get_str(&self, k: &str) -> Option<&str> {
self.get(k).and_then(|v| v.to_str().ok())
}
}
fn join<'a>(parts: impl Iterator<Item = &'a str>) -> String {
let mut out = String::new();
for part in parts {
out.reserve(2 + part.len());
if !out.is_empty() {
out.push_str(", ");
}
out.push_str(part);
}
out
}
pub enum BeforeRequest {
Fresh(http::response::Parts),
Stale {
request: http::request::Parts,
matches: bool,
},
}
impl BeforeRequest {
pub fn satisfies_without_revalidation(&self) -> bool {
match self {
Self::Fresh(_) => true,
_ => false,
}
}
}
pub trait RequestLike {
fn uri(&self) -> Uri;
fn is_same_uri(&self, other: &Uri) -> bool;
fn method(&self) -> &Method;
fn headers(&self) -> &HeaderMap;
}
pub trait ResponseLike {
fn status(&self) -> StatusCode;
fn headers(&self) -> &HeaderMap;
}
impl<Body> RequestLike for Request<Body> {
fn uri(&self) -> Uri {
self.uri().clone()
}
fn is_same_uri(&self, other: &Uri) -> bool {
self.uri() == other
}
fn method(&self) -> &Method {
self.method()
}
fn headers(&self) -> &HeaderMap {
self.headers()
}
}
impl RequestLike for http::request::Parts {
fn uri(&self) -> Uri {
self.uri.clone()
}
fn is_same_uri(&self, other: &Uri) -> bool {
&self.uri == other
}
fn method(&self) -> &Method {
&self.method
}
fn headers(&self) -> &HeaderMap {
&self.headers
}
}
impl<Body> ResponseLike for Response<Body> {
fn status(&self) -> StatusCode {
self.status()
}
fn headers(&self) -> &HeaderMap {
self.headers()
}
}
impl ResponseLike for http::response::Parts {
fn status(&self) -> StatusCode {
self.status
}
fn headers(&self) -> &HeaderMap {
&self.headers
}
}
#[cfg(feature = "reqwest")]
impl RequestLike for reqwest::Request {
fn uri(&self) -> Uri {
self.url().as_str().parse().expect("Uri and Url are incompatible!?")
}
fn is_same_uri(&self, other: &Uri) -> bool {
self.url().as_str() == other
}
fn method(&self) -> &Method {
self.method()
}
fn headers(&self) -> &HeaderMap {
self.headers()
}
}
#[cfg(feature = "reqwest")]
impl ResponseLike for reqwest::Response {
fn status(&self) -> StatusCode {
self.status()
}
fn headers(&self) -> &HeaderMap {
self.headers()
}
}