use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::Semaphore;
use trust_dns_resolver::config::{ResolverConfig, ResolverOpts};
use trust_dns_resolver::TokioAsyncResolver;
#[derive(Debug, Error, Clone)]
pub enum DnsError {
#[error("DNS resolver creation failed: {0}")]
ResolverCreation(String),
#[error("DNS lookup failed: {0}")]
LookupFailed(String),
#[error("DNS timeout after {0}ms")]
Timeout(u64),
#[error("DNS rate limit exceeded, try again later")]
RateLimited,
#[error(
"DNS verification failed: IP not in forward lookup results (possible rebinding attack)"
)]
IpMismatch,
}
impl From<trust_dns_resolver::error::ResolveError> for DnsError {
fn from(e: trust_dns_resolver::error::ResolveError) -> Self {
DnsError::ResolverCreation(e.to_string())
}
}
#[derive(Debug)]
pub struct DnsResolver {
resolver: TokioAsyncResolver,
timeout: Duration,
semaphore: Arc<Semaphore>,
max_concurrent: usize,
}
impl DnsResolver {
pub async fn new(timeout_ms: u64, max_concurrent: usize) -> Result<Self, DnsError> {
let mut opts = ResolverOpts::default();
opts.timeout = Duration::from_millis(timeout_ms);
opts.attempts = 2;
let resolver = TokioAsyncResolver::tokio(ResolverConfig::default(), opts);
Ok(Self {
resolver,
timeout: Duration::from_millis(timeout_ms),
semaphore: Arc::new(Semaphore::new(max_concurrent)),
max_concurrent,
})
}
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
pub fn max_concurrent(&self) -> usize {
self.max_concurrent
}
async fn acquire_permit(&self) -> Option<tokio::sync::SemaphorePermit<'_>> {
match self.semaphore.try_acquire() {
Ok(permit) => Some(permit),
Err(_) => {
tracing::warn!(
"DNS rate limit reached: {}/{} permits in use",
self.max_concurrent - self.semaphore.available_permits(),
self.max_concurrent
);
None
}
}
}
pub async fn reverse_lookup(&self, ip: IpAddr) -> Result<Option<String>, DnsError> {
let _permit = self.acquire_permit().await.ok_or(DnsError::RateLimited)?;
match tokio::time::timeout(self.timeout, self.resolver.reverse_lookup(ip)).await {
Ok(Ok(response)) => {
if let Some(record) = response.iter().next() {
Ok(Some(record.to_string().trim_end_matches('.').to_string()))
} else {
Ok(None)
}
}
Ok(Err(e)) => {
tracing::debug!("Reverse DNS lookup for {} failed: {}", ip, e);
Ok(None)
}
Err(_) => {
tracing::debug!(
"Reverse DNS lookup for {} timed out after {}ms",
ip,
self.timeout.as_millis()
);
Err(DnsError::Timeout(self.timeout.as_millis() as u64))
}
}
}
pub async fn forward_lookup(&self, hostname: &str) -> Result<Vec<IpAddr>, DnsError> {
let _permit = self.acquire_permit().await.ok_or(DnsError::RateLimited)?;
match tokio::time::timeout(self.timeout, self.resolver.lookup_ip(hostname)).await {
Ok(Ok(response)) => Ok(response.iter().collect()),
Ok(Err(e)) => {
tracing::debug!("Forward DNS lookup for {} failed: {}", hostname, e);
Err(DnsError::LookupFailed(e.to_string()))
}
Err(_) => {
tracing::debug!(
"Forward DNS lookup for {} timed out after {}ms",
hostname,
self.timeout.as_millis()
);
Err(DnsError::Timeout(self.timeout.as_millis() as u64))
}
}
}
pub async fn verify_ip(&self, ip: IpAddr) -> Result<(bool, Option<String>), DnsError> {
let hostname = match self.reverse_lookup(ip).await? {
Some(h) => h,
None => return Ok((false, None)),
};
let resolved_ips = match self.forward_lookup(&hostname).await {
Ok(ips) => ips,
Err(DnsError::RateLimited) => return Err(DnsError::RateLimited),
Err(_) => return Ok((false, Some(hostname))),
};
let verified = resolved_ips.contains(&ip);
if !verified {
tracing::warn!(
ip = %ip,
hostname = %hostname,
resolved_ips = ?resolved_ips,
"DNS rebinding check failed: requesting IP not in forward lookup results"
);
}
Ok((verified, Some(hostname)))
}
pub async fn verify_ip_strict(&self, ip: IpAddr) -> Result<String, DnsError> {
let hostname = match self.reverse_lookup(ip).await? {
Some(h) => h,
None => return Err(DnsError::LookupFailed("No PTR record".to_string())),
};
let resolved_ips = self.forward_lookup(&hostname).await?;
if !resolved_ips.contains(&ip) {
tracing::warn!(
ip = %ip,
hostname = %hostname,
resolved_ips = ?resolved_ips,
"DNS rebinding attack detected: IP not in forward lookup results"
);
return Err(DnsError::IpMismatch);
}
Ok(hostname)
}
}