Skip to main content

athena_dns/
lib.rs

1use serde::Serialize;
2use std::collections::BTreeSet;
3use std::env;
4use std::net::IpAddr;
5use tokio::net::lookup_host;
6
7pub const DEFAULT_MANAGED_CLUSTER_NAME: &str = "athena-cluster";
8pub const DEFAULT_MANAGED_BASE_DOMAIN: &str = "athena-cluster.com";
9pub const DEFAULT_WILDCARD_HOST_PATTERN: &str = "*.v3.athena-cluster.com";
10pub const DEFAULT_WILDCARD_TARGET_BASE_URL: &str = "https://pool.athena-cluster.com";
11
12const ENV_WILDCARD_HOST_PATTERN: &str = "ATHENA_WILDCARD_HOST_PATTERN";
13
14#[derive(Debug, Clone, Serialize, PartialEq, Eq, PartialOrd, Ord)]
15pub struct DnsLookupRecord {
16    pub address: String,
17    pub family: String,
18}
19
20#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
21pub struct DnsLookupOutcome {
22    pub host: String,
23    pub resolves: bool,
24    pub records: Vec<DnsLookupRecord>,
25    pub error: Option<String>,
26}
27
28pub fn wildcard_host_pattern_from_env() -> String {
29    env::var(ENV_WILDCARD_HOST_PATTERN)
30        .ok()
31        .map(|value| value.trim().to_string())
32        .filter(|value| !value.is_empty())
33        .unwrap_or_else(|| DEFAULT_WILDCARD_HOST_PATTERN.to_string())
34}
35
36pub fn managed_base_domain_from_pattern(pattern: &str) -> Option<String> {
37    let suffix = wildcard_suffix_from_pattern(pattern)?;
38    suffix
39        .split_once('.')
40        .map(|(_, remainder)| remainder.to_string())
41        .or(Some(suffix))
42}
43
44pub fn wildcard_suffix_from_pattern(pattern: &str) -> Option<String> {
45    let normalized = pattern.trim().trim_end_matches('.').to_ascii_lowercase();
46    let suffix = normalized.strip_prefix("*.")?.trim();
47    if suffix.is_empty() || suffix.contains('*') {
48        return None;
49    }
50    Some(suffix.to_string())
51}
52
53pub fn build_host_from_prefix(prefix: &str, wildcard_pattern: &str) -> Option<String> {
54    let normalized_prefix = prefix.trim().trim_end_matches('.').to_ascii_lowercase();
55    if !is_prefix_label_valid(&normalized_prefix) {
56        return None;
57    }
58
59    let suffix = wildcard_suffix_from_pattern(wildcard_pattern)?;
60    Some(format!("{normalized_prefix}.{suffix}"))
61}
62
63pub fn wildcard_public_host_for_route_key(route_key: &str) -> Result<String, String> {
64    let wildcard_pattern = wildcard_host_pattern_from_env();
65    build_host_from_prefix(route_key, &wildcard_pattern).ok_or_else(|| {
66        format!(
67            "Failed to derive a wildcard host for route key '{}' from pattern '{}'.",
68            route_key, wildcard_pattern
69        )
70    })
71}
72
73pub fn postgres_major_version_from_image(image: &str) -> Option<String> {
74    let tag = image.split_once(':')?.1.trim();
75    let version = tag.split('-').next()?.trim();
76    if version.is_empty() {
77        return None;
78    }
79    Some(version.to_string())
80}
81
82pub async fn check_dns_host(host: &str) -> DnsLookupOutcome {
83    let normalized_host = match normalize_host(host) {
84        Some(value) => value,
85        None => {
86            return DnsLookupOutcome {
87                host: host.trim().to_string(),
88                resolves: false,
89                records: Vec::new(),
90                error: Some("host is empty or invalid".to_string()),
91            };
92        }
93    };
94
95    let lookup_result = lookup_host((normalized_host.clone(), 0)).await;
96    match lookup_result {
97        Ok(addresses) => {
98            let mut unique = BTreeSet::new();
99            for socket in addresses {
100                unique.insert(socket.ip());
101            }
102
103            let records = unique
104                .into_iter()
105                .map(|ip| DnsLookupRecord {
106                    address: ip.to_string(),
107                    family: ip_family(ip).to_string(),
108                })
109                .collect::<Vec<_>>();
110
111            DnsLookupOutcome {
112                host: normalized_host,
113                resolves: !records.is_empty(),
114                records,
115                error: None,
116            }
117        }
118        Err(err) => DnsLookupOutcome {
119            host: normalized_host,
120            resolves: false,
121            records: Vec::new(),
122            error: Some(err.to_string()),
123        },
124    }
125}
126
127fn normalize_host(value: &str) -> Option<String> {
128    let mut host = value.trim();
129    if host.is_empty() {
130        return None;
131    }
132
133    if let Some((_, remainder)) = host.split_once("://") {
134        host = remainder;
135    }
136    if let Some((without_path, _)) = host.split_once('/') {
137        host = without_path;
138    }
139    if let Some((first, _)) = host.split_once(',') {
140        host = first.trim();
141    }
142    if host.is_empty() {
143        return None;
144    }
145
146    if host.starts_with('[') {
147        return host
148            .find(']')
149            .map(|idx| host[1..idx].to_string())
150            .filter(|value| !value.is_empty());
151    }
152
153    if host.matches(':').count() == 1 {
154        return host
155            .split_once(':')
156            .map(|(candidate, _)| candidate.to_string())
157            .filter(|candidate| !candidate.is_empty());
158    }
159
160    Some(host.to_string())
161}
162
163fn is_prefix_label_valid(prefix: &str) -> bool {
164    !prefix.is_empty()
165        && !prefix.contains('.')
166        && prefix
167            .chars()
168            .all(|ch| ch.is_ascii_lowercase() || ch.is_ascii_digit() || ch == '-' || ch == '_')
169}
170
171fn ip_family(ip: IpAddr) -> &'static str {
172    match ip {
173        IpAddr::V4(_) => "ipv4",
174        IpAddr::V6(_) => "ipv6",
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use std::sync::Mutex;
182
183    static ENV_LOCK: Mutex<()> = Mutex::new(());
184
185    #[test]
186    fn builds_client_host_from_default_pattern() {
187        let host = build_host_from_prefix("logging", DEFAULT_WILDCARD_HOST_PATTERN);
188        assert_eq!(host.as_deref(), Some("logging.v3.athena-cluster.com"));
189    }
190
191    #[test]
192    fn resolves_managed_base_domain_from_pattern() {
193        assert_eq!(
194            managed_base_domain_from_pattern("*.v3.athena-cluster.com").as_deref(),
195            Some("athena-cluster.com")
196        );
197    }
198
199    #[test]
200    fn extracts_postgres_major_version_from_image_tag() {
201        assert_eq!(
202            postgres_major_version_from_image("postgres:17-alpine").as_deref(),
203            Some("17")
204        );
205    }
206
207    #[test]
208    fn wildcard_host_pattern_uses_env_override() {
209        let _guard = ENV_LOCK.lock().expect("lock env");
210        unsafe {
211            env::set_var(ENV_WILDCARD_HOST_PATTERN, "*.v4.athena-cluster.com");
212        }
213        assert_eq!(wildcard_host_pattern_from_env(), "*.v4.athena-cluster.com");
214        unsafe {
215            env::remove_var(ENV_WILDCARD_HOST_PATTERN);
216        }
217    }
218}