use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use bytes::Bytes;
use dashmap::DashMap;
use http::{Response, header};
use http_body_util::{BodyExt, Full};
use hyper::Request;
use hyper::body::Incoming;
use crate::context::RequestContext;
use crate::middleware::{BoxFuture, Middleware, Next};
use crate::response::BoxBody;
pub(crate) const CACHE_TTL_HEADER: &str = "x-rapina-cache-ttl";
pub const CACHE_STATUS_HEADER: &str = "x-cache";
const CLEANUP_INTERVAL: u64 = 1000;
type CacheFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
#[derive(Clone, Debug)]
pub struct CachedResponse {
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: Bytes,
}
pub trait CacheBackend: Send + Sync + 'static {
fn get(&self, key: &str) -> CacheFuture<'_, Option<CachedResponse>>;
fn set(&self, key: &str, response: CachedResponse, ttl: Duration) -> CacheFuture<'_, ()>;
fn invalidate_prefix(&self, prefix: &str) -> CacheFuture<'_, ()>;
}
struct CacheEntry {
response: CachedResponse,
expires_at: Instant,
created_at: Instant,
}
pub struct InMemoryCache {
entries: Arc<DashMap<String, CacheEntry>>,
max_entries: usize,
op_count: Arc<AtomicU64>,
}
impl InMemoryCache {
pub fn new(max_entries: usize) -> Self {
Self {
entries: Arc::new(DashMap::new()),
max_entries,
op_count: Arc::new(AtomicU64::new(0)),
}
}
fn maybe_cleanup(&self) {
let count = self.op_count.fetch_add(1, Ordering::Relaxed);
if count > 0 && count % CLEANUP_INTERVAL == 0 {
self.cleanup_expired();
}
}
fn cleanup_expired(&self) {
let now = Instant::now();
self.entries.retain(|_, entry| entry.expires_at > now);
}
fn evict_if_full(&self) {
if self.entries.len() < self.max_entries {
return;
}
self.cleanup_expired();
if self.entries.len() < self.max_entries {
return;
}
let oldest_key = self
.entries
.iter()
.min_by_key(|entry| entry.value().created_at)
.map(|entry| entry.key().clone());
if let Some(key) = oldest_key {
self.entries.remove(&key);
}
}
}
impl CacheBackend for InMemoryCache {
fn get(&self, key: &str) -> CacheFuture<'_, Option<CachedResponse>> {
self.maybe_cleanup();
let result = self.entries.get(key).and_then(|entry| {
if entry.expires_at > Instant::now() {
Some(entry.response.clone())
} else {
None
}
});
if result.is_none() {
self.entries
.remove_if(key, |_, entry| entry.expires_at <= Instant::now());
}
Box::pin(std::future::ready(result))
}
fn set(&self, key: &str, response: CachedResponse, ttl: Duration) -> CacheFuture<'_, ()> {
self.maybe_cleanup();
self.evict_if_full();
let now = Instant::now();
self.entries.insert(
key.to_string(),
CacheEntry {
response,
expires_at: now + ttl,
created_at: now,
},
);
Box::pin(std::future::ready(()))
}
fn invalidate_prefix(&self, prefix: &str) -> CacheFuture<'_, ()> {
self.entries.retain(|key, _| !key.starts_with(prefix));
Box::pin(std::future::ready(()))
}
}
pub enum CacheConfig {
InMemory { max_entries: usize },
#[cfg(feature = "cache-redis")]
Redis { url: String },
}
impl CacheConfig {
pub fn in_memory(max_entries: usize) -> Self {
CacheConfig::InMemory { max_entries }
}
#[cfg(feature = "cache-redis")]
pub fn redis(url: &str) -> Self {
CacheConfig::Redis {
url: url.to_string(),
}
}
pub async fn build(self) -> Result<Arc<dyn CacheBackend>, std::io::Error> {
match self {
CacheConfig::InMemory { max_entries } => Ok(Arc::new(InMemoryCache::new(max_entries))),
#[cfg(feature = "cache-redis")]
CacheConfig::Redis { url } => {
let backend = crate::cache_redis::RedisCache::connect(&url)
.await
.map_err(|e| {
std::io::Error::other(format!("Redis connection failed: {}", e))
})?;
Ok(Arc::new(backend))
}
}
}
}
pub struct CacheMiddleware {
backend: Arc<dyn CacheBackend>,
}
impl CacheMiddleware {
pub fn new(backend: Arc<dyn CacheBackend>) -> Self {
Self { backend }
}
}
impl Middleware for CacheMiddleware {
fn handle<'a>(
&'a self,
req: Request<Incoming>,
_ctx: &'a RequestContext,
next: Next<'a>,
) -> BoxFuture<'a, Response<BoxBody>> {
Box::pin(async move {
let method = req.method().clone();
let path = req.uri().path().to_string();
let query = req.uri().query().unwrap_or("").to_string();
if method == http::Method::GET {
let cache_key = build_cache_key(&path, &query);
if let Some(cached) = self.backend.get(&cache_key).await {
return build_response_from_cache(cached, "HIT");
}
let response = next.run(req).await;
if let Some(ttl) = extract_ttl_header(&response) {
let (parts, body) = response.into_parts();
let body_bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(_) => {
return Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
.body(Full::new(Bytes::new()))
.unwrap();
}
};
let cached = CachedResponse {
status: parts.status.as_u16(),
headers: parts
.headers
.iter()
.filter(|(name, _)| name.as_str() != CACHE_TTL_HEADER)
.map(|(name, value)| {
(name.to_string(), value.to_str().unwrap_or("").to_string())
})
.collect(),
body: body_bytes.clone(),
};
self.backend
.set(&cache_key, cached, Duration::from_secs(ttl))
.await;
let mut response = Response::from_parts(parts, Full::new(body_bytes));
response.headers_mut().remove(CACHE_TTL_HEADER);
response
.headers_mut()
.insert(CACHE_STATUS_HEADER, http::HeaderValue::from_static("MISS"));
return response;
}
return response;
}
let response = next.run(req).await;
if is_mutation(&method) && response.status().is_success() {
let prefix = build_invalidation_prefix(&path);
self.backend.invalidate_prefix(&prefix).await;
}
response
})
}
}
fn build_cache_key(path: &str, query: &str) -> String {
if query.is_empty() {
format!("GET:{}", path)
} else {
let mut params: Vec<&str> = query.split('&').collect();
params.sort();
format!("GET:{}?{}", path, params.join("&"))
}
}
fn build_invalidation_prefix(path: &str) -> String {
let base = path
.rfind('/')
.filter(|&i| i > 0)
.map(|i| &path[..i])
.unwrap_or(path);
format!("GET:{}", base)
}
fn is_mutation(method: &http::Method) -> bool {
matches!(
*method,
http::Method::POST | http::Method::PUT | http::Method::DELETE | http::Method::PATCH
)
}
fn extract_ttl_header(response: &Response<BoxBody>) -> Option<u64> {
response
.headers()
.get(CACHE_TTL_HEADER)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse().ok())
}
fn build_response_from_cache(cached: CachedResponse, status: &'static str) -> Response<BoxBody> {
let mut builder = Response::builder().status(cached.status);
for (name, value) in &cached.headers {
if let (Ok(name), Ok(value)) = (
header::HeaderName::from_bytes(name.as_bytes()),
header::HeaderValue::from_str(value),
) {
builder = builder.header(name, value);
}
}
let mut response = builder.body(Full::new(cached.body)).unwrap();
response
.headers_mut()
.insert(CACHE_STATUS_HEADER, http::HeaderValue::from_static(status));
response
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_cache_set_and_get() {
let cache = InMemoryCache::new(100);
let response = CachedResponse {
status: 200,
headers: vec![("content-type".to_string(), "application/json".to_string())],
body: Bytes::from(r#"{"ok":true}"#),
};
cache
.set("key1", response.clone(), Duration::from_secs(60))
.await;
let result = cache.get("key1").await;
assert!(result.is_some());
let cached = result.unwrap();
assert_eq!(cached.status, 200);
assert_eq!(cached.body, Bytes::from(r#"{"ok":true}"#));
}
#[tokio::test]
async fn test_in_memory_cache_miss() {
let cache = InMemoryCache::new(100);
let result = cache.get("nonexistent").await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_in_memory_cache_ttl_expiry() {
let cache = InMemoryCache::new(100);
let response = CachedResponse {
status: 200,
headers: vec![],
body: Bytes::from("data"),
};
cache.set("key1", response, Duration::from_millis(1)).await;
tokio::time::sleep(Duration::from_millis(10)).await;
let result = cache.get("key1").await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_in_memory_cache_invalidate_prefix() {
let cache = InMemoryCache::new(100);
let response = CachedResponse {
status: 200,
headers: vec![],
body: Bytes::from("data"),
};
cache
.set("GET:/users", response.clone(), Duration::from_secs(60))
.await;
cache
.set("GET:/users/1", response.clone(), Duration::from_secs(60))
.await;
cache
.set("GET:/posts", response.clone(), Duration::from_secs(60))
.await;
cache.invalidate_prefix("GET:/users").await;
assert!(cache.get("GET:/users").await.is_none());
assert!(cache.get("GET:/users/1").await.is_none());
assert!(cache.get("GET:/posts").await.is_some());
}
#[tokio::test]
async fn test_in_memory_cache_max_entries_eviction() {
let cache = InMemoryCache::new(2);
let response = CachedResponse {
status: 200,
headers: vec![],
body: Bytes::from("data"),
};
cache
.set("key1", response.clone(), Duration::from_secs(60))
.await;
cache
.set("key2", response.clone(), Duration::from_secs(60))
.await;
cache
.set("key3", response.clone(), Duration::from_secs(60))
.await;
assert!(cache.get("key1").await.is_none());
assert!(cache.get("key2").await.is_some());
assert!(cache.get("key3").await.is_some());
}
#[tokio::test]
async fn test_in_memory_cache_cleanup_expired() {
let cache = InMemoryCache::new(100);
let response = CachedResponse {
status: 200,
headers: vec![],
body: Bytes::from("data"),
};
cache
.set("expired", response.clone(), Duration::from_millis(1))
.await;
cache
.set("fresh", response.clone(), Duration::from_secs(60))
.await;
tokio::time::sleep(Duration::from_millis(10)).await;
cache.cleanup_expired();
assert_eq!(cache.entries.len(), 1);
assert!(cache.entries.get("fresh").is_some());
}
#[test]
fn test_build_cache_key_no_query() {
assert_eq!(build_cache_key("/users", ""), "GET:/users");
}
#[test]
fn test_build_cache_key_with_query() {
let key = build_cache_key("/users", "page=1&sort=name");
assert_eq!(key, "GET:/users?page=1&sort=name");
}
#[test]
fn test_build_cache_key_sorts_query_params() {
let key1 = build_cache_key("/users", "sort=name&page=1");
let key2 = build_cache_key("/users", "page=1&sort=name");
assert_eq!(key1, key2);
}
#[test]
fn test_build_invalidation_prefix() {
assert_eq!(build_invalidation_prefix("/users/123"), "GET:/users");
assert_eq!(build_invalidation_prefix("/users"), "GET:/users");
assert_eq!(build_invalidation_prefix("/"), "GET:/");
}
#[test]
fn test_is_mutation() {
assert!(is_mutation(&http::Method::POST));
assert!(is_mutation(&http::Method::PUT));
assert!(is_mutation(&http::Method::DELETE));
assert!(is_mutation(&http::Method::PATCH));
assert!(!is_mutation(&http::Method::GET));
assert!(!is_mutation(&http::Method::HEAD));
}
#[test]
fn test_cache_config_in_memory() {
let config = CacheConfig::in_memory(500);
assert!(matches!(config, CacheConfig::InMemory { max_entries: 500 }));
}
#[test]
fn test_build_response_from_cache() {
let cached = CachedResponse {
status: 200,
headers: vec![("content-type".to_string(), "text/plain".to_string())],
body: Bytes::from("hello"),
};
let response = build_response_from_cache(cached, "HIT");
assert_eq!(response.status(), 200);
assert_eq!(response.headers().get(CACHE_STATUS_HEADER).unwrap(), "HIT");
assert_eq!(
response.headers().get("content-type").unwrap(),
"text/plain"
);
}
}