1use std::io;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use std::time::Duration;
4
5use dnssector::constants::{Class, Type};
6use dnssector::*;
7use rand::{seq::SliceRandom, Rng};
8
9use crate::backend::sync::SyncBackend;
10use crate::upstream_server::UpstreamServer;
11
12#[derive(Clone, Debug)]
13pub struct DNSClient {
14 backend: SyncBackend,
15 upstream_servers: Vec<UpstreamServer>,
16 local_v4_addr: SocketAddr,
17 local_v6_addr: SocketAddr,
18 force_tcp: bool,
19}
20
21impl DNSClient {
22 pub fn new(upstream_servers: Vec<UpstreamServer>) -> Self {
23 DNSClient {
24 backend: SyncBackend::new(Duration::new(6, 0)),
25 upstream_servers,
26 local_v4_addr: ([0; 4], 0).into(),
27 local_v6_addr: ([0; 16], 0).into(),
28 force_tcp: false,
29 }
30 }
31
32 #[cfg(unix)]
33 pub fn new_with_system_resolvers() -> Result<Self, io::Error> {
34 Ok(DNSClient::new(crate::system::default_resolvers()?))
35 }
36
37 pub fn set_timeout(&mut self, timeout: Duration) {
38 self.backend.upstream_server_timeout = timeout
39 }
40
41 pub fn set_local_v4_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
42 self.local_v4_addr = addr.into()
43 }
44
45 pub fn set_local_v6_addr<T: Into<SocketAddr>>(&mut self, addr: T) {
46 self.local_v6_addr = addr.into()
47 }
48
49 pub fn force_tcp(&mut self, force_tcp: bool) {
50 self.force_tcp = force_tcp;
51 }
52
53 fn send_query_to_upstream_server(
54 &self,
55 upstream_server: &UpstreamServer,
56 query_tid: u16,
57 query_question: &Option<(Vec<u8>, u16, u16)>,
58 query: &[u8],
59 ) -> Result<ParsedPacket, io::Error> {
60 let local_addr = match upstream_server.addr {
61 SocketAddr::V4(_) => &self.local_v4_addr,
62 SocketAddr::V6(_) => &self.local_v6_addr,
63 };
64 let response = if self.force_tcp {
65 self.backend
66 .dns_exchange_tcp(local_addr, upstream_server, query)?
67 } else {
68 self.backend
69 .dns_exchange_udp(local_addr, upstream_server, query)?
70 };
71 let mut parsed_response = DNSSector::new(response)
72 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
73 .parse()
74 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
75 if !self.force_tcp && parsed_response.flags() & DNS_FLAG_TC == DNS_FLAG_TC {
76 parsed_response = {
77 let response = self
78 .backend
79 .dns_exchange_tcp(local_addr, upstream_server, query)?;
80 DNSSector::new(response)
81 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
82 .parse()
83 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
84 };
85 }
86 if parsed_response.tid() != query_tid || &parsed_response.question() != query_question {
87 return Err(io::Error::new(
88 io::ErrorKind::PermissionDenied,
89 "Unexpected response",
90 ));
91 }
92 Ok(parsed_response)
93 }
94
95 fn query_from_parsed_query(
96 &self,
97 mut parsed_query: ParsedPacket,
98 ) -> Result<ParsedPacket, io::Error> {
99 let query_tid = parsed_query.tid();
100 let query_question = parsed_query.question();
101 if query_question.is_none() || parsed_query.flags() & DNS_FLAG_QR != 0 {
102 return Err(io::Error::new(
103 io::ErrorKind::InvalidInput,
104 "No DNS question",
105 ));
106 }
107 let valid_query = parsed_query.into_packet();
108 for upstream_server in &self.upstream_servers {
109 if let Ok(parsed_response) = self.send_query_to_upstream_server(
110 upstream_server,
111 query_tid,
112 &query_question,
113 &valid_query,
114 ) {
115 return Ok(parsed_response);
116 }
117 }
118 Err(io::Error::new(
119 io::ErrorKind::InvalidInput,
120 "No response received from any servers",
121 ))
122 }
123
124 pub fn query_raw(&self, query: &[u8], tid_masking: bool) -> Result<Vec<u8>, io::Error> {
126 let mut parsed_query = DNSSector::new(query.to_vec())
127 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?
128 .parse()
129 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
130 let mut tid = 0;
131 if tid_masking {
132 tid = parsed_query.tid();
133 let mut rnd = rand::rng();
134 let masked_tid: u16 = rnd.random();
135 parsed_query.set_tid(masked_tid);
136 }
137 let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
138 if tid_masking {
139 parsed_response.set_tid(tid);
140 }
141 let response = parsed_response.into_packet();
142 Ok(response)
143 }
144
145 pub fn query_a(&self, name: &str) -> Result<Vec<Ipv4Addr>, io::Error> {
147 let parsed_query = dnssector::gen::query(
148 name.as_bytes(),
149 Type::from_string("A").unwrap(),
150 Class::from_string("IN").unwrap(),
151 )
152 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
153 let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
154 let mut ips = vec![];
155 {
156 let mut it = parsed_response.into_iter_answer();
157 while let Some(item) = it {
158 if let Ok(IpAddr::V4(addr)) = item.rr_ip() {
159 ips.push(addr);
160 }
161 it = item.next();
162 }
163 }
164 ips.shuffle(&mut rand::rng());
165 Ok(ips)
166 }
167
168 pub fn query_aaaa(&self, name: &str) -> Result<Vec<Ipv6Addr>, io::Error> {
170 let parsed_query = dnssector::gen::query(
171 name.as_bytes(),
172 Type::from_string("AAAA").unwrap(),
173 Class::from_string("IN").unwrap(),
174 )
175 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
176 let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
177 let mut ips = vec![];
178 {
179 let mut it = parsed_response.into_iter_answer();
180 while let Some(item) = it {
181 if let Ok(IpAddr::V6(addr)) = item.rr_ip() {
182 ips.push(addr);
183 }
184 it = item.next();
185 }
186 }
187 ips.shuffle(&mut rand::rng());
188 Ok(ips)
189 }
190
191 pub fn query_addrs(&self, name: &str) -> Result<Vec<IpAddr>, io::Error> {
193 let ipv4_ips = self.query_a(name)?;
194 let ipv6_ips = self.query_aaaa(name)?;
195 let mut ips: Vec<_> = ipv4_ips
196 .into_iter()
197 .map(IpAddr::from)
198 .chain(ipv6_ips.into_iter().map(IpAddr::from))
199 .collect();
200 ips.shuffle(&mut rand::rng());
201 Ok(ips)
202 }
203
204 pub fn query_txt(&self, name: &str) -> Result<Vec<Vec<u8>>, io::Error> {
206 let rr_class = Class::from_string("IN").unwrap();
207 let rr_type = Type::from_string("TXT").unwrap();
208 let parsed_query = dnssector::gen::query(name.as_bytes(), rr_type, rr_class)
209 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
210 let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
211 let mut txts: Vec<Vec<u8>> = vec![];
212
213 let mut it = parsed_response.into_iter_answer();
214 while let Some(item) = it {
215 if item.rr_class() != rr_class.into() || item.rr_type() != rr_type.into() {
216 it = item.next();
217 continue;
218 }
219 if let Ok(RawRRData::Data(data)) = item.rr_rd() {
220 let mut txt = vec![];
221 let mut it = data.iter();
222 while let Some(&len) = it.next() {
223 for _ in 0..len {
224 txt.push(*it.next().ok_or_else(|| {
225 io::Error::new(io::ErrorKind::InvalidInput, "Invalid text record")
226 })?)
227 }
228 }
229 txts.push(txt);
230 }
231 it = item.next();
232 }
233 Ok(txts)
234 }
235
236 pub fn query_ptr(&self, ip: &IpAddr) -> Result<Vec<String>, io::Error> {
238 let rr_class = Class::from_string("IN").unwrap();
239 let rr_type = Type::from_string("PTR").unwrap();
240 let rev_name = match ip {
241 IpAddr::V4(ip) => {
242 let mut octets = ip.octets();
243 octets.reverse();
244 format!(
245 "{}.{}.{}.{}.in-addr.arpa",
246 octets[0], octets[1], octets[2], octets[3]
247 )
248 }
249 IpAddr::V6(ip) => {
250 let mut octets = ip.octets();
251 octets.reverse();
252 let rev = octets
253 .iter()
254 .map(|x| x.to_string())
255 .collect::<Vec<_>>()
256 .join(".");
257 format!("{}.ip6.arpa", rev)
258 }
259 };
260 let parsed_query = dnssector::gen::query(rev_name.as_bytes(), rr_type, rr_class)
261 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
262 let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
263 let mut names: Vec<String> = vec![];
264
265 let mut it = parsed_response.into_iter_answer();
266 while let Some(item) = it {
267 if item.rr_class() != rr_class.into() || item.rr_type() != rr_type.into() {
268 it = item.next();
269 continue;
270 }
271 if let Ok(RawRRData::Data(data)) = item.rr_rd() {
272 let mut name = vec![];
273 let mut it = data.iter();
274 while let Some(&len) = it.next() {
275 if len != 0 && !name.is_empty() {
276 name.push(b'.');
277 }
278 for _ in 0..len {
279 name.push(*it.next().ok_or_else(|| {
280 io::Error::new(io::ErrorKind::InvalidInput, "Invalid text record")
281 })?)
282 }
283 }
284 if name.is_empty() {
285 name.push(b'.');
286 }
287 if let Ok(name) = String::from_utf8(name) {
288 match ip {
289 IpAddr::V4(ip) => {
290 if self.query_a(&name)?.contains(ip) {
291 names.push(name)
292 }
293 }
294 IpAddr::V6(ip) => {
295 if self.query_aaaa(&name)?.contains(ip) {
296 names.push(name)
297 }
298 }
299 };
300 }
301 }
302 it = item.next();
303 }
304 Ok(names)
305 }
306
307 pub fn query_rrs_data(
309 &self,
310 name: &str,
311 query_class: &str,
312 query_type: &str,
313 ) -> Result<Vec<Vec<u8>>, io::Error> {
314 let rr_class = Class::from_string(query_class)
315 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
316 let rr_type = Type::from_string(query_type)
317 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
318 let parsed_query = dnssector::gen::query(name.as_bytes(), rr_type, rr_class)
319 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
320 let mut parsed_response = self.query_from_parsed_query(parsed_query)?;
321 let mut raw_rrs = vec![];
322
323 let mut it = parsed_response.into_iter_answer();
324 while let Some(item) = it {
325 if item.rr_class() != rr_class.into() || item.rr_type() != rr_type.into() {
326 it = item.next();
327 continue;
328 }
329 if let Ok(RawRRData::Data(data)) = item.rr_rd() {
330 raw_rrs.push(data.to_vec());
331 }
332 it = item.next();
333 }
334 Ok(raw_rrs)
335 }
336}
337
338#[test]
339fn test_query_a() {
340 use std::str::FromStr;
341
342 let upstream_servers = crate::system::default_resolvers().unwrap_or_else(|_| {
343 vec![
344 UpstreamServer::new(SocketAddr::from_str("1.0.0.1:53").unwrap()),
345 UpstreamServer::new(SocketAddr::from_str("1.1.1.1:53").unwrap()),
346 ]
347 });
348 let dns_client = DNSClient::new(upstream_servers);
349 let r = dns_client.query_a("one.one.one.one").unwrap();
350 assert!(r.contains(&Ipv4Addr::new(1, 1, 1, 1)));
351}
352
353#[test]
354fn test_query_ptr() {
355 use std::str::FromStr;
356
357 let upstream_servers = crate::system::default_resolvers().unwrap_or_else(|_| {
358 vec![
359 UpstreamServer::new(SocketAddr::from_str("1.0.0.1:53").unwrap()),
360 UpstreamServer::new(SocketAddr::from_str("1.1.1.1:53").unwrap()),
361 ]
362 });
363 let dns_client = DNSClient::new(upstream_servers);
364 let r = dns_client
365 .query_ptr(&IpAddr::from_str("1.1.1.1").unwrap())
366 .unwrap();
367 assert_eq!(r[0], "one.one.one.one");
368}