use crate::{args::Args, opt_name::OptName};
use derive_more::{Debug, Display};
use exn::{Result, ResultExt as _, bail};
use hickory_client::client::{Client, ClientHandle as _};
use hickory_proto::{
op::ResponseCode,
rr::{DNSClass, Name, RData, Record, RecordType},
runtime::TokioRuntimeProvider,
tcp::TcpClientStream,
udp::UdpClientStream,
xfer::DnsResponse,
};
use hickory_resolver::{
TokioResolver,
config::{LookupIpStrategy, ResolverOpts},
name_server::GenericConnector,
};
use itertools::Itertools as _;
use std::{
collections::{BTreeSet, HashMap, HashSet},
fmt,
future::Future,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
pin::Pin,
sync::RwLock,
};
#[derive(Debug, Display)]
#[allow(clippy::missing_docs_in_private_items, reason = "self explainatory")]
pub enum ResolverError {
#[display("NS lookup failed for {_0}")]
NsLookup(String),
#[display("IP lookup failed for {_0}")]
IpLookup(String),
#[display("Client connect failed for {_0}")]
ClientConnect(OptName),
#[display("Client query failed for {_1} {_0}")]
ClientQuery(Name, RecordType),
#[display("Client creation failed")]
ClientNew(OptName),
#[display("Failed to build tokio resolver")]
BuildTokioResolver,
#[display("No IP address found for hostname: {_0}")]
NoIpForHostname(String),
#[display("do recurse depth {_0}")]
DoRecurse(usize),
#[display("Failed to acquire read lock")]
ReadLock,
#[display("Failed to acquire write lock")]
WriteLock,
}
impl std::error::Error for ResolverError {}
macro_rules! is_ip_allowed {
($self:expr, $ip:expr ) => {
$ip.is_ipv4() && $self.arguments.ipv4 || $ip.is_ipv6() && $self.arguments.ipv6 || !($self.arguments.ipv4 || $self.arguments.ipv6) };
}
type CacheKey = (IpAddr, Name);
type Cache = HashSet<CacheKey>;
#[derive(Clone, Debug, Default)]
struct FullResult {
records: BTreeSet<Record>,
response_code: ResponseCode,
}
#[cfg_attr(test, mockall::automock)]
pub trait NameResolver: Send + Sync {
async fn ns_lookup(&self, name: &str) -> Result<Vec<Name>, ResolverError>;
async fn lookup_ip(&self, name: &str) -> Result<Vec<IpAddr>, ResolverError>;
}
#[derive(Clone, Debug)]
pub struct QueryResult {
pub authoritative: bool,
pub answers: Vec<Record>,
pub name_servers: Vec<Record>,
pub additionals: Vec<Record>,
pub response_code: ResponseCode,
}
impl Default for QueryResult {
fn default() -> Self {
Self {
authoritative: false,
answers: vec![],
name_servers: vec![],
additionals: vec![],
response_code: ResponseCode::NoError,
}
}
}
impl From<DnsResponse> for QueryResult {
fn from(resp: DnsResponse) -> Self {
Self {
authoritative: resp.authoritative(),
answers: resp.answers().to_vec(),
name_servers: resp.name_servers().to_vec(),
additionals: resp.additionals().to_vec(),
response_code: resp.response_code(),
}
}
}
#[cfg_attr(test, mockall::automock)]
pub trait DnsQuerier: Send + Sync {
async fn query(
&self,
server: &OptName,
name: &Name,
query_type: RecordType,
) -> Result<QueryResult, ResolverError>;
}
pub struct TokioNameResolver(TokioResolver);
impl NameResolver for TokioNameResolver {
async fn ns_lookup(&self, name: &str) -> Result<Vec<Name>, ResolverError> {
Ok(self
.0
.ns_lookup(name)
.await
.or_raise(|| ResolverError::NsLookup(name.to_owned()))?
.iter()
.map(|ns| ns.0.clone())
.collect())
}
async fn lookup_ip(&self, name: &str) -> Result<Vec<IpAddr>, ResolverError> {
Ok(self
.0
.lookup_ip(name)
.await
.or_raise(|| ResolverError::IpLookup(name.to_owned()))?
.iter()
.collect())
}
}
pub struct DefaultDnsQuerier {
tcp: bool,
timeout: std::time::Duration,
source_address: Option<IpAddr>,
no_edns0: bool,
}
impl DefaultDnsQuerier {
const fn from_args(args: &Args) -> Self {
Self {
tcp: args.tcp,
timeout: args.timeout,
source_address: args.source_address,
no_edns0: args.no_edns0,
}
}
async fn udp_query(
&self,
server: &OptName,
name: &Name,
query_type: RecordType,
) -> Result<DnsResponse, ResolverError> {
let stream = UdpClientStream::builder(server.into(), TokioRuntimeProvider::new())
.with_timeout(Some(self.timeout))
.with_bind_addr(self.source_address.map(|ip| SocketAddr::new(ip, 0)))
.build();
let (mut client, bg) = Client::connect(stream)
.await
.or_raise(|| ResolverError::ClientConnect(server.clone()))?;
if self.no_edns0 {
client.disable_edns();
} else {
client.enable_edns();
}
tokio::spawn(bg);
client
.query(name.clone(), DNSClass::IN, query_type)
.await
.or_raise(|| ResolverError::ClientQuery(name.clone(), query_type))
}
async fn tcp_query(
&self,
server: &OptName,
name: &Name,
query_type: RecordType,
) -> Result<DnsResponse, ResolverError> {
let (stream, sender) = TcpClientStream::new(
server.into(),
self.source_address.map(|ip| SocketAddr::new(ip, 0)),
Some(self.timeout),
TokioRuntimeProvider::new(),
);
let (mut client, bg) = Client::new(stream, sender, None)
.await
.or_raise(|| ResolverError::ClientNew(server.clone()))?;
if self.no_edns0 {
client.disable_edns();
} else {
client.enable_edns();
}
tokio::spawn(bg);
client
.query(name.clone(), DNSClass::IN, query_type)
.await
.or_raise(|| ResolverError::ClientQuery(name.clone(), query_type))
}
}
impl DnsQuerier for DefaultDnsQuerier {
async fn query(
&self,
server: &OptName,
name: &Name,
query_type: RecordType,
) -> Result<QueryResult, ResolverError> {
let response = if self.tcp {
self.tcp_query(server, name, query_type).await
} else {
self.udp_query(server, name, query_type).await
}?;
Ok(QueryResult::from(response))
}
}
#[derive(Debug)]
pub struct RecursiveResolver<'a, R = TokioNameResolver, Q = DefaultDnsQuerier> {
results: RwLock<HashMap<OptName, FullResult>>,
#[debug(skip)]
name_resolver: R,
#[debug(skip)]
querier: Q,
arguments: &'a Args,
positive_cache: Option<RwLock<Cache>>,
negative_cache: Option<RwLock<Cache>>,
}
impl<'a> RecursiveResolver<'a> {
pub fn new(args: &'a Args) -> Result<Self, ResolverError> {
let mut resolver_opts = ResolverOpts::default();
resolver_opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
resolver_opts.attempts = args.retries;
resolver_opts.timeout = args.timeout;
resolver_opts.edns0 = !args.no_edns0;
let resolver = TokioResolver::builder(GenericConnector::new(TokioRuntimeProvider::new()))
.or_raise(|| ResolverError::BuildTokioResolver)?
.with_options(resolver_opts)
.build();
Ok(Self {
results: RwLock::new(HashMap::new()),
name_resolver: TokioNameResolver(resolver),
querier: DefaultDnsQuerier::from_args(args),
positive_cache: (!args.no_positive_cache).then(|| RwLock::new(HashSet::new())),
negative_cache: args.negative_cache.then(|| RwLock::new(HashSet::new())),
arguments: args,
})
}
}
impl<R: NameResolver, Q: DnsQuerier> RecursiveResolver<'_, R, Q> {
pub async fn init(&self) -> Result<Vec<OptName>, ResolverError> {
let mut results: Vec<OptName> = vec![];
if let Ok(ip) = self.arguments.server.parse::<IpAddr>() {
results.push(OptName {
ip,
name: None,
zone: None,
});
} else if self.arguments.server == "." {
let root_ns = self.name_resolver.ns_lookup(".").await?;
for ns in root_ns {
let ns_str = ns.to_string();
results.append(
&mut self
.name_resolver
.lookup_ip(&ns_str)
.await
.or_raise(|| ResolverError::IpLookup(ns_str))?
.into_iter()
.filter(|ip| is_ip_allowed!(self, ip))
.map(|ip| OptName {
ip,
name: Some(ns.to_string()),
zone: Some(".".to_owned()),
})
.collect(),
);
}
} else {
results.append(
&mut self
.name_resolver
.lookup_ip(&self.arguments.server)
.await
.or_raise(|| ResolverError::IpLookup(self.arguments.server.clone()))?
.into_iter()
.filter(|ip| is_ip_allowed!(self, ip))
.map(|ip| OptName {
ip,
name: Some(self.arguments.server.clone()),
zone: None,
})
.collect(),
);
}
if results.is_empty() {
bail!(ResolverError::NoIpForHostname(
self.arguments.server.clone()
));
}
Ok(results)
}
pub fn do_recurse<'b>(
&'b self,
name: &'b Name,
server: &'b OptName,
depth: usize,
last: Vec<bool>,
) -> Pin<Box<dyn Future<Output = Result<(), ResolverError>> + 'b>> {
Box::pin(async move {
if self.cache_get(&(server.ip, name.clone())) {
Self::print(depth, server, "(cached)", &last);
return Ok(());
}
let query_type = if depth == 0 {
RecordType::NS
} else {
self.arguments.query_type
};
match self.querier.query(server, name, query_type).await {
Ok(response) => {
let mut next_servers: Option<Vec<OptName>> = None;
if response.authoritative {
let result = &response.answers;
Self::print(depth, server, "found authoritative answer", &last);
self.cache_set(true, (server.ip, name.clone()));
self.add_result(server.clone(), response.response_code, result)?;
if self.arguments.query_type != RecordType::CNAME
&& result.iter().all(|r| r.record_type() == RecordType::CNAME)
&& !response.name_servers.is_empty()
{
for cname in response
.answers
.iter()
.filter_map(|r| r.data().as_cname())
{
next_servers = Some(
self.get_next_servers(
&response.name_servers,
&response.additionals,
server,
cname,
depth,
&last,
)
.await,
);
}
}
} else {
Self::print(depth, server, "", &last);
let (records, additionals) = if depth == 0 && !response.answers.is_empty() {
(&response.answers, &response.additionals)
} else {
(&response.name_servers, &response.additionals)
};
next_servers = Some(
self.get_next_servers(records, additionals, server, name, depth, &last)
.await,
);
}
if let Some(next) = next_servers {
let len = next.len();
for (index, ns) in next.iter().sorted().enumerate() {
self.do_recurse(name, ns, depth + 1, {
let mut new_last = last.clone();
new_last.push(index == (len - 1));
new_last
})
.await
.or_raise(|| ResolverError::DoRecurse(depth + 1))?;
}
}
}
Err(e) => {
self.cache_set(false, (server.ip, name.clone()));
Self::print(
depth,
server,
format!(
"{e} -> {}",
e.frame().children().first().map_or_else(
|| "unknown error".to_owned(),
std::string::ToString::to_string
)
),
&last,
);
}
}
Ok(())
})
}
async fn get_next_servers(
&self,
records: &[Record],
additionals: &[Record],
server: &OptName,
name: &Name,
depth: usize,
last: &[bool],
) -> Vec<OptName> {
let mut next_servers: Vec<OptName> = vec![];
for record in records {
let Some(ns) = record.data().as_ns() else {
continue;
};
let before_len = next_servers.len();
next_servers.append(
&mut additionals
.iter()
.filter(|r| *r.name() == ns.0)
.filter_map(|additional| match *additional.data() {
RData::A(ref a) => {
Some((additional, IpAddr::from(Into::<Ipv4Addr>::into(*a))))
}
RData::AAAA(ref a) => {
Some((additional, IpAddr::from(Into::<Ipv6Addr>::into(*a))))
}
_ => None,
})
.filter(|&(_, ip)| is_ip_allowed!(self, ip))
.map(|(additional, ip)| OptName {
ip,
name: Some(additional.name().to_string()),
zone: Some(record.name().to_string()),
})
.collect(),
);
if next_servers.len() == before_len {
let ns_s = ns.to_string();
let before_resolve_len = next_servers.len();
if let Ok(ips) = self.name_resolver.lookup_ip(&ns_s).await {
next_servers.append(
&mut ips
.into_iter()
.filter(|ip| is_ip_allowed!(self, ip))
.map(|ip| OptName {
ip,
name: Some(ns.to_string()),
zone: Some(record.name().to_string()),
})
.collect(),
);
}
if next_servers.len() > before_resolve_len {
self.cache_set(true, (server.ip, name.clone()));
} else {
Self::print(
depth,
&OptName {
ip: [0, 0, 0, 0].into(),
name: Some(ns.to_string()),
zone: Some(record.name().to_string()),
},
"no ip found",
last,
);
}
} else {
self.cache_set(true, (server.ip, name.clone()));
}
}
next_servers
}
#[expect(clippy::print_stdout, reason = "print")]
pub fn show_overview(&self) -> Result<(), ResolverError> {
for (key, values) in self
.results
.read()
.map_err(|_| ResolverError::ReadLock)?
.iter()
{
if values.response_code != ResponseCode::NoError {
println!(
"{} ({})\t{}",
key.name.as_deref().unwrap_or_default(),
key.ip,
values.response_code
);
}
for record in values
.records
.iter()
.sorted_by_cached_key(|r| format!("{r}"))
{
println!(
"{} ({}) \t{record}",
key.name.as_deref().unwrap_or_default(),
key.ip
);
}
}
Ok(())
}
fn cache_get(&self, key: &CacheKey) -> bool {
self.positive_cache
.as_ref()
.is_some_and(|o| o.read().ok().as_ref().and_then(|r| r.get(key)).is_some())
|| self
.negative_cache
.as_ref()
.is_some_and(|o| o.read().ok().as_ref().and_then(|r| r.get(key)).is_some())
}
#[expect(clippy::print_stderr, reason = "non fatal error")]
fn cache_set(&self, positive: bool, key: CacheKey) {
let cache = if positive {
&self.positive_cache
} else {
&self.negative_cache
};
#[expect(clippy::pattern_type_mismatch, reason = "can't dereference guard")]
if let Some(locked_cache) = cache {
match locked_cache.write() {
Ok(mut c) => {
c.insert(key);
}
Err(error) => {
eprintln!("cache set error {error}");
}
}
}
}
#[expect(
clippy::significant_drop_tightening,
reason = "Scope is short enough and there should not be contentions"
)]
fn add_result(
&self,
server: OptName,
response_code: ResponseCode,
results: &[Record],
) -> Result<(), ResolverError> {
let mut res = self.results.write().map_err(|_| ResolverError::WriteLock)?;
let full = res.entry(server).or_default();
full.response_code = response_code;
for result in results {
full.records.insert(result.clone());
}
Ok(())
}
#[expect(clippy::print_stdout, reason = "called print")]
fn print<S: fmt::Display>(depth: usize, server: &OptName, rest: S, last: &[bool]) {
let mut output = String::new();
for i in 0..depth {
if *last.get(i).unwrap_or(&false) {
output.push_str(" ");
} else {
output.push_str(" |");
}
if i < depth - 1 {
output.push_str(" ");
}
}
if depth > 0 {
output.push_str(r"\___ ");
}
let rest = format!("{rest}");
if rest.is_empty() {
println!("{output}{server}");
} else {
println!("{output}{server} {rest}");
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::indexing_slicing, reason = "test")]
use super::*;
use crate::args::Args;
use hickory_proto::rr::{Name, RData, Record, RecordType, rdata};
use mockall::predicate;
use std::{
net::{IpAddr, Ipv4Addr},
str::FromStr as _,
time::Duration,
};
fn default_args() -> Args {
Args {
domain: "example.com".to_owned(),
no_positive_cache: false,
negative_cache: false,
no_edns0: true,
overview: false,
query_type: RecordType::A,
retries: 3,
server: ".".to_owned(),
timeout: Duration::from_secs(5),
source_address: None,
ipv6: false,
ipv4: true,
tcp: false,
}
}
fn mock_resolver(
args: &Args,
name_resolver: MockNameResolver,
querier: MockDnsQuerier,
) -> RecursiveResolver<'_, MockNameResolver, MockDnsQuerier> {
RecursiveResolver {
results: RwLock::new(HashMap::new()),
name_resolver,
querier,
arguments: args,
positive_cache: (!args.no_positive_cache).then(|| RwLock::new(HashSet::new())),
negative_cache: args.negative_cache.then(|| RwLock::new(HashSet::new())),
}
}
fn delegation_response(ns_name: &str, ns_ip: Ipv4Addr, zone: &str) -> QueryResult {
let ns_name_parsed = Name::from_str(ns_name).expect("ns_name is a valid DNS name literal");
let zone_name = Name::from_str(zone).expect("zone is a valid DNS name literal");
let ns_record = Record::from_rdata(
zone_name,
3600,
RData::NS(rdata::NS(ns_name_parsed.clone())),
);
let glue_record = Record::from_rdata(ns_name_parsed, 3600, RData::A(rdata::A(ns_ip)));
QueryResult {
authoritative: false,
answers: vec![],
name_servers: vec![ns_record],
additionals: vec![glue_record],
response_code: ResponseCode::NoError,
}
}
fn authoritative_a_response(domain: &str, ip: Ipv4Addr) -> QueryResult {
let record = Record::from_rdata(
Name::from_str(domain).expect("domain is a valid DNS name literal"),
300,
RData::A(rdata::A(ip)),
);
QueryResult {
authoritative: true,
answers: vec![record],
name_servers: vec![],
additionals: vec![],
response_code: ResponseCode::NoError,
}
}
#[test]
fn recursive_resolver_new() {
let args = default_args();
let resolver = RecursiveResolver::new(&args)
.expect("resolver creation with default args should succeed");
assert_eq!(*resolver.arguments, args);
assert!(resolver.positive_cache.is_some());
assert!(resolver.negative_cache.is_none());
}
#[test]
fn recursive_resolver_new_2() {
let args = Args {
no_positive_cache: true,
negative_cache: true,
..default_args()
};
let resolver = RecursiveResolver::new(&args)
.expect("resolver creation with default args should succeed");
assert_eq!(*resolver.arguments, args);
assert!(resolver.positive_cache.is_none());
assert!(resolver.negative_cache.is_some());
}
#[tokio::test]
async fn recursive_resolver_init_with_ip() {
let args = Args {
server: "8.8.8.8".to_owned(),
..default_args()
};
let resolver =
RecursiveResolver::new(&args).expect("resolver creation with IP server should succeed");
let servers = resolver
.init()
.await
.expect("init with an IP server should succeed");
assert_eq!(servers.len(), 1);
assert_eq!(servers[0].ip, IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
assert!(servers[0].name.is_none());
}
#[tokio::test]
async fn init_with_dot_uses_ns_then_lookup_ip() {
let args = default_args();
let mut nr = MockNameResolver::new();
nr.expect_ns_lookup()
.with(predicate::eq("."))
.once()
.returning(|_| {
Ok(vec![
Name::from_str("a.root-servers.net.")
.expect("a.root-servers.net. is a valid DNS name"),
])
});
nr.expect_lookup_ip()
.with(predicate::eq("a.root-servers.net."))
.once()
.returning(|_| Ok(vec![IpAddr::from([198, 41, 0, 4])]));
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let servers = resolver.init().await.expect("init with dot should succeed");
assert_eq!(servers.len(), 1);
assert_eq!(servers[0].ip, IpAddr::from([198, 41, 0, 4]));
assert_eq!(servers[0].name, Some("a.root-servers.net.".to_owned()));
assert_eq!(servers[0].zone, Some(".".to_owned()));
}
#[tokio::test]
async fn init_with_dot_filters_ipv6_when_ipv4_only() {
let args = default_args();
let mut nr = MockNameResolver::new();
nr.expect_ns_lookup().once().returning(|_| {
Ok(vec![
Name::from_str("a.root-servers.net.")
.expect("a.root-servers.net. is a valid DNS name"),
])
});
nr.expect_lookup_ip().once().returning(|_| {
let ipv6: IpAddr = "2001:503:ba3e::2:30"
.parse()
.expect("2001:503:ba3e::2:30 is a valid IPv6 address");
Ok(vec![IpAddr::from([198, 41, 0, 4]), ipv6])
});
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let servers = resolver.init().await.expect("init with dot should succeed");
assert_eq!(servers.len(), 1);
assert!(servers[0].ip.is_ipv4());
}
#[tokio::test]
async fn init_with_hostname_resolves_ips() {
let args = Args {
server: "ns1.example.com".to_owned(),
..default_args()
};
let mut nr = MockNameResolver::new();
nr.expect_lookup_ip()
.with(predicate::eq("ns1.example.com"))
.once()
.returning(|_| {
Ok(vec![
IpAddr::from([192, 0, 2, 1]),
IpAddr::from([192, 0, 2, 2]),
])
});
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let servers = resolver
.init()
.await
.expect("init with hostname should succeed");
assert_eq!(servers.len(), 2);
assert!(
servers
.iter()
.all(|s| s.name == Some("ns1.example.com".to_owned()))
);
assert!(servers.iter().all(|s| s.zone.is_none()));
}
#[tokio::test]
async fn init_with_no_results_returns_error() {
let args = Args {
server: "ns1.example.com".to_owned(),
..default_args()
};
let mut nr = MockNameResolver::new();
nr.expect_lookup_ip().once().returning(|_| {
let ipv6: IpAddr = "2001:db8::1"
.parse()
.expect("2001:db8::1 is a valid IPv6 address");
Ok(vec![ipv6])
});
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let result = resolver.init().await;
assert!(result.is_err());
}
#[tokio::test]
async fn do_recurse_authoritative_answer_stored_in_results() {
let args = default_args();
let name = Name::from_str("example.com.").expect("example.com. is a valid DNS name");
let server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: Some("ns1.example.com.".to_owned()),
zone: None,
};
let mut q = MockDnsQuerier::new();
q.expect_query().once().returning(|_, _, _| {
Ok(authoritative_a_response(
"example.com.",
Ipv4Addr::new(93, 184, 216, 34),
))
});
let resolver = mock_resolver(&args, MockNameResolver::new(), q);
resolver
.do_recurse(&name, &server, 1, vec![true])
.await
.expect("do_recurse should succeed");
let results = resolver
.results
.read()
.expect("results lock should not be poisoned");
assert_eq!(results.len(), 1);
let result = results
.values()
.next()
.expect("results map should contain at least one entry");
assert_eq!(result.response_code, ResponseCode::NoError);
assert_eq!(result.records.len(), 1);
drop(results);
}
#[tokio::test]
async fn do_recurse_uses_ns_at_depth_zero() {
let args = default_args();
let name = Name::from_str("example.com.").expect("example.com. is a valid DNS name");
let server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: Some("ns1.example.com.".to_owned()),
zone: None,
};
let mut q = MockDnsQuerier::new();
q.expect_query()
.withf(|_, _, qtype| *qtype == RecordType::NS)
.once()
.returning(|_, _, _| Ok(QueryResult::default()));
let resolver = mock_resolver(&args, MockNameResolver::new(), q);
resolver
.do_recurse(&name, &server, 0, vec![])
.await
.expect("do_recurse should succeed");
}
#[tokio::test]
async fn do_recurse_follows_ns_delegation() {
let args = default_args();
let name = Name::from_str("example.com.").expect("example.com. is a valid DNS name");
let first_server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: Some("ns1.example.com.".to_owned()),
zone: None,
};
let mut q = MockDnsQuerier::new();
q.expect_query()
.withf(|server, _, _| server.ip == IpAddr::from([192, 0, 2, 1]))
.once()
.returning(|_, _, _| {
Ok(delegation_response(
"ns2.example.com.",
Ipv4Addr::new(192, 0, 2, 2),
"example.com.",
))
});
q.expect_query()
.withf(|server, _, _| server.ip == IpAddr::from([192, 0, 2, 2]))
.once()
.returning(|_, _, _| {
Ok(authoritative_a_response(
"example.com.",
Ipv4Addr::new(93, 184, 216, 34),
))
});
let resolver = mock_resolver(&args, MockNameResolver::new(), q);
resolver
.do_recurse(&name, &first_server, 1, vec![])
.await
.expect("do_recurse should succeed");
let results = resolver
.results
.read()
.expect("results lock should not be poisoned");
assert_eq!(results.len(), 1);
drop(results);
}
#[tokio::test]
async fn do_recurse_skips_cached_servers() {
let args = default_args();
let name = Name::from_str("example.com.").expect("example.com. is a valid DNS name");
let server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: Some("ns1.example.com.".to_owned()),
zone: None,
};
let mut q = MockDnsQuerier::new();
q.expect_query().never();
let resolver = mock_resolver(&args, MockNameResolver::new(), q);
resolver
.positive_cache
.as_ref()
.expect("positive cache should be initialized")
.write()
.expect("positive cache lock should not be poisoned")
.insert((server.ip, name.clone()));
resolver
.do_recurse(&name, &server, 1, vec![])
.await
.expect("do_recurse should succeed");
}
#[tokio::test]
async fn do_recurse_sets_negative_cache_on_error() {
let args = Args {
negative_cache: true,
..default_args()
};
let name = Name::from_str("example.com.").expect("example.com. is a valid DNS name");
let server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: Some("ns1.example.com.".to_owned()),
zone: None,
};
let mut q = MockDnsQuerier::new();
q.expect_query().once().returning({
let name = name.clone();
move |_, _, _| Err(ResolverError::ClientQuery(name.clone(), RecordType::A).into())
});
let resolver = mock_resolver(&args, MockNameResolver::new(), q);
resolver
.do_recurse(&name, &server, 1, vec![])
.await
.expect("do_recurse should succeed even when the query errors (error is printed, not propagated)");
let neg = resolver
.negative_cache
.as_ref()
.expect("negative cache should be initialized")
.read()
.expect("negative cache lock should not be poisoned");
assert!(neg.contains(&(server.ip, name.clone())));
drop(neg);
}
#[tokio::test]
async fn get_next_servers_uses_glue_records_from_additionals() {
let args = default_args();
let name = Name::from_str("example.com.").expect("example.com. is a valid DNS name");
let current_server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: None,
zone: None,
};
let nr = MockNameResolver::new();
let ns_name =
Name::from_str("ns1.example.com.").expect("ns1.example.com. is a valid DNS name");
let zone_name = Name::from_str("example.com.").expect("example.com. is a valid DNS name");
let ns_record = Record::from_rdata(zone_name, 3600, RData::NS(rdata::NS(ns_name.clone())));
let glue = Record::from_rdata(ns_name, 3600, RData::A(rdata::A(Ipv4Addr::new(1, 2, 3, 4))));
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let next = resolver
.get_next_servers(&[ns_record], &[glue], ¤t_server, &name, 1, &[true])
.await;
assert_eq!(next.len(), 1);
assert_eq!(next[0].ip, IpAddr::from([1, 2, 3, 4]));
}
#[tokio::test]
async fn get_next_servers_falls_back_to_lookup_when_no_glue() {
let args = default_args();
let name = Name::from_str("example.com.").expect("example.com. is a valid DNS name");
let current_server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: None,
zone: None,
};
let mut nr = MockNameResolver::new();
nr.expect_lookup_ip()
.with(predicate::eq("ns1.example.com."))
.once()
.returning(|_| Ok(vec![IpAddr::from([5, 6, 7, 8])]));
let ns_name =
Name::from_str("ns1.example.com.").expect("ns1.example.com. is a valid DNS name");
let zone_name = Name::from_str("example.com.").expect("example.com. is a valid DNS name");
let ns_record = Record::from_rdata(zone_name, 3600, RData::NS(rdata::NS(ns_name)));
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let next = resolver
.get_next_servers(
&[ns_record],
&[], ¤t_server,
&name,
1,
&[true],
)
.await;
assert_eq!(next.len(), 1);
assert_eq!(next[0].ip, IpAddr::from([5, 6, 7, 8]));
}
#[test]
fn query_result_default_values() {
let qr = QueryResult::default();
assert!(!qr.authoritative);
assert_eq!(qr.response_code, ResponseCode::NoError);
assert!(qr.answers.is_empty());
assert!(qr.name_servers.is_empty());
assert!(qr.additionals.is_empty());
}
}