use std::{
fmt,
future::Future,
hash::{Hash, Hasher},
num::NonZeroUsize,
pin::Pin,
sync::{Arc, Mutex},
task::{Context as TaskContext, Poll},
time::{Duration, Instant},
};
use axum::body::{Body, Bytes};
use axum::extract::Request;
use axum::response::Response;
use http::header::{ACCEPT, CACHE_CONTROL, CONTENT_TYPE};
use http_body_util::BodyExt;
use lru::LruCache;
use tower::{Layer, Service};
#[cfg(feature = "auth")]
use super::auth::AuthContext;
#[derive(Debug, Clone, Default)]
pub struct CacheScope {
pub tenant: Option<String>,
pub actor: Option<String>,
}
impl CacheScope {
pub(crate) fn as_key(&self) -> String {
match (&self.tenant, &self.actor) {
(Some(t), Some(a)) => format!("t={t}:a={a}"),
(Some(t), None) => format!("t={t}"),
(None, Some(a)) => format!("a={a}"),
(None, None) => "_".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub capacity: NonZeroUsize,
pub ttl: Option<Duration>,
pub max_body_size: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
capacity: NonZeroUsize::new(10_000).unwrap(),
ttl: Some(Duration::from_secs(60)),
max_body_size: 1024 * 1024,
}
}
}
impl CacheConfig {
pub fn with_capacity(capacity: NonZeroUsize) -> Self {
Self {
capacity,
..Self::default()
}
}
}
pub type CachePredicate = Arc<dyn Fn(&Request<Body>) -> bool + Send + Sync>;
#[derive(Clone)]
pub struct CacheLayer {
cache: RequestCache,
predicate: CachePredicate,
}
impl fmt::Debug for CacheLayer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CacheLayer")
.field("capacity", &self.cache.config.capacity)
.field("ttl", &self.cache.config.ttl)
.field("max_body_size", &self.cache.config.max_body_size)
.finish()
}
}
impl CacheLayer {
pub fn new(config: CacheConfig) -> Self {
Self::with_predicate(config, Arc::new(default_cacheable_predicate))
}
pub fn with_predicate(config: CacheConfig, predicate: CachePredicate) -> Self {
Self {
cache: RequestCache::new(config),
predicate,
}
}
}
impl<S> Layer<S> for CacheLayer {
type Service = CacheService<S>;
fn layer(&self, inner: S) -> Self::Service {
CacheService {
inner,
cache: self.cache.clone(),
predicate: Arc::clone(&self.predicate),
}
}
}
#[derive(Clone)]
pub struct CacheService<S> {
inner: S,
cache: RequestCache,
predicate: CachePredicate,
}
impl<S> fmt::Debug for CacheService<S>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CacheService")
.field("inner", &self.inner)
.finish()
}
}
impl<S> Service<Request<Body>> for CacheService<S>
where
S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Response, S::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
if !(self.predicate)(&request) {
return Box::pin(self.inner.clone().call(request));
}
let cache = self.cache.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let (parts, body) = request.into_parts();
let bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(_) => {
let req = Request::from_parts(parts, Body::empty());
return inner.call(req).await;
}
};
if bytes.len() > cache.config.max_body_size {
let req = Request::from_parts(parts, Body::from(bytes));
return inner.call(req).await;
}
#[cfg(feature = "auth")]
let scope = parts
.extensions
.get::<AuthContext>()
.map(|ctx| CacheScope {
tenant: ctx.tenant_id.clone(),
actor: ctx.subject.clone(),
})
.unwrap_or_default();
#[cfg(not(feature = "auth"))]
let scope = CacheScope::default();
let key = build_cache_key(&parts.method, &parts.uri, &parts.headers, &scope, &bytes);
{
let mut guard = cache.inner.lock().expect("cache lock poisoned");
if let Some(entry) = guard.get(&key)
&& !entry.is_expired(cache.config.ttl)
{
return Ok(entry.to_response());
}
}
let req = Request::from_parts(parts, Body::from(bytes));
let response = inner.call(req).await?;
let (parts, body) = response.into_parts();
let response_bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(_) => {
return Ok(Response::builder()
.status(parts.status)
.body(Body::empty())
.expect("valid response"));
}
};
if parts.status.is_success() && response_bytes.len() <= cache.config.max_body_size {
let entry = CachedResponse {
status: parts.status,
headers: parts
.headers
.iter()
.map(|(k, v)| (k.to_string(), v.as_bytes().to_vec()))
.collect(),
body: response_bytes.clone(),
created_at: Instant::now(),
};
let mut guard = cache.inner.lock().expect("cache lock poisoned");
guard.put(key, entry);
}
Ok(Response::builder()
.status(parts.status)
.body(Body::from(response_bytes))
.expect("valid response"))
})
}
}
#[derive(Clone)]
struct RequestCache {
inner: Arc<Mutex<LruCache<CacheKey, CachedResponse, gxhash::GxBuildHasher>>>,
config: CacheConfig,
}
impl RequestCache {
fn new(config: CacheConfig) -> Self {
let cache = LruCache::with_hasher(config.capacity, gxhash::GxBuildHasher::default());
Self {
inner: Arc::new(Mutex::new(cache)),
config,
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
struct CacheKey {
scope: String,
method: String,
uri: String,
content_type: String,
accept: String,
body_hash: u64,
}
impl Hash for CacheKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.scope.hash(state);
self.method.hash(state);
self.uri.hash(state);
self.content_type.hash(state);
self.accept.hash(state);
self.body_hash.hash(state);
}
}
fn build_cache_key(
method: &http::Method,
uri: &http::Uri,
headers: &http::HeaderMap,
scope: &CacheScope,
body: &Bytes,
) -> CacheKey {
let body_hash = gxhash::gxhash64(body, 0);
CacheKey {
scope: scope.as_key(),
method: method.to_string(),
uri: uri.to_string(),
content_type: header_value(headers, CONTENT_TYPE),
accept: header_value(headers, ACCEPT),
body_hash,
}
}
fn header_value(headers: &http::HeaderMap, name: http::header::HeaderName) -> String {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string()
}
fn default_cacheable_predicate(request: &Request<Body>) -> bool {
let path = request.uri().path();
if path.starts_with("/health/") || path == "/metrics" {
return false;
}
if let Some(value) = request.headers().get(CACHE_CONTROL).and_then(|v| v.to_str().ok())
&& (value.contains("no-store") || value.contains("no-cache"))
{
return false;
}
let method = request.method();
if method == http::Method::GET {
return true;
}
if method == http::Method::POST && is_read_rpc_method(path) {
return true;
}
false
}
fn is_read_rpc_method(path: &str) -> bool {
let trimmed = path.trim_start_matches('/');
let method_name = trimmed
.rfind('/')
.map(|pos| &trimmed[pos + 1..])
.unwrap_or(trimmed);
method_name.starts_with("List")
|| method_name.starts_with("Get")
|| method_name.starts_with("Check")
|| method_name.starts_with("Expand")
|| method_name.starts_with("To")
|| method_name.starts_with("Describe")
}
#[derive(Debug, Clone)]
struct CachedResponse {
status: http::StatusCode,
headers: Vec<(String, Vec<u8>)>,
body: Bytes,
created_at: Instant,
}
impl CachedResponse {
fn is_expired(&self, ttl: Option<Duration>) -> bool {
ttl.is_some_and(|duration| self.created_at.elapsed() > duration)
}
fn to_response(&self) -> Response {
let mut headers = http::HeaderMap::new();
for (name, value) in &self.headers {
if let (Ok(name), Ok(value)) = (
http::HeaderName::from_bytes(name.as_bytes()),
http::HeaderValue::from_bytes(value),
) {
let _ = headers.insert(name, value);
}
}
Response::builder()
.status(self.status)
.body(Body::from(self.body.clone()))
.expect("valid cached response")
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::response::IntoResponse;
use http::Request;
use std::convert::Infallible;
use std::sync::{Arc, Mutex};
use tower::ServiceExt;
fn test_config() -> CacheConfig {
CacheConfig {
capacity: NonZeroUsize::new(100).unwrap(),
ttl: Some(Duration::from_secs(60)),
max_body_size: 1024,
}
}
#[test]
fn default_predicate_caches_get_requests() {
let req = Request::get("/foo").body(Body::empty()).unwrap();
assert!(default_cacheable_predicate(&req));
}
#[test]
fn default_predicate_skips_health_and_metrics() {
let health = Request::get("/health/live").body(Body::empty()).unwrap();
let metrics = Request::get("/metrics").body(Body::empty()).unwrap();
assert!(!default_cacheable_predicate(&health));
assert!(!default_cacheable_predicate(&metrics));
}
#[test]
fn default_predicate_caches_read_rpcs() {
let list = Request::post("/iam.v1.TenantService/ListTenants")
.body(Body::empty())
.unwrap();
let get = Request::post("/iam.v1.TenantService/GetTenant")
.body(Body::empty())
.unwrap();
let create = Request::post("/iam.v1.TenantService/CreateTenant")
.body(Body::empty())
.unwrap();
assert!(default_cacheable_predicate(&list));
assert!(default_cacheable_predicate(&get));
assert!(!default_cacheable_predicate(&create));
}
#[test]
fn default_predicate_respects_cache_control_no_store() {
let req = Request::get("/foo")
.header(CACHE_CONTROL, "no-store")
.body(Body::empty())
.unwrap();
assert!(!default_cacheable_predicate(&req));
}
#[tokio::test]
async fn cache_returns_hits_without_calling_inner_service() {
let count = Arc::new(Mutex::new(0usize));
let service = tower::service_fn({
let count = Arc::clone(&count);
move |_req: Request<Body>| {
let count = Arc::clone(&count);
async move {
*count.lock().unwrap() += 1;
Ok::<_, Infallible>("hello".into_response())
}
}
});
let mut cache = CacheLayer::new(test_config()).layer(service);
let req1 = Request::get("/cached").body(Body::empty()).unwrap();
let resp1 = cache.ready().await.unwrap().call(req1).await.unwrap();
let body1 = resp1.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body1[..], b"hello");
assert_eq!(*count.lock().unwrap(), 1);
let req2 = Request::get("/cached").body(Body::empty()).unwrap();
let resp2 = cache.call(req2).await.unwrap();
let body2 = resp2.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body2[..], b"hello");
assert_eq!(*count.lock().unwrap(), 1); }
#[tokio::test]
#[cfg(feature = "auth")]
async fn cache_is_scoped_by_tenant_and_actor() {
let service = tower::service_fn(|_req: Request<Body>| async move {
Ok::<_, Infallible>("response".into_response())
});
let mut cache = CacheLayer::new(test_config()).layer(service);
let mut req_a = Request::get("/scoped").body(Body::empty()).unwrap();
req_a.extensions_mut().insert(AuthContext::authenticated("t1", "a1"));
let _ = cache.ready().await.unwrap().call(req_a).await.unwrap();
let mut req_b = Request::get("/scoped").body(Body::empty()).unwrap();
req_b.extensions_mut().insert(AuthContext::authenticated("t1", "a2"));
let resp_b = cache.call(req_b).await.unwrap();
let body_b = resp_b.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body_b[..], b"response");
}
#[tokio::test]
async fn cache_respects_ttl() {
let service = tower::service_fn(|_req: Request<Body>| async move {
Ok::<_, Infallible>("response".into_response())
});
let config = CacheConfig {
ttl: Some(Duration::from_millis(10)),
..test_config()
};
let mut cache = CacheLayer::new(config).layer(service);
let req1 = Request::get("/ttl").body(Body::empty()).unwrap();
let _ = cache.ready().await.unwrap().call(req1).await.unwrap();
tokio::time::sleep(Duration::from_millis(20)).await;
let req2 = Request::get("/ttl").body(Body::empty()).unwrap();
let resp2 = cache.call(req2).await.unwrap();
let body2 = resp2.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body2[..], b"response");
}
}