use std::time::{Duration, SystemTime};
use http::header::HeaderValue;
use crate::rkyvutil::OwnedArchive;
use self::control::CacheControl;
mod control;
#[derive(
Clone,
Debug,
Default,
rkyv::Archive,
rkyv::Deserialize,
rkyv::Portable,
rkyv::Serialize,
bytecheck::CheckBytes,
)]
#[rkyv(as = Self)]
#[repr(C)]
struct CacheConfig {
shared: bool,
}
#[derive(Debug)]
pub struct CachePolicyBuilder {
config: CacheConfig,
request: Request,
request_headers: http::HeaderMap,
}
impl CachePolicyBuilder {
pub fn new(request: &reqwest::Request) -> Self {
let config = CacheConfig::default();
let request_headers = request.headers().clone();
let request = Request::from(request);
Self {
config,
request,
request_headers,
}
}
pub fn build(self, response: &reqwest::Response) -> CachePolicy {
let vary = Vary::from_request_response_headers(&self.request_headers, response.headers());
CachePolicy {
config: self.config,
request: self.request,
response: Response::from(response),
vary,
}
}
}
#[derive(Debug, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[rkyv(derive(Debug))]
pub struct CachePolicy {
config: CacheConfig,
request: Request,
response: Response,
vary: Vary,
}
impl CachePolicy {
pub fn to_archived(&self) -> OwnedArchive<Self> {
OwnedArchive::from_unarchived(self).expect("all possible values can be archived")
}
}
impl ArchivedCachePolicy {
pub fn before_request(&self, request: &mut reqwest::Request) -> BeforeRequest {
let now = SystemTime::now();
if !self.is_storable() {
tracing::trace!(
"Request {} does not match cache request {} because it isn't storable",
request.url(),
self.request.uri,
);
return BeforeRequest::NoMatch;
}
if self.request.uri != request.url().as_str() {
tracing::trace!(
"Request {} does not match cache URL of {}",
request.url(),
self.request.uri,
);
return BeforeRequest::NoMatch;
}
if request.method() != http::Method::GET && request.method() != http::Method::HEAD {
tracing::trace!(
"Method {:?} for request {} is not supported by this cache",
request.method(),
request.url(),
);
return BeforeRequest::NoMatch;
}
if !self.vary.matches(request.headers()) {
tracing::trace!(
"Request {} does not match cached request because of the 'Vary' header",
request.url(),
);
self.set_revalidation_headers(request);
return BeforeRequest::Stale(self.new_cache_policy_builder(request));
}
if self.response.headers.cc.no_cache {
self.set_revalidation_headers(request);
return BeforeRequest::Stale(self.new_cache_policy_builder(request));
}
if self.is_fresh(now, request) {
return BeforeRequest::Fresh;
}
self.set_revalidation_headers(request);
BeforeRequest::Stale(self.new_cache_policy_builder(request))
}
pub fn after_response(
&self,
cache_policy_builder: CachePolicyBuilder,
response: &reqwest::Response,
) -> AfterResponse {
let mut new_policy = cache_policy_builder.build(response);
if self.is_modified(&new_policy) {
AfterResponse::Modified(new_policy)
} else {
new_policy.response.status = self.response.status.into();
AfterResponse::NotModified(new_policy)
}
}
fn is_modified(&self, new_policy: &CachePolicy) -> bool {
if new_policy.response.status != 304 {
tracing::trace!(
"Resource is modified because status is {:?} and not 304",
new_policy.response.status
);
return true;
}
if let Some(old_etag) = self.response.headers.etag.as_ref() {
if let Some(new_etag) = new_policy.response.headers.etag.as_ref() {
if !old_etag.weak && !new_etag.weak && old_etag.value == new_etag.value {
tracing::trace!(
"Resource is not modified because old and new etag values ({:?}) match",
new_etag.value,
);
return false;
}
}
}
if let Some(old_last_modified) = self.response.headers.last_modified_unix_timestamp.as_ref()
{
if let Some(new_last_modified) = new_policy
.response
.headers
.last_modified_unix_timestamp
.as_ref()
{
if old_last_modified == new_last_modified {
tracing::trace!(
"Resource is not modified because modified times ({new_last_modified:?}) match",
);
return false;
}
}
}
if self.response.headers.etag.is_none()
&& new_policy.response.headers.etag.is_none()
&& self.response.headers.last_modified_unix_timestamp.is_none()
&& new_policy
.response
.headers
.last_modified_unix_timestamp
.is_none()
{
tracing::trace!(
"Resource is not modified because there are no etags or last modified \
timestamps, so we assume the 304 status is correct",
);
return false;
}
true
}
fn set_revalidation_headers(&self, request: &mut reqwest::Request) {
if let Some(etag) = self.response.headers.etag.as_ref() {
if !etag.weak {
if let Ok(header) = HeaderValue::from_bytes(&etag.value) {
request.headers_mut().append("if-none-match", header);
}
}
}
if !request.headers().contains_key("if-modified-since") {
if let Some(&last_modified_unix_timestamp) =
self.response.headers.last_modified_unix_timestamp.as_ref()
{
if let Some(last_modified) =
unix_timestamp_to_header(last_modified_unix_timestamp.into())
{
request
.headers_mut()
.insert("if-modified-since", last_modified);
}
}
}
}
pub fn is_storable(&self) -> bool {
const HEURISTICALLY_CACHEABLE_STATUS_CODES: &[u16] =
&[200, 203, 204, 206, 300, 301, 308, 404, 405, 410, 414, 501];
if !matches!(
self.request.method,
ArchivedMethod::Get | ArchivedMethod::Head
) {
tracing::trace!(
"Response from {} is not storable because of the request method {:?}",
self.request.uri,
self.request.method
);
return false;
}
if !self.response.has_final_status() {
tracing::trace!(
"Response from {} is not storable because it has \
a non-final status code {:?}",
self.request.uri,
self.response.status,
);
return false;
}
if self.response.status == 206 || self.response.status == 304 {
tracing::trace!(
"Response from {} is not storable because it has \
an unsupported status code {:?}",
self.request.uri,
self.response.status,
);
return false;
}
if self.request.headers.cc.no_store {
tracing::trace!(
"Response from {} is not storable because its request has \
a 'no-store' cache-control directive",
self.request.uri,
);
return false;
}
if self.response.headers.cc.no_store {
tracing::trace!(
"Response from {} is not storable because it has \
a 'no-store' cache-control directive",
self.request.uri,
);
return false;
}
if self.config.shared {
if self.response.headers.cc.private {
tracing::trace!(
"Response from {} is not storable because this is a shared \
cache and has a 'private' cache-control directive",
self.request.uri,
);
return false;
}
if self.request.headers.authorization && !self.allows_authorization_storage() {
tracing::trace!(
"Response from {} is not storable because this is a shared \
cache and the request has an 'Authorization' header set and \
the response has indicated that caching requests with an \
'Authorization' header is allowed",
self.request.uri,
);
return false;
}
}
if self.response.headers.cc.public {
tracing::trace!(
"Response from {} is storable because it has \
a 'public' cache-control directive",
self.request.uri,
);
return true;
}
if !self.config.shared && self.response.headers.cc.private {
tracing::trace!(
"Response from {} is storable because this is a shared cache \
and has a 'private' cache-control directive",
self.request.uri,
);
return true;
}
if self.response.headers.expires_unix_timestamp.is_some() {
tracing::trace!(
"Response from {} is storable because it has an \
'Expires' header set",
self.request.uri,
);
return true;
}
if self.response.headers.cc.max_age_seconds.is_some() {
tracing::trace!(
"Response from {} is storable because it has an \
'max-age' cache-control directive",
self.request.uri,
);
return true;
}
if self.config.shared && self.response.headers.cc.s_maxage_seconds.is_some() {
tracing::trace!(
"Response from {} is storable because this is a shared cache \
and has a 's-maxage' cache-control directive",
self.request.uri,
);
return true;
}
if HEURISTICALLY_CACHEABLE_STATUS_CODES.contains(&self.response.status.into()) {
tracing::trace!(
"Response from {} is storable because it has a \
heuristically cacheable status code {:?}",
self.request.uri,
self.response.status,
);
return true;
}
tracing::trace!(
"Response from {} is not storable because it does not meet any \
of the necessary criteria (e.g., it doesn't have an 'Expires' \
header set or a 'max-age' cache-control directive)",
self.request.uri,
);
false
}
fn allows_authorization_storage(&self) -> bool {
self.response.headers.cc.must_revalidate
|| self.response.headers.cc.public
|| self.response.headers.cc.s_maxage_seconds.is_some()
}
fn is_fresh(&self, now: SystemTime, request: &reqwest::Request) -> bool {
let freshness_lifetime = self.freshness_lifetime().as_secs();
let age = self.age(now).as_secs();
if !self.response.headers.cc.immutable {
let reqcc = request
.headers()
.get_all("cache-control")
.iter()
.collect::<CacheControl>();
if reqcc.no_cache {
tracing::trace!(
"Request to {} does not have a fresh cache entry because \
it has a 'no-cache' cache-control directive",
request.url(),
);
return false;
}
if let Some(&max_age) = reqcc.max_age_seconds.as_ref() {
if age > max_age {
tracing::trace!(
"Request to {} does not have a fresh cache entry because \
the cached response's age is {} seconds and the max age \
allowed by the request is {} seconds",
request.url(),
age,
max_age,
);
return false;
}
}
if let Some(&min_fresh) = reqcc.min_fresh_seconds.as_ref() {
let time_to_live = freshness_lifetime.saturating_sub(unix_timestamp(now));
if time_to_live < min_fresh {
tracing::trace!(
"Request to {} does not have a fresh cache entry because \
the request set a 'min-fresh' cache-control directive, \
and its time-to-live is {} seconds but it needs to be \
at least {} seconds",
request.url(),
time_to_live,
min_fresh,
);
return false;
}
}
}
if age > freshness_lifetime {
let allows_stale = self.allows_stale(now);
if !allows_stale {
tracing::trace!(
"Request to {} does not have a fresh cache entry because \
its age is {} seconds, it is greater than the freshness \
lifetime of {} seconds and stale cached responses are not \
allowed",
request.url(),
age,
freshness_lifetime,
);
return false;
}
}
true
}
fn allows_stale(&self, now: SystemTime) -> bool {
if self.response.headers.cc.must_revalidate {
tracing::trace!(
"Request to {} has a cached response that does not \
permit staleness because the response has a 'must-revalidate' \
cache-control directive set",
self.request.uri,
);
return false;
}
if let Some(&max_stale) = self.request.headers.cc.max_stale_seconds.as_ref() {
let stale_amount = self
.age(now)
.as_secs()
.saturating_sub(self.freshness_lifetime().as_secs());
if stale_amount <= max_stale.into() {
tracing::trace!(
"Request to {} has a cached response that allows staleness \
in this case because the stale amount is {} seconds and the \
'max-stale' cache-control directive set by the cached request \
is {} seconds",
self.request.uri,
stale_amount,
max_stale,
);
return true;
}
}
tracing::trace!(
"Request to {} has a cached response that does not allow staleness",
self.request.uri,
);
false
}
fn age(&self, now: SystemTime) -> Duration {
let apparent_age =
u64::from(self.response.unix_timestamp).saturating_sub(self.response.header_date());
let response_delay = u64::from(self.response.unix_timestamp)
.saturating_sub(self.request.unix_timestamp.into());
let corrected_age_value = self.response.header_age().saturating_add(response_delay);
let corrected_initial_age = apparent_age.max(corrected_age_value);
let resident_age = unix_timestamp(now).saturating_sub(self.response.unix_timestamp.into());
let current_age = corrected_initial_age + resident_age;
Duration::from_secs(current_age)
}
fn freshness_lifetime(&self) -> Duration {
if self.config.shared {
if let Some(&s_maxage) = self.response.headers.cc.s_maxage_seconds.as_ref() {
let duration = Duration::from_secs(s_maxage.into());
tracing::trace!(
"Freshness lifetime found via shared \
cache-control max age setting: {duration:?}"
);
return duration;
}
}
if let Some(&max_age) = self.response.headers.cc.max_age_seconds.as_ref() {
let duration = Duration::from_secs(max_age.into());
tracing::trace!(
"Freshness lifetime found via cache-control max age setting: {duration:?}"
);
return duration;
}
if let Some(&expires) = self.response.headers.expires_unix_timestamp.as_ref() {
let duration =
Duration::from_secs(u64::from(expires).saturating_sub(self.response.header_date()));
tracing::trace!("Freshness lifetime found via expires header: {duration:?}");
return duration;
}
if self.response.headers.last_modified_unix_timestamp.is_some() {
let duration = Duration::from_secs(600);
tracing::trace!(
"Freshness lifetime heuristically assumed \
because of presence of last-modified header: {duration:?}"
);
return duration;
}
tracing::trace!("Could not determine freshness lifetime, assuming none exists");
Duration::ZERO
}
fn new_cache_policy_builder(&self, request: &reqwest::Request) -> CachePolicyBuilder {
let request_headers = request.headers().clone();
CachePolicyBuilder {
config: self.config.clone(),
request: Request::from(request),
request_headers,
}
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum BeforeRequest {
Fresh,
Stale(CachePolicyBuilder),
NoMatch,
}
#[derive(Debug)]
pub enum AfterResponse {
NotModified(CachePolicy),
Modified(CachePolicy),
}
#[derive(Debug, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[rkyv(derive(Debug))]
struct Request {
uri: String,
method: Method,
headers: RequestHeaders,
unix_timestamp: u64,
}
impl<'a> From<&'a reqwest::Request> for Request {
fn from(from: &'a reqwest::Request) -> Self {
Self {
uri: from.url().to_string(),
method: Method::from(from.method()),
headers: RequestHeaders::from(from.headers()),
unix_timestamp: unix_timestamp(SystemTime::now()),
}
}
}
#[derive(Debug, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[rkyv(derive(Debug))]
struct RequestHeaders {
cc: CacheControl,
authorization: bool,
}
impl<'a> From<&'a http::HeaderMap> for RequestHeaders {
fn from(from: &'a http::HeaderMap) -> Self {
Self {
cc: from.get_all("cache-control").iter().collect(),
authorization: from.contains_key("authorization"),
}
}
}
#[derive(Debug, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[rkyv(derive(Debug))]
#[repr(u8)]
enum Method {
Get,
Head,
Unrecognized,
}
impl<'a> From<&'a http::Method> for Method {
fn from(from: &'a http::Method) -> Self {
if from == http::Method::GET {
Self::Get
} else if from == http::Method::HEAD {
Self::Head
} else {
Self::Unrecognized
}
}
}
#[derive(Debug, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[rkyv(derive(Debug))]
struct Response {
status: u16,
headers: ResponseHeaders,
unix_timestamp: u64,
}
impl ArchivedResponse {
fn header_age(&self) -> u64 {
self.headers
.age_seconds
.as_ref()
.map(u64::from)
.unwrap_or(0)
}
fn header_date(&self) -> u64 {
self.headers
.date_unix_timestamp
.unwrap_or(self.unix_timestamp)
.into()
}
fn has_final_status(&self) -> bool {
self.status >= 200
}
}
impl<'a> From<&'a reqwest::Response> for Response {
fn from(from: &'a reqwest::Response) -> Self {
Self {
status: from.status().as_u16(),
headers: ResponseHeaders::from(from.headers()),
unix_timestamp: unix_timestamp(SystemTime::now()),
}
}
}
#[derive(Debug, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[rkyv(derive(Debug))]
struct ResponseHeaders {
cc: CacheControl,
age_seconds: Option<u64>,
date_unix_timestamp: Option<u64>,
expires_unix_timestamp: Option<u64>,
last_modified_unix_timestamp: Option<u64>,
etag: Option<ETag>,
}
impl<'a> From<&'a http::HeaderMap> for ResponseHeaders {
fn from(from: &'a http::HeaderMap) -> Self {
Self {
cc: from.get_all("cache-control").iter().collect(),
age_seconds: from
.get("age")
.and_then(|header| parse_seconds(header.as_bytes())),
date_unix_timestamp: from
.get("date")
.and_then(|header| header.to_str().ok())
.and_then(rfc2822_to_unix_timestamp),
expires_unix_timestamp: from
.get("expires")
.and_then(|header| header.to_str().ok())
.and_then(rfc2822_to_unix_timestamp),
last_modified_unix_timestamp: from
.get("last-modified")
.and_then(|header| header.to_str().ok())
.and_then(rfc2822_to_unix_timestamp),
etag: from
.get("etag")
.map(|header| ETag::parse(header.as_bytes())),
}
}
}
#[derive(Debug, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[rkyv(derive(Debug))]
struct ETag {
value: Vec<u8>,
weak: bool,
}
impl ETag {
fn parse(header_value: &[u8]) -> Self {
let (value, weak) = if header_value.starts_with(b"W/") {
(&header_value[2..], true)
} else {
(header_value, false)
};
Self {
value: value.to_vec(),
weak,
}
}
}
#[derive(Debug, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[rkyv(derive(Debug))]
struct Vary {
fields: Vec<VaryField>,
}
impl Vary {
fn always_fails_to_match() -> Self {
Self {
fields: vec![VaryField {
name: "*".to_string(),
value: vec![],
}],
}
}
fn from_request_response_headers(
request: &http::HeaderMap,
response: &http::HeaderMap,
) -> Self {
let mut fields = vec![];
for header in response.get_all("vary") {
let Ok(csv) = header.to_str() else { continue };
for header_name in csv.split(',') {
let header_name = header_name.trim().to_ascii_lowercase();
if header_name == "*" {
return Self::always_fails_to_match();
}
let value = request
.get(&header_name)
.map(|header| header.as_bytes().to_vec())
.unwrap_or_default();
fields.push(VaryField {
name: header_name,
value,
});
}
}
Self { fields }
}
}
impl ArchivedVary {
fn matches(&self, request_headers: &http::HeaderMap) -> bool {
for field in self.fields.iter() {
if field.name == "*" {
return false;
}
let request_header_value = request_headers
.get(field.name.as_str())
.map_or(&b""[..], |header| header.as_bytes());
if field.value.as_slice() != request_header_value {
return false;
}
}
true
}
}
#[derive(Debug, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[rkyv(derive(Debug))]
struct VaryField {
name: String,
value: Vec<u8>,
}
fn unix_timestamp(time: SystemTime) -> u64 {
time.duration_since(SystemTime::UNIX_EPOCH)
.expect("UNIX_EPOCH is as early as it gets")
.as_secs()
}
fn rfc2822_to_unix_timestamp(s: &str) -> Option<u64> {
rfc2822_to_datetime(s).and_then(|timestamp| u64::try_from(timestamp.as_second()).ok())
}
fn rfc2822_to_datetime(s: &str) -> Option<jiff::Timestamp> {
jiff::fmt::rfc2822::DateTimeParser::new()
.parse_timestamp(s)
.ok()
}
fn unix_timestamp_to_header(seconds: u64) -> Option<HeaderValue> {
unix_timestamp_to_rfc2822(seconds).and_then(|string| HeaderValue::from_str(&string).ok())
}
fn unix_timestamp_to_rfc2822(seconds: u64) -> Option<String> {
use jiff::fmt::rfc2822::DateTimePrinter;
unix_timestamp_to_datetime(seconds).and_then(|timestamp| {
DateTimePrinter::new()
.timestamp_to_rfc9110_string(×tamp)
.ok()
})
}
fn unix_timestamp_to_datetime(seconds: u64) -> Option<jiff::Timestamp> {
jiff::Timestamp::from_second(i64::try_from(seconds).ok()?).ok()
}
fn parse_seconds(value: &[u8]) -> Option<u64> {
if !value.iter().all(u8::is_ascii_digit) {
return None;
}
std::str::from_utf8(value).ok()?.parse().ok()
}