#[cfg(feature = "smol-async")]
use crate::smol_async::{query_raw_tcp, query_raw_udp};
#[cfg(feature = "std-async")]
use crate::std_async::{query_raw_tcp, query_raw_udp};
#[cfg(feature = "sync")]
use crate::sync::{query_raw_tcp, query_raw_udp};
#[cfg(feature = "tokio-async")]
use crate::tokio_async::{query_raw_tcp, query_raw_udp};
use crate::{err::as_io_error, reverse::reverse_dns_query, tcp::tcp_query};
use dnssector::constants::{Class, Type};
use dnssector::*;
use std::{
io::{self, Error, ErrorKind},
net::{IpAddr, SocketAddr},
time::Duration,
};
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Clone, Debug)]
pub struct DNSClient {
upstream_server_timeout: Duration,
upstream_servers: Vec<SocketAddr>,
local_v4_addr: SocketAddr,
local_v6_addr: SocketAddr,
}
impl DNSClient {
pub fn new(upstream_servers: Vec<SocketAddr>) -> Self {
DNSClient {
upstream_server_timeout: DEFAULT_TIMEOUT,
upstream_servers,
local_v4_addr: ([0; 4], 0).into(),
local_v6_addr: ([0; 16], 0).into(),
}
}
pub fn set_timeout(&mut self, timeout: Duration) {
self.upstream_server_timeout = timeout
}
pub fn set_local_v4_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
self.local_v4_addr = addr.into()
}
pub fn set_local_v6_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
self.local_v6_addr = addr.into()
}
#[maybe_async::maybe_async]
pub async fn query_a(&self, name: &str) -> io::Result<Vec<IpAddr>> {
let name = encode_name(name)?;
let query = dnssector::gen::query(name.as_bytes(), Type::A, Class::IN)
.map_err(as_io_error(ErrorKind::InvalidInput))?;
let response = self.query(query).await?;
extract_ips(response)
}
#[maybe_async::maybe_async]
pub async fn query_aaaa(&self, name: &str) -> io::Result<Vec<IpAddr>> {
let name = encode_name(name)?;
let query = dnssector::gen::query(name.as_bytes(), Type::AAAA, Class::IN)
.map_err(as_io_error(ErrorKind::InvalidInput))?;
let response = self.query(query).await?;
extract_ips(response)
}
#[maybe_async::maybe_async]
pub async fn query_ptr(&self, ip: IpAddr) -> io::Result<String> {
let in_addr = reverse_dns_query(ip);
let query = dnssector::gen::query(&in_addr, Type::PTR, Class::IN)
.map_err(as_io_error(ErrorKind::InvalidInput))?;
let response = self.query(query).await?;
extract_names(response).map(|mut v| v.remove(0))
}
#[maybe_async::maybe_async]
pub async fn query_ns(&self, domain: &str) -> io::Result<Vec<String>> {
let query = dnssector::gen::query(domain.as_bytes(), Type::NS, Class::IN)
.map_err(as_io_error(ErrorKind::InvalidInput))?;
let response = self.query(query).await?;
extract_names(response).or_else(|e| {
if e.kind() == ErrorKind::NotFound {
Ok(Vec::new())
} else {
Err(e)
}
})
}
#[maybe_async::maybe_async]
async fn query(&self, packet: ParsedPacket) -> io::Result<ParsedPacket> {
let is_compressed = matches!(
packet.qtype_qclass(),
Some((rr_type, _class)) if rr_type == Type::NS as u16
);
let raw_packet = packet.into_packet();
for i in 0..self.upstream_servers.len() {
let response = self
.query_upstream(&self.upstream_servers[i], &raw_packet, is_compressed)
.await;
if response.is_ok() || i >= self.upstream_servers.len() - 1 {
return response;
}
}
unreachable!("query must be ok or err");
}
#[maybe_async::maybe_async]
async fn query_upstream(
&self,
upstream: &SocketAddr,
packet: &[u8],
is_compressed_response: bool,
) -> io::Result<ParsedPacket> {
let local_addr = match upstream {
SocketAddr::V4(_) => &self.local_v4_addr,
SocketAddr::V6(_) => &self.local_v6_addr,
};
let raw_response =
query_raw_udp(local_addr, upstream, packet, self.upstream_server_timeout).await?;
let response = parse_response(raw_response, is_compressed_response)?;
if response.flags() & DNS_FLAG_TC != DNS_FLAG_TC {
return Ok(response);
}
let tcp_packet = tcp_query(packet);
let raw_response =
query_raw_tcp(upstream, &tcp_packet, self.upstream_server_timeout).await?;
parse_response(raw_response, is_compressed_response)
}
}
fn parse_response(raw: Vec<u8>, is_compressed: bool) -> io::Result<ParsedPacket> {
let mut raw_response = raw;
if is_compressed {
raw_response =
Compress::uncompress(&raw_response).map_err(as_io_error(ErrorKind::InvalidData))?;
}
DNSSector::new(raw_response)
.map_err(as_io_error(ErrorKind::InvalidData))?
.parse()
.map_err(as_io_error(ErrorKind::InvalidData))
}
fn extract_ips(mut packet: ParsedPacket) -> io::Result<Vec<IpAddr>> {
use std::result::Result as StdResult;
let mut ips = Vec::new();
let mut response = packet.into_iter_answer();
while let Some(i) = response {
ips.push(i.rr_ip());
response = i.next();
}
let (ips, errors): (Vec<_>, Vec<_>) = ips.into_iter().partition(StdResult::is_ok);
if ips.is_empty() {
if let Some(Err(e)) = errors.into_iter().next() {
return Err(Error::new(ErrorKind::InvalidData, e));
}
}
let ips: Vec<_> = ips.into_iter().map(StdResult::unwrap).collect();
Ok(ips)
}
fn extract_names(mut packet: ParsedPacket) -> io::Result<Vec<String>> {
let mut response = packet.into_iter_answer();
let mut ret = Vec::new();
while let Some(i) = response {
let raw_name = &i.rdata_slice()[DNS_RR_HEADER_SIZE..];
let name = parse_tlv_name(raw_name);
ret.push(name);
response = i.next();
}
if ret.is_empty() {
return Err(ErrorKind::NotFound.into());
}
ret.iter().map(|i| decode_name(i)).collect()
}
fn parse_tlv_name(raw: &[u8]) -> Vec<u8> {
let mut result = Vec::with_capacity(raw.len());
let mut i = 0;
let mut remaining = 0;
while i < raw.len() && raw[i] != 0 {
if remaining == 0 {
remaining = raw[i];
if i > 0 {
result.push(b'.')
}
} else {
result.push(raw[i]);
remaining -= 1;
}
i += 1;
}
result
}
fn encode_name(name: &str) -> io::Result<String> {
let parts: io::Result<Vec<String>> = name
.split('.')
.map(|part| {
if part.is_ascii() {
Ok(part.to_string())
} else {
unic_idna_punycode::encode_str(part)
.map(|s| "xn--".to_string() + &s)
.ok_or_else(|| ErrorKind::InvalidInput.into())
}
})
.collect();
let parts = parts?;
let ret = parts.join(".");
Ok(ret)
}
fn decode_name(name: &[u8]) -> io::Result<String> {
let parts: io::Result<Vec<String>> = name
.split(|ch| *ch == b'.')
.map(|part| {
if let Some(code) = part.strip_prefix(b"xn--") {
String::from_utf8(code.to_vec())
.map_err(as_io_error(ErrorKind::InvalidData))
.and_then(|code| {
unic_idna_punycode::decode_to_string(&code)
.ok_or_else(|| ErrorKind::InvalidData.into())
})
} else {
String::from_utf8(part.to_vec()).map_err(as_io_error(ErrorKind::InvalidData))
}
})
.collect();
let parts = parts?;
let ret = parts.join(".");
Ok(ret)
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(not(feature = "sync"))]
use std::future::Future;
use std::{
net::{Ipv4Addr, Ipv6Addr},
str::FromStr,
};
const EXAMPLE_FQDN: &str = "one.one.one.one";
const EXAMPLE_DOMAIN: &str = "one.one.one";
const EXAMPLE_DOMAIN_NS: &str = "ns.cloudflare.com";
const EXAMPLE_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
const EXAMPLE_IPV6: IpAddr =
IpAddr::V6(Ipv6Addr::new(0x2606, 0x4700, 0x4700, 0, 0, 0, 0, 0x1111));
const EXAMPLE_IDN: &str = "日本.icom.museum";
const EXAMPLE_IDN_PUNYCODE: &str = "xn--wgv71a.icom.museum";
const EXAMPLE_IDN_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(81, 201, 190, 55));
#[cfg(feature = "std-async")]
fn block_on<F: Future>(future: F) -> F::Output {
use async_std::task;
task::block_on(future)
}
#[cfg(feature = "smol-async")]
fn block_on<F: Future>(future: F) -> F::Output {
smol::block_on(future)
}
#[cfg(feature = "tokio-async")]
fn block_on<F: Future>(future: F) -> F::Output {
use tokio::runtime;
let rt = runtime::Builder::new_current_thread()
.enable_time()
.enable_io()
.build()
.unwrap();
rt.block_on(future)
}
#[cfg(not(feature = "sync"))]
macro_rules! block_on {
($b:expr) => {
block_on(async move { $b.await })
};
}
#[cfg(feature = "sync")]
macro_rules! block_on {
($b:expr) => {
$b
};
}
fn dns_servers() -> Vec<SocketAddr> {
vec![
SocketAddr::from_str("1.0.0.1:53").unwrap(),
SocketAddr::from_str("1.1.1.1:53").unwrap(),
]
}
fn slow_dns_servers() -> Vec<SocketAddr> {
vec![
SocketAddr::from_str("109.75.41.201:53").unwrap(),
SocketAddr::from_str("124.99.9.4:53").unwrap(),
]
}
#[test]
fn query_a() {
let dns_client = DNSClient::new(dns_servers());
let r = block_on!(dns_client.query_a(EXAMPLE_FQDN)).unwrap();
let expected = EXAMPLE_IPV4;
assert!(r.contains(&expected), "Expected {} got {:?}", expected, r);
}
#[test]
fn query_timeout() {
let mut dns_client = DNSClient::new(slow_dns_servers());
dns_client.set_timeout(Duration::from_millis(1));
let r = block_on!(dns_client.query_a(EXAMPLE_FQDN));
assert!(
matches!(&r, Err(e) if e.kind() == ErrorKind::TimedOut || e.kind() == ErrorKind::WouldBlock),
"Expected timout got {:?}",
r,
);
}
#[test]
fn query_utf8() {
let dns_client = DNSClient::new(dns_servers());
let jp_res = block_on!(dns_client.query_a(EXAMPLE_IDN)).unwrap();
let expected = EXAMPLE_IDN_IP;
assert!(
jp_res.contains(&expected),
"Expected {} for {} got {:?}",
expected,
EXAMPLE_IDN,
jp_res
);
}
#[test]
fn query_aaaa() {
let dns_client = DNSClient::new(dns_servers());
let r = block_on!(dns_client.query_aaaa(EXAMPLE_FQDN)).unwrap();
let expected = EXAMPLE_IPV6;
assert!(r.contains(&expected), "Expected {} got {:?}", expected, r);
}
#[test]
fn query_ptr_ipv4() {
let dns_client = DNSClient::new(dns_servers());
let r = block_on!(dns_client.query_ptr(EXAMPLE_IPV4)).unwrap();
let expected = EXAMPLE_FQDN;
assert!(r == expected, "Expected {} got {:?}", expected, r);
}
#[test]
fn query_ptr_ipv6() {
let dns_client = DNSClient::new(dns_servers());
let r = block_on!(dns_client.query_ptr(EXAMPLE_IPV6)).unwrap();
let expected = EXAMPLE_FQDN;
assert!(r == expected, "Expected {} got {:?}", expected, r);
}
#[test]
fn query_ptr_utf8() {
let r = decode_name(EXAMPLE_IDN_PUNYCODE.as_bytes()).unwrap();
let expected = EXAMPLE_IDN;
assert!(r == expected, "Expected {} got {:?}", expected, r);
}
#[test]
fn query_ns() {
let dns_client = DNSClient::new(dns_servers());
let r = block_on!(dns_client.query_ns(EXAMPLE_DOMAIN)).unwrap();
assert!(
r.iter().any(|n| n.ends_with(EXAMPLE_DOMAIN_NS)),
"Expected {} got {:?}",
EXAMPLE_DOMAIN_NS,
r
);
}
}