use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use axum::body::{to_bytes, Body};
use axum::http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode};
use tower::Service;
use crate::cache::BoxedCache;
#[derive(serde::Serialize, serde::Deserialize)]
struct CachedResponse {
status: u16,
headers: Vec<(String, String)>,
body_b64: String,
}
#[derive(Clone)]
pub struct CachePageLayer {
cache: BoxedCache,
timeout: Duration,
key_prefix: String,
vary_on: Vec<HeaderName>,
}
impl CachePageLayer {
#[must_use]
pub fn new(cache: BoxedCache) -> Self {
Self {
cache,
timeout: Duration::from_secs(60),
key_prefix: "rustango.cache_page".to_owned(),
vary_on: Vec::new(),
}
}
#[must_use]
pub fn timeout(mut self, dur: Duration) -> Self {
self.timeout = dur;
self
}
#[must_use]
pub fn key_prefix(mut self, prefix: impl Into<String>) -> Self {
self.key_prefix = prefix.into();
self
}
#[must_use]
pub fn vary_on<I, S>(mut self, names: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
for n in names {
let lower = n.as_ref().to_ascii_lowercase();
let h = HeaderName::from_bytes(lower.as_bytes())
.expect("vary_on: header name must be valid ASCII");
self.vary_on.push(h);
}
self
}
}
impl<S> tower::Layer<S> for CachePageLayer {
type Service = CachePageService<S>;
fn layer(&self, inner: S) -> Self::Service {
CachePageService {
inner,
cache: self.cache.clone(),
timeout: self.timeout,
key_prefix: Arc::new(self.key_prefix.clone()),
vary_on: Arc::new(self.vary_on.clone()),
}
}
}
#[derive(Clone)]
pub struct CachePageService<S> {
inner: S,
cache: BoxedCache,
timeout: Duration,
key_prefix: Arc<String>,
vary_on: Arc<Vec<HeaderName>>,
}
impl<S> Service<Request<Body>> for CachePageService<S>
where
S: Service<Request<Body>, Response = Response<Body>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Body>;
type Error = Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Response<Body>, Infallible>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let cache = self.cache.clone();
let timeout = self.timeout;
let prefix = self.key_prefix.clone();
let vary = self.vary_on.clone();
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move {
if req.method() != axum::http::Method::GET {
return inner.call(req).await;
}
let key = compute_cache_key(&prefix, &req, &vary);
if let Ok(Some(serialized)) = cache.get(&key).await {
if let Ok(stored) = serde_json::from_str::<CachedResponse>(&serialized) {
if let Some(resp) = stored.into_response(&vary) {
return Ok(resp);
}
}
}
let resp = inner.call(req).await?;
let status = resp.status();
let cache_control_opt_out = resp
.headers()
.get_all(axum::http::header::CACHE_CONTROL)
.iter()
.any(|v| {
v.to_str()
.map(|s| s.to_ascii_lowercase().contains("no-store"))
.unwrap_or(false)
});
if status != StatusCode::OK || cache_control_opt_out {
return Ok(resp);
}
let (parts, body) = resp.into_parts();
let bytes = match to_bytes(body, MAX_CACHEABLE_BODY_BYTES).await {
Ok(b) => b,
Err(e) => {
tracing::warn!(
target: "rustango::cache_page",
error = %e,
max_bytes = MAX_CACHEABLE_BODY_BYTES,
"response body exceeds cache size limit or failed to buffer; \
passing through uncached"
);
let mut resp = Response::from_parts(parts, Body::empty());
resp.headers_mut()
.insert(X_CACHE_STATUS, HeaderValue::from_static("BYPASS"));
return Ok(resp);
}
};
let stored = CachedResponse::from_parts(&parts, &bytes);
if let Ok(json) = serde_json::to_string(&stored) {
if let Err(e) = cache.set(&key, &json, Some(timeout)).await {
tracing::warn!(
target: "rustango::cache_page",
error = %e,
"cache backend rejected set(); response served fresh, not cached"
);
}
}
let mut rebuilt = Response::from_parts(parts, Body::from(bytes));
let headers = rebuilt.headers_mut();
headers.insert(X_CACHE_STATUS, HeaderValue::from_static("MISS"));
apply_vary_header(headers, &vary);
Ok(rebuilt)
})
}
}
const MAX_CACHEABLE_BODY_BYTES: usize = 1 << 20;
const X_CACHE_STATUS: HeaderName = HeaderName::from_static("x-cache-status");
fn compute_cache_key(prefix: &str, req: &Request<Body>, vary_on: &[HeaderName]) -> String {
use std::fmt::Write as _;
let mut k = String::with_capacity(prefix.len() + 128);
let _ = write!(&mut k, "{prefix}|");
write_lp(&mut k, req.method().as_str());
write_lp(&mut k, req.uri().path());
write_lp(&mut k, req.uri().query().unwrap_or(""));
let host = req
.headers()
.get(axum::http::header::HOST)
.and_then(|h| h.to_str().ok())
.unwrap_or("");
write_lp(&mut k, host);
for name in vary_on {
let v = req
.headers()
.get(name)
.and_then(|h| h.to_str().ok())
.unwrap_or("");
write_lp(&mut k, name.as_str());
write_lp(&mut k, v);
}
k
}
fn write_lp(buf: &mut String, s: &str) {
use std::fmt::Write as _;
let _ = write!(buf, "{}:{}|", s.len(), s);
}
fn apply_vary_header(headers: &mut HeaderMap, vary_on: &[HeaderName]) {
use std::fmt::Write as _;
let mut parts: Vec<String> = Vec::with_capacity(vary_on.len() + 1);
parts.push("host".to_owned());
for n in vary_on {
parts.push(n.as_str().to_owned());
}
let mut s = String::new();
for (i, p) in parts.iter().enumerate() {
if i > 0 {
s.push_str(", ");
}
let _ = write!(&mut s, "{p}");
}
if let Ok(v) = HeaderValue::from_str(&s) {
headers.append(axum::http::header::VARY, v);
}
}
impl CachedResponse {
fn from_parts(parts: &axum::http::response::Parts, body: &[u8]) -> Self {
use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine as _;
let mut headers = Vec::with_capacity(parts.headers.len());
for (name, value) in parts.headers.iter() {
if name == X_CACHE_STATUS {
continue;
}
if let Ok(v) = value.to_str() {
headers.push((name.as_str().to_owned(), v.to_owned()));
}
}
Self {
status: parts.status.as_u16(),
headers,
body_b64: B64.encode(body),
}
}
fn into_response(self, vary_on: &[HeaderName]) -> Option<Response<Body>> {
use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine as _;
let body = B64.decode(&self.body_b64).ok()?;
let mut resp = Response::builder()
.status(StatusCode::from_u16(self.status).unwrap_or(StatusCode::OK))
.body(Body::from(body))
.ok()?;
let headers = resp.headers_mut();
for (name, value) in self.headers {
let Ok(n) = HeaderName::from_bytes(name.as_bytes()) else {
continue;
};
let Ok(v) = HeaderValue::from_str(&value) else {
continue;
};
headers.append(n, v);
}
headers.insert(X_CACHE_STATUS, HeaderValue::from_static("HIT"));
apply_vary_header(headers, vary_on);
Some(resp)
}
}
#[derive(Default, Clone, Debug)]
#[must_use = "call .build() to produce the HeaderValue"]
pub struct CacheControl {
max_age: Option<u64>,
public: bool,
private: bool,
no_cache: bool,
no_store: bool,
must_revalidate: bool,
s_maxage: Option<u64>,
}
impl CacheControl {
pub fn new() -> Self {
Self::default()
}
pub fn max_age(mut self, secs: u64) -> Self {
self.max_age = Some(secs);
self
}
pub fn s_maxage(mut self, secs: u64) -> Self {
self.s_maxage = Some(secs);
self
}
pub fn public(mut self) -> Self {
self.public = true;
self.private = false;
self
}
pub fn private(mut self) -> Self {
self.private = true;
self.public = false;
self
}
pub fn no_cache(mut self) -> Self {
self.no_cache = true;
self
}
pub fn no_store(mut self) -> Self {
self.no_store = true;
self
}
pub fn must_revalidate(mut self) -> Self {
self.must_revalidate = true;
self
}
pub fn build(self) -> HeaderValue {
let mut parts: Vec<String> = Vec::with_capacity(7);
if let Some(n) = self.max_age {
parts.push(format!("max-age={n}"));
}
if let Some(n) = self.s_maxage {
parts.push(format!("s-maxage={n}"));
}
if self.public {
parts.push("public".into());
}
if self.private {
parts.push("private".into());
}
if self.no_cache {
parts.push("no-cache".into());
}
if self.no_store {
parts.push("no-store".into());
}
if self.must_revalidate {
parts.push("must-revalidate".into());
}
HeaderValue::from_str(&parts.join(", ")).expect("ASCII directive string")
}
}
#[must_use]
pub fn never_cache() -> HeaderValue {
CacheControl::new()
.no_store()
.no_cache()
.must_revalidate()
.max_age(0)
.build()
}
#[must_use]
pub fn vary_on<I, S>(names: I) -> HeaderValue
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let parts: Vec<String> = names.into_iter().map(|s| s.as_ref().to_owned()).collect();
HeaderValue::from_str(&parts.join(", "))
.expect("vary_on: header names must be ASCII without control characters")
}
#[allow(dead_code)]
fn _trait_check(_h: &HeaderMap) {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cache_control_builds_expected_directive() {
let v = CacheControl::new().max_age(60).public().build();
let s = v.to_str().unwrap();
assert!(s.contains("max-age=60"));
assert!(s.contains("public"));
}
#[test]
fn never_cache_emits_full_no_store_directive() {
let s = never_cache().to_str().unwrap().to_string();
assert!(s.contains("no-store"));
assert!(s.contains("no-cache"));
assert!(s.contains("must-revalidate"));
assert!(s.contains("max-age=0"));
}
#[test]
fn public_and_private_are_mutually_exclusive() {
let s = CacheControl::new()
.public()
.private()
.build()
.to_str()
.unwrap()
.to_string();
assert!(s.contains("private"), "last call wins");
assert!(!s.contains("public"));
}
#[test]
fn vary_on_joins_with_comma_space() {
let v = vary_on(["cookie", "accept-language"]);
assert_eq!(v.to_str().unwrap(), "cookie, accept-language");
}
}