use crate::{args::Args, opt_name::OptName};
use derive_more::{Debug, Display};
use exn::{Result, ResultExt as _, bail};
use hickory_client::client::{Client, ClientHandle as _};
use hickory_proto::{
op::ResponseCode,
rr::{DNSClass, Name, RData, Record, RecordType},
runtime::TokioRuntimeProvider,
tcp::TcpClientStream,
udp::UdpClientStream,
xfer::DnsResponse,
};
use hickory_resolver::{
TokioResolver,
config::{LookupIpStrategy, ResolverOpts},
name_server::GenericConnector,
};
use itertools::Itertools as _;
use std::{
collections::{BTreeSet, HashMap, HashSet},
fmt,
future::Future,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
pin::Pin,
sync::RwLock,
};
#[derive(Debug, Display)]
pub struct ResolverError(String);
impl std::error::Error for ResolverError {}
macro_rules! is_ip_allowed {
($self:expr, $ip:expr ) => {
$ip.is_ipv4() && $self.arguments.ipv4 || $ip.is_ipv6() && $self.arguments.ipv6 || !($self.arguments.ipv4 || $self.arguments.ipv6) };
}
type CacheKey = (IpAddr, Name);
type Cache = HashSet<CacheKey>;
#[derive(Clone, Debug, Default)]
struct FullResult {
records: BTreeSet<Record>,
response_code: ResponseCode,
}
#[cfg_attr(test, mockall::automock)]
pub trait NameResolver: Send + Sync {
async fn ns_lookup(&self, name: &str) -> Result<Vec<Name>, ResolverError>;
async fn lookup_ip(&self, name: &str) -> Result<Vec<IpAddr>, ResolverError>;
}
#[derive(Clone, Debug)]
pub struct QueryResult {
pub authoritative: bool,
pub answers: Vec<Record>,
pub name_servers: Vec<Record>,
pub additionals: Vec<Record>,
pub response_code: ResponseCode,
}
impl Default for QueryResult {
fn default() -> Self {
Self {
authoritative: false,
answers: vec![],
name_servers: vec![],
additionals: vec![],
response_code: ResponseCode::NoError,
}
}
}
impl From<DnsResponse> for QueryResult {
fn from(resp: DnsResponse) -> Self {
Self {
authoritative: resp.authoritative(),
answers: resp.answers().to_vec(),
name_servers: resp.name_servers().to_vec(),
additionals: resp.additionals().to_vec(),
response_code: resp.response_code(),
}
}
}
#[cfg_attr(test, mockall::automock)]
pub trait DnsQuerier: Send + Sync {
async fn query(
&self,
server: &OptName,
name: &Name,
query_type: RecordType,
) -> Result<QueryResult, ResolverError>;
}
pub struct TokioNameResolver(TokioResolver);
impl NameResolver for TokioNameResolver {
async fn ns_lookup(&self, name: &str) -> Result<Vec<Name>, ResolverError> {
Ok(self
.0
.ns_lookup(name)
.await
.or_raise(|| ResolverError("ns lookup".to_owned()))?
.iter()
.map(|ns| ns.0.clone())
.collect())
}
async fn lookup_ip(&self, name: &str) -> Result<Vec<IpAddr>, ResolverError> {
Ok(self
.0
.lookup_ip(name)
.await
.or_raise(|| ResolverError("ip lookup".to_owned()))?
.iter()
.collect())
}
}
pub struct DefaultDnsQuerier {
tcp: bool,
timeout: std::time::Duration,
source_address: Option<IpAddr>,
no_edns0: bool,
}
impl DefaultDnsQuerier {
const fn from_args(args: &Args) -> Self {
Self {
tcp: args.tcp,
timeout: args.timeout,
source_address: args.source_address,
no_edns0: args.no_edns0,
}
}
async fn udp_query(
&self,
server: &OptName,
name: &Name,
query_type: RecordType,
) -> Result<DnsResponse, ResolverError> {
let stream = UdpClientStream::builder(server.into(), TokioRuntimeProvider::new())
.with_timeout(Some(self.timeout))
.with_bind_addr(self.source_address.map(|ip| SocketAddr::new(ip, 0)))
.build();
let (mut client, bg) = Client::connect(stream)
.await
.or_raise(|| ResolverError("client connect".to_owned()))?;
if self.no_edns0 {
client.disable_edns();
} else {
client.enable_edns();
}
tokio::spawn(bg);
client
.query(name.clone(), DNSClass::IN, query_type)
.await
.or_raise(|| ResolverError("client query".to_owned()))
}
async fn tcp_query(
&self,
server: &OptName,
name: &Name,
query_type: RecordType,
) -> Result<DnsResponse, ResolverError> {
let (stream, sender) = TcpClientStream::new(
server.into(),
self.source_address.map(|ip| SocketAddr::new(ip, 0)),
Some(self.timeout),
TokioRuntimeProvider::new(),
);
let (mut client, bg) = Client::new(stream, sender, None)
.await
.or_raise(|| ResolverError("client new".to_owned()))?;
if self.no_edns0 {
client.disable_edns();
} else {
client.enable_edns();
}
tokio::spawn(bg);
client
.query(name.clone(), DNSClass::IN, query_type)
.await
.or_raise(|| ResolverError("client query".to_owned()))
}
}
impl DnsQuerier for DefaultDnsQuerier {
async fn query(
&self,
server: &OptName,
name: &Name,
query_type: RecordType,
) -> Result<QueryResult, ResolverError> {
let response = if self.tcp {
self.tcp_query(server, name, query_type).await
} else {
self.udp_query(server, name, query_type).await
}?;
Ok(QueryResult::from(response))
}
}
#[derive(Debug)]
pub struct RecursiveResolver<'a, R = TokioNameResolver, Q = DefaultDnsQuerier> {
results: RwLock<HashMap<OptName, FullResult>>,
#[debug(skip)]
name_resolver: R,
#[debug(skip)]
querier: Q,
arguments: &'a Args,
positive_cache: Option<RwLock<Cache>>,
negative_cache: Option<RwLock<Cache>>,
}
impl<'a> RecursiveResolver<'a> {
pub fn new(args: &'a Args) -> Result<Self, ResolverError> {
let mut resolver_opts = ResolverOpts::default();
resolver_opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
resolver_opts.attempts = args.retries;
resolver_opts.timeout = args.timeout;
resolver_opts.edns0 = !args.no_edns0;
let resolver = TokioResolver::builder(GenericConnector::new(TokioRuntimeProvider::new()))
.or_raise(|| ResolverError("build tokio resolver".to_owned()))?
.with_options(resolver_opts)
.build();
Ok(Self {
results: RwLock::new(HashMap::new()),
name_resolver: TokioNameResolver(resolver),
querier: DefaultDnsQuerier::from_args(args),
positive_cache: (!args.no_positive_cache).then(|| RwLock::new(HashSet::new())),
negative_cache: args.negative_cache.then(|| RwLock::new(HashSet::new())),
arguments: args,
})
}
}
impl<R: NameResolver, Q: DnsQuerier> RecursiveResolver<'_, R, Q> {
pub async fn init(&self) -> Result<Vec<OptName>, ResolverError> {
let mut results: Vec<OptName> = vec![];
if let Ok(ip) = self.arguments.server.parse::<IpAddr>() {
results.push(OptName {
ip,
name: None,
zone: None,
});
} else if self.arguments.server == "." {
let root_ns = self.name_resolver.ns_lookup(".").await?;
for ns in root_ns {
let ns_str = ns.to_string();
results.append(
&mut self
.name_resolver
.lookup_ip(&ns_str)
.await
.or_raise(|| ResolverError("ip lookup".to_owned()))?
.into_iter()
.filter(|ip| is_ip_allowed!(self, ip))
.map(|ip| OptName {
ip,
name: Some(ns.to_string()),
zone: Some(".".to_owned()),
})
.collect(),
);
}
} else {
results.append(
&mut self
.name_resolver
.lookup_ip(&self.arguments.server)
.await
.or_raise(|| ResolverError("ip lookup".to_owned()))?
.into_iter()
.filter(|ip| is_ip_allowed!(self, ip))
.map(|ip| OptName {
ip,
name: Some(self.arguments.server.clone()),
zone: None,
})
.collect(),
);
}
if results.is_empty() {
bail!(ResolverError(format!(
"no IP address found for hostname: {}",
self.arguments.server
)));
}
Ok(results)
}
pub fn do_recurse<'b>(
&'b self,
name: &'b Name,
server: &'b OptName,
depth: usize,
last: Vec<bool>,
) -> Pin<Box<dyn Future<Output = Result<(), ResolverError>> + 'b>> {
Box::pin(async move {
if self.cache_get(&(server.ip, name.clone())) {
Self::print(depth, server, "(cached)", &last);
return Ok(());
}
let query_type = if depth == 0 {
RecordType::NS
} else {
self.arguments.query_type
};
match self.querier.query(server, name, query_type).await {
Ok(response) => {
let mut next_servers: Option<Vec<OptName>> = None;
if response.authoritative {
let result = &response.answers;
Self::print(depth, server, "found authoritative answer", &last);
self.cache_set(true, (server.ip, name.clone()));
self.add_result(server.clone(), response.response_code, result)?;
if self.arguments.query_type != RecordType::CNAME
&& result.iter().all(|r| r.record_type() == RecordType::CNAME)
&& !response.name_servers.is_empty()
{
for cname in response
.answers
.iter()
.filter_map(|r| r.data().as_cname())
{
next_servers = Some(
self.get_next_servers(
&response.name_servers,
&response.additionals,
server,
cname,
depth,
&last,
)
.await,
);
}
}
} else {
Self::print(depth, server, "", &last);
let (records, additionals) = if depth == 0 && !response.answers.is_empty() {
(&response.answers, &response.additionals)
} else {
(&response.name_servers, &response.additionals)
};
next_servers = Some(
self.get_next_servers(records, additionals, server, name, depth, &last)
.await,
);
}
if let Some(next) = next_servers {
let len = next.len();
for (index, ns) in next.iter().sorted().enumerate() {
self.do_recurse(name, ns, depth + 1, {
let mut new_last = last.clone();
new_last.push(index == (len - 1));
new_last
})
.await
.or_raise(|| {
ResolverError(format!("do recurse depth {}", depth + 1))
})?;
}
}
}
Err(e) => {
self.cache_set(false, (server.ip, name.clone()));
Self::print(depth, server, format!("resolution error: {e}"), &last);
}
}
Ok(())
})
}
async fn get_next_servers(
&self,
records: &[Record],
additionals: &[Record],
server: &OptName,
name: &Name,
depth: usize,
last: &[bool],
) -> Vec<OptName> {
let mut next_servers: Vec<OptName> = vec![];
for record in records {
let Some(ns) = record.data().as_ns() else {
continue;
};
let before_len = next_servers.len();
next_servers.append(
&mut additionals
.iter()
.filter(|r| *r.name() == ns.0)
.filter_map(|additional| match *additional.data() {
RData::A(ref a) => {
Some((additional, IpAddr::from(Into::<Ipv4Addr>::into(*a))))
}
RData::AAAA(ref a) => {
Some((additional, IpAddr::from(Into::<Ipv6Addr>::into(*a))))
}
_ => None,
})
.filter(|&(_, ip)| is_ip_allowed!(self, ip))
.map(|(additional, ip)| OptName {
ip,
name: Some(additional.name().to_string()),
zone: Some(record.name().to_string()),
})
.collect(),
);
if next_servers.len() == before_len {
let ns_s = ns.to_string();
let before_resolve_len = next_servers.len();
if let Ok(ips) = self.name_resolver.lookup_ip(&ns_s).await {
next_servers.append(
&mut ips
.into_iter()
.filter(|ip| is_ip_allowed!(self, ip))
.map(|ip| OptName {
ip,
name: Some(ns.to_string()),
zone: Some(record.name().to_string()),
})
.collect(),
);
}
if next_servers.len() > before_resolve_len {
self.cache_set(true, (server.ip, name.clone()));
} else {
Self::print(
depth,
&OptName {
ip: [0, 0, 0, 0].into(),
name: Some(ns.to_string()),
zone: Some(record.name().to_string()),
},
"no ip found",
last,
);
}
} else {
self.cache_set(true, (server.ip, name.clone()));
}
}
next_servers
}
#[expect(clippy::print_stdout, reason = "print")]
pub fn show_overview(&self) -> Result<(), ResolverError> {
for (key, values) in self
.results
.read()
.map_err(|e| ResolverError(format!("get read lock: {e:?}")))?
.iter()
{
if values.response_code != ResponseCode::NoError {
println!(
"{} ({})\t{}",
key.name.as_deref().unwrap_or_default(),
key.ip,
values.response_code
);
}
for record in values
.records
.iter()
.sorted_by_cached_key(|r| format!("{r}"))
{
println!(
"{} ({}) \t{record}",
key.name.as_deref().unwrap_or_default(),
key.ip
);
}
}
Ok(())
}
fn cache_get(&self, key: &CacheKey) -> bool {
self.positive_cache
.as_ref()
.is_some_and(|o| o.read().ok().as_ref().and_then(|r| r.get(key)).is_some())
|| self
.negative_cache
.as_ref()
.is_some_and(|o| o.read().ok().as_ref().and_then(|r| r.get(key)).is_some())
}
#[expect(clippy::print_stderr, reason = "non fatal error")]
fn cache_set(&self, positive: bool, key: CacheKey) {
let cache = if positive {
&self.positive_cache
} else {
&self.negative_cache
};
#[expect(clippy::pattern_type_mismatch, reason = "can't dereference guard")]
if let Some(locked_cache) = cache {
match locked_cache.write() {
Ok(mut c) => {
c.insert(key);
}
Err(error) => {
eprintln!("cache set error {error}");
}
}
}
}
#[expect(
clippy::significant_drop_tightening,
reason = "Scope is short enough and there should not be contentions"
)]
fn add_result(
&self,
server: OptName,
response_code: ResponseCode,
results: &[Record],
) -> Result<(), ResolverError> {
let mut res = self
.results
.write()
.map_err(|e| ResolverError(format!("get write lock: {e:?}")))?;
let full = res.entry(server).or_default();
full.response_code = response_code;
for result in results {
full.records.insert(result.clone());
}
Ok(())
}
#[expect(clippy::print_stdout, reason = "called print")]
fn print<S: fmt::Display>(depth: usize, server: &OptName, rest: S, last: &[bool]) {
let mut output = String::new();
for i in 0..depth {
if *last.get(i).unwrap_or(&false) {
output.push_str(" ");
} else {
output.push_str(" |");
}
if i < depth - 1 {
output.push_str(" ");
}
}
if depth > 0 {
output.push_str(r"\___ ");
}
let rest = format!("{rest}");
if rest.is_empty() {
println!("{output}{server}");
} else {
println!("{output}{server} {rest}");
}
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::expect_used,
clippy::unwrap_used,
clippy::indexing_slicing,
reason = "test"
)]
use super::*;
use crate::args::Args;
use hickory_proto::rr::{Name, RData, Record, RecordType, rdata};
use mockall::predicate;
use std::{
net::{IpAddr, Ipv4Addr},
str::FromStr as _,
time::Duration,
};
fn default_args() -> Args {
Args {
domain: "example.com".to_owned(),
no_positive_cache: false,
negative_cache: false,
no_edns0: true,
overview: false,
query_type: RecordType::A,
retries: 3,
server: ".".to_owned(),
timeout: Duration::from_secs(5),
source_address: None,
ipv6: false,
ipv4: true,
tcp: false,
}
}
fn mock_resolver(
args: &Args,
name_resolver: MockNameResolver,
querier: MockDnsQuerier,
) -> RecursiveResolver<'_, MockNameResolver, MockDnsQuerier> {
RecursiveResolver {
results: RwLock::new(HashMap::new()),
name_resolver,
querier,
arguments: args,
positive_cache: (!args.no_positive_cache).then(|| RwLock::new(HashSet::new())),
negative_cache: args.negative_cache.then(|| RwLock::new(HashSet::new())),
}
}
fn delegation_response(ns_name: &str, ns_ip: Ipv4Addr, zone: &str) -> QueryResult {
let ns_name_parsed = Name::from_str(ns_name).unwrap();
let zone_name = Name::from_str(zone).unwrap();
let ns_record = Record::from_rdata(
zone_name,
3600,
RData::NS(rdata::NS(ns_name_parsed.clone())),
);
let glue_record = Record::from_rdata(ns_name_parsed, 3600, RData::A(rdata::A(ns_ip)));
QueryResult {
authoritative: false,
answers: vec![],
name_servers: vec![ns_record],
additionals: vec![glue_record],
response_code: ResponseCode::NoError,
}
}
fn authoritative_a_response(domain: &str, ip: Ipv4Addr) -> QueryResult {
let record =
Record::from_rdata(Name::from_str(domain).unwrap(), 300, RData::A(rdata::A(ip)));
QueryResult {
authoritative: true,
answers: vec![record],
name_servers: vec![],
additionals: vec![],
response_code: ResponseCode::NoError,
}
}
#[test]
fn recursive_resolver_new() {
let args = default_args();
let resolver = RecursiveResolver::new(&args).unwrap();
assert_eq!(*resolver.arguments, args);
assert!(resolver.positive_cache.is_some());
assert!(resolver.negative_cache.is_none());
}
#[test]
fn recursive_resolver_new_2() {
let args = Args {
no_positive_cache: true,
negative_cache: true,
..default_args()
};
let resolver = RecursiveResolver::new(&args).unwrap();
assert_eq!(*resolver.arguments, args);
assert!(resolver.positive_cache.is_none());
assert!(resolver.negative_cache.is_some());
}
#[tokio::test]
async fn recursive_resolver_init_with_ip() {
let args = Args {
server: "8.8.8.8".to_owned(),
..default_args()
};
let resolver = RecursiveResolver::new(&args).unwrap();
let result = resolver.init().await;
assert!(result.is_ok());
let servers = result.unwrap();
assert_eq!(servers.len(), 1);
assert_eq!(servers[0].ip, IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
assert!(servers[0].name.is_none());
}
#[tokio::test]
async fn init_with_dot_uses_ns_then_lookup_ip() {
let args = default_args();
let mut nr = MockNameResolver::new();
nr.expect_ns_lookup()
.with(predicate::eq("."))
.once()
.returning(|_| Ok(vec![Name::from_str("a.root-servers.net.").unwrap()]));
nr.expect_lookup_ip()
.with(predicate::eq("a.root-servers.net."))
.once()
.returning(|_| Ok(vec![IpAddr::from([198, 41, 0, 4])]));
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let servers = resolver.init().await.unwrap();
assert_eq!(servers.len(), 1);
assert_eq!(servers[0].ip, IpAddr::from([198, 41, 0, 4]));
assert_eq!(servers[0].name, Some("a.root-servers.net.".to_owned()));
assert_eq!(servers[0].zone, Some(".".to_owned()));
}
#[tokio::test]
async fn init_with_dot_filters_ipv6_when_ipv4_only() {
let args = default_args();
let mut nr = MockNameResolver::new();
nr.expect_ns_lookup()
.once()
.returning(|_| Ok(vec![Name::from_str("a.root-servers.net.").unwrap()]));
nr.expect_lookup_ip().once().returning(|_| {
let ipv6: IpAddr = "2001:503:ba3e::2:30".parse().unwrap();
Ok(vec![IpAddr::from([198, 41, 0, 4]), ipv6])
});
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let servers = resolver.init().await.unwrap();
assert_eq!(servers.len(), 1);
assert!(servers[0].ip.is_ipv4());
}
#[tokio::test]
async fn init_with_hostname_resolves_ips() {
let args = Args {
server: "ns1.example.com".to_owned(),
..default_args()
};
let mut nr = MockNameResolver::new();
nr.expect_lookup_ip()
.with(predicate::eq("ns1.example.com"))
.once()
.returning(|_| {
Ok(vec![
IpAddr::from([192, 0, 2, 1]),
IpAddr::from([192, 0, 2, 2]),
])
});
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let servers = resolver.init().await.unwrap();
assert_eq!(servers.len(), 2);
assert!(
servers
.iter()
.all(|s| s.name == Some("ns1.example.com".to_owned()))
);
assert!(servers.iter().all(|s| s.zone.is_none()));
}
#[tokio::test]
async fn init_with_no_results_returns_error() {
let args = Args {
server: "ns1.example.com".to_owned(),
..default_args()
};
let mut nr = MockNameResolver::new();
nr.expect_lookup_ip().once().returning(|_| {
let ipv6: IpAddr = "2001:db8::1".parse().unwrap();
Ok(vec![ipv6])
});
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let result = resolver.init().await;
assert!(result.is_err());
}
#[tokio::test]
async fn do_recurse_authoritative_answer_stored_in_results() {
let args = default_args();
let name = Name::from_str("example.com.").unwrap();
let server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: Some("ns1.example.com.".to_owned()),
zone: None,
};
let mut q = MockDnsQuerier::new();
q.expect_query().once().returning(|_, _, _| {
Ok(authoritative_a_response(
"example.com.",
Ipv4Addr::new(93, 184, 216, 34),
))
});
let resolver = mock_resolver(&args, MockNameResolver::new(), q);
resolver
.do_recurse(&name, &server, 1, vec![true])
.await
.unwrap();
let results = resolver.results.read().unwrap();
assert_eq!(results.len(), 1);
let result = results.values().next().unwrap();
assert_eq!(result.response_code, ResponseCode::NoError);
assert_eq!(result.records.len(), 1);
drop(results);
}
#[tokio::test]
async fn do_recurse_uses_ns_at_depth_zero() {
let args = default_args();
let name = Name::from_str("example.com.").unwrap();
let server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: Some("ns1.example.com.".to_owned()),
zone: None,
};
let mut q = MockDnsQuerier::new();
q.expect_query()
.withf(|_, _, qtype| *qtype == RecordType::NS)
.once()
.returning(|_, _, _| Ok(QueryResult::default()));
let resolver = mock_resolver(&args, MockNameResolver::new(), q);
resolver
.do_recurse(&name, &server, 0, vec![])
.await
.unwrap();
}
#[tokio::test]
async fn do_recurse_follows_ns_delegation() {
let args = default_args();
let name = Name::from_str("example.com.").unwrap();
let first_server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: Some("ns1.example.com.".to_owned()),
zone: None,
};
let mut q = MockDnsQuerier::new();
q.expect_query()
.withf(|server, _, _| server.ip == IpAddr::from([192, 0, 2, 1]))
.once()
.returning(|_, _, _| {
Ok(delegation_response(
"ns2.example.com.",
Ipv4Addr::new(192, 0, 2, 2),
"example.com.",
))
});
q.expect_query()
.withf(|server, _, _| server.ip == IpAddr::from([192, 0, 2, 2]))
.once()
.returning(|_, _, _| {
Ok(authoritative_a_response(
"example.com.",
Ipv4Addr::new(93, 184, 216, 34),
))
});
let resolver = mock_resolver(&args, MockNameResolver::new(), q);
resolver
.do_recurse(&name, &first_server, 1, vec![])
.await
.unwrap();
let results = resolver.results.read().unwrap();
assert_eq!(results.len(), 1);
drop(results);
}
#[tokio::test]
async fn do_recurse_skips_cached_servers() {
let args = default_args();
let name = Name::from_str("example.com.").unwrap();
let server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: Some("ns1.example.com.".to_owned()),
zone: None,
};
let mut q = MockDnsQuerier::new();
q.expect_query().never();
let resolver = mock_resolver(&args, MockNameResolver::new(), q);
resolver
.positive_cache
.as_ref()
.unwrap()
.write()
.unwrap()
.insert((server.ip, name.clone()));
resolver
.do_recurse(&name, &server, 1, vec![])
.await
.unwrap();
}
#[tokio::test]
async fn do_recurse_sets_negative_cache_on_error() {
let args = Args {
negative_cache: true,
..default_args()
};
let name = Name::from_str("example.com.").unwrap();
let server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: Some("ns1.example.com.".to_owned()),
zone: None,
};
let mut q = MockDnsQuerier::new();
q.expect_query()
.once()
.returning(|_, _, _| Err(ResolverError("simulated timeout".to_owned()).into()));
let resolver = mock_resolver(&args, MockNameResolver::new(), q);
resolver
.do_recurse(&name, &server, 1, vec![])
.await
.unwrap();
let neg = resolver.negative_cache.as_ref().unwrap().read().unwrap();
assert!(neg.contains(&(server.ip, name)));
drop(neg);
}
#[tokio::test]
async fn get_next_servers_uses_glue_records_from_additionals() {
let args = default_args();
let name = Name::from_str("example.com.").unwrap();
let current_server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: None,
zone: None,
};
let nr = MockNameResolver::new();
let ns_name = Name::from_str("ns1.example.com.").unwrap();
let zone_name = Name::from_str("example.com.").unwrap();
let ns_record = Record::from_rdata(zone_name, 3600, RData::NS(rdata::NS(ns_name.clone())));
let glue = Record::from_rdata(ns_name, 3600, RData::A(rdata::A(Ipv4Addr::new(1, 2, 3, 4))));
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let next = resolver
.get_next_servers(&[ns_record], &[glue], ¤t_server, &name, 1, &[true])
.await;
assert_eq!(next.len(), 1);
assert_eq!(next[0].ip, IpAddr::from([1, 2, 3, 4]));
}
#[tokio::test]
async fn get_next_servers_falls_back_to_lookup_when_no_glue() {
let args = default_args();
let name = Name::from_str("example.com.").unwrap();
let current_server = OptName {
ip: IpAddr::from([192, 0, 2, 1]),
name: None,
zone: None,
};
let mut nr = MockNameResolver::new();
nr.expect_lookup_ip()
.with(predicate::eq("ns1.example.com."))
.once()
.returning(|_| Ok(vec![IpAddr::from([5, 6, 7, 8])]));
let ns_name = Name::from_str("ns1.example.com.").unwrap();
let zone_name = Name::from_str("example.com.").unwrap();
let ns_record = Record::from_rdata(zone_name, 3600, RData::NS(rdata::NS(ns_name)));
let resolver = mock_resolver(&args, nr, MockDnsQuerier::new());
let next = resolver
.get_next_servers(
&[ns_record],
&[], ¤t_server,
&name,
1,
&[true],
)
.await;
assert_eq!(next.len(), 1);
assert_eq!(next[0].ip, IpAddr::from([5, 6, 7, 8]));
}
#[test]
fn query_result_default_values() {
let qr = QueryResult::default();
assert!(!qr.authoritative);
assert_eq!(qr.response_code, ResponseCode::NoError);
assert!(qr.answers.is_empty());
assert!(qr.name_servers.is_empty());
assert!(qr.additionals.is_empty());
}
}