use std::{
future::{ready, Ready},
sync::Arc,
time::{Duration, Instant},
};
use actix_service::{Service, Transform};
use actix_web::{
body::MessageBody,
dev::{forward_ready, ServiceRequest, ServiceResponse},
Error, HttpResponse,
};
use indexmap::IndexMap;
use parking_lot::Mutex;
const MAX_CACHE_ENTRIES: usize = 1000;
const MAX_CACHEABLE_BODY_BYTES: usize = 1024 * 1024;
struct CacheEntry {
status: actix_web::http::StatusCode,
headers: actix_web::http::header::HeaderMap,
body: actix_web::web::Bytes,
expires_at: Instant,
vary_headers: Vec<(String, String)>,
}
fn build_cache_key(host: &str, uri: &str, vary_headers: &[(String, String)]) -> String {
let base_key = format!("{}::{}", host, uri);
if vary_headers.is_empty() {
return base_key;
}
let mut key = base_key;
for (name, value) in vary_headers {
key.push('\x00'); key.push_str(name);
key.push('=');
key.push_str(value);
}
key
}
fn extract_vary_values(
req_headers: &actix_web::http::header::HeaderMap,
vary: &str,
) -> Vec<(String, String)> {
vary.split(',')
.map(|s| s.trim().to_ascii_lowercase())
.filter(|name| name != "*") .map(|name| {
let value = req_headers
.get(&name)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_owned();
(name, value)
})
.collect()
}
fn vary_matches(
req_headers: &actix_web::http::header::HeaderMap,
cached_vary: &[(String, String)],
) -> bool {
cached_vary.iter().all(|(name, cached_value)| {
let current_value = req_headers
.get(name.as_str())
.and_then(|v| v.to_str().ok())
.unwrap_or("");
current_value == cached_value
})
}
#[derive(Clone)]
struct CacheStore {
inner: Arc<Mutex<IndexMap<String, CacheEntry>>>,
}
impl CacheStore {
fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(IndexMap::new())),
}
}
}
#[derive(Clone)]
pub struct Cache {
ttl: Duration,
}
impl Cache {
pub fn new(ttl: Duration) -> Self {
Self { ttl }
}
}
impl<S, B> Transform<S, ServiceRequest> for Cache
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: MessageBody + 'static,
{
type Response = ServiceResponse;
type Error = Error;
type Transform = CacheMiddleware<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(CacheMiddleware {
service,
ttl: self.ttl,
cache: CacheStore::new(),
}))
}
}
pub struct CacheMiddleware<S> {
service: S,
ttl: Duration,
cache: CacheStore,
}
impl<S, B> Service<ServiceRequest> for CacheMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: MessageBody + 'static,
{
type Response = ServiceResponse;
type Error = Error;
type Future =
std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>>>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let is_upgrade = req.headers().contains_key(actix_web::http::header::UPGRADE);
if req.method() != actix_web::http::Method::GET || is_upgrade {
let fut = self.service.call(req);
return Box::pin(async move {
let res = fut.await?;
Ok(res.map_into_boxed_body())
});
}
let uri = req.uri().to_string();
let host = req.connection_info().host().to_string();
let cache_prefix = format!("{}::{}", host, uri);
let now = Instant::now();
let req_headers = req.headers().clone();
{
let mut cache = self.cache.inner.lock();
let mut hit_key: Option<String> = None;
let mut expired_keys: Vec<String> = Vec::new();
for (key, entry) in cache.iter() {
if !key.starts_with(&cache_prefix) {
continue;
}
if entry.expires_at <= now {
expired_keys.push(key.clone());
continue;
}
if vary_matches(&req_headers, &entry.vary_headers) {
hit_key = Some(key.clone());
break;
}
}
for k in expired_keys {
cache.swap_remove(&k);
}
if let Some(key) = hit_key {
if let Some(entry) = cache.get(&key) {
let (req, _) = req.into_parts();
let mut builder = HttpResponse::build(entry.status);
for (name, value) in &entry.headers {
builder.insert_header((name.clone(), value.clone()));
}
let response = builder.body(entry.body.clone());
return Box::pin(ready(Ok(ServiceResponse::new(req, response))));
}
}
}
let ttl = self.ttl;
let cache = self.cache.clone();
let fut = self.service.call(req);
Box::pin(async move {
let res = fut.await?;
if !res.status().is_success() {
return Ok(res.map_into_boxed_body());
}
let (req, response) = res.into_parts();
let status = response.status();
let headers = response.headers().clone();
let is_sse = headers
.get(actix_web::http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s.contains("text/event-stream") || s.contains("websocket"))
.unwrap_or(false);
let is_response_upgrade = headers.contains_key(actix_web::http::header::UPGRADE);
let vary_star = headers
.get("vary")
.and_then(|v| v.to_str().ok())
.map(|v| v.trim() == "*")
.unwrap_or(false);
let declared_len = headers
.get(actix_web::http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<usize>().ok());
let should_cache = !is_sse
&& !is_response_upgrade
&& !vary_star
&& declared_len.map(|l| l <= MAX_CACHEABLE_BODY_BYTES).unwrap_or(false);
if !should_cache {
return Ok(ServiceResponse::new(req, response).map_into_boxed_body());
}
let vary_str = headers
.get("vary")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_owned();
let vary_values = extract_vary_values(&req_headers, &vary_str);
let cache_key = build_cache_key(&host, &uri, &vary_values);
let mut body = std::pin::pin!(response.into_body());
let mut bytes = actix_web::web::BytesMut::new();
while let Some(chunk_res) = std::future::poll_fn(|cx| {
body.as_mut().poll_next(cx)
})
.await
{
let chunk = chunk_res.map_err(|e| {
let boxed: Box<dyn std::error::Error> = e.into();
actix_web::error::ErrorInternalServerError(boxed.to_string())
})?;
bytes.extend_from_slice(&chunk);
if bytes.len() > MAX_CACHEABLE_BODY_BYTES {
return Err(actix_web::error::ErrorPayloadTooLarge(
"Response body exceeded cache size limit despite Content-Length claim",
));
}
}
let bytes = bytes.freeze();
let mut map = cache.inner.lock();
map.retain(|_, e| e.expires_at > now);
if map.len() >= MAX_CACHE_ENTRIES {
map.swap_remove_index(0);
}
map.insert(
cache_key,
CacheEntry {
status,
headers: headers.clone(),
body: bytes.clone(),
expires_at: now + ttl,
vary_headers: vary_values,
},
);
Ok(ServiceResponse::new(
req,
HttpResponse::build(status).body(bytes),
))
})
}
}