http-smtp-rele 0.1.0

Minimal, secure HTTP-to-SMTP submission relay
Documentation
//! Authentication and access control.
//!
//! Implements RFC 040: API key authentication via Axum `FromRequestParts`.
//!
//! # Flow
//!
//! ```text
//! Request
//!   -> resolve client IP (socket peer / X-Forwarded-For via trusted proxy)
//!   -> extract token from Authorization: Bearer or X-API-Key header
//!   -> constant-time compare against each enabled api_key secret
//!   -> check source CIDR allowlist (if configured)
//!   -> produce AuthContext on success
//!   -> return 401 / 403 on failure
//! ```
//!
//! # Security notes
//!
//! - Token comparison always uses constant-time equality to prevent timing attacks.
//! - Tokens are never logged; only `key_id` (non-secret) is propagated.
//! - Forwarded headers are only trusted when the peer IP is in `trusted_source_cidrs`.
//! - Source IP allowlist is enforced via `allowed_source_cidrs` (distinct field).

use std::{net::IpAddr, sync::Arc};

use axum::{
    extract::{FromRef, FromRequestParts},
    http::{request::Parts, StatusCode},
};
use ipnet::IpNet;
use subtle::ConstantTimeEq;
use tracing::warn;

use crate::{config::ApiKeyConfig, AppState};

// ---------------------------------------------------------------------------
// AuthContext
// ---------------------------------------------------------------------------

/// Proof of successful authentication for a single request.
///
/// Produced by the `AuthContext` Axum extractor.
/// Only `key_id` is stored — the secret is never retained after comparison.
#[derive(Debug, Clone)]
pub struct AuthContext {
    /// Non-secret identifier for the matched API key (suitable for logging).
    pub key_id: String,
    /// Resolved client IP after trusted-proxy handling.
    pub client_ip: IpAddr,
    /// Per-key rate limit override (tokens/minute). None = use global default.
    pub key_rate_limit_per_min: Option<u32>,
}

// ---------------------------------------------------------------------------
// Axum extractor
// ---------------------------------------------------------------------------

impl<S> FromRequestParts<S> for AuthContext
where
    Arc<AppState>: axum::extract::FromRef<S>,
    S: Send + Sync,
{
    type Rejection = (StatusCode, axum::Json<serde_json::Value>);

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let app_state = Arc::<AppState>::from_ref(state);
        let cfg = &app_state.config;
        let security = &cfg.security;

        // ------------------------------------------------------------------
        // 1. Resolve client IP
        // ------------------------------------------------------------------
        let peer_ip = resolve_peer_ip(parts);
        let client_ip = if security.trust_proxy_headers {
            resolve_client_ip(parts, peer_ip, &security.trusted_source_cidrs)
        } else {
            peer_ip
        };

        // ------------------------------------------------------------------
        // 2. Source CIDR allowlist (empty = allow all)
        //
        // `allowed_source_cidrs` controls which resolved client IPs may proceed.
        // This is distinct from `trusted_source_cidrs` (proxy header trust).
        // ------------------------------------------------------------------
        if !security.allowed_source_cidrs.is_empty()
            && !ip_in_cidrs(client_ip, &security.allowed_source_cidrs)
        {
            warn!(
                client_ip = %client_ip,
                "auth: client IP not in allowed_source_cidrs"
            );
            return Err(forbidden());
        }

        // ------------------------------------------------------------------
        // 3. Extract token from headers
        // ------------------------------------------------------------------
        let token = match extract_token(parts) {
            Some(t) => t,
            None => {
                warn!(client_ip = %client_ip, "auth: missing or malformed token");
                return Err(unauthorized());
            }
        };

        // ------------------------------------------------------------------
        // 4. Constant-time match against api_keys
        // ------------------------------------------------------------------
        match find_matching_key(&security.api_keys, token) {
            MatchResult::Matched(key_id, key_rate_limit_per_min) => Ok(AuthContext { key_id, client_ip, key_rate_limit_per_min }),
            MatchResult::Disabled(key_id) => {
                warn!(
                    client_ip = %client_ip,
                    key_id = %key_id,
                    "auth: key is disabled"
                );
                Err(forbidden())
            }
            MatchResult::NotFound => {
                warn!(client_ip = %client_ip, "auth: token not matched");
                Err(forbidden())
            }
        }
    }
}

// ---------------------------------------------------------------------------
// Token extraction
// ---------------------------------------------------------------------------

/// Extract token string from `Authorization: Bearer <token>` or `X-API-Key: <token>`.
///
/// Priority: `Authorization` > `X-API-Key`.
/// Returns `None` if neither header is present or if `Authorization` is
/// present but malformed (not `Bearer `-prefixed).
fn extract_token(parts: &Parts) -> Option<&str> {
    // Authorization: Bearer <token>
    if let Some(auth) = parts
        .headers
        .get(axum::http::header::AUTHORIZATION)
        .and_then(|v| v.to_str().ok())
    {
        return auth.strip_prefix("Bearer ");
        // Explicit return: if Authorization is present but malformed, reject.
        // Do NOT fall through to X-API-Key.
    }

    // X-API-Key: <token>  (fallback)
    parts
        .headers
        .get("x-api-key")
        .and_then(|v| v.to_str().ok())
}

// ---------------------------------------------------------------------------
// Key matching
// ---------------------------------------------------------------------------

enum MatchResult {
    /// key_id, rate_limit_per_min
    Matched(String, Option<u32>),
    Disabled(String),
    NotFound,
}

