mtop_client/dns/
message.rs

1use crate::core::MtopError;
2use crate::dns::core::{RecordClass, RecordType};
3use crate::dns::name::Name;
4use crate::dns::rdata::RecordData;
5use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
6use std::fmt;
7use std::io::Seek;
8
9#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
10#[repr(transparent)]
11pub struct MessageId(u16);
12
13impl MessageId {
14    pub fn random() -> Self {
15        Self(rand::random())
16    }
17
18    pub fn size(&self) -> usize {
19        2
20    }
21}
22
23impl From<u16> for MessageId {
24    fn from(value: u16) -> Self {
25        Self(value)
26    }
27}
28
29impl From<MessageId> for u16 {
30    fn from(value: MessageId) -> Self {
31        value.0
32    }
33}
34
35impl fmt::Display for MessageId {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        fmt::Display::fmt(&self.0, f)
38    }
39}
40
41#[derive(Debug, Clone, Eq, PartialEq)]
42pub struct Message {
43    id: MessageId,
44    flags: Flags,
45    questions: Vec<Question>,
46    answers: Vec<Record>,
47    authority: Vec<Record>,
48    extra: Vec<Record>,
49}
50
51impl Message {
52    pub fn new(id: MessageId, flags: Flags) -> Self {
53        Self {
54            id,
55            flags,
56            questions: Vec::new(),
57            answers: Vec::new(),
58            authority: Vec::new(),
59            extra: Vec::new(),
60        }
61    }
62
63    pub fn size(&self) -> usize {
64        self.id.size()
65            + self.flags.size()
66            + (2 * 4) // lengths of questions, answers, authority, extra
67            + self.questions.iter().map(|q| q.size()).sum::<usize>()
68            + self.answers.iter().map(|r| r.size()).sum::<usize>()
69            + self.authority.iter().map(|r| r.size()).sum::<usize>()
70            + self.extra.iter().map(|r| r.size()).sum::<usize>()
71    }
72
73    pub fn id(&self) -> MessageId {
74        self.id
75    }
76
77    pub fn flags(&self) -> Flags {
78        self.flags
79    }
80
81    pub fn set_flags(mut self, flags: Flags) -> Self {
82        self.flags = flags;
83        self
84    }
85
86    pub fn questions(&self) -> &[Question] {
87        &self.questions
88    }
89
90    pub fn add_question(mut self, q: Question) -> Self {
91        self.questions.push(q);
92        self
93    }
94
95    pub fn answers(&self) -> &[Record] {
96        &self.answers
97    }
98
99    pub fn add_answer(mut self, r: Record) -> Self {
100        self.answers.push(r);
101        self
102    }
103
104    pub fn authority(&self) -> &[Record] {
105        &self.authority
106    }
107
108    pub fn add_authority(mut self, r: Record) -> Self {
109        self.authority.push(r);
110        self
111    }
112
113    pub fn extra(&self) -> &[Record] {
114        &self.extra
115    }
116
117    pub fn add_extra(mut self, r: Record) -> Self {
118        self.extra.push(r);
119        self
120    }
121
122    fn header(&self) -> Header {
123        assert!(self.questions.len() < usize::from(u16::MAX));
124        assert!(self.answers.len() < usize::from(u16::MAX));
125        assert!(self.authority.len() < usize::from(u16::MAX));
126        assert!(self.extra.len() < usize::from(u16::MAX));
127
128        Header {
129            id: self.id,
130            flags: self.flags,
131            num_questions: self.questions.len() as u16,
132            num_answers: self.answers.len() as u16,
133            num_authority: self.authority.len() as u16,
134            num_extra: self.extra.len() as u16,
135        }
136    }
137
138    pub fn write_network_bytes<T>(&self, mut buf: T) -> Result<(), MtopError>
139    where
140        T: WriteBytesExt,
141    {
142        let header = self.header();
143        header.write_network_bytes(&mut buf)?;
144
145        for q in self.questions.iter() {
146            q.write_network_bytes(&mut buf)?;
147        }
148
149        for r in self.answers.iter() {
150            r.write_network_bytes(&mut buf)?;
151        }
152
153        for r in self.authority.iter() {
154            r.write_network_bytes(&mut buf)?;
155        }
156
157        for r in self.extra.iter() {
158            r.write_network_bytes(&mut buf)?;
159        }
160
161        Ok(())
162    }
163
164    pub fn read_network_bytes<T>(mut buf: T) -> Result<Self, MtopError>
165    where
166        T: ReadBytesExt + Seek,
167    {
168        let header = Header::read_network_bytes(&mut buf)?;
169
170        let mut questions = Vec::new();
171        for _ in 0..header.num_questions {
172            questions.push(Question::read_network_bytes(&mut buf)?);
173        }
174
175        let mut answers = Vec::new();
176        for _ in 0..header.num_answers {
177            answers.push(Record::read_network_bytes(&mut buf)?);
178        }
179
180        let mut authority = Vec::new();
181        for _ in 0..header.num_authority {
182            authority.push(Record::read_network_bytes(&mut buf)?);
183        }
184
185        let mut extra = Vec::new();
186        for _ in 0..header.num_extra {
187            extra.push(Record::read_network_bytes(&mut buf)?);
188        }
189
190        Ok(Self {
191            id: header.id,
192            flags: header.flags,
193            questions,
194            answers,
195            authority,
196            extra,
197        })
198    }
199}
200
201#[derive(Debug, Clone, Eq, PartialEq)]
202struct Header {
203    id: MessageId,
204    flags: Flags,
205    num_questions: u16,
206    num_answers: u16,
207    num_authority: u16,
208    num_extra: u16,
209}
210
211impl Header {
212    fn write_network_bytes<T>(&self, mut buf: T) -> Result<(), MtopError>
213    where
214        T: WriteBytesExt,
215    {
216        buf.write_u16::<NetworkEndian>(self.id.into())?;
217        buf.write_u16::<NetworkEndian>(self.flags.as_u16())?;
218        buf.write_u16::<NetworkEndian>(self.num_questions)?;
219        buf.write_u16::<NetworkEndian>(self.num_answers)?;
220        buf.write_u16::<NetworkEndian>(self.num_authority)?;
221        Ok(buf.write_u16::<NetworkEndian>(self.num_extra)?)
222    }
223
224    fn read_network_bytes<T>(mut buf: T) -> Result<Self, MtopError>
225    where
226        T: ReadBytesExt,
227    {
228        let id = MessageId::from(buf.read_u16::<NetworkEndian>()?);
229        let flags = Flags::try_from(buf.read_u16::<NetworkEndian>()?)?;
230        let num_questions = buf.read_u16::<NetworkEndian>()?;
231        let num_answers = buf.read_u16::<NetworkEndian>()?;
232        let num_authority = buf.read_u16::<NetworkEndian>()?;
233        let num_extra = buf.read_u16::<NetworkEndian>()?;
234
235        Ok(Header {
236            id,
237            flags,
238            num_questions,
239            num_answers,
240            num_authority,
241            num_extra,
242        })
243    }
244}
245
246#[derive(Default, Copy, Clone, Eq, PartialEq)]
247#[repr(transparent)]
248pub struct Flags(u16);
249
250impl Flags {
251    const MASK_QR: u16 = 0b1000_0000_0000_0000; // query / response
252    const MASK_OP: u16 = 0b0111_1000_0000_0000; // 4 bits, op code
253    const MASK_AA: u16 = 0b0000_0100_0000_0000; // authoritative answer
254    const MASK_TC: u16 = 0b0000_0010_0000_0000; // truncated
255    const MASK_RD: u16 = 0b0000_0001_0000_0000; // recursion desired
256    const MASK_RA: u16 = 0b0000_0000_1000_0000; // recursion available
257    const MASK_RC: u16 = 0b0000_0000_0000_1111; // 4 bits, response code
258
259    const OFFSET_QR: usize = 15;
260    const OFFSET_OP: usize = 11;
261    const OFFSET_AA: usize = 10;
262    const OFFSET_TC: usize = 9;
263    const OFFSET_RD: usize = 8;
264    const OFFSET_RA: usize = 7;
265    const OFFSET_RC: usize = 0;
266
267    pub fn size(&self) -> usize {
268        2
269    }
270
271    pub fn is_query(&self) -> bool {
272        !(self.0 & Self::MASK_QR) > 0
273    }
274
275    pub fn set_query(self) -> Self {
276        Flags(self.0 & !Self::MASK_QR)
277    }
278
279    pub fn is_response(&self) -> bool {
280        self.0 & Self::MASK_QR > 0
281    }
282
283    pub fn set_response(self) -> Self {
284        Flags(self.0 | Self::MASK_QR)
285    }
286
287    pub fn get_op_code(&self) -> Operation {
288        Operation::try_from((self.0 & Self::MASK_OP) >> Self::OFFSET_OP).unwrap()
289    }
290
291    pub fn set_op_code(self, op: Operation) -> Self {
292        let op = (op as u16) << Self::OFFSET_OP;
293        Flags(self.0 | op)
294    }
295
296    pub fn is_authoritative(&self) -> bool {
297        self.0 & Self::MASK_AA > 0
298    }
299
300    pub fn set_authoritative(self) -> Self {
301        Flags(self.0 | Self::MASK_AA)
302    }
303
304    pub fn is_truncated(&self) -> bool {
305        self.0 & Self::MASK_TC > 0
306    }
307
308    pub fn set_truncated(self) -> Self {
309        Flags(self.0 | Self::MASK_TC)
310    }
311
312    pub fn is_recursion_desired(&self) -> bool {
313        self.0 & Self::MASK_RD > 0
314    }
315
316    pub fn set_recursion_desired(self) -> Self {
317        Flags(self.0 | Self::MASK_RD)
318    }
319
320    pub fn is_recursion_available(&self) -> bool {
321        self.0 & Self::MASK_RA > 0
322    }
323
324    pub fn set_recursion_available(self) -> Self {
325        Flags(self.0 | Self::MASK_RA)
326    }
327
328    pub fn get_response_code(&self) -> ResponseCode {
329        ResponseCode::try_from((self.0 & Self::MASK_RC) >> Self::OFFSET_RC).unwrap()
330    }
331
332    pub fn set_response_code(self, code: ResponseCode) -> Self {
333        let code = (code as u16) << Self::OFFSET_RC;
334        Flags(self.0 | code)
335    }
336
337    pub fn as_u16(&self) -> u16 {
338        self.0
339    }
340}
341
342impl TryFrom<u16> for Flags {
343    type Error = MtopError;
344
345    fn try_from(value: u16) -> Result<Self, Self::Error> {
346        // Ensure that operation and response code are valid values but
347        // otherwise use the value as is. The rest of the fields are on/off
348        // bits so any combination is valid even if they don't make sense.
349        let _op = Operation::try_from((value & Self::MASK_OP) >> Self::OFFSET_OP)?;
350        let _rc = ResponseCode::try_from((value & Self::MASK_RC) >> Self::OFFSET_RC)?;
351        Ok(Flags(value))
352    }
353}
354
355impl fmt::Debug for Flags {
356    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
357        let qr = (self.0 & Self::MASK_QR) >> Self::OFFSET_QR;
358        let op = Operation::try_from((self.0 & Self::MASK_OP) >> Self::OFFSET_OP).unwrap();
359        let aa = (self.0 & Self::MASK_AA) >> Self::OFFSET_AA;
360        let tc = (self.0 & Self::MASK_TC) >> Self::OFFSET_TC;
361        let rd = (self.0 & Self::MASK_RD) >> Self::OFFSET_RD;
362        let ra = (self.0 & Self::MASK_RA) >> Self::OFFSET_RA;
363        let rc = ResponseCode::try_from((self.0 & Self::MASK_RC) >> Self::OFFSET_RC).unwrap();
364
365        write!(
366            f,
367            "Flags{{qr = {qr}, op = {op:?}, aa = {aa}, tc = {tc}, rd = {rd}, ra = {ra}, rc = {rc:?}}}"
368        )
369    }
370}
371
372#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
373#[repr(u16)]
374pub enum ResponseCode {
375    #[default]
376    NoError = 0,
377    FormatError = 1,
378    ServerFailure = 2,
379    NameError = 3,
380    NotImplemented = 4,
381    Refused = 5,
382    YxDomain = 6,
383    YxRrSet = 7,
384    NxRrSet = 8,
385    NotAuth = 9,
386    NotZone = 10,
387    BadVersion = 16,
388}
389
390impl fmt::Display for ResponseCode {
391    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392        fmt::Debug::fmt(self, f)
393    }
394}
395
396impl TryFrom<u16> for ResponseCode {
397    type Error = MtopError;
398
399    fn try_from(value: u16) -> Result<Self, Self::Error> {
400        match value {
401            0 => Ok(ResponseCode::NoError),
402            1 => Ok(ResponseCode::FormatError),
403            2 => Ok(ResponseCode::ServerFailure),
404            3 => Ok(ResponseCode::NameError),
405            4 => Ok(ResponseCode::NotImplemented),
406            5 => Ok(ResponseCode::Refused),
407            6 => Ok(ResponseCode::YxDomain),
408            7 => Ok(ResponseCode::YxRrSet),
409            8 => Ok(ResponseCode::NxRrSet),
410            9 => Ok(ResponseCode::NotAuth),
411            10 => Ok(ResponseCode::NotZone),
412            16 => Ok(ResponseCode::BadVersion),
413            _ => Err(MtopError::runtime(format!(
414                "invalid or unsupported response code {}",
415                value
416            ))),
417        }
418    }
419}
420
421#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
422#[repr(u16)]
423pub enum Operation {
424    #[default]
425    Query = 0,
426    IQuery = 1,
427    Status = 2,
428    Notify = 4,
429    Update = 5,
430}
431
432impl TryFrom<u16> for Operation {
433    type Error = MtopError;
434
435    fn try_from(value: u16) -> Result<Self, Self::Error> {
436        match value {
437            0 => Ok(Operation::Query),
438            1 => Ok(Operation::IQuery),
439            2 => Ok(Operation::Status),
440            4 => Ok(Operation::Notify),
441            5 => Ok(Operation::Update),
442            _ => Err(MtopError::runtime(format!(
443                "invalid or unsupported operation {}",
444                value
445            ))),
446        }
447    }
448}
449
450#[derive(Debug, Clone, Eq, PartialEq)]
451pub struct Question {
452    name: Name,
453    qtype: RecordType,
454    qclass: RecordClass,
455}
456
457impl Question {
458    pub fn new(name: Name, qtype: RecordType) -> Self {
459        Self {
460            name,
461            qtype,
462            qclass: RecordClass::INET,
463        }
464    }
465
466    pub fn size(&self) -> usize {
467        self.name.size() + self.qtype.size() + self.qclass.size()
468    }
469
470    pub fn set_qclass(mut self, qclass: RecordClass) -> Self {
471        self.qclass = qclass;
472        self
473    }
474
475    pub fn name(&self) -> &Name {
476        &self.name
477    }
478
479    pub fn qtype(&self) -> RecordType {
480        self.qtype
481    }
482
483    pub fn qclass(&self) -> RecordClass {
484        self.qclass
485    }
486
487    pub fn write_network_bytes<T>(&self, mut buf: T) -> Result<(), MtopError>
488    where
489        T: WriteBytesExt,
490    {
491        self.name.write_network_bytes(&mut buf)?;
492        buf.write_u16::<NetworkEndian>(self.qtype.into())?;
493        Ok(buf.write_u16::<NetworkEndian>(self.qclass.into())?)
494    }
495
496    pub fn read_network_bytes<T>(mut buf: T) -> Result<Self, MtopError>
497    where
498        T: ReadBytesExt + Seek,
499    {
500        let name = Name::read_network_bytes(&mut buf)?;
501        let qtype = RecordType::from(buf.read_u16::<NetworkEndian>()?);
502        let qclass = RecordClass::from(buf.read_u16::<NetworkEndian>()?);
503        Ok(Self { name, qtype, qclass })
504    }
505}
506
507#[derive(Debug, Clone, Eq, PartialEq)]
508pub struct Record {
509    name: Name,
510    rtype: RecordType,
511    rclass: RecordClass,
512    ttl: u32,
513    rdata: RecordData,
514}
515
516impl Record {
517    pub fn new(name: Name, rtype: RecordType, rclass: RecordClass, ttl: u32, rdata: RecordData) -> Self {
518        Self {
519            name,
520            rtype,
521            rclass,
522            ttl,
523            rdata,
524        }
525    }
526
527    pub fn size(&self) -> usize {
528        self.name.size()
529            + self.rtype.size()
530            + self.rclass.size()
531            + 4 // ttl
532            + 2 // rdata length
533            + self.rdata.size()
534    }
535
536    pub fn name(&self) -> &Name {
537        &self.name
538    }
539
540    pub fn rtype(&self) -> RecordType {
541        self.rtype
542    }
543
544    pub fn rclass(&self) -> RecordClass {
545        self.rclass
546    }
547
548    pub fn ttl(&self) -> u32 {
549        self.ttl
550    }
551
552    pub fn rdata(&self) -> &RecordData {
553        &self.rdata
554    }
555
556    pub fn write_network_bytes<T>(&self, mut buf: T) -> Result<(), MtopError>
557    where
558        T: WriteBytesExt,
559    {
560        // It shouldn't be possible for rdata to overflow u16 so if we do, that's a bug.
561        let size = self.rdata.size();
562        assert!(
563            size <= usize::from(u16::MAX),
564            "rdata length of {} bytes exceeds max of {} bytes",
565            size,
566            u16::MAX
567        );
568
569        self.name.write_network_bytes(&mut buf)?;
570        buf.write_u16::<NetworkEndian>(self.rtype.into())?;
571        buf.write_u16::<NetworkEndian>(self.rclass.into())?;
572        buf.write_u32::<NetworkEndian>(self.ttl)?;
573        buf.write_u16::<NetworkEndian>(size as u16)?;
574        self.rdata.write_network_bytes(&mut buf)
575    }
576
577    pub fn read_network_bytes<T>(mut buf: T) -> Result<Self, MtopError>
578    where
579        T: ReadBytesExt + Seek,
580    {
581        let name = Name::read_network_bytes(&mut buf)?;
582        let rtype = RecordType::from(buf.read_u16::<NetworkEndian>()?);
583        let rclass = RecordClass::from(buf.read_u16::<NetworkEndian>()?);
584        let ttl = buf.read_u32::<NetworkEndian>()?;
585        let rdata_len = buf.read_u16::<NetworkEndian>()?;
586        let rdata = RecordData::read_network_bytes(rtype, rdata_len, &mut buf)?;
587
588        Ok(Self {
589            name,
590            rtype,
591            rclass,
592            ttl,
593            rdata,
594        })
595    }
596}
597
598#[cfg(test)]
599mod test {
600    use super::{Flags, Header, Message, MessageId, Operation, Question, Record, ResponseCode};
601    use crate::dns::core::{RecordClass, RecordType};
602    use crate::dns::name::Name;
603    use crate::dns::rdata::{RecordData, RecordDataA, RecordDataSRV};
604    use std::io::Cursor;
605    use std::net::Ipv4Addr;
606    use std::str::FromStr;
607
608    #[rustfmt::skip]
609    #[test]
610    fn test_message_write_network_bytes() {
611        let question = Question::new(Name::from_str("_cache._tcp.example.com.").unwrap(), RecordType::SRV);
612        let answer_rdata = RecordData::SRV(RecordDataSRV::new(
613            10,
614            10,
615            11211,
616            Name::from_str("cache01.example.com.").unwrap(),
617        ));
618        let answer = Record::new(
619            Name::from_str("_cache._tcp.example.com.").unwrap(),
620            RecordType::SRV,
621            RecordClass::INET,
622            300,
623            answer_rdata,
624        );
625        let extra_rdata = RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 100)));
626        let extra = Record::new(
627            Name::from_str("cache01.example.com.").unwrap(),
628            RecordType::A,
629            RecordClass::INET,
630            60,
631            extra_rdata,
632        );
633
634        let message = Message::new(
635            MessageId::from(65333), Flags::default()
636                .set_response()
637                .set_op_code(Operation::Query)
638                .set_response_code(ResponseCode::NoError))
639            .add_question(question)
640            .add_answer(answer)
641            .add_extra(extra);
642
643        let mut cur = Cursor::new(Vec::new());
644        message.write_network_bytes(&mut cur).unwrap();
645        let buf = cur.into_inner();
646
647        assert_eq!(
648            vec![
649                // Header
650                255, 53, // ID
651                128, 0,  // Flags: response, query op, no error
652                0, 1,    // questions
653                0, 1,    // answers
654                0, 0,    // authority
655                0, 1,    // extra
656
657                // Question
658                6,                                // length
659                95, 99, 97, 99, 104, 101,         // "_cache"
660                4,                                // length
661                95, 116, 99, 112,                 // "_tcp"
662                7,                                // length
663                101, 120, 97, 109, 112, 108, 101, // "example"
664                3,                                // length
665                99, 111, 109,                     // "com"
666                0,                                // root
667                0, 33,                            // record type, SRV
668                0, 1,                             // record class, INET
669
670                // Answer
671                6,                                // length
672                95, 99, 97, 99, 104, 101,         // "_cache"
673                4,                                // length
674                95, 116, 99, 112,                 // "_tcp"
675                7,                                // length
676                101, 120, 97, 109, 112, 108, 101, // "example"
677                3,                                // length
678                99, 111, 109,                     // "com"
679                0,                                // root
680                0, 33,                            // record type, SRV
681                0, 1,                             // record class, INET
682                0, 0, 1, 44,                      // TTL
683                0, 27,                            // rdata size
684                0, 10,                            // priority
685                0, 10,                            // weight
686                43, 203,                          // port
687                7,                                // length
688                99, 97, 99, 104, 101, 48, 49,     // "cache01"
689                7,                                // length
690                101, 120, 97, 109, 112, 108, 101, // "example"
691                3,                                // length
692                99, 111, 109,                     // "com"
693                0,                                // root
694
695                // Extra
696                7,                                // length
697                99, 97, 99, 104, 101, 48, 49,     // "cache01"
698                7,                                // length
699                101, 120, 97, 109, 112, 108, 101, // "example"
700                3,                                // length
701                99, 111, 109,                     // "com"
702                0,                                // root
703                0, 1,                             // record type, A
704                0, 1,                             // record class, INET
705                0, 0, 0, 60,                      // TTL
706                0, 4,                             // rdata size
707                127, 0, 0, 100,                   // rdata, A address
708            ],
709            buf,
710        );
711    }
712
713    #[rustfmt::skip]
714    #[test]
715    fn test_message_read_network_bytes() {
716        let cur = Cursor::new(vec![
717            // Header
718            255, 53, // ID
719            128, 0,  // Flags: response, query op, no error
720            0, 1,    // questions
721            0, 1,    // answers
722            0, 0,    // authority
723            0, 1,    // extra
724
725            // Question
726            6,                                // length
727            95, 99, 97, 99, 104, 101,         // "_cache"
728            4,                                // length
729            95, 116, 99, 112,                 // "_tcp"
730            7,                                // length
731            101, 120, 97, 109, 112, 108, 101, // "example"
732            3,                                // length
733            99, 111, 109,                     // "com"
734            0,                                // root
735            0, 33,                            // record type, SRV
736            0, 1,                             // record class, INET
737
738            // Answer
739            6,                                // length
740            95, 99, 97, 99, 104, 101,         // "_cache"
741            4,                                // length
742            95, 116, 99, 112,                 // "_tcp"
743            7,                                // length
744            101, 120, 97, 109, 112, 108, 101, // "example"
745            3,                                // length
746            99, 111, 109,                     // "com"
747            0,                                // root
748            0, 33,                            // record type, SRV
749            0, 1,                             // record class, INET
750            0, 0, 1, 44,                      // TTL
751            0, 27,                            // rdata size
752            0, 10,                            // priority
753            0, 10,                            // weight
754            43, 203,                          // port
755            7,                                // length
756            99, 97, 99, 104, 101, 48, 49,     // "cache01"
757            7,                                // length
758            101, 120, 97, 109, 112, 108, 101, // "example"
759            3,                                // length
760            99, 111, 109,                     // "com"
761            0,                                // root
762
763            // Extra
764            7,                                // length
765            99, 97, 99, 104, 101, 48, 49,     // "cache01"
766            7,                                // length
767            101, 120, 97, 109, 112, 108, 101, // "example"
768            3,                                // length
769            99, 111, 109,                     // "com"
770            0,                                // root
771            0, 1,                             // record type, A
772            0, 1,                             // record class, INET
773            0, 0, 0, 60,                      // TTL
774            0, 4,                             // rdata size
775            127, 0, 0, 100,                   // rdata, A address
776        ]);
777
778        let message = Message::read_network_bytes(cur).unwrap();
779        assert_eq!(MessageId::from(65333), message.id());
780        assert_eq!(
781            Flags::default()
782                .set_response()
783                .set_response_code(ResponseCode::NoError)
784                .set_op_code(Operation::Query),
785            message.flags()
786        );
787
788        let questions = message.questions();
789        assert_eq!("_cache._tcp.example.com.", questions[0].name().to_string());
790        assert_eq!(RecordType::SRV, questions[0].qtype());
791        assert_eq!(RecordClass::INET, questions[0].qclass());
792
793        let answers = message.answers();
794        assert_eq!("_cache._tcp.example.com.", answers[0].name().to_string());
795        assert_eq!(RecordType::SRV, answers[0].rtype());
796        assert_eq!(RecordClass::INET, answers[0].rclass());
797        assert_eq!(300, answers[0].ttl());
798
799        if let RecordData::SRV(rd) = answers[0].rdata() {
800            assert_eq!(10, rd.weight());
801            assert_eq!(10, rd.priority());
802            assert_eq!(11211, rd.port());
803            assert_eq!("cache01.example.com.", rd.target().to_string());
804        } else {
805            panic!("unexpected record data type: {:?}", answers[0].rdata());
806        }
807
808        let extra = message.extra();
809        assert_eq!("cache01.example.com.", extra[0].name().to_string());
810        assert_eq!(RecordType::A, extra[0].rtype());
811        assert_eq!(RecordClass::INET, extra[0].rclass());
812        assert_eq!(60, extra[0].ttl());
813
814        if let RecordData::A(rd) = extra[0].rdata() {
815            assert_eq!(Ipv4Addr::new(127, 0, 0, 100), rd.addr());
816        } else {
817            panic!("unexpected record data type: {:?}", extra[0].rdata());
818        }
819    }
820
821    #[rustfmt::skip]
822    #[test]
823    fn test_header_write_network_bytes() {
824        let h = Header {
825            id: MessageId::from(65333),
826            flags: Flags::default().set_recursion_desired(),
827            num_questions: 1,
828            num_answers: 2,
829            num_authority: 3,
830            num_extra: 4,
831        };
832        let mut cur = Cursor::new(Vec::new());
833        h.write_network_bytes(&mut cur).unwrap();
834        let buf = cur.into_inner();
835
836        assert_eq!(
837            vec![
838                255, 53, // ID
839                1, 0,    // Flags, recursion desired
840                0, 1,    // questions
841                0, 2,    // answers
842                0, 3,    // authority
843                0, 4,    // extra
844            ],
845            buf,
846        )
847    }
848
849    #[rustfmt::skip]
850    #[test]
851    fn test_header_read_network_bytes() {
852        let cur = Cursor::new(vec![
853            255, 53, // ID
854            1, 0,    // Flags, recursion desired
855            0, 1,    // questions
856            0, 2,    // answers,
857            0, 3,    // authority
858            0, 4,    // extra
859        ]);
860
861        let h = Header::read_network_bytes(cur).unwrap();
862        assert_eq!(MessageId::from(65333), h.id);
863        assert_eq!(Flags::default().set_recursion_desired(), h.flags);
864        assert_eq!(1, h.num_questions);
865        assert_eq!(2, h.num_answers);
866        assert_eq!(3, h.num_authority);
867        assert_eq!(4, h.num_extra);
868    }
869
870    #[test]
871    fn test_flags() {
872        let f = Flags::default().set_query();
873        assert!(f.is_query());
874
875        let f = Flags::default().set_response();
876        assert!(f.is_response());
877
878        let f = Flags::default().set_op_code(Operation::Notify);
879        assert_eq!(Operation::Notify, f.get_op_code());
880
881        let f = Flags::default().set_authoritative();
882        assert!(f.is_authoritative());
883
884        let f = Flags::default().set_truncated();
885        assert!(f.is_truncated());
886
887        let f = Flags::default().set_recursion_desired();
888        assert!(f.is_recursion_desired());
889
890        let f = Flags::default().set_recursion_available();
891        assert!(f.is_recursion_available());
892
893        let f = Flags::default().set_response_code(ResponseCode::ServerFailure);
894        assert_eq!(ResponseCode::ServerFailure, f.get_response_code());
895
896        let f = Flags::default()
897            .set_query()
898            .set_recursion_desired()
899            .set_op_code(Operation::Query);
900        assert!(f.is_query());
901        assert!(f.is_recursion_desired());
902        assert_eq!(Operation::Query, f.get_op_code());
903    }
904
905    #[rustfmt::skip]
906    #[test]
907    fn test_question_write_network_bytes() {
908        let q = Question::new(Name::from_str("example.com.").unwrap(), RecordType::AAAA);
909        let size = q.size();
910        let mut cur = Cursor::new(Vec::new());
911        q.write_network_bytes(&mut cur).unwrap();
912        let buf = cur.into_inner();
913
914        assert_eq!(
915            vec![
916                7,                                // length
917                101, 120, 97, 109, 112, 108, 101, // "example"
918                3,                                // length
919                99, 111, 109,                     // "com"
920                0,                                // root
921                0, 28,                            // AAAA record
922                0, 1,                             // INET class
923            ],
924            buf,
925        );
926        assert_eq!(size, buf.len());
927    }
928
929    #[rustfmt::skip]
930    #[test]
931    fn test_question_read_network_bytes() {
932        let cur = Cursor::new(vec![
933            7,                                // length
934            101, 120, 97, 109, 112, 108, 101, // "example"
935            3,                                // length
936            99, 111, 109,                     // "com"
937            0,                                // root
938            0, 28,                            // AAAA record
939            0, 1,                             // INET class
940        ]);
941
942        let size = cur.get_ref().len();
943        let q = Question::read_network_bytes(cur).unwrap();
944        assert_eq!("example.com.", q.name().to_string());
945        assert_eq!(RecordType::AAAA, q.qtype());
946        assert_eq!(RecordClass::INET, q.qclass());
947        assert_eq!(size, q.size());
948    }
949
950    #[rustfmt::skip]
951    #[test]
952    fn test_record_write_network_bytes() {
953        let rr = Record::new(
954            Name::from_str("www.example.com.").unwrap(),
955            RecordType::A,
956            RecordClass::INET,
957            300,
958            RecordData::A(RecordDataA::new(Ipv4Addr::new(127, 0, 0, 100))),
959        );
960        let size = rr.size();
961        let mut cur = Cursor::new(Vec::new());
962        rr.write_network_bytes(&mut cur).unwrap();
963        let buf = cur.into_inner();
964
965        assert_eq!(
966            vec![
967                3,                                // length
968                119, 119, 119,                    // "www"
969                7,                                // length
970                101, 120, 97, 109, 112, 108, 101, // "example"
971                3,                                // length
972                99, 111, 109,                     // "com"
973                0,                                // root
974                0, 1,                             // record type, A
975                0, 1,                             // record class, INET
976                0, 0, 1, 44,                      // TTL
977                0, 4,                             // rdata size
978                127, 0, 0, 100,                   // rdata, A address
979            ],
980            buf,
981        );
982        assert_eq!(size, buf.len());
983    }
984
985    #[rustfmt::skip]
986    #[test]
987    fn test_record_read_network_bytes() {
988        let cur = Cursor::new(vec![
989            3,                                // length
990            119, 119, 119,                    // "www"
991            7,                                // length
992            101, 120, 97, 109, 112, 108, 101, // "example"
993            3,                                // length
994            99, 111, 109,                     // "com"
995            0,                                // root
996            0, 1,                             // record type, A
997            0, 1,                             // record class, INET
998            0, 0, 1, 44,                      // TTL
999            0, 4,                             // rdata size
1000            127, 0, 0, 100,                   // rdata, A address
1001        ]);
1002
1003        let size = cur.get_ref().len();
1004        let rr = Record::read_network_bytes(cur).unwrap();
1005        assert_eq!("www.example.com.", rr.name().to_string());
1006        assert_eq!(RecordType::A, rr.rtype());
1007        assert_eq!(RecordClass::INET, rr.rclass());
1008        assert_eq!(300, rr.ttl());
1009        if let RecordData::A(rd) = rr.rdata() {
1010            assert_eq!(Ipv4Addr::new(127, 0, 0, 100), rd.addr());
1011        } else {
1012            panic!("unexpected rdata type: {:?}", rr.rdata());
1013        }
1014        assert_eq!(size, rr.size());
1015    }
1016}