use std::{
convert::Infallible,
fmt::Debug,
future::Future,
hash::Hash,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
time::Duration,
};
use tracing_futures::Instrument as _;
#[cfg(feature = "axum07")]
use axum_07 as axum;
#[cfg(feature = "axum08")]
use axum_08 as axum;
use axum::body;
use axum::{
body::{Body, Bytes},
http::{response::Parts, Request, StatusCode},
response::{IntoResponse, Response},
};
use cached::{Cached, CloneCached, TimedCache};
use tower::{Layer, Service};
use tracing::{debug, instrument};
pub trait Keyer {
type Key;
fn get_key(&self, request: &Request<Body>) -> Self::Key;
}
impl<K, F> Keyer for F
where
F: Fn(&Request<Body>) -> K + Send + Sync + 'static,
{
type Key = K;
fn get_key(&self, request: &Request<Body>) -> Self::Key {
self(request)
}
}
pub struct BasicKeyer;
pub type BasicKey = (http::Method, http::Uri);
impl Keyer for BasicKeyer {
type Key = BasicKey;
fn get_key(&self, request: &Request<Body>) -> Self::Key {
(request.method().clone(), request.uri().clone())
}
}
#[derive(Clone, Debug)]
pub struct CachedResponse {
parts: Parts,
body: Bytes,
timestamp: Option<std::time::Instant>,
}
impl IntoResponse for CachedResponse {
fn into_response(self) -> Response {
let mut response = Response::from_parts(self.parts, Body::from(self.body));
if let Some(timestamp) = self.timestamp {
let age = timestamp.elapsed().as_secs();
response
.headers_mut()
.insert("X-Cache-Age", age.to_string().parse().unwrap());
}
response
}
}
pub struct CacheLayer<C, K> {
cache: Arc<Mutex<C>>,
use_stale: bool,
limit: usize,
allow_invalidation: bool,
add_response_headers: bool,
keyer: Arc<K>,
}
impl<C, K> Clone for CacheLayer<C, K> {
fn clone(&self) -> Self {
Self {
cache: Arc::clone(&self.cache),
use_stale: self.use_stale,
limit: self.limit,
allow_invalidation: self.allow_invalidation,
add_response_headers: self.add_response_headers,
keyer: Arc::clone(&self.keyer),
}
}
}
impl<C, K> CacheLayer<C, K>
where
C: Cached<K::Key, CachedResponse> + CloneCached<K::Key, CachedResponse>,
K: Keyer,
K::Key: Debug + Hash + Eq + Clone + Send + 'static,
{
pub fn with_cache_and_keyer(cache: C, keyer: K) -> Self {
Self {
cache: Arc::new(Mutex::new(cache)),
use_stale: false,
limit: 128 * 1024 * 1024,
allow_invalidation: false,
add_response_headers: false,
keyer: Arc::new(keyer),
}
}
pub fn use_stale_on_failure(self) -> Self {
Self {
use_stale: true,
..self
}
}
pub fn body_limit(self, new_limit: usize) -> Self {
Self {
limit: new_limit,
..self
}
}
pub fn allow_invalidation(self) -> Self {
Self {
allow_invalidation: true,
..self
}
}
pub fn add_response_headers(self) -> Self {
Self {
add_response_headers: true,
..self
}
}
}
impl<C> CacheLayer<C, BasicKeyer>
where
C: Cached<BasicKey, CachedResponse> + CloneCached<BasicKey, CachedResponse>,
{
pub fn with(cache: C) -> Self {
Self {
cache: Arc::new(Mutex::new(cache)),
use_stale: false,
limit: 128 * 1024 * 1024,
allow_invalidation: false,
add_response_headers: false,
keyer: Arc::new(BasicKeyer),
}
}
}
impl CacheLayer<TimedCache<BasicKey, CachedResponse>, BasicKey> {
pub fn with_lifespan(
ttl: Duration,
) -> CacheLayer<TimedCache<BasicKey, CachedResponse>, BasicKeyer> {
CacheLayer::with(TimedCache::with_lifespan(ttl))
}
}
impl<K> CacheLayer<TimedCache<K::Key, CachedResponse>, K>
where
K: Keyer,
K::Key: Debug + Hash + Eq + Clone + Send + 'static,
{
pub fn with_lifespan_and_keyer(
ttl: Duration,
keyer: K,
) -> CacheLayer<TimedCache<K::Key, CachedResponse>, K> {
CacheLayer::with_cache_and_keyer(TimedCache::with_lifespan(ttl), keyer)
}
}
impl<S, C, K> Layer<S> for CacheLayer<C, K>
where
K: Keyer,
K::Key: Debug + Hash + Eq + Clone + Send + 'static,
{
type Service = CacheService<S, C, K>;
fn layer(&self, inner: S) -> Self::Service {
Self::Service {
inner,
cache: Arc::clone(&self.cache),
use_stale: self.use_stale,
limit: self.limit,
allow_invalidation: self.allow_invalidation,
add_response_headers: self.add_response_headers,
keyer: Arc::clone(&self.keyer),
}
}
}
pub struct CacheService<S, C, K> {
inner: S,
cache: Arc<Mutex<C>>,
use_stale: bool,
limit: usize,
allow_invalidation: bool,
add_response_headers: bool,
keyer: Arc<K>,
}
impl<S, C, K> Clone for CacheService<S, C, K>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
cache: Arc::clone(&self.cache),
use_stale: self.use_stale,
limit: self.limit,
allow_invalidation: self.allow_invalidation,
add_response_headers: self.add_response_headers,
keyer: Arc::clone(&self.keyer),
}
}
}
impl<S, C, K> Service<Request<Body>> for CacheService<S, C, K>
where
S: Service<Request<Body>, Response = Response, Error = Infallible> + Clone + Send,
S::Future: Send + 'static,
C: Cached<K::Key, CachedResponse> + CloneCached<K::Key, CachedResponse> + Send + 'static,
K: Keyer,
K::Key: Debug + Hash + Eq + Clone + Send + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Response, Infallible>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
#[instrument(skip(self, request))]
fn call(&mut self, request: Request<Body>) -> Self::Future {
let mut inner = self.inner.clone();
let use_stale = self.use_stale;
let allow_invalidation = self.allow_invalidation;
let add_response_headers = self.add_response_headers;
let limit = self.limit;
let cache = Arc::clone(&self.cache);
let key = self.keyer.get_key(&request);
if allow_invalidation && request.headers().contains_key("X-Invalidate-Cache") {
cache.lock().unwrap().cache_remove(&key);
debug!("Cache invalidated manually for key {:?}", key);
}
let inner_fut = inner
.call(request)
.instrument(tracing::info_span!("inner_service"));
let (cached, evicted) = {
let mut guard = cache.lock().unwrap();
let (cached, evicted) = guard.cache_get_expired(&key);
if let (Some(stale), true) = (cached.as_ref(), evicted) {
debug!("Found stale value in cache, reinsterting and attempting refresh");
guard.cache_set(key.clone(), stale.clone());
}
(cached, evicted)
};
Box::pin(async move {
match (cached, evicted) {
(Some(value), false) => Ok(value.into_response()),
(Some(stale_value), true) => {
let response = inner_fut.await.unwrap();
if response.status().is_success() {
Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
} else if use_stale {
debug!("Returning stale value.");
Ok(stale_value.into_response())
} else {
debug!("Stale value in cache, evicting and returning failed response.");
cache.lock().unwrap().cache_remove(&key);
Ok(response)
}
}
(None, _) => {
let response = inner_fut.await.unwrap();
if response.status().is_success() {
Ok(update_cache(&cache, key, response, limit, add_response_headers).await)
} else {
Ok(response)
}
}
}
})
}
}
#[instrument(skip(cache, response))]
async fn update_cache<C, K>(
cache: &Arc<Mutex<C>>,
key: K,
response: Response,
limit: usize,
add_response_headers: bool,
) -> Response
where
C: Cached<K, CachedResponse> + CloneCached<K, CachedResponse>,
K: Debug + Hash + Eq + Clone + Send + 'static,
{
let (parts, body) = response.into_parts();
let Ok(body) = body::to_bytes(body, limit).await else {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("File too big, over {limit} bytes"),
)
.into_response();
};
let value = CachedResponse {
parts,
body,
timestamp: if add_response_headers {
Some(std::time::Instant::now())
} else {
None
},
};
{
cache.lock().unwrap().cache_set(key, value.clone());
}
value.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use std::sync::atomic::{AtomicIsize, Ordering};
#[cfg(feature = "axum07")]
use axum_07 as axum;
#[cfg(feature = "axum08")]
use axum_08 as axum;
use axum::{
extract::State,
http::{Request, StatusCode},
routing::get,
Router,
};
use tower::Service;
#[derive(Clone, Debug)]
struct Counter {
value: Arc<AtomicIsize>,
}
impl Counter {
fn new(init: isize) -> Self {
Self {
value: AtomicIsize::from(init).into(),
}
}
fn increment(&self) {
self.value.fetch_add(1, Ordering::Release);
}
fn read(&self) -> isize {
self.value.load(Ordering::Acquire)
}
}
#[tokio::test]
async fn should_use_cached_value() {
let handler = |State(cnt): State<Counter>| async move {
cnt.increment();
StatusCode::OK
};
let counter = Counter::new(0);
let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).use_stale_on_failure();
let mut router = Router::new()
.route("/", get(handler).layer(cache))
.with_state(counter.clone());
for _ in 0..10 {
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(status.is_success(), "handler should return success");
}
assert_eq!(1, counter.read(), "handler should’ve been called only once");
}
#[tokio::test]
async fn should_not_cache_unsuccessful_responses() {
let handler = |State(cnt): State<Counter>| async move {
cnt.increment();
let responses = [
StatusCode::BAD_REQUEST,
StatusCode::INTERNAL_SERVER_ERROR,
StatusCode::NOT_FOUND,
];
let mut rng = rand::rng();
responses[rng.random_range(0..responses.len())]
};
let counter = Counter::new(0);
let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).use_stale_on_failure();
let mut router = Router::new()
.route("/", get(handler).layer(cache))
.with_state(counter.clone());
for _ in 0..10 {
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(!status.is_success(), "handler should never return success");
}
assert_eq!(
10,
counter.read(),
"handler should’ve been called for all requests"
);
}
#[tokio::test]
async fn should_use_last_correct_stale_value() {
let handler = |State(cnt): State<Counter>| async move {
let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
let responses = [
StatusCode::BAD_REQUEST,
StatusCode::INTERNAL_SERVER_ERROR,
StatusCode::NOT_FOUND,
];
let mut rng = rand::rng();
if prev == 0 {
StatusCode::OK
} else {
responses[rng.random_range(0..responses.len())]
}
};
let counter = Counter::new(0);
let cache = CacheLayer::with_lifespan(Duration::from_millis(100)).use_stale_on_failure();
let mut router = Router::new()
.route("/", get(handler).layer(cache))
.with_state(counter);
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(status.is_success(), "handler should return success");
tokio::time::sleep(tokio::time::Duration::from_millis(105)).await;
for _ in 1..10 {
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(
status.is_success(),
"cache should return stale successful value"
);
}
}
#[tokio::test]
async fn should_not_use_stale_values() {
let handler = |State(cnt): State<Counter>| async move {
let prev = cnt.value.fetch_add(1, Ordering::AcqRel);
let responses = [
StatusCode::BAD_REQUEST,
StatusCode::INTERNAL_SERVER_ERROR,
StatusCode::NOT_FOUND,
];
let mut rng = rand::rng();
if prev == 0 {
StatusCode::OK
} else {
responses[rng.random_range(0..responses.len())]
}
};
let counter = Counter::new(0);
let cache = CacheLayer::with_lifespan(Duration::from_millis(100));
let mut router = Router::new()
.route("/", get(handler).layer(cache))
.with_state(counter.clone());
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(status.is_success(), "handler should return success");
tokio::time::sleep(tokio::time::Duration::from_millis(105)).await;
for _ in 1..10 {
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(
!status.is_success(),
"cache should forward unsuccessful values"
);
}
assert_eq!(
10,
counter.read(),
"handler should’ve been called for all requests"
);
}
#[tokio::test]
async fn should_not_invalidate_cache_when_disabled() {
let handler = |State(cnt): State<Counter>| async move {
cnt.increment();
StatusCode::OK
};
let counter = Counter::new(0);
let cache = CacheLayer::with_lifespan(Duration::from_secs(60));
let mut router = Router::new()
.route("/", get(handler).layer(cache))
.with_state(counter.clone());
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(status.is_success(), "handler should return success");
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(status.is_success(), "handler should return success");
let status = router
.call(
Request::get("/")
.header("X-Invalidate-Cache", "true")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap()
.status();
assert!(status.is_success(), "handler should return success");
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(status.is_success(), "handler should return success");
assert_eq!(1, counter.read(), "handler should’ve been called only once");
}
#[tokio::test]
async fn should_invalidate_cache_when_enabled() {
let handler = |State(cnt): State<Counter>| async move {
cnt.increment();
StatusCode::OK
};
let counter = Counter::new(0);
let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).allow_invalidation();
let mut router = Router::new()
.route("/", get(handler).layer(cache))
.with_state(counter.clone());
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(status.is_success(), "handler should return success");
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(status.is_success(), "handler should return success");
let status = router
.call(
Request::get("/")
.header("X-Invalidate-Cache", "true")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap()
.status();
assert!(status.is_success(), "handler should return success");
let status = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap()
.status();
assert!(status.is_success(), "handler should return success");
assert_eq!(2, counter.read(), "handler should’ve been called twice");
}
#[tokio::test]
async fn should_not_include_age_header_when_disabled() {
let handler = |State(cnt): State<Counter>| async move {
cnt.increment();
StatusCode::OK
};
let counter = Counter::new(0);
let cache = CacheLayer::with_lifespan(Duration::from_secs(60));
let mut router = Router::new()
.route("/", get(handler).layer(cache))
.with_state(counter.clone());
let response = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert!(
response.status().is_success(),
"handler should return success"
);
let response = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert!(
response.status().is_success(),
"handler should return success"
);
assert!(
response.headers().get("X-Cache-Age").is_none(),
"Age header should not be present"
);
assert_eq!(1, counter.read(), "handler should’ve been called only once");
}
#[tokio::test]
async fn should_include_age_header_when_enabled() {
let handler = |State(cnt): State<Counter>| async move {
cnt.increment();
StatusCode::OK
};
let counter = Counter::new(0);
let cache = CacheLayer::with_lifespan(Duration::from_secs(60)).add_response_headers();
let mut router = Router::new()
.route("/", get(handler).layer(cache))
.with_state(counter.clone());
let response = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert!(
response.status().is_success(),
"handler should return success"
);
assert_eq!(
response
.headers()
.get("X-Cache-Age")
.and_then(|v| v.to_str().ok())
.unwrap_or(""),
"0",
"Age header should be present and equal to 0"
);
tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
let response = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(
response
.headers()
.get("X-Cache-Age")
.and_then(|v| v.to_str().ok())
.unwrap_or(""),
"2",
"Age header should be present and equal to 2"
);
assert_eq!(1, counter.read(), "handler should’ve been called only once");
}
#[tokio::test]
async fn should_cache_by_custom_keys() {
let handler = |State(cnt): State<Counter>| async move {
cnt.increment();
StatusCode::OK
};
let counter = Counter::new(0);
let keyer = |request: &Request<Body>| {
(
request.method().clone(),
request
.headers()
.get(axum::http::header::ACCEPT)
.and_then(|c| c.to_str().ok())
.unwrap_or("")
.to_string(),
request.uri().clone(),
)
};
let cache = CacheLayer::with_lifespan_and_keyer(Duration::from_secs(60), keyer)
.add_response_headers();
let mut router = Router::new()
.route("/", get(handler).layer(cache))
.with_state(counter.clone());
let response = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert!(
response.status().is_success(),
"handler should return success"
);
assert_eq!(
response
.headers()
.get("X-Cache-Age")
.and_then(|v| v.to_str().ok())
.unwrap_or(""),
"0",
"Age header should be present and equal to 0"
);
tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
let response = router
.call(Request::get("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(
response
.headers()
.get("X-Cache-Age")
.and_then(|v| v.to_str().ok())
.unwrap_or(""),
"2",
"Age header should be present and equal to 2"
);
let response = router
.call(
Request::get("/")
.header(axum::http::header::ACCEPT, "application/json")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(
response
.headers()
.get("X-Cache-Age")
.and_then(|v| v.to_str().ok())
.unwrap_or(""),
"0",
"Age header should be present and equal to 0"
);
tokio::time::sleep(tokio::time::Duration::from_millis(2100)).await;
let response = router
.call(
Request::get("/")
.header(axum::http::header::ACCEPT, "application/json")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(
response
.headers()
.get("X-Cache-Age")
.and_then(|v| v.to_str().ok())
.unwrap_or(""),
"2",
"Age header should be present and equal to 2"
);
assert_eq!(
2,
counter.read(),
"handler should’ve been called only twice"
);
}
}