use http::{HeaderMap, Method, StatusCode};
use std::collections::HashSet;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use crate::logging::MLLoggingConfig;
use crate::streaming::StreamingPolicy;
use crate::tags::TagPolicy;
type MethodPredicateFn = Arc<dyn Fn(&Method) -> bool + Send + Sync>;
type TagExtractorFn = Arc<dyn Fn(&Method, &http::Uri) -> Vec<String> + Send + Sync>;
#[derive(Clone)]
pub struct CachePolicy {
ttl: Duration,
negative_ttl: Duration,
stale_while_revalidate: Duration,
refresh_before: Duration,
max_body_size: Option<usize>,
min_body_size: Option<usize>,
cache_statuses: HashSet<u16>,
respect_cache_control: bool,
method_predicate: Option<MethodPredicateFn>,
header_allowlist: Option<HashSet<String>>,
allow_streaming_bodies: bool,
compression: CompressionConfig,
ml_logging: MLLoggingConfig,
tag_policy: TagPolicy,
tag_extractor: Option<TagExtractorFn>,
streaming_policy: StreamingPolicy,
}
#[derive(Clone, Copy, Debug)]
pub enum CompressionStrategy {
None,
Gzip,
}
#[derive(Clone, Copy, Debug)]
pub struct CompressionConfig {
pub strategy: CompressionStrategy,
pub min_size: usize,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
strategy: CompressionStrategy::None,
min_size: 1024,
}
}
}
impl fmt::Debug for CachePolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CachePolicy")
.field("ttl", &self.ttl)
.field("negative_ttl", &self.negative_ttl)
.field("stale_while_revalidate", &self.stale_while_revalidate)
.field("refresh_before", &self.refresh_before)
.field("max_body_size", &self.max_body_size)
.field("min_body_size", &self.min_body_size)
.field("cache_statuses", &self.cache_statuses)
.field("respect_cache_control", &self.respect_cache_control)
.field(
"header_allowlist",
&self
.header_allowlist
.as_ref()
.map(|set| set.iter().collect::<Vec<_>>()),
)
.field("allow_streaming_bodies", &self.allow_streaming_bodies)
.field("compression", &self.compression)
.finish()
}
}
impl CachePolicy {
pub fn new(
ttl: Duration,
negative_ttl: Duration,
statuses: impl IntoIterator<Item = u16>,
) -> Self {
Self {
ttl,
negative_ttl,
stale_while_revalidate: Duration::from_secs(0),
refresh_before: Duration::from_secs(0),
max_body_size: None,
min_body_size: None,
cache_statuses: statuses.into_iter().collect(),
respect_cache_control: true,
method_predicate: None,
header_allowlist: None,
allow_streaming_bodies: false,
compression: CompressionConfig::default(),
ml_logging: MLLoggingConfig::default(),
tag_policy: TagPolicy::default(),
tag_extractor: None,
streaming_policy: StreamingPolicy::default(),
}
}
pub fn ttl_for(&self, status: StatusCode) -> Option<Duration> {
if self.cache_statuses.contains(&status.as_u16()) {
Some(self.ttl)
} else if status.is_client_error() && !self.negative_ttl.is_zero() {
Some(self.negative_ttl)
} else {
None
}
}
pub fn should_cache_method(&self, method: &Method) -> bool {
if let Some(predicate) = &self.method_predicate {
predicate(method)
} else {
matches!(method, &Method::GET | &Method::HEAD)
}
}
pub fn respect_cache_control(&self) -> bool {
self.respect_cache_control
}
pub fn max_body_size(&self) -> Option<usize> {
self.max_body_size
}
pub fn min_body_size(&self) -> Option<usize> {
self.min_body_size
}
pub fn allow_streaming_bodies(&self) -> bool {
self.allow_streaming_bodies
}
pub fn compression(&self) -> CompressionConfig {
self.compression
}
pub fn refresh_before(&self) -> Duration {
self.refresh_before
}
pub fn ttl(&self) -> Duration {
self.ttl
}
pub fn negative_ttl(&self) -> Duration {
self.negative_ttl
}
pub fn stale_while_revalidate(&self) -> Duration {
self.stale_while_revalidate
}
pub fn ml_logging(&self) -> &MLLoggingConfig {
&self.ml_logging
}
pub fn tag_policy(&self) -> &TagPolicy {
&self.tag_policy
}
pub fn streaming_policy(&self) -> &StreamingPolicy {
&self.streaming_policy
}
pub fn extract_tags(&self, method: &Method, uri: &http::Uri) -> Vec<String> {
if !self.tag_policy.enabled {
return Vec::new();
}
if let Some(ref extractor) = self.tag_extractor {
let tags = extractor(method, uri);
self.tag_policy.validate_tags(tags)
} else {
Vec::new()
}
}
pub fn headers_to_cache(&self, headers: &HeaderMap) -> Vec<(String, Vec<u8>)> {
match &self.header_allowlist {
Some(allowlist) => headers
.iter()
.filter(|(name, _)| allowlist.contains(&name.as_str().to_ascii_lowercase()))
.map(|(name, value)| (name.as_str().to_owned(), value.as_bytes().to_vec()))
.collect(),
None => headers
.iter()
.map(|(name, value)| (name.as_str().to_owned(), value.as_bytes().to_vec()))
.collect(),
}
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
pub fn with_negative_ttl(mut self, ttl: Duration) -> Self {
self.negative_ttl = ttl;
self
}
pub fn with_statuses(mut self, statuses: impl IntoIterator<Item = u16>) -> Self {
self.cache_statuses = statuses.into_iter().collect();
self
}
pub fn with_stale_while_revalidate(mut self, duration: Duration) -> Self {
self.stale_while_revalidate = duration;
self
}
pub fn with_refresh_before(mut self, duration: Duration) -> Self {
self.refresh_before = duration;
self
}
pub fn with_max_body_size(mut self, size: Option<usize>) -> Self {
self.max_body_size = size;
self
}
pub fn with_min_body_size(mut self, size: Option<usize>) -> Self {
self.min_body_size = size;
self
}
pub fn with_allow_streaming_bodies(mut self, allow: bool) -> Self {
self.allow_streaming_bodies = allow;
self
}
pub fn with_compression(mut self, config: CompressionConfig) -> Self {
self.compression = config;
self
}
pub fn with_respect_cache_control(mut self, enabled: bool) -> Self {
self.respect_cache_control = enabled;
self
}
pub fn with_method_predicate<F>(mut self, predicate: F) -> Self
where
F: Fn(&Method) -> bool + Send + Sync + 'static,
{
self.method_predicate = Some(Arc::new(predicate));
self
}
pub fn with_header_allowlist<I, S>(mut self, headers: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.header_allowlist = Some(
headers
.into_iter()
.map(|h| h.into().to_ascii_lowercase())
.collect(),
);
self
}
pub fn with_ml_logging(mut self, config: MLLoggingConfig) -> Self {
self.ml_logging = config;
self
}
pub fn with_tag_policy(mut self, policy: TagPolicy) -> Self {
self.tag_policy = policy;
self
}
pub fn with_tag_extractor<F>(mut self, extractor: F) -> Self
where
F: Fn(&Method, &http::Uri) -> Vec<String> + Send + Sync + 'static,
{
self.tag_extractor = Some(Arc::new(extractor));
self
}
pub fn with_streaming_policy(mut self, policy: StreamingPolicy) -> Self {
self.streaming_policy = policy;
self
}
}
impl Default for CachePolicy {
fn default() -> Self {
Self {
ttl: Duration::from_secs(60),
negative_ttl: Duration::from_secs(5),
stale_while_revalidate: Duration::from_secs(0),
refresh_before: Duration::from_secs(0),
max_body_size: None,
min_body_size: None,
cache_statuses: HashSet::from([200, 203, 300, 301, 404]),
respect_cache_control: true,
method_predicate: None,
header_allowlist: None,
allow_streaming_bodies: false,
compression: CompressionConfig::default(),
ml_logging: MLLoggingConfig::default(),
tag_policy: TagPolicy::default(),
tag_extractor: None,
streaming_policy: StreamingPolicy::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
use std::collections::HashSet;
#[test]
fn ttl_for_prefers_primary_ttl_for_cacheable_status() {
let policy = CachePolicy::default().with_ttl(Duration::from_secs(123));
assert_eq!(
policy.ttl_for(StatusCode::OK),
Some(Duration::from_secs(123))
);
}
#[test]
fn ttl_for_uses_negative_ttl_for_client_error() {
let policy = CachePolicy::default().with_negative_ttl(Duration::from_secs(9));
assert_eq!(
policy.ttl_for(StatusCode::BAD_REQUEST),
Some(Duration::from_secs(9))
);
assert_eq!(policy.ttl_for(StatusCode::INTERNAL_SERVER_ERROR), None);
}
#[test]
fn headers_to_cache_respects_allowlist() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("content-type"),
HeaderValue::from_static("text/plain"),
);
headers.insert(
HeaderName::from_static("x-cacheable"),
HeaderValue::from_static("yes"),
);
let policy = CachePolicy::default().with_header_allowlist(["content-type"]);
let cached = policy.headers_to_cache(&headers);
let expected = vec![("content-type".to_owned(), b"text/plain".to_vec())];
assert_eq!(cached, expected);
}
#[test]
fn method_predicate_overrides_default_behavior() {
let policy = CachePolicy::default().with_method_predicate(|method| method == Method::POST);
assert!(!policy.should_cache_method(&Method::GET));
assert!(policy.should_cache_method(&Method::POST));
}
#[test]
fn with_statuses_updates_allowlist() {
let policy = CachePolicy::default().with_statuses([201, 202]);
assert_eq!(policy.ttl_for(StatusCode::CREATED), Some(policy.ttl()));
assert_eq!(policy.ttl_for(StatusCode::OK), None);
assert_eq!(policy.cache_statuses, HashSet::from([201_u16, 202_u16]));
}
}