dnsi/commands/
query.rs

1//! The query command of _dnsi._
2
3use crate::client::{Answer, Client, Server, Transport};
4use crate::error::Error;
5use crate::output::OutputFormat;
6use bytes::Bytes;
7use domain::base::iana::{Class, Rtype};
8use domain::base::message::Message;
9use domain::base::message_builder::MessageBuilder;
10use domain::base::name::{Name, ParsedName, ToName, UncertainName};
11use domain::base::rdata::RecordData;
12use domain::net::client::request::{ComposeRequest, RequestMessage};
13use domain::rdata::{AllRecordData, Ns, Soa};
14use domain::resolv::stub::conf::ResolvConf;
15use domain::resolv::stub::StubResolver;
16use std::collections::HashSet;
17use std::fmt;
18use std::net::{IpAddr, SocketAddr};
19use std::str::FromStr;
20use std::time::Duration;
21
22//------------ Query ---------------------------------------------------------
23
24#[derive(Clone, Debug, clap::Args)]
25pub struct Query {
26    /// The name of the resource records to look up
27    #[arg(value_name = "QUERY_NAME_OR_ADDR")]
28    qname: NameOrAddr,
29
30    /// The record type to look up
31    #[arg(value_name = "QUERY_TYPE")]
32    qtype: Option<Rtype>,
33
34    /// The server to send the query to. System servers used if missing
35    #[arg(short, long, value_name = "ADDR_OR_HOST")]
36    server: Option<ServerName>,
37
38    /// The port of the server to send query to.
39    #[arg(short = 'p', long = "port", requires = "server")]
40    port: Option<u16>,
41
42    /// Use only IPv4 for communication.
43    #[arg(short = '4', long, conflicts_with = "ipv6")]
44    ipv4: bool,
45
46    /// Use only IPv6 for communication.
47    #[arg(short = '6', long, conflicts_with = "ipv4")]
48    ipv6: bool,
49
50    /// Use only TCP.
51    #[arg(short, long)]
52    tcp: bool,
53
54    /// Use only UDP.
55    #[arg(short, long)]
56    udp: bool,
57
58    /// Use TLS.
59    #[arg(long)]
60    tls: bool,
61
62    /// The name of the server for SNI and certificate verification.
63    #[arg(long = "tls-hostname")]
64    tls_hostname: Option<String>,
65
66    /// Set the timeout for a query.
67    #[arg(long, value_name = "SECONDS")]
68    timeout: Option<f32>,
69
70    /// Set the number of retries over UDP.
71    #[arg(long)]
72    retries: Option<u8>,
73
74    /// Set the advertised UDP payload size.
75    #[arg(long)]
76    udp_payload_size: Option<u16>,
77
78    // No need to set the AA flag in the request.
79    /// Set the AD flag in the request.
80    #[arg(long, overrides_with = "_no_ad")]
81    ad: bool,
82
83    /// Do not set the AD flag in the request.
84    #[arg(long = "no-ad")]
85    _no_ad: bool,
86
87    /// Set the CD flag in the request.
88    #[arg(long, overrides_with = "_no_cd")]
89    cd: bool,
90
91    /// Do not set the CD flag in the request.
92    #[arg(long = "no-cd")]
93    _no_cd: bool,
94
95    /// Set the DO flag in the EDNS Opt record in the request.
96    // Calling the field `do` would conflict with the keyward `do`.
97    #[arg(long = "do", overrides_with = "_no_do")]
98    dnssec_ok: bool,
99
100    /// Do not set the DO flag in the request, avoid creating the EDNS Opt
101    /// record.
102    #[arg(long = "no-do")]
103    _no_do: bool,
104
105    // No need to set the RA flag in the request.
106    /// Set the RD flag in the request.
107    // Tricky, we want RD default to true. The obvious, to have default_value
108    // fails in combination with overrides_with. The solution is to test if
109    // no_rd is false.
110    #[arg(long, overrides_with = "no_rd")]
111    _rd: bool,
112
113    /// Do not set the RD flag in the request.
114    #[arg(long = "no-rd")]
115    no_rd: bool,
116
117    // No need to set the TC flag in the request.
118    /// Disable all sanity checks.
119    #[arg(long, short = 'f')]
120    force: bool,
121
122    /// Verify the answer against an authoritative server.
123    #[arg(long)]
124    verify: bool,
125
126    /// Select the output format.
127    #[arg(long = "format", default_value = "dig")]
128    output_format: OutputFormat,
129}
130
131/// # Executing the command
132///
133impl Query {
134    pub fn execute(self) -> Result<(), Error> {
135        #[allow(clippy::collapsible_if)] // There may be more later ...
136        if !self.force {
137            let qtype = self.qtype();
138            if qtype == Rtype::AXFR || qtype == Rtype::IXFR {
139                return Err(
140                    "AXFR and IXFR query types invoke zone transfer which \
141                     may result in a sequence\n\
142                     of responses but only the first is shown \
143                     by the 'query' command.\n\
144                     Please use the 'xfr' command for zone transfer.\n\
145                     (Use --force to query anyway.)"
146                        .into(),
147                );
148            }
149        }
150
151        tokio::runtime::Builder::new_multi_thread()
152            .enable_all()
153            .build()
154            .unwrap()
155            .block_on(self.async_execute())
156    }
157
158    pub async fn async_execute(mut self) -> Result<(), Error> {
159        let client = match self.server {
160            Some(ServerName::Name(ref host)) => {
161                if self.tls_hostname.is_none() {
162                    self.tls_hostname = Some(host.to_string());
163                }
164                self.host_server(host).await?
165            }
166            Some(ServerName::Addr(addr)) => {
167                if self.tls && self.tls_hostname.is_none() {
168                    return Err(
169                        "--tls-hostname is required for TLS transport".into(),
170                    );
171                }
172                self.addr_server(addr)
173            }
174            None => {
175                if self.tls {
176                    return Err(
177                        "--server is required for TLS transport".into()
178                    );
179                }
180                self.system_server()
181            }
182        };
183
184        let answer = client.request(self.create_request()).await?;
185        self.output_format.print(&answer)?;
186        if self.verify {
187            let auth_answer = self.auth_answer().await?;
188            if let Some(diff) =
189                Self::diff_answers(auth_answer.message(), answer.message())?
190            {
191                println!("\n;; Authoritative ANSWER does not match.");
192                println!(
193                    ";; Difference of ANSWER with authoritative server {}:",
194                    auth_answer.stats().server_addr
195                );
196                self.output_diff(diff);
197            } else {
198                println!("\n;; Authoritative ANSWER matches.");
199            }
200        }
201        Ok(())
202    }
203}
204
205/// # Configuration
206///
207impl Query {
208    fn timeout(&self) -> Duration {
209        Duration::from_secs_f32(self.timeout.unwrap_or(5.))
210    }
211
212    fn retries(&self) -> u8 {
213        self.retries.unwrap_or(2)
214    }
215
216    fn udp_payload_size(&self) -> u16 {
217        self.udp_payload_size.unwrap_or(1232)
218    }
219}
220
221/// # Resolving the server set
222///
223impl Query {
224    /// Resolves a provided server name.
225    async fn host_server(
226        &self,
227        server: &UncertainName<Vec<u8>>,
228    ) -> Result<Client, Error> {
229        let resolver = StubResolver::default();
230        let answer = match server {
231            UncertainName::Absolute(name) => resolver.lookup_host(name).await,
232            UncertainName::Relative(name) => resolver.search_host(name).await,
233        }
234        .map_err(|err| err.to_string())?;
235
236        let mut servers = Vec::new();
237        for addr in answer.iter() {
238            if (addr.is_ipv4() && self.ipv6) || (addr.is_ipv6() && self.ipv4)
239            {
240                continue;
241            }
242            servers.push(Server {
243                addr: SocketAddr::new(
244                    addr,
245                    self.port.unwrap_or({
246                        if self.tls {
247                            853
248                        } else {
249                            53
250                        }
251                    }),
252                ),
253                transport: self.transport(),
254                timeout: self.timeout(),
255                retries: self.retries.unwrap_or(2),
256                udp_payload_size: self.udp_payload_size.unwrap_or(1232),
257                tls_hostname: self.tls_hostname.clone(),
258            });
259        }
260        Ok(Client::with_servers(servers))
261    }
262
263    /// Resolves a provided server name.
264    fn addr_server(&self, addr: IpAddr) -> Client {
265        Client::with_servers(vec![Server {
266            addr: SocketAddr::new(
267                addr,
268                self.port.unwrap_or(if self.tls { 853 } else { 53 }),
269            ),
270            transport: self.transport(),
271            timeout: self.timeout(),
272            retries: self.retries(),
273            udp_payload_size: self.udp_payload_size(),
274            tls_hostname: self.tls_hostname.clone(),
275        }])
276    }
277
278    /// Creates a client based on the system defaults.
279    fn system_server(&self) -> Client {
280        let conf = ResolvConf::default();
281        Client::with_servers(
282            conf.servers
283                .iter()
284                .map(|server| Server {
285                    addr: server.addr,
286                    transport: self.transport(),
287                    timeout: server.request_timeout,
288                    retries: u8::try_from(conf.options.attempts).unwrap_or(2),
289                    udp_payload_size: server.udp_payload_size,
290                    tls_hostname: None,
291                })
292                .collect(),
293        )
294    }
295
296    fn transport(&self) -> Transport {
297        if self.udp {
298            Transport::Udp
299        } else if self.tls {
300            Transport::Tls
301        } else if self.tcp {
302            Transport::Tcp
303        } else {
304            Transport::UdpTcp
305        }
306    }
307}
308
309/// # Create the actual query
310///
311impl Query {
312    /// Creates a new request message.
313    fn create_request(&self) -> RequestMessage<Vec<u8>> {
314        let mut res = MessageBuilder::new_vec();
315
316        res.header_mut().set_ad(self.ad);
317        res.header_mut().set_cd(self.cd);
318        res.header_mut().set_rd(!self.no_rd);
319
320        let mut res = res.question();
321        res.push((&self.qname.to_name(), self.qtype())).unwrap();
322
323        let mut req = RequestMessage::new(res);
324        if self.dnssec_ok {
325            // Avoid touching the EDNS Opt record unless we need to set DO.
326            req.set_dnssec_ok(true);
327        }
328        req
329    }
330}
331
332/// # Get an authoritative answer
333impl Query {
334    async fn auth_answer(&self) -> Result<Answer, Error> {
335        let servers = {
336            let resolver = StubResolver::new();
337            let apex = self.get_apex(&resolver).await?;
338            let ns_set = self.get_ns_set(&apex, &resolver).await?;
339            self.get_ns_addrs(&ns_set, &resolver).await?
340        };
341        Client::with_servers(servers)
342            .query((self.qname.to_name(), self.qtype()))
343            .await
344    }
345
346    /// Tries to determine the apex of the zone the requested records live in.
347    async fn get_apex(
348        &self,
349        resolv: &StubResolver,
350    ) -> Result<Name<Vec<u8>>, Error> {
351        // Ask for the SOA record for the qname.
352        let qname = self.qname.to_name();
353        let response = resolv.query((&qname, Rtype::SOA)).await?;
354
355        // The SOA record is in the answer section if the qname is the apex
356        // or in the authority section with the apex as the owner name
357        // otherwise.
358        let mut answer = response.answer()?.limit_to_in::<Soa<_>>();
359        if let Some(soa) = answer.next() {
360            let soa = soa?;
361            if *soa.owner() == qname {
362                return Ok(qname.clone());
363            }
364            // Strange SOA in the answer section, let’s continue with
365            // the authority section.
366        }
367
368        let mut authority =
369            answer.next_section()?.unwrap().limit_to_in::<Soa<_>>();
370        if let Some(soa) = authority.next() {
371            let soa = soa?;
372            return Ok(soa.owner().to_name());
373        }
374
375        Err("no SOA record".into())
376    }
377
378    /// Tries to find the NS set for the given apex name.
379    async fn get_ns_set(
380        &self,
381        apex: &Name<Vec<u8>>,
382        resolv: &StubResolver,
383    ) -> Result<Vec<Name<Vec<u8>>>, Error> {
384        let response = resolv.query((apex, Rtype::NS)).await?;
385        let mut res = Vec::new();
386        for record in response.answer()?.limit_to_in::<Ns<_>>() {
387            let record = record?;
388            if *record.owner() != apex {
389                continue;
390            }
391            res.push(record.data().nsdname().to_name());
392        }
393
394        // We could technically get the A and AAAA records from the additional
395        // section, but we’re going to ask anyway, so: meh.
396
397        Ok(res)
398    }
399
400    /// Tries to get all the addresses for all the name servers.
401    async fn get_ns_addrs(
402        &self,
403        ns_set: &[Name<Vec<u8>>],
404        resolv: &StubResolver,
405    ) -> Result<Vec<Server>, Error> {
406        let mut res = HashSet::new();
407        for ns in ns_set {
408            for addr in resolv.lookup_host(ns).await?.iter() {
409                res.insert(addr);
410            }
411        }
412        Ok(res
413            .into_iter()
414            .map(|addr| Server {
415                addr: SocketAddr::new(addr, 53),
416                transport: Transport::UdpTcp,
417                timeout: self.timeout(),
418                retries: self.retries(),
419                udp_payload_size: self.udp_payload_size(),
420                tls_hostname: None,
421            })
422            .collect())
423    }
424
425    /// Produces a diff between two answer sections.
426    ///
427    /// Returns `Ok(None)` if the two answer sections are identical apart from
428    /// the TTLs.
429    #[allow(clippy::mutable_key_type)]
430    fn diff_answers(
431        left: &Message<Bytes>,
432        right: &Message<Bytes>,
433    ) -> Result<Option<Vec<DiffItem>>, Error> {
434        // Put all the answers into a two hashsets.
435        let left = left
436            .answer()?
437            .into_records::<AllRecordData<_, _>>()
438            .filter_map(Result::ok)
439            .map(|record| {
440                let class = record.class();
441                let (name, data) = record.into_owner_and_data();
442                (name, class, data)
443            })
444            .collect::<HashSet<_>>();
445
446        let right = right
447            .answer()?
448            .into_records::<AllRecordData<_, _>>()
449            .filter_map(Result::ok)
450            .map(|record| {
451                let class = record.class();
452                let (name, data) = record.into_owner_and_data();
453                (name, class, data)
454            })
455            .collect::<HashSet<_>>();
456
457        let mut diff = left
458            .intersection(&right)
459            .cloned()
460            .map(|item| (Action::Unchanged, item))
461            .collect::<Vec<_>>();
462        let size = diff.len();
463
464        diff.extend(
465            left.difference(&right)
466                .cloned()
467                .map(|item| (Action::Removed, item)),
468        );
469
470        diff.extend(
471            right
472                .difference(&left)
473                .cloned()
474                .map(|item| (Action::Added, item)),
475        );
476
477        diff.sort_by(|left, right| left.1.cmp(&right.1));
478
479        if size == diff.len() {
480            Ok(None)
481        } else {
482            Ok(Some(diff))
483        }
484    }
485
486    /// Prints the content of a diff.
487    fn output_diff(&self, diff: Vec<DiffItem>) {
488        for item in diff {
489            println!(
490                "{}{} {} {} {}",
491                item.0,
492                item.1 .0,
493                item.1 .1,
494                item.1 .2.rtype(),
495                item.1 .2
496            );
497        }
498    }
499
500    fn qtype(&self) -> Rtype {
501        match self.qtype {
502            Some(qtype) => qtype,
503            None => match self.qname {
504                NameOrAddr::Addr(_) => Rtype::PTR,
505                NameOrAddr::Name(_) => Rtype::AAAA,
506            },
507        }
508    }
509}
510
511//------------ ServerName ---------------------------------------------------
512
513#[derive(Clone, Debug)]
514enum ServerName {
515    Name(UncertainName<Vec<u8>>),
516    Addr(IpAddr),
517}
518
519impl FromStr for ServerName {
520    type Err = &'static str;
521
522    fn from_str(s: &str) -> Result<Self, Self::Err> {
523        if let Ok(addr) = IpAddr::from_str(s) {
524            Ok(ServerName::Addr(addr))
525        } else {
526            UncertainName::from_str(s)
527                .map(Self::Name)
528                .map_err(|_| "illegal host name")
529        }
530    }
531}
532
533//------------ NameOrAddr ----------------------------------------------------
534
535#[derive(Clone, Debug)]
536enum NameOrAddr {
537    Name(Name<Vec<u8>>),
538    Addr(IpAddr),
539}
540
541impl NameOrAddr {
542    fn to_name(&self) -> Name<Vec<u8>> {
543        match &self {
544            NameOrAddr::Name(host) => host.clone(),
545            NameOrAddr::Addr(addr) => {
546                Name::<Vec<u8>>::reverse_from_addr(*addr).unwrap()
547            }
548        }
549    }
550}
551
552impl FromStr for NameOrAddr {
553    type Err = &'static str;
554
555    fn from_str(s: &str) -> Result<Self, Self::Err> {
556        if let Ok(addr) = IpAddr::from_str(s) {
557            Ok(NameOrAddr::Addr(addr))
558        } else {
559            Name::from_str(s)
560                .map(Self::Name)
561                .map_err(|_| "illegal host name")
562        }
563    }
564}
565
566//------------ Action --------------------------------------------------------
567
568#[derive(Clone, Copy, Debug)]
569enum Action {
570    Added,
571    Removed,
572    Unchanged,
573}
574
575impl fmt::Display for Action {
576    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
577        f.write_str(match *self {
578            Self::Added => "+ ",
579            Self::Removed => "- ",
580            Self::Unchanged => "  ",
581        })
582    }
583}
584
585//----------- DiffItem -------------------------------------------------------
586
587type DiffItem = (
588    Action,
589    (
590        ParsedName<Bytes>,
591        Class,
592        AllRecordData<Bytes, ParsedName<Bytes>>,
593    ),
594);