use self::conf::{
ResolvConf, ResolvOptions, SearchSuffix, ServerConf, Transport,
};
use crate::base::iana::Rcode;
use crate::base::message::Message;
use crate::base::message_builder::{AdditionalBuilder, MessageBuilder};
use crate::base::name::{ToName, ToRelativeName};
use crate::base::question::Question;
use crate::net::client::dgram_stream;
use crate::net::client::multi_stream;
use crate::net::client::protocol::{TcpConnect, UdpConnect};
use crate::net::client::redundant;
use crate::net::client::request::{
ComposeRequest, Error, RequestMessage, SendRequest,
};
use crate::resolv::lookup::addr::{lookup_addr, FoundAddrs};
use crate::resolv::lookup::host::{lookup_host, search_host, FoundHosts};
use crate::resolv::lookup::srv::{lookup_srv, FoundSrvs, SrvError};
use crate::resolv::resolver::{Resolver, SearchNames};
use bytes::Bytes;
use futures_util::stream::{FuturesUnordered, StreamExt};
use octseq::array::Array;
use std::boxed::Box;
use std::fmt::Debug;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use std::string::ToString;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::vec::Vec;
use std::{io, ops};
#[cfg(feature = "resolv-sync")]
use tokio::runtime;
use tokio::sync::Mutex;
use tokio::time::timeout;
pub mod conf;
#[derive(Debug)]
pub struct StubResolver {
transport: Mutex<Option<redundant::Connection<RequestMessage<Vec<u8>>>>>,
options: ResolvOptions,
servers: Vec<ServerConf>,
}
impl StubResolver {
pub fn new() -> Self {
Self::from_conf(ResolvConf::default())
}
pub fn from_conf(conf: ResolvConf) -> Self {
StubResolver {
transport: None.into(),
options: conf.options,
servers: conf.servers,
}
}
pub fn options(&self) -> &ResolvOptions {
&self.options
}
pub async fn add_connection(
&self,
connection: Box<
dyn SendRequest<RequestMessage<Vec<u8>>> + Send + Sync,
>,
) {
self.get_transport()
.await
.expect("The 'redundant::Connection' task should not fail")
.add(connection)
.await
.expect("The 'redundant::Connection' task should not fail");
}
pub async fn query<N: ToName, Q: Into<Question<N>>>(
&self,
question: Q,
) -> Result<Answer, io::Error> {
Query::new(self)?
.run(Query::create_message(question.into()))
.await
}
async fn query_message(
&self,
message: QueryMessage,
) -> Result<Answer, io::Error> {
Query::new(self)?.run(message).await
}
async fn setup_transport<
CR: Clone + Debug + ComposeRequest + Send + Sync + 'static,
>(
&self,
) -> Result<redundant::Connection<CR>, Error> {
let (redun, transp) = redundant::Connection::new();
let redun_run_fut = transp.run();
tokio::spawn(async move {
redun_run_fut.await;
});
let fut_list_tcp = FuturesUnordered::new();
let fut_list_udp_tcp = FuturesUnordered::new();
for s in &self.servers {
if self.options.use_vc || matches!(s.transport, Transport::Tcp) {
let (conn, tran) =
multi_stream::Connection::new(TcpConnect::new(s.addr));
fut_list_tcp.push(tran.run());
redun.add(Box::new(conn)).await?;
} else {
let udp_connect = UdpConnect::new(s.addr);
let tcp_connect = TcpConnect::new(s.addr);
let (conn, tran) =
dgram_stream::Connection::new(udp_connect, tcp_connect);
fut_list_udp_tcp.push(tran.run());
redun.add(Box::new(conn)).await?;
}
}
tokio::spawn(async move {
run(fut_list_tcp, fut_list_udp_tcp).await;
});
Ok(redun)
}
async fn get_transport(
&self,
) -> Result<redundant::Connection<RequestMessage<Vec<u8>>>, Error> {
let mut opt_transport = self.transport.lock().await;
match &*opt_transport {
Some(transport) => Ok(transport.clone()),
None => {
let transport = self.setup_transport().await?;
*opt_transport = Some(transport.clone());
Ok(transport)
}
}
}
}
async fn run<TcpFut: Future, UdpTcpFut: Future>(
mut fut_list_tcp: FuturesUnordered<TcpFut>,
mut fut_list_udp_tcp: FuturesUnordered<UdpTcpFut>,
) {
loop {
let tcp_empty = fut_list_tcp.is_empty();
let udp_tcp_empty = fut_list_udp_tcp.is_empty();
if tcp_empty && udp_tcp_empty {
break;
}
tokio::select! {
_ = fut_list_tcp.next(), if !tcp_empty => {
}
_ = fut_list_udp_tcp.next(), if !udp_tcp_empty => {
}
}
}
}
impl StubResolver {
pub async fn lookup_addr(
&self,
addr: IpAddr,
) -> Result<FoundAddrs<Self>, io::Error> {
lookup_addr(self, addr).await
}
pub async fn lookup_host(
&self,
qname: impl ToName,
) -> Result<FoundHosts<Self>, io::Error> {
lookup_host(self, qname).await
}
pub async fn search_host(
&self,
qname: impl ToRelativeName,
) -> Result<FoundHosts<Self>, io::Error> {
search_host(self, qname).await
}
pub async fn lookup_srv(
&self,
service: impl ToRelativeName,
name: impl ToName,
fallback_port: u16,
) -> Result<Option<FoundSrvs>, SrvError> {
lookup_srv(self, service, name, fallback_port).await
}
}
#[cfg(feature = "resolv-sync")]
#[cfg_attr(docsrs, doc(cfg(feature = "resolv-sync")))]
impl StubResolver {
pub fn run<R, T, E, F>(op: F) -> R::Output
where
R: Future<Output = Result<T, E>> + Send + 'static,
E: From<io::Error>,
F: FnOnce(StubResolver) -> R + Send + 'static,
{
Self::run_with_conf(ResolvConf::default(), op)
}
pub fn run_with_conf<R, T, E, F>(conf: ResolvConf, op: F) -> R::Output
where
R: Future<Output = Result<T, E>> + Send + 'static,
E: From<io::Error>,
F: FnOnce(StubResolver) -> R + Send + 'static,
{
let resolver = Self::from_conf(conf);
let runtime = runtime::Builder::new_current_thread()
.enable_all()
.build()?;
runtime.block_on(op(resolver))
}
}
impl Default for StubResolver {
fn default() -> Self {
Self::new()
}
}
impl Resolver for StubResolver {
type Octets = Bytes;
type Answer = Answer;
type Query<'a> =
Pin<Box<dyn Future<Output = Result<Answer, io::Error>> + Send + 'a>>;
fn query<'a, N, Q>(&'a self, question: Q) -> Self::Query<'a>
where
N: ToName,
Q: Into<Question<N>>,
{
let message = Query::create_message(question.into());
Box::pin(self.query_message(message))
}
}
impl SearchNames for StubResolver {
type Name = SearchSuffix;
type Iter<'a> = SearchIter<'a>;
fn search_iter<'a>(&'a self) -> Self::Iter<'a> {
SearchIter {
resolver: self,
pos: 0,
}
}
}
pub struct Query<'a> {
resolver: &'a StubResolver,
edns: Arc<AtomicBool>,
error: Result<Answer, io::Error>,
}
impl<'a> Query<'a> {
pub fn new(resolver: &'a StubResolver) -> Result<Self, io::Error> {
Ok(Query {
resolver,
edns: Arc::new(AtomicBool::new(true)),
error: Err(io::Error::new(
io::ErrorKind::TimedOut,
"all timed out",
)),
})
}
pub async fn run(
mut self,
mut message: QueryMessage,
) -> Result<Answer, io::Error> {
loop {
match self.run_query(&mut message).await {
Ok(answer) => {
if answer.header().rcode() == Rcode::FORMERR
&& self.does_edns()
{
self.disable_edns();
continue;
} else if answer.header().rcode() == Rcode::SERVFAIL {
self.update_error_servfail(answer);
} else {
return Ok(answer);
}
}
Err(err) => self.update_error(err),
}
return self.error;
}
}
fn create_message(question: Question<impl ToName>) -> QueryMessage {
let mut message = MessageBuilder::from_target(Default::default())
.expect("MessageBuilder should not fail");
message.header_mut().set_rd(true);
let mut message = message.question();
message.push(question).expect("push should not fail");
message.additional()
}
async fn run_query(
&mut self,
message: &mut QueryMessage,
) -> Result<Answer, io::Error> {
let msg = Message::from_octets(message.as_target().to_vec())
.expect("Message::from_octets should not fail");
let request_msg = RequestMessage::new(msg)
.map_err(|e| io::Error::other(e.to_string()))?;
let transport = self
.resolver
.get_transport()
.await
.map_err(|e| io::Error::other(e.to_string()))?;
let mut gr_fut = transport.send_request(request_msg);
let reply =
timeout(self.resolver.options().timeout, gr_fut.get_response())
.await?
.map_err(|e| io::Error::other(e.to_string()))?;
Ok(Answer { message: reply })
}
fn update_error(&mut self, err: io::Error) {
if err.kind() != io::ErrorKind::TimedOut && self.error.is_err() {
self.error = Err(err)
}
}
fn update_error_servfail(&mut self, answer: Answer) {
self.error = Ok(answer)
}
pub fn does_edns(&self) -> bool {
self.edns.load(Ordering::Relaxed)
}
pub fn disable_edns(&self) {
self.edns.store(false, Ordering::Relaxed);
}
}
pub(super) type QueryMessage = AdditionalBuilder<Array<512>>;
#[derive(Clone)]
pub struct Answer {
message: Message<Bytes>,
}
impl Answer {
pub fn is_final(&self) -> bool {
(self.message.header().rcode() == Rcode::NOERROR
|| self.message.header().rcode() == Rcode::NXDOMAIN)
&& !self.message.header().tc()
}
pub fn is_truncated(&self) -> bool {
self.message.header().tc()
}
pub fn into_message(self) -> Message<Bytes> {
self.message
}
}
impl From<Message<Bytes>> for Answer {
fn from(message: Message<Bytes>) -> Self {
Answer { message }
}
}
impl ops::Deref for Answer {
type Target = Message<Bytes>;
fn deref(&self) -> &Self::Target {
&self.message
}
}
impl AsRef<Message<Bytes>> for Answer {
fn as_ref(&self) -> &Message<Bytes> {
&self.message
}
}
#[derive(Clone, Debug)]
pub struct SearchIter<'a> {
resolver: &'a StubResolver,
pos: usize,
}
impl Iterator for SearchIter<'_> {
type Item = SearchSuffix;
fn next(&mut self) -> Option<Self::Item> {
if let Some(res) = self.resolver.options().search.get(self.pos) {
self.pos += 1;
Some(res.clone())
} else {
None
}
}
}