use std::net::IpAddr;
use std::str::FromStr;
use std::time::Duration;
use hickory_resolver::config::{NameServerConfig, ResolveHosts, ResolverConfig, GOOGLE};
use hickory_resolver::net::runtime::TokioRuntimeProvider;
use hickory_resolver::net::NetError;
use hickory_resolver::proto::dnssec::PublicKey;
use hickory_resolver::proto::rr::rdata::CAA;
use hickory_resolver::proto::rr::{RData as HickoryRData, RecordType as HickoryRecordType};
use hickory_resolver::TokioResolver;
use tracing::{debug, instrument};
use super::records::{DnsRecord, RecordData, RecordType};
use crate::error::{Result, SeerError};
use crate::validation::normalize_domain;
fn dns_lookup_or_empty<T>(
result: std::result::Result<T, NetError>,
record_type: &str,
) -> Result<Option<T>> {
match result {
Ok(response) => Ok(Some(response)),
Err(e) if e.is_no_records_found() => Ok(None),
Err(e) => Err(SeerError::DnsError(format!(
"{} lookup failed: {}",
record_type, e
))),
}
}
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
fn build_resolver(config: ResolverConfig, timeout: Duration) -> TokioResolver {
let mut builder = TokioResolver::builder_with_config(config, TokioRuntimeProvider::default());
{
let opts = builder.options_mut();
opts.timeout = timeout;
opts.attempts = 2;
opts.use_hosts_file = ResolveHosts::Never;
}
builder
.build()
.expect("hickory resolver build is infallible without TLS features")
}
#[derive(Clone)]
pub struct DnsResolver {
timeout: Duration,
default_resolver: TokioResolver,
}
impl std::fmt::Debug for DnsResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DnsResolver")
.field("timeout", &self.timeout)
.finish()
}
}
impl Default for DnsResolver {
fn default() -> Self {
Self::new()
}
}
impl DnsResolver {
pub fn new() -> Self {
Self {
timeout: DEFAULT_TIMEOUT,
default_resolver: build_resolver(ResolverConfig::udp_and_tcp(&GOOGLE), DEFAULT_TIMEOUT),
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self.default_resolver = build_resolver(ResolverConfig::udp_and_tcp(&GOOGLE), timeout);
self
}
async fn create_custom_resolver(&self, nameserver: &str) -> Result<TokioResolver> {
let ips: Vec<IpAddr> = if let Ok(ip) = nameserver.parse::<IpAddr>() {
vec![ip]
} else {
let response = self
.default_resolver
.lookup_ip(nameserver)
.await
.map_err(|e| {
SeerError::DnsError(format!(
"failed to resolve nameserver hostname {}: {}",
nameserver, e
))
})?;
let resolved: Vec<IpAddr> = response.iter().collect();
if resolved.is_empty() {
return Err(SeerError::DnsError(format!(
"nameserver {} did not resolve to any addresses",
nameserver
)));
}
resolved
};
for ip in &ips {
if let Some(reason) = crate::validation::describe_reserved_ip(ip) {
return Err(SeerError::DnsError(format!(
"nameserver {} blocked: {}",
nameserver, reason
)));
}
}
let mut config = ResolverConfig::from_parts(None, vec![], vec![]);
for ip in ips {
config.add_name_server(NameServerConfig::udp(ip));
}
Ok(build_resolver(config, self.timeout))
}
#[instrument(skip(self), fields(domain = %domain, record_type = %record_type))]
pub async fn resolve(
&self,
domain: &str,
record_type: RecordType,
nameserver: Option<&str>,
) -> Result<Vec<DnsRecord>> {
let custom_resolver;
let resolver = if let Some(ns) = nameserver {
custom_resolver = self.create_custom_resolver(ns).await?;
&custom_resolver
} else {
&self.default_resolver
};
let domain = prepare_query(domain, record_type)?;
debug!(nameserver = nameserver.unwrap_or("system"), "Resolving DNS");
match record_type {
RecordType::A => self.resolve_a(resolver, &domain).await,
RecordType::AAAA => self.resolve_aaaa(resolver, &domain).await,
RecordType::CNAME => self.resolve_cname(resolver, &domain).await,
RecordType::MX => self.resolve_mx(resolver, &domain).await,
RecordType::NS => self.resolve_ns(resolver, &domain).await,
RecordType::TXT => self.resolve_txt(resolver, &domain).await,
RecordType::SOA => self.resolve_soa(resolver, &domain).await,
RecordType::PTR => self.resolve_ptr(resolver, &domain).await,
RecordType::SRV => match parse_srv_query(&domain) {
Some((service, protocol, name)) => {
self.resolve_srv_core(resolver, &service, &protocol, &name)
.await
}
None => Err(SeerError::InvalidInput(
"SRV records require service name format: _service._proto.name".to_string(),
)),
},
RecordType::CAA => self.resolve_caa(resolver, &domain).await,
RecordType::DNSKEY => self.resolve_dnskey(resolver, &domain).await,
RecordType::DS => self.resolve_ds(resolver, &domain).await,
RecordType::TLSA => self.resolve_tlsa(resolver, &domain).await,
RecordType::SSHFP => self.resolve_sshfp(resolver, &domain).await,
RecordType::NAPTR => self.resolve_naptr(resolver, &domain).await,
RecordType::ANY => self.resolve_any(resolver, &domain).await,
}
}
#[instrument(skip(self), fields(domain = %domain, service = %service, protocol = %protocol))]
pub async fn resolve_srv(
&self,
service: &str,
protocol: &str,
domain: &str,
nameserver: Option<&str>,
) -> Result<Vec<DnsRecord>> {
let custom_resolver;
let resolver = if let Some(ns) = nameserver {
custom_resolver = self.create_custom_resolver(ns).await?;
&custom_resolver
} else {
&self.default_resolver
};
self.resolve_srv_core(resolver, service, protocol, domain)
.await
}
async fn resolve_srv_core(
&self,
resolver: &TokioResolver,
service: &str,
protocol: &str,
domain: &str,
) -> Result<Vec<DnsRecord>> {
if !is_valid_srv_label(service) {
return Err(SeerError::InvalidInput(format!(
"invalid SRV service name: {}",
service
)));
}
if !is_valid_srv_label(protocol) {
return Err(SeerError::InvalidInput(format!(
"invalid SRV protocol name: {}",
protocol
)));
}
let query_name = format!("_{}._{}.{}", service, protocol, domain);
let Some(response) = dns_lookup_or_empty(
resolver.lookup(&query_name, HickoryRecordType::SRV).await,
"SRV",
)?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::SRV(srv) = &record.data {
Some(DnsRecord {
name: query_name.clone(),
record_type: RecordType::SRV,
ttl: record.ttl,
data: RecordData::SRV {
priority: srv.priority,
weight: srv.weight,
port: srv.port,
target: srv.target.to_string(),
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_a(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
let Some(response) =
dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::A).await, "A")?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::A(addr) = &record.data {
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::A,
ttl: record.ttl,
data: RecordData::A {
address: addr.0.to_string(),
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_aaaa(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
let Some(response) = dns_lookup_or_empty(
resolver.lookup(domain, HickoryRecordType::AAAA).await,
"AAAA",
)?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::AAAA(addr) = &record.data {
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::AAAA,
ttl: record.ttl,
data: RecordData::AAAA {
address: addr.0.to_string(),
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_cname(
&self,
resolver: &TokioResolver,
domain: &str,
) -> Result<Vec<DnsRecord>> {
let Some(response) = dns_lookup_or_empty(
resolver.lookup(domain, HickoryRecordType::CNAME).await,
"CNAME",
)?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::CNAME(cname) = &record.data {
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::CNAME,
ttl: record.ttl,
data: RecordData::CNAME {
target: cname.0.to_string(),
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_mx(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
let Some(response) =
dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::MX).await, "MX")?
else {
return Ok(vec![]);
};
let mut records: Vec<DnsRecord> = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::MX(mx) = &record.data {
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::MX,
ttl: record.ttl,
data: RecordData::MX {
preference: mx.preference,
exchange: mx.exchange.to_string(),
},
})
} else {
None
}
})
.collect();
records.sort_by_key(|r| {
if let RecordData::MX { preference, .. } = &r.data {
*preference
} else {
0
}
});
Ok(records)
}
async fn resolve_ns(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
let Some(response) =
dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::NS).await, "NS")?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::NS(ns) = &record.data {
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::NS,
ttl: record.ttl,
data: RecordData::NS {
nameserver: ns.0.to_string(),
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_txt(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
let Some(response) =
dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::TXT).await, "TXT")?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::TXT(txt) = &record.data {
let text = txt
.txt_data
.iter()
.map(|data| String::from_utf8_lossy(data).to_string())
.collect::<Vec<_>>()
.join("");
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::TXT,
ttl: record.ttl,
data: RecordData::TXT { text },
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_soa(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
let Some(response) =
dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::SOA).await, "SOA")?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::SOA(soa) = &record.data {
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::SOA,
ttl: record.ttl,
data: RecordData::SOA {
mname: soa.mname.to_string(),
rname: soa.rname.to_string(),
serial: soa.serial,
refresh: soa.refresh as u32,
retry: soa.retry as u32,
expire: soa.expire as u32,
minimum: soa.minimum,
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_ptr(&self, resolver: &TokioResolver, query: &str) -> Result<Vec<DnsRecord>> {
let query = if let Ok(ip) = IpAddr::from_str(query) {
reverse_dns_name(&ip)
} else {
query.to_string()
};
let Some(response) =
dns_lookup_or_empty(resolver.lookup(&query, HickoryRecordType::PTR).await, "PTR")?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::PTR(ptr) = &record.data {
Some(DnsRecord {
name: query.clone(),
record_type: RecordType::PTR,
ttl: record.ttl,
data: RecordData::PTR {
target: ptr.0.to_string(),
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_caa(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
let Some(response) =
dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::CAA).await, "CAA")?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::CAA(caa) = &record.data {
let (flags, tag, value) = parse_caa(caa);
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::CAA,
ttl: record.ttl,
data: RecordData::CAA { flags, tag, value },
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_dnskey(
&self,
resolver: &TokioResolver,
domain: &str,
) -> Result<Vec<DnsRecord>> {
use hickory_resolver::proto::dnssec::rdata::DNSSECRData;
let Some(response) = dns_lookup_or_empty(
resolver.lookup(domain, HickoryRecordType::DNSKEY).await,
"DNSKEY",
)?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::DNSSEC(DNSSECRData::DNSKEY(dnskey)) = &record.data {
use base64::{engine::general_purpose::STANDARD, Engine};
let public_key_buf = dnskey.public_key();
let public_key = STANDARD.encode(public_key_buf.public_bytes());
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::DNSKEY,
ttl: record.ttl,
data: RecordData::DNSKEY {
flags: dnskey.flags(),
protocol: 3,
algorithm: u8::from(public_key_buf.algorithm()),
public_key,
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_ds(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
use hickory_resolver::proto::dnssec::rdata::DNSSECRData;
let Some(response) =
dns_lookup_or_empty(resolver.lookup(domain, HickoryRecordType::DS).await, "DS")?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::DNSSEC(DNSSECRData::DS(ds)) = &record.data {
let digest = ds
.digest()
.iter()
.map(|b| format!("{:02X}", b))
.collect::<String>();
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::DS,
ttl: record.ttl,
data: RecordData::DS {
key_tag: ds.key_tag(),
algorithm: u8::from(ds.algorithm()),
digest_type: u8::from(ds.digest_type()),
digest,
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_tlsa(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
let Some(response) = dns_lookup_or_empty(
resolver.lookup(domain, HickoryRecordType::TLSA).await,
"TLSA",
)?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::TLSA(tlsa) = &record.data {
let cert_data = tlsa
.cert_data
.iter()
.map(|b| format!("{:02X}", b))
.collect::<String>();
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::TLSA,
ttl: record.ttl,
data: RecordData::TLSA {
cert_usage: u8::from(tlsa.cert_usage),
selector: u8::from(tlsa.selector),
matching: u8::from(tlsa.matching),
cert_data,
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_sshfp(
&self,
resolver: &TokioResolver,
domain: &str,
) -> Result<Vec<DnsRecord>> {
let Some(response) = dns_lookup_or_empty(
resolver.lookup(domain, HickoryRecordType::SSHFP).await,
"SSHFP",
)?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::SSHFP(sshfp) = &record.data {
let fingerprint = sshfp
.fingerprint
.iter()
.map(|b| format!("{:02X}", b))
.collect::<String>();
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::SSHFP,
ttl: record.ttl,
data: RecordData::SSHFP {
algorithm: u8::from(sshfp.algorithm),
fingerprint_type: u8::from(sshfp.fingerprint_type),
fingerprint,
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_naptr(
&self,
resolver: &TokioResolver,
domain: &str,
) -> Result<Vec<DnsRecord>> {
let Some(response) = dns_lookup_or_empty(
resolver.lookup(domain, HickoryRecordType::NAPTR).await,
"NAPTR",
)?
else {
return Ok(vec![]);
};
let records = response
.answers()
.iter()
.filter_map(|record| {
if let HickoryRData::NAPTR(naptr) = &record.data {
Some(DnsRecord {
name: domain.to_string(),
record_type: RecordType::NAPTR,
ttl: record.ttl,
data: RecordData::NAPTR {
order: naptr.order,
preference: naptr.preference,
flags: String::from_utf8_lossy(&naptr.flags).into_owned(),
services: String::from_utf8_lossy(&naptr.services).into_owned(),
regexp: String::from_utf8_lossy(&naptr.regexp).into_owned(),
replacement: naptr.replacement.to_string(),
},
})
} else {
None
}
})
.collect();
Ok(records)
}
async fn resolve_any(&self, resolver: &TokioResolver, domain: &str) -> Result<Vec<DnsRecord>> {
let mut all_records = Vec::new();
let record_types = [
RecordType::A,
RecordType::AAAA,
RecordType::MX,
RecordType::NS,
RecordType::TXT,
RecordType::SOA,
RecordType::CAA,
];
let mut any_ok = false;
let mut last_err = None;
for record_type in record_types {
match self.resolve_type(resolver, domain, record_type).await {
Ok(records) => {
any_ok = true;
all_records.extend(records);
}
Err(e) => last_err = Some(e),
}
}
match last_err {
Some(e) if !any_ok => Err(e),
_ => Ok(all_records),
}
}
async fn resolve_type(
&self,
resolver: &TokioResolver,
domain: &str,
record_type: RecordType,
) -> Result<Vec<DnsRecord>> {
match record_type {
RecordType::A => self.resolve_a(resolver, domain).await,
RecordType::AAAA => self.resolve_aaaa(resolver, domain).await,
RecordType::CNAME => self.resolve_cname(resolver, domain).await,
RecordType::MX => self.resolve_mx(resolver, domain).await,
RecordType::NS => self.resolve_ns(resolver, domain).await,
RecordType::TXT => self.resolve_txt(resolver, domain).await,
RecordType::SOA => self.resolve_soa(resolver, domain).await,
RecordType::CAA => self.resolve_caa(resolver, domain).await,
RecordType::DNSKEY => self.resolve_dnskey(resolver, domain).await,
RecordType::DS => self.resolve_ds(resolver, domain).await,
_ => Err(SeerError::DnsError("unsupported record type".to_string())),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DnsPresence {
Present,
Absent,
Unknown,
}
fn classify_ns_presence(result: &Result<Vec<DnsRecord>>) -> DnsPresence {
match result {
Ok(records) if records.is_empty() => DnsPresence::Absent,
Ok(_) => DnsPresence::Present,
Err(_) => DnsPresence::Unknown,
}
}
impl DnsResolver {
pub async fn presence(&self, domain: &str) -> DnsPresence {
classify_ns_presence(&self.resolve(domain, RecordType::NS, None).await)
}
}
fn prepare_query(domain: &str, record_type: RecordType) -> Result<String> {
if record_type == RecordType::PTR {
if let Ok(ip) = IpAddr::from_str(domain.trim()) {
return Ok(ip.to_string());
}
}
normalize_domain(domain)
}
fn parse_srv_query(name: &str) -> Option<(String, String, String)> {
let mut parts = name.splitn(3, '.');
let service = parts.next()?.strip_prefix('_')?;
let protocol = parts.next()?.strip_prefix('_')?;
let rest = parts.next()?;
if service.is_empty() || protocol.is_empty() || rest.is_empty() {
return None;
}
Some((service.to_string(), protocol.to_string(), rest.to_string()))
}
fn reverse_dns_name(ip: &IpAddr) -> String {
match ip {
IpAddr::V4(addr) => {
let octets = addr.octets();
format!(
"{}.{}.{}.{}.in-addr.arpa",
octets[3], octets[2], octets[1], octets[0]
)
}
IpAddr::V6(addr) => {
let segments = addr.segments();
let mut result = String::with_capacity(72);
let mut first = true;
for segment in segments.iter().rev() {
for shift in [0, 4, 8, 12] {
if !first {
result.push('.');
}
first = false;
let nibble = (segment >> shift) & 0xF;
result
.push(char::from_digit(nibble as u32, 16).expect("nibble is always 0-15"));
}
}
result.push_str(".ip6.arpa");
result
}
}
}
fn parse_caa(caa: &CAA) -> (u8, String, String) {
let flags = if caa.issuer_critical { 128 } else { 0 };
let tag = caa.tag.clone();
let value = String::from_utf8_lossy(&caa.value).to_string();
(flags, tag, value)
}
fn is_valid_srv_label(label: &str) -> bool {
!label.is_empty()
&& label.len() <= 63
&& label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
&& !label.starts_with('-')
&& !label.ends_with('-')
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn record_type_from_str_accepts_lowercase() {
assert_eq!(RecordType::from_str("a").unwrap(), RecordType::A);
assert_eq!(RecordType::from_str("mx").unwrap(), RecordType::MX);
assert_eq!(RecordType::from_str("cname").unwrap(), RecordType::CNAME);
assert_eq!(RecordType::from_str("dnskey").unwrap(), RecordType::DNSKEY);
}
#[test]
fn record_type_from_str_accepts_mixed_case() {
assert_eq!(RecordType::from_str("Mx").unwrap(), RecordType::MX);
assert_eq!(RecordType::from_str("cNaMe").unwrap(), RecordType::CNAME);
}
#[test]
fn record_type_from_str_rejects_whitespace_padded() {
assert!(RecordType::from_str(" A").is_err());
assert!(RecordType::from_str("A ").is_err());
assert!(RecordType::from_str("\tA\n").is_err());
}
#[test]
fn record_type_from_str_rejects_unknown() {
assert!(RecordType::from_str("NOTAREAL").is_err());
assert!(RecordType::from_str("A1").is_err());
assert!(RecordType::from_str("").is_err());
}
#[test]
fn record_type_from_str_accepts_star_as_any() {
assert_eq!(RecordType::from_str("*").unwrap(), RecordType::ANY);
assert_eq!(RecordType::from_str("ANY").unwrap(), RecordType::ANY);
assert_eq!(RecordType::from_str("any").unwrap(), RecordType::ANY);
}
#[test]
fn srv_label_accepts_alphanumeric_and_hyphen() {
assert!(is_valid_srv_label("http"));
assert!(is_valid_srv_label("ldap-tls"));
assert!(is_valid_srv_label("a1"));
assert!(is_valid_srv_label("tcp"));
}
#[test]
fn srv_label_rejects_empty() {
assert!(!is_valid_srv_label(""));
}
#[test]
fn srv_label_rejects_leading_or_trailing_hyphen() {
assert!(!is_valid_srv_label("-http"));
assert!(!is_valid_srv_label("http-"));
assert!(!is_valid_srv_label("-"));
}
#[test]
fn srv_label_rejects_dots() {
assert!(!is_valid_srv_label("http.evil"));
assert!(!is_valid_srv_label("a.b"));
}
#[test]
fn srv_label_rejects_special_chars() {
assert!(!is_valid_srv_label("http evil"));
assert!(!is_valid_srv_label("http/evil"));
assert!(!is_valid_srv_label("http\0"));
assert!(!is_valid_srv_label("http\n"));
}
#[test]
fn srv_label_rejects_over_63_chars() {
let too_long = "a".repeat(64);
assert!(!is_valid_srv_label(&too_long));
let exactly_63 = "a".repeat(63);
assert!(is_valid_srv_label(&exactly_63));
}
#[test]
fn classify_ns_presence_absent_on_empty_ok() {
let r: Result<Vec<DnsRecord>> = Ok(vec![]);
assert_eq!(classify_ns_presence(&r), DnsPresence::Absent);
}
#[test]
fn classify_ns_presence_present_on_records() {
let rec = DnsRecord {
name: "example.test.".to_string(),
record_type: RecordType::NS,
ttl: 3600,
data: RecordData::NS {
nameserver: "ns1.example.net.".to_string(),
},
};
let r: Result<Vec<DnsRecord>> = Ok(vec![rec]);
assert_eq!(classify_ns_presence(&r), DnsPresence::Present);
}
#[test]
fn classify_ns_presence_unknown_on_error() {
let r: Result<Vec<DnsRecord>> = Err(SeerError::DnsError("servfail".to_string()));
assert_eq!(classify_ns_presence(&r), DnsPresence::Unknown);
}
#[test]
fn reverse_dns_name_formats_ipv4_correctly() {
let ip: IpAddr = Ipv4Addr::new(192, 0, 2, 1).into();
assert_eq!(reverse_dns_name(&ip), "1.2.0.192.in-addr.arpa");
}
#[test]
fn reverse_dns_name_formats_ipv6_correctly() {
let ip: IpAddr = Ipv6Addr::LOCALHOST.into();
let name = reverse_dns_name(&ip);
assert!(
name.ends_with(".ip6.arpa"),
"must end with .ip6.arpa; got: {}",
name
);
assert!(
name.starts_with("1."),
"expected '1.' prefix, got: {}",
name
);
assert_eq!(name.len(), 72);
}
#[test]
fn resolver_new_has_default_timeout() {
let r = DnsResolver::new();
assert_eq!(r.timeout, DEFAULT_TIMEOUT);
}
#[test]
fn resolver_with_timeout_overrides_default() {
let custom = Duration::from_secs(42);
let r = DnsResolver::new().with_timeout(custom);
assert_eq!(r.timeout, custom);
}
#[test]
fn resolver_default_matches_new() {
let a = DnsResolver::default();
let b = DnsResolver::new();
assert_eq!(a.timeout, b.timeout);
}
#[tokio::test]
async fn custom_resolver_rejects_invalid_input() {
let r = DnsResolver::new();
let err = r.create_custom_resolver("..").await.unwrap_err();
let msg = err.to_string().to_lowercase();
assert!(
msg.contains("dns resolution failed") || msg.contains("invalid"),
"expected resolution failure, got: {}",
msg
);
}
#[tokio::test]
async fn custom_resolver_rejects_private_ipv4() {
let r = DnsResolver::new();
for reserved in ["127.0.0.1", "10.0.0.1", "192.168.1.1", "169.254.169.254"] {
let err = r.create_custom_resolver(reserved).await.unwrap_err();
let msg = err.to_string().to_lowercase();
assert!(
msg.contains("blocked") || msg.contains("reserved"),
"reserved IP {} must be rejected, got error: {}",
reserved,
msg
);
}
}
#[tokio::test]
async fn custom_resolver_rejects_loopback_ipv6() {
let r = DnsResolver::new();
let err = r.create_custom_resolver("::1").await.unwrap_err();
let msg = err.to_string().to_lowercase();
assert!(
msg.contains("blocked") || msg.contains("reserved"),
"::1 must be rejected, got error: {}",
msg
);
}
#[tokio::test]
async fn custom_resolver_accepts_public_ipv4() {
let r = DnsResolver::new();
let result = r.create_custom_resolver("8.8.8.8").await;
assert!(
result.is_ok(),
"8.8.8.8 must be accepted as a public nameserver, got: {:?}",
result.err()
);
}
#[tokio::test]
async fn resolve_srv_rejects_invalid_service_label() {
let r = DnsResolver::new();
let result = r.resolve_srv("http.evil", "tcp", "example.com", None).await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string().to_lowercase();
assert!(
msg.contains("invalid srv service"),
"expected SRV service validation error, got: {}",
msg
);
}
#[tokio::test]
async fn resolve_srv_rejects_invalid_protocol_label() {
let r = DnsResolver::new();
let result = r.resolve_srv("http", "tcp.evil", "example.com", None).await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string().to_lowercase();
assert!(
msg.contains("invalid srv protocol"),
"expected SRV protocol validation error, got: {}",
msg
);
}
#[tokio::test]
async fn resolve_normalizes_uppercase_domain_input() {
let r = DnsResolver::new();
let result = r.resolve(".bad.example", RecordType::A, None).await;
assert!(result.is_err(), "leading-dot domain must be rejected");
}
#[test]
fn parse_srv_query_extracts_service_proto_and_name() {
assert_eq!(
parse_srv_query("_sip._tcp.example.com"),
Some((
"sip".to_string(),
"tcp".to_string(),
"example.com".to_string()
))
);
}
#[test]
fn parse_srv_query_keeps_multilabel_domain() {
assert_eq!(
parse_srv_query("_sip._tcp.sip.voice.google.com"),
Some((
"sip".to_string(),
"tcp".to_string(),
"sip.voice.google.com".to_string()
))
);
}
#[test]
fn parse_srv_query_rejects_bare_domain() {
assert_eq!(parse_srv_query("example.com"), None);
}
#[test]
fn parse_srv_query_rejects_missing_proto_label() {
assert_eq!(parse_srv_query("_sip.example.com"), None);
}
#[tokio::test]
async fn resolve_rejects_bare_domain_for_srv_as_input_error() {
let r = DnsResolver::new();
let err = r
.resolve("example.com", RecordType::SRV, None)
.await
.expect_err("bare-domain SRV must error");
assert!(
matches!(err, SeerError::InvalidInput(_)),
"bare-domain SRV should be an input error, got: {err:?}"
);
assert!(err.to_string().contains("_service._proto"));
}
#[tokio::test]
#[ignore = "live network"]
async fn resolve_srv_via_dig_style_name_returns_records() {
let r = DnsResolver::new();
let records = r
.resolve("_caldavs._tcp.google.com", RecordType::SRV, None)
.await
.expect("dig-style SRV lookup should succeed");
assert!(!records.is_empty(), "expected SRV records");
assert!(records.iter().all(|r| r.record_type == RecordType::SRV));
}
#[tokio::test]
#[ignore = "live network"]
async fn resolve_naptr_returns_records() {
let r = DnsResolver::new();
let records = r
.resolve("sip2sip.info", RecordType::NAPTR, None)
.await
.expect("NAPTR lookup should succeed");
assert!(!records.is_empty(), "expected NAPTR records");
assert!(records.iter().all(|r| r.record_type == RecordType::NAPTR));
}
#[test]
fn prepare_query_passes_ipv6_literal_through_for_ptr() {
let out = prepare_query("2606:4700:4700::1111", RecordType::PTR).unwrap();
assert_eq!(out, "2606:4700:4700::1111");
}
#[test]
fn prepare_query_passes_ipv6_loopback_through_for_ptr() {
let out = prepare_query("::1", RecordType::PTR).unwrap();
assert_eq!(out, "::1");
}
#[test]
fn prepare_query_passes_ipv4_literal_through_for_ptr() {
let out = prepare_query("8.8.8.8", RecordType::PTR).unwrap();
assert_eq!(out, "8.8.8.8");
}
#[test]
fn prepare_query_normalizes_non_ip_ptr_names() {
let out = prepare_query("1.1.1.1.in-addr.arpa", RecordType::PTR).unwrap();
assert_eq!(out, "1.1.1.1.in-addr.arpa");
}
#[test]
fn prepare_query_normalizes_domains_for_non_ptr() {
let out = prepare_query("HTTPS://WWW.Example.com/path", RecordType::A).unwrap();
assert_eq!(out, "example.com");
}
}