use serde::Serialize;
use std::collections::BTreeSet;
use std::env;
use std::net::IpAddr;
use tokio::net::lookup_host;
pub const DEFAULT_MANAGED_CLUSTER_NAME: &str = "athena-cluster";
pub const DEFAULT_MANAGED_BASE_DOMAIN: &str = "athena-cluster.com";
pub const DEFAULT_WILDCARD_HOST_PATTERN: &str = "*.v3.athena-cluster.com";
pub const DEFAULT_WILDCARD_TARGET_BASE_URL: &str = "https://pool.athena-cluster.com";
const ENV_WILDCARD_HOST_PATTERN: &str = "ATHENA_WILDCARD_HOST_PATTERN";
#[derive(Debug, Clone, Serialize, PartialEq, Eq, PartialOrd, Ord)]
pub struct DnsLookupRecord {
pub address: String,
pub family: String,
}
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
pub struct DnsLookupOutcome {
pub host: String,
pub resolves: bool,
pub records: Vec<DnsLookupRecord>,
pub error: Option<String>,
}
pub fn wildcard_host_pattern_from_env() -> String {
env::var(ENV_WILDCARD_HOST_PATTERN)
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.unwrap_or_else(|| DEFAULT_WILDCARD_HOST_PATTERN.to_string())
}
pub fn managed_base_domain_from_pattern(pattern: &str) -> Option<String> {
let suffix = wildcard_suffix_from_pattern(pattern)?;
suffix
.split_once('.')
.map(|(_, remainder)| remainder.to_string())
.or(Some(suffix))
}
pub fn wildcard_suffix_from_pattern(pattern: &str) -> Option<String> {
let normalized = pattern.trim().trim_end_matches('.').to_ascii_lowercase();
let suffix = normalized.strip_prefix("*.")?.trim();
if suffix.is_empty() || suffix.contains('*') {
return None;
}
Some(suffix.to_string())
}
pub fn build_host_from_prefix(prefix: &str, wildcard_pattern: &str) -> Option<String> {
let normalized_prefix = prefix.trim().trim_end_matches('.').to_ascii_lowercase();
if !is_prefix_label_valid(&normalized_prefix) {
return None;
}
let suffix = wildcard_suffix_from_pattern(wildcard_pattern)?;
Some(format!("{normalized_prefix}.{suffix}"))
}
pub fn wildcard_public_host_for_route_key(route_key: &str) -> Result<String, String> {
let wildcard_pattern = wildcard_host_pattern_from_env();
build_host_from_prefix(route_key, &wildcard_pattern).ok_or_else(|| {
format!(
"Failed to derive a wildcard host for route key '{}' from pattern '{}'.",
route_key, wildcard_pattern
)
})
}
pub fn postgres_major_version_from_image(image: &str) -> Option<String> {
let tag = image.split_once(':')?.1.trim();
let version = tag.split('-').next()?.trim();
if version.is_empty() {
return None;
}
Some(version.to_string())
}
pub async fn check_dns_host(host: &str) -> DnsLookupOutcome {
let normalized_host = match normalize_host(host) {
Some(value) => value,
None => {
return DnsLookupOutcome {
host: host.trim().to_string(),
resolves: false,
records: Vec::new(),
error: Some("host is empty or invalid".to_string()),
};
}
};
let lookup_result = lookup_host((normalized_host.clone(), 0)).await;
match lookup_result {
Ok(addresses) => {
let mut unique = BTreeSet::new();
for socket in addresses {
unique.insert(socket.ip());
}
let records = unique
.into_iter()
.map(|ip| DnsLookupRecord {
address: ip.to_string(),
family: ip_family(ip).to_string(),
})
.collect::<Vec<_>>();
DnsLookupOutcome {
host: normalized_host,
resolves: !records.is_empty(),
records,
error: None,
}
}
Err(err) => DnsLookupOutcome {
host: normalized_host,
resolves: false,
records: Vec::new(),
error: Some(err.to_string()),
},
}
}
fn normalize_host(value: &str) -> Option<String> {
let mut host = value.trim();
if host.is_empty() {
return None;
}
if let Some((_, remainder)) = host.split_once("://") {
host = remainder;
}
if let Some((without_path, _)) = host.split_once('/') {
host = without_path;
}
if let Some((first, _)) = host.split_once(',') {
host = first.trim();
}
if host.is_empty() {
return None;
}
if host.starts_with('[') {
return host
.find(']')
.map(|idx| host[1..idx].to_string())
.filter(|value| !value.is_empty());
}
if host.matches(':').count() == 1 {
return host
.split_once(':')
.map(|(candidate, _)| candidate.to_string())
.filter(|candidate| !candidate.is_empty());
}
Some(host.to_string())
}
fn is_prefix_label_valid(prefix: &str) -> bool {
!prefix.is_empty()
&& !prefix.contains('.')
&& prefix
.chars()
.all(|ch| ch.is_ascii_lowercase() || ch.is_ascii_digit() || ch == '-' || ch == '_')
}
fn ip_family(ip: IpAddr) -> &'static str {
match ip {
IpAddr::V4(_) => "ipv4",
IpAddr::V6(_) => "ipv6",
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn builds_client_host_from_default_pattern() {
let host = build_host_from_prefix("logging", DEFAULT_WILDCARD_HOST_PATTERN);
assert_eq!(host.as_deref(), Some("logging.v3.athena-cluster.com"));
}
#[test]
fn resolves_managed_base_domain_from_pattern() {
assert_eq!(
managed_base_domain_from_pattern("*.v3.athena-cluster.com").as_deref(),
Some("athena-cluster.com")
);
}
#[test]
fn extracts_postgres_major_version_from_image_tag() {
assert_eq!(
postgres_major_version_from_image("postgres:17-alpine").as_deref(),
Some("17")
);
}
#[test]
fn wildcard_host_pattern_uses_env_override() {
let _guard = ENV_LOCK.lock().expect("lock env");
unsafe {
env::set_var(ENV_WILDCARD_HOST_PATTERN, "*.v4.athena-cluster.com");
}
assert_eq!(wildcard_host_pattern_from_env(), "*.v4.athena-cluster.com");
unsafe {
env::remove_var(ENV_WILDCARD_HOST_PATTERN);
}
}
}