use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant as StdInstant};
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use hickory_resolver::TokioResolver;
use sha2::{Digest, Sha256};
use tokio::time::Instant;
use tracing::{debug, warn};
use crate::error::{Error, Result};
const ACME_CHALLENGE_PREFIX: &str = "_acme-challenge.";
pub const DEFAULT_PROPAGATION_TIMEOUT: Duration = Duration::from_secs(120);
pub const DEFAULT_PROPAGATION_INTERVAL: Duration = Duration::from_secs(4);
pub const DEFAULT_NAMESERVERS: &[&str] = &["8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53", "1.0.0.1:53"];
pub const DNS_TIMEOUT: Duration = Duration::from_secs(10);
pub fn to_fqdn(domain: &str) -> String {
if domain.ends_with('.') {
domain.to_string()
} else {
format!("{domain}.")
}
}
pub fn from_fqdn(fqdn: &str) -> String {
fqdn.strip_suffix('.').unwrap_or(fqdn).to_string()
}
pub fn challenge_record_name(domain: &str) -> String {
let clean = from_fqdn(domain);
to_fqdn(&format!("{ACME_CHALLENGE_PREFIX}{clean}"))
}
pub fn challenge_record_value(key_auth: &str) -> String {
let digest = Sha256::digest(key_auth.as_bytes());
URL_SAFE_NO_PAD.encode(digest)
}
pub fn is_valid_domain(domain: &str) -> bool {
let domain = domain.strip_suffix('.').unwrap_or(domain);
let domain = domain.strip_prefix("*.").unwrap_or(domain);
if domain.is_empty() || domain.len() > 253 {
return false;
}
let labels: Vec<&str> = domain.split('.').collect();
if labels.len() < 2 {
return false;
}
for label in &labels {
if label.is_empty() || label.len() > 63 {
return false;
}
if label.starts_with('-') || label.ends_with('-') {
return false;
}
if !label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
return false;
}
}
true
}
pub fn sanitize_domain(domain: &str) -> String {
let domain = domain.to_ascii_lowercase();
let domain = domain.strip_prefix("*.").unwrap_or(&domain);
let domain = domain.strip_suffix('.').unwrap_or(domain);
domain
.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || c == '-' || c == '.' {
c
} else {
'_'
}
})
.collect()
}
pub async fn check_dns_propagation(
fqdn: &str,
expected_value: &str,
timeout: Duration,
interval: Duration,
) -> Result<bool> {
let fqdn_normalized = to_fqdn(fqdn);
let lookup_name = from_fqdn(&fqdn_normalized);
let deadline = Instant::now() + timeout;
debug!(
fqdn = %fqdn_normalized,
expected_value,
"starting DNS propagation check"
);
loop {
match try_lookup_txt(&lookup_name, expected_value).await {
Ok(true) => {
debug!(
fqdn = %fqdn_normalized,
"DNS propagation confirmed"
);
return Ok(true);
}
Ok(false) => {
debug!(
fqdn = %fqdn_normalized,
"TXT record not yet propagated, will retry"
);
}
Err(e) => {
warn!(
fqdn = %fqdn_normalized,
error = %e,
"DNS lookup error during propagation check"
);
}
}
if Instant::now() >= deadline {
return Err(Error::Timeout(format!(
"DNS propagation check for {fqdn_normalized} timed out after {timeout:?}"
)));
}
tokio::time::sleep(interval).await;
}
}
async fn try_lookup_txt(name: &str, expected_value: &str) -> std::result::Result<bool, String> {
let txt_records = lookup_txt(name).await.map_err(|e| format!("{e}"))?;
for record in &txt_records {
let cleaned = record.trim_matches('"');
if cleaned == expected_value {
return Ok(true);
}
}
Ok(false)
}
async fn lookup_txt(fqdn: &str) -> Result<Vec<String>> {
let resolver = TokioResolver::builder_tokio()
.map_err(|e| Error::Other(format!("failed to create DNS resolver: {e}")))?
.build();
let response = resolver
.txt_lookup(fqdn)
.await
.map_err(|e| Error::Other(format!("DNS TXT lookup failed for {fqdn}: {e}")))?;
Ok(response
.iter()
.map(|r: &hickory_resolver::proto::rr::rdata::TXT| r.to_string())
.collect())
}
const ZONE_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
struct ZoneCacheEntry {
zone: Option<String>,
expires_at: StdInstant,
}
static ZONE_CACHE: std::sync::OnceLock<Mutex<HashMap<String, ZoneCacheEntry>>> =
std::sync::OnceLock::new();
fn zone_cache() -> &'static Mutex<HashMap<String, ZoneCacheEntry>> {
ZONE_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}
fn zone_cache_get(domain: &str) -> Option<Option<String>> {
let cache = zone_cache();
let map = cache.lock().ok()?;
let entry = map.get(domain)?;
if StdInstant::now() < entry.expires_at {
Some(entry.zone.clone())
} else {
None }
}
fn zone_cache_set(domain: &str, zone: Option<String>) {
if let Ok(mut map) = zone_cache().lock() {
if map.len() > 1000 {
let now = StdInstant::now();
map.retain(|_, v| now < v.expires_at);
}
map.insert(
domain.to_string(),
ZoneCacheEntry {
zone,
expires_at: StdInstant::now() + ZONE_CACHE_TTL,
},
);
}
}
pub fn clear_zone_cache() {
if let Ok(mut map) = zone_cache().lock() {
map.clear();
}
}
pub fn find_zone_by_fqdn(fqdn: &str) -> Option<String> {
let normalized = from_fqdn(fqdn).to_lowercase();
if let Some(cached) = zone_cache_get(&normalized) {
return cached;
}
let result = find_zone_by_fqdn_heuristic(fqdn);
zone_cache_set(&normalized, result.clone());
result
}
pub async fn find_zone_by_fqdn_async(fqdn: &str) -> Option<String> {
let normalized = from_fqdn(fqdn).to_lowercase();
if let Some(cached) = zone_cache_get(&normalized) {
return cached;
}
let domain = from_fqdn(fqdn);
let labels: Vec<&str> = domain.split('.').collect();
if labels.len() < 2 {
zone_cache_set(&normalized, None);
return None;
}
let resolver = match TokioResolver::builder_tokio() {
Ok(builder) => builder.build(),
Err(_) => {
let result = find_zone_by_fqdn_heuristic(fqdn);
zone_cache_set(&normalized, result.clone());
return result;
}
};
for i in 0..labels.len() - 1 {
let candidate = labels[i..].join(".");
if candidate.split('.').count() < 2 {
break;
}
match resolver.soa_lookup(&candidate).await {
Ok(_) => {
debug!(zone = %candidate, "found SOA record for zone");
let result = Some(to_fqdn(&candidate));
zone_cache_set(&normalized, result.clone());
return result;
}
Err(_) => continue,
}
}
let result = find_zone_by_fqdn_heuristic(fqdn);
zone_cache_set(&normalized, result.clone());
result
}
fn find_zone_by_fqdn_heuristic(fqdn: &str) -> Option<String> {
let domain = from_fqdn(fqdn);
let labels: Vec<&str> = domain.split('.').collect();
if labels.len() < 2 {
return None;
}
let zone = labels[labels.len() - 2..].join(".");
Some(to_fqdn(&zone))
}
pub fn populate_nameserver_ports(servers: &mut [String]) {
for server in servers.iter_mut() {
if !server.contains(':') {
*server = format!("{server}:53");
} else if server.starts_with('[') {
if !server.contains("]:") {
*server = format!("{server}:53");
}
}
}
}
pub fn system_nameservers() -> Vec<String> {
#[cfg(unix)]
{
if let Ok(contents) = std::fs::read_to_string("/etc/resolv.conf") {
let servers: Vec<String> = contents
.lines()
.filter_map(|line| {
let trimmed = line.trim();
if trimmed.starts_with("nameserver") {
trimmed.split_whitespace().nth(1).map(|s| s.to_string())
} else {
None
}
})
.collect();
if !servers.is_empty() {
return servers;
}
}
DEFAULT_NAMESERVERS.iter().map(|s| s.to_string()).collect()
}
#[cfg(not(unix))]
{
DEFAULT_NAMESERVERS.iter().map(|s| s.to_string()).collect()
}
}
pub fn recursive_nameservers(custom: &[String]) -> Vec<String> {
let mut servers: Vec<String> = if custom.is_empty() {
let system = system_nameservers();
if system.is_empty() {
DEFAULT_NAMESERVERS.iter().map(|s| s.to_string()).collect()
} else {
system
}
} else {
custom.to_vec()
};
populate_nameserver_ports(&mut servers);
servers
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_to_fqdn_without_dot() {
assert_eq!(to_fqdn("example.com"), "example.com.");
}
#[test]
fn test_to_fqdn_with_dot() {
assert_eq!(to_fqdn("example.com."), "example.com.");
}
#[test]
fn test_from_fqdn_with_dot() {
assert_eq!(from_fqdn("example.com."), "example.com");
}
#[test]
fn test_from_fqdn_without_dot() {
assert_eq!(from_fqdn("example.com"), "example.com");
}
#[test]
fn test_challenge_record_name_simple() {
assert_eq!(
challenge_record_name("example.com"),
"_acme-challenge.example.com."
);
}
#[test]
fn test_challenge_record_name_fqdn() {
assert_eq!(
challenge_record_name("example.com."),
"_acme-challenge.example.com."
);
}
#[test]
fn test_challenge_record_name_subdomain() {
assert_eq!(
challenge_record_name("sub.example.com"),
"_acme-challenge.sub.example.com."
);
}
#[test]
fn test_challenge_record_value_deterministic() {
let v1 = challenge_record_value("token.thumbprint");
let v2 = challenge_record_value("token.thumbprint");
assert_eq!(v1, v2);
}
#[test]
fn test_challenge_record_value_no_padding() {
let value = challenge_record_value("token.thumbprint");
assert!(!value.contains('='));
assert!(!value.contains('+'));
assert!(!value.contains('/'));
}
#[test]
fn test_challenge_record_value_known_vector() {
let value = challenge_record_value("test");
assert_eq!(value, "n4bQgYhMfWWaL-qgxVrQFaO_TxsrC4Is0V1sFbDwCgg");
}
#[test]
fn test_valid_domains() {
assert!(is_valid_domain("example.com"));
assert!(is_valid_domain("sub.example.com"));
assert!(is_valid_domain("a.b.c.example.com"));
assert!(is_valid_domain("example.com."));
assert!(is_valid_domain("*.example.com"));
assert!(is_valid_domain("xn--nxasmq6b.example.com")); }
#[test]
fn test_invalid_domains() {
assert!(!is_valid_domain(""));
assert!(!is_valid_domain("."));
assert!(!is_valid_domain("com"));
assert!(!is_valid_domain("-bad.com"));
assert!(!is_valid_domain("bad-.com"));
assert!(!is_valid_domain("ex ample.com"));
assert!(!is_valid_domain("example..com"));
assert!(!is_valid_domain(".example.com"));
}
#[test]
fn test_sanitize_domain_basic() {
assert_eq!(sanitize_domain("Example.COM"), "example.com");
}
#[test]
fn test_sanitize_domain_wildcard() {
assert_eq!(sanitize_domain("*.example.com"), "example.com");
}
#[test]
fn test_sanitize_domain_trailing_dot() {
assert_eq!(sanitize_domain("example.com."), "example.com");
}
#[test]
fn test_sanitize_domain_special_chars() {
assert_eq!(sanitize_domain("ex@mple.com"), "ex_mple.com");
}
#[test]
fn test_find_zone_subdomain() {
assert_eq!(
find_zone_by_fqdn("sub.example.com"),
Some("example.com.".to_string())
);
}
#[test]
fn test_find_zone_apex() {
assert_eq!(
find_zone_by_fqdn("example.com"),
Some("example.com.".to_string())
);
}
#[test]
fn test_find_zone_fqdn() {
assert_eq!(
find_zone_by_fqdn("example.com."),
Some("example.com.".to_string())
);
}
#[test]
fn test_find_zone_single_label() {
assert_eq!(find_zone_by_fqdn("com"), None);
}
#[test]
fn test_populate_nameserver_ports_adds_default() {
let mut servers = vec!["8.8.8.8".to_string()];
populate_nameserver_ports(&mut servers);
assert_eq!(servers, vec!["8.8.8.8:53"]);
}
#[test]
fn test_populate_nameserver_ports_preserves_existing() {
let mut servers = vec!["8.8.8.8:5353".to_string()];
populate_nameserver_ports(&mut servers);
assert_eq!(servers, vec!["8.8.8.8:5353"]);
}
#[test]
fn test_recursive_nameservers_defaults() {
let servers = recursive_nameservers(&[]);
assert!(!servers.is_empty());
assert!(servers.iter().all(|s| s.ends_with(":53")));
}
#[test]
fn test_recursive_nameservers_custom() {
let custom = vec!["10.0.0.1".to_string()];
let servers = recursive_nameservers(&custom);
assert_eq!(servers, vec!["10.0.0.1:53"]);
}
}