#![warn(missing_docs)]
use chrono::prelude::*;
use http::HeaderMap;
use http::HeaderValue;
use http::Method;
use http::Request;
use http::Response;
use http::StatusCode;
use http::Uri;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::time::Duration;
use std::time::SystemTime;
const STATUS_CODE_CACHEABLE_BY_DEFAULT: &[u16] =
&[200, 203, 204, 206, 300, 301, 404, 405, 410, 414, 501];
const UNDERSTOOD_STATUSES: &[u16] = &[
200, 203, 204, 300, 301, 302, 303, 307, 308, 404, 405, 410, 414, 501,
];
const HOP_BY_HOP_HEADERS: &[&str] = &[
"date",
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
];
const EXCLUDED_FROM_REVALIDATION_UPDATE: &[&str] = &[
"content-length",
"content-encoding",
"transfer-encoding",
"content-range",
];
type CacheControl = HashMap<Box<str>, Option<Box<str>>>;
fn parse_cache_control<'a>(headers: impl IntoIterator<Item = &'a HeaderValue>) -> CacheControl {
let mut cc = CacheControl::new();
let mut is_valid = true;
for h in headers.into_iter().filter_map(|v| v.to_str().ok()) {
for part in h.split(',') {
if part.trim().is_empty() {
continue;
}
let mut kv = part.splitn(2, '=');
let k = kv.next().unwrap().trim();
if k.is_empty() {
continue;
}
let v = kv.next().map(str::trim);
match cc.entry(k.into()) {
Entry::Occupied(e) => {
if e.get().as_deref() != v {
is_valid = false;
}
}
Entry::Vacant(e) => {
e.insert(v.map(|v| v.trim_matches('"')).map(From::from));
}
}
}
}
if !is_valid {
cc.insert("must-revalidate".into(), None);
}
cc
}
fn format_cache_control(cc: &CacheControl) -> String {
let mut out = String::new();
for (k, v) in cc {
if !out.is_empty() {
out.push_str(", ");
}
out.push_str(k);
if let Some(v) = v {
out.push('=');
let needs_quote =
v.is_empty() || v.as_bytes().iter().any(|b| !b.is_ascii_alphanumeric());
if needs_quote {
out.push('"');
}
out.push_str(v);
if needs_quote {
out.push('"');
}
}
}
out
}
#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "with_serde", derive(serde_derive::Serialize, serde_derive::Deserialize))]
pub struct CachePolicyOptions {
pub shared: bool,
pub cache_heuristic: f32,
pub immutable_min_time_to_live: Duration,
pub ignore_cargo_cult: bool,
pub trust_server_date: bool,
pub response_time: SystemTime,
}
impl Default for CachePolicyOptions {
fn default() -> Self {
Self {
shared: true,
cache_heuristic: 0.1,
immutable_min_time_to_live: Duration::from_secs(24 * 3600),
ignore_cargo_cult: false,
trust_server_date: true,
response_time: SystemTime::now(),
}
}
}
#[derive(Debug)]
#[cfg_attr(feature = "with_serde", derive(serde_derive::Serialize, serde_derive::Deserialize))]
pub struct CachePolicy {
#[cfg_attr(feature = "with_serde", serde(with = "http_serde::header_map"))]
req: HeaderMap,
#[cfg_attr(feature = "with_serde", serde(with = "http_serde::header_map"))]
res: HeaderMap,
#[cfg_attr(feature = "with_serde", serde(with = "http_serde::uri"))]
uri: Uri,
#[cfg_attr(feature = "with_serde", serde(with = "http_serde::status_code"))]
status: StatusCode,
#[cfg_attr(feature = "with_serde", serde(with = "http_serde::method"))]
method: Method,
opts: CachePolicyOptions,
res_cc: CacheControl,
req_cc: CacheControl,
}
impl CachePolicy {
#[inline]
pub fn new<Req: RequestLike, Res: ResponseLike>(
req: &Req,
res: &Res,
opts: CachePolicyOptions,
) -> Self {
let uri = req.uri().clone();
let status = res.status();
let method = req.method().clone();
let res = res.headers().clone();
let req = req.headers().clone();
Self::from_details(uri, method, status, req, res, opts)
}
fn from_details(
uri: Uri,
method: Method,
status: StatusCode,
req: HeaderMap,
mut res: HeaderMap,
opts: CachePolicyOptions,
) -> Self {
let mut res_cc = parse_cache_control(res.get_all("cache-control"));
let req_cc = parse_cache_control(req.get_all("cache-control"));
if opts.ignore_cargo_cult
&& res_cc.get("pre-check").is_some()
&& res_cc.get("post-check").is_some()
{
res_cc.remove("pre-check");
res_cc.remove("post-check");
res_cc.remove("no-cache");
res_cc.remove("no-store");
res_cc.remove("must-revalidate");
res.insert(
"cache-control",
HeaderValue::from_str(&format_cache_control(&res_cc)).unwrap(),
);
res.remove("expires");
res.remove("pragma");
}
if !res.contains_key("cache-control")
&& res
.get_str("pragma")
.map_or(false, |p| p.contains("no-cache"))
{
res_cc.insert("no-cache".into(), None);
}
Self {
status,
method,
res,
req,
res_cc,
req_cc,
opts,
uri,
}
}
pub fn is_storable(&self) -> bool {
!self.req_cc.contains_key("no-store") &&
(Method::GET == self.method ||
Method::HEAD == self.method ||
(Method::POST == self.method && self.has_explicit_expiration())) &&
UNDERSTOOD_STATUSES.contains(&self.status.as_u16()) &&
!self.res_cc.contains_key("no-store") &&
(!self.opts.shared || !self.res_cc.contains_key("private")) &&
(!self.opts.shared ||
!self.req.contains_key("authorization") ||
self.allows_storing_authenticated()) &&
(self.res.contains_key("expires") ||
self.res_cc.contains_key("max-age") ||
(self.opts.shared && self.res_cc.contains_key("s-maxage")) ||
self.res_cc.contains_key("public") ||
STATUS_CODE_CACHEABLE_BY_DEFAULT.contains(&self.status.as_u16()))
}
fn has_explicit_expiration(&self) -> bool {
(self.opts.shared && self.res_cc.contains_key("s-maxage"))
|| self.res_cc.contains_key("max-age")
|| self.res.contains_key("expires")
}
pub fn satisfies_without_revalidation<Req: RequestLike>(
&self,
req: &Req,
now: SystemTime,
) -> bool {
let req_headers = req.headers();
let req_cc = parse_cache_control(req_headers.get_all("cache-control"));
if req_cc.contains_key("no-cache")
|| req_headers
.get_str("pragma")
.map_or(false, |v| v.contains("no-cache"))
{
return false;
}
if let Some(max_age) = req_cc
.get("max-age")
.and_then(|v| v.as_ref())
.and_then(|p| p.parse().ok())
{
if self.age(now) > Duration::from_secs(max_age) {
return false;
}
}
if let Some(min_fresh) = req_cc
.get("min-fresh")
.and_then(|v| v.as_ref())
.and_then(|p| p.parse().ok())
{
if self.time_to_live(now) < Duration::from_secs(min_fresh) {
return false;
}
}
if self.is_stale(now) {
let max_stale = req_cc.get("max-stale");
let has_max_stale = max_stale.is_some();
let max_stale = max_stale
.and_then(|m| m.as_ref())
.and_then(|s| s.parse().ok());
let allows_stale = !self.res_cc.contains_key("must-revalidate")
&& has_max_stale
&& max_stale.map_or(true, |val| {
Duration::from_secs(val) > self.age(now) - self.max_age()
});
if !allows_stale {
return false;
}
}
self.request_matches(req, false)
}
fn request_matches<Req: RequestLike>(&self, req: &Req, allow_head_method: bool) -> bool {
&self.uri == req.uri() &&
self.req.get("host") == req.headers().get("host") &&
(
self.method == req.method() ||
(allow_head_method && Method::HEAD == req.method())) &&
self.vary_matches(req)
}
fn allows_storing_authenticated(&self) -> bool {
self.res_cc.contains_key("must-revalidate")
|| self.res_cc.contains_key("public")
|| self.res_cc.contains_key("s-maxage")
}
fn vary_matches<Req: RequestLike>(&self, req: &Req) -> bool {
for name in get_all_comma(self.res.get_all("vary")) {
if name == "*" {
return false;
}
let name = name.trim().to_ascii_lowercase();
if req.headers().get(&name) != self.req.get(&name) {
return false;
}
}
true
}
fn copy_without_hop_by_hop_headers(in_headers: &HeaderMap) -> HeaderMap {
let mut headers = HeaderMap::with_capacity(in_headers.len());
for (h, v) in in_headers
.iter()
.filter(|(h, _)| !HOP_BY_HOP_HEADERS.contains(&h.as_str()))
{
headers.insert(h.to_owned(), v.to_owned());
}
for name in get_all_comma(in_headers.get_all("connection")) {
headers.remove(name);
}
let new_warnings = join(
get_all_comma(in_headers.get_all("warning")).filter(|warning| {
!warning.trim_start().starts_with('1')
}),
);
if new_warnings.is_empty() {
headers.remove("warning");
} else {
headers.insert("warning", HeaderValue::from_str(&new_warnings).unwrap());
}
headers
}
pub fn response_headers(&self, now: SystemTime) -> HeaderMap {
let mut headers = Self::copy_without_hop_by_hop_headers(&self.res);
let age = self.age(now);
let day = Duration::from_secs(3600 * 24);
if age > day && !self.has_explicit_expiration() && self.max_age() > day {
headers.append(
"warning",
HeaderValue::from_static(r#"113 - "rfc7234 5.5.4""#),
);
}
let timestamp = now
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
let date = DateTime::<Utc>::from_utc(NaiveDateTime::from_timestamp(timestamp as _, 0), Utc);
headers.insert(
"age",
HeaderValue::from_str(&format!("{}", age.as_secs() as u32)).unwrap(),
);
headers.insert("date", HeaderValue::from_str(&date.to_rfc2822()).unwrap());
headers
}
fn date(&self) -> SystemTime {
if self.opts.trust_server_date {
self.server_date()
} else {
self.opts.response_time
}
}
fn server_date(&self) -> SystemTime {
let date = self
.res
.get_str("date")
.and_then(|d| DateTime::parse_from_rfc2822(d).ok())
.and_then(|d| {
SystemTime::UNIX_EPOCH.checked_add(Duration::from_secs(d.timestamp() as _))
});
if let Some(date) = date {
let max_clock_drift = Duration::from_secs(8 * 3600);
let clock_drift = if self.opts.response_time > date {
self.opts.response_time.duration_since(date)
} else {
date.duration_since(self.opts.response_time)
}
.unwrap();
if clock_drift < max_clock_drift {
return date;
}
}
self.opts.response_time
}
pub fn age(&self, now: SystemTime) -> Duration {
let mut age = self.age_header();
if let Ok(since_date) = self.opts.response_time.duration_since(self.date()) {
age = age.max(since_date);
}
if let Ok(resident_time) = now.duration_since(self.opts.response_time) {
age += resident_time;
}
age
}
fn age_header(&self) -> Duration {
Duration::from_secs(
self.res
.get_str("age")
.and_then(|v| v.parse().ok())
.unwrap_or(0),
)
}
pub fn max_age(&self) -> Duration {
if !self.is_storable() || self.res_cc.contains_key("no-cache") {
return Duration::from_secs(0);
}
if self.opts.shared
&& (self.res.contains_key("set-cookie")
&& !self.res_cc.contains_key("public")
&& !self.res_cc.contains_key("immutable"))
{
return Duration::from_secs(0);
}
if self.res.get_str("vary").map(str::trim) == Some("*") {
return Duration::from_secs(0);
}
if self.opts.shared {
if self.res_cc.contains_key("proxy-revalidate") {
return Duration::from_secs(0);
}
if let Some(s_max) = self.res_cc.get("s-maxage").and_then(|v| v.as_ref()) {
return Duration::from_secs(s_max.parse().unwrap_or(0));
}
}
if let Some(max_age) = self.res_cc.get("max-age").and_then(|v| v.as_ref()) {
return Duration::from_secs(max_age.parse().unwrap_or(0));
}
let default_min_ttl = if self.res_cc.contains_key("immutable") {
self.opts.immutable_min_time_to_live
} else {
Duration::from_secs(0)
};
let server_date = self.server_date();
if let Some(expires) = self.res.get_str("expires") {
return match DateTime::parse_from_rfc2822(expires) {
Err(_) => Duration::from_secs(0),
Ok(expires) => {
let expires = SystemTime::UNIX_EPOCH
+ Duration::from_secs(expires.timestamp().max(0) as _);
return default_min_ttl
.max(expires.duration_since(server_date).unwrap_or_default());
}
};
}
if let Some(last_modified) = self.res.get_str("last-modified") {
if let Ok(last_modified) = DateTime::parse_from_rfc2822(last_modified) {
let last_modified = SystemTime::UNIX_EPOCH
+ Duration::from_secs(last_modified.timestamp().max(0) as _);
if let Ok(diff) = server_date.duration_since(last_modified) {
let secs_left = diff.as_secs() as f64 * self.opts.cache_heuristic as f64;
return default_min_ttl.max(Duration::from_secs(secs_left as _));
}
}
}
default_min_ttl
}
pub fn time_to_live(&self, now: SystemTime) -> Duration {
self.max_age()
.checked_sub(self.age(now))
.unwrap_or_default()
}
pub fn is_stale(&self, now: SystemTime) -> bool {
self.max_age() <= self.age(now)
}
pub fn revalidation_headers<Req: RequestLike>(&self, incoming_req: &Req) -> HeaderMap {
let mut headers = Self::copy_without_hop_by_hop_headers(incoming_req.headers());
headers.remove("if-range");
if !self.request_matches(incoming_req, true) || !self.is_storable() {
headers.remove("if-none-match");
headers.remove("if-modified-since");
return headers;
}
if let Some(etag) = self.res.get_str("etag") {
let if_none = join(get_all_comma(headers.get_all("if-none-match")).chain(Some(etag)));
headers.insert("if-none-match", HeaderValue::from_str(&if_none).unwrap());
}
let forbids_weak_validators = self.method != Method::GET
|| headers.contains_key("accept-ranges")
|| headers.contains_key("if-match")
|| headers.contains_key("if-unmodified-since");
if forbids_weak_validators {
headers.remove("if-modified-since");
let etags = join(
get_all_comma(headers.get_all("if-none-match"))
.filter(|etag| !etag.trim_start().starts_with("W/")),
);
if etags.is_empty() {
headers.remove("if-none-match");
} else {
headers.insert("if-none-match", HeaderValue::from_str(&etags).unwrap());
}
} else if !headers.contains_key("if-modified-since") {
if let Some(last_modified) = self.res.get_str("last-modified") {
headers.insert(
"if-modified-since",
HeaderValue::from_str(&last_modified).unwrap(),
);
}
}
headers
}
pub fn revalidated_policy<ReqB, ResB>(
&self,
request: &Request<ReqB>,
response: &mut Response<ResB>,
) -> RevalidatedPolicy {
let old_etag = self.res.get_str("etag").map(str::trim);
let old_last_modified = response.headers().get_str("last-modified").map(str::trim);
let new_etag = response.headers().get_str("new_etag").map(str::trim);
let new_last_modified = response.headers().get_str("last-modified").map(str::trim);
let mut matches = false;
if response.status() != StatusCode::NOT_MODIFIED {
matches = false;
} else if new_etag.map_or(false, |etag| !etag.starts_with("W/")) {
matches = old_etag.map(|e| e.trim_start_matches("W/")) == new_etag;
} else if let (Some(old), Some(new)) = (old_etag, new_etag) {
matches = old.trim_start_matches("W/") == new.trim_start_matches("W/");
} else if old_last_modified.is_some() {
matches = old_last_modified == new_last_modified;
} else {
if old_etag.is_none()
&& new_etag.is_none()
&& old_last_modified.is_none()
&& new_last_modified.is_none()
{
matches = true;
}
}
let modified = response.status() != StatusCode::NOT_MODIFIED;
if matches {
for (h, v) in &self.res {
if !EXCLUDED_FROM_REVALIDATION_UPDATE.contains(&&h.as_str()) {
response.headers_mut().insert(h.to_owned(), v.to_owned());
}
}
*response.status_mut() = self.status;
}
RevalidatedPolicy {
policy: CachePolicy::new(request, &*response, self.opts),
modified,
matches,
}
}
}
pub struct RevalidatedPolicy {
pub policy: CachePolicy,
pub modified: bool,
pub matches: bool,
}
fn get_all_comma<'a>(
all: impl IntoIterator<Item = &'a HeaderValue>,
) -> impl Iterator<Item = &'a str> {
all.into_iter()
.filter_map(|v| v.to_str().ok())
.flat_map(|s| s.split(',').map(str::trim))
}
trait GetHeaderStr {
fn get_str(&self, k: &str) -> Option<&str>;
}
impl GetHeaderStr for HeaderMap {
#[inline]
fn get_str(&self, k: &str) -> Option<&str> {
self.get(k).and_then(|v| v.to_str().ok())
}
}
fn join<'a>(parts: impl Iterator<Item = &'a str>) -> String {
let mut out = String::new();
for part in parts {
out.reserve(2 + part.len());
if !out.is_empty() {
out.push_str(", ");
}
out.push_str(part);
}
out
}
pub trait RequestLike {
fn uri(&self) -> &Uri;
fn method(&self) -> &Method;
fn headers(&self) -> &HeaderMap;
}
pub trait ResponseLike {
fn status(&self) -> StatusCode;
fn headers(&self) -> &HeaderMap;
}
impl<Body> RequestLike for Request<Body> {
fn uri(&self) -> &Uri {
self.uri()
}
fn method(&self) -> &Method {
self.method()
}
fn headers(&self) -> &HeaderMap {
self.headers()
}
}
impl<Body> ResponseLike for Response<Body> {
fn status(&self) -> StatusCode {
self.status()
}
fn headers(&self) -> &HeaderMap {
self.headers()
}
}