1use crate::client::{Answer, Client, Server, Transport};
4use crate::error::Error;
5use crate::output::OutputFormat;
6use bytes::Bytes;
7use domain::base::iana::{Class, Rtype};
8use domain::base::message::Message;
9use domain::base::message_builder::MessageBuilder;
10use domain::base::name::{Name, ParsedName, ToName, UncertainName};
11use domain::base::rdata::RecordData;
12use domain::net::client::request::{ComposeRequest, RequestMessage};
13use domain::rdata::{AllRecordData, Ns, Soa};
14use domain::resolv::stub::conf::ResolvConf;
15use domain::resolv::stub::StubResolver;
16use std::collections::HashSet;
17use std::fmt;
18use std::net::{IpAddr, SocketAddr};
19use std::str::FromStr;
20use std::time::Duration;
21
22#[derive(Clone, Debug, clap::Args)]
25pub struct Query {
26 #[arg(value_name = "QUERY_NAME_OR_ADDR")]
28 qname: NameOrAddr,
29
30 #[arg(value_name = "QUERY_TYPE")]
32 qtype: Option<Rtype>,
33
34 #[arg(short, long, value_name = "ADDR_OR_HOST")]
36 server: Option<ServerName>,
37
38 #[arg(short = 'p', long = "port", requires = "server")]
40 port: Option<u16>,
41
42 #[arg(short = '4', long, conflicts_with = "ipv6")]
44 ipv4: bool,
45
46 #[arg(short = '6', long, conflicts_with = "ipv4")]
48 ipv6: bool,
49
50 #[arg(short, long)]
52 tcp: bool,
53
54 #[arg(short, long)]
56 udp: bool,
57
58 #[arg(long)]
60 tls: bool,
61
62 #[arg(long = "tls-hostname")]
64 tls_hostname: Option<String>,
65
66 #[arg(long, value_name = "SECONDS")]
68 timeout: Option<f32>,
69
70 #[arg(long)]
72 retries: Option<u8>,
73
74 #[arg(long)]
76 udp_payload_size: Option<u16>,
77
78 #[arg(long, overrides_with = "_no_ad")]
81 ad: bool,
82
83 #[arg(long = "no-ad")]
85 _no_ad: bool,
86
87 #[arg(long, overrides_with = "_no_cd")]
89 cd: bool,
90
91 #[arg(long = "no-cd")]
93 _no_cd: bool,
94
95 #[arg(long = "do", overrides_with = "_no_do")]
98 dnssec_ok: bool,
99
100 #[arg(long = "no-do")]
103 _no_do: bool,
104
105 #[arg(long, overrides_with = "no_rd")]
111 _rd: bool,
112
113 #[arg(long = "no-rd")]
115 no_rd: bool,
116
117 #[arg(long, short = 'f')]
120 force: bool,
121
122 #[arg(long)]
124 verify: bool,
125
126 #[arg(long = "format", default_value = "dig")]
128 output_format: OutputFormat,
129}
130
131impl Query {
134 pub fn execute(self) -> Result<(), Error> {
135 #[allow(clippy::collapsible_if)] if !self.force {
137 let qtype = self.qtype();
138 if qtype == Rtype::AXFR || qtype == Rtype::IXFR {
139 return Err(
140 "AXFR and IXFR query types invoke zone transfer which \
141 may result in a sequence\n\
142 of responses but only the first is shown \
143 by the 'query' command.\n\
144 Please use the 'xfr' command for zone transfer.\n\
145 (Use --force to query anyway.)"
146 .into(),
147 );
148 }
149 }
150
151 tokio::runtime::Builder::new_multi_thread()
152 .enable_all()
153 .build()
154 .unwrap()
155 .block_on(self.async_execute())
156 }
157
158 pub async fn async_execute(mut self) -> Result<(), Error> {
159 let client = match self.server {
160 Some(ServerName::Name(ref host)) => {
161 if self.tls_hostname.is_none() {
162 self.tls_hostname = Some(host.to_string());
163 }
164 self.host_server(host).await?
165 }
166 Some(ServerName::Addr(addr)) => {
167 if self.tls && self.tls_hostname.is_none() {
168 return Err(
169 "--tls-hostname is required for TLS transport".into(),
170 );
171 }
172 self.addr_server(addr)
173 }
174 None => {
175 if self.tls {
176 return Err(
177 "--server is required for TLS transport".into()
178 );
179 }
180 self.system_server()
181 }
182 };
183
184 let answer = client.request(self.create_request()).await?;
185 self.output_format.print(&answer)?;
186 if self.verify {
187 let auth_answer = self.auth_answer().await?;
188 if let Some(diff) =
189 Self::diff_answers(auth_answer.message(), answer.message())?
190 {
191 println!("\n;; Authoritative ANSWER does not match.");
192 println!(
193 ";; Difference of ANSWER with authoritative server {}:",
194 auth_answer.stats().server_addr
195 );
196 self.output_diff(diff);
197 } else {
198 println!("\n;; Authoritative ANSWER matches.");
199 }
200 }
201 Ok(())
202 }
203}
204
205impl Query {
208 fn timeout(&self) -> Duration {
209 Duration::from_secs_f32(self.timeout.unwrap_or(5.))
210 }
211
212 fn retries(&self) -> u8 {
213 self.retries.unwrap_or(2)
214 }
215
216 fn udp_payload_size(&self) -> u16 {
217 self.udp_payload_size.unwrap_or(1232)
218 }
219}
220
221impl Query {
224 async fn host_server(
226 &self,
227 server: &UncertainName<Vec<u8>>,
228 ) -> Result<Client, Error> {
229 let resolver = StubResolver::default();
230 let answer = match server {
231 UncertainName::Absolute(name) => resolver.lookup_host(name).await,
232 UncertainName::Relative(name) => resolver.search_host(name).await,
233 }
234 .map_err(|err| err.to_string())?;
235
236 let mut servers = Vec::new();
237 for addr in answer.iter() {
238 if (addr.is_ipv4() && self.ipv6) || (addr.is_ipv6() && self.ipv4)
239 {
240 continue;
241 }
242 servers.push(Server {
243 addr: SocketAddr::new(
244 addr,
245 self.port.unwrap_or({
246 if self.tls {
247 853
248 } else {
249 53
250 }
251 }),
252 ),
253 transport: self.transport(),
254 timeout: self.timeout(),
255 retries: self.retries.unwrap_or(2),
256 udp_payload_size: self.udp_payload_size.unwrap_or(1232),
257 tls_hostname: self.tls_hostname.clone(),
258 });
259 }
260 Ok(Client::with_servers(servers))
261 }
262
263 fn addr_server(&self, addr: IpAddr) -> Client {
265 Client::with_servers(vec![Server {
266 addr: SocketAddr::new(
267 addr,
268 self.port.unwrap_or(if self.tls { 853 } else { 53 }),
269 ),
270 transport: self.transport(),
271 timeout: self.timeout(),
272 retries: self.retries(),
273 udp_payload_size: self.udp_payload_size(),
274 tls_hostname: self.tls_hostname.clone(),
275 }])
276 }
277
278 fn system_server(&self) -> Client {
280 let conf = ResolvConf::default();
281 Client::with_servers(
282 conf.servers
283 .iter()
284 .map(|server| Server {
285 addr: server.addr,
286 transport: self.transport(),
287 timeout: server.request_timeout,
288 retries: u8::try_from(conf.options.attempts).unwrap_or(2),
289 udp_payload_size: server.udp_payload_size,
290 tls_hostname: None,
291 })
292 .collect(),
293 )
294 }
295
296 fn transport(&self) -> Transport {
297 if self.udp {
298 Transport::Udp
299 } else if self.tls {
300 Transport::Tls
301 } else if self.tcp {
302 Transport::Tcp
303 } else {
304 Transport::UdpTcp
305 }
306 }
307}
308
309impl Query {
312 fn create_request(&self) -> RequestMessage<Vec<u8>> {
314 let mut res = MessageBuilder::new_vec();
315
316 res.header_mut().set_ad(self.ad);
317 res.header_mut().set_cd(self.cd);
318 res.header_mut().set_rd(!self.no_rd);
319
320 let mut res = res.question();
321 res.push((&self.qname.to_name(), self.qtype())).unwrap();
322
323 let mut req = RequestMessage::new(res);
324 if self.dnssec_ok {
325 req.set_dnssec_ok(true);
327 }
328 req
329 }
330}
331
332impl Query {
334 async fn auth_answer(&self) -> Result<Answer, Error> {
335 let servers = {
336 let resolver = StubResolver::new();
337 let apex = self.get_apex(&resolver).await?;
338 let ns_set = self.get_ns_set(&apex, &resolver).await?;
339 self.get_ns_addrs(&ns_set, &resolver).await?
340 };
341 Client::with_servers(servers)
342 .query((self.qname.to_name(), self.qtype()))
343 .await
344 }
345
346 async fn get_apex(
348 &self,
349 resolv: &StubResolver,
350 ) -> Result<Name<Vec<u8>>, Error> {
351 let qname = self.qname.to_name();
353 let response = resolv.query((&qname, Rtype::SOA)).await?;
354
355 let mut answer = response.answer()?.limit_to_in::<Soa<_>>();
359 if let Some(soa) = answer.next() {
360 let soa = soa?;
361 if *soa.owner() == qname {
362 return Ok(qname.clone());
363 }
364 }
367
368 let mut authority =
369 answer.next_section()?.unwrap().limit_to_in::<Soa<_>>();
370 if let Some(soa) = authority.next() {
371 let soa = soa?;
372 return Ok(soa.owner().to_name());
373 }
374
375 Err("no SOA record".into())
376 }
377
378 async fn get_ns_set(
380 &self,
381 apex: &Name<Vec<u8>>,
382 resolv: &StubResolver,
383 ) -> Result<Vec<Name<Vec<u8>>>, Error> {
384 let response = resolv.query((apex, Rtype::NS)).await?;
385 let mut res = Vec::new();
386 for record in response.answer()?.limit_to_in::<Ns<_>>() {
387 let record = record?;
388 if *record.owner() != apex {
389 continue;
390 }
391 res.push(record.data().nsdname().to_name());
392 }
393
394 Ok(res)
398 }
399
400 async fn get_ns_addrs(
402 &self,
403 ns_set: &[Name<Vec<u8>>],
404 resolv: &StubResolver,
405 ) -> Result<Vec<Server>, Error> {
406 let mut res = HashSet::new();
407 for ns in ns_set {
408 for addr in resolv.lookup_host(ns).await?.iter() {
409 res.insert(addr);
410 }
411 }
412 Ok(res
413 .into_iter()
414 .map(|addr| Server {
415 addr: SocketAddr::new(addr, 53),
416 transport: Transport::UdpTcp,
417 timeout: self.timeout(),
418 retries: self.retries(),
419 udp_payload_size: self.udp_payload_size(),
420 tls_hostname: None,
421 })
422 .collect())
423 }
424
425 #[allow(clippy::mutable_key_type)]
430 fn diff_answers(
431 left: &Message<Bytes>,
432 right: &Message<Bytes>,
433 ) -> Result<Option<Vec<DiffItem>>, Error> {
434 let left = left
436 .answer()?
437 .into_records::<AllRecordData<_, _>>()
438 .filter_map(Result::ok)
439 .map(|record| {
440 let class = record.class();
441 let (name, data) = record.into_owner_and_data();
442 (name, class, data)
443 })
444 .collect::<HashSet<_>>();
445
446 let right = right
447 .answer()?
448 .into_records::<AllRecordData<_, _>>()
449 .filter_map(Result::ok)
450 .map(|record| {
451 let class = record.class();
452 let (name, data) = record.into_owner_and_data();
453 (name, class, data)
454 })
455 .collect::<HashSet<_>>();
456
457 let mut diff = left
458 .intersection(&right)
459 .cloned()
460 .map(|item| (Action::Unchanged, item))
461 .collect::<Vec<_>>();
462 let size = diff.len();
463
464 diff.extend(
465 left.difference(&right)
466 .cloned()
467 .map(|item| (Action::Removed, item)),
468 );
469
470 diff.extend(
471 right
472 .difference(&left)
473 .cloned()
474 .map(|item| (Action::Added, item)),
475 );
476
477 diff.sort_by(|left, right| left.1.cmp(&right.1));
478
479 if size == diff.len() {
480 Ok(None)
481 } else {
482 Ok(Some(diff))
483 }
484 }
485
486 fn output_diff(&self, diff: Vec<DiffItem>) {
488 for item in diff {
489 println!(
490 "{}{} {} {} {}",
491 item.0,
492 item.1 .0,
493 item.1 .1,
494 item.1 .2.rtype(),
495 item.1 .2
496 );
497 }
498 }
499
500 fn qtype(&self) -> Rtype {
501 match self.qtype {
502 Some(qtype) => qtype,
503 None => match self.qname {
504 NameOrAddr::Addr(_) => Rtype::PTR,
505 NameOrAddr::Name(_) => Rtype::AAAA,
506 },
507 }
508 }
509}
510
511#[derive(Clone, Debug)]
514enum ServerName {
515 Name(UncertainName<Vec<u8>>),
516 Addr(IpAddr),
517}
518
519impl FromStr for ServerName {
520 type Err = &'static str;
521
522 fn from_str(s: &str) -> Result<Self, Self::Err> {
523 if let Ok(addr) = IpAddr::from_str(s) {
524 Ok(ServerName::Addr(addr))
525 } else {
526 UncertainName::from_str(s)
527 .map(Self::Name)
528 .map_err(|_| "illegal host name")
529 }
530 }
531}
532
533#[derive(Clone, Debug)]
536enum NameOrAddr {
537 Name(Name<Vec<u8>>),
538 Addr(IpAddr),
539}
540
541impl NameOrAddr {
542 fn to_name(&self) -> Name<Vec<u8>> {
543 match &self {
544 NameOrAddr::Name(host) => host.clone(),
545 NameOrAddr::Addr(addr) => {
546 Name::<Vec<u8>>::reverse_from_addr(*addr).unwrap()
547 }
548 }
549 }
550}
551
552impl FromStr for NameOrAddr {
553 type Err = &'static str;
554
555 fn from_str(s: &str) -> Result<Self, Self::Err> {
556 if let Ok(addr) = IpAddr::from_str(s) {
557 Ok(NameOrAddr::Addr(addr))
558 } else {
559 Name::from_str(s)
560 .map(Self::Name)
561 .map_err(|_| "illegal host name")
562 }
563 }
564}
565
566#[derive(Clone, Copy, Debug)]
569enum Action {
570 Added,
571 Removed,
572 Unchanged,
573}
574
575impl fmt::Display for Action {
576 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
577 f.write_str(match *self {
578 Self::Added => "+ ",
579 Self::Removed => "- ",
580 Self::Unchanged => " ",
581 })
582 }
583}
584
585type DiffItem = (
588 Action,
589 (
590 ParsedName<Bytes>,
591 Class,
592 AllRecordData<Bytes, ParsedName<Bytes>>,
593 ),
594);