hickory_proto/op/
message.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Basic protocol message for DNS
9
10use alloc::{boxed::Box, fmt, vec::Vec};
11use core::{iter, mem, ops::Deref};
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15use tracing::{debug, warn};
16
17use crate::{
18    error::*,
19    op::{Edns, Header, MessageType, OpCode, Query, ResponseCode},
20    rr::{Record, RecordType},
21    serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder, EncodeMode},
22    xfer::DnsResponse,
23};
24
25/// The basic request and response data structure, used for all DNS protocols.
26///
27/// [RFC 1035, DOMAIN NAMES - IMPLEMENTATION AND SPECIFICATION, November 1987](https://tools.ietf.org/html/rfc1035)
28///
29/// ```text
30/// 4.1. Format
31///
32/// All communications inside of the domain protocol are carried in a single
33/// format called a message.  The top level format of message is divided
34/// into 5 sections (some of which are empty in certain cases) shown below:
35///
36///     +--------------------------+
37///     |        Header            |
38///     +--------------------------+
39///     |  Question / Zone         | the question for the name server
40///     +--------------------------+
41///     |   Answer  / Prerequisite | RRs answering the question
42///     +--------------------------+
43///     | Authority / Update       | RRs pointing toward an authority
44///     +--------------------------+
45///     |      Additional          | RRs holding additional information
46///     +--------------------------+
47///
48/// The header section is always present.  The header includes fields that
49/// specify which of the remaining sections are present, and also specify
50/// whether the message is a query or a response, a standard query or some
51/// other opcode, etc.
52///
53/// The names of the sections after the header are derived from their use in
54/// standard queries.  The question section contains fields that describe a
55/// question to a name server.  These fields are a query type (QTYPE), a
56/// query class (QCLASS), and a query domain name (QNAME).  The last three
57/// sections have the same format: a possibly empty list of concatenated
58/// resource records (RRs).  The answer section contains RRs that answer the
59/// question; the authority section contains RRs that point toward an
60/// authoritative name server; the additional records section contains RRs
61/// which relate to the query, but are not strictly answers for the
62/// question.
63/// ```
64///
65/// By default Message is a Query. Use the Message::as_update() to create and update, or
66///  Message::new_update()
67#[derive(Clone, Debug, PartialEq, Eq, Default)]
68#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
69pub struct Message {
70    header: Header,
71    queries: Vec<Query>,
72    answers: Vec<Record>,
73    name_servers: Vec<Record>,
74    additionals: Vec<Record>,
75    signature: Vec<Record>,
76    edns: Option<Edns>,
77}
78
79impl Message {
80    /// Returns a new "empty" Message
81    pub fn new() -> Self {
82        Self {
83            header: Header::new(),
84            queries: Vec::new(),
85            answers: Vec::new(),
86            name_servers: Vec::new(),
87            additionals: Vec::new(),
88            signature: Vec::new(),
89            edns: None,
90        }
91    }
92
93    /// Returns a Message constructed with error details to return to a client
94    ///
95    /// # Arguments
96    ///
97    /// * `id` - message id should match the request message id
98    /// * `op_code` - operation of the request
99    /// * `response_code` - the error code for the response
100    pub fn error_msg(id: u16, op_code: OpCode, response_code: ResponseCode) -> Self {
101        let mut message = Self::new();
102        message
103            .set_message_type(MessageType::Response)
104            .set_id(id)
105            .set_response_code(response_code)
106            .set_op_code(op_code);
107
108        message
109    }
110
111    /// Truncates a Message, this blindly removes all response fields and sets truncated to `true`
112    pub fn truncate(&self) -> Self {
113        // copy header
114        let mut header = self.header;
115        header.set_truncated(true);
116        header
117            .set_additional_count(0)
118            .set_answer_count(0)
119            .set_name_server_count(0);
120
121        let mut msg = Self::new();
122        // drops additional/answer/nameservers/signature
123        // adds query/OPT
124        msg.add_queries(self.queries().iter().cloned());
125        if let Some(edns) = self.extensions().clone() {
126            msg.set_edns(edns);
127        }
128        // set header
129        msg.set_header(header);
130
131        // TODO, perhaps just quickly add a few response records here? that we know would fit?
132        msg
133    }
134
135    /// Sets the `Header` with provided
136    pub fn set_header(&mut self, header: Header) -> &mut Self {
137        self.header = header;
138        self
139    }
140
141    /// see `Header::set_id`
142    pub fn set_id(&mut self, id: u16) -> &mut Self {
143        self.header.set_id(id);
144        self
145    }
146
147    /// see `Header::set_message_type`
148    pub fn set_message_type(&mut self, message_type: MessageType) -> &mut Self {
149        self.header.set_message_type(message_type);
150        self
151    }
152
153    /// see `Header::set_op_code`
154    pub fn set_op_code(&mut self, op_code: OpCode) -> &mut Self {
155        self.header.set_op_code(op_code);
156        self
157    }
158
159    /// see `Header::set_authoritative`
160    pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
161        self.header.set_authoritative(authoritative);
162        self
163    }
164
165    /// see `Header::set_truncated`
166    pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
167        self.header.set_truncated(truncated);
168        self
169    }
170
171    /// see `Header::set_recursion_desired`
172    pub fn set_recursion_desired(&mut self, recursion_desired: bool) -> &mut Self {
173        self.header.set_recursion_desired(recursion_desired);
174        self
175    }
176
177    /// see `Header::set_recursion_available`
178    pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
179        self.header.set_recursion_available(recursion_available);
180        self
181    }
182
183    /// see `Header::set_authentic_data`
184    pub fn set_authentic_data(&mut self, authentic_data: bool) -> &mut Self {
185        self.header.set_authentic_data(authentic_data);
186        self
187    }
188
189    /// see `Header::set_checking_disabled`
190    pub fn set_checking_disabled(&mut self, checking_disabled: bool) -> &mut Self {
191        self.header.set_checking_disabled(checking_disabled);
192        self
193    }
194
195    /// see `Header::set_response_code`
196    pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
197        self.header.set_response_code(response_code);
198        self
199    }
200
201    /// see `Header::set_query_count`
202    ///
203    /// this count will be ignored during serialization,
204    /// where the length of the associated records will be used instead.
205    pub fn set_query_count(&mut self, query_count: u16) -> &mut Self {
206        self.header.set_query_count(query_count);
207        self
208    }
209
210    /// see `Header::set_answer_count`
211    ///
212    /// this count will be ignored during serialization,
213    /// where the length of the associated records will be used instead.
214    pub fn set_answer_count(&mut self, answer_count: u16) -> &mut Self {
215        self.header.set_answer_count(answer_count);
216        self
217    }
218
219    /// see `Header::set_name_server_count`
220    ///
221    /// this count will be ignored during serialization,
222    /// where the length of the associated records will be used instead.
223    pub fn set_name_server_count(&mut self, name_server_count: u16) -> &mut Self {
224        self.header.set_name_server_count(name_server_count);
225        self
226    }
227
228    /// see `Header::set_additional_count`
229    ///
230    /// this count will be ignored during serialization,
231    /// where the length of the associated records will be used instead.
232    pub fn set_additional_count(&mut self, additional_count: u16) -> &mut Self {
233        self.header.set_additional_count(additional_count);
234        self
235    }
236
237    /// Add a query to the Message, either the query response from the server, or the request Query.
238    pub fn add_query(&mut self, query: Query) -> &mut Self {
239        self.queries.push(query);
240        self
241    }
242
243    /// Adds an iterator over a set of Queries to be added to the message
244    pub fn add_queries<Q, I>(&mut self, queries: Q) -> &mut Self
245    where
246        Q: IntoIterator<Item = Query, IntoIter = I>,
247        I: Iterator<Item = Query>,
248    {
249        for query in queries {
250            self.add_query(query);
251        }
252
253        self
254    }
255
256    /// Add an answer to the Message
257    pub fn add_answer(&mut self, record: Record) -> &mut Self {
258        self.answers.push(record);
259        self
260    }
261
262    /// Add all the records from the iterator to the answers section of the Message
263    pub fn add_answers<R, I>(&mut self, records: R) -> &mut Self
264    where
265        R: IntoIterator<Item = Record, IntoIter = I>,
266        I: Iterator<Item = Record>,
267    {
268        for record in records {
269            self.add_answer(record);
270        }
271
272        self
273    }
274
275    /// Sets the answers to the specified set of Records.
276    ///
277    /// # Panics
278    ///
279    /// Will panic if answer records are already associated to the message.
280    pub fn insert_answers(&mut self, records: Vec<Record>) {
281        assert!(self.answers.is_empty());
282        self.answers = records;
283    }
284
285    /// Add a name server record to the Message
286    pub fn add_name_server(&mut self, record: Record) -> &mut Self {
287        self.name_servers.push(record);
288        self
289    }
290
291    /// Add all the records in the Iterator to the name server section of the message
292    pub fn add_name_servers<R, I>(&mut self, records: R) -> &mut Self
293    where
294        R: IntoIterator<Item = Record, IntoIter = I>,
295        I: Iterator<Item = Record>,
296    {
297        for record in records {
298            self.add_name_server(record);
299        }
300
301        self
302    }
303
304    /// Sets the name_servers to the specified set of Records.
305    ///
306    /// # Panics
307    ///
308    /// Will panic if name_servers records are already associated to the message.
309    pub fn insert_name_servers(&mut self, records: Vec<Record>) {
310        assert!(self.name_servers.is_empty());
311        self.name_servers = records;
312    }
313
314    /// Add an additional Record to the message
315    pub fn add_additional(&mut self, record: Record) -> &mut Self {
316        self.additionals.push(record);
317        self
318    }
319
320    /// Add all the records from the iterator to the additionals section of the Message
321    pub fn add_additionals<R, I>(&mut self, records: R) -> &mut Self
322    where
323        R: IntoIterator<Item = Record, IntoIter = I>,
324        I: Iterator<Item = Record>,
325    {
326        for record in records {
327            self.add_additional(record);
328        }
329
330        self
331    }
332
333    /// Sets the additional to the specified set of Records.
334    ///
335    /// # Panics
336    ///
337    /// Will panic if additional records are already associated to the message.
338    pub fn insert_additionals(&mut self, records: Vec<Record>) {
339        assert!(self.additionals.is_empty());
340        self.additionals = records;
341    }
342
343    /// Add the EDNS section to the Message
344    pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
345        self.edns = Some(edns);
346        self
347    }
348
349    /// Add a SIG0 record, i.e. sign this message
350    ///
351    /// This must be used only after all records have been associated. Generally this will be handled by the client and not need to be used directly
352    #[cfg(feature = "__dnssec")]
353    pub fn add_sig0(&mut self, record: Record) -> &mut Self {
354        assert_eq!(RecordType::SIG, record.record_type());
355        self.signature.push(record);
356        self
357    }
358
359    /// Add a TSIG record, i.e. authenticate this message
360    ///
361    /// This must be used only after all records have been associated. Generally this will be handled by the client and not need to be used directly
362    #[cfg(feature = "__dnssec")]
363    pub fn add_tsig(&mut self, record: Record) -> &mut Self {
364        assert_eq!(RecordType::TSIG, record.record_type());
365        self.signature.push(record);
366        self
367    }
368
369    /// Gets the header of the Message
370    pub fn header(&self) -> &Header {
371        &self.header
372    }
373
374    /// see `Header::id()`
375    pub fn id(&self) -> u16 {
376        self.header.id()
377    }
378
379    /// see `Header::message_type()`
380    pub fn message_type(&self) -> MessageType {
381        self.header.message_type()
382    }
383
384    /// see `Header::op_code()`
385    pub fn op_code(&self) -> OpCode {
386        self.header.op_code()
387    }
388
389    /// see `Header::authoritative()`
390    pub fn authoritative(&self) -> bool {
391        self.header.authoritative()
392    }
393
394    /// see `Header::truncated()`
395    pub fn truncated(&self) -> bool {
396        self.header.truncated()
397    }
398
399    /// see `Header::recursion_desired()`
400    pub fn recursion_desired(&self) -> bool {
401        self.header.recursion_desired()
402    }
403
404    /// see `Header::recursion_available()`
405    pub fn recursion_available(&self) -> bool {
406        self.header.recursion_available()
407    }
408
409    /// see `Header::authentic_data()`
410    pub fn authentic_data(&self) -> bool {
411        self.header.authentic_data()
412    }
413
414    /// see `Header::checking_disabled()`
415    pub fn checking_disabled(&self) -> bool {
416        self.header.checking_disabled()
417    }
418
419    /// # Return value
420    ///
421    /// The `ResponseCode`, if this is an EDNS message then this will join the section from the OPT
422    ///  record to create the EDNS `ResponseCode`
423    pub fn response_code(&self) -> ResponseCode {
424        self.header.response_code()
425    }
426
427    /// Returns the query from this Message.
428    ///
429    /// In almost all cases, a Message will only contain one query. This is a convenience function to get the single query.
430    /// See the alternative `queries*` methods for the raw set of queries in the Message
431    pub fn query(&self) -> Option<&Query> {
432        self.queries.first()
433    }
434
435    /// ```text
436    /// Question        Carries the query name and other query parameters.
437    /// ```
438    pub fn queries(&self) -> &[Query] {
439        &self.queries
440    }
441
442    /// Provides mutable access to `queries`
443    pub fn queries_mut(&mut self) -> &mut Vec<Query> {
444        &mut self.queries
445    }
446
447    /// Removes all the answers from the Message
448    pub fn take_queries(&mut self) -> Vec<Query> {
449        mem::take(&mut self.queries)
450    }
451
452    /// ```text
453    /// Answer          Carries RRs which directly answer the query.
454    /// ```
455    pub fn answers(&self) -> &[Record] {
456        &self.answers
457    }
458
459    /// Provides mutable access to `answers`
460    pub fn answers_mut(&mut self) -> &mut Vec<Record> {
461        &mut self.answers
462    }
463
464    /// Removes all the answers from the Message
465    pub fn take_answers(&mut self) -> Vec<Record> {
466        mem::take(&mut self.answers)
467    }
468
469    /// ```text
470    /// Authority       Carries RRs which describe other authoritative servers.
471    ///                 May optionally carry the SOA RR for the authoritative
472    ///                 data in the answer section.
473    /// ```
474    pub fn name_servers(&self) -> &[Record] {
475        &self.name_servers
476    }
477
478    /// Provides mutable access to `name_servers`
479    pub fn name_servers_mut(&mut self) -> &mut Vec<Record> {
480        &mut self.name_servers
481    }
482
483    /// Remove the name servers from the Message
484    pub fn take_name_servers(&mut self) -> Vec<Record> {
485        mem::take(&mut self.name_servers)
486    }
487
488    /// ```text
489    /// Additional      Carries RRs which may be helpful in using the RRs in the
490    ///                 other sections.
491    /// ```
492    pub fn additionals(&self) -> &[Record] {
493        &self.additionals
494    }
495
496    /// Provides mutable access to `additionals`
497    pub fn additionals_mut(&mut self) -> &mut Vec<Record> {
498        &mut self.additionals
499    }
500
501    /// Remove the additional Records from the Message
502    pub fn take_additionals(&mut self) -> Vec<Record> {
503        mem::take(&mut self.additionals)
504    }
505
506    /// All sections chained
507    pub fn all_sections(&self) -> impl Iterator<Item = &Record> {
508        self.answers
509            .iter()
510            .chain(self.name_servers().iter())
511            .chain(self.additionals.iter())
512    }
513
514    /// [RFC 6891, EDNS(0) Extensions, April 2013](https://tools.ietf.org/html/rfc6891#section-6.1.1)
515    ///
516    /// ```text
517    /// 6.1.1.  Basic Elements
518    ///
519    ///  An OPT pseudo-RR (sometimes called a meta-RR) MAY be added to the
520    ///  additional data section of a request.
521    ///
522    ///  The OPT RR has RR type 41.
523    ///
524    ///  If an OPT record is present in a received request, compliant
525    ///  responders MUST include an OPT record in their respective responses.
526    ///
527    ///  An OPT record does not carry any DNS data.  It is used only to
528    ///  contain control information pertaining to the question-and-answer
529    ///  sequence of a specific transaction.  OPT RRs MUST NOT be cached,
530    ///  forwarded, or stored in or loaded from Zone Files.
531    ///
532    ///  The OPT RR MAY be placed anywhere within the additional data section.
533    ///  When an OPT RR is included within any DNS message, it MUST be the
534    ///  only OPT RR in that message.  If a query message with more than one
535    ///  OPT RR is received, a FORMERR (RCODE=1) MUST be returned.  The
536    ///  placement flexibility for the OPT RR does not override the need for
537    ///  the TSIG or SIG(0) RRs to be the last in the additional section
538    ///  whenever they are present.
539    /// ```
540    /// # Return value
541    ///
542    /// Optionally returns a reference to EDNS section
543    #[deprecated(note = "Please use `extensions()`")]
544    pub fn edns(&self) -> Option<&Edns> {
545        self.edns.as_ref()
546    }
547
548    /// Optionally returns mutable reference to EDNS section
549    #[deprecated(
550        note = "Please use `extensions_mut()`. You can chain `.get_or_insert_with(Edns::new)` to recover original behavior of adding Edns if not present"
551    )]
552    pub fn edns_mut(&mut self) -> &mut Edns {
553        if self.edns.is_none() {
554            self.set_edns(Edns::new());
555        }
556        self.edns.as_mut().unwrap()
557    }
558
559    /// Returns reference of Edns section
560    pub fn extensions(&self) -> &Option<Edns> {
561        &self.edns
562    }
563
564    /// Returns mutable reference of Edns section
565    pub fn extensions_mut(&mut self) -> &mut Option<Edns> {
566        &mut self.edns
567    }
568
569    /// # Return value
570    ///
571    /// the max payload value as it's defined in the EDNS section.
572    pub fn max_payload(&self) -> u16 {
573        let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
574        if max_size < 512 { 512 } else { max_size }
575    }
576
577    /// # Return value
578    ///
579    /// the version as defined in the EDNS record
580    pub fn version(&self) -> u8 {
581        self.edns.as_ref().map_or(0, Edns::version)
582    }
583
584    /// [RFC 2535, Domain Name System Security Extensions, March 1999](https://tools.ietf.org/html/rfc2535#section-4)
585    ///
586    /// ```text
587    /// A DNS request may be optionally signed by including one or more SIGs
588    ///  at the end of the query. Such SIGs are identified by having a "type
589    ///  covered" field of zero. They sign the preceding DNS request message
590    ///  including DNS header but not including the IP header or any request
591    ///  SIGs at the end and before the request RR counts have been adjusted
592    ///  for the inclusions of any request SIG(s).
593    /// ```
594    ///
595    /// # Return value
596    ///
597    /// The sig0 and tsig, i.e. signed record, for verifying the sending and package integrity
598    // comportment change: can now return TSIG instead of SIG0. Maybe should get deprecated in
599    // favor of signature() which have more correct naming ?
600    pub fn sig0(&self) -> &[Record] {
601        &self.signature
602    }
603
604    /// [RFC 2535, Domain Name System Security Extensions, March 1999](https://tools.ietf.org/html/rfc2535#section-4)
605    ///
606    /// ```text
607    /// A DNS request may be optionally signed by including one or more SIGs
608    ///  at the end of the query. Such SIGs are identified by having a "type
609    ///  covered" field of zero. They sign the preceding DNS request message
610    ///  including DNS header but not including the IP header or any request
611    ///  SIGs at the end and before the request RR counts have been adjusted
612    ///  for the inclusions of any request SIG(s).
613    /// ```
614    ///
615    /// # Return value
616    ///
617    /// The sig0 and tsig, i.e. signed record, for verifying the sending and package integrity
618    pub fn signature(&self) -> &[Record] {
619        &self.signature
620    }
621
622    /// Remove signatures from the Message
623    pub fn take_signature(&mut self) -> Vec<Record> {
624        mem::take(&mut self.signature)
625    }
626
627    // TODO: only necessary in tests, should it be removed?
628    /// this is necessary to match the counts in the header from the record sections
629    ///  this happens implicitly on write_to, so no need to call before write_to
630    #[cfg(test)]
631    pub fn update_counts(&mut self) -> &mut Self {
632        self.header = update_header_counts(
633            &self.header,
634            self.truncated(),
635            HeaderCounts {
636                query_count: self.queries.len(),
637                answer_count: self.answers.len(),
638                nameserver_count: self.name_servers.len(),
639                additional_count: self.additionals.len(),
640            },
641        );
642        self
643    }
644
645    /// Attempts to read the specified number of `Query`s
646    pub fn read_queries(decoder: &mut BinDecoder<'_>, count: usize) -> ProtoResult<Vec<Query>> {
647        let mut queries = Vec::with_capacity(count);
648        for _ in 0..count {
649            queries.push(Query::read(decoder)?);
650        }
651        Ok(queries)
652    }
653
654    /// Attempts to read the specified number of records
655    ///
656    /// # Returns
657    ///
658    /// This returns a tuple of first standard Records, then a possibly associated Edns, and then finally any optionally associated SIG0 and TSIG records.
659    #[cfg_attr(not(feature = "__dnssec"), allow(unused_mut))]
660    pub fn read_records(
661        decoder: &mut BinDecoder<'_>,
662        count: usize,
663        is_additional: bool,
664    ) -> ProtoResult<(Vec<Record>, Option<Edns>, Vec<Record>)> {
665        let mut records: Vec<Record> = Vec::with_capacity(count);
666        let mut edns: Option<Edns> = None;
667        let mut sigs: Vec<Record> = Vec::with_capacity(if is_additional { 1 } else { 0 });
668
669        // sig0 must be last, once this is set, disable.
670        let mut saw_sig0 = false;
671        // tsig must be last, once this is set, disable.
672        let mut saw_tsig = false;
673        for _ in 0..count {
674            let record = Record::read(decoder)?;
675            if saw_tsig {
676                return Err("tsig must be final resource record".into());
677            } // TSIG must be last and multiple TSIG records are not allowed
678            if !is_additional {
679                if saw_sig0 {
680                    return Err("sig0 must be final resource record".into());
681                } // SIG0 must be last
682                records.push(record)
683            } else {
684                match record.record_type() {
685                    #[cfg(feature = "__dnssec")]
686                    RecordType::SIG => {
687                        saw_sig0 = true;
688                        sigs.push(record);
689                    }
690                    #[cfg(feature = "__dnssec")]
691                    RecordType::TSIG => {
692                        if saw_sig0 {
693                            return Err("sig0 must be final resource record".into());
694                        } // SIG0 must be last
695                        saw_tsig = true;
696                        sigs.push(record);
697                    }
698                    RecordType::OPT => {
699                        if saw_sig0 {
700                            return Err("sig0 must be final resource record".into());
701                        } // SIG0 must be last
702                        if edns.is_some() {
703                            return Err("more than one edns record present".into());
704                        }
705                        edns = Some((&record).into());
706                    }
707                    _ => {
708                        if saw_sig0 {
709                            return Err("sig0 must be final resource record".into());
710                        } // SIG0 must be last
711                        records.push(record);
712                    }
713                }
714            }
715        }
716
717        Ok((records, edns, sigs))
718    }
719
720    /// Decodes a message from the buffer.
721    pub fn from_vec(buffer: &[u8]) -> ProtoResult<Self> {
722        let mut decoder = BinDecoder::new(buffer);
723        Self::read(&mut decoder)
724    }
725
726    /// Encodes the Message into a buffer
727    pub fn to_vec(&self) -> Result<Vec<u8>, ProtoError> {
728        // TODO: this feels like the right place to verify the max packet size of the message,
729        //  will need to update the header for truncation and the lengths if we send less than the
730        //  full response. This needs to conform with the EDNS settings of the server...
731        let mut buffer = Vec::with_capacity(512);
732        {
733            let mut encoder = BinEncoder::new(&mut buffer);
734            self.emit(&mut encoder)?;
735        }
736
737        Ok(buffer)
738    }
739
740    /// Finalize the message prior to sending.
741    ///
742    /// Subsequent to calling this, the Message should not change.
743    #[allow(clippy::match_single_binding)]
744    pub fn finalize(
745        &mut self,
746        finalizer: &dyn MessageFinalizer,
747        inception_time: u32,
748    ) -> ProtoResult<Option<MessageVerifier>> {
749        debug!("finalizing message: {:?}", self);
750        let (finals, verifier): (Vec<Record>, Option<MessageVerifier>) =
751            finalizer.finalize_message(self, inception_time)?;
752
753        // append all records to message
754        for fin in finals {
755            match fin.record_type() {
756                // SIG0's are special, and come at the very end of the message
757                #[cfg(feature = "__dnssec")]
758                RecordType::SIG => self.add_sig0(fin),
759                #[cfg(feature = "__dnssec")]
760                RecordType::TSIG => self.add_tsig(fin),
761                _ => self.add_additional(fin),
762            };
763        }
764
765        Ok(verifier)
766    }
767
768    /// Consumes `Message` and returns into components
769    pub fn into_parts(self) -> MessageParts {
770        self.into()
771    }
772}
773
774impl From<MessageParts> for Message {
775    fn from(msg: MessageParts) -> Self {
776        let MessageParts {
777            header,
778            queries,
779            answers,
780            name_servers,
781            additionals,
782            sig0,
783            edns,
784        } = msg;
785        Self {
786            header,
787            queries,
788            answers,
789            name_servers,
790            additionals,
791            signature: sig0,
792            edns,
793        }
794    }
795}
796
797impl Deref for Message {
798    type Target = Header;
799
800    fn deref(&self) -> &Self::Target {
801        &self.header
802    }
803}
804
805/// Consumes `Message` giving public access to fields in `Message` so they can be
806/// destructured and taken by value
807/// ```rust
808/// use hickory_proto::op::{Message, MessageParts};
809///
810///  let msg = Message::new();
811///  let MessageParts { queries, .. } = msg.into_parts();
812/// ```
813#[derive(Clone, Debug, PartialEq, Eq, Default)]
814pub struct MessageParts {
815    /// message header
816    pub header: Header,
817    /// message queries
818    pub queries: Vec<Query>,
819    /// message answers
820    pub answers: Vec<Record>,
821    /// message name_servers
822    pub name_servers: Vec<Record>,
823    /// message additional records
824    pub additionals: Vec<Record>,
825    /// sig0 or tsig
826    // this can now contains TSIG too. It should probably be renamed to reflect that, but it's a
827    // breaking change
828    pub sig0: Vec<Record>,
829    /// optional edns records
830    pub edns: Option<Edns>,
831}
832
833impl From<Message> for MessageParts {
834    fn from(msg: Message) -> Self {
835        let Message {
836            header,
837            queries,
838            answers,
839            name_servers,
840            additionals,
841            signature,
842            edns,
843        } = msg;
844        Self {
845            header,
846            queries,
847            answers,
848            name_servers,
849            additionals,
850            sig0: signature,
851            edns,
852        }
853    }
854}
855
856/// Tracks the counts of the records in the Message.
857///
858/// This is only used internally during serialization.
859#[derive(Clone, Copy, Debug)]
860pub struct HeaderCounts {
861    /// The number of queries in the Message
862    pub query_count: usize,
863    /// The number of answers in the Message
864    pub answer_count: usize,
865    /// The number of nameservers or authorities in the Message
866    pub nameserver_count: usize,
867    /// The number of additional records in the Message
868    pub additional_count: usize,
869}
870
871/// Returns a new Header with accurate counts for each Message section
872pub fn update_header_counts(
873    current_header: &Header,
874    is_truncated: bool,
875    counts: HeaderCounts,
876) -> Header {
877    assert!(counts.query_count <= u16::MAX as usize);
878    assert!(counts.answer_count <= u16::MAX as usize);
879    assert!(counts.nameserver_count <= u16::MAX as usize);
880    assert!(counts.additional_count <= u16::MAX as usize);
881
882    // TODO: should the function just take by value?
883    let mut header = *current_header;
884    header
885        .set_query_count(counts.query_count as u16)
886        .set_answer_count(counts.answer_count as u16)
887        .set_name_server_count(counts.nameserver_count as u16)
888        .set_additional_count(counts.additional_count as u16)
889        .set_truncated(is_truncated);
890
891    header
892}
893
894/// Alias for a function verifying if a message is properly signed
895pub type MessageVerifier = Box<dyn FnMut(&[u8]) -> ProtoResult<DnsResponse> + Send>;
896
897/// A trait for performing final amendments to a Message before it is sent.
898///
899/// An example of this is a SIG0 signer, which needs the final form of the message,
900///  but then needs to attach additional data to the body of the message.
901pub trait MessageFinalizer: Send + Sync + 'static {
902    /// The message taken in should be processed and then return [`Record`]s which should be
903    ///  appended to the additional section of the message.
904    ///
905    /// # Arguments
906    ///
907    /// * `message` - message to process
908    /// * `current_time` - the current time as specified by the system, it's not recommended to read the current time as that makes testing complicated.
909    ///
910    /// # Return
911    ///
912    /// A vector to append to the additionals section of the message, sorted in the order as they should appear in the message.
913    fn finalize_message(
914        &self,
915        message: &Message,
916        current_time: u32,
917    ) -> ProtoResult<(Vec<Record>, Option<MessageVerifier>)>;
918
919    /// Return whether the message requires further processing before being sent
920    /// By default, returns true for AXFR and IXFR queries, and Update and Notify messages
921    fn should_finalize_message(&self, message: &Message) -> bool {
922        [OpCode::Update, OpCode::Notify].contains(&message.op_code())
923            || message
924                .queries()
925                .iter()
926                .any(|q| [RecordType::AXFR, RecordType::IXFR].contains(&q.query_type()))
927    }
928}
929
930/// Returns the count written and a boolean if it was truncated
931pub fn count_was_truncated(result: ProtoResult<usize>) -> ProtoResult<(usize, bool)> {
932    match result {
933        Ok(count) => Ok((count, false)),
934        Err(e) => match e.kind() {
935            ProtoErrorKind::NotAllRecordsWritten { count } => Ok((*count, true)),
936            _ => Err(e),
937        },
938    }
939}
940
941/// A trait that defines types which can be emitted as a set, with the associated count returned.
942pub trait EmitAndCount {
943    /// Emit self to the encoder and return the count of items
944    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize>;
945}
946
947impl<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable> EmitAndCount for I {
948    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
949        encoder.emit_all(self)
950    }
951}
952
953/// Emits the different sections of a message properly
954///
955/// # Return
956///
957/// In the case of a successful emit, the final header (updated counts, etc) is returned for help with logging, etc.
958#[allow(clippy::too_many_arguments)]
959pub fn emit_message_parts<Q, A, N, D>(
960    header: &Header,
961    queries: &mut Q,
962    answers: &mut A,
963    name_servers: &mut N,
964    additionals: &mut D,
965    edns: Option<&Edns>,
966    signature: &[Record],
967    encoder: &mut BinEncoder<'_>,
968) -> ProtoResult<Header>
969where
970    Q: EmitAndCount,
971    A: EmitAndCount,
972    N: EmitAndCount,
973    D: EmitAndCount,
974{
975    let include_signature = encoder.mode() != EncodeMode::Signing;
976    let place = encoder.place::<Header>()?;
977
978    let query_count = queries.emit(encoder)?;
979    // TODO: need to do something on max records
980    //  return offset of last emitted record.
981    let answer_count = count_was_truncated(answers.emit(encoder))?;
982    let nameserver_count = count_was_truncated(name_servers.emit(encoder))?;
983    let mut additional_count = count_was_truncated(additionals.emit(encoder))?;
984
985    if let Some(mut edns) = edns.cloned() {
986        // need to commit the error code
987        edns.set_rcode_high(header.response_code().high());
988
989        let count = count_was_truncated(encoder.emit_all(iter::once(&Record::from(&edns))))?;
990        additional_count.0 += count.0;
991        additional_count.1 |= count.1;
992    } else if header.response_code().high() > 0 {
993        warn!(
994            "response code: {} for request: {} requires EDNS but none available",
995            header.response_code(),
996            header.id()
997        );
998    }
999
1000    // this is a little hacky, but if we are Verifying a signature, i.e. the original Message
1001    //  then the SIG0 records should not be encoded and the edns record (if it exists) is already
1002    //  part of the additionals section.
1003    if include_signature {
1004        let count = count_was_truncated(encoder.emit_all(signature.iter()))?;
1005        additional_count.0 += count.0;
1006        additional_count.1 |= count.1;
1007    }
1008
1009    let counts = HeaderCounts {
1010        query_count,
1011        answer_count: answer_count.0,
1012        nameserver_count: nameserver_count.0,
1013        additional_count: additional_count.0,
1014    };
1015    let was_truncated =
1016        header.truncated() || answer_count.1 || nameserver_count.1 || additional_count.1;
1017
1018    let final_header = update_header_counts(header, was_truncated, counts);
1019    place.replace(encoder, final_header)?;
1020    Ok(final_header)
1021}
1022
1023impl BinEncodable for Message {
1024    fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
1025        emit_message_parts(
1026            &self.header,
1027            &mut self.queries.iter(),
1028            &mut self.answers.iter(),
1029            &mut self.name_servers.iter(),
1030            &mut self.additionals.iter(),
1031            self.edns.as_ref(),
1032            &self.signature,
1033            encoder,
1034        )?;
1035
1036        Ok(())
1037    }
1038}
1039
1040impl<'r> BinDecodable<'r> for Message {
1041    fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
1042        let mut header = Header::read(decoder)?;
1043
1044        // TODO: return just header, and in the case of the rest of message getting an error.
1045        //  this could improve error detection while decoding.
1046
1047        // get the questions
1048        let count = header.query_count() as usize;
1049        let mut queries = Vec::with_capacity(count);
1050        for _ in 0..count {
1051            queries.push(Query::read(decoder)?);
1052        }
1053
1054        // get all counts before header moves
1055        let answer_count = header.answer_count() as usize;
1056        let name_server_count = header.name_server_count() as usize;
1057        let additional_count = header.additional_count() as usize;
1058
1059        let (answers, _, _) = Self::read_records(decoder, answer_count, false)?;
1060        let (name_servers, _, _) = Self::read_records(decoder, name_server_count, false)?;
1061        let (additionals, edns, signature) = Self::read_records(decoder, additional_count, true)?;
1062
1063        // need to grab error code from EDNS (which might have a higher value)
1064        if let Some(edns) = &edns {
1065            let high_response_code = edns.rcode_high();
1066            header.merge_response_code(high_response_code);
1067        }
1068
1069        Ok(Self {
1070            header,
1071            queries,
1072            answers,
1073            name_servers,
1074            additionals,
1075            signature,
1076            edns,
1077        })
1078    }
1079}
1080
1081impl fmt::Display for Message {
1082    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1083        let write_query = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1084            for d in slice {
1085                writeln!(f, ";; {d}")?;
1086            }
1087
1088            Ok(())
1089        };
1090
1091        let write_slice = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1092            for d in slice {
1093                writeln!(f, "{d}")?;
1094            }
1095
1096            Ok(())
1097        };
1098
1099        writeln!(f, "; header {header}", header = self.header())?;
1100
1101        if let Some(edns) = self.extensions() {
1102            writeln!(f, "; edns {edns}")?;
1103        }
1104
1105        writeln!(f, "; query")?;
1106        write_query(self.queries(), f)?;
1107
1108        if self.header().message_type() == MessageType::Response
1109            || self.header().op_code() == OpCode::Update
1110        {
1111            writeln!(f, "; answers {}", self.answer_count())?;
1112            write_slice(self.answers(), f)?;
1113            writeln!(f, "; nameservers {}", self.name_server_count())?;
1114            write_slice(self.name_servers(), f)?;
1115            writeln!(f, "; additionals {}", self.additional_count())?;
1116            write_slice(self.additionals(), f)?;
1117        }
1118
1119        Ok(())
1120    }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125    use super::*;
1126
1127    #[test]
1128    fn test_emit_and_read_header() {
1129        let mut message = Message::new();
1130        message
1131            .set_id(10)
1132            .set_message_type(MessageType::Response)
1133            .set_op_code(OpCode::Update)
1134            .set_authoritative(true)
1135            .set_truncated(false)
1136            .set_recursion_desired(true)
1137            .set_recursion_available(true)
1138            .set_response_code(ResponseCode::ServFail);
1139
1140        test_emit_and_read(message);
1141    }
1142
1143    #[test]
1144    fn test_emit_and_read_query() {
1145        let mut message = Message::new();
1146        message
1147            .set_id(10)
1148            .set_message_type(MessageType::Response)
1149            .set_op_code(OpCode::Update)
1150            .set_authoritative(true)
1151            .set_truncated(true)
1152            .set_recursion_desired(true)
1153            .set_recursion_available(true)
1154            .set_response_code(ResponseCode::ServFail)
1155            .add_query(Query::new())
1156            .update_counts(); // we're not testing the query parsing, just message
1157
1158        test_emit_and_read(message);
1159    }
1160
1161    #[test]
1162    fn test_emit_and_read_records() {
1163        let mut message = Message::new();
1164        message
1165            .set_id(10)
1166            .set_message_type(MessageType::Response)
1167            .set_op_code(OpCode::Update)
1168            .set_authoritative(true)
1169            .set_truncated(true)
1170            .set_recursion_desired(true)
1171            .set_recursion_available(true)
1172            .set_authentic_data(true)
1173            .set_checking_disabled(true)
1174            .set_response_code(ResponseCode::ServFail);
1175
1176        message.add_answer(Record::stub());
1177        message.add_name_server(Record::stub());
1178        message.add_additional(Record::stub());
1179        message.update_counts();
1180
1181        test_emit_and_read(message);
1182    }
1183
1184    #[cfg(test)]
1185    fn test_emit_and_read(message: Message) {
1186        let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1187        {
1188            let mut encoder = BinEncoder::new(&mut byte_vec);
1189            message.emit(&mut encoder).unwrap();
1190        }
1191
1192        let mut decoder = BinDecoder::new(&byte_vec);
1193        let got = Message::read(&mut decoder).unwrap();
1194
1195        assert_eq!(got, message);
1196    }
1197
1198    #[test]
1199    fn test_header_counts_correction_after_emit_read() {
1200        let mut message = Message::new();
1201
1202        message
1203            .set_id(10)
1204            .set_message_type(MessageType::Response)
1205            .set_op_code(OpCode::Update)
1206            .set_authoritative(true)
1207            .set_truncated(true)
1208            .set_recursion_desired(true)
1209            .set_recursion_available(true)
1210            .set_authentic_data(true)
1211            .set_checking_disabled(true)
1212            .set_response_code(ResponseCode::ServFail);
1213
1214        message.add_answer(Record::stub());
1215        message.add_name_server(Record::stub());
1216        message.add_additional(Record::stub());
1217
1218        // at here, we don't call update_counts and we even set wrong count,
1219        // because we are trying to test whether the counts in the header
1220        // are correct after the message is emitted and read.
1221        message.set_query_count(1);
1222        message.set_answer_count(5);
1223        message.set_name_server_count(5);
1224        // message.set_additional_count(1);
1225
1226        let got = get_message_after_emitting_and_reading(message);
1227
1228        // make comparison
1229        assert_eq!(got.query_count(), 0);
1230        assert_eq!(got.answer_count(), 1);
1231        assert_eq!(got.name_server_count(), 1);
1232        assert_eq!(got.additional_count(), 1);
1233    }
1234
1235    #[cfg(test)]
1236    fn get_message_after_emitting_and_reading(message: Message) -> Message {
1237        let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1238        {
1239            let mut encoder = BinEncoder::new(&mut byte_vec);
1240            message.emit(&mut encoder).unwrap();
1241        }
1242
1243        let mut decoder = BinDecoder::new(&byte_vec);
1244
1245        Message::read(&mut decoder).unwrap()
1246    }
1247
1248    #[test]
1249    fn test_legit_message() {
1250        #[rustfmt::skip]
1251        let buf: Vec<u8> = vec![
1252            0x10, 0x00, 0x81,
1253            0x80, // id = 4096, response, op=query, recursion_desired, recursion_available, no_error
1254            0x00, 0x01, 0x00, 0x01, // 1 query, 1 answer,
1255            0x00, 0x00, 0x00, 0x00, // 0 nameservers, 0 additional record
1256            0x03, b'w', b'w', b'w', // query --- www.example.com
1257            0x07, b'e', b'x', b'a', //
1258            b'm', b'p', b'l', b'e', //
1259            0x03, b'c', b'o', b'm', //
1260            0x00,                   // 0 = endname
1261            0x00, 0x01, 0x00, 0x01, // RecordType = A, Class = IN
1262            0xC0, 0x0C,             // name pointer to www.example.com
1263            0x00, 0x01, 0x00, 0x01, // RecordType = A, Class = IN
1264            0x00, 0x00, 0x00, 0x02, // TTL = 2 seconds
1265            0x00, 0x04,             // record length = 4 (ipv4 address)
1266            0x5D, 0xB8, 0xD7, 0x0E, // address = 93.184.215.14
1267        ];
1268
1269        let mut decoder = BinDecoder::new(&buf);
1270        let message = Message::read(&mut decoder).unwrap();
1271
1272        assert_eq!(message.id(), 4_096);
1273
1274        let mut buf: Vec<u8> = Vec::with_capacity(512);
1275        {
1276            let mut encoder = BinEncoder::new(&mut buf);
1277            message.emit(&mut encoder).unwrap();
1278        }
1279
1280        let mut decoder = BinDecoder::new(&buf);
1281        let message = Message::read(&mut decoder).unwrap();
1282
1283        assert_eq!(message.id(), 4_096);
1284    }
1285
1286    #[test]
1287    fn rdata_zero_roundtrip() {
1288        let buf = &[
1289            160, 160, 0, 13, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
1290        ];
1291
1292        assert!(Message::from_bytes(buf).is_err());
1293    }
1294
1295    #[test]
1296    fn nsec_deserialization() {
1297        const CRASHING_MESSAGE: &[u8] = &[
1298            0, 0, 132, 0, 0, 0, 0, 1, 0, 0, 0, 1, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100,
1299            52, 50, 52, 45, 52, 102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55,
1300            56, 48, 102, 50, 98, 5, 108, 111, 99, 97, 108, 0, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4,
1301            192, 168, 1, 17, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100, 52, 50, 52, 45, 52,
1302            102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55, 56, 48, 102, 50, 98,
1303            5, 108, 111, 99, 97, 108, 0, 0, 47, 128, 1, 0, 0, 0, 120, 0, 5, 192, 70, 0, 1, 64,
1304        ];
1305
1306        Message::from_vec(CRASHING_MESSAGE).expect("failed to parse message");
1307    }
1308
1309    #[test]
1310    fn prior_to_pointer() {
1311        const MESSAGE: &[u8] = include_bytes!("../../tests/test-data/fuzz-prior-to-pointer.rdata");
1312        let message = Message::from_bytes(MESSAGE).expect("failed to parse message");
1313        let encoded = message.to_bytes().unwrap();
1314        Message::from_bytes(&encoded).expect("failed to parse encoded message");
1315    }
1316}