use crate::{
CacheKey, CacheOptions, CachePolicy, CacheStorage, StoredEntry,
tee::TeeingReader,
validation::{AfterResponse, BeforeRequest},
};
use std::{sync::Arc, time::SystemTime};
use trillium::{Body, Conn, Handler, KnownHeaderName, Method};
use url::Url;
const DEFAULT_MAX_CACHEABLE_SIZE: u64 = 16 * 1024 * 1024;
#[derive(Debug)]
pub struct Cache<S: CacheStorage> {
storage: Arc<S>,
options: CacheOptions,
max_cacheable_size: u64,
}
impl<S: CacheStorage> Clone for Cache<S> {
fn clone(&self) -> Self {
Self {
storage: Arc::clone(&self.storage),
options: self.options,
max_cacheable_size: self.max_cacheable_size,
}
}
}
impl<S: CacheStorage> Cache<S> {
pub fn new(storage: S) -> Self {
Self {
storage: Arc::new(storage),
options: CacheOptions::default(),
max_cacheable_size: DEFAULT_MAX_CACHEABLE_SIZE,
}
}
pub fn with_options(mut self, options: CacheOptions) -> Self {
self.options = options;
self
}
pub fn shared(mut self) -> Self {
self.options.shared = true;
self
}
pub fn with_max_cacheable_size(mut self, max: u64) -> Self {
self.max_cacheable_size = max;
self
}
pub fn storage(&self) -> &S {
&self.storage
}
}
enum CacheCtx<E: StoredEntry> {
Hit,
Revalidation { stored: E, key: CacheKey },
Miss { key: CacheKey },
Unsafe { url: Url },
}
impl<E: StoredEntry> std::fmt::Debug for CacheCtx<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Hit => f.write_str("Hit"),
Self::Revalidation { key, .. } => f
.debug_struct("Revalidation")
.field("key", key)
.finish_non_exhaustive(),
Self::Miss { key } => f.debug_struct("Miss").field("key", key).finish(),
Self::Unsafe { url } => f.debug_struct("Unsafe").field("url", url).finish(),
}
}
}
fn url_from_conn(conn: &Conn) -> Option<Url> {
let scheme = if conn.is_secure() { "https" } else { "http" };
let host = conn.host()?;
let path_and_query = conn.path_and_query();
Url::parse(&format!("{scheme}://{host}{path_and_query}")).ok()
}
impl<S: CacheStorage> Handler for Cache<S> {
async fn run(&self, mut conn: Conn) -> Conn {
let method = conn.method();
let Some(url) = url_from_conn(&conn) else {
log::trace!("cache: no host on request, passing through without caching");
return conn;
};
let key = CacheKey::new(method, url.clone());
log::trace!("cache: run {method} {url}");
if !method.is_safe() {
log::trace!("cache: unsafe method {method}, bypassing cache read");
return conn.with_state(CacheCtx::<S::StoredEntry>::Unsafe { url });
}
let now = SystemTime::now();
let entries = self.storage.get(&key).await;
log::trace!("cache: {} stored candidate(s) for {key}", entries.len());
for entry in entries {
match entry.policy().before_request(conn.request_headers(), now) {
BeforeRequest::Fresh(cached) => {
log::trace!("cache: hit (fresh) for {key}, serving cached response");
*conn.response_headers_mut() = cached.headers;
let body = match entry.open().await {
Ok(b) => b,
Err(e) => {
log::warn!(
"cache: open for hit failed for {key}: {e}, passing through"
);
return conn;
}
};
return conn
.with_state(CacheCtx::<S::StoredEntry>::Hit)
.with_status(cached.status)
.with_body(body)
.halt();
}
BeforeRequest::NotModified(cached) => {
log::trace!("cache: hit (fresh, conditional matches) for {key}, serving 304");
*conn.response_headers_mut() = cached.headers;
return conn
.with_state(CacheCtx::<S::StoredEntry>::Hit)
.with_status(cached.status)
.with_body(Body::default())
.halt();
}
BeforeRequest::Stale {
request_headers,
matches: true,
} => {
log::trace!("cache: stale for {key}, sending conditional revalidation request");
*conn.request_headers_mut() = request_headers;
return conn.with_state(CacheCtx::Revalidation { stored: entry, key });
}
BeforeRequest::Stale { matches: false, .. } => {
log::trace!("cache: candidate vary-mismatch for {key}, trying next");
continue;
}
}
}
log::trace!("cache: miss for {key}, forwarding to downstream handler");
conn.with_state(CacheCtx::<S::StoredEntry>::Miss { key })
}
async fn before_send(&self, mut conn: Conn) -> Conn {
let Some(ctx) = conn.take_state::<CacheCtx<S::StoredEntry>>() else {
return conn;
};
match ctx {
CacheCtx::Hit => conn,
CacheCtx::Revalidation { stored, key } => {
let now = SystemTime::now();
let origin_failed = conn.status().is_some_and(|s| s.is_server_error());
if origin_failed && stored.policy().is_sie_eligible(now) {
log::trace!(
"cache: stale-if-error recovery for {} (downstream {:?}), serving stale",
conn.method(),
conn.status()
);
return apply_stale(conn, stored, now).await;
}
if conn.status().is_none() {
log::trace!("cache: downstream produced no status, passing through");
return conn;
}
self.handle_revalidation(conn, stored, key).await
}
CacheCtx::Miss { key } => {
if conn.status().is_none() {
log::trace!("cache: downstream produced no status, passing through");
return conn;
}
self.handle_miss(conn, key).await
}
CacheCtx::Unsafe { url } => {
let Some(status) = conn.status() else {
return conn;
};
if status.is_success() || status.is_redirection() {
log::trace!(
"cache: unsafe method {} → {}, invalidating GET and HEAD entries for {url}",
conn.method(),
status
);
self.invalidate_url(&url).await;
for header in [KnownHeaderName::Location, KnownHeaderName::ContentLocation] {
let Some(value) = conn.response_headers().get_str(header) else {
continue;
};
let Ok(target) = url.join(value) else {
continue;
};
if target.host_str() != url.host_str() {
continue;
}
log::trace!(
"cache: unsafe method secondary invalidation via {header}: {target}"
);
self.invalidate_url(&target).await;
}
}
conn
}
}
}
}
impl<S: CacheStorage> Cache<S> {
async fn invalidate_url(&self, url: &Url) {
self.storage
.invalidate(&CacheKey::new(Method::Get, url.clone()))
.await;
self.storage
.invalidate(&CacheKey::new(Method::Head, url.clone()))
.await;
}
async fn handle_revalidation(
&self,
mut conn: Conn,
mut stored: S::StoredEntry,
key: CacheKey,
) -> Conn {
let now = SystemTime::now();
let status = conn.status().expect("checked above");
match stored.policy().after_response(
conn.request_headers(),
status,
conn.response_headers(),
now,
) {
AfterResponse::NotModified(new_policy, cached_response) => {
log::trace!(
"cache: revalidation 304 for {key}, reusing stored body and refreshing entry"
);
if let Err(e) = stored.refresh_policy(new_policy).await {
log::warn!("cache: refresh_policy failed for {key}: {e}");
}
let body = match stored.open().await {
Ok(b) => b,
Err(e) => {
log::warn!("cache: open after 304 failed for {key}: {e}, passing through");
return conn;
}
};
*conn.response_headers_mut() = cached_response.headers;
conn.set_status(cached_response.status);
conn.set_body(body);
conn
}
AfterResponse::Modified => {
drop(stored);
self.handle_miss(conn, key).await
}
}
}
async fn handle_miss(&self, mut conn: Conn, key: CacheKey) -> Conn {
let status = conn.status().expect("checked above");
if !CachePolicy::is_storable(
conn.method(),
conn.request_headers(),
status,
conn.response_headers(),
&self.options,
) {
log::trace!("cache: miss for {key}, response not storable, passing through");
return conn;
}
if let Some(body_ref) = conn.response_body()
&& let Some(len) = body_ref.len()
&& len > self.max_cacheable_size
{
log::trace!(
"cache: miss for {key}, body {len} > max {}, not caching",
self.max_cacheable_size
);
return conn;
}
let policy = CachePolicy::new(
conn.method(),
conn.request_headers(),
status,
conn.response_headers().clone(),
SystemTime::now(),
self.options,
);
let put_handle = match self.storage.put(key.clone(), policy).await {
Ok(h) => h,
Err(e) => {
log::warn!("cache: put({key}) failed: {e}, passing through");
return conn;
}
};
let Some(body) = conn.take_response_body() else {
log::trace!("cache: miss for {key}, no body, passing through");
return conn;
};
let len = body.len();
log::trace!("cache: miss for {key}, streaming through tee");
let body = body.without_chunked_framing();
let tee = TeeingReader::new(body, put_handle, self.max_cacheable_size);
conn.set_body(Body::new_with_trailers(tee, len));
conn
}
}
async fn apply_stale<E: StoredEntry>(mut conn: Conn, stored: E, now: SystemTime) -> Conn {
let cached = stored.policy().cached_response(now);
let body = match stored.open().await {
Ok(b) => b,
Err(e) => {
log::warn!("cache: open for stale serve failed: {e}, passing through");
return conn;
}
};
*conn.response_headers_mut() = cached.headers;
conn.set_status(cached.status);
conn.set_body(body);
conn
}
#[cfg(test)]
mod tests {
use super::*;
use crate::InMemoryStorage;
use std::sync::atomic::{AtomicUsize, Ordering};
use trillium_testing::{TestResult, TestServer, harness, test};
#[derive(Debug, Clone)]
struct CountingHandler {
counter: Arc<AtomicUsize>,
cache_control: &'static str,
etag: Option<&'static str>,
}
impl CountingHandler {
fn new(cache_control: &'static str) -> Self {
Self {
counter: Arc::new(AtomicUsize::new(0)),
cache_control,
etag: None,
}
}
fn with_etag(mut self, etag: &'static str) -> Self {
self.etag = Some(etag);
self
}
}
impl Handler for CountingHandler {
async fn run(&self, conn: Conn) -> Conn {
let n = self.counter.fetch_add(1, Ordering::SeqCst);
if let Some(etag) = self.etag
&& conn.request_headers().get_str(KnownHeaderName::IfNoneMatch) == Some(etag)
{
return conn
.with_response_header(KnownHeaderName::Etag, etag)
.with_status(304)
.halt();
}
let mut conn = conn
.with_response_header(KnownHeaderName::CacheControl, self.cache_control)
.ok(format!("body-{n}"));
if let Some(etag) = self.etag {
conn.response_headers_mut()
.insert(KnownHeaderName::Etag, etag);
}
conn
}
}
fn cache_app(inner: CountingHandler) -> impl Handler {
(Cache::new(InMemoryStorage::new()), inner)
}
#[test(harness)]
async fn first_request_misses_subsequent_request_hits() -> TestResult {
let inner = CountingHandler::new("max-age=600");
let counter = inner.counter.clone();
let app = TestServer::new(cache_app(inner)).await;
let r1 = app.get("/x").await;
r1.assert_ok().assert_body("body-0");
let r2 = app.get("/x").await;
r2.assert_ok().assert_body("body-0");
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"inner handler only hit once"
);
Ok(())
}
#[test(harness)]
async fn different_urls_dont_collide() -> TestResult {
let inner = CountingHandler::new("max-age=600");
let counter = inner.counter.clone();
let app = TestServer::new(cache_app(inner)).await;
app.get("/a").await.assert_body("body-0");
app.get("/b").await.assert_body("body-1");
assert_eq!(counter.load(Ordering::SeqCst), 2);
Ok(())
}
#[test(harness)]
async fn no_store_response_is_not_cached() -> TestResult {
let inner = CountingHandler::new("no-store");
let counter = inner.counter.clone();
let app = TestServer::new(cache_app(inner)).await;
app.get("/x").await.assert_body("body-0");
app.get("/x").await.assert_body("body-1");
assert_eq!(counter.load(Ordering::SeqCst), 2);
Ok(())
}
#[test(harness)]
async fn post_invalidates_existing_entry() -> TestResult {
let inner = CountingHandler::new("max-age=600");
let counter = inner.counter.clone();
let app = TestServer::new(cache_app(inner)).await;
app.get("/x").await.assert_body("body-0");
let _ = app.post("/x").await;
app.get("/x").await.assert_body("body-2");
assert_eq!(counter.load(Ordering::SeqCst), 3);
Ok(())
}
#[test(harness)]
async fn stale_with_etag_revalidates_to_304() -> TestResult {
let inner = CountingHandler::new("max-age=0").with_etag(r#""v1""#);
let counter = inner.counter.clone();
let app = TestServer::new(cache_app(inner)).await;
app.get("/x").await.assert_body("body-0");
assert_eq!(counter.load(Ordering::SeqCst), 1);
let r2 = app.get("/x").await;
r2.assert_ok().assert_body("body-0");
assert_eq!(counter.load(Ordering::SeqCst), 2);
Ok(())
}
#[test(harness)]
async fn vary_isolates_entries_by_request_header() -> TestResult {
#[derive(Debug, Clone, Default)]
struct VaryHandler(Arc<AtomicUsize>);
impl Handler for VaryHandler {
async fn run(&self, conn: Conn) -> Conn {
self.0.fetch_add(1, Ordering::SeqCst);
let ae = conn
.request_headers()
.get_str(KnownHeaderName::AcceptEncoding)
.unwrap_or("none")
.to_string();
conn.with_response_header(KnownHeaderName::CacheControl, "max-age=600")
.with_response_header(KnownHeaderName::Vary, "Accept-Encoding")
.ok(format!("body-for-{ae}"))
}
}
let inner = VaryHandler::default();
let counter = inner.0.clone();
let app = TestServer::new((Cache::new(InMemoryStorage::new()), inner)).await;
app.get("/x")
.with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
.await
.assert_body("body-for-gzip");
app.get("/x")
.with_request_header(KnownHeaderName::AcceptEncoding, "br")
.await
.assert_body("body-for-br");
app.get("/x")
.with_request_header(KnownHeaderName::AcceptEncoding, "gzip")
.await
.assert_body("body-for-gzip");
assert_eq!(counter.load(Ordering::SeqCst), 2);
Ok(())
}
#[test(harness)]
async fn oversized_body_is_served_but_not_cached() -> TestResult {
let inner = CountingHandler::new("max-age=600");
let counter = inner.counter.clone();
let app = TestServer::new((
Cache::new(InMemoryStorage::new()).with_max_cacheable_size(3),
inner,
))
.await;
app.get("/x").await.assert_body("body-0");
app.get("/x").await.assert_body("body-1");
assert_eq!(counter.load(Ordering::SeqCst), 2);
Ok(())
}
#[test(harness)]
async fn sie_serves_stale_on_5xx() -> TestResult {
#[derive(Debug, Clone)]
struct FlakyHandler(Arc<AtomicUsize>);
impl Handler for FlakyHandler {
async fn run(&self, conn: Conn) -> Conn {
let n = self.0.fetch_add(1, Ordering::SeqCst);
if n == 0 {
conn.with_response_header(
KnownHeaderName::CacheControl,
"max-age=0, stale-if-error=3600",
)
.ok("stable")
} else {
conn.with_status(500).halt()
}
}
}
let inner = FlakyHandler(Arc::new(AtomicUsize::new(0)));
let counter = inner.0.clone();
let app = TestServer::new((Cache::new(InMemoryStorage::new()), inner)).await;
app.get("/x").await.assert_ok().assert_body("stable");
assert_eq!(counter.load(Ordering::SeqCst), 1);
let r2 = app.get("/x").await;
r2.assert_ok().assert_body("stable");
assert_eq!(counter.load(Ordering::SeqCst), 2);
Ok(())
}
}