use std::cell::Cell;
use std::io;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs};
use std::time::{Duration, Instant};
use std::vec::IntoIter;
use log::info;
use crate::address::address_name;
use crate::config::DnsConfig;
use crate::message::{Message, Qr, Question, MESSAGE_LIMIT};
use crate::record::{Class, Ptr, Record, RecordType, A, AAAA};
use crate::socket::{DnsSocket, Error};
pub struct DnsResolver {
sock: DnsSocket,
config: DnsConfig,
next_ns: Cell<usize>,
}
impl DnsResolver {
pub fn new(config: DnsConfig) -> io::Result<DnsResolver> {
let bind = bind_addr(&config.name_servers);
let sock = DnsSocket::bind((bind, 0))?;
DnsResolver::with_sock(sock, config)
}
pub fn bind<A: ToSocketAddrs>(addr: A, config: DnsConfig) -> io::Result<DnsResolver> {
let sock = DnsSocket::bind(addr)?;
DnsResolver::with_sock(sock, config)
}
fn with_sock(sock: DnsSocket, config: DnsConfig) -> io::Result<DnsResolver> {
Ok(DnsResolver {
sock,
config,
next_ns: Cell::new(0),
})
}
pub fn resolve_addr(&self, addr: &IpAddr) -> io::Result<String> {
convert_error("failed to resolve address", || {
let mut out_msg = self.basic_message();
out_msg.question.push(Question::new(
address_name(addr),
RecordType::Ptr,
Class::Internet,
));
let mut buf = [0; MESSAGE_LIMIT];
let msg = self.send_message(&out_msg, &mut buf)?;
for rr in msg.answer.into_iter() {
if rr.r_type == RecordType::Ptr {
let ptr = rr.read_rdata::<Ptr>()?;
let mut name = ptr.name;
if name.ends_with('.') {
name.pop();
}
return Ok(name);
}
}
Err(Error::IoError(io::Error::new(
io::ErrorKind::Other,
"failed to resolve address: name not found",
)))
})
}
pub fn resolve_host(&self, host: &str) -> io::Result<ResolveHost> {
convert_error("failed to resolve host", || {
query_names(host, &self.config, |name| {
let mut err;
let mut res = Vec::new();
info!("attempting lookup of name \"{}\"", name);
if self.config.use_inet6 {
err = self
.resolve_host_v6(&name, |ip| res.push(IpAddr::V6(ip)))
.err();
if res.is_empty() {
err = err.or_else(|| {
self.resolve_host_v4(&name, |ip| {
res.push(IpAddr::V6(ip.to_ipv6_mapped()))
})
.err()
});
}
} else {
err = self
.resolve_host_v4(&name, |ip| res.push(IpAddr::V4(ip)))
.err();
err = err.or_else(|| {
self.resolve_host_v6(&name, |ip| res.push(IpAddr::V6(ip)))
.err()
});
}
if !res.is_empty() {
return Ok(ResolveHost(res.into_iter()));
}
if let Some(e) = err {
Err(e)
} else {
Err(Error::IoError(io::Error::new(
io::ErrorKind::Other,
"failed to resolve host: name not found",
)))
}
})
})
}
pub fn resolve_record<Rec: Record>(&self, name: &str) -> io::Result<Vec<Rec>> {
convert_error("failed to resolve record", || {
let r_ty = Rec::record_type();
let mut msg = self.basic_message();
msg.question
.push(Question::new(name.to_owned(), r_ty, Class::Internet));
let mut buf = [0; MESSAGE_LIMIT];
let reply = self.send_message(&msg, &mut buf)?;
let mut rec = Vec::new();
for rr in reply.answer.into_iter() {
if rr.r_type == r_ty {
rec.push(rr.read_rdata::<Rec>()?);
}
}
Ok(rec)
})
}
fn resolve_host_v4<F>(&self, host: &str, mut f: F) -> Result<(), Error>
where
F: FnMut(Ipv4Addr),
{
let mut out_msg = self.basic_message();
out_msg.question.push(Question::new(
host.to_owned(),
RecordType::A,
Class::Internet,
));
let mut buf = [0; MESSAGE_LIMIT];
let msg = self.send_message(&out_msg, &mut buf)?;
for rr in msg.answer.into_iter() {
if rr.r_type == RecordType::A {
let a = rr.read_rdata::<A>()?;
f(a.address);
}
}
Ok(())
}
fn resolve_host_v6<F>(&self, host: &str, mut f: F) -> Result<(), Error>
where
F: FnMut(Ipv6Addr),
{
let mut out_msg = self.basic_message();
out_msg.question.push(Question::new(
host.to_owned(),
RecordType::AAAA,
Class::Internet,
));
let mut buf = [0; MESSAGE_LIMIT];
let msg = self.send_message(&out_msg, &mut buf)?;
for rr in msg.answer.into_iter() {
if rr.r_type == RecordType::AAAA {
let aaaa = rr.read_rdata::<AAAA>()?;
f(aaaa.address);
}
}
Ok(())
}
fn basic_message(&self) -> Message {
let mut msg = Message::new();
msg.header.recursion_desired = true;
msg
}
#[allow(dropping_references)] pub fn send_message<'buf>(
&self,
out_msg: &Message,
buf: &'buf mut [u8],
) -> Result<Message<'buf>, Error> {
let mut last_err = None;
let buf_ptr = buf as *mut _;
drop(buf);
'retry: for retries in 0..self.config.attempts {
let ns_addr = if self.config.rotate {
self.next_nameserver()
} else {
let n = self.config.name_servers.len();
self.config.name_servers[retries as usize % n]
};
let mut timeout = self.config.timeout;
info!("resolver sending message to {}", ns_addr);
self.sock.send_message(out_msg, ns_addr)?;
loop {
self.sock.get().set_read_timeout(Some(timeout))?;
let buf = unsafe { &mut *buf_ptr };
let start = Instant::now();
match self.sock.recv_message(&ns_addr, buf) {
Ok(None) => {
let passed = start.elapsed();
if timeout < passed {
timeout = Duration::from_secs(0);
} else {
timeout -= passed;
}
}
Ok(Some(msg)) => {
if msg.header.id == out_msg.header.id && msg.header.qr == Qr::Response {
msg.get_error()?;
return Ok(msg);
}
}
Err(e) => {
if e.is_timeout() {
last_err = Some(e);
continue 'retry;
}
return Err(e);
}
}
}
}
Err(last_err.unwrap())
}
fn next_nameserver(&self) -> SocketAddr {
let n = self.next_ns.get();
self.next_ns.set((n + 1) % self.config.name_servers.len());
self.config.name_servers[n]
}
}
fn bind_addr(name_servers: &[SocketAddr]) -> IpAddr {
match name_servers.first() {
Some(&SocketAddr::V6(_)) => IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
_ => IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
}
}
fn convert_error<T, F>(desc: &str, f: F) -> io::Result<T>
where
F: FnOnce() -> Result<T, Error>,
{
match f() {
Ok(t) => Ok(t),
Err(Error::IoError(e)) => Err(e),
Err(e) => Err(io::Error::new(
io::ErrorKind::Other,
format!("{}: {}", desc, e),
)),
}
}
fn query_names<F, T>(name: &str, config: &DnsConfig, mut f: F) -> Result<T, Error>
where
F: FnMut(String) -> Result<T, Error>,
{
let use_search =
!name.ends_with('.') && name.chars().filter(|&c| c == '.').count() as u32 >= config.n_dots;
if use_search {
let mut err = None;
for name in with_suffixes(name, &config.search) {
match f(name) {
Ok(t) => return Ok(t),
Err(e) => err = Some(e),
}
}
if let Some(e) = err {
Err(e)
} else {
Err(Error::IoError(io::Error::new(
io::ErrorKind::Other,
"failed to resolve host: name not found",
)))
}
} else {
f(name.to_owned())
}
}
fn with_suffixes(host: &str, suffixes: &[String]) -> Vec<String> {
let mut v = suffixes
.iter()
.map(|s| format!("{}.{}", host, s))
.collect::<Vec<_>>();
v.push(host.to_owned());
v
}
pub fn resolve_addr(addr: &IpAddr) -> io::Result<String> {
let r = DnsResolver::new(DnsConfig::load_default()?)?;
r.resolve_addr(addr)
}
pub fn resolve_host(host: &str) -> io::Result<ResolveHost> {
let r = DnsResolver::new(DnsConfig::load_default()?)?;
r.resolve_host(host)
}
pub struct ResolveHost(IntoIter<IpAddr>);
impl Iterator for ResolveHost {
type Item = IpAddr;
fn next(&mut self) -> Option<IpAddr> {
self.0.next()
}
}