mothership 0.0.100

Process supervisor with HTTP exposure - wrap, monitor, and expose your fleet
Documentation
//! CORS preflight response cache
//!
//! Caches OPTIONS responses from backends to reduce load. Cache key includes:
//! - Origin header
//! - Request path
//! - Bind name
//! - Host header
//! - Access-Control-Request-Method header
//! - Access-Control-Request-Headers header (normalized)

use moka::policy::Expiry;
use moka::sync::Cache;
use rama::http::header::{self, HeaderMap, HeaderName, HeaderValue};
use rama::http::{Response, StatusCode};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tracing::debug;

/// CORS cache key
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct CorsCacheKey {
    origin: String,
    path: String,
    bind: String,
    host: String,
    access_control_request_method: String,
    access_control_request_headers: String,
}

impl CorsCacheKey {
    /// Create a cache key from request headers
    pub fn from_request(path: &str, bind_name: &str, headers: &HeaderMap) -> Option<Self> {
        let origin = headers
            .get(header::ORIGIN)
            .and_then(|v| v.to_str().ok())
            .map(|s| s.to_string())?;

        // Don't cache null origin (privacy mode, file://, etc.)
        if origin == "null" {
            return None;
        }

        let acrm = headers
            .get(header::ACCESS_CONTROL_REQUEST_METHOD)
            .and_then(|v| v.to_str().ok())
            .unwrap_or("")
            .to_string();

        let acrh = headers
            .get(header::ACCESS_CONTROL_REQUEST_HEADERS)
            .and_then(|v| v.to_str().ok())
            .map(normalize_acrh)
            .unwrap_or_default();

        let host = headers
            .get(header::HOST)
            .and_then(|v| v.to_str().ok())
            .map(|s| s.to_ascii_lowercase())
            .unwrap_or_default();

        Some(Self {
            origin,
            path: path.to_string(),
            bind: bind_name.to_string(),
            host,
            access_control_request_method: acrm,
            access_control_request_headers: acrh,
        })
    }
}

/// Cached CORS response entry with per-entry TTL
#[derive(Debug, Clone)]
struct CorsCacheEntry {
    status: u16,
    headers: Vec<(String, String)>,
    ttl: Duration,
}

impl CorsCacheEntry {
    /// Create entry from response
    fn from_response(response: &Response, ttl: Duration) -> Self {
        let headers = response
            .headers()
            .iter()
            .filter(|(name, _)| is_cors_header(name))
            .map(|(name, value)| {
                (
                    name.as_str().to_string(),
                    value.to_str().unwrap_or("").to_string(),
                )
            })
            .collect();

        Self {
            status: response.status().as_u16(),
            headers,
            ttl,
        }
    }

    /// Convert entry back to response
    fn to_response(&self) -> Response {
        let mut builder =
            Response::builder().status(StatusCode::from_u16(self.status).unwrap_or(StatusCode::OK));

        for (name, value) in &self.headers {
            if let (Ok(name), Ok(value)) = (
                HeaderName::try_from(name.as_str()),
                HeaderValue::try_from(value.as_str()),
            ) {
                builder = builder.header(name, value);
            }
        }

        builder.body(rama::http::Body::empty()).unwrap()
    }
}

/// Per-entry TTL expiry policy based on Access-Control-Max-Age
struct CorsExpiry;

impl Expiry<CorsCacheKey, CorsCacheEntry> for CorsExpiry {
    fn expire_after_create(
        &self,
        _key: &CorsCacheKey,
        value: &CorsCacheEntry,
        _current_time: Instant,
    ) -> Option<Duration> {
        Some(value.ttl)
    }
}

/// CORS preflight cache
pub struct CorsCache {
    cache: Cache<CorsCacheKey, CorsCacheEntry>,
    default_ttl: Duration,
    hits: AtomicU64,
    misses: AtomicU64,
}

impl CorsCache {
    /// Create a new CORS cache
    pub fn new(default_ttl: Duration, max_entries: usize) -> Self {
        let cache = Cache::builder()
            .max_capacity(max_entries as u64)
            .expire_after(CorsExpiry)
            .build();

        Self {
            cache,
            default_ttl,
            hits: AtomicU64::new(0),
            misses: AtomicU64::new(0),
        }
    }

    /// Get cached response if available
    pub fn get(&self, key: &CorsCacheKey) -> Option<Response> {
        if let Some(entry) = self.cache.get(key) {
            self.hits.fetch_add(1, Ordering::Relaxed);
            debug!(
                origin = %key.origin,
                path = %key.path,
                bind = %key.bind,
                host = %key.host,
                "CORS cache hit"
            );
            return Some(entry.to_response());
        }
        self.misses.fetch_add(1, Ordering::Relaxed);
        None
    }

    /// Store response in cache
    pub fn insert(&self, key: CorsCacheKey, response: &Response) {
        // Only cache successful responses with CORS headers
        if !response.status().is_success() {
            return;
        }

        if !has_cors_headers(response.headers()) {
            return;
        }

        // Extract TTL from Access-Control-Max-Age or use default
        let ttl = extract_max_age(response.headers()).unwrap_or(self.default_ttl);

        debug!(
            origin = %key.origin,
            path = %key.path,
            bind = %key.bind,
            host = %key.host,
            ttl_secs = ttl.as_secs(),
            "CORS cache insert"
        );

        self.cache
            .insert(key, CorsCacheEntry::from_response(response, ttl));
    }

    /// Get cache statistics
    pub fn stats(&self) -> (u64, u64, u64) {
        (
            self.hits.load(Ordering::Relaxed),
            self.misses.load(Ordering::Relaxed),
            self.cache.entry_count(),
        )
    }

    /// Process pending tasks (for testing)
    #[cfg(test)]
    fn sync(&self) {
        self.cache.run_pending_tasks();
    }
}

