use std::collections::HashMap;
use std::convert::{TryFrom, TryInto};
use std::net::SocketAddr;
use std::fmt;
use std::ops::Index;
use std::time::Duration;
use std::time::Instant;
use crate::error::*;
use crate::internal_error;
use crate::query_private::QDnsReq;
use super::common::*;
#[derive(Debug, Clone)]
pub struct QuerySetup
{
pub(crate) measure_time: bool,
pub(crate) ign_hosts: bool,
pub(crate) timeout: Option<u32>,
}
impl Default for QuerySetup
{
fn default() -> Self
{
return Self { measure_time: false, ign_hosts: false, timeout: None };
}
}
impl QuerySetup
{
pub
fn set_measure_time(&mut self, flag: bool) -> &mut Self
{
self.measure_time = flag;
return self;
}
pub
fn set_ign_hosts(&mut self, flag: bool) -> &mut Self
{
self.ign_hosts = flag;
return self;
}
pub
fn set_override_timeout(&mut self, timeout: u32) -> &mut Self
{
if timeout == 0
{
return self;
}
self.timeout = Some(timeout);
return self;
}
pub
fn reset_override_timeout(&mut self) -> &mut Self
{
self.timeout = None;
return self;
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum QDnsQueryRec
{
Ok,
ServFail,
NxDomain,
Refused,
NotImpl,
Truncated,
FormError,
}
impl QDnsQueryRec
{
pub(crate)
fn try_next_nameserver(&self, aa: bool) -> bool
{
match *self
{
Self::NxDomain =>
{
if aa == true
{
return false;
}
return true;
},
Self::Truncated =>
{
return true;
},
Self::Ok | Self::Refused | Self::ServFail | Self::NotImpl | Self::FormError =>
{
return true;
}
}
}
pub(crate)
fn should_try_tcp(&self) -> bool
{
return *self == Self::Truncated || *self == Self::ServFail || *self == Self::NxDomain;
}
}
impl TryFrom<StatusBits> for QDnsQueryRec
{
type Error = CDnsError;
fn try_from(value: StatusBits) -> Result<Self, Self::Error>
{
if value.contains(StatusBits::TRUN_CATION) == true
{
return Ok(QDnsQueryRec::Truncated);
}
else if value.contains(StatusBits::RESP_NOERROR) == true
{
return Ok(QDnsQueryRec::Ok);
}
else if value.contains(StatusBits::RESP_FORMERR) == true
{
return Ok(QDnsQueryRec::FormError);
}
else if value.contains(StatusBits::RESP_NOT_IMPL) == true
{
return Ok(QDnsQueryRec::NotImpl);
}
else if value.contains(StatusBits::RESP_NXDOMAIN) == true
{
return Ok(QDnsQueryRec::NxDomain);
}
else if value.contains(StatusBits::RESP_REFUSED) == true
{
return Ok(QDnsQueryRec::Refused);
}
else if value.contains(StatusBits::RESP_SERVFAIL) == true
{
return Ok(QDnsQueryRec::ServFail);
}
else
{
internal_error!(CDnsErrorType::DnsResponse, "response status bits unknwon result: '{}'", value.bits());
};
}
}
impl fmt::Display for QDnsQueryRec
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
match *self
{
Self::Ok =>
writeln!(f, "OK"),
Self::ServFail =>
writeln!(f, "SERVFAIL"),
Self::NxDomain =>
writeln!(f, "NXDOMAIN"),
Self::Refused =>
writeln!(f, "REFUSED"),
Self::NotImpl =>
writeln!(f, "NOT IMPLEMENTED"),
Self::Truncated =>
writeln!(f, "TRUNCATED"),
Self::FormError =>
writeln!(f, "FORMAT ERROR"),
}
}
}
#[derive(Debug, Default)]
pub struct QDnsQueryResult
{
queries: HashMap<QDnsReq, CDnsResult<QDnsQuery>>,
}
impl IntoIterator for QDnsQueryResult
{
type Item = (QDnsReq, CDnsResult<QDnsQuery>);
type IntoIter = std::collections::hash_map::IntoIter<QDnsReq, CDnsResult<QDnsQuery>>;
#[inline]
fn into_iter(self) -> Self::IntoIter
{
return self.queries.into_iter();
}
}
impl Index<QDnsReq> for QDnsQueryResult
{
type Output = CDnsResult<QDnsQuery>;
fn index(&self, index: QDnsReq) -> &Self::Output
{
return &self.queries.get(&index).unwrap();
}
}
impl QDnsQueryResult
{
pub(crate)
fn with_capacity(cap: usize) -> Self
{
return Self{ queries: HashMap::with_capacity(cap) };
}
pub(crate)
fn push(&mut self, req: QDnsReq, resp: CDnsResult<QDnsQuery>)
{
self.queries.insert(req, resp);
}
pub
fn is_empty(&self) -> bool
{
return self.queries.is_empty();
}
pub
fn contains_dnsreq(&self, req: &QDnsReq) -> bool
{
return self.queries.contains_key(req);
}
pub
fn extend(&mut self, other: Self)
{
self.queries.extend(other.queries);
}
pub
fn list_results(&self) -> std::collections::hash_map::Iter<'_, QDnsReq, Result<QDnsQuery, CDnsError>>
{
return self.queries.iter();
}
pub
fn get_result(self) -> CDnsResult<Vec<QDnsQuery>>
{
let ok = self.collect_ok();
if ok.is_empty() == true
{
internal_error!(CDnsErrorType::DnsNotAvailable, "network error");
}
return Ok(ok);
}
pub
fn get_ok_or_error(self) ->CDnsResult<Vec<QDnsQuery>>
{
return
self
.queries
.into_iter()
.map(|e| e.1)
.collect::<CDnsResult<Vec<QDnsQuery>>>();
}
pub
fn collect_ok(self) -> Vec<QDnsQuery>
{
return
self
.queries
.into_iter()
.filter(
|(_k, v)|
v.is_ok()
)
.map(|(_, v)| v.unwrap())
.collect();
}
pub
fn collect_ok_with_answers(self) -> Vec<QDnsQuery>
{
return
self
.queries
.into_iter()
.filter(
|(_k, v)|
v.is_ok() && v.as_ref().unwrap().resp.is_empty() == false
)
.map(|(_, v)| v.unwrap())
.collect();
}
pub
fn collect_split(self) -> (Vec<(QDnsReq, Result<QDnsQuery, CDnsError>)>, Vec<(QDnsReq, Result<QDnsQuery, CDnsError>)>)
{
return
self
.queries
.into_iter()
.partition(|(_, v)|
v.is_ok()
);
}
}
impl fmt::Display for QDnsQueryResult
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
if self.is_empty() == false
{
for (req, qr) in self.list_results()
{
match qr
{
Ok(r) =>
write!(f, "{}", r)?,
Err(e) =>
write!(f, "request: {}, error: {}", req, e)?
}
}
}
else
{
write!(f, "No DNS server available")?;
}
return Ok(());
}
}
#[derive(Clone, Debug)]
pub struct QDnsQuery
{
pub elapsed: Option<Duration>,
pub server: String,
pub aa: bool,
pub authoratives: Vec<DnsResponsePayload>,
pub resp: Vec<DnsResponsePayload>,
pub status: QDnsQueryRec,
}
impl fmt::Display for QDnsQuery
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "Source: {} ", self.server)?;
if let Some(ref el) = self.elapsed
{
write!(f, "{:.2?} ", el)?;
}
if self.aa == true
{
writeln!(f, "Authoritative answer")?;
}
else
{
writeln!(f, "Non-Authoritative answer")?;
}
writeln!(f, "Authoritatives: {}", self.authoratives.len())?;
if self.authoratives.len() > 0
{
for a in self.authoratives.iter()
{
writeln!(f, "{}", a)?;
}
writeln!(f, "")?;
}
writeln!(f, "Status: {}", self.status)?;
writeln!(f, "Answers: {}", self.resp.len())?;
if self.resp.len() > 0
{
for r in self.resp.iter()
{
writeln!(f, "{}", r)?;
}
writeln!(f, "")?;
}
return Ok(());
}
}
impl IntoIterator for QDnsQuery
{
type Item = DnsResponsePayload;
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter
{
self.resp.into_iter()
}
}
impl QDnsQuery
{
pub
fn is_ok(&self) -> bool
{
return self.status == QDnsQueryRec::Ok;
}
pub
fn is_authorative(&self) -> bool
{
return self.aa;
}
pub
fn get_elapsed_time(&self) -> Option<&Duration>
{
return self.elapsed.as_ref();
}
pub
fn get_server(&self) -> &String
{
return &self.server;
}
pub
fn get_authoratives(&self) -> &[DnsResponsePayload]
{
return self.authoratives.as_slice();
}
pub
fn get_responses(&self) -> &[DnsResponsePayload]
{
return self.resp.as_slice();
}
pub
fn move_responses(self) -> Vec<DnsResponsePayload>
{
return self.resp;
}
pub
fn get_status(&self) -> QDnsQueryRec
{
return self.status;
}
pub
fn should_check_next_ns(&self) -> bool
{
return self.status.try_next_nameserver(self.aa);
}
}
impl QDnsQuery
{
pub
fn from_local(req_pl: Vec<DnsResponsePayload>, now: Option<&Instant>) -> QDnsQuery
{
let elapsed = now.map(|n| n.elapsed());
return
Self
{
elapsed: elapsed,
server: HOST_CFG_PATH.to_string(),
aa: true,
authoratives: Vec::new(),
status: QDnsQueryRec::Ok,
resp: req_pl
};
}
pub
fn from_response(server: &SocketAddr, ans: DnsRequestAnswer, now: Option<&Instant>) -> CDnsResult<Self>
{
return Ok(
Self
{
elapsed: now.map_or(None, |n| Some(n.elapsed())),
server: server.to_string(),
aa: ans.req_header.header.status.contains(StatusBits::AUTH_ANSWER),
authoratives: ans.authoratives,
status: ans.req_header.header.status.try_into()?,
resp: ans.response,
}
);
}
}