use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::LazyLock,
task::{Context, Poll},
time::{Duration, Instant},
};
use axum::{
body::{Body, to_bytes},
http::{
HeaderMap, Method, Request, StatusCode, Uri,
header::{self, CACHE_CONTROL, CONTENT_TYPE, HeaderValue, VARY},
},
response::Response,
};
use parking_lot::RwLock;
use tower::{Layer, Service};
type BoxResponseFuture<E> = Pin<Box<dyn Future<Output = Result<Response, E>> + Send>>;
static CACHE_STORE: LazyLock<RwLock<HashMap<String, CachedResponse>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
#[derive(Clone)]
struct CachedResponse {
status: StatusCode,
content_type: Option<String>,
body: String,
expires_at: Instant,
}
#[derive(Clone)]
pub struct CacheMiddlewareLayer {
pub cache_timeout: u64,
pub key_prefix: String,
pub cache_anonymous_only: bool,
}
impl Default for CacheMiddlewareLayer {
fn default() -> Self {
Self {
cache_timeout: 600,
key_prefix: "rjango".to_string(),
cache_anonymous_only: false,
}
}
}
impl<S> Layer<S> for CacheMiddlewareLayer {
type Service = CacheMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
CacheMiddleware {
inner,
cache_timeout: self.cache_timeout,
key_prefix: self.key_prefix.clone(),
cache_anonymous_only: self.cache_anonymous_only,
}
}
}
#[derive(Clone)]
pub struct CacheMiddleware<S> {
inner: S,
cache_timeout: u64,
key_prefix: String,
cache_anonymous_only: bool,
}
#[must_use]
fn canonical_cache_method(method: &Method) -> &str {
if matches!(*method, Method::GET | Method::HEAD) {
Method::GET.as_str()
} else {
method.as_str()
}
}
#[must_use]
fn cache_key(prefix: &str, method: &Method, uri: &Uri) -> String {
let query = uri.query().unwrap_or_default();
format!(
"{prefix}:{}:{}?{query}",
canonical_cache_method(method),
uri.path()
)
}
#[must_use]
fn is_cache_lookup_method(method: &Method) -> bool {
matches!(*method, Method::GET | Method::HEAD)
}
#[must_use]
fn is_mutating_method(method: &Method) -> bool {
!is_cache_lookup_method(method)
}
#[must_use]
fn cache_control_is_private(headers: &HeaderMap) -> bool {
headers
.get(CACHE_CONTROL)
.and_then(|value| value.to_str().ok())
.is_some_and(|value| {
value
.split(',')
.map(str::trim)
.any(|directive| directive.eq_ignore_ascii_case("private"))
})
}
fn append_vary(headers: &mut HeaderMap, value: &str) {
let mut vary_values = headers
.get(VARY)
.and_then(|header| header.to_str().ok())
.map(|header| {
header
.split(',')
.map(str::trim)
.filter(|entry| !entry.is_empty())
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
})
.unwrap_or_default();
if vary_values
.iter()
.any(|entry| entry.eq_ignore_ascii_case(value))
{
return;
}
vary_values.push(value.to_string());
headers.insert(
VARY,
HeaderValue::from_str(&vary_values.join(", ")).expect("vary header value should be valid"),
);
}
fn invalidate_cache_entry(prefix: &str, uri: &Uri) {
let key = cache_key(prefix, &Method::GET, uri);
CACHE_STORE.write().remove(&key);
}
#[must_use]
fn cached_response(
key: &str,
request_method: &Method,
cache_timeout: u64,
cache_anonymous_only: bool,
) -> Option<Response> {
let cached = CACHE_STORE.read().get(key).cloned()?;
if Instant::now() >= cached.expires_at {
CACHE_STORE.write().remove(key);
return None;
}
let mut builder = Response::builder().status(cached.status);
if let Some(content_type) = cached.content_type.as_deref() {
builder = builder.header(CONTENT_TYPE, content_type);
}
builder = builder.header(CACHE_CONTROL, format!("max-age={cache_timeout}"));
let mut response = builder
.body(if *request_method == Method::HEAD {
Body::empty()
} else {
Body::from(cached.body)
})
.expect("cached response should build");
if cache_anonymous_only {
append_vary(response.headers_mut(), "Cookie");
}
Some(response)
}
#[must_use]
fn should_cache_response(
method: &Method,
status: StatusCode,
headers: &HeaderMap,
cache_anonymous_only: bool,
cookie_present: bool,
) -> bool {
*method == Method::GET
&& status == StatusCode::OK
&& !cache_control_is_private(headers)
&& !(cache_anonymous_only && cookie_present)
}
impl<S> Service<Request<Body>> for CacheMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = BoxResponseFuture<Self::Error>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let request_method = request.method().clone();
let request_uri = request.uri().clone();
let cookie_present = request.headers().contains_key(header::COOKIE);
let cache_timeout = self.cache_timeout;
let key_prefix = self.key_prefix.clone();
let cache_anonymous_only = self.cache_anonymous_only;
let key = cache_key(&key_prefix, &request_method, &request_uri);
if is_mutating_method(&request_method) {
invalidate_cache_entry(&key_prefix, &request_uri);
return Box::pin(self.inner.call(request));
}
if is_cache_lookup_method(&request_method)
&& !(cache_anonymous_only && cookie_present)
&& let Some(response) =
cached_response(&key, &request_method, cache_timeout, cache_anonymous_only)
{
return Box::pin(async move { Ok(response) });
}
let future = self.inner.call(request);
Box::pin(async move {
let response = future.await?;
if !should_cache_response(
&request_method,
response.status(),
response.headers(),
cache_anonymous_only,
cookie_present,
) {
let mut response = response;
if cache_anonymous_only {
append_vary(response.headers_mut(), "Cookie");
}
return Ok(response);
}
let (mut parts, body) = response.into_parts();
let body_bytes = to_bytes(body, usize::MAX)
.await
.expect("cache middleware should be able to buffer response bodies");
let body_text = String::from_utf8_lossy(&body_bytes).into_owned();
let content_type = parts
.headers
.get(CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(ToOwned::to_owned);
parts.headers.insert(
CACHE_CONTROL,
HeaderValue::from_str(&format!("max-age={cache_timeout}"))
.expect("cache-control header should be valid"),
);
if cache_anonymous_only {
append_vary(&mut parts.headers, "Cookie");
}
CACHE_STORE.write().insert(
key,
CachedResponse {
status: parts.status,
content_type,
body: body_text,
expires_at: Instant::now() + Duration::from_secs(cache_timeout),
},
);
Ok(Response::from_parts(parts, Body::from(body_bytes)))
})
}
}
#[cfg(test)]
mod tests {
use std::{
convert::Infallible,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use axum::{
body::to_bytes,
http::{Request, StatusCode, header},
};
use tower::{ServiceExt, service_fn};
use super::*;
fn clear_cache() {
CACHE_STORE.write().clear();
}
fn build_layer(prefix: &str, timeout: u64) -> CacheMiddlewareLayer {
CacheMiddlewareLayer {
cache_timeout: timeout,
key_prefix: prefix.to_string(),
cache_anonymous_only: false,
}
}
async fn response_body(response: Response) -> String {
let body = to_bytes(response.into_body(), usize::MAX)
.await
.expect("body should be readable");
String::from_utf8(body.to_vec()).expect("body should be valid utf-8")
}
#[tokio::test]
async fn test_cache_miss_passes_through() {
let counter = Arc::new(AtomicUsize::new(0));
let service_counter = Arc::clone(&counter);
let service =
build_layer("cache-miss", 60).layer(service_fn(move |_request: Request<Body>| {
let service_counter = Arc::clone(&service_counter);
async move {
let current = service_counter.fetch_add(1, Ordering::SeqCst) + 1;
Ok::<_, Infallible>(
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/plain")
.body(Body::from(format!("response-{current}")))
.expect("response should build"),
)
}
}));
let response = service
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/articles")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert_eq!(response_body(response).await, "response-1");
}
#[tokio::test]
async fn test_cache_hit_returns_cached_response() {
let counter = Arc::new(AtomicUsize::new(0));
let service_counter = Arc::clone(&counter);
let layer = build_layer("cache-hit", 60);
let service = layer.layer(service_fn(move |_request: Request<Body>| {
let service_counter = Arc::clone(&service_counter);
async move {
let current = service_counter.fetch_add(1, Ordering::SeqCst) + 1;
Ok::<_, Infallible>(
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/plain")
.body(Body::from(format!("cached-{current}")))
.expect("response should build"),
)
}
}));
let first = service
.clone()
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/cached")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
let second = service
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/cached")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert_eq!(response_body(first).await, "cached-1");
assert_eq!(response_body(second).await, "cached-1");
}
#[tokio::test]
async fn test_post_request_not_cached() {
let counter = Arc::new(AtomicUsize::new(0));
let service_counter = Arc::clone(&counter);
let layer = build_layer("post-not-cached", 60);
let service = layer.layer(service_fn(move |_request: Request<Body>| {
let service_counter = Arc::clone(&service_counter);
async move {
let current = service_counter.fetch_add(1, Ordering::SeqCst) + 1;
Ok::<_, Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(Body::from(format!("post-{current}")))
.expect("response should build"),
)
}
}));
let first = service
.clone()
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/submit")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
let second = service
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/submit")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(counter.load(Ordering::SeqCst), 2);
assert_eq!(response_body(first).await, "post-1");
assert_eq!(response_body(second).await, "post-2");
}
#[tokio::test]
async fn test_cache_respects_timeout() {
let counter = Arc::new(AtomicUsize::new(0));
let service_counter = Arc::clone(&counter);
let layer = build_layer("timeout", 0);
let service = layer.layer(service_fn(move |_request: Request<Body>| {
let service_counter = Arc::clone(&service_counter);
async move {
let current = service_counter.fetch_add(1, Ordering::SeqCst) + 1;
Ok::<_, Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(Body::from(format!("timeout-{current}")))
.expect("response should build"),
)
}
}));
let first = service
.clone()
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/timeout")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
let second = service
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/timeout")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(counter.load(Ordering::SeqCst), 2);
assert_eq!(response_body(first).await, "timeout-1");
assert_eq!(response_body(second).await, "timeout-2");
}
#[tokio::test]
async fn test_head_request_uses_get_cache_entry() {
let counter = Arc::new(AtomicUsize::new(0));
let service_counter = Arc::clone(&counter);
let layer = build_layer("head-cache", 60);
let service = layer.layer(service_fn(move |_request: Request<Body>| {
let service_counter = Arc::clone(&service_counter);
async move {
let current = service_counter.fetch_add(1, Ordering::SeqCst) + 1;
Ok::<_, Infallible>(
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/plain")
.body(Body::from(format!("head-{current}")))
.expect("response should build"),
)
}
}));
let _ = service
.clone()
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/head-test")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
let response = service
.oneshot(
Request::builder()
.method(Method::HEAD)
.uri("/head-test")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response_body(response).await, "");
}
#[tokio::test]
async fn test_mutating_request_invalidates_matching_cache_key() {
let counter = Arc::new(AtomicUsize::new(0));
let service_counter = Arc::clone(&counter);
let layer = build_layer("invalidate", 60);
let service = layer.layer(service_fn(move |_request: Request<Body>| {
let service_counter = Arc::clone(&service_counter);
async move {
let current = service_counter.fetch_add(1, Ordering::SeqCst) + 1;
Ok::<_, Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(Body::from(format!("item-{current}")))
.expect("response should build"),
)
}
}));
let first_get = service
.clone()
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/item/1")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
let _ = service
.clone()
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/item/1")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
let second_get = service
.oneshot(
Request::builder()
.method(Method::GET)
.uri("/item/1")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(counter.load(Ordering::SeqCst), 3);
assert_eq!(response_body(first_get).await, "item-1");
assert_eq!(response_body(second_get).await, "item-3");
}
}