use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use hickory_proto::op::{Message, MessageType, OpCode, Query, ResponseCode};
use hickory_proto::rr::{Name, RData, RecordType};
use hickory_proto::ProtoError;
use hickory_resolver::config::ResolverConfig;
use hickory_resolver::net::runtime::TokioRuntimeProvider;
use hickory_resolver::TokioResolver;
use thiserror::Error;
use tokio::net::UdpSocket;
use tokio::runtime::Runtime;
use tokio::sync::Semaphore;
use tokio::time::timeout as tokio_timeout;
use crate::rdap::{self, RdapOutcome};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DnsVerdict {
Registered { detail: String },
Available { detail: String },
Nodata { detail: String },
Failure { detail: String },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FullVerdict {
pub kind: &'static str, pub detail: String,
}
#[derive(Debug, Clone)]
pub struct FullCheckJob {
pub zone: String,
pub registered: String,
pub rdap_url: Option<String>,
}
#[derive(Debug, Error)]
pub enum DnsError {
#[error("resolver init: {0}")]
Init(String),
#[error("name parse: {0}")]
Name(#[from] ProtoError),
}
const RDAP_PER_HOST_CONCURRENCY: usize = 2;
const RDAP_PER_HOST_MIN_GAP: Duration = Duration::from_millis(100);
const NS_HOST_COOLDOWN: Duration = Duration::from_secs(300);
#[derive(Default)]
struct HostHealth {
blocked: Mutex<HashMap<String, (Instant, String)>>,
}
impl HostHealth {
fn is_healthy(&self, host: &str) -> bool {
let mut guard = self.blocked.lock().expect("host health lock");
match guard.get(host) {
Some((until, _)) if Instant::now() < *until => false,
Some(_) => {
guard.remove(host);
true
}
None => true,
}
}
fn mark(&self, host: &str, reason: String) {
let mut guard = self.blocked.lock().expect("host health lock");
guard.insert(
host.to_string(),
(Instant::now() + NS_HOST_COOLDOWN, reason),
);
}
}
#[derive(Default)]
struct RdapThrottle {
hosts: Mutex<HashMap<String, Arc<Semaphore>>>,
}
impl RdapThrottle {
fn permit_for(&self, host: &str) -> Arc<Semaphore> {
let mut guard = self.hosts.lock().expect("rdap throttle lock");
guard
.entry(host.to_string())
.or_insert_with(|| Arc::new(Semaphore::new(RDAP_PER_HOST_CONCURRENCY)))
.clone()
}
}
pub struct DnsClient {
resolver: TokioResolver,
runtime: Arc<Runtime>,
http: reqwest::Client,
rdap_throttle: Arc<RdapThrottle>,
host_health: Arc<HostHealth>,
timeout: Duration,
}
impl DnsClient {
pub fn new(timeout: Duration) -> Result<Self, DnsError> {
let runtime = Runtime::new().map_err(|e| DnsError::Init(e.to_string()))?;
let runtime = Arc::new(runtime);
let resolver = runtime
.block_on(async {
let builder = TokioResolver::builder_tokio().unwrap_or_else(|_| {
TokioResolver::builder_with_config(
ResolverConfig::default(),
TokioRuntimeProvider::default(),
)
});
builder.build()
})
.map_err(|e| DnsError::Init(e.to_string()))?;
let http = rdap::build_client(timeout).map_err(|e| DnsError::Init(e.to_string()))?;
Ok(Self {
resolver,
runtime,
http,
rdap_throttle: Arc::new(RdapThrottle::default()),
host_health: Arc::new(HostHealth::default()),
timeout,
})
}
pub fn check_authoritative(&self, zone: &str, registered: &str) -> DnsVerdict {
self.runtime.block_on(check_authoritative_async(
&self.resolver,
&self.host_health,
zone,
registered,
self.timeout,
))
}
pub fn check_batch(&self, pairs: Vec<(String, String)>, concurrency: usize) -> Vec<DnsVerdict> {
let resolver = self.resolver.clone();
let health = self.host_health.clone();
let timeout = self.timeout;
self.runtime.block_on(async move {
let semaphore = Arc::new(Semaphore::new(concurrency.max(1)));
let tasks: Vec<_> = pairs
.into_iter()
.map(|(zone, registered)| {
let sem = semaphore.clone();
let res = resolver.clone();
let h = health.clone();
tokio::spawn(async move {
let _permit = sem.acquire_owned().await.expect("semaphore not closed");
check_authoritative_async(&res, &h, &zone, ®istered, timeout).await
})
})
.collect();
let mut out = Vec::with_capacity(tasks.len());
for t in tasks {
out.push(t.await.unwrap_or_else(|e| DnsVerdict::Failure {
detail: format!("task join: {e}"),
}));
}
out
})
}
pub fn check_full_batch(
&self,
jobs: Vec<FullCheckJob>,
concurrency: usize,
) -> Vec<FullVerdict> {
let resolver = self.resolver.clone();
let http = self.http.clone();
let throttle = self.rdap_throttle.clone();
let health = self.host_health.clone();
let timeout = self.timeout;
self.runtime.block_on(async move {
let semaphore = Arc::new(Semaphore::new(concurrency.max(1)));
let tasks: Vec<_> = jobs
.into_iter()
.map(|job| {
let sem = semaphore.clone();
let res = resolver.clone();
let http = http.clone();
let throttle = throttle.clone();
let h = health.clone();
tokio::spawn(async move {
let _permit = sem.acquire_owned().await.expect("semaphore not closed");
run_full_job(&res, &http, &throttle, &h, job, timeout).await
})
})
.collect();
let mut out = Vec::with_capacity(tasks.len());
for t in tasks {
out.push(t.await.unwrap_or_else(|e| FullVerdict {
kind: "failure",
detail: format!("task join: {e}"),
}));
}
out
})
}
}
async fn run_full_job(
resolver: &TokioResolver,
http: &reqwest::Client,
throttle: &RdapThrottle,
health: &HostHealth,
job: FullCheckJob,
timeout: Duration,
) -> FullVerdict {
let dns =
check_authoritative_async(resolver, health, &job.zone, &job.registered, timeout).await;
match dns {
DnsVerdict::Registered { detail } => FullVerdict {
kind: "registered",
detail,
},
DnsVerdict::Available { detail } => FullVerdict {
kind: "available",
detail,
},
DnsVerdict::Failure { detail } => FullVerdict {
kind: "failure",
detail,
},
DnsVerdict::Nodata { detail } => match job.rdap_url {
Some(base) => {
let host = rdap_host(&base).unwrap_or_else(|| base.clone());
let semaphore = throttle.permit_for(&host);
let permit = semaphore.acquire_owned().await.expect("rdap permit");
let outcome = rdap::lookup(http, &job.registered, &base).await;
tokio::time::sleep(RDAP_PER_HOST_MIN_GAP).await;
drop(permit);
match outcome {
Some(RdapOutcome::Registered { detail: rdetail }) => FullVerdict {
kind: "registered",
detail: format!("{detail}; {rdetail}"),
},
Some(RdapOutcome::Available { detail: rdetail }) => FullVerdict {
kind: "available",
detail: format!("{detail}; {rdetail}"),
},
None => FullVerdict {
kind: "available",
detail: format!("{detail}; RDAP inconclusive"),
},
}
}
None => FullVerdict {
kind: "available",
detail: format!("{detail} (no RDAP configured for zone '{}')", job.zone),
},
},
}
}
fn rdap_host(base_url: &str) -> Option<String> {
let stripped = base_url.split("://").nth(1).unwrap_or(base_url);
stripped.split('/').next().map(|s| s.to_ascii_lowercase())
}
async fn check_authoritative_async(
resolver: &TokioResolver,
health: &HostHealth,
zone: &str,
registered: &str,
timeout: Duration,
) -> DnsVerdict {
let zone_name = match Name::from_str(zone) {
Ok(n) => n,
Err(e) => {
return DnsVerdict::Failure {
detail: format!("bad zone {zone}: {e}"),
}
}
};
let all_nameservers: Vec<String> = match resolver.lookup(zone_name, RecordType::NS).await {
Ok(lookup) => lookup
.answers()
.iter()
.filter_map(|record| match &record.data {
RData::NS(ns) => Some(ns.0.to_utf8().trim_end_matches('.').to_string()),
_ => None,
})
.collect(),
Err(e) => {
return DnsVerdict::Failure {
detail: format!("zone NS lookup failed: {e}"),
}
}
};
if all_nameservers.is_empty() {
return DnsVerdict::Failure {
detail: format!("zone {zone} has no NS"),
};
}
let healthy: Vec<String> = all_nameservers
.into_iter()
.filter(|ns| health.is_healthy(ns))
.collect();
if healthy.is_empty() {
return DnsVerdict::Failure {
detail: format!("all {zone} nameservers in cooldown"),
};
}
let registered_name = match Name::from_str(registered) {
Ok(n) => n,
Err(e) => {
return DnsVerdict::Failure {
detail: format!("bad domain {registered}: {e}"),
}
}
};
let mut last_error: Option<String> = None;
for ns in healthy {
let ip = match resolver.lookup_ip(&ns).await {
Ok(addrs) => match addrs
.iter()
.find(|ip| ip.is_ipv4())
.or_else(|| addrs.iter().next())
{
Some(ip) => ip,
None => {
health.mark(&ns, format!("no A/AAAA for {ns}"));
last_error = Some(format!("no A/AAAA for {ns}"));
continue;
}
},
Err(e) => {
health.mark(&ns, format!("resolve: {e}"));
last_error = Some(format!("resolve {ns}: {e}"));
continue;
}
};
match query_authoritative_async(®istered_name, ip, timeout).await {
Ok(response) => {
if let Some(verdict) = classify(&response, &ns) {
return verdict;
}
last_error = Some(format!("unexpected rcode via {ns}"));
}
Err(e) => {
health.mark(&ns, format!("query: {e}"));
last_error = Some(format!("query {ns}: {e}"));
}
}
}
DnsVerdict::Failure {
detail: last_error.unwrap_or_else(|| "all nameservers failed".to_string()),
}
}
async fn query_authoritative_async(
name: &Name,
ip: IpAddr,
timeout: Duration,
) -> Result<Message, String> {
let mut msg = Message::new(rand::random::<u16>(), MessageType::Query, OpCode::Query);
msg.metadata.recursion_desired = false;
msg.add_query(Query::query(name.clone(), RecordType::NS));
let bytes = msg.to_vec().map_err(|e| e.to_string())?;
let bind_addr = if ip.is_ipv6() { "[::]:0" } else { "0.0.0.0:0" };
let sock = UdpSocket::bind(bind_addr)
.await
.map_err(|e| e.to_string())?;
sock.connect(SocketAddr::new(ip, 53))
.await
.map_err(|e| e.to_string())?;
let send = sock.send(&bytes);
tokio_timeout(timeout, send)
.await
.map_err(|_| "send timeout".to_string())?
.map_err(|e| e.to_string())?;
let mut buf = vec![0u8; 4096];
let recv = sock.recv(&mut buf);
let n = tokio_timeout(timeout, recv)
.await
.map_err(|_| "recv timeout".to_string())?
.map_err(|e| e.to_string())?;
Message::from_vec(&buf[..n]).map_err(|e| e.to_string())
}
fn classify(response: &Message, ns: &str) -> Option<DnsVerdict> {
match response.metadata.response_code {
ResponseCode::NXDomain => Some(DnsVerdict::Available {
detail: format!("NXDOMAIN via {ns}"),
}),
ResponseCode::NoError => {
let has_ns = response
.answers
.iter()
.chain(response.authorities.iter())
.any(|r| r.record_type() == RecordType::NS);
if has_ns {
Some(DnsVerdict::Registered {
detail: format!("delegation via {ns}"),
})
} else {
Some(DnsVerdict::Nodata {
detail: format!("NODATA via {ns}"),
})
}
}
_ => None,
}
}