/// Compare `token` against every configured API key using constant-time equality.
///
/// All comparisons are performed regardless of early match to avoid
/// timing-based enumeration of which keys are configured.
fn find_matching_key(keys: &[ApiKeyConfig], token: &str) -> MatchResult {
    let token_bytes = token.as_bytes();
    let mut matched_key: Option<&ApiKeyConfig> = None;

    for key in keys {
        let secret_bytes = key.secret.expose().as_bytes();
        if ct_eq_bytes(token_bytes, secret_bytes) {
            matched_key = Some(key);
            // Continue loop — do not break — to avoid timing difference.
        }
    }

    match matched_key {
        Some(k) if k.enabled => MatchResult::Matched(k.id.clone(), k.rate_limit_per_min),
        Some(k) => MatchResult::Disabled(k.id.clone()),
        None => MatchResult::NotFound,
    }
}

/// Constant-time byte-slice comparison.
///
/// Length mismatch is detected before comparison (leaking only length, not content).
/// This is acceptable since token lengths are not secret.
fn ct_eq_bytes(a: &[u8], b: &[u8]) -> bool {
    if a.len() != b.len() {
        return false;
    }
    a.ct_eq(b).into()
}

// ---------------------------------------------------------------------------
// Client IP resolution
// ---------------------------------------------------------------------------

/// Return the socket peer IP address.
///
/// Falls back to `127.0.0.1` if the peer address is unavailable (e.g., in tests).
fn resolve_peer_ip(parts: &Parts) -> IpAddr {
    parts
        .extensions
        .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
        .map(|ci| ci.0.ip())
        .unwrap_or(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST))
}

/// Resolve the effective client IP, honouring `X-Forwarded-For` only when
/// the peer IP is in the trusted proxy CIDR list.
fn resolve_client_ip(parts: &Parts, peer_ip: IpAddr, trusted_cidrs: &[String]) -> IpAddr {
    if !ip_in_cidrs(peer_ip, trusted_cidrs) {
        return peer_ip;
    }
    // Peer is a trusted proxy — use the leftmost value in X-Forwarded-For.
    parts
        .headers
        .get("x-forwarded-for")
        .and_then(|v| v.to_str().ok())
        .and_then(|v| v.split(',').next())
        .and_then(|s| s.trim().parse::<IpAddr>().ok())
        .unwrap_or(peer_ip)
}

// ---------------------------------------------------------------------------
// CIDR helpers
// ---------------------------------------------------------------------------

fn ip_in_cidrs(ip: IpAddr, cidrs: &[String]) -> bool {
    cidrs
        .iter()
        .filter_map(|s| s.parse::<IpNet>().ok())
        .any(|net| net.contains(&ip))
}

// ---------------------------------------------------------------------------
// Error responses
// ---------------------------------------------------------------------------

fn unauthorized() -> (StatusCode, axum::Json<serde_json::Value>) {
    (
        StatusCode::UNAUTHORIZED,
        axum::Json(serde_json::json!({
            "status": "error",
            "code": "unauthorized",
            "message": "Authentication required"
        })),
    )
}

fn forbidden() -> (StatusCode, axum::Json<serde_json::Value>) {
    (
        StatusCode::FORBIDDEN,
        axum::Json(serde_json::json!({
            "status": "error",
            "code": "forbidden",
            "message": "Access denied"
        })),
    )
}

// ---------------------------------------------------------------------------
// Unit tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::*;
    use crate::config::{ApiKeyConfig, SecretString};

    fn make_key(id: &str, secret: &str, enabled: bool) -> ApiKeyConfig {
        ApiKeyConfig {
            id: id.to_string(),
            secret: SecretString::new(secret),
            enabled,
            description: None,
            allowed_recipient_domains: vec![],
            rate_limit_per_min: None,
        }
    }

    #[test]
    fn matching_key_returns_key_id() {
        let keys = vec![make_key("svc-a", "secret-a", true)];
        match find_matching_key(&keys, "secret-a") {
            MatchResult::Matched(id, _) => assert_eq!(id, "svc-a"),
            _ => panic!("expected Matched"),
        }
    }

    #[test]
    fn wrong_token_returns_not_found() {
        let keys = vec![make_key("svc-a", "secret-a", true)];
        assert!(matches!(
            find_matching_key(&keys, "wrong"),
            MatchResult::NotFound
        ));
    }

    #[test]
    fn disabled_key_returns_disabled() {
        let keys = vec![make_key("svc-a", "secret-a", false)];
        assert!(matches!(
            find_matching_key(&keys, "secret-a"),
            MatchResult::Disabled(_)
        ));
    }

    #[test]
    fn multiple_keys_correct_one_matches() {
        let keys = vec![
            make_key("svc-a", "token-aaa", true),
            make_key("svc-b", "token-bbb", true),
        ];
        match find_matching_key(&keys, "token-bbb") {
            MatchResult::Matched(id, _) => assert_eq!(id, "svc-b"),
            _ => panic!("expected Matched for svc-b"),
        }
    }

    #[test]
    fn ip_in_cidrs_loopback() {
        let cidrs = vec!["127.0.0.1/32".to_string()];
        assert!(ip_in_cidrs("127.0.0.1".parse().unwrap(), &cidrs));
        assert!(!ip_in_cidrs("10.0.0.1".parse().unwrap(), &cidrs));
    }

    #[test]
    fn ip_in_cidrs_range() {
        let cidrs = vec!["10.0.0.0/8".to_string()];
        assert!(ip_in_cidrs("10.1.2.3".parse().unwrap(), &cidrs));
        assert!(!ip_in_cidrs("192.168.1.1".parse().unwrap(), &cidrs));
    }

    #[test]
    fn empty_cidr_list_returns_false() {
        assert!(!ip_in_cidrs("127.0.0.1".parse().unwrap(), &[]));
    }

    #[test]
    fn different_length_tokens_do_not_match() {
        assert!(!ct_eq_bytes(b"short", b"longer-token"));
    }
}