use self::conf::{
ResolvConf, ResolvOptions, SearchSuffix, ServerConf, Transport,
};
use crate::base::iana::Rcode;
use crate::base::message::Message;
use crate::base::message_builder::{
AdditionalBuilder, MessageBuilder, StreamTarget,
};
use crate::base::name::{ToDname, ToRelativeDname};
use crate::base::octets::Octets512;
use crate::base::question::Question;
use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs};
use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts};
use crate::resolv::lookup::srv::{lookup_srv, FoundSrvs, SrvError};
use crate::resolv::resolver::{Resolver, SearchNames};
use bytes::Bytes;
use std::boxed::Box;
use std::future::Future;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::vec::Vec;
use std::{io, ops};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};
#[cfg(feature = "resolv-sync")]
use tokio::runtime;
use tokio::time::timeout;
pub mod conf;
const RETRY_RANDOM_PORT: usize = 10;
#[derive(Clone, Debug)]
pub struct StubResolver {
preferred: ServerList,
stream: ServerList,
options: ResolvOptions,
}
impl StubResolver {
pub fn new() -> Self {
Self::from_conf(ResolvConf::default())
}
pub fn from_conf(conf: ResolvConf) -> Self {
StubResolver {
preferred: ServerList::from_conf(&conf, |s| {
s.transport.is_preferred()
}),
stream: ServerList::from_conf(&conf, |s| s.transport.is_stream()),
options: conf.options,
}
}
pub fn options(&self) -> &ResolvOptions {
&self.options
}
pub async fn query<N: ToDname, Q: Into<Question<N>>>(
&self,
question: Q,
) -> Result<Answer, io::Error> {
Query::new(self)?
.run(Query::create_message(question.into()))
.await
}
async fn query_message(
&self,
message: QueryMessage,
) -> Result<Answer, io::Error> {
Query::new(self)?.run(message).await
}
}
impl StubResolver {
pub async fn lookup_addr(
&self,
addr: IpAddr,
) -> Result<FoundAddrs<&Self>, io::Error> {
lookup_addr(&self, addr).await
}
pub async fn lookup_host(
&self,
qname: impl ToDname,
) -> Result<FoundHosts<&Self>, io::Error> {
lookup_host(&self, qname).await
}
pub async fn search_host(
&self,
qname: impl ToRelativeDname,
) -> Result<FoundHosts<&Self>, io::Error> {
search_host(&self, qname).await
}
pub async fn lookup_srv(
&self,
service: impl ToRelativeDname,
name: impl ToDname,
fallback_port: u16,
) -> Result<Option<FoundSrvs>, SrvError> {
lookup_srv(&self, service, name, fallback_port).await
}
}
#[cfg(feature = "resolv-sync")]
#[cfg_attr(docsrs, doc(cfg(feature = "resolv-sync")))]
impl StubResolver {
pub fn run<R, F>(op: F) -> R::Output
where
R: Future + Send + 'static,
R::Output: Send + 'static,
F: FnOnce(StubResolver) -> R + Send + 'static,
{
Self::run_with_conf(ResolvConf::default(), op)
}
pub fn run_with_conf<R, F>(conf: ResolvConf, op: F) -> R::Output
where
R: Future + Send + 'static,
R::Output: Send + 'static,
F: FnOnce(StubResolver) -> R + Send + 'static,
{
let resolver = Self::from_conf(conf);
let runtime = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(op(resolver))
}
}
impl Default for StubResolver {
fn default() -> Self {
Self::new()
}
}
impl<'a> Resolver for &'a StubResolver {
type Octets = Bytes;
type Answer = Answer;
type Query =
Pin<Box<dyn Future<Output = Result<Answer, io::Error>> + Send + 'a>>;
fn query<N, Q>(&self, question: Q) -> Self::Query
where
N: ToDname,
Q: Into<Question<N>>,
{
let message = Query::create_message(question.into());
Box::pin(self.query_message(message))
}
}
impl<'a> SearchNames for &'a StubResolver {
type Name = SearchSuffix;
type Iter = SearchIter<'a>;
fn search_iter(&self) -> Self::Iter {
SearchIter {
resolver: self,
pos: 0,
}
}
}
pub struct Query<'a> {
resolver: &'a StubResolver,
preferred: bool,
attempt: usize,
counter: ServerListCounter,
error: Result<Answer, io::Error>,
}
impl<'a> Query<'a> {
pub fn new(resolver: &'a StubResolver) -> Result<Self, io::Error> {
let (preferred, counter) =
if resolver.options().use_vc || resolver.preferred.is_empty() {
if resolver.stream.is_empty() {
return Err(io::Error::new(
io::ErrorKind::NotFound,
"no servers available",
));
}
(false, resolver.stream.counter(resolver.options().rotate))
} else {
(true, resolver.preferred.counter(resolver.options().rotate))
};
Ok(Query {
resolver,
preferred,
attempt: 0,
counter,
error: Err(io::Error::new(
io::ErrorKind::TimedOut,
"all timed out",
)),
})
}
pub async fn run(
mut self,
mut message: QueryMessage,
) -> Result<Answer, io::Error> {
loop {
match self.run_query(&mut message).await {
Ok(answer) => {
if answer.header().rcode() == Rcode::FormErr
&& self.current_server().does_edns()
{
self.current_server().disable_edns();
continue;
} else if answer.header().rcode() == Rcode::ServFail {
self.update_error_servfail(answer);
} else if answer.header().tc()
&& self.preferred
&& !self.resolver.options().ign_tc
{
if self.switch_to_stream() {
continue;
} else {
return Ok(answer);
}
} else {
return Ok(answer);
}
}
Err(err) => self.update_error(err),
}
if !self.next_server() {
return self.error;
}
}
}
fn create_message(question: Question<impl ToDname>) -> QueryMessage {
let mut message = MessageBuilder::from_target(
StreamTarget::new(Octets512::new()).unwrap(),
)
.unwrap();
message.header_mut().set_rd(true);
let mut message = message.question();
message.push(question).unwrap();
message.additional()
}
async fn run_query(
&mut self,
message: &mut QueryMessage,
) -> Result<Answer, io::Error> {
let server = self.current_server();
server.prepare_message(message);
server.query(message).await
}
fn current_server(&self) -> &ServerInfo {
let list = if self.preferred {
&self.resolver.preferred
} else {
&self.resolver.stream
};
self.counter.info(list)
}
fn update_error(&mut self, err: io::Error) {
if err.kind() != io::ErrorKind::TimedOut && self.error.is_err() {
self.error = Err(err)
}
}
fn update_error_servfail(&mut self, answer: Answer) {
self.error = Ok(answer)
}
fn switch_to_stream(&mut self) -> bool {
if !self.preferred {
return false;
}
self.preferred = false;
self.attempt = 0;
self.counter =
self.resolver.stream.counter(self.resolver.options().rotate);
true
}
fn next_server(&mut self) -> bool {
if self.counter.next() {
return true;
}
self.attempt += 1;
if self.attempt >= self.resolver.options().attempts {
return false;
}
self.counter = if self.preferred {
self.resolver
.preferred
.counter(self.resolver.options().rotate)
} else {
self.resolver.stream.counter(self.resolver.options().rotate)
};
true
}
}
pub(super) type QueryMessage = AdditionalBuilder<StreamTarget<Octets512>>;
#[derive(Clone)]
pub struct Answer {
message: Message<Bytes>,
}
impl Answer {
pub fn is_final(&self) -> bool {
(self.message.header().rcode() == Rcode::NoError
|| self.message.header().rcode() == Rcode::NXDomain)
&& !self.message.header().tc()
}
pub fn is_truncated(&self) -> bool {
self.message.header().tc()
}
pub fn into_message(self) -> Message<Bytes> {
self.message
}
}
impl From<Message<Bytes>> for Answer {
fn from(message: Message<Bytes>) -> Self {
Answer { message }
}
}
#[derive(Clone, Debug)]
struct ServerInfo {
conf: ServerConf,
edns: Arc<AtomicBool>,
}
impl ServerInfo {
pub fn does_edns(&self) -> bool {
self.edns.load(Ordering::Relaxed)
}
pub fn disable_edns(&self) {
self.edns.store(false, Ordering::Relaxed);
}
pub fn prepare_message(&self, query: &mut QueryMessage) {
query.rewind();
if self.does_edns() {
query
.opt(|opt| {
opt.set_udp_payload_size(self.conf.udp_payload_size);
Ok(())
})
.unwrap();
}
}
pub async fn query(
&self,
query: &QueryMessage,
) -> Result<Answer, io::Error> {
let res = match self.conf.transport {
Transport::Udp => {
timeout(
self.conf.request_timeout,
Self::udp_query(
query,
self.conf.addr,
self.conf.recv_size,
),
)
.await
}
Transport::Tcp => {
timeout(
self.conf.request_timeout,
Self::tcp_query(query, self.conf.addr),
)
.await
}
};
match res {
Ok(Ok(answer)) => Ok(answer),
Ok(Err(err)) => Err(err),
Err(_) => Err(io::Error::new(
io::ErrorKind::TimedOut,
"request timed out",
)),
}
}
pub async fn tcp_query(
query: &QueryMessage,
addr: SocketAddr,
) -> Result<Answer, io::Error> {
let mut sock = TcpStream::connect(&addr).await?;
sock.write_all(query.as_target().as_stream_slice()).await?;
loop {
let mut buf = Vec::new();
let len = sock.read_u16().await? as u64;
AsyncReadExt::take(&mut sock, len)
.read_to_end(&mut buf)
.await?;
if let Ok(answer) = Message::from_octets(buf.into()) {
if answer.is_answer(&query.as_message()) {
return Ok(answer.into());
}
} else {
return Err(io::Error::new(
io::ErrorKind::Other,
"short buf",
));
}
}
}
pub async fn udp_query(
query: &QueryMessage,
addr: SocketAddr,
recv_size: usize,
) -> Result<Answer, io::Error> {
let sock = Self::udp_bind(addr.is_ipv4()).await?;
sock.connect(addr).await?;
let sent = sock.send(query.as_target().as_dgram_slice()).await?;
if sent != query.as_target().as_dgram_slice().len() {
return Err(io::Error::new(
io::ErrorKind::Other,
"short UDP send",
));
}
loop {
let mut buf = vec![0; recv_size]; let len = sock.recv(&mut buf).await?;
buf.truncate(len);
let answer = match Message::from_octets(buf.into()) {
Ok(answer) => answer,
Err(_) => continue,
};
if !answer.is_answer(&query.as_message()) {
continue;
}
return Ok(answer.into());
}
}
async fn udp_bind(v4: bool) -> Result<UdpSocket, io::Error> {
let mut i = 0;
loop {
let local: SocketAddr = if v4 {
([0u8; 4], 0).into()
} else {
([0u16; 8], 0).into()
};
match UdpSocket::bind(&local).await {
Ok(sock) => return Ok(sock),
Err(err) => {
if i == RETRY_RANDOM_PORT {
return Err(err);
} else {
i += 1
}
}
}
}
}
}
impl From<ServerConf> for ServerInfo {
fn from(conf: ServerConf) -> Self {
ServerInfo {
conf,
edns: Arc::new(AtomicBool::new(true)),
}
}
}
impl<'a> From<&'a ServerConf> for ServerInfo {
fn from(conf: &'a ServerConf) -> Self {
conf.clone().into()
}
}
#[derive(Clone, Debug)]
struct ServerList {
servers: Vec<ServerInfo>,
start: Arc<AtomicUsize>,
}
impl ServerList {
pub fn from_conf<F>(conf: &ResolvConf, filter: F) -> Self
where
F: Fn(&ServerConf) -> bool,
{
ServerList {
servers: {
conf.servers
.iter()
.filter(|f| filter(f))
.map(Into::into)
.collect()
},
start: Arc::new(AtomicUsize::new(0)),
}
}
pub fn is_empty(&self) -> bool {
self.servers.is_empty()
}
pub fn counter(&self, rotate: bool) -> ServerListCounter {
let res = ServerListCounter::new(self);
if rotate {
self.rotate()
}
res
}
pub fn iter(&self) -> ServerListIter {
ServerListIter::new(self)
}
pub fn rotate(&self) {
self.start.fetch_add(1, Ordering::SeqCst);
}
}
impl<'a> IntoIterator for &'a ServerList {
type Item = &'a ServerInfo;
type IntoIter = ServerListIter<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl ops::Deref for ServerList {
type Target = [ServerInfo];
fn deref(&self) -> &Self::Target {
self.servers.as_ref()
}
}
#[derive(Clone, Debug)]
struct ServerListCounter {
cur: usize,
end: usize,
}
impl ServerListCounter {
fn new(list: &ServerList) -> Self {
if list.servers.is_empty() {
return ServerListCounter { cur: 0, end: 0 };
}
let start = list.start.load(Ordering::Relaxed) % list.servers.len();
ServerListCounter {
cur: start,
end: start + list.servers.len(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> bool {
let next = self.cur + 1;
if next < self.end {
self.cur = next;
true
} else {
false
}
}
pub fn info<'a>(&self, list: &'a ServerList) -> &'a ServerInfo {
&list[self.cur % list.servers.len()]
}
}
#[derive(Clone, Debug)]
struct ServerListIter<'a> {
servers: &'a ServerList,
counter: ServerListCounter,
}
impl<'a> ServerListIter<'a> {
fn new(list: &'a ServerList) -> Self {
ServerListIter {
servers: list,
counter: ServerListCounter::new(list),
}
}
}
impl<'a> Iterator for ServerListIter<'a> {
type Item = &'a ServerInfo;
fn next(&mut self) -> Option<Self::Item> {
if self.counter.next() {
Some(self.counter.info(self.servers))
} else {
None
}
}
}
impl ops::Deref for Answer {
type Target = Message<Bytes>;
fn deref(&self) -> &Self::Target {
&self.message
}
}
impl AsRef<Message<Bytes>> for Answer {
fn as_ref(&self) -> &Message<Bytes> {
&self.message
}
}
#[derive(Clone, Debug)]
pub struct SearchIter<'a> {
resolver: &'a StubResolver,
pos: usize,
}
impl<'a> Iterator for SearchIter<'a> {
type Item = SearchSuffix;
fn next(&mut self) -> Option<Self::Item> {
if let Some(res) = self.resolver.options().search.get(self.pos) {
self.pos += 1;
Some(res.clone())
} else {
None
}
}
}