use crate::api::headers::request_context::disallow_jdbc_routing;
use actix_web::HttpRequest;
use std::net::IpAddr;
use std::str::FromStr;
use tokio_postgres::Config as PgConfig;
use tokio_postgres::config::Host;
pub fn x_jdbc_url(req: &HttpRequest) -> Option<String> {
if disallow_jdbc_routing(req) {
return None;
}
req.headers()
.get("X-JDBC-URL")
.or_else(|| req.headers().get("x-jdbc-url"))
.and_then(|h| h.to_str().ok())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}
pub fn jdbc_to_postgres_url(jdbc_url: &str) -> Option<String> {
let trimmed: &str = jdbc_url.trim();
if trimmed.starts_with("jdbc:postgresql://") {
Some(trimmed.replacen("jdbc:postgresql://", "postgres://", 1))
} else if trimmed.starts_with("postgres://") || trimmed.starts_with("postgresql://") {
Some(trimmed.to_string())
} else {
None
}
}
fn host_matches_allowlist(host: &str, allowed_hosts: &[String]) -> bool {
if allowed_hosts.is_empty() {
return true;
}
let normalized: String = host.trim().to_ascii_lowercase();
allowed_hosts
.iter()
.any(|allowed| normalized == *allowed || normalized.ends_with(&format!(".{}", allowed)))
}
fn host_is_private_or_local(host: &Host) -> bool {
match host {
Host::Tcp(hostname) => {
let normalized: String = hostname.trim().to_ascii_lowercase();
if normalized == "localhost" {
return true;
}
match IpAddr::from_str(&normalized) {
Ok(IpAddr::V4(ip)) => {
ip.is_loopback()
|| ip.is_private()
|| ip.is_link_local()
|| ip.is_unspecified()
|| ip.is_multicast()
}
Ok(IpAddr::V6(ip)) => {
ip.is_loopback()
|| ip.is_unspecified()
|| ip.is_multicast()
|| ip.is_unique_local()
|| ip.is_unicast_link_local()
}
Err(_) => false,
}
}
#[cfg(unix)]
Host::Unix(_) => true,
}
}
pub fn validate_postgres_target(
postgres_url: &str,
allow_private_hosts: bool,
allowed_hosts: &[String],
) -> Result<(), String> {
let parsed = PgConfig::from_str(postgres_url)
.map_err(|_| "X-JDBC-URL could not be parsed as a PostgreSQL connection string")?;
let hosts = parsed.get_hosts();
if hosts.is_empty() {
return Err("X-JDBC-URL must include a host".to_string());
}
for host in hosts {
match host {
Host::Tcp(hostname) => {
if !host_matches_allowlist(hostname, allowed_hosts) {
return Err(
"X-JDBC-URL host is not in the configured JDBC host allowlist".to_string(),
);
}
if !allow_private_hosts && host_is_private_or_local(host) {
return Err(
"X-JDBC-URL private or local hosts are blocked by policy".to_string()
);
}
}
#[cfg(unix)]
Host::Unix(_) => {
return Err("X-JDBC-URL unix socket hosts are not allowed".to_string());
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jdbc_to_postgres_url() {
assert_eq!(
jdbc_to_postgres_url("jdbc:postgresql://localhost:5432/mydb"),
Some("postgres://localhost:5432/mydb".to_string())
);
assert_eq!(
jdbc_to_postgres_url("jdbc:postgresql://user:pass@host/db"),
Some("postgres://user:pass@host/db".to_string())
);
assert_eq!(
jdbc_to_postgres_url("postgres://localhost/db"),
Some("postgres://localhost/db".to_string())
);
assert_eq!(jdbc_to_postgres_url("jdbc:mysql://localhost/db"), None);
}
#[test]
fn test_validate_postgres_target_rejects_local_when_private_blocked() {
let result: Result<(), String> =
validate_postgres_target("postgres://localhost:5432/test", false, &[]);
assert!(result.is_err());
}
#[test]
fn test_validate_postgres_target_allows_allowlisted_host() {
let allowed: Vec<String> = vec!["db.example.com".to_string()];
let result = validate_postgres_target(
"postgres://user:pass@db.example.com:5432/test",
false,
&allowed,
);
assert!(result.is_ok());
}
#[test]
fn test_validate_postgres_target_rejects_non_allowlisted_host() {
let allowed: Vec<String> = vec!["db.example.com".to_string()];
let result = validate_postgres_target(
"postgres://user:pass@other.example.com:5432/test",
true,
&allowed,
);
assert!(result.is_err());
}
}