use std::time::{Duration, SystemTime};
use trillium_caching_headers::{CacheControlDirective, CacheControlHeader, CachingHeadersExt};
use trillium_http::{Headers, KnownHeaderName, Method, Status};
pub(crate) fn effective_response_cache_control(
response_headers: &Headers,
options: &CacheOptions,
) -> (Option<CacheControlHeader>, bool) {
if options.shared
&& let Some(raw) = response_headers.get_str(KnownHeaderName::CdnCacheControl)
&& looks_like_valid_sf_dictionary(raw)
&& let Some(cdn_cc) = response_headers.cdn_cache_control()
&& !cdn_cc.is_empty()
{
return (Some(cdn_cc), true);
}
(response_headers.cache_control(), false)
}
fn looks_like_valid_sf_dictionary(s: &str) -> bool {
let s = s.trim();
if s.is_empty() {
return false;
}
s.split(',').all(|member| {
let member = member.trim();
if member.is_empty() {
return false;
}
let key = member.split_once('=').map_or(member, |(k, _)| k).trim_end();
is_valid_sf_key(key)
})
}
fn is_valid_sf_key(s: &str) -> bool {
let mut chars = s.chars();
let Some(first) = chars.next() else {
return false;
};
if !first.is_ascii_alphabetic() && first != '*' {
return false;
}
chars.all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '.' | '*'))
}
#[derive(Debug, Copy, Clone, fieldwork::Fieldwork)]
#[fieldwork(get, set, get_mut, with, rename_predicates)]
pub struct CacheOptions {
pub(crate) shared: bool,
pub(crate) cache_heuristic: f32,
#[field(copy)]
pub(crate) immutable_min_time_to_live: Duration,
}
impl Default for CacheOptions {
fn default() -> Self {
Self {
shared: false,
cache_heuristic: 0.1,
immutable_min_time_to_live: Duration::from_secs(24 * 3600),
}
}
}
#[derive(Debug, Clone)]
pub struct CachePolicy {
pub(crate) request_method: Method,
pub(crate) vary_snapshot: Vec<(String, Option<String>)>,
pub(crate) response_status: Status,
pub(crate) response_headers: Headers,
pub(crate) response_cache_control: Option<CacheControlHeader>,
pub(crate) targeted_cc_in_effect: bool,
pub(crate) response_time: SystemTime,
pub(crate) options: CacheOptions,
}
impl CachePolicy {
pub fn same_variant_as(&self, other: &Self) -> bool {
self.vary_snapshot == other.vary_snapshot
}
pub(crate) fn new(
request_method: Method,
request_headers: &Headers,
response_status: Status,
response_headers: Headers,
response_time: SystemTime,
options: CacheOptions,
) -> Self {
let (mut response_cache_control, targeted_cc_in_effect) =
effective_response_cache_control(&response_headers, &options);
if response_cache_control.is_none()
&& response_headers
.get_str(KnownHeaderName::Pragma)
.is_some_and(|p| p.contains("no-cache"))
{
response_cache_control = Some(CacheControlHeader::from(CacheControlDirective::NoCache));
}
let vary_snapshot = build_vary_snapshot(&response_headers, request_headers);
Self {
request_method,
vary_snapshot,
response_status,
response_headers,
response_cache_control,
targeted_cc_in_effect,
response_time,
options,
}
}
}
fn build_vary_snapshot(
response_headers: &Headers,
request_headers: &Headers,
) -> Vec<(String, Option<String>)> {
let Some(values) = response_headers.get_values(KnownHeaderName::Vary) else {
return Vec::new();
};
values
.iter()
.filter_map(|v| v.as_str())
.flat_map(|line| line.split(','))
.map(str::trim)
.filter(|n| !n.is_empty())
.map(|name| {
let lower = name.to_ascii_lowercase();
let value = request_headers.get_str(lower.as_str()).map(str::to_string);
(lower, value)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
use trillium_client::ConnExt;
use trillium_http::KnownHeaderName::*;
#[test]
fn vary_snapshot_handles_multiple_header_lines() {
let mut conn = exchange(
Method::Get,
&[(AcceptEncoding, "gzip"), (AcceptLanguage, "en-US")],
Status::Ok,
&[(Vary, "Accept-Encoding")],
);
conn.response_headers_mut().append(Vary, "Accept-Language");
let policy = policy_from(&conn, SystemTime::now(), private_cache());
assert_eq!(
policy.vary_snapshot,
vec![
("accept-encoding".to_string(), Some("gzip".to_string())),
("accept-language".to_string(), Some("en-US".to_string())),
]
);
}
#[test]
fn vary_snapshot_captures_star_from_second_line() {
let mut conn = exchange(
Method::Get,
&[],
Status::Ok,
&[(Vary, "")], );
conn.response_headers_mut().append(Vary, "*");
let policy = policy_from(&conn, SystemTime::now(), private_cache());
assert!(policy.vary_snapshot.iter().any(|(name, _)| name == "*"));
}
#[test]
fn vary_snapshot_captures_named_request_headers() {
let conn = exchange(
Method::Get,
&[(AcceptEncoding, "gzip"), (AcceptLanguage, "en-US")],
Status::Ok,
&[(Vary, "Accept-Encoding, Accept-Language")],
);
let policy = policy_from(&conn, SystemTime::now(), private_cache());
assert_eq!(
policy.vary_snapshot,
vec![
("accept-encoding".to_string(), Some("gzip".to_string())),
("accept-language".to_string(), Some("en-US".to_string())),
]
);
}
#[test]
fn sf_dictionary_validator() {
assert!(looks_like_valid_sf_dictionary("max-age=600"));
assert!(looks_like_valid_sf_dictionary("no-store"));
assert!(looks_like_valid_sf_dictionary("max-age=600, no-store"));
assert!(looks_like_valid_sf_dictionary(r#"max-age="600""#));
assert!(looks_like_valid_sf_dictionary("MaX-aGe=3600"));
assert!(!looks_like_valid_sf_dictionary("max-age=10000, &&&&&"));
assert!(!looks_like_valid_sf_dictionary("&&&&&"));
assert!(!looks_like_valid_sf_dictionary(""));
assert!(!looks_like_valid_sf_dictionary(" "));
assert!(!looks_like_valid_sf_dictionary("max-age=600,"));
}
#[test]
fn vary_snapshot_records_absent_request_header_as_none() {
let conn = exchange(Method::Get, &[], Status::Ok, &[(Vary, "Accept-Encoding")]);
let policy = policy_from(&conn, SystemTime::now(), private_cache());
assert_eq!(
policy.vary_snapshot,
vec![("accept-encoding".to_string(), None)]
);
}
}