athena_rs 3.3.0

Database gateway API
Documentation
//! Host-based client routing helpers.
//!
//! Supports wildcard domain patterns like `*.v3.athena-db.com` and extracts
//! the `{prefix}` label so requests can be routed to a mapped Athena client.

use actix_web::http::header::{HOST, HeaderMap};
use std::env;

pub const DEFAULT_WILDCARD_HOST_PATTERN: &str = "*.v3.athena-db.com";
pub const DEFAULT_WILDCARD_TARGET_BASE_URL: &str = "https://pool.athena-db.com";

const ENV_WILDCARD_HOST_PATTERN: &str = "ATHENA_WILDCARD_HOST_PATTERN";
const ENV_WILDCARD_HOST_ROUTING_ENABLED: &str = "ATHENA_WILDCARD_HOST_ROUTING_ENABLED";

fn header_value_as_trimmed_str(headers: &HeaderMap, name: &str) -> Option<String> {
    headers
        .get(name)
        .and_then(|value| value.to_str().ok())
        .map(str::trim)
        .filter(|value| !value.is_empty())
        .map(str::to_string)
}

fn normalize_host(host_header: &str) -> Option<String> {
    let raw = host_header
        .trim()
        .trim_end_matches('.')
        .to_ascii_lowercase();
    if raw.is_empty() {
        return None;
    }

    // Strip a trailing `:port` when present.
    let without_port = if raw.starts_with('[') {
        // Bracketed IPv6 hosts are not eligible for wildcard routing.
        return None;
    } else {
        raw.split(':').next().unwrap_or_default().trim().to_string()
    };

    if without_port.is_empty() {
        None
    } else {
        Some(without_port)
    }
}

fn wildcard_suffix_from_pattern(pattern: &str) -> Option<String> {
    let normalized = pattern.trim().trim_end_matches('.').to_ascii_lowercase();
    let suffix = normalized.strip_prefix("*.")?.trim();
    if suffix.is_empty() || suffix.contains('*') {
        return None;
    }
    Some(suffix.to_string())
}

fn is_prefix_label_valid(prefix: &str) -> bool {
    !prefix.is_empty()
        && !prefix.contains('.')
        && prefix
            .chars()
            .all(|ch| ch.is_ascii_lowercase() || ch.is_ascii_digit() || ch == '-' || ch == '_')
}

pub fn wildcard_host_pattern() -> String {
    env::var(ENV_WILDCARD_HOST_PATTERN)
        .map(|value| value.trim().to_string())
        .ok()
        .filter(|value| !value.is_empty())
        .unwrap_or_else(|| DEFAULT_WILDCARD_HOST_PATTERN.to_string())
}

pub fn wildcard_host_routing_enabled() -> bool {
    env::var(ENV_WILDCARD_HOST_ROUTING_ENABLED)
        .ok()
        .map(|value| value.trim().to_ascii_lowercase())
        .map(|value| !matches!(value.as_str(), "0" | "false" | "no" | "off"))
        .unwrap_or(true)
}

pub fn extract_prefix_from_host(host_header: &str, wildcard_pattern: &str) -> Option<String> {
    let host = normalize_host(host_header)?;
    let suffix = wildcard_suffix_from_pattern(wildcard_pattern)?;
    let dotted_suffix = format!(".{suffix}");
    if !host.ends_with(&dotted_suffix) {
        return None;
    }

    let prefix = &host[..host.len().saturating_sub(dotted_suffix.len())];
    if !is_prefix_label_valid(prefix) {
        return None;
    }

    Some(prefix.to_string())
}

/// Returns a wildcard host prefix candidate when host routing should be applied.
///
/// Routing is skipped whenever an explicit `X-Athena-Client` or `X-JDBC-URL`
/// was already supplied by the caller.
pub fn infer_wildcard_host_prefix(headers: &HeaderMap) -> Option<String> {
    if !wildcard_host_routing_enabled() {
        return None;
    }

    if header_value_as_trimmed_str(headers, "X-Athena-Client").is_some()
        || header_value_as_trimmed_str(headers, "X-JDBC-URL").is_some()
    {
        return None;
    }

    let host_header = header_value_as_trimmed_str(headers, HOST.as_str())?;
    extract_prefix_from_host(&host_header, &wildcard_host_pattern())
}

#[cfg(test)]
mod tests {
    use super::{
        extract_prefix_from_host, infer_wildcard_host_prefix, wildcard_host_pattern,
        wildcard_host_routing_enabled,
    };
    use actix_web::http::header::{HOST, HeaderMap, HeaderName, HeaderValue};
    use std::sync::Mutex;

    static ENV_LOCK: Mutex<()> = Mutex::new(());

    #[test]
    fn extracts_prefix_from_matching_host() {
        let prefix =
            extract_prefix_from_host("athena_logging.v3.athena-db.com", "*.v3.athena-db.com");
        assert_eq!(prefix.as_deref(), Some("athena_logging"));
    }

    #[test]
    fn extracts_prefix_and_strips_port() {
        let prefix =
            extract_prefix_from_host("Reporting.v3.athena-db.com:443", "*.v3.athena-db.com");
        assert_eq!(prefix.as_deref(), Some("reporting"));
    }

    #[test]
    fn rejects_multi_label_prefix() {
        let prefix = extract_prefix_from_host("foo.bar.v3.athena-db.com", "*.v3.athena-db.com");
        assert!(prefix.is_none());
    }

