use moka::policy::Expiry;
use moka::sync::Cache;
use rama::http::header::{self, HeaderMap, HeaderName, HeaderValue};
use rama::http::{Response, StatusCode};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tracing::debug;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct CorsCacheKey {
origin: String,
path: String,
bind: String,
host: String,
access_control_request_method: String,
access_control_request_headers: String,
}
impl CorsCacheKey {
pub fn from_request(path: &str, bind_name: &str, headers: &HeaderMap) -> Option<Self> {
let origin = headers
.get(header::ORIGIN)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())?;
if origin == "null" {
return None;
}
let acrm = headers
.get(header::ACCESS_CONTROL_REQUEST_METHOD)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let acrh = headers
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
.and_then(|v| v.to_str().ok())
.map(normalize_acrh)
.unwrap_or_default();
let host = headers
.get(header::HOST)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_ascii_lowercase())
.unwrap_or_default();
Some(Self {
origin,
path: path.to_string(),
bind: bind_name.to_string(),
host,
access_control_request_method: acrm,
access_control_request_headers: acrh,
})
}
}
#[derive(Debug, Clone)]
struct CorsCacheEntry {
status: u16,
headers: Vec<(String, String)>,
ttl: Duration,
}
impl CorsCacheEntry {
fn from_response(response: &Response, ttl: Duration) -> Self {
let headers = response
.headers()
.iter()
.filter(|(name, _)| is_cors_header(name))
.map(|(name, value)| {
(
name.as_str().to_string(),
value.to_str().unwrap_or("").to_string(),
)
})
.collect();
Self {
status: response.status().as_u16(),
headers,
ttl,
}
}
fn to_response(&self) -> Response {
let mut builder =
Response::builder().status(StatusCode::from_u16(self.status).unwrap_or(StatusCode::OK));
for (name, value) in &self.headers {
if let (Ok(name), Ok(value)) = (
HeaderName::try_from(name.as_str()),
HeaderValue::try_from(value.as_str()),
) {
builder = builder.header(name, value);
}
}
builder.body(rama::http::Body::empty()).unwrap()
}
}
struct CorsExpiry;
impl Expiry<CorsCacheKey, CorsCacheEntry> for CorsExpiry {
fn expire_after_create(
&self,
_key: &CorsCacheKey,
value: &CorsCacheEntry,
_current_time: Instant,
) -> Option<Duration> {
Some(value.ttl)
}
}
pub struct CorsCache {
cache: Cache<CorsCacheKey, CorsCacheEntry>,
default_ttl: Duration,
hits: AtomicU64,
misses: AtomicU64,
}
impl CorsCache {
pub fn new(default_ttl: Duration, max_entries: usize) -> Self {
let cache = Cache::builder()
.max_capacity(max_entries as u64)
.expire_after(CorsExpiry)
.build();
Self {
cache,
default_ttl,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
pub fn get(&self, key: &CorsCacheKey) -> Option<Response> {
if let Some(entry) = self.cache.get(key) {
self.hits.fetch_add(1, Ordering::Relaxed);
debug!(
origin = %key.origin,
path = %key.path,
bind = %key.bind,
host = %key.host,
"CORS cache hit"
);
return Some(entry.to_response());
}
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
pub fn insert(&self, key: CorsCacheKey, response: &Response) {
if !response.status().is_success() {
return;
}
if !has_cors_headers(response.headers()) {
return;
}
let ttl = extract_max_age(response.headers()).unwrap_or(self.default_ttl);
debug!(
origin = %key.origin,
path = %key.path,
bind = %key.bind,
host = %key.host,
ttl_secs = ttl.as_secs(),
"CORS cache insert"
);
self.cache
.insert(key, CorsCacheEntry::from_response(response, ttl));
}
pub fn stats(&self) -> (u64, u64, u64) {
(
self.hits.load(Ordering::Relaxed),
self.misses.load(Ordering::Relaxed),
self.cache.entry_count(),
)
}
#[cfg(test)]
fn sync(&self) {
self.cache.run_pending_tasks();
}
}
fn normalize_acrh(header: &str) -> String {
let mut parts: Vec<&str> = header
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
parts.sort_unstable();
parts
.iter()
.map(|s| s.to_lowercase())
.collect::<Vec<_>>()
.join(",")
}
fn extract_max_age(headers: &HeaderMap) -> Option<Duration> {
headers
.get(header::ACCESS_CONTROL_MAX_AGE)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_secs)
}
fn is_cors_header(name: &HeaderName) -> bool {
matches!(
name,
&header::ACCESS_CONTROL_ALLOW_ORIGIN
| &header::ACCESS_CONTROL_ALLOW_METHODS
| &header::ACCESS_CONTROL_ALLOW_HEADERS
| &header::ACCESS_CONTROL_ALLOW_CREDENTIALS
| &header::ACCESS_CONTROL_MAX_AGE
| &header::ACCESS_CONTROL_EXPOSE_HEADERS
| &header::VARY
)
}
fn has_cors_headers(headers: &HeaderMap) -> bool {
headers.contains_key(header::ACCESS_CONTROL_ALLOW_ORIGIN)
|| headers.contains_key(header::ACCESS_CONTROL_ALLOW_METHODS)
|| headers.contains_key(header::ACCESS_CONTROL_ALLOW_HEADERS)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_acrh() {
assert_eq!(normalize_acrh("Content-Type"), "content-type");
assert_eq!(
normalize_acrh("Authorization, Content-Type"),
"authorization,content-type"
);
assert_eq!(
normalize_acrh(" X-Custom , Content-Type , Authorization "),
"authorization,content-type,x-custom"
);
assert_eq!(normalize_acrh(""), "");
}
#[test]
fn test_cache_key_from_request() {
let mut headers = HeaderMap::new();
headers.insert(header::ORIGIN, "https://example.com".parse().unwrap());
headers.insert(header::HOST, "api.example.com".parse().unwrap());
headers.insert(
header::ACCESS_CONTROL_REQUEST_METHOD,
"POST".parse().unwrap(),
);
headers.insert(
header::ACCESS_CONTROL_REQUEST_HEADERS,
"Content-Type, Authorization".parse().unwrap(),
);
let key = CorsCacheKey::from_request("/api/login", "http", &headers).unwrap();
assert_eq!(key.origin, "https://example.com");
assert_eq!(key.path, "/api/login");
assert_eq!(key.bind, "http");
assert_eq!(key.host, "api.example.com");
assert_eq!(key.access_control_request_method, "POST");
assert_eq!(
key.access_control_request_headers,
"authorization,content-type"
);
}
#[test]
fn test_cache_key_null_origin() {
let mut headers = HeaderMap::new();
headers.insert(header::ORIGIN, "null".parse().unwrap());
let key = CorsCacheKey::from_request("/api/login", "http", &headers);
assert!(key.is_none());
}
#[test]
fn test_cache_basic() {
let cache = CorsCache::new(Duration::from_secs(3600), 1000);
let mut headers = HeaderMap::new();
headers.insert(header::ORIGIN, "https://example.com".parse().unwrap());
headers.insert(header::HOST, "api.example.com".parse().unwrap());
let key = CorsCacheKey::from_request("/api/test", "http", &headers).unwrap();
assert!(cache.get(&key).is_none());
let response = Response::builder()
.status(StatusCode::OK)
.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "https://example.com")
.header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET, POST")
.body(rama::http::Body::empty())
.unwrap();
cache.insert(key.clone(), &response);
cache.sync();
let cached = cache.get(&key);
assert!(cached.is_some());
let (hits, misses, entries) = cache.stats();
assert_eq!(hits, 1);
assert_eq!(misses, 1);
assert_eq!(entries, 1);
}
#[test]
fn test_cache_isolated_by_bind_and_host() {
let cache = CorsCache::new(Duration::from_secs(3600), 1000);
let mut headers = HeaderMap::new();
headers.insert(header::ORIGIN, "https://example.com".parse().unwrap());
headers.insert(header::HOST, "api.example.com".parse().unwrap());
let key = CorsCacheKey::from_request("/api/test", "http", &headers).unwrap();
let response = Response::builder()
.status(StatusCode::OK)
.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "https://example.com")
.header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET, POST")
.body(rama::http::Body::empty())
.unwrap();
cache.insert(key, &response);
cache.sync();
headers.insert(header::HOST, "admin.example.com".parse().unwrap());
let other_host_key = CorsCacheKey::from_request("/api/test", "http", &headers).unwrap();
assert!(cache.get(&other_host_key).is_none());
headers.insert(header::HOST, "api.example.com".parse().unwrap());
let other_bind_key = CorsCacheKey::from_request("/api/test", "https", &headers).unwrap();
assert!(cache.get(&other_bind_key).is_none());
}
}