/// Normalize Access-Control-Request-Headers for consistent cache keys
fn normalize_acrh(header: &str) -> String {
    let mut parts: Vec<&str> = header
        .split(',')
        .map(|s| s.trim())
        .filter(|s| !s.is_empty())
        .collect();
    parts.sort_unstable();
    parts
        .iter()
        .map(|s| s.to_lowercase())
        .collect::<Vec<_>>()
        .join(",")
}

/// Extract max-age from Access-Control-Max-Age header
fn extract_max_age(headers: &HeaderMap) -> Option<Duration> {
    headers
        .get(header::ACCESS_CONTROL_MAX_AGE)
        .and_then(|v| v.to_str().ok())
        .and_then(|s| s.parse::<u64>().ok())
        .map(Duration::from_secs)
}

/// Check if header is a CORS-related header worth caching
fn is_cors_header(name: &HeaderName) -> bool {
    matches!(
        name,
        &header::ACCESS_CONTROL_ALLOW_ORIGIN
            | &header::ACCESS_CONTROL_ALLOW_METHODS
            | &header::ACCESS_CONTROL_ALLOW_HEADERS
            | &header::ACCESS_CONTROL_ALLOW_CREDENTIALS
            | &header::ACCESS_CONTROL_MAX_AGE
            | &header::ACCESS_CONTROL_EXPOSE_HEADERS
            | &header::VARY
    )
}

/// Check if response has any CORS headers
fn has_cors_headers(headers: &HeaderMap) -> bool {
    headers.contains_key(header::ACCESS_CONTROL_ALLOW_ORIGIN)
        || headers.contains_key(header::ACCESS_CONTROL_ALLOW_METHODS)
        || headers.contains_key(header::ACCESS_CONTROL_ALLOW_HEADERS)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_normalize_acrh() {
        assert_eq!(normalize_acrh("Content-Type"), "content-type");
        assert_eq!(
            normalize_acrh("Authorization, Content-Type"),
            "authorization,content-type"
        );
        assert_eq!(
            normalize_acrh("  X-Custom ,  Content-Type  ,  Authorization  "),
            "authorization,content-type,x-custom"
        );
        assert_eq!(normalize_acrh(""), "");
    }

    #[test]
    fn test_cache_key_from_request() {
        let mut headers = HeaderMap::new();
        headers.insert(header::ORIGIN, "https://example.com".parse().unwrap());
        headers.insert(header::HOST, "api.example.com".parse().unwrap());
        headers.insert(
            header::ACCESS_CONTROL_REQUEST_METHOD,
            "POST".parse().unwrap(),
        );
        headers.insert(
            header::ACCESS_CONTROL_REQUEST_HEADERS,
            "Content-Type, Authorization".parse().unwrap(),
        );

        let key = CorsCacheKey::from_request("/api/login", "http", &headers).unwrap();
        assert_eq!(key.origin, "https://example.com");
        assert_eq!(key.path, "/api/login");
        assert_eq!(key.bind, "http");
        assert_eq!(key.host, "api.example.com");
        assert_eq!(key.access_control_request_method, "POST");
        assert_eq!(
            key.access_control_request_headers,
            "authorization,content-type"
        );
    }

    #[test]
    fn test_cache_key_null_origin() {
        let mut headers = HeaderMap::new();
        headers.insert(header::ORIGIN, "null".parse().unwrap());

        let key = CorsCacheKey::from_request("/api/login", "http", &headers);
        assert!(key.is_none());
    }

    #[test]
    fn test_cache_basic() {
        let cache = CorsCache::new(Duration::from_secs(3600), 1000);

        let mut headers = HeaderMap::new();
        headers.insert(header::ORIGIN, "https://example.com".parse().unwrap());
        headers.insert(header::HOST, "api.example.com".parse().unwrap());

        let key = CorsCacheKey::from_request("/api/test", "http", &headers).unwrap();

        // Miss on empty cache
        assert!(cache.get(&key).is_none());

        // Create a mock response
        let response = Response::builder()
            .status(StatusCode::OK)
            .header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "https://example.com")
            .header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET, POST")
            .body(rama::http::Body::empty())
            .unwrap();

        cache.insert(key.clone(), &response);
        cache.sync(); // moka processes inserts lazily

        // Hit after insert
        let cached = cache.get(&key);
        assert!(cached.is_some());

        let (hits, misses, entries) = cache.stats();
        assert_eq!(hits, 1);
        assert_eq!(misses, 1);
        assert_eq!(entries, 1);
    }

    #[test]
    fn test_cache_isolated_by_bind_and_host() {
        let cache = CorsCache::new(Duration::from_secs(3600), 1000);

        let mut headers = HeaderMap::new();
        headers.insert(header::ORIGIN, "https://example.com".parse().unwrap());
        headers.insert(header::HOST, "api.example.com".parse().unwrap());

        let key = CorsCacheKey::from_request("/api/test", "http", &headers).unwrap();

        let response = Response::builder()
            .status(StatusCode::OK)
            .header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "https://example.com")
            .header(header::ACCESS_CONTROL_ALLOW_METHODS, "GET, POST")
            .body(rama::http::Body::empty())
            .unwrap();

        cache.insert(key, &response);
        cache.sync();

        headers.insert(header::HOST, "admin.example.com".parse().unwrap());
        let other_host_key = CorsCacheKey::from_request("/api/test", "http", &headers).unwrap();
        assert!(cache.get(&other_host_key).is_none());

        headers.insert(header::HOST, "api.example.com".parse().unwrap());
        let other_bind_key = CorsCacheKey::from_request("/api/test", "https", &headers).unwrap();
        assert!(cache.get(&other_bind_key).is_none());
    }
}