    #[test]
    fn rejects_non_matching_suffix() {
        let prefix =
            extract_prefix_from_host("athena_logging.v2.athena-db.com", "*.v3.athena-db.com");
        assert!(prefix.is_none());
    }

    #[test]
    fn infer_skips_when_explicit_headers_exist() {
        let mut headers = HeaderMap::new();
        headers.insert(
            HOST,
            HeaderValue::from_static("athena_logging.v3.athena-db.com"),
        );
        headers.insert(
            HeaderName::from_static("x-athena-client"),
            HeaderValue::from_static("explicit"),
        );
        assert!(infer_wildcard_host_prefix(&headers).is_none());
    }

    #[test]
    fn env_defaults_are_set() {
        assert!(!wildcard_host_pattern().trim().is_empty());
        let _ = wildcard_host_routing_enabled();
    }

    #[test]
    fn extract_accepts_hyphen_underscore_and_digits() {
        let prefix =
            extract_prefix_from_host("acme_42-prod.v3.athena-db.com", "*.v3.athena-db.com");
        assert_eq!(prefix.as_deref(), Some("acme_42-prod"));
    }

    #[test]
    fn extract_lowercases_prefix() {
        let prefix = extract_prefix_from_host("AcMe_42.v3.athena-db.com", "*.v3.athena-db.com");
        assert_eq!(prefix.as_deref(), Some("acme_42"));
    }

    #[test]
    fn extract_rejects_empty_prefix() {
        let prefix = extract_prefix_from_host("v3.athena-db.com", "*.v3.athena-db.com");
        assert!(prefix.is_none());
    }

    #[test]
    fn extract_rejects_invalid_prefix_character() {
        let prefix = extract_prefix_from_host("acme!prod.v3.athena-db.com", "*.v3.athena-db.com");
        assert!(prefix.is_none());
    }

    #[test]
    fn extract_accepts_trailing_dot() {
        let prefix =
            extract_prefix_from_host("athena_logging.v3.athena-db.com.", "*.v3.athena-db.com");
        assert_eq!(prefix.as_deref(), Some("athena_logging"));
    }

    #[test]
    fn extract_rejects_ipv6_host() {
        let prefix = extract_prefix_from_host("[::1]:4052", "*.v3.athena-db.com");
        assert!(prefix.is_none());
    }

    #[test]
    fn extract_rejects_invalid_wildcard_pattern() {
        let prefix =
            extract_prefix_from_host("athena_logging.v3.athena-db.com", "*.*.athena-db.com");
        assert!(prefix.is_none());
    }

    #[test]
    fn infer_returns_none_without_host_header() {
        let headers = HeaderMap::new();
        assert!(infer_wildcard_host_prefix(&headers).is_none());
    }

    #[test]
    fn infer_skips_when_jdbc_header_exists() {
        let mut headers = HeaderMap::new();
        headers.insert(
            HOST,
            HeaderValue::from_static("athena_logging.v3.athena-db.com"),
        );
        headers.insert(
            HeaderName::from_static("x-jdbc-url"),
            HeaderValue::from_static("jdbc:postgresql://localhost:5432/db"),
        );
        assert!(infer_wildcard_host_prefix(&headers).is_none());
    }

    #[test]
    fn infer_resolves_prefix_from_host_without_explicit_headers() {
        let mut headers = HeaderMap::new();
        headers.insert(
            HOST,
            HeaderValue::from_static("athena_logging.v3.athena-db.com"),
        );
        let prefix = infer_wildcard_host_prefix(&headers);
        assert_eq!(prefix.as_deref(), Some("athena_logging"));
    }

    #[test]
    fn infer_resolves_prefix_from_host_with_port() {
        let mut headers = HeaderMap::new();
        headers.insert(
            HOST,
            HeaderValue::from_static("reporting.v3.athena-db.com:443"),
        );
        let prefix = infer_wildcard_host_prefix(&headers);
        assert_eq!(prefix.as_deref(), Some("reporting"));
    }

    #[test]
    fn wildcard_host_pattern_uses_env_override() {
        let _guard = ENV_LOCK.lock().expect("lock env mutex");
        unsafe {
            std::env::set_var("ATHENA_WILDCARD_HOST_PATTERN", "*.v4.athena-db.com");
        }
        assert_eq!(wildcard_host_pattern(), "*.v4.athena-db.com");
        unsafe {
            std::env::remove_var("ATHENA_WILDCARD_HOST_PATTERN");
        }
    }

    #[test]
    fn wildcard_host_routing_enabled_respects_false_values() {
        let _guard = ENV_LOCK.lock().expect("lock env mutex");
        for value in ["0", "false", "no", "off"] {
            unsafe {
                std::env::set_var("ATHENA_WILDCARD_HOST_ROUTING_ENABLED", value);
            }
            assert!(
                !wildcard_host_routing_enabled(),
                "value '{value}' should disable"
            );
        }
        unsafe {
            std::env::remove_var("ATHENA_WILDCARD_HOST_ROUTING_ENABLED");
        }
    }

    #[test]
    fn wildcard_host_routing_enabled_accepts_trueish_values() {
        let _guard = ENV_LOCK.lock().expect("lock env mutex");
        for value in ["1", "true", "yes", "on"] {
            unsafe {
                std::env::set_var("ATHENA_WILDCARD_HOST_ROUTING_ENABLED", value);
            }
            assert!(
                wildcard_host_routing_enabled(),
                "value '{value}' should enable"
            );
        }
        unsafe {
            std::env::remove_var("ATHENA_WILDCARD_HOST_ROUTING_ENABLED");
        }
    }
}