use std::collections::{BTreeSet, HashMap};
use std::convert::TryFrom;
use std::marker::PhantomData;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use async_recursion::async_recursion;
use crate::a_sync::caches::CachesController;
#[cfg(feature = "built_in_async")]
use crate::a_sync::interface::MutexedCaches;
#[cfg(not(feature = "built_in_async"))]
use crate::a_sync::interface::MutexedCaches;
#[cfg(feature = "built_in_async")]
use crate::a_sync::{IoInterf, SocketBase};
#[cfg(feature = "built_in_async")]
use crate::a_sync::TokioInterf;
use crate::a_sync::SocketTaps;
use crate::cfg_resolv_parser::{ConfigEntryTls, ResolveConfEntry, ResolveConfigFamily};
use crate::common::{CDdnsGlobals, DnsRequestAnswer, DnsRequestHeader};
use crate::{error::*, DnsResponsePayload, QDnsQueryResult};
use crate::query::QDnsQuery;
use crate::{write_error, internal_error};
use crate::query_private::QDnsReq;
use super::network::{NetworkTapType, SocketTap};
use super::{QDnsName, QType, QuerySetup, ResolveConfig};
#[cfg(feature = "built_in_async")]
#[derive(Clone, Debug)]
pub struct QDns<LOC: Sync + Send = SocketBase, TAP: SocketTaps<LOC> = SocketBase, MC: MutexedCaches = IoInterf>
{
resolvers: Arc<ResolveConfig>,
ordered_req_list: Vec<QDnsReq>,
opts: QuerySetup,
cache: Arc<CachesController<MC>>,
_tap: PhantomData<TAP>,
_loc: PhantomData<LOC>
}
#[cfg(feature = "use_async_tokio")]
impl QDns<SocketBase>
{
#[inline]
pub async
fn builtin_make_empty(resolvers: Option<Arc<ResolveConfig>>, opts: QuerySetup, cache: Arc<CachesController<TokioInterf>>) -> CDnsResult<QDns<SocketBase, SocketBase, TokioInterf>>
{
return QDns::<_>::make_empty(resolvers, opts, cache).await;
}
#[inline]
pub
fn buildin_add_request<R>(&mut self, qtype: QType, req_name: R) -> CDnsResult<()>
where
R: TryInto<QDnsName, Error = CDnsError>,
{
let qr = QDnsReq::new_into(req_name, qtype)?;
self.ordered_req_list.push(qr);
return Ok(());
}
pub async
fn buildin_make_a_aaaa_request<R: AsRef<str>>(resolvers_opt: Option<Arc<ResolveConfig>>, req_name: R,
opts: QuerySetup, cache: Arc<CachesController<TokioInterf>>) -> CDnsResult<Self>
{
return QDns::<_>::make_a_aaaa_request(resolvers_opt, req_name, opts, cache).await;
}
}
#[cfg(not(feature = "built_in_async"))]
#[derive(Clone, Debug)]
pub struct QDns<LOC: Sync + Send, TAP: SocketTaps<LOC>, MC: MutexedCaches>
{
resolvers: Arc<ResolveConfig>,
ordered_req_list: Vec<QDnsReq>,
opts: QuerySetup,
cache: Arc<CachesController<MC>>,
_tap: PhantomData<TAP>,
_loc: PhantomData<LOC>
}
impl<LOC: Sync + Send, TAP: SocketTaps<LOC>, MC: MutexedCaches> QDns<LOC, TAP, MC>
{
pub async
fn make_empty(resolvers: Option<Arc<ResolveConfig>>, opts: QuerySetup, cache: Arc<CachesController<MC>>) -> CDnsResult<QDns<LOC, TAP, MC>>
{
return Ok(
Self
{
resolvers: resolvers.unwrap_or(cache.clone_resolve_list().await?),
ordered_req_list: Vec::new(),
opts: opts,
cache: cache,
_tap: PhantomData,
_loc: PhantomData
}
);
}
pub
fn add_request<R>(&mut self, qtype: QType, req_name: R) -> CDnsResult<()>
where
R: TryInto<QDnsName, Error = CDnsError>
{
let qr = QDnsReq::new_into(req_name, qtype)?;
self.ordered_req_list.push(qr);
return Ok(());
}
pub async
fn make_a_aaaa_request<R: AsRef<str>>(resolvers_opt: Option<Arc<ResolveConfig>>, req_name_ref: R,
opts: QuerySetup, cache: Arc<CachesController<MC>>) -> CDnsResult<Self>
{
let req_n = QDnsName::try_from(req_name_ref.as_ref())?;
let resolvers = resolvers_opt.unwrap_or(cache.clone_resolve_list().await?);
let reqs: Vec<QDnsReq> =
match resolvers.family
{
ResolveConfigFamily::INET4_INET6 =>
{
vec![
QDnsReq::new(req_n.clone(), QType::A),
QDnsReq::new(req_n, QType::AAAA),
]
},
ResolveConfigFamily::INET6_INET4 =>
{
vec![
QDnsReq::new(req_n.clone(), QType::AAAA),
QDnsReq::new(req_n, QType::A),
]
},
ResolveConfigFamily::INET6 =>
{
vec![
QDnsReq::new(req_n, QType::AAAA),
]
},
ResolveConfigFamily::INET4 =>
{
vec![
QDnsReq::new(req_n, QType::A),
]
}
_ =>
{
vec![
QDnsReq::new(req_n.clone(), QType::A),
QDnsReq::new(req_n, QType::AAAA),
]
}
};
return Ok(
Self
{
resolvers: resolvers,
ordered_req_list: reqs,
opts: opts,
cache: cache,
_tap: PhantomData,
_loc: PhantomData
}
);
}
pub async
fn query(mut self) -> QDnsQueryResult
{
let now =
if self.opts.measure_time == true
{
Some(Instant::now())
}
else
{
None
};
if self.resolvers.lookup.is_file_first()
{
let mut query_res =
match self.lookup_file(now.as_ref()).await
{
Ok(file) =>
{
if file.is_empty() == false
{
self.ordered_req_list.retain(|req|
{
return !file.contains_dnsreq(req);
}
);
}
file
},
Err(e) =>
{
write_error!(e);
QDnsQueryResult::default()
}
};
if self.ordered_req_list.is_empty() == false && self.resolvers.lookup.is_bind() == true
{
let res = self.process_request(now.as_ref()).await;
query_res.extend(res);
}
return query_res;
}
else
{
let mut dns_res = self.process_request(now.as_ref()).await;
if dns_res.is_empty() == false
{
self.ordered_req_list.retain(|req|
{
return !dns_res.contains_dnsreq(req);
}
);
}
if self.ordered_req_list.is_empty() == false && self.resolvers.lookup.is_file() == true
{
match self.lookup_file(now.as_ref()).await
{
Ok(res) =>
{
dns_res.extend(res);
},
Err(e) =>
{
write_error!(e);
}
}
}
return dns_res;
}
}
fn get_timeout(&self) -> Duration
{
if let Some(timeout) = self.opts.timeout
{
return Duration::from_secs(timeout as u64);
}
else
{
return Duration::from_secs(self.resolvers.timeout as u64);
}
}
async
fn lookup_file(&mut self, now: Option<&Instant>) -> CDnsResult<QDnsQueryResult>
{
let mut dnsquries: QDnsQueryResult = QDnsQueryResult::default();
if self.opts.ign_hosts == false
{
let hlist = self.cache.clone_host_list().await?;
for req in self.ordered_req_list.iter()
{
match *req.get_type()
{
QType::A | QType::AAAA =>
{
let req_name = String::from(req.get_req_name());
let Some(host_name_ent) = hlist.search_by_fqdn(req.get_type(), req_name.as_str())
else { continue };
let Some(drp) = DnsResponsePayload::new_local(*req.get_type(), host_name_ent)
else { continue };
dnsquries.push(req.clone(), Ok(QDnsQuery::from_local(drp, now)));
},
QType::PTR =>
{
let Ok(ip) = IpAddr::try_from(req.get_req_name())
else { continue };
let Some(host_name_ent) = hlist.search_by_ip(&ip)
else { continue };
let Some(drp) = DnsResponsePayload::new_local(*req.get_type(), host_name_ent)
else { continue };
dnsquries.push(req.clone(), Ok(QDnsQuery::from_local(drp, now)));
},
_ =>
continue,
}
}
}
return Ok(dnsquries);
}
#[inline]
fn create_socket(&self, force_tcp: bool, resolver: Arc<ResolveConfEntry>) -> CDnsResult<Box<NetworkTapType<LOC>>>
{
if resolver.get_tls_type() != ConfigEntryTls::None
{
#[cfg(feature = "use_async_tokio_tls")]
return TAP::new_tls_socket(resolver, self.get_timeout());
#[cfg(not(feature = "use_async_tokio_tls"))]
internal_error!(CDnsErrorType::SocketNotSupported, "compiled without TLS support");
}
else if self.resolvers.option_flags.is_force_tcp() == true || force_tcp == true
{
return TAP::new_tcp_socket(resolver, self.get_timeout());
}
else
{
return TAP::new_udp_socket(resolver, self.get_timeout());
}
}
async
fn process_request(&mut self, now: Option<&Instant>) -> QDnsQueryResult
{
let mut responses: QDnsQueryResult = QDnsQueryResult::with_capacity(self.ordered_req_list.len());
if self.resolvers.option_flags.is_no_parallel() == true
{
for req in self.ordered_req_list.iter()
{
let mut last_resp: Option<CDnsResult<QDnsQuery>> = None;
for resolver in self.resolvers.get_resolvers_iter()
{
match self.query_exec_seq(now, resolver.clone(), req, None).await
{
Ok(resp) =>
{
if resp.should_check_next_ns() == true
{
last_resp = Some(Ok(resp));
continue;
}
else
{
responses.push(req.clone(), Ok(resp));
let _ = last_resp.take();
break;
}
},
Err(e) =>
{
if last_resp.is_none() == true
{
last_resp = Some(Err(e));
}
continue;
}
}
}
responses.push(req.clone(), last_resp.take().unwrap());
} }
else
{
for resolver in self.resolvers.get_resolvers_iter()
{
if self.ordered_req_list.is_empty() == true
{
break;
}
match self.query_exec_pipelined(now, resolver.clone(), None).await
{
Ok(resp) =>
{
for (qdns_res, qdns_que) in resp
{
if let Ok(ref resp) = qdns_que
{
if resp.should_check_next_ns() == false
{
self
.ordered_req_list
.retain(
|req_item|
req_item != &qdns_res
);
}
}
responses.push(qdns_res, qdns_que);
}
},
Err(e) =>
{
write_error!(e);
continue;
}
}
}
}
return responses;
}
#[async_recursion]
async
fn query_exec_pipelined(
&self,
now: Option<&Instant>,
resolver: Arc<ResolveConfEntry>,
requery: Option<HashMap<DnsRequestHeader<'async_recursion>, QDnsReq>>,
) -> CDnsResult<QDnsQueryResult>
{
let force_tcp = self.resolvers.option_flags.is_force_tcp() || requery.is_some();
let mut query_headers: HashMap<DnsRequestHeader, QDnsReq> =
if let Some(requer) = requery
{
let pkts_ids = requer.iter().map(|q| q.0.get_id()).collect::<BTreeSet<u16>>();
requer
.into_iter()
.map(
|(mut qrr, qdr)|
{
loop
{
qrr.regenerate_id();
if pkts_ids.contains(&qrr.get_id()) == false
{
break;
}
}
(qrr, qdr)
})
.collect::<HashMap<DnsRequestHeader, QDnsReq>>()
}
else
{
let mut pkts_ids: BTreeSet<u16> = BTreeSet::new();
self
.ordered_req_list
.iter()
.map(
|query|
{
let mut drh_res = DnsRequestHeader::try_from(query);
loop
{
if let Ok(ref mut drh) = drh_res
{
if pkts_ids.contains(&drh.get_id()) == true
{
drh.regenerate_id();
continue;
}
else
{
pkts_ids.insert(drh.get_id());
break;
}
}
else
{
break;
}
}
drh_res.map(|dh| (dh, query.clone()))
}
)
.collect::<CDnsResult<HashMap<DnsRequestHeader, QDnsReq>>>()?
};
let mut tap = self.create_socket(force_tcp, resolver.clone())?;
tap.connect(CDdnsGlobals::get_tcp_conn_timeout()).await?;
for qh in query_headers.iter()
{
let pkt = qh.0.to_bytes(tap.should_append_len())?;
tap.send(pkt.as_slice()).await?;
}
let mut resp: QDnsQueryResult = QDnsQueryResult::with_capacity(self.ordered_req_list.len());
let mut requery: HashMap<DnsRequestHeader, QDnsReq> = HashMap::new();
loop
{
if query_headers.len() == 0
{
break;
}
tap.poll_read().await?;
let ans = self.read_response(tap.as_mut()).await?;
let Some((query_header, qdnsreq)) = query_headers.remove_entry(&ans.req_header)
else
{
internal_error!(CDnsErrorType::IoError, "can not find response with request: {}", ans.req_header);
};
ans.verify(&query_header)?;
let qdns_resp = QDnsQuery::from_response(tap.get_remote_addr(), ans, now);
if let Ok(ref qdns) = qdns_resp
{
if qdns.get_status().should_try_tcp() == true && force_tcp == false
{
requery.insert(query_header, qdnsreq);
}
else
{
resp.push(qdnsreq, qdns_resp);
}
}
else
{
resp.push(qdnsreq, qdns_resp);
}
}
if requery.is_empty() == false
{
let res = self.query_exec_pipelined(now, resolver, Some(requery)).await?;
resp.extend(res);
}
return Ok(resp);
}
#[async_recursion]
async
fn query_exec_seq(
&self,
now: Option<&Instant>,
resolver: Arc<ResolveConfEntry>,
query: &QDnsReq,
requery: Option<DnsRequestHeader<'async_recursion>>,
) -> CDnsResult<QDnsQuery>
{
let force_tcp = self.resolvers.option_flags.is_force_tcp() || requery.is_some();
let mut tap = self.create_socket(force_tcp, resolver.clone())?;
let query_header =
if let Some(mut requery) = requery
{
requery.regenerate_id();
requery
}
else
{
let drh_req = DnsRequestHeader::try_from(query)?;
drh_req
};
let res =
{
tap.connect(CDdnsGlobals::get_tcp_conn_timeout()).await?;
let pkt = query_header.to_bytes(tap.should_append_len())?;
tap.send(pkt.as_slice()).await?;
let ans = self.read_response(tap.as_mut()).await?;
ans.verify(&query_header)?;
let resp = QDnsQuery::from_response(tap.get_remote_addr(), ans, now)?;
Ok(resp)
};
if (res.is_ok() == true && res.as_ref().unwrap().status.should_try_tcp() == false) ||
(res.is_err() == true && force_tcp == true)
{
return res;
}
return
self.query_exec_seq(now, resolver.clone(), query, Some(query_header)).await;
}
async
fn read_response(&self, socktap: &mut dyn SocketTap<LOC>) -> CDnsResult<DnsRequestAnswer<'_>>
{
if socktap.is_tcp() == false
{
let mut rcvbuf = vec![0_u8; 1457];
let n = socktap.recv(rcvbuf.as_mut_slice()).await?;
return DnsRequestAnswer::parse(&rcvbuf); }
else
{
let mut pkg_pen: [u8; 2] = [0, 0];
let n = socktap.recv(&mut pkg_pen).await?;
if n == 0
{
internal_error!(CDnsErrorType::IoError, "tcp received zero len message!");
}
else if n != 2
{
internal_error!(CDnsErrorType::IoError, "tcp expected 2 bytes to be read!");
}
let ln = u16::from_be_bytes(pkg_pen);
let mut rcvbuf = vec![0_u8; ln as usize];
let mut n = socktap.recv(rcvbuf.as_mut_slice()).await?;
if n == 0
{
internal_error!(CDnsErrorType::IoError, "tcp received zero len message!");
}
else if n == 1
{
n = socktap.recv(&mut rcvbuf[1..]).await?;
if n == 0
{
internal_error!(CDnsErrorType::IoError, "tcp received zero len message again!");
}
n += 1;
}
return DnsRequestAnswer::parse(&rcvbuf);
}
}
}
#[cfg(feature = "use_async_tokio")]
#[cfg(test)]
mod tests
{
use std::{net::IpAddr, sync::Arc};
use crate::{a_sync::{query::QDns, CachesController}, common::{byte2hexchar, ip2pkt, RecordPTR, RecordReader}, QDnsQueryRec, QType, QuerySetup};
#[tokio::test]
async fn test_ip2pkt()
{
use tokio::time::Instant;
use std::net::{IpAddr, Ipv4Addr};
let test = IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8));
let now = Instant::now();
let res = ip2pkt(&test);
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
let ctrl = b"\x01\x38\x01\x38\x01\x38\x01\x38\x07\x69\x6e\x2d\x61\x64\x64\x72\x04\x61\x72\x70\x61\x00";
assert_eq!(res.as_slice(), ctrl);
}
#[tokio::test]
async fn test_byte2hexchar()
{
assert_eq!(byte2hexchar(1), 0x31);
assert_eq!(byte2hexchar(9), 0x39);
assert_eq!(byte2hexchar(10), 'a' as u8);
assert_eq!(byte2hexchar(15), 'f' as u8);
}
#[tokio::test]
async fn reverse_lookup_test()
{
use tokio::time::Instant;
let cache = Arc::new(CachesController::new().await.unwrap());
let ipp: IpAddr = "8.8.8.8".parse().unwrap();
let mut query_setup = QuerySetup::default();
query_setup.set_measure_time(true);
let now = Instant::now();
let mut dns_req =
QDns::builtin_make_empty(None, query_setup, cache).await.unwrap();
dns_req.add_request(QType::PTR, ipp).unwrap();
let res = dns_req.query().await;
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
println!("{}", res);
assert_eq!(res.is_empty(), false);
let recs = res.collect_ok();
let rec = &recs[0];
assert_eq!(rec.status, QDnsQueryRec::Ok);
assert_eq!(rec.resp.len(), 1);
assert_eq!(rec.resp[0].rdata, RecordPTR::wrap(RecordPTR{ fqdn: "dns.google".to_string() }));
}
#[tokio::test]
async fn reverse_lookup_hosts_test()
{
use tokio::time::Instant;
let cache = Arc::new(CachesController::new().await.unwrap());
let ipp: IpAddr = "127.0.0.1".parse().unwrap();
let now = Instant::now();
let mut query_setup = QuerySetup::default();
query_setup.set_measure_time(true);
let mut dns_req =
QDns::builtin_make_empty(None, query_setup, cache).await.unwrap();
dns_req.add_request(QType::PTR, ipp).unwrap();
let res = dns_req.query().await;
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
println!("{}", res);
assert_eq!(res.is_empty(), false);
let recs = res.collect_ok();
let rec = &recs[0];
assert_eq!(rec.server.as_str(), "/etc/hosts");
assert_eq!(rec.status, QDnsQueryRec::Ok);
assert_eq!(rec.resp.len(), 1);
assert_eq!(rec.resp[0].rdata, RecordPTR::wrap(RecordPTR{ fqdn: "localhost".to_string() }));
}
#[tokio::test]
async fn reverse_lookup_a()
{
use tokio::time::Instant;
let cache = Arc::new(CachesController::new().await.unwrap());
let mut query_setup = QuerySetup::default();
query_setup.set_measure_time(true);
let res =
QDns::buildin_make_a_aaaa_request(None, "dns.google", query_setup, cache).await.unwrap();
let now = Instant::now();
let res = res.query().await;
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
println!("{}", res);
}
#[tokio::test]
async fn reverse_lookup_a_idn()
{
use std::time::Instant;
let cache = Arc::new(CachesController::new().await.unwrap());
let mut query_setup = QuerySetup::default();
query_setup.set_measure_time(true);
let res =
QDns
::buildin_make_a_aaaa_request(None, "законипорядок.бел", query_setup, cache).await.unwrap();
let now = Instant::now();
let res = res.query().await;
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
println!("{}", res);
let ok: Vec<crate::QDnsQuery> = res.collect_ok();
let name = ok[0].get_responses()[0].name.clone();
let idn_name = ok[0].get_responses()[0].get_full_domain_name_with_idn_decode().unwrap();
println!("{} -> {}", name, idn_name);
assert_eq!(name, "xn--80aihfjcshcbin9q.xn--90ais");
assert_eq!(idn_name, "законипорядок.бел");
}
}