use std::collections::{BTreeSet, HashMap};
use std::convert::TryFrom;
use std::io::ErrorKind;
use std::net::{IpAddr, SocketAddr, TcpStream, ToSocketAddrs, UdpSocket};
use std::vec;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use crate::cfg_resolv_parser::{ConfigEntryTls, ResolveConfEntry, ResolveConfig, ResolveConfigFamily};
use crate::{error::*, write_error};
use crate::query::{QDnsQuery, QDnsQueryResult, QuerySetup};
use crate::sync::network::NetworkTap;
use crate::internal_error;
use crate::query_private::QDnsReq;
use crate::common::*;
use super::caches::CACHE;
use super::network::SocketTap;
#[cfg(feature = "use_sync_tls")]
use super::network::with_tls::{TcpHttpsConnection, TcpTlsConnection};
#[derive(Clone, Debug)]
pub enum QDnsSockerAddr
{
Ip(SocketAddr),
Host(String, u16)
}
impl ToSocketAddrs for QDnsSockerAddr
{
type Iter = vec::IntoIter<SocketAddr>;
fn to_socket_addrs(&self) -> std::io::Result<Self::Iter>
{
match self
{
QDnsSockerAddr::Ip(socket_addr) =>
return Ok(vec![socket_addr.clone()].into_iter()),
QDnsSockerAddr::Host(host, port) =>
{
let qdns =
QDns
::make_a_aaaa_request(None, host, QuerySetup::default())
.map_err(|e|
std::io::Error::new(ErrorKind::NotFound,
CDnsSuperError::from(e))
)?;
let res =
qdns
.query()
.collect_ok()
.into_iter()
.map(|resp|
resp
.get_responses()
.iter()
.map(|r|
r.get_rdata().get_ip().map(|f| SocketAddr::new(f, *port))
)
.collect::<Vec<Option<SocketAddr>>>()
)
.flatten()
.filter(|p| p.is_some())
.map(|v| v.unwrap())
.collect::<Vec<SocketAddr>>();
return Ok( res.into_iter() );
}
}
}
}
impl QDnsSockerAddr
{
pub
fn resolve<D>(host: D) -> std::io::Result<Self>
where D: AsRef<str>
{
if let Ok(addr) = host.as_ref().parse::<SocketAddr>()
{
return Ok(Self::Ip(addr));
}
else
{
let ref_host = host.as_ref();
let (domain, port) =
match ref_host.split_once(":")
{
Some((h, portno)) =>
{
let port: u16 =
portno
.parse()
.map_err(|e|
std::io::Error::new(ErrorKind::InvalidData, format!("{}", e))
)?;
(h, port)
},
None =>
return Err(std::io::Error::new(ErrorKind::InvalidData, "missing prot number"))
};
return Ok(Self::Host(domain.to_string(), port));
}
}
pub
fn resolve_port<D>(host: D, port: u16) -> std::io::Result<Self>
where D: AsRef<str>
{
if let Ok(addr) = host.as_ref().parse::<IpAddr>()
{
return Ok(Self::Ip(SocketAddr::new(addr, port)));
}
else
{
return Ok(Self::Host(host.as_ref().to_string(), port));
}
}
}
#[derive(Clone, Debug)]
pub struct QDns
{
resolvers: Arc<ResolveConfig>,
ordered_req_list: Vec<QDnsReq>,
opts: QuerySetup,
}
impl QDns
{
pub
fn make_empty(resolvers: Option<Arc<ResolveConfig>>, opts: QuerySetup) -> CDnsResult<Self>
{
return Ok(
Self
{
resolvers: resolvers.unwrap_or(CACHE.clone_resolve_list()?),
ordered_req_list: Vec::new(),
opts: opts,
}
);
}
pub
fn add_request<R>(&mut self, qtype: QType, req_name: R) -> CDnsResult<()>
where R: TryInto<QDnsName, Error = CDnsError>
{
let qr = QDnsReq::new_into(req_name, qtype)?;
self.ordered_req_list.push(qr);
return Ok(());
}
pub
fn make_a_aaaa_request<R: AsRef<str>>(resolvers_opt: Option<Arc<ResolveConfig>>, req_name_ref: R,
opts: QuerySetup) -> CDnsResult<Self>
{
let req_n = QDnsName::try_from(req_name_ref.as_ref())?;
let resolvers = resolvers_opt.unwrap_or(CACHE.clone_resolve_list()?);
let reqs: Vec<QDnsReq> =
match resolvers.family
{
ResolveConfigFamily::INET4_INET6 =>
{
vec![
QDnsReq::new(req_n.clone(), QType::A),
QDnsReq::new(req_n, QType::AAAA),
]
},
ResolveConfigFamily::INET6_INET4 =>
{
vec![
QDnsReq::new(req_n.clone(), QType::AAAA),
QDnsReq::new(req_n, QType::A),
]
},
ResolveConfigFamily::INET6 =>
{
vec![
QDnsReq::new(req_n, QType::AAAA),
]
},
ResolveConfigFamily::INET4 =>
{
vec![
QDnsReq::new(req_n, QType::A),
]
}
_ =>
{
vec![
QDnsReq::new(req_n.clone(), QType::A),
QDnsReq::new(req_n, QType::AAAA),
]
}
};
return Ok(
Self
{
resolvers: resolvers,
ordered_req_list: reqs,
opts: opts,
}
);
}
pub
fn query(mut self) -> QDnsQueryResult
{
let now =
if self.opts.measure_time == true
{
Some(Instant::now())
}
else
{
None
};
if self.resolvers.lookup.is_file_first()
{
let mut query_res =
match self.lookup_file(now.as_ref())
{
Ok(file) =>
{
if file.is_empty() == false
{
self.ordered_req_list.retain(|req|
{
return !file.contains_dnsreq(req);
}
);
}
file
},
Err(e) =>
{
write_error!(e);
QDnsQueryResult::default()
}
};
if self.ordered_req_list.is_empty() == false && self.resolvers.lookup.is_bind() == true
{
let res = self.process_request(now.as_ref());
query_res.extend(res);
}
return query_res;
}
else
{
let mut dns_res = self.process_request(now.as_ref());
if dns_res.is_empty() == false
{
self.ordered_req_list.retain(|req|
{
return !dns_res.contains_dnsreq(req);
}
);
}
if self.ordered_req_list.is_empty() == false && self.resolvers.lookup.is_file() == true
{
match self.lookup_file(now.as_ref())
{
Ok(res) =>
{
dns_res.extend(res);
},
Err(e) =>
{
write_error!(e);
}
}
}
return dns_res;
}
}
fn get_timeout(&self) -> Duration
{
if let Some(timeout) = self.opts.timeout
{
return Duration::from_secs(timeout as u64);
}
else
{
return Duration::from_secs(self.resolvers.timeout as u64);
}
}
fn lookup_file(&mut self, now: Option<&Instant>) -> CDnsResult<QDnsQueryResult>
{
let mut dnsquries: QDnsQueryResult = QDnsQueryResult::default();
if self.opts.ign_hosts == false
{
let hlist = CACHE.clone_host_list()?;
for req in self.ordered_req_list.iter()
{
match *req.get_type()
{
QType::A | QType::AAAA =>
{
let req_name = String::from(req.get_req_name());
let Some(host_name_ent) = hlist.search_by_fqdn(req.get_type(), req_name.as_str())
else { continue };
let Some(drp) = DnsResponsePayload::new_local(*req.get_type(), host_name_ent)
else { continue };
dnsquries.push(req.clone(), Ok(QDnsQuery::from_local(drp, now)));
},
QType::PTR =>
{
let Ok(ip) = IpAddr::try_from(req.get_req_name())
else { continue };
let Some(host_name_ent) = hlist.search_by_ip(&ip)
else { continue };
let Some(drp) = DnsResponsePayload::new_local(*req.get_type(), host_name_ent)
else { continue };
dnsquries.push(req.clone(), Ok(QDnsQuery::from_local(drp, now)));
},
_ =>
continue,
}
}
}
return Ok(dnsquries);
}
#[inline]
fn create_socket(&self, force_tcp: bool, nonblk_flag: bool, resolver: Arc<ResolveConfEntry>) -> CDnsResult<Box<dyn SocketTap>>
{
let is_tls = resolver.get_tls_type();
if is_tls == ConfigEntryTls::Tls
{
#[cfg(feature = "use_sync_tls")]
return
NetworkTap
::<TcpTlsConnection>
::new_tls(resolver, self.get_timeout(), nonblk_flag, CDdnsGlobals::get_tcp_conn_timeout());
#[cfg(not(feature = "use_sync_tls"))]
internal_error!(CDnsErrorType::SocketNotSupported,
"socket not supported: '{}'", resolver.get_tls_type());
}
else if is_tls == ConfigEntryTls::Https
{
#[cfg(feature = "use_sync_tls")]
return
NetworkTap
::<TcpHttpsConnection>
::new_https(resolver, self.get_timeout(), nonblk_flag, CDdnsGlobals::get_tcp_conn_timeout());
#[cfg(not(feature = "use_sync_tls"))]
internal_error!(CDnsErrorType::SocketNotSupported,
"socket not supported: '{}'", resolver.get_tls_type());
}
else if self.resolvers.option_flags.is_force_tcp() == true || force_tcp == true
{
return
NetworkTap
::<TcpStream>
::new_tcp(resolver, nonblk_flag, self.get_timeout(), CDdnsGlobals::get_tcp_conn_timeout());
}
else
{
return
NetworkTap
::<UdpSocket>
::new_udp(resolver, nonblk_flag, self.get_timeout());
}
}
fn process_request(&mut self, now: Option<&Instant>) -> QDnsQueryResult
{
let mut responses: QDnsQueryResult = QDnsQueryResult::with_capacity(self.ordered_req_list.len());
if self.resolvers.option_flags.is_no_parallel() == true
{
for req in self.ordered_req_list.iter()
{
let mut last_resp: Option<CDnsResult<QDnsQuery>> = None;
for resolver in self.resolvers.get_resolvers_iter()
{
match self.query_exec_seq(now, resolver.clone(), req, None)
{
Ok(resp) =>
{
if resp.should_check_next_ns() == true
{
last_resp = Some(Ok(resp));
continue;
}
else
{
responses.push(req.clone(), Ok(resp));
let _ = last_resp.take();
break;
}
},
Err(e) =>
{
if last_resp.is_none() == true
{
last_resp = Some(Err(e));
}
continue;
}
}
}
responses.push(req.clone(), last_resp.take().unwrap());
} }
else
{
for resolver in self.resolvers.get_resolvers_iter()
{
if self.ordered_req_list.is_empty() == true
{
break;
}
match self.query_exec_pipelined(now, resolver.clone(), None)
{
Ok(resp) =>
{
for (qdns_res, qdns_que) in resp
{
if let Ok(ref resp) = qdns_que
{
if resp.should_check_next_ns() == false
{
self
.ordered_req_list
.retain(
|req_item|
req_item != &qdns_res
);
}
}
responses.push(qdns_res, qdns_que);
}
},
Err(e) =>
{
write_error!(e);
continue;
}
}
}
}
return responses;
}
fn query_exec_pipelined(
&self,
now: Option<&Instant>,
resolver: Arc<ResolveConfEntry>,
requery: Option<HashMap<DnsRequestHeader, QDnsReq>>,
) -> CDnsResult<QDnsQueryResult>
{
let force_tcp = self.resolvers.option_flags.is_force_tcp() || requery.is_some();
let mut query_headers: HashMap<DnsRequestHeader, QDnsReq> =
if let Some(requer) = requery
{
let pkts_ids = requer.iter().map(|q| q.0.get_id()).collect::<BTreeSet<u16>>();
requer
.into_iter()
.map(
|(mut qrr, qdr)|
{
loop
{
qrr.regenerate_id();
if pkts_ids.contains(&qrr.get_id()) == false
{
break;
}
}
(qrr, qdr)
})
.collect::<HashMap<DnsRequestHeader, QDnsReq>>()
}
else
{
let mut pkts_ids: BTreeSet<u16> = BTreeSet::new();
self
.ordered_req_list
.iter()
.map(
|query|
{
let mut drh_res = DnsRequestHeader::try_from(query);
loop
{
if let Ok(ref mut drh) = drh_res
{
if pkts_ids.contains(&drh.get_id()) == true
{
drh.regenerate_id();
continue;
}
else
{
pkts_ids.insert(drh.get_id());
break;
}
}
else
{
break;
}
}
drh_res.map(|dh| (dh, query.clone()))
}
)
.collect::<CDnsResult<HashMap<DnsRequestHeader, QDnsReq>>>()?
};
let tap =
self.create_socket(force_tcp, false, resolver.clone())?;
for qh in query_headers.iter()
{
let pkt = qh.0.to_bytes(tap.should_append_len())?;
tap.send(pkt.as_slice())?;
}
let mut resp = QDnsQueryResult::with_capacity(self.ordered_req_list.len());
let mut requery: HashMap<DnsRequestHeader, QDnsReq> = HashMap::new();
loop
{
if query_headers.len() == 0
{
break;
}
if tap.poll_read()? == false
{
break;
}
let ans = tap.recv()?;
let Some((query_header, qdnsreq)) =
query_headers.remove_entry(&ans.req_header)
else
{
internal_error!(CDnsErrorType::IoError,
"can not find response with request: {}", ans.req_header);
};
ans.verify(&query_header)?;
let qdns_resp =
QDnsQuery::from_response(tap.get_remote_addr(), ans, now);
if let Ok(ref qdns) = qdns_resp
{
if qdns.get_status().should_try_tcp() == true && force_tcp == false
{
requery.insert(query_header, qdnsreq);
}
else
{
resp.push(qdnsreq, qdns_resp);
}
}
else
{
resp.push(qdnsreq, qdns_resp);
}
}
if requery.is_empty() == false
{
let res = self.query_exec_pipelined(now, resolver, Some(requery))?;
resp.extend(res);
}
return Ok(resp);
}
fn query_exec_seq(
&self,
now: Option<&Instant>,
resolver: Arc<ResolveConfEntry>,
query: &QDnsReq,
requery: Option<DnsRequestHeader>,
) -> CDnsResult<QDnsQuery>
{
let force_tcp = self.resolvers.option_flags.is_force_tcp() || requery.is_some();
let query_header =
if let Some(mut requery) = requery
{
requery.regenerate_id();
requery
}
else
{
let drh_req = DnsRequestHeader::try_from(query)?;
drh_req
};
let res =
{
let tap =
self
.create_socket(force_tcp, false, resolver.clone())?;
let pkt = query_header.to_bytes(tap.should_append_len())?;
tap.send(pkt.as_slice())?;
let ans = tap.recv()?;
ans.verify(&query_header)?;
let resp = QDnsQuery::from_response(tap.get_remote_addr(), ans, now)?;
Ok(resp)
};
if (res.is_ok() == true && res.as_ref().unwrap().status.should_try_tcp() == false) ||
(res.is_err() == true && force_tcp == true)
{
return res;
}
return
self.query_exec_seq(now, resolver.clone(), query, Some(query_header));
}
}
#[cfg(test)]
mod tests
{
use std::net::IpAddr;
use crate::{common::{byte2hexchar, ip2pkt, RecordPTR, RecordReader}, sync::{query::QDns}, QDnsQueryRec, QType, QuerySetup};
#[test]
fn test_ip2pkt()
{
use std::time::Instant;
use std::net::{IpAddr, Ipv4Addr};
let test = IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8));
let now = Instant::now();
let res = ip2pkt(&test);
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
let ctrl = b"\x01\x38\x01\x38\x01\x38\x01\x38\x07\x69\x6e\x2d\x61\x64\x64\x72\x04\x61\x72\x70\x61\x00";
assert_eq!(res.as_slice(), ctrl);
}
#[test]
fn test_byte2hexchar()
{
assert_eq!(byte2hexchar(1), 0x31);
assert_eq!(byte2hexchar(9), 0x39);
assert_eq!(byte2hexchar(10), 'a' as u8);
assert_eq!(byte2hexchar(15), 'f' as u8);
}
#[test]
fn reverse_lookup_test()
{
use std::time::Instant;
let ipp: IpAddr = "8.8.8.8".parse().unwrap();
let mut query_setup = QuerySetup::default();
query_setup.set_measure_time(true);
let now = Instant::now();
let mut dns_req =
QDns::make_empty(None, query_setup).unwrap();
dns_req.add_request(QType::PTR, ipp).unwrap();
let res = dns_req.query();
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
println!("{}", res);
assert_eq!(res.is_empty(), false);
let recs = res.collect_ok();
let rec = &recs[0];
assert_eq!(rec.status, QDnsQueryRec::Ok);
assert_eq!(rec.resp.len(), 1);
assert_eq!(rec.resp[0].rdata, RecordPTR::wrap(RecordPTR{ fqdn: "dns.google".to_string() }));
}
#[test]
fn reverse_lookup_hosts_test()
{
use std::time::Instant;
let ipp: IpAddr = "127.0.0.1".parse().unwrap();
let now = Instant::now();
let mut query_setup = QuerySetup::default();
query_setup.set_measure_time(true);
let mut dns_req =
QDns::make_empty(None, query_setup).unwrap();
dns_req.add_request(QType::PTR, ipp).unwrap();
let res = dns_req.query();
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
println!("{}", res);
assert_eq!(res.is_empty(), false);
let recs = res.collect_ok();
let rec = &recs[0];
assert_eq!(rec.server.as_str(), "/etc/hosts");
assert_eq!(rec.status, QDnsQueryRec::Ok);
assert_eq!(rec.resp.len(), 1);
assert_eq!(rec.resp[0].rdata, RecordPTR::wrap(RecordPTR{ fqdn: "localhost".to_string() }));
}
#[test]
fn reverse_lookup_a()
{
use std::time::Instant;
let mut query_setup = QuerySetup::default();
query_setup.set_measure_time(true);
let res =
QDns::make_a_aaaa_request(None, "dns.google", query_setup).unwrap();
let now = Instant::now();
let res = res.query();
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
println!("{}", res);
}
#[cfg(feature = "enable_IDN_support")]
#[test]
fn reverse_lookup_a_idn()
{
use std::time::Instant;
let mut query_setup = QuerySetup::default();
query_setup.set_measure_time(true);
let res =
QDns::make_a_aaaa_request(None, "законипорядок.бел", query_setup).unwrap();
let now = Instant::now();
let res = res.query();
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
println!("{}", res);
let ok: Vec<crate::QDnsQuery> = res.collect_ok_with_answers();
let name = ok[0].get_responses()[0].name.clone();
let idn_name = ok[0].get_responses()[0].get_full_domain_name_with_idn_decode().unwrap();
println!("{} -> {}", name, idn_name);
assert_eq!(name, "xn--80aihfjcshcbin9q.xn--90ais");
assert_eq!(idn_name, "законипорядок.бел");
}
#[test]
fn truncation_test_1()
{
use std::time::Instant;
let mut query_setup = QuerySetup::default();
query_setup.set_measure_time(true);
let mut res =
QDns::make_empty(None, query_setup).unwrap();
res.add_request(QType::TXT, "truncate-zyxw11.go.dnscheck.tools").unwrap();
let now = Instant::now();
let res = res.query();
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
println!("{}", res);
}
#[test]
fn truncation_test_2()
{
use std::time::Instant;
let mut query_setup = QuerySetup::default();
query_setup.set_measure_time(true);
let mut res =
QDns::make_empty(None, query_setup).unwrap();
res.add_request(QType::A, "dnscheck.tools").unwrap();
res.add_request(QType::A, "localhost").unwrap();
res.add_request(QType::A, "unknownnonexistentdomain.com").unwrap();
res.add_request(QType::TXT, "example.com").unwrap();
let now = Instant::now();
let res = res.query();
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
let (ans, errs) = res.collect_split();
println!("ANSWERS");
for a in ans
{
println!("-----\n{} \n{}", a.0, a.1.unwrap());
}
println!("ERRORS");
for b in errs
{
println!("-----\n{} \n{}", b.0, b.1.err().unwrap());
}
}
}