use axum::{
extract::Request,
http::{
header::{CACHE_CONTROL, ETAG, IF_NONE_MATCH, LAST_MODIFIED},
HeaderValue, StatusCode,
},
response::Response,
};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::{
collections::HashMap,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::{SystemTime, UNIX_EPOCH},
};
use crate::impl_default_new;
use tower::{Layer, Service};
#[cfg(feature = "logging")]
use tracing;
const DEFAULT_CACHE_TTL: u64 = 300; const DEFAULT_MAX_CACHE_SIZE: usize = 100 * 1024 * 1024; const DEFAULT_MAX_CACHE_ENTRIES: usize = 10000;
const DEFAULT_CACHEABLE_METHODS: &[&str] = &["GET", "HEAD"];
const DEFAULT_CACHEABLE_STATUS_CODES: &[u16] = &[200, 203, 204, 206, 300, 301, 404, 410];
#[derive(Debug, Clone)]
struct MinHeapEntry {
key: CacheKey,
last_accessed: u64,
}
#[derive(Debug, Clone)]
struct MinHeap {
entries: Vec<MinHeapEntry>,
}
impl MinHeap {
fn new() -> Self {
Self {
entries: Vec::new(),
}
}
fn push(&mut self, entry: MinHeapEntry) {
self.entries.push(entry);
self.sift_up(self.entries.len().saturating_sub(1));
}
fn extract_min(&mut self) -> Option<MinHeapEntry> {
if self.entries.is_empty() {
return None;
}
let min = Some(self.entries[0].clone());
if self.entries.len() > 1 {
self.entries[0] = self.entries.pop().unwrap();
self.sift_down(0);
} else {
self.entries.pop();
}
min
}
fn remove(&mut self, key: &CacheKey) -> bool {
if let Some(pos) = self.entries.iter().position(|e| &e.key == key) {
let last = self.entries.pop().unwrap();
if pos < self.entries.len() {
self.entries[pos] = last;
let mut sift_up_done = false;
let mut sift_down_done = false;
if pos > 0 {
self.sift_up(pos);
sift_up_done = true;
}
if pos < self.entries.len() {
self.sift_down(pos);
sift_down_done = true;
}
if !sift_up_done && !sift_down_done {
}
}
return true;
}
false
}
fn sift_up(&mut self, mut idx: usize) {
while idx > 0 {
let parent = (idx.saturating_sub(1)) / 2;
if self.entries[parent].last_accessed > self.entries[idx].last_accessed {
self.entries.swap(parent, idx);
idx = parent;
} else {
break;
}
}
}
fn sift_down(&mut self, mut idx: usize) {
let len = self.entries.len();
loop {
let left = 2 * idx + 1;
let right = 2 * idx + 2;
let mut smallest = idx;
if left < len && self.entries[left].last_accessed < self.entries[smallest].last_accessed
{
smallest = left;
}
if right < len
&& self.entries[right].last_accessed < self.entries[smallest].last_accessed
{
smallest = right;
}
if smallest != idx {
self.entries.swap(idx, smallest);
idx = smallest;
} else {
break;
}
}
}
}
impl_default_new!(MinHeap);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub ttl_seconds: u64,
pub max_size_bytes: usize,
pub max_entries: usize,
#[serde(default = "default_cacheable_methods")]
pub cacheable_methods: Vec<String>,
#[serde(default = "default_cacheable_status_codes")]
pub cacheable_status_codes: Vec<u16>,
}
fn default_cacheable_methods() -> Vec<String> {
DEFAULT_CACHEABLE_METHODS
.iter()
.map(|s| s.to_string())
.collect()
}
fn default_cacheable_status_codes() -> Vec<u16> {
DEFAULT_CACHEABLE_STATUS_CODES.to_vec()
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
ttl_seconds: DEFAULT_CACHE_TTL,
max_size_bytes: DEFAULT_MAX_CACHE_SIZE,
max_entries: DEFAULT_MAX_CACHE_ENTRIES,
cacheable_methods: default_cacheable_methods(),
cacheable_status_codes: default_cacheable_status_codes(),
}
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
body: Arc<Vec<u8>>,
headers: HashMap<String, HeaderValue>,
etag: String,
last_modified: u64,
expires_at: u64,
size: usize,
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct CacheKey {
method: String,
uri: String,
query_string: String,
headers_hash: String,
body_hash: Option<String>,
}
impl CacheKey {
pub fn new(
method: &str,
uri: &str,
body: Option<&[u8]>,
headers: &axum::http::HeaderMap,
) -> Self {
let query_string = if let Some(query) = uri.split('?').nth(1) {
query.to_string()
} else {
String::new()
};
let cache_headers: Vec<_> = headers
.iter()
.filter(|(name, _)| {
!matches!(
name.as_str(),
"cache-control" | "pragma" | "authorization" | "cookie"
)
})
.map(|(name, value)| format!("{}:{}", name.as_str(), value.to_str().unwrap_or("")))
.collect();
let headers_hash = Self::secure_hash(cache_headers.join("\n").as_bytes());
let body_hash = body.map(Self::secure_hash);
Self {
method: method.to_string(),
uri: uri.split('?').next().unwrap_or(uri).to_string(),
query_string,
headers_hash,
body_hash,
}
}
fn secure_hash(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
format!("{:x}", hasher.finalize())
}
}
#[derive(Clone)]
pub struct CacheMiddleware {
config: CacheConfig,
cache: Arc<DashMap<CacheKey, Arc<CacheEntry>>>,
current_size: Arc<AtomicUsize>,
entry_count: Arc<AtomicUsize>,
lru_heap: Arc<DashMap<CacheKey, MinHeap>>,
expiration_index: Arc<DashMap<u64, Vec<CacheKey>>>,
}
impl CacheMiddleware {
pub fn new(config: CacheConfig) -> Self {
Self {
config,
cache: Arc::new(DashMap::new()),
current_size: Arc::new(AtomicUsize::new(0)),
entry_count: Arc::new(AtomicUsize::new(0)),
lru_heap: Arc::new(DashMap::new()),
expiration_index: Arc::new(DashMap::new()),
}
}
#[inline]
pub fn generate_etag(body: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(body);
let result = hasher.finalize();
format!("\"{:x}\"", result)
}
#[inline]
pub fn generate_last_modified() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
#[inline]
pub fn generate_cache_key(
method: &str,
uri: &str,
body: Option<&[u8]>,
headers: &axum::http::HeaderMap,
) -> CacheKey {
CacheKey::new(method, uri, body, headers)
}
#[inline]
pub fn should_cache(&self, method: &str, status: u16) -> bool {
let method_upper = method.to_uppercase();
self.config
.cacheable_methods
.iter()
.any(|m| m == &method_upper)
&& self.config.cacheable_status_codes.contains(&status)
}
#[inline]
fn is_expired(&self, expires_at: u64) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
now > expires_at
}
#[inline]
fn now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
fn evict_lru(&self, min_needed: usize) {
let max_entries = self.config.max_entries;
let max_size = self.config.max_size_bytes;
let mut freed = 0;
let mut removed_count = 0;
const MAX_ATTEMPTS: usize = 100;
loop {
let size_now = self.current_size.load(Ordering::Acquire);
let count_now = self.entry_count.load(Ordering::Acquire);
if size_now + min_needed <= max_size && count_now < max_entries {
break;
}
if removed_count >= MAX_ATTEMPTS {
#[cfg(feature = "logging")]
tracing::warn!("LRU eviction hit safety limit of {} entries", MAX_ATTEMPTS);
break;
}
let mut found_entry = false;
for mut shard in self.lru_heap.iter_mut() {
let heap = shard.value_mut();
if let Some(entry) = heap.extract_min() {
if let Some((_, cache_entry)) = self.cache.remove(&entry.key) {
self.remove_from_expiration_index(&entry.key, cache_entry.expires_at);
let size = cache_entry.body.len();
self.current_size.fetch_sub(size, Ordering::Release);
self.entry_count.fetch_sub(1, Ordering::Release);
freed += size;
removed_count += 1;
}
found_entry = true;
break;
}
}
if !found_entry || self.entry_count.load(Ordering::Acquire) == 0 {
break;
}
}
#[cfg(feature = "logging")]
if removed_count > 0 {
tracing::debug!(
freed_bytes = freed,
removed_entries = removed_count,
"LRU eviction completed"
);
}
}
fn add_to_expiration_index(&self, key: &CacheKey, expires_at: u64) {
let mut shard = self.expiration_index.entry(expires_at).or_default();
shard.push(key.clone());
}
fn remove_from_expiration_index(&self, key: &CacheKey, expires_at: u64) {
if let Some(mut shard) = self.expiration_index.get_mut(&expires_at) {
shard.retain(|k| k != key);
}
}
fn clear_expired_entries(&self) -> usize {
let now = CacheMiddleware::now();
let mut removed = 0;
let expired_times: Vec<u64> = self
.expiration_index
.iter()
.filter(|shard| shard.key() <= &now)
.map(|shard| *shard.key())
.collect();
for expires_at in expired_times {
if let Some((_, keys)) = self.expiration_index.remove(&expires_at) {
for key in keys {
if let Some((_, entry)) = self.cache.remove(&key) {
self.entry_count.fetch_sub(1, Ordering::Relaxed);
self.current_size.fetch_sub(entry.size, Ordering::Relaxed);
removed += 1;
}
}
}
}
removed
}
fn cleanup_and_evict(&self, needed: usize) {
let _ = self.clear_expired_entries();
let current_size = self.current_size.load(Ordering::Relaxed);
let entry_count = self.entry_count.load(Ordering::Relaxed);
if current_size + needed > self.config.max_size_bytes
|| entry_count >= self.config.max_entries
{
self.evict_lru(needed);
}
}
#[inline]
fn enforce_size_limit(&self, needed: usize) {
let current_size = self.current_size.load(Ordering::Relaxed);
let entry_count = self.entry_count.load(Ordering::Relaxed);
if current_size + needed > self.config.max_size_bytes
|| entry_count >= self.config.max_entries
{
self.cleanup_and_evict(needed);
}
}
fn update_access_time(&self, key: &CacheKey) {
let now = CacheMiddleware::now();
for mut shard in self.lru_heap.iter_mut() {
let heap = shard.value_mut();
if heap.remove(key) {
heap.push(MinHeapEntry {
key: key.clone(),
last_accessed: now,
});
return;
}
}
}
}
impl<S> Layer<S> for CacheMiddleware {
type Service = CacheService<S>;
fn layer(&self, inner: S) -> Self::Service {
CacheService {
inner,
middleware: self.clone(),
}
}
}
#[derive(Clone)]
pub struct CacheService<S> {
inner: S,
middleware: CacheMiddleware,
}
impl<S> Service<Request> for CacheService<S>
where
S: Service<Request, Response = Response> + Send + 'static + Clone,
S::Future: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = futures_util::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
let middleware = self.middleware.clone();
let method = req.method().to_string();
let uri = req.uri().to_string();
if !middleware.should_cache(&method, 200) {
let mut inner = self.inner.clone();
return Box::pin(async move { inner.call(req).await });
}
let cache_key = CacheMiddleware::generate_cache_key(&method, &uri, None, req.headers());
if let Some(if_none_match) = req.headers().get(IF_NONE_MATCH) {
if let Some(entry) = middleware.cache.get(&cache_key) {
if if_none_match
.to_str()
.ok()
.map(|s| s == entry.etag)
.unwrap_or(false)
{
let mut response = Response::new(axum::body::Body::empty());
*response.status_mut() = StatusCode::NOT_MODIFIED;
return Box::pin(async move { Ok(response) });
}
}
}
if let Some(entry) = middleware.cache.get(&cache_key) {
if !middleware.is_expired(entry.expires_at) {
middleware.update_access_time(&cache_key);
let body_bytes = bytes::Bytes::copy_from_slice(&entry.body);
let mut response = Response::new(axum::body::Body::from(body_bytes));
if let Ok(etag_value) = HeaderValue::from_str(&entry.etag) {
response.headers_mut().insert(ETAG, etag_value);
}
if let Ok(lm_value) = HeaderValue::from_str(&entry.last_modified.to_string()) {
response.headers_mut().insert(LAST_MODIFIED, lm_value);
}
if let Ok(cc_value) =
HeaderValue::from_str(&format!("max-age={}", middleware.config.ttl_seconds))
{
response.headers_mut().insert(CACHE_CONTROL, cc_value);
}
for (name, value) in &entry.headers {
if let Ok(name) = axum::http::HeaderName::from_bytes(name.as_bytes()) {
response.headers_mut().insert(name, value.clone());
}
}
return Box::pin(async move { Ok(response) });
}
}
let mut inner = self.inner.clone();
Box::pin(async move {
let response = inner.call(req).await?;
let status = response.status().as_u16();
if middleware.should_cache(&method, status) {
let (parts, body) = response.into_parts();
let body_bytes = match axum::body::to_bytes(body, 10 * 1024 * 1024).await {
Ok(bytes) => bytes.to_vec(),
Err(_e) => {
#[cfg(feature = "logging")]
tracing::error!(error = %_e, "Failed to convert response body to bytes for caching");
let response = Response::from_parts(parts, axum::body::Body::empty());
return Ok(response);
}
};
let body_len = body_bytes.len();
let etag = CacheMiddleware::generate_etag(&body_bytes);
let last_modified = CacheMiddleware::generate_last_modified();
let expires_at = last_modified + middleware.config.ttl_seconds;
let mut headers = HashMap::new();
for (name, value) in parts.headers.iter() {
if name != CACHE_CONTROL && name != ETAG && name != LAST_MODIFIED {
headers.insert(name.as_str().to_string(), value.clone());
}
}
let entry = CacheEntry {
body: Arc::new(body_bytes),
headers,
etag: etag.clone(),
last_modified,
expires_at,
size: body_len,
};
let entry_size = body_len;
middleware.enforce_size_limit(entry_size);
let current_size = middleware.current_size.load(Ordering::Relaxed);
if current_size + entry_size <= middleware.config.max_size_bytes {
let entry_arc = Arc::new(entry);
let cache_key_for_heap = cache_key.clone();
let cache_key_for_expiry = cache_key.clone();
middleware
.cache
.insert(cache_key.clone(), Arc::clone(&entry_arc));
if let Some(mut shard) = middleware.lru_heap.iter_mut().next() {
let heap = shard.value_mut();
heap.push(MinHeapEntry {
key: cache_key_for_heap,
last_accessed: CacheMiddleware::now(),
});
}
middleware.add_to_expiration_index(&cache_key_for_expiry, expires_at);
middleware
.current_size
.fetch_add(entry_size, Ordering::Relaxed);
middleware.entry_count.fetch_add(1, Ordering::Relaxed);
let body_bytes = bytes::Bytes::copy_from_slice(&entry_arc.body);
let mut response =
Response::from_parts(parts, axum::body::Body::from(body_bytes));
if let Ok(etag_value) = HeaderValue::from_str(&entry_arc.etag) {
response.headers_mut().insert(ETAG, etag_value);
}
if let Ok(lm_value) =
HeaderValue::from_str(&entry_arc.last_modified.to_string())
{
response.headers_mut().insert(LAST_MODIFIED, lm_value);
}
if let Ok(cc_value) =
HeaderValue::from_str(&format!("max-age={}", middleware.config.ttl_seconds))
{
response.headers_mut().insert(CACHE_CONTROL, cc_value);
}
return Ok(response);
}
let body_bytes = bytes::Bytes::copy_from_slice(&entry.body);
let mut response = Response::from_parts(parts, axum::body::Body::from(body_bytes));
if let Ok(etag_value) = HeaderValue::from_str(&etag) {
response.headers_mut().insert(ETAG, etag_value);
}
if let Ok(lm_value) = HeaderValue::from_str(&last_modified.to_string()) {
response.headers_mut().insert(LAST_MODIFIED, lm_value);
}
if let Ok(cc_value) =
HeaderValue::from_str(&format!("max-age={}", middleware.config.ttl_seconds))
{
response.headers_mut().insert(CACHE_CONTROL, cc_value);
}
Ok(response)
} else {
let (parts, body) = response.into_parts();
Ok(Response::from_parts(parts, body))
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_cache_config() {
let config = CacheConfig::default();
assert_eq!(config.ttl_seconds, 300);
assert_eq!(config.max_size_bytes, 100 * 1024 * 1024);
assert!(config.cacheable_methods.contains(&"GET".to_string()));
}
#[test]
fn test_etag_generation() {
let body = b"Hello, World!";
let etag1 = CacheMiddleware::generate_etag(body);
let etag2 = CacheMiddleware::generate_etag(body);
assert_eq!(etag1, etag2);
let different_body = b"Different content";
let etag3 = CacheMiddleware::generate_etag(different_body);
assert_ne!(etag1, etag3);
}
#[test]
fn test_cache_key_generation() {
let headers = axum::http::HeaderMap::new();
let key1 = CacheMiddleware::generate_cache_key("GET", "/api/users", None, &headers);
let key2 = CacheMiddleware::generate_cache_key("GET", "/api/users", None, &headers);
assert_eq!(key1, key2);
let key3 = CacheMiddleware::generate_cache_key(
"POST",
"/api/users",
Some(b"{\"name\":\"test\"}"),
&headers,
);
assert_ne!(key1, key3);
}
#[test]
fn test_should_cache() {
let config = CacheConfig::default();
let middleware = CacheMiddleware::new(config);
assert!(middleware.should_cache("GET", 200));
assert!(middleware.should_cache("GET", 404));
assert!(!middleware.should_cache("POST", 200));
assert!(!middleware.should_cache("GET", 500));
}
#[test]
fn test_cache_key_with_body() {
let headers = axum::http::HeaderMap::new();
let key1 = CacheMiddleware::generate_cache_key("GET", "/api/users", None, &headers);
let key2 = CacheMiddleware::generate_cache_key("GET", "/api/users", None, &headers);
assert_eq!(key1, key2);
let key3 =
CacheMiddleware::generate_cache_key("POST", "/api/users", Some(b"body"), &headers);
assert_ne!(key1, key3);
let key4 = CacheMiddleware::generate_cache_key(
"GET",
"/api/users",
Some(b"different body"),
&headers,
);
assert_ne!(key1, key4);
}
#[test]
fn test_min_heap_operations() {
let mut heap = MinHeap::new();
assert!(heap.entries.is_empty());
let empty_headers = axum::http::HeaderMap::new();
heap.push(MinHeapEntry {
key: CacheKey::new("GET", "/a", Some(b""), &empty_headers),
last_accessed: 1,
});
heap.push(MinHeapEntry {
key: CacheKey::new("GET", "/b", Some(b""), &empty_headers),
last_accessed: 3,
});
heap.push(MinHeapEntry {
key: CacheKey::new("GET", "/c", Some(b""), &empty_headers),
last_accessed: 2,
});
assert_eq!(heap.entries.len(), 3);
let min = heap.extract_min().unwrap();
assert_eq!(min.last_accessed, 1);
}
}