use std::{
net::{IpAddr, SocketAddr},
str::FromStr,
time::Duration,
};
use rand::Rng;
use serde::Serialize;
use thiserror::Error;
use trust_dns_resolver::{
config::{LookupIpStrategy, NameServerConfig, Protocol, ResolverConfig, ResolverOpts},
error::{ResolveError, ResolveErrorKind},
lookup_ip::LookupIp,
proto::{op::ResponseCode, rr::RecordType, xfer::DnsRequestOptions},
TokioAsyncResolver,
};
pub struct Resolver {
inner: TokioAsyncResolver,
}
impl Resolver {
fn new(inner: TokioAsyncResolver) -> Self {
Self { inner }
}
pub fn builder() -> ResolverBuilder {
ResolverBuilder::new()
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn lookup_address(&self, hostname: &str) -> Result<AddressResponse, ResolverError> {
let result = self.inner.lookup_ip(hostname).await;
match result {
Ok(items) => self.process_address_ok(items),
Err(error) => self.process_address_err(error),
}
}
fn process_address_ok(&self, items: LookupIp) -> Result<AddressResponse, ResolverError> {
let mut address_response = AddressResponse::default();
address_response.addresses.extend(items.iter());
for record in items.as_lookup().record_iter() {
address_response.text_records.push(format!("{}", record));
}
tracing::debug!(count = address_response.addresses.len(), "ok");
Ok(address_response)
}
fn process_address_err(&self, error: ResolveError) -> Result<AddressResponse, ResolverError> {
if let ResolveErrorKind::NoRecordsFound {
query: _,
soa: _,
negative_ttl: _,
response_code,
trusted: _,
} = error.kind()
{
tracing::debug!(response_code = response_code.to_str(), "err");
}
Err(error.into())
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn lookup_record(
&self,
record_type: &str,
hostname: &str,
) -> Result<Vec<String>, ResolverError> {
let record_type = Self::parse_record_type(record_type)?;
let response = self
.inner
.lookup(hostname, record_type, DnsRequestOptions::default())
.await?;
let mut text_records = Vec::new();
for record in response.record_iter() {
text_records.push(record.to_string())
}
Ok(text_records)
}
fn parse_record_type(record_type: &str) -> Result<RecordType, ResolverError> {
if let Ok(value) = record_type.parse::<u16>() {
return Ok(RecordType::from(value));
}
match RecordType::from_str(record_type) {
Ok(value) => Ok(value),
Err(_) => Err(ResolverError::Io(std::io::ErrorKind::InvalidInput.into())),
}
}
pub async fn clear_cache(&mut self) {
self.inner.clear_cache().await;
}
}
pub struct ResolverBuilder {
bind_address: Option<SocketAddr>,
doh_servers: Vec<(SocketAddr, String)>,
dnssec: bool,
}
impl Default for ResolverBuilder {
fn default() -> Self {
Self::new()
}
}
impl ResolverBuilder {
pub fn new() -> Self {
Self {
bind_address: None,
doh_servers: Vec::new(),
dnssec: false,
}
}
pub fn with_bind_address(mut self, address: SocketAddr) -> Self {
self.bind_address = Some(address);
self
}
pub fn with_doh_server(mut self, address: SocketAddr, hostname: &str) -> Self {
self.doh_servers.push((address, hostname.to_string()));
self
}
pub fn with_dnssec(mut self, value: bool) -> Self {
self.dnssec = value;
self
}
pub fn build(&self) -> Resolver {
let mut opts = ResolverOpts::default();
opts.timeout = Duration::from_secs(10);
opts.attempts = 1;
opts.edns0 = true;
opts.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
opts.cache_size = 128;
opts.use_hosts_file = false;
opts.preserve_intermediates = true;
let mut config = ResolverConfig::new();
for server in &self.doh_servers {
let server_config = NameServerConfig {
socket_addr: server.0,
protocol: Protocol::Https,
tls_dns_name: Some(server.1.to_string()),
trust_nx_responses: false,
tls_config: None,
bind_addr: self.bind_address,
};
config.add_name_server(server_config);
}
Resolver::new(TokioAsyncResolver::tokio(config, opts).unwrap())
}
}
#[derive(Default, Serialize)]
pub struct AddressResponse {
addresses: Vec<IpAddr>,
text_records: Vec<String>,
}
impl AddressResponse {
pub fn addresses(&self) -> &[IpAddr] {
&self.addresses
}
pub fn text_records(&self) -> &[String] {
&self.text_records
}
}
#[derive(Error, Debug)]
pub enum ResolverError {
#[error("non-existent domain")]
NoName,
#[error("no records for given record type")]
NoRecord,
#[error("negative response: {0}")]
Other(&'static str),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
OtherInternal(ResolveError),
}
impl From<ResolveError> for ResolverError {
fn from(error: ResolveError) -> Self {
match error.kind() {
ResolveErrorKind::NoRecordsFound {
query: _,
soa: _,
negative_ttl: _,
response_code: ResponseCode::NXDomain,
trusted: _,
} => Self::NoName,
ResolveErrorKind::NoRecordsFound {
query: _,
soa: _,
negative_ttl: _,
response_code: ResponseCode::NoError,
trusted: _,
} => Self::NoRecord,
ResolveErrorKind::NoRecordsFound {
query: _,
soa: _,
negative_ttl: _,
response_code,
trusted: _,
} => Self::Other(response_code.to_str()),
_ => Self::OtherInternal(error),
}
}
}
pub fn random_domain() -> String {
let length = rand::thread_rng().gen_range(20usize..=50usize);
let label = rand::thread_rng()
.sample_iter(rand::distributions::Alphanumeric)
.take(length)
.map(char::from)
.collect::<String>();
format!("{}.net", label)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_random_domain() {
let result = random_domain();
assert!(result.len() > 20);
assert!(result.len() < 60);
assert!(result.contains('.'));
}
#[test_log::test(tokio::test)]
#[ignore = "external resources"]
async fn test_resolver() {
let resolver = ResolverBuilder::new()
.with_doh_server("1.1.1.1:443".parse().unwrap(), "cloudflare-dns.com")
.with_doh_server("8.8.8.8:443".parse().unwrap(), "dns.google")
.build();
let result = resolver.lookup_address("www.icanhascheezburger.com").await;
assert!(matches!(result, Ok(_)));
let lookup = result.unwrap();
assert!(!lookup.addresses.is_empty());
assert!(!lookup.text_records.is_empty());
assert!(matches!(
resolver.lookup_address(&random_domain()).await,
Err(ResolverError::NoName)
));
}
}