athena_rs 3.3.0

Database gateway API
Documentation
//! Retrieves the `X-JDBC-URL` header for direct Postgres connection.
//!
//! When present, the gateway can connect directly to a database instead of using
//! a pre-configured `X-Athena-Client`. Accepts JDBC-style URLs
//! (`jdbc:postgresql://...`) and converts them to the `postgres://` format used by sqlx.

use crate::api::headers::request_context::disallow_jdbc_routing;
use actix_web::HttpRequest;
use std::net::IpAddr;
use std::str::FromStr;
use tokio_postgres::Config as PgConfig;
use tokio_postgres::config::Host;

/// Extracts the JDBC URL from the request, if present.
/// Supports both `X-JDBC-URL` and `x-jdbc-url` (case-insensitive).
pub fn x_jdbc_url(req: &HttpRequest) -> Option<String> {
    if disallow_jdbc_routing(req) {
        return None;
    }

    req.headers()
        .get("X-JDBC-URL")
        .or_else(|| req.headers().get("x-jdbc-url"))
        .and_then(|h| h.to_str().ok())
        .map(|s| s.trim().to_string())
        .filter(|s| !s.is_empty())
}

/// Converts a JDBC PostgreSQL URL to the postgres:// format expected by sqlx.
///
/// Handles:
/// - `jdbc:postgresql://host:port/database` → `postgres://host:port/database`
/// - `jdbc:postgresql://user:pass@host:port/database` → `postgres://user:pass@host:port/database`
/// - Query parameters (user, password, sslmode, etc.) are preserved.
pub fn jdbc_to_postgres_url(jdbc_url: &str) -> Option<String> {
    let trimmed: &str = jdbc_url.trim();
    if trimmed.starts_with("jdbc:postgresql://") {
        Some(trimmed.replacen("jdbc:postgresql://", "postgres://", 1))
    } else if trimmed.starts_with("postgres://") || trimmed.starts_with("postgresql://") {
        Some(trimmed.to_string())
    } else {
        None
    }
}

fn host_matches_allowlist(host: &str, allowed_hosts: &[String]) -> bool {
    if allowed_hosts.is_empty() {
        return true;
    }
    let normalized: String = host.trim().to_ascii_lowercase();
    allowed_hosts
        .iter()
        .any(|allowed| normalized == *allowed || normalized.ends_with(&format!(".{}", allowed)))
}

fn host_is_private_or_local(host: &Host) -> bool {
    match host {
        Host::Tcp(hostname) => {
            let normalized: String = hostname.trim().to_ascii_lowercase();
            if normalized == "localhost" {
                return true;
            }
            match IpAddr::from_str(&normalized) {
                Ok(IpAddr::V4(ip)) => {
                    ip.is_loopback()
                        || ip.is_private()
                        || ip.is_link_local()
                        || ip.is_unspecified()
                        || ip.is_multicast()
                }
                Ok(IpAddr::V6(ip)) => {
                    ip.is_loopback()
                        || ip.is_unspecified()
                        || ip.is_multicast()
                        || ip.is_unique_local()
                        || ip.is_unicast_link_local()
                }
                Err(_) => false,
            }
        }
        #[cfg(unix)]
        Host::Unix(_) => true,
    }
}

pub fn validate_postgres_target(
    postgres_url: &str,
    allow_private_hosts: bool,
    allowed_hosts: &[String],
) -> Result<(), String> {
    let parsed = PgConfig::from_str(postgres_url)
        .map_err(|_| "X-JDBC-URL could not be parsed as a PostgreSQL connection string")?;

    let hosts = parsed.get_hosts();
    if hosts.is_empty() {
        return Err("X-JDBC-URL must include a host".to_string());
    }

    for host in hosts {
        match host {
            Host::Tcp(hostname) => {
                if !host_matches_allowlist(hostname, allowed_hosts) {
                    return Err(
                        "X-JDBC-URL host is not in the configured JDBC host allowlist".to_string(),
                    );
                }

                if !allow_private_hosts && host_is_private_or_local(host) {
                    return Err(
                        "X-JDBC-URL private or local hosts are blocked by policy".to_string()
                    );
                }
            }
            #[cfg(unix)]
            Host::Unix(_) => {
                return Err("X-JDBC-URL unix socket hosts are not allowed".to_string());
            }
        }
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_jdbc_to_postgres_url() {
        assert_eq!(
            jdbc_to_postgres_url("jdbc:postgresql://localhost:5432/mydb"),
            Some("postgres://localhost:5432/mydb".to_string())
        );
        assert_eq!(
            jdbc_to_postgres_url("jdbc:postgresql://user:pass@host/db"),
            Some("postgres://user:pass@host/db".to_string())
        );
        assert_eq!(
            jdbc_to_postgres_url("postgres://localhost/db"),
            Some("postgres://localhost/db".to_string())
        );
        assert_eq!(jdbc_to_postgres_url("jdbc:mysql://localhost/db"), None);
    }

    #[test]
    fn test_validate_postgres_target_rejects_local_when_private_blocked() {
        let result: Result<(), String> =
            validate_postgres_target("postgres://localhost:5432/test", false, &[]);
        assert!(result.is_err());
    }

    #[test]
    fn test_validate_postgres_target_allows_allowlisted_host() {
        let allowed: Vec<String> = vec!["db.example.com".to_string()];
        let result = validate_postgres_target(
            "postgres://user:pass@db.example.com:5432/test",
            false,
            &allowed,
        );
        assert!(result.is_ok());
    }

    #[test]
    fn test_validate_postgres_target_rejects_non_allowlisted_host() {
        let allowed: Vec<String> = vec!["db.example.com".to_string()];
        let result = validate_postgres_target(
            "postgres://user:pass@other.example.com:5432/test",
            true,
            &allowed,
        );
        assert!(result.is_err());
    }
}