use async_trait::async_trait;
use std::fmt::Debug;
use thiserror::Error;
pub type DnsResult<T> = Result<T, DnsProviderError>;
#[derive(Debug, Error)]
pub enum DnsProviderError {
#[error("Authentication failed: {0}")]
Authentication(String),
#[error("Zone not found for domain '{domain}'")]
ZoneNotFound { domain: String },
#[error("Failed to create TXT record for '{record_name}': {message}")]
RecordCreation {
record_name: String,
message: String,
},
#[error("Failed to delete TXT record '{record_id}': {message}")]
RecordDeletion { record_id: String, message: String },
#[error("API request failed: {0}")]
ApiRequest(String),
#[error("Rate limited by DNS provider, retry after {retry_after_secs}s")]
RateLimited { retry_after_secs: u64 },
#[error("Request timed out after {elapsed_secs}s")]
Timeout { elapsed_secs: u64 },
#[error("Invalid configuration: {0}")]
Configuration(String),
#[error("Failed to load credentials: {0}")]
Credentials(String),
#[error("Domain '{domain}' is not supported by this provider")]
UnsupportedDomain { domain: String },
}
#[async_trait]
pub trait DnsProvider: Send + Sync + Debug {
fn name(&self) -> &'static str;
async fn create_txt_record(
&self,
domain: &str,
record_name: &str,
record_value: &str,
) -> DnsResult<String>;
async fn delete_txt_record(&self, domain: &str, record_id: &str) -> DnsResult<()>;
async fn supports_domain(&self, domain: &str) -> DnsResult<bool>;
}
pub const ACME_CHALLENGE_RECORD: &str = "_acme-challenge";
pub const CHALLENGE_TTL: u32 = 60;
pub fn normalize_domain(domain: &str) -> &str {
domain.strip_prefix("*.").unwrap_or(domain)
}
pub fn challenge_record_fqdn(domain: &str) -> String {
let normalized = normalize_domain(domain);
format!("{}.{}", ACME_CHALLENGE_RECORD, normalized)
}
#[cfg(test)]
mod tests {
use super::*;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
#[test]
fn test_normalize_domain() {
assert_eq!(normalize_domain("example.com"), "example.com");
assert_eq!(normalize_domain("*.example.com"), "example.com");
assert_eq!(normalize_domain("sub.example.com"), "sub.example.com");
assert_eq!(normalize_domain("*.sub.example.com"), "sub.example.com");
}
#[test]
fn test_challenge_record_fqdn() {
assert_eq!(
challenge_record_fqdn("example.com"),
"_acme-challenge.example.com"
);
assert_eq!(
challenge_record_fqdn("*.example.com"),
"_acme-challenge.example.com"
);
assert_eq!(
challenge_record_fqdn("sub.example.com"),
"_acme-challenge.sub.example.com"
);
}
#[test]
fn test_dns_provider_error_display() {
let err = DnsProviderError::Authentication("bad token".to_string());
assert!(err.to_string().contains("Authentication failed"));
let err = DnsProviderError::ZoneNotFound {
domain: "test.com".to_string(),
};
assert!(err.to_string().contains("test.com"));
let err = DnsProviderError::RecordCreation {
record_name: "_acme-challenge".to_string(),
message: "API error".to_string(),
};
assert!(err.to_string().contains("_acme-challenge"));
let err = DnsProviderError::RateLimited {
retry_after_secs: 60,
};
assert!(err.to_string().contains("60"));
let err = DnsProviderError::Timeout { elapsed_secs: 30 };
assert!(err.to_string().contains("30"));
}
#[derive(Debug)]
pub struct MockDnsProvider {
pub records: Mutex<HashMap<(String, String), (String, String)>>,
pub supported_domains: Vec<String>,
pub record_counter: AtomicU64,
pub fail_on_create: bool,
pub fail_on_delete: bool,
}
impl MockDnsProvider {
pub fn new(supported_domains: Vec<String>) -> Self {
Self {
records: Mutex::new(HashMap::new()),
supported_domains,
record_counter: AtomicU64::new(1),
fail_on_create: false,
fail_on_delete: false,
}
}
pub fn with_failure_on_create(mut self) -> Self {
self.fail_on_create = true;
self
}
pub fn with_failure_on_delete(mut self) -> Self {
self.fail_on_delete = true;
self
}
pub fn get_record(&self, domain: &str, record_name: &str) -> Option<(String, String)> {
self.records
.lock()
.get(&(domain.to_string(), record_name.to_string()))
.cloned()
}
pub fn record_count(&self) -> usize {
self.records.lock().len()
}
}
#[async_trait]
impl DnsProvider for MockDnsProvider {
fn name(&self) -> &'static str {
"mock"
}
async fn create_txt_record(
&self,
domain: &str,
record_name: &str,
record_value: &str,
) -> DnsResult<String> {
if self.fail_on_create {
return Err(DnsProviderError::RecordCreation {
record_name: record_name.to_string(),
message: "Mock failure".to_string(),
});
}
let record_id = format!(
"record-{}",
self.record_counter.fetch_add(1, Ordering::SeqCst)
);
self.records.lock().insert(
(domain.to_string(), record_name.to_string()),
(record_id.clone(), record_value.to_string()),
);
Ok(record_id)
}
async fn delete_txt_record(&self, domain: &str, record_id: &str) -> DnsResult<()> {
if self.fail_on_delete {
return Err(DnsProviderError::RecordDeletion {
record_id: record_id.to_string(),
message: "Mock failure".to_string(),
});
}
let mut records = self.records.lock();
records.retain(|_, (id, _)| id != record_id);
Ok(())
}
async fn supports_domain(&self, domain: &str) -> DnsResult<bool> {
let normalized = normalize_domain(domain);
Ok(self
.supported_domains
.iter()
.any(|d| normalized == *d || normalized.ends_with(&format!(".{}", d))))
}
}
#[tokio::test]
async fn test_mock_provider_create_record() {
let provider = MockDnsProvider::new(vec!["example.com".to_string()]);
let record_id = provider
.create_txt_record("example.com", "_acme-challenge", "test-value")
.await
.unwrap();
assert!(record_id.starts_with("record-"));
assert_eq!(provider.record_count(), 1);
let (stored_id, stored_value) = provider
.get_record("example.com", "_acme-challenge")
.unwrap();
assert_eq!(stored_id, record_id);
assert_eq!(stored_value, "test-value");
}
#[tokio::test]
async fn test_mock_provider_delete_record() {
let provider = MockDnsProvider::new(vec!["example.com".to_string()]);
let record_id = provider
.create_txt_record("example.com", "_acme-challenge", "test-value")
.await
.unwrap();
assert_eq!(provider.record_count(), 1);
provider
.delete_txt_record("example.com", &record_id)
.await
.unwrap();
assert_eq!(provider.record_count(), 0);
}
#[tokio::test]
async fn test_mock_provider_supports_domain() {
let provider = MockDnsProvider::new(vec!["example.com".to_string()]);
assert!(provider.supports_domain("example.com").await.unwrap());
assert!(provider.supports_domain("sub.example.com").await.unwrap());
assert!(provider.supports_domain("*.example.com").await.unwrap());
assert!(!provider.supports_domain("other.com").await.unwrap());
}
#[tokio::test]
async fn test_mock_provider_failure_on_create() {
let provider =
MockDnsProvider::new(vec!["example.com".to_string()]).with_failure_on_create();
let result = provider
.create_txt_record("example.com", "_acme-challenge", "test-value")
.await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
DnsProviderError::RecordCreation { .. }
));
}
#[tokio::test]
async fn test_mock_provider_failure_on_delete() {
let provider =
MockDnsProvider::new(vec!["example.com".to_string()]).with_failure_on_delete();
let result = provider.delete_txt_record("example.com", "record-1").await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
DnsProviderError::RecordDeletion { .. }
));
}
}