flex_dns/
lib.rs

1#![no_std]
2#![feature(generic_const_exprs)]
3
4use crate::additional::DnsAdditionals;
5use crate::answer::DnsAnswers;
6
7use crate::header::DnsHeader;
8use crate::name::DnsName;
9use crate::name_servers::DnsNameServers;
10use crate::parse::Parse;
11use crate::question::DnsQuestions;
12
13pub mod header;
14pub mod name;
15pub mod characters;
16pub mod question;
17pub mod name_servers;
18pub mod additional;
19pub mod answer;
20pub mod rdata;
21pub mod buffer;
22mod parse;
23mod write;
24
25pub use buffer::{Buffer, MutBuffer};
26
27#[derive(Debug, PartialEq)]
28pub enum DnsMessageError {
29    DnsError(DnsError),
30    BufferError(BufferError),
31}
32
33impl From<DnsError> for DnsMessageError {
34    fn from(e: DnsError) -> Self {
35        DnsMessageError::DnsError(e)
36    }
37}
38
39impl From<BufferError> for DnsMessageError {
40    fn from(e: BufferError) -> Self {
41        DnsMessageError::BufferError(e)
42    }
43}
44
45#[derive(Debug, PartialEq)]
46pub enum DnsError {
47    MessageTooShort,
48    InvalidHeader,
49    InvalidQuestion,
50    InvalidAnswer,
51    InvalidAuthority,
52    InvalidAdditional,
53    PointerIntoTheFuture,
54    PointerCycle,
55    NameTooLong,
56    LabelTooLong,
57    CharacterStringTooLong,
58    CharacterStringInvalidLength,
59    RDataLongerThanMessage,
60    UnexpectedEndOfBuffer,
61    InvalidTxtRecord,
62}
63
64#[derive(Debug, PartialEq)]
65pub enum BufferError {
66    OutOfMemory,
67    LengthOutOfBounds,
68    InvalidLength,
69    OffsetOutOfBounds,
70}
71
72const DNS_HEADER_SIZE: usize = 12;
73
74/// A DNS message.
75pub struct DnsMessage<
76    const PTR_STORAGE: usize,
77    const DNS_SECTION: usize,
78    B,
79> {
80    buffer: B,
81    position: usize,
82    // Pointers are stored as offsets from the start of the buffer
83    // We dont need this for reading, but we need it for writing compressed pointers
84    ptr_storage: [usize; PTR_STORAGE],
85    ptr_len: usize,
86}
87
88macro_rules! to_section_impl {
89    ($from:expr, $to:expr) => {
90        impl<
91            const PTR_STORAGE: usize,
92            B: Buffer,
93        > DnsMessage<PTR_STORAGE, { $from }, B> {
94            #[inline]
95            pub fn next_section(self) -> DnsMessage<PTR_STORAGE, { $to }, B> {
96                DnsMessage {
97                    buffer: self.buffer,
98                    position: self.position,
99                    ptr_storage: self.ptr_storage,
100                    ptr_len: self.ptr_len,
101                }
102            }
103        }
104    };
105}
106
107to_section_impl!(0, 1);
108to_section_impl!(1, 2);
109to_section_impl!(2, 3);
110
111impl<
112    const PTR_STORAGE: usize,
113    const SECTION: usize,
114    B: Buffer,
115> DnsMessage<PTR_STORAGE, SECTION, B> {
116    /// Creates a new DNS message with the given buffer.
117    #[inline(always)]
118    pub fn new(buffer: B) -> Result<Self, DnsMessageError> {
119        if buffer.len() < DNS_HEADER_SIZE {
120            return Err(DnsMessageError::DnsError(DnsError::MessageTooShort));
121        }
122
123        Ok(Self {
124            buffer,
125            position: DNS_HEADER_SIZE,
126            ptr_storage: [0; PTR_STORAGE],
127            ptr_len: 0,
128        })
129    }
130
131    /// Resets the message to the start of the buffer.
132    #[inline(always)]
133    pub fn reset(self) -> DnsMessage<PTR_STORAGE, 0, B> {
134        DnsMessage {
135            buffer: self.buffer,
136            position: 0,
137            ptr_storage: self.ptr_storage,
138            ptr_len: self.ptr_len,
139        }
140    }
141
142    /// Aborts the message and returns the buffer.
143    #[inline(always)]
144    pub fn abort(self) -> Result<B, DnsMessageError> {
145        Ok(self.buffer)
146    }
147
148    /// Returns the header of the message (read-only reference).
149    #[inline(always)]
150    pub fn header(&self) -> Result<&DnsHeader, DnsMessageError> {
151        if self.buffer.len() < DNS_HEADER_SIZE {
152            return Err(DnsMessageError::DnsError(DnsError::MessageTooShort));
153        }
154
155        Ok(DnsHeader::from_bytes(
156            self.buffer.read_bytes_at(0, DNS_HEADER_SIZE)?
157        ))
158    }
159
160    #[inline(always)]
161    pub(crate) fn bytes_and_position(&mut self) -> (&[u8], &mut usize) {
162        (self.buffer.bytes(), &mut self.position)
163    }
164}
165
166impl<
167    const PTR_STORAGE: usize,
168    const SECTION: usize,
169    B: MutBuffer + Buffer,
170> DnsMessage<PTR_STORAGE, SECTION, B> {
171    /// Creates a new DNS message with the given buffer.
172    #[inline(always)]
173    pub fn new_mut(mut buffer: B) -> Result<Self, DnsMessageError> {
174        if buffer.len() < DNS_HEADER_SIZE {
175            buffer.write_bytes(&[0; DNS_HEADER_SIZE])?;
176        }
177
178        Ok(Self {
179            buffer,
180            position: DNS_HEADER_SIZE,
181            ptr_storage: [0; PTR_STORAGE],
182            ptr_len: 0,
183        })
184    }
185
186    /// Returns the header of the message as a mutable reference.
187    #[inline(always)]
188    pub fn header_mut(&mut self) -> Result<&mut DnsHeader, DnsMessageError> {
189        self.position = core::cmp::max(self.position, DNS_HEADER_SIZE);
190        Ok(DnsHeader::from_bytes_mut(
191            self.buffer.read_bytes_at_mut(0, DNS_HEADER_SIZE)?
192        ))
193    }
194
195    #[inline(always)]
196    pub(crate) fn write_bytes(&mut self, bytes: &[u8]) -> Result<usize, DnsMessageError> {
197        self.position += bytes.len();
198        self.buffer.write_bytes(bytes)?;
199
200        Ok(bytes.len())
201    }
202
203    #[inline(always)]
204    pub(crate) fn truncate(&mut self) -> Result<(), DnsMessageError> {
205        self.buffer.truncate(self.position)?;
206
207        Ok(())
208    }
209
210    #[inline(always)]
211    pub(crate) fn write_placeholder<const SIZE: usize>(&mut self) -> Result<impl Fn(&mut Self, [u8; SIZE]) -> usize, DnsMessageError> {
212        let placeholder_pos = self.position;
213        self.position += SIZE;
214        self.buffer.write_bytes(&[0; SIZE])?;
215
216        Ok(move |message: &mut DnsMessage<PTR_STORAGE, SECTION, B>, bytes: [u8; SIZE]| {
217            message.buffer.write_array_at(placeholder_pos, bytes).unwrap();
218
219            SIZE
220        })
221    }
222
223    pub(crate) fn write_name(
224        &mut self,
225        name: DnsName,
226    ) -> Result<usize, DnsMessageError> {
227        // Try to find match
228        for &idx in &self.ptr_storage[..self.ptr_len] {
229            let mut i = idx;
230            let name_at_idx = DnsName::parse(self.buffer.bytes(), &mut i)?;
231            if name_at_idx == name {
232                return Ok(self.write_bytes(&(idx as u16 | 0b1100_0000_0000_0000).to_be_bytes())?);
233            }
234        }
235
236        // No match found, write name
237        let (first, rest) = name.split_first()?;
238        let original_position = self.position;
239        let mut bytes_written = 0;
240        bytes_written += self.write_bytes(&[first.len() as u8])?;
241        bytes_written += self.write_bytes(first)?;
242
243        if let Some(rest) = rest {
244            bytes_written += self.write_name(rest)?;
245        } else {
246            bytes_written += self.write_bytes(&[0])?; // Null terminator
247        }
248        if self.ptr_len < PTR_STORAGE {
249            // Store pointer for later, if we have space
250            // If we dont have space, we just write the name uncompressed
251            // in the future
252            self.ptr_storage[self.ptr_len] = original_position;
253            self.ptr_len += 1;
254        }
255
256        Ok(bytes_written)
257    }
258}
259
260impl<
261    const PTR_STORAGE: usize,
262    B: Buffer,
263> DnsMessage<PTR_STORAGE, 0, B> {
264    /// Read or write questions in the message.
265    #[inline(always)]
266    pub fn questions(self) -> DnsQuestions<PTR_STORAGE, B> {
267        DnsQuestions::new(self)
268    }
269
270    /// Completes and verifies the message and returns the buffer.
271    #[inline(always)]
272    pub fn complete(self) -> Result<(B, usize), DnsMessageError> {
273        // Read the full packet.
274        let questions = self.questions();
275        let message = questions.complete()?;
276        let answers = message.answers();
277        let message = answers.complete()?;
278        let name_servers = message.name_servers();
279        let message = name_servers.complete()?;
280        let additionals = message.additionals();
281        let message = additionals.complete()?;
282
283        Ok((message.buffer, message.position))
284    }
285}
286
287impl<
288    const PTR_STORAGE: usize,
289    B: Buffer,
290> DnsMessage<PTR_STORAGE, 1, B> {
291    /// Read or write answers in the message.
292    pub fn answers(self) -> DnsAnswers<PTR_STORAGE, B> {
293        DnsAnswers::new(self)
294    }
295
296    /// Completes and verifies the message and returns the buffer.
297    #[inline(always)]
298    pub fn complete(self) -> Result<(B, usize), DnsMessageError> {
299        // Read the full packet.
300        let answers = self.answers();
301        let message = answers.complete()?;
302        let name_servers = message.name_servers();
303        let message = name_servers.complete()?;
304        let additionals = message.additionals();
305        let message = additionals.complete()?;
306
307        Ok((message.buffer, message.position))
308    }
309}
310
311impl<
312    const PTR_STORAGE: usize,
313    B: Buffer,
314> DnsMessage<PTR_STORAGE, 2, B> {
315    /// Read or write name servers in the message.
316    pub fn name_servers(self) -> DnsNameServers<PTR_STORAGE, B> {
317        DnsNameServers::new(self)
318    }
319
320    /// Completes and verifies the message and returns the buffer.
321    #[inline(always)]
322    pub fn complete(self) -> Result<(B, usize), DnsMessageError> {
323        // Read the full packet.
324        let name_servers = self.name_servers();
325        let message = name_servers.complete()?;
326        let additionals = message.additionals();
327        let message = additionals.complete()?;
328
329        Ok((message.buffer, message.position))
330    }
331}
332
333impl<
334    const PTR_STORAGE: usize,
335    B: Buffer,
336> DnsMessage<PTR_STORAGE, 3, B> {
337    /// Read or write additionals in the message.
338    pub fn additionals(self) -> DnsAdditionals<PTR_STORAGE, B> {
339        DnsAdditionals::new(self)
340    }
341
342    /// Completes and verifies the message and returns the buffer.
343    #[inline(always)]
344    pub fn complete(self) -> Result<(B, usize), DnsMessageError> {
345        // Read the full packet.
346        let additionals = self.additionals();
347        let message = additionals.complete()?;
348
349        Ok((message.buffer, message.position))
350    }
351}
352
353#[cfg(any(feature = "heapless", feature = "arrayvec", feature = "vec"))]
354#[cfg(test)]
355mod test {
356    use super::*;
357
358    mod question {
359        use crate::header::{DnsHeaderOpcode, DnsHeaderResponseCode};
360        use crate::question::{DnsQClass, DnsQType, DnsQuestion};
361        use super::*;
362
363        #[cfg(feature = "heapless")]
364        mod test_heapless {
365            use heapless::Vec;
366            use super::*;
367
368            #[test]
369            fn test_question_heapless() {
370                test_question(Vec::<u8, 512>::new())
371            }
372
373            #[test]
374            fn test_question_heapless_mut() {
375                test_question(&mut Vec::<u8, 512>::new())
376            }
377        }
378
379        #[cfg(feature = "arrayvec")]
380        mod test_arrayvec {
381            use arrayvec::ArrayVec;
382            use super::*;
383
384            #[test]
385            fn test_question_arrayvec() {
386                test_question(ArrayVec::<u8, 512>::new())
387            }
388
389            #[test]
390            fn test_question_arrayvec_mut() {
391                test_question(&mut ArrayVec::<u8, 512>::new())
392            }
393
394
395            #[test]
396            fn query_google_com() {
397                let buffer = ArrayVec::from([
398                    0x00, 0x03, // ID
399                    0x01, 0x00, // Flags
400                    0x00, 0x01, // Question count
401                    0x00, 0x00, // Answer count
402                    0x00, 0x00, // Authority count
403                    0x00, 0x00, // Additional count
404                    0x06, b'g', b'o', b'o', b'g', b'l', b'e', // Name
405                    0x03, b'c', b'o', b'm', // Name
406                    0x00, // Name
407                    0x00, 0x01, // Type
408                    0x00, 0x01, // Class
409                ]);
410                let message: DnsMessage<8, 0, _> = DnsMessage::new(buffer).unwrap();
411                assert_eq!(message.header().unwrap().id(), 0x0003);
412                assert_eq!(message.header().unwrap().opcode(), DnsHeaderOpcode::Query);
413                assert_eq!(message.header().unwrap().authoritative_answer(), false);
414                assert_eq!(message.header().unwrap().truncated(), false);
415                assert_eq!(message.header().unwrap().recursion_desired(), true);
416                assert_eq!(message.header().unwrap().recursion_available(), false);
417                assert_eq!(message.header().unwrap().response_code(), DnsHeaderResponseCode::NoError);
418                let mut questions = message.questions();
419                let mut question_iter = questions.iter().unwrap();
420                let question = question_iter.next().unwrap().unwrap();
421                assert_eq!(question.name, DnsName::new(b"\x06google\x03com\x00").unwrap());
422                assert_eq!(question.qtype, DnsQType::A);
423                assert_eq!(question.qclass, DnsQClass::IN);
424                assert!(question_iter.next().is_none());
425            }
426
427            #[test]
428            fn query_google_com_and_garbage() {
429                let buffer = ArrayVec::from([
430                    0x00, 0x03, // ID
431                    0x01, 0x00, // Flags
432                    0x00, 0x01, // Question count
433                    0x00, 0x00, // Answer count
434                    0x00, 0x00, // Authority count
435                    0x00, 0x00, // Additional count
436                    0x06, b'g', b'o', b'o', b'g', b'l', b'e', // Name
437                    0x03, b'c', b'o', b'm', // Name
438                    0x00, // Name
439                    0x00, 0x01, // Type
440                    0x00, 0x01, // Class
441                    0x15, 0x16, 0x17, 0x18, // Garbage
442                ]);
443                let message: DnsMessage<8, 0, _> = DnsMessage::new(buffer).unwrap();
444                assert_eq!(message.header().unwrap().id(), 0x0003);
445                assert_eq!(message.header().unwrap().opcode(), DnsHeaderOpcode::Query);
446                assert_eq!(message.header().unwrap().authoritative_answer(), false);
447                assert_eq!(message.header().unwrap().truncated(), false);
448                assert_eq!(message.header().unwrap().recursion_desired(), true);
449                assert_eq!(message.header().unwrap().recursion_available(), false);
450                assert_eq!(message.header().unwrap().response_code(), DnsHeaderResponseCode::NoError);
451                let mut questions = message.questions();
452                let mut question_iter = questions.iter().unwrap();
453                let question = question_iter.next().unwrap().unwrap();
454                assert_eq!(question.name, DnsName::new(b"\x06google\x03com\x00").unwrap());
455                assert_eq!(question.qtype, DnsQType::A);
456                assert_eq!(question.qclass, DnsQClass::IN);
457                assert!(question_iter.next().is_none());
458                let message = questions.complete().unwrap();
459                let (buffer, pos) = message.complete().unwrap();
460                assert_eq!(buffer[pos..], [0x15, 0x16, 0x17, 0x18]);
461            }
462
463            #[test]
464            fn multiple_questions_compression() {
465                let buffer: ArrayVec<u8, 512> = ArrayVec::new();
466                let mut message: DnsMessage<8, 0, _> = DnsMessage::new_mut(buffer).unwrap();
467                message.header_mut().unwrap().set_id(0x1234);
468                message.header_mut().unwrap().set_opcode(DnsHeaderOpcode::Query);
469                message.header_mut().unwrap().set_authoritative_answer(false);
470                message.header_mut().unwrap().set_truncated(false);
471                message.header_mut().unwrap().set_recursion_desired(false);
472                message.header_mut().unwrap().set_recursion_available(false);
473                message.header_mut().unwrap().set_response_code(DnsHeaderResponseCode::NoError);
474                let mut questions = message.questions();
475                questions.append(DnsQuestion {
476                    name: DnsName::new(b"\x03www\x07example\x03com\x00").unwrap(),
477                    qtype: DnsQType::A,
478                    qclass: DnsQClass::IN,
479                }).unwrap();
480                questions.append(DnsQuestion {
481                    name: DnsName::new(b"\x03www\x07example\x03com\x00").unwrap(),
482                    qtype: DnsQType::AAAA,
483                    qclass: DnsQClass::IN,
484                }).unwrap();
485                questions.append(DnsQuestion {
486                    name: DnsName::new(b"\x03www\x07example\x03com\x00").unwrap(),
487                    qtype: DnsQType::MX,
488                    qclass: DnsQClass::IN,
489                }).unwrap();
490                questions.append(DnsQuestion {
491                    name: DnsName::new(b"\x03www\x08examples\x03com\x00").unwrap(),
492                    qtype: DnsQType::TXT,
493                    qclass: DnsQClass::IN,
494                }).unwrap();
495                questions.append(DnsQuestion {
496                    name: DnsName::new(b"\x08examples\x03com\x00").unwrap(),
497                    qtype: DnsQType::CERT,
498                    qclass: DnsQClass::IN,
499                }).unwrap();
500                let message = questions.complete().unwrap();
501                let buffer = message.abort().unwrap();
502
503                assert_eq!(
504                    buffer.as_slice(),
505                    [
506                        0x12, 0x34, // ID
507                        0b0000_0000, 0b0000_0000, // Flags
508                        0x00, 0x05, // Question count
509                        0x00, 0x00, // Answer count
510                        0x00, 0x00, // Authority count
511                        0x00, 0x00, // Additional count
512                        0x03, b'w', b'w', b'w', // Name
513                        0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', // Name
514                        0x03, b'c', b'o', b'm', // Name
515                        0x00, // Name
516                        0x00, 0x01, // Type
517                        0x00, 0x01, // Class
518                        0xC0, 0x0C, // Name Pointer (0x0C = 12)
519                        0x00, 0x1C, // Type
520                        0x00, 0x01, // Class
521                        0xC0, 0x0C, // Name Pointer (0x0C = 12)
522                        0x00, 0x0F, // Type
523                        0x00, 0x01, // Class
524                        0x03, b'w', b'w', b'w', // Name
525                        0x08, b'e', b'x', b'a', b'm', b'p', b'l', b'e', b's', // Name
526                        0xC0, 0x18, // Name Pointer (0x18 = 24)
527                        0x00, 0x10, // Type
528                        0x00, 0x01, // Class
529                        0xC0, 0x31, // Name Pointer (0x31 = 48)
530                        0x00, 0x25, // Type
531                        0x00, 0x01, // Class
532                    ].as_slice()
533                );
534
535                // Decode the message again and check that it is the same
536                let message: DnsMessage<8, 0, _> = DnsMessage::new(buffer).unwrap();
537                assert_eq!(message.header().unwrap().id(), 0x1234);
538                assert_eq!(message.header().unwrap().opcode(), DnsHeaderOpcode::Query);
539                assert_eq!(message.header().unwrap().authoritative_answer(), false);
540                assert_eq!(message.header().unwrap().truncated(), false);
541                assert_eq!(message.header().unwrap().recursion_desired(), false);
542                assert_eq!(message.header().unwrap().recursion_available(), false);
543                assert_eq!(message.header().unwrap().response_code(), DnsHeaderResponseCode::NoError);
544                let mut questions = message.questions();
545                let mut question_iter = questions.iter().unwrap();
546                let question = question_iter.next().unwrap().unwrap();
547                assert_eq!(question.name, DnsName::new(b"\x03www\x07example\x03com\x00").unwrap());
548                assert_eq!(question.qtype, DnsQType::A);
549                assert_eq!(question.qclass, DnsQClass::IN);
550                let question = question_iter.next().unwrap().unwrap();
551                assert_eq!(question.name, DnsName::new(b"\x03www\x07example\x03com\x00").unwrap());
552                assert_eq!(question.qtype, DnsQType::AAAA);
553                assert_eq!(question.qclass, DnsQClass::IN);
554                let question = question_iter.next().unwrap().unwrap();
555                assert_eq!(question.name, DnsName::new(b"\x03www\x07example\x03com\x00").unwrap());
556                assert_eq!(question.qtype, DnsQType::MX);
557                assert_eq!(question.qclass, DnsQClass::IN);
558                let question = question_iter.next().unwrap().unwrap();
559                assert_eq!(question.name, DnsName::new(b"\x03www\x08examples\x03com\x00").unwrap());
560                assert_eq!(question.qtype, DnsQType::TXT);
561                assert_eq!(question.qclass, DnsQClass::IN);
562                let question = question_iter.next().unwrap().unwrap();
563                assert_eq!(question.name, DnsName::new(b"\x08examples\x03com\x00").unwrap());
564                assert_eq!(question.qtype, DnsQType::CERT);
565                assert_eq!(question.qclass, DnsQClass::IN);
566                assert!(question_iter.next().is_none());
567            }
568        }
569
570        #[cfg(feature = "vec")]
571        mod test_alloc {
572            use alloc::vec::Vec;
573            use super::*;
574
575            extern crate alloc;
576
577            #[test]
578            fn test_question_vec() {
579                test_question(Vec::<u8>::new())
580            }
581
582            #[test]
583            fn test_question_vec_mut() {
584                test_question(&mut Vec::<u8>::new())
585            }
586        }
587
588        fn test_question<B: Buffer + MutBuffer>(buffer: B) {
589            let mut message: DnsMessage<8, 0, _> = DnsMessage::new_mut(buffer).unwrap();
590            message.header_mut().unwrap().set_id(0x1234);
591            message.header_mut().unwrap().set_opcode(DnsHeaderOpcode::Query);
592            message.header_mut().unwrap().set_authoritative_answer(false);
593            message.header_mut().unwrap().set_truncated(false);
594            message.header_mut().unwrap().set_recursion_desired(false);
595            message.header_mut().unwrap().set_recursion_available(false);
596            message.header_mut().unwrap().set_response_code(DnsHeaderResponseCode::NoError);
597            let mut questions = message.questions();
598            questions.append(DnsQuestion {
599                name: DnsName::new(b"\x03www\x07example\x03com\x00").unwrap(),
600                qtype: DnsQType::A,
601                qclass: DnsQClass::IN,
602            }).unwrap();
603            let message = questions.complete().unwrap();
604            let buffer = message.abort().unwrap();
605
606            assert_eq!(buffer.bytes(), [
607                0x12, 0x34, // ID
608                0b0000_0000, 0b0000_0000, // Flags
609                0x00, 0x01, // Question count
610                0x00, 0x00, // Answer count
611                0x00, 0x00, // Authority count
612                0x00, 0x00, // Additional count
613                0x03, b'w', b'w', b'w', // Name
614                0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', // Name
615                0x03, b'c', b'o', b'm', // Name
616                0x00, // Name
617                0x00, 0x01, // Type
618                0x00, 0x01, // Class
619            ].as_slice());
620
621            // Decode
622            let message: DnsMessage<8, 0, _> = DnsMessage::new(buffer).unwrap();
623            assert_eq!(message.header().unwrap().id(), 0x1234);
624            assert_eq!(message.header().unwrap().opcode(), DnsHeaderOpcode::Query);
625            assert_eq!(message.header().unwrap().authoritative_answer(), false);
626            assert_eq!(message.header().unwrap().truncated(), false);
627            assert_eq!(message.header().unwrap().recursion_desired(), false);
628            assert_eq!(message.header().unwrap().recursion_available(), false);
629            assert_eq!(message.header().unwrap().response_code(), DnsHeaderResponseCode::NoError);
630            let mut questions = message.questions();
631            let mut question_iter = questions.iter().unwrap();
632            let question = question_iter.next().unwrap().unwrap();
633            assert_eq!(question.name, DnsName::new(b"\x03www\x07example\x03com\x00").unwrap());
634            assert_eq!(question.qtype, DnsQType::A);
635            assert_eq!(question.qclass, DnsQClass::IN);
636            assert!(question_iter.next().is_none());
637        }
638    }
639
640
641    #[cfg(feature = "arrayvec")]
642    mod answer {
643        use arrayvec::ArrayVec;
644        use crate::answer::{DnsAClass, DnsAnswer};
645        use crate::header::{DnsHeaderOpcode, DnsHeaderResponseCode};
646        use crate::rdata::{A, DnsAType};
647        use super::*;
648
649        #[test]
650        fn single_answer() {
651            let buffer: ArrayVec<u8, 512> = ArrayVec::new();
652            let mut message: DnsMessage<8, 0, _> = DnsMessage::new_mut(buffer).unwrap();
653            message.header_mut().unwrap().set_id(0x1234);
654            message.header_mut().unwrap().set_opcode(DnsHeaderOpcode::Query);
655            message.header_mut().unwrap().set_authoritative_answer(false);
656            message.header_mut().unwrap().set_truncated(false);
657            message.header_mut().unwrap().set_recursion_desired(false);
658            message.header_mut().unwrap().set_recursion_available(false);
659            message.header_mut().unwrap().set_response_code(DnsHeaderResponseCode::NoError);
660            let message = message.questions().complete().unwrap();
661            let message = {
662                let mut answers = message.answers();
663                answers.append(DnsAnswer {
664                    name: DnsName::new(b"\x03www\x07example\x03com\x00").unwrap(),
665                    aclass: DnsAClass::IN,
666                    ttl: 0x12345678,
667                    rdata: DnsAType::A(A { address: [127, 0, 0, 1] }),
668                    cache_flush: false,
669                }).unwrap();
670                answers.complete().unwrap()
671            };
672            let buffer = message.abort().unwrap();
673
674            assert_eq!(
675                buffer.as_slice(),
676                [
677                    0x12, 0x34, // ID
678                    0b0000_0000, 0b0000_0000, // Flags
679                    0x00, 0x00, // Question count
680                    0x00, 0x01, // Answer count
681                    0x00, 0x00, // Authority count
682                    0x00, 0x00, // Additional count
683                    0x03, b'w', b'w', b'w', // Name
684                    0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', // Name
685                    0x03, b'c', b'o', b'm', // Name
686                    0x00, // Name
687                    0x00, 0x01, // Type
688                    0x00, 0x01, // Class
689                    0x12, 0x34, 0x56, 0x78, // TTL
690                    0x00, 0x04, // Data length
691                    127, 0, 0, 1, // Data
692                ].as_slice()
693            );
694
695            // Decode the message again and check that it is the same
696            let message: DnsMessage<8, 1, _> = DnsMessage::new(buffer).unwrap();
697            assert_eq!(message.header().unwrap().id(), 0x1234);
698            assert_eq!(message.header().unwrap().opcode(), DnsHeaderOpcode::Query);
699            assert_eq!(message.header().unwrap().authoritative_answer(), false);
700            assert_eq!(message.header().unwrap().truncated(), false);
701            assert_eq!(message.header().unwrap().recursion_desired(), false);
702            assert_eq!(message.header().unwrap().recursion_available(), false);
703            assert_eq!(message.header().unwrap().response_code(), DnsHeaderResponseCode::NoError);
704            let mut answers = message.answers();
705            let mut answer_iter = answers.iter().unwrap();
706            let answer = answer_iter.next().unwrap().unwrap();
707            assert_eq!(answer.name, DnsName::new(b"\x03www\x07example\x03com\x00").unwrap());
708            assert_eq!(answer.ttl, 0x12345678);
709            assert_eq!(answer.into_parsed().unwrap().rdata, DnsAType::A(A { address: [127, 0, 0, 1] }));
710            assert!(answer_iter.next().is_none());
711        }
712    }
713
714    #[cfg(feature = "arrayvec")]
715    mod error {
716        use arrayvec::ArrayVec;
717        use crate::header::{DnsHeaderOpcode, DnsHeaderResponseCode};
718        use super::*;
719
720        #[test]
721        fn truncated() {
722            let buffer: ArrayVec<u8, 12> = ArrayVec::from([
723                0x12, 0x34, // ID
724                0b0000_0000, 0b0000_0000, // Flags
725                0x00, 0x01, // Question count
726                0x00, 0x00, // Answer count
727                0x00, 0x00, // Authority count
728                0x00, 0x00, // Additional count
729                // Premature end of message
730            ]);
731            let message: DnsMessage<8, 0, _> = DnsMessage::new(buffer).unwrap();
732            assert_eq!(message.header().unwrap().id(), 0x1234);
733            assert_eq!(message.header().unwrap().opcode(), DnsHeaderOpcode::Query);
734            assert_eq!(message.header().unwrap().authoritative_answer(), false);
735            assert_eq!(message.header().unwrap().truncated(), false);
736            assert_eq!(message.header().unwrap().recursion_desired(), false);
737            assert_eq!(message.header().unwrap().recursion_available(), false);
738            assert_eq!(message.header().unwrap().response_code(), DnsHeaderResponseCode::NoError);
739            let mut questions = message.questions();
740            let mut question_iter = questions.iter().unwrap();
741            assert_eq!(
742                question_iter.next(),
743                Some(Err(DnsMessageError::DnsError(DnsError::UnexpectedEndOfBuffer)))
744            );
745        }
746    }
747}