use std::io;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::time::Duration;
use dnssector::constants::{Class, Type};
use dnssector::*;
use rand::{seq::SliceRandom, Rng};
use crate::backend::sync::SyncBackend;
use crate::upstream_server::UpstreamServer;
#[derive(Clone, Debug)]
pub struct DNSClient {
backend: SyncBackend,
upstream_servers: Vec<UpstreamServer>,
local_v4_addr: SocketAddr,
local_v6_addr: SocketAddr,
force_tcp: bool,
}
impl DNSClient {
pub fn new(upstream_servers: Vec<UpstreamServer>) -> Self {
DNSClient {
backend: SyncBackend::new(Duration::new(6, 0)),
upstream_servers,
local_v4_addr: ([0; 4], 0).into(),
local_v6_addr: ([0; 16], 0).into(),
force_tcp: false,
}
}
#[cfg(unix)]
pub fn new_with_system_resolvers() -> Result<Self, io::Error> {
Ok(DNSClient::new(crate::system::default_resolvers()?))
}
pub fn set_timeout(&mut self, timeout: Duration) {
self.backend.upstream_server_timeout = timeout
}
pub fn set_local_v4_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
self.local_v4_addr = addr.into()
}
pub fn set_local_v6_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
self.local_v6_addr = addr.into()
}
pub fn force_tcp(&mut self, force_tcp: bool) {
self.force_tcp = force_tcp;
}
fn send_query_to_upstream_server(
&self,
upstream_server: &UpstreamServer,
query_tid: u16,
query_question: &Option<(Vec<u8>, u16, u16)>,
query: &[u8],
) -> Result<ParsedPacket, io::Error> {
let local_addr = match upstream_server.addr {
SocketAddr::V4(_) => &self.local_v4_addr,
SocketAddr::V6(_) => &self.local_v6_addr,
};
let response = if self.force_tcp {
self.backend
.dns_exchange_tcp(local_addr, upstream_server, query)?
} else {
self.backend
.dns_exchange_udp(local_addr, upstream_server, query)?
};
let mut parsed_response = DNSSector::new(response)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
if !self.force_tcp && parsed_response.flags() & DNS_FLAG_TC == DNS_FLAG_TC {
parsed_response = {
let response = self
.backend
.dns_exchange_tcp(local_addr, upstream_server, query)?;
DNSSector::new(response)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
};
}
if parsed_response.tid() != query_tid || &parsed_response.question() != query_question {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"Unexpected response",
));
}
Ok(parsed_response)
}
fn query_from_parsed_query(
&self,
mut parsed_query: ParsedPacket,
) -> Result<ParsedPacket, io::Error> {
let query_tid = parsed_query.tid();
let query_question = parsed_query.question();
if query_question.is_none() || parsed_query.flags() & DNS_FLAG_QR != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"No DNS question",
));
}
let valid_query = parsed_query.into_packet();
for upstream_server in &self.upstream_servers {
if let Ok(parsed_response) = self.send_query_to_upstream_server(
upstream_server,
query_tid,
&query_question,
&valid_query,
) {
return Ok(parsed_response);
}
}
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"No response received from any servers",
))
}
pub fn query_raw(&self, query: &[u8], tid_masking: bool) -> Result<Vec<u8>, io::Error> {
let mut parsed_query = DNSSector::new(query.to_vec())
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
let mut tid = 0;
if tid_masking {
tid = parsed_query.tid();
let mut rnd = rand::rng();
let masked_tid: u16 = rnd.random();
parsed_query.set_tid(masked_tid);
}
let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
if tid_masking {
parsed_response.set_tid(tid);
}
let response = parsed_response.into_packet();
Ok(response)
}
pub fn query_a(&self, name: &str) -> Result<Vec<Ipv4Addr>, io::Error> {
let parsed_query = dnssector::gen::query(
name.as_bytes(),
Type::from_string("A").unwrap(),
Class::from_string("IN").unwrap(),
)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
let mut ips = vec![];
{
let mut it = parsed_response.into_iter_answer();
while let Some(item) = it {
if let Ok(IpAddr::V4(addr)) = item.rr_ip() {
ips.push(addr);
}
it = item.next();
}
}
ips.shuffle(&mut rand::rng());
Ok(ips)
}
pub fn query_aaaa(&self, name: &str) -> Result<Vec<Ipv6Addr>, io::Error> {
let parsed_query = dnssector::gen::query(
name.as_bytes(),
Type::from_string("AAAA").unwrap(),
Class::from_string("IN").unwrap(),
)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
let mut ips = vec![];
{
let mut it = parsed_response.into_iter_answer();
while let Some(item) = it {
if let Ok(IpAddr::V6(addr)) = item.rr_ip() {
ips.push(addr);
}
it = item.next();
}
}
ips.shuffle(&mut rand::rng());
Ok(ips)
}
pub fn query_addrs(&self, name: &str) -> Result<Vec<IpAddr>, io::Error> {
let ipv4_ips = self.query_a(name)?;
let ipv6_ips = self.query_aaaa(name)?;
let mut ips: Vec<_> = ipv4_ips
.into_iter()
.map(IpAddr::from)
.chain(ipv6_ips.into_iter().map(IpAddr::from))
.collect();
ips.shuffle(&mut rand::rng());
Ok(ips)
}
pub fn query_txt(&self, name: &str) -> Result<Vec<Vec<u8>>, io::Error> {
let rr_class = Class::from_string("IN").unwrap();
let rr_type = Type::from_string("TXT").unwrap();
let parsed_query = dnssector::gen::query(name.as_bytes(), rr_type, rr_class)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
let mut txts: Vec<Vec<u8>> = vec![];
let mut it = parsed_response.into_iter_answer();
while let Some(item) = it {
if item.rr_class() != rr_class.into() || item.rr_type() != rr_type.into() {
it = item.next();
continue;
}
if let Ok(RawRRData::Data(data)) = item.rr_rd() {
let mut txt = vec![];
let mut it = data.iter();
while let Some(&len) = it.next() {
for _ in 0..len {
txt.push(*it.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "Invalid text record")
})?)
}
}
txts.push(txt);
}
it = item.next();
}
Ok(txts)
}
pub fn query_ptr(&self, ip: &IpAddr) -> Result<Vec<String>, io::Error> {
let rr_class = Class::from_string("IN").unwrap();
let rr_type = Type::from_string("PTR").unwrap();
let rev_name = match ip {
IpAddr::V4(ip) => {
let mut octets = ip.octets();
octets.reverse();
format!(
"{}.{}.{}.{}.in-addr.arpa",
octets[0], octets[1], octets[2], octets[3]
)
}
IpAddr::V6(ip) => {
let mut octets = ip.octets();
octets.reverse();
let rev = octets
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(".");
format!("{}.ip6.arpa", rev)
}
};
let parsed_query = dnssector::gen::query(rev_name.as_bytes(), rr_type, rr_class)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
let mut names: Vec<String> = vec![];
let mut it = parsed_response.into_iter_answer();
while let Some(item) = it {
if item.rr_class() != rr_class.into() || item.rr_type() != rr_type.into() {
it = item.next();
continue;
}
if let Ok(RawRRData::Data(data)) = item.rr_rd() {
let mut name = vec![];
let mut it = data.iter();
while let Some(&len) = it.next() {
if len != 0 && !name.is_empty() {
name.push(b'.');
}
for _ in 0..len {
name.push(*it.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "Invalid text record")
})?)
}
}
if name.is_empty() {
name.push(b'.');
}
if let Ok(name) = String::from_utf8(name) {
match ip {
IpAddr::V4(ip) => {
if self.query_a(&name)?.contains(ip) {
names.push(name)
}
}
IpAddr::V6(ip) => {
if self.query_aaaa(&name)?.contains(ip) {
names.push(name)
}
}
};
}
}
it = item.next();
}
Ok(names)
}
pub fn query_rrs_data(
&self,
name: &str,
query_class: &str,
query_type: &str,
) -> Result<Vec<Vec<u8>>, io::Error> {
let rr_class = Class::from_string(query_class)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
let rr_type = Type::from_string(query_type)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
let parsed_query = dnssector::gen::query(name.as_bytes(), rr_type, rr_class)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
let mut raw_rrs = vec![];
let mut it = parsed_response.into_iter_answer();
while let Some(item) = it {
if item.rr_class() != rr_class.into() || item.rr_type() != rr_type.into() {
it = item.next();
continue;
}
if let Ok(RawRRData::Data(data)) = item.rr_rd() {
raw_rrs.push(data.to_vec());
}
it = item.next();
}
Ok(raw_rrs)
}
}
#[test]
fn test_query_a() {
use std::str::FromStr;
let upstream_servers = crate::system::default_resolvers().unwrap_or_else(|_| {
vec![
UpstreamServer::new(SocketAddr::from_str("1.0.0.1:53").unwrap()),
UpstreamServer::new(SocketAddr::from_str("1.1.1.1:53").unwrap()),
]
});
let dns_client = DNSClient::new(upstream_servers);
let r = dns_client.query_a("one.one.one.one").unwrap();
assert!(r.contains(&Ipv4Addr::new(1, 1, 1, 1)));
}
#[test]
fn test_query_ptr() {
use std::str::FromStr;
let upstream_servers = crate::system::default_resolvers().unwrap_or_else(|_| {
vec![
UpstreamServer::new(SocketAddr::from_str("1.0.0.1:53").unwrap()),
UpstreamServer::new(SocketAddr::from_str("1.1.1.1:53").unwrap()),
]
});
let dns_client = DNSClient::new(upstream_servers);
let r = dns_client
.query_ptr(&IpAddr::from_str("1.1.1.1").unwrap())
.unwrap();
assert_eq!(r[0], "one.one.one.one");
}