use async_trait::async_trait;
use domain::{
base::{Dname, Rtype},
rdata::{Aaaa, Mx, Txt, A},
resolv::{
stub::conf::{ResolvConf, ResolvOptions},
StubResolver,
},
};
use std::{
error::Error,
io::{self, ErrorKind},
net::{IpAddr, Ipv4Addr, Ipv6Addr},
str::FromStr,
sync::Arc,
time::Duration,
};
use viaspf::lookup::{Lookup, LookupError, LookupResult, Name};
pub enum Resolver {
Live(DomainResolver),
Mock(Arc<Box<dyn Lookup>>),
}
pub struct DomainResolver {
resolver: StubResolver,
}
impl DomainResolver {
pub fn new(timeout: Duration) -> Self {
let options = ResolvOptions {
timeout,
..Default::default()
};
let mut conf = ResolvConf {
options,
..Default::default()
};
conf.finalize();
let resolver = StubResolver::from_conf(conf);
Self { resolver }
}
}
#[async_trait]
impl Lookup for DomainResolver {
async fn lookup_a(&self, name: &Name) -> LookupResult<Vec<Ipv4Addr>> {
let name = to_dname(name)?;
self.resolver
.query((name, Rtype::A))
.await
.map_err(to_lookup_error)?
.answer()
.map_err(wrap_error)?
.limit_to::<A>()
.map(|record| record.map(|r| r.data().addr()))
.collect::<Result<Vec<_>, _>>()
.map_err(wrap_error)
}
async fn lookup_aaaa(&self, name: &Name) -> LookupResult<Vec<Ipv6Addr>> {
let name = to_dname(name)?;
self.resolver
.query((name, Rtype::Aaaa))
.await
.map_err(to_lookup_error)?
.answer()
.map_err(wrap_error)?
.limit_to::<Aaaa>()
.map(|record| record.map(|r| r.data().addr()))
.collect::<Result<Vec<_>, _>>()
.map_err(wrap_error)
}
async fn lookup_mx(&self, name: &Name) -> LookupResult<Vec<Name>> {
let name = to_dname(name)?;
let answer = self
.resolver
.query((name, Rtype::Mx))
.await
.map_err(to_lookup_error)?;
let mut mxs = answer
.answer()
.map_err(wrap_error)?
.limit_to::<Mx<_>>()
.map(|record| record.map(|r| r.into_data()))
.collect::<Result<Vec<_>, _>>()
.map_err(wrap_error)?;
mxs.sort_by_key(|mx| mx.preference());
mxs.into_iter()
.map(|mx| Name::new(&mx.exchange().to_string()))
.collect::<Result<Vec<_>, _>>()
.map_err(wrap_error)
}
async fn lookup_txt(&self, name: &Name) -> LookupResult<Vec<String>> {
let name = to_dname(name)?;
self.resolver
.query((name, Rtype::Txt))
.await
.map_err(to_lookup_error)?
.answer()
.map_err(wrap_error)?
.limit_to::<Txt<_>>()
.map(|record| {
record.map(|r| {
let text = r.into_data().text::<Vec<_>>();
String::from_utf8_lossy(&text).into_owned()
})
})
.collect::<Result<Vec<_>, _>>()
.map_err(wrap_error)
}
async fn lookup_ptr(&self, ip: IpAddr) -> LookupResult<Vec<Name>> {
self.resolver
.lookup_addr(ip)
.await
.map_err(to_lookup_error)?
.into_iter()
.map(|ptr| Name::new(&ptr.to_string()))
.collect::<Result<Vec<_>, _>>()
.map_err(wrap_error)
}
}
fn to_dname(name: &Name) -> LookupResult<Dname<Vec<u8>>> {
Dname::from_str(name.as_str()).map_err(wrap_error)
}
fn to_lookup_error(error: io::Error) -> LookupError {
match error.kind() {
ErrorKind::NotFound => LookupError::NoRecords,
ErrorKind::TimedOut => LookupError::Timeout,
_ => wrap_error(error),
}
}
fn wrap_error(error: impl Error + Send + Sync + 'static) -> LookupError {
LookupError::Dns(Some(error.into()))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore = "depends on live DNS records"]
async fn domain_resolver_lookup_ok() {
let resolver = DomainResolver::new(Duration::from_secs(30));
let domain = Name::new("gluet.ch").unwrap();
let ips = resolver.lookup_a(&domain).await;
assert!(ips.is_ok());
let ip = ips.unwrap().into_iter().next().unwrap();
assert!(resolver.lookup_ptr(ip.into()).await.is_ok());
let ips = resolver.lookup_aaaa(&domain).await;
assert!(ips.is_ok());
let ip = ips.unwrap().into_iter().next().unwrap();
assert!(resolver.lookup_ptr(ip.into()).await.is_ok());
}
}