use std::convert::Infallible;
use std::sync::Arc;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::http::{Method, StatusCode};
use http::Request;
use http_body_util::BodyExt;
use tower::{Layer, Service};
use super::Cache;
#[derive(Clone)]
struct CachedResponse {
status: StatusCode,
headers: http::HeaderMap,
body: bytes::Bytes,
}
#[derive(Clone)]
pub struct CacheResponseLayer {
store: Arc<dyn Cache>,
}
impl CacheResponseLayer {
pub fn from_cache(store: impl Cache + 'static) -> Self {
Self {
store: Arc::new(store),
}
}
pub fn from_shared(store: Arc<dyn Cache>) -> Self {
Self { store }
}
}
impl<S> Layer<S> for CacheResponseLayer {
type Service = CacheResponseService<S>;
fn layer(&self, inner: S) -> Self::Service {
CacheResponseService {
inner,
store: self.store.clone(),
}
}
}
#[derive(Clone)]
pub struct CacheResponseService<S> {
inner: S,
store: Arc<dyn Cache>,
}
impl<S> Service<Request<Body>> for CacheResponseService<S>
where
S: Service<Request<Body>, Response = axum::response::Response, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = axum::response::Response;
type Error = Infallible;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
if req.method() != Method::GET {
return Box::pin(self.inner.call(req));
}
let mut buf = [0u8; 512];
let cache_key_str = {
let mut cursor = &mut buf[..];
if std::io::Write::write_fmt(&mut cursor, format_args!("http:{}", req.uri())).is_ok() {
let len = 512 - cursor.len();
std::str::from_utf8(&buf[..len]).unwrap_or_default()
} else {
""
}
};
let store = self.store.clone();
let cache_hit = if cache_key_str.is_empty() {
super::get::<CachedResponse>(store.as_ref(), &format!("http:{}", req.uri()))
} else {
super::get::<CachedResponse>(store.as_ref(), cache_key_str)
};
if let Some(cached) = cache_hit {
return Box::pin(async move {
let mut builder = axum::response::Response::builder().status(cached.status);
if let Some(headers) = builder.headers_mut() {
headers.extend(cached.headers);
}
let resp = builder.body(Body::from(cached.body)).unwrap_or_else(|_| {
axum::response::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.expect("infallible response builder")
});
Ok(resp)
});
}
let mut inner = self.inner.clone();
let cache_key = if cache_key_str.is_empty() {
format!("http:{}", req.uri())
} else {
cache_key_str.to_owned()
};
Box::pin(async move {
let response = inner.call(req).await?;
if response.status() != StatusCode::OK {
return Ok(response);
}
let (parts, body) = response.into_parts();
let Ok(collected) = body.collect().await else {
let resp = axum::response::Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.expect("infallible response builder");
return Ok(resp);
};
let body_bytes = collected.to_bytes();
let cached = CachedResponse {
status: parts.status,
headers: parts.headers.clone(),
body: body_bytes.clone(),
};
super::insert(store.as_ref(), &cache_key, cached);
let response = axum::response::Response::from_parts(parts, Body::from(body_bytes));
Ok(response)
})
}
}
#[cfg(all(test, feature = "cache-moka"))]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tower::{ServiceBuilder, ServiceExt};
fn counting_service(
counter: Arc<AtomicUsize>,
body: &'static str,
) -> impl Service<
Request<Body>,
Response = axum::response::Response,
Error = Infallible,
Future = impl std::future::Future<Output = Result<axum::response::Response, Infallible>> + Send,
> + Clone
+ Send
+ 'static {
let body = body.to_owned();
tower::service_fn(move |_req: Request<Body>| {
let counter = counter.clone();
let body = body.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(axum::response::Response::builder()
.status(StatusCode::OK)
.body(Body::from(body))
.expect("infallible response builder"))
}
})
}
#[tokio::test]
async fn caches_get_responses() {
let store = super::super::MokaCache::new(100, None);
let counter = Arc::new(AtomicUsize::new(0));
let mut svc = ServiceBuilder::new()
.layer(CacheResponseLayer::from_cache(store))
.service(counting_service(counter.clone(), "hello"));
let req = Request::get("/test")
.body(Body::empty())
.expect("infallible response builder");
let resp = svc
.ready()
.await
.expect("infallible response builder")
.call(req)
.await
.expect("infallible response builder");
assert_eq!(resp.status(), StatusCode::OK);
let body = http_body_util::BodyExt::collect(resp.into_body())
.await
.expect("infallible response builder")
.to_bytes();
assert_eq!(body.as_ref(), b"hello");
assert_eq!(counter.load(Ordering::SeqCst), 1);
let req = Request::get("/test")
.body(Body::empty())
.expect("infallible response builder");
let resp = svc
.ready()
.await
.expect("infallible response builder")
.call(req)
.await
.expect("infallible response builder");
assert_eq!(resp.status(), StatusCode::OK);
let body = http_body_util::BodyExt::collect(resp.into_body())
.await
.expect("infallible response builder")
.to_bytes();
assert_eq!(body.as_ref(), b"hello");
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"inner should not be called again"
);
}
#[tokio::test]
async fn does_not_cache_post_requests() {
let store = super::super::MokaCache::new(100, None);
let counter = Arc::new(AtomicUsize::new(0));
let mut svc = ServiceBuilder::new()
.layer(CacheResponseLayer::from_cache(store))
.service(counting_service(counter.clone(), "created"));
let req = Request::post("/items")
.body(Body::empty())
.expect("infallible response builder");
let _resp = svc
.ready()
.await
.expect("infallible response builder")
.call(req)
.await
.expect("infallible response builder");
assert_eq!(counter.load(Ordering::SeqCst), 1);
let req = Request::post("/items")
.body(Body::empty())
.expect("infallible response builder");
let _resp = svc
.ready()
.await
.expect("infallible response builder")
.call(req)
.await
.expect("infallible response builder");
assert_eq!(
counter.load(Ordering::SeqCst),
2,
"POST should not be cached"
);
}
#[tokio::test]
async fn does_not_cache_non_200_responses() {
let store = super::super::MokaCache::new(100, None);
let counter = Arc::new(AtomicUsize::new(0));
let svc_inner = {
let counter = counter.clone();
tower::service_fn(move |_req: Request<Body>| {
let counter = counter.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok::<_, Infallible>(
axum::response::Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("not found"))
.expect("infallible response builder"),
)
}
})
};
let mut svc = ServiceBuilder::new()
.layer(CacheResponseLayer::from_cache(store))
.service(svc_inner);
let req = Request::get("/missing")
.body(Body::empty())
.expect("infallible response builder");
let resp = svc
.ready()
.await
.expect("infallible response builder")
.call(req)
.await
.expect("infallible response builder");
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let req = Request::get("/missing")
.body(Body::empty())
.expect("infallible response builder");
let resp = svc
.ready()
.await
.expect("infallible response builder")
.call(req)
.await
.expect("infallible response builder");
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
assert_eq!(
counter.load(Ordering::SeqCst),
2,
"404 should not be cached"
);
}
#[tokio::test]
async fn different_uris_cached_separately() {
let store = super::super::MokaCache::new(100, None);
let counter = Arc::new(AtomicUsize::new(0));
let mut svc = ServiceBuilder::new()
.layer(CacheResponseLayer::from_cache(store))
.service(counting_service(counter.clone(), "ok"));
let req = Request::get("/a")
.body(Body::empty())
.expect("infallible response builder");
let _resp = svc
.ready()
.await
.expect("infallible response builder")
.call(req)
.await
.expect("infallible response builder");
let req = Request::get("/b")
.body(Body::empty())
.expect("infallible response builder");
let _resp = svc
.ready()
.await
.expect("infallible response builder")
.call(req)
.await
.expect("infallible response builder");
assert_eq!(
counter.load(Ordering::SeqCst),
2,
"different URIs should miss"
);
let req = Request::get("/a")
.body(Body::empty())
.expect("infallible response builder");
let _resp = svc
.ready()
.await
.expect("infallible response builder")
.call(req)
.await
.expect("infallible response builder");
assert_eq!(counter.load(Ordering::SeqCst), 2, "/a should be cached");
}
#[test]
fn from_shared_accepts_arc() {
let store = Arc::new(super::super::MokaCache::new(100, None));
let _layer = CacheResponseLayer::from_shared(store);
}
#[tokio::test]
async fn caches_get_responses_very_long_uri() {
let store = super::super::MokaCache::new(100, None);
let counter = Arc::new(AtomicUsize::new(0));
let mut svc = ServiceBuilder::new()
.layer(CacheResponseLayer::from_cache(store))
.service(counting_service(counter.clone(), "hello"));
let long_uri = format!("/test/{}", "a".repeat(1000));
let req1 = Request::get(&long_uri)
.body(Body::empty())
.expect("infallible response builder");
let resp1 = svc
.ready()
.await
.expect("infallible response builder")
.call(req1)
.await
.expect("infallible response builder");
assert_eq!(resp1.status(), StatusCode::OK);
assert_eq!(counter.load(Ordering::SeqCst), 1);
let req2 = Request::get(&long_uri)
.body(Body::empty())
.expect("infallible response builder");
let resp2 = svc
.ready()
.await
.expect("infallible response builder")
.call(req2)
.await
.expect("infallible response builder");
assert_eq!(resp2.status(), StatusCode::OK);
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"Should be cached despite long URI"
);
}
}