mtop_client/dns/
name.rs

1use crate::core::MtopError;
2use byteorder::{ReadBytesExt, WriteBytesExt};
3use std::fmt;
4use std::io::{Read, Seek, SeekFrom};
5use std::str::FromStr;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct Name {
9    labels: Vec<Vec<u8>>,
10    is_fqdn: bool,
11}
12
13impl Name {
14    const MAX_LENGTH: usize = 255;
15    const MAX_LABEL_LENGTH: usize = 63;
16    const MAX_POINTERS: u32 = 64;
17
18    pub fn root() -> Self {
19        Name {
20            labels: Vec::new(),
21            is_fqdn: true,
22        }
23    }
24
25    pub fn size(&self) -> usize {
26        self.labels.iter().map(|l| l.len()).sum::<usize>() + self.labels.len() + 1
27    }
28
29    pub fn is_root(&self) -> bool {
30        self.labels.is_empty() && self.is_fqdn
31    }
32
33    pub fn is_fqdn(&self) -> bool {
34        self.is_fqdn
35    }
36
37    pub fn to_fqdn(mut self) -> Self {
38        self.is_fqdn = true;
39        self
40    }
41
42    pub fn append(mut self, other: Name) -> Self {
43        if self.is_fqdn {
44            return self;
45        }
46
47        self.labels.extend(other.labels);
48        Self {
49            labels: self.labels,
50            is_fqdn: other.is_fqdn,
51        }
52    }
53
54    pub fn write_network_bytes<T>(&self, mut out: T) -> Result<(), MtopError>
55    where
56        T: WriteBytesExt,
57    {
58        // We convert all incoming Names to fully qualified names. If we missed doing
59        // that, it's a bug and we should panic here. Encoded names all end with the
60        // root so trying to encode something that doesn't makes no sense.
61        assert!(self.is_fqdn, "only fully qualified domains can be encoded");
62
63        for label in self.labels.iter() {
64            out.write_u8(label.len() as u8)?;
65            out.write_all(label)?;
66        }
67
68        Ok(out.write_u8(0)?)
69    }
70
71    pub fn read_network_bytes<T>(mut inp: T) -> Result<Self, MtopError>
72    where
73        T: ReadBytesExt + Seek,
74    {
75        let mut name = Vec::new();
76        Self::read_inner(&mut inp, &mut name)?;
77        Self::from_bytes(&name)
78    }
79
80    /// Read the DNS message format bytes for a `Name` and write them to `out`,
81    /// following any pointers (how names are compressed in DNS messages). The
82    /// bytes written to `out` are ASCII characters of the text representation
83    /// of the name, e.g. `"example.com.".as_bytes()`.
84    fn read_inner<T>(inp: &mut T, out: &mut Vec<u8>) -> Result<(), MtopError>
85    where
86        T: ReadBytesExt + Seek,
87    {
88        let mut pointers = 0;
89        let mut position = None;
90        loop {
91            // To avoid loops from badly behaved servers, only follow a fixed number of
92            // pointers when trying to resolve a single name. This number is picked to be
93            // much higher than most names will use but still finite.
94            if pointers > Self::MAX_POINTERS {
95                return Err(MtopError::runtime(format!(
96                    "reached max number of pointers ({}) while reading name",
97                    Self::MAX_POINTERS
98                )));
99            }
100
101            let len = inp.read_u8()?;
102            // If the length isn't a length but actually a pointer to another name
103            // or label within the message, seek to that position within the message
104            // and read the name from there. After resolving all pointers and reading
105            // labels, reset the stream back to immediately after the first pointer.
106            if Self::is_compressed_label(len) {
107                let offset = Self::get_offset(len, inp.read_u8()?);
108                if position.is_none() {
109                    position = Some(inp.stream_position()?);
110                }
111
112                inp.seek(SeekFrom::Start(u64::from(offset)))?;
113                pointers += 1;
114            } else if Self::is_standard_label(len) {
115                // If the length is a length, read the next label (segment) of the name
116                // returning early once we read the "root" label (`.`) signified by a
117                // length of 0.
118                if Self::read_label_into(inp, len, out)? {
119                    if let Some(p) = position {
120                        // If we followed a pointer to different part of the message while
121                        // parsing this name, seek to the position immediately after the
122                        // pointer now that we've finished parsing this name.
123                        inp.seek(SeekFrom::Start(p))?;
124                    }
125
126                    return Ok(());
127                }
128            } else {
129                // Binary labels are deprecated (RFC 6891) and there are (currently) no other
130                // types of labels that we should expect. Return an error to make this obvious.
131                return Err(MtopError::runtime(format!("unsupported Name label type: {}", len)));
132            }
133        }
134    }
135
136    /// If `len` doesn't indicate this is the root label, read the next name label into
137    /// `out` followed by a `.` and return false, true if next label was the root.
138    fn read_label_into<T>(inp: &mut T, len: u8, out: &mut Vec<u8>) -> Result<bool, MtopError>
139    where
140        T: ReadBytesExt + Seek,
141    {
142        if len == 0 {
143            return Ok(true);
144        }
145
146        // Only six bits of the length are supposed to be used to encode the
147        // length of a label so 63 is the max length but double check just in
148        // case one of the pointer bits was set for some reason.
149        if usize::from(len) > Self::MAX_LABEL_LENGTH {
150            return Err(MtopError::runtime(format!(
151                "max size for label would be exceeded reading {} bytes",
152                len,
153            )));
154        }
155
156        if usize::from(len) + out.len() + 1 > Self::MAX_LENGTH {
157            return Err(MtopError::runtime(format!(
158                "max size for name would be exceeded adding {} bytes to {}",
159                len,
160                out.len()
161            )));
162        }
163
164        let mut handle = inp.take(u64::from(len));
165        let n = handle.read_to_end(out)?;
166        if n != usize::from(len) {
167            return Err(MtopError::runtime(format!(
168                "short read for Name segment. expected {} got {}",
169                len, n
170            )));
171        }
172
173        out.push(b'.');
174        Ok(false)
175    }
176
177    fn is_standard_label(len: u8) -> bool {
178        len & 0b1100_0000 == 0
179    }
180
181    fn is_compressed_label(len: u8) -> bool {
182        // The top two bits of the length byte of a name label (section) are used
183        // to indicate the name is actually an offset in the DNS message to a previous
184        // name to avoid duplicating the same names over and over.
185        len & 0b1100_0000 == 192
186    }
187
188    fn get_offset(len: u8, next: u8) -> u16 {
189        let pointer = u16::from(len & 0b0011_1111) << 8;
190        pointer | u16::from(next)
191    }
192
193    /// Construct a new `Name` from the bytes of a string representation of a domain
194    /// name, e.g. `"example.com.".as_bytes()`. This method validates the total length
195    /// of the name, the length of each label, and the characters used for each label.
196    /// An empty byte slice or a byte slice with only the ASCII `.` character are treated
197    /// as the root domain.
198    fn from_bytes(bytes: &[u8]) -> Result<Self, MtopError> {
199        if bytes.is_empty() || bytes == b"." {
200            return Ok(Self::root());
201        }
202
203        // Trim any trailing dot from the domain which indicates that it is fully qualified
204        // and make a note of it. This is required since we're splitting the domain into each
205        // individual label and need to know how to construct the correct string form afterward.
206        let (bytes, is_fqdn) = match bytes.strip_suffix(b".") {
207            Some(stripped) => (stripped, true),
208            None => (bytes, false),
209        };
210
211        // We add 2 to the length in bytes of the name since we need to account for the
212        // byte used for the length of the first label when in encoded form (as opposed
213        // to the text form using ASCII bytes) and the dot used to encode the root.
214        if bytes.len() + 2 > Self::MAX_LENGTH {
215            return Err(MtopError::configuration(format!(
216                "Name too long; max {} bytes, got {}",
217                Self::MAX_LENGTH,
218                bytes.len()
219            )));
220        }
221
222        // Start with space for 4 labels which covers the size of typical domains.
223        let mut labels = Vec::with_capacity(4);
224
225        for label in bytes.split(|&b| b == b'.') {
226            let label_len = label.len();
227
228            if label_len > Self::MAX_LABEL_LENGTH {
229                return Err(MtopError::configuration(format!(
230                    "label too long; max {} bytes, got {}",
231                    Self::MAX_LABEL_LENGTH,
232                    label_len,
233                )));
234            }
235
236            for (i, b) in label.iter().enumerate() {
237                let c = char::from(*b);
238                if i == 0 && c != '_' && !c.is_ascii_alphanumeric() {
239                    return Err(MtopError::configuration(format!(
240                        "label must begin with ASCII letter, number, or underscore; got {}",
241                        c
242                    )));
243                } else if i == label_len - 1 && !c.is_ascii_alphanumeric() {
244                    return Err(MtopError::configuration(format!(
245                        "label must end with ASCII letter or number; got {}",
246                        c
247                    )));
248                } else if c != '-' && c != '_' && !c.is_ascii_alphanumeric() {
249                    return Err(MtopError::configuration(format!(
250                        "label must be ASCII letter, number, hyphen, or underscore; got {}",
251                        c
252                    )));
253                }
254            }
255
256            labels.push(label.to_ascii_lowercase());
257        }
258
259        Ok(Self { labels, is_fqdn })
260    }
261}
262
263impl fmt::Display for Name {
264    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265        let len = self.labels.len();
266        for (i, l) in self.labels.iter().enumerate() {
267            // SAFETY: We only allow construction of Name instances with ASCII labels
268            unsafe { str::from_utf8_unchecked(l) }.fmt(f)?;
269            if i != len - 1 {
270                ".".fmt(f)?;
271            }
272        }
273        if self.is_fqdn {
274            ".".fmt(f)?;
275        }
276
277        Ok(())
278    }
279}
280
281impl FromStr for Name {
282    type Err = MtopError;
283
284    fn from_str(s: &str) -> Result<Self, Self::Err> {
285        Self::from_bytes(s.as_bytes())
286    }
287}
288
289#[cfg(test)]
290mod test {
291    use super::Name;
292    use std::io::Cursor;
293    use std::str::FromStr;
294
295    #[test]
296    fn test_name_from_str_error_max_length_fqdn() {
297        let parts = [
298            "a".repeat(Name::MAX_LABEL_LENGTH),
299            "b".repeat(Name::MAX_LABEL_LENGTH),
300            "c".repeat(Name::MAX_LABEL_LENGTH),
301            "d".repeat(Name::MAX_LABEL_LENGTH - 1),
302        ];
303
304        let complete = {
305            let mut s = parts.join(".");
306            s.push('.');
307            s
308        };
309
310        // We're testing a name with:
311        // 1b + 63b + 1b + 63b + 1b + 63b + 1b + 62b + 1b = 256b
312        // One byte for the length of each label and one byte for the root.
313        let res = Name::from_str(&complete);
314        assert!(res.is_err());
315    }
316
317    #[test]
318    fn test_name_from_str_error_max_length_not_fqdn() {
319        let parts = [
320            "a".repeat(Name::MAX_LABEL_LENGTH),
321            "b".repeat(Name::MAX_LABEL_LENGTH),
322            "c".repeat(Name::MAX_LABEL_LENGTH),
323            "d".repeat(Name::MAX_LABEL_LENGTH - 1),
324        ];
325
326        let complete = parts.join(".");
327
328        // We're testing a name with:
329        // 1b + 63b + 1b + 63b + 1b + 63b + 1b + 62b = 255b
330        // that _omits_ the root from the end of the Name. We still need to
331        // validate that the Name is under the length limit including the root.
332        let res = Name::from_str(&complete);
333        assert!(res.is_err());
334    }
335
336    #[test]
337    fn test_name_from_str_error_max_label() {
338        let parts = ["a".repeat(Name::MAX_LABEL_LENGTH + 1), "com.".to_owned()];
339        let res = Name::from_str(&parts.join("."));
340        assert!(res.is_err());
341    }
342
343    #[test]
344    fn test_name_from_str_error_bad_label_start() {
345        let res = Name::from_str("-example.com.");
346        assert!(res.is_err());
347    }
348
349    #[test]
350    fn test_name_from_str_error_bad_label_end() {
351        let res = Name::from_str("example-.com.");
352        assert!(res.is_err());
353    }
354
355    #[test]
356    fn test_name_from_str_error_bad_label_char() {
357        let res = Name::from_str("exa%mple.com.");
358        assert!(res.is_err());
359    }
360
361    #[test]
362    fn test_name_from_str_success_not_fqdn() {
363        let name = Name::from_str("example.com").unwrap();
364        assert!(!name.is_root());
365        assert!(!name.is_fqdn());
366    }
367
368    #[test]
369    fn test_name_from_str_success_fqdn() {
370        let name = Name::from_str("example.com.").unwrap();
371        assert!(!name.is_root());
372        assert!(name.is_fqdn());
373    }
374
375    #[test]
376    fn test_name_from_str_success_root_empty() {
377        let name = Name::from_str("").unwrap();
378        assert!(name.is_root());
379        assert!(name.is_fqdn());
380    }
381
382    #[test]
383    fn test_name_from_str_success_root_dot() {
384        let name = Name::from_str(".").unwrap();
385        assert!(name.is_root());
386        assert!(name.is_fqdn());
387    }
388
389    #[test]
390    fn test_name_to_string_not_fqdn() {
391        let name = Name::from_str("example.com").unwrap();
392        assert_eq!("example.com", name.to_string());
393        assert!(!name.is_fqdn());
394    }
395
396    #[test]
397    fn test_name_to_string_fqdn() {
398        let name = Name::from_str("example.com.").unwrap();
399        assert_eq!("example.com.", name.to_string());
400        assert!(name.is_fqdn());
401    }
402
403    #[test]
404    fn test_name_to_string_root() {
405        let name = Name::root();
406        assert_eq!(".", name.to_string());
407        assert!(name.is_fqdn());
408    }
409
410    #[test]
411    fn test_name_to_fqdn_not_fqdn() {
412        let name = Name::from_str("example.com").unwrap();
413        assert!(!name.is_fqdn());
414
415        let fqdn = name.to_fqdn();
416        assert!(fqdn.is_fqdn());
417    }
418
419    #[test]
420    fn test_name_to_fqdn_already_fqdn() {
421        let name = Name::from_str("example.com.").unwrap();
422        assert!(name.is_fqdn());
423
424        let fqdn = name.to_fqdn();
425        assert!(fqdn.is_fqdn());
426    }
427
428    #[test]
429    fn test_name_append_already_fqdn() {
430        let name1 = Name::from_str("example.com.").unwrap();
431        let name2 = Name::from_str("example.net.").unwrap();
432        let combined = name1.clone().append(name2);
433
434        assert_eq!(name1, combined);
435        assert!(combined.is_fqdn());
436    }
437
438    #[test]
439    fn test_name_append_with_non_fqdn() {
440        let name1 = Name::from_str("www").unwrap();
441        let name2 = Name::from_str("example").unwrap();
442        let combined = name1.clone().append(name2);
443
444        assert_eq!(Name::from_str("www.example").unwrap(), combined);
445        assert!(!combined.is_fqdn());
446    }
447
448    #[test]
449    fn test_name_append_with_fqdn() {
450        let name1 = Name::from_str("www").unwrap();
451        let name2 = Name::from_str("example.net.").unwrap();
452        let combined = name1.clone().append(name2);
453
454        assert_eq!(Name::from_str("www.example.net.").unwrap(), combined);
455        assert!(combined.is_fqdn());
456    }
457
458    #[test]
459    fn test_name_append_with_root() {
460        let name = Name::from_str("example.com").unwrap();
461        let combined = name.clone().append(Name::root());
462
463        assert_eq!(Name::from_str("example.com.").unwrap(), combined);
464        assert!(combined.is_fqdn());
465    }
466
467    #[test]
468    fn test_name_append_multiple() {
469        let name1 = Name::from_str("dev").unwrap();
470        let name2 = Name::from_str("www").unwrap();
471        let name3 = Name::from_str("example.com").unwrap();
472
473        let combined = name1.append(name2).append(name3).append(Name::root());
474        assert_eq!(Name::from_str("dev.www.example.com.").unwrap(), combined);
475        assert!(combined.is_fqdn());
476    }
477
478    #[test]
479    fn test_name_size_root() {
480        let name = Name::root();
481        assert_eq!(1, name.size());
482    }
483
484    #[test]
485    fn test_name_size_non_root() {
486        let name = Name::from_str("example.com.").unwrap();
487        assert_eq!(13, name.size());
488    }
489
490    #[test]
491    fn test_name_size_non_root_fqdn() {
492        let name = Name::from_str("example.com").unwrap();
493        assert!(!name.is_fqdn());
494        assert_eq!(13, name.size());
495
496        // The purpose of the .size() method is to figure out how many bytes this
497        // name would be when serialized to binary message format. Only FQDN can be
498        // serialized so we expect the size to be the same between the non-FQDN and
499        // FQDN version of this name.
500        let name = name.to_fqdn();
501        assert!(name.is_fqdn());
502        assert_eq!(13, name.size());
503    }
504
505    #[test]
506    fn test_name_equal_same_case() {
507        let name1 = Name::from_str("example.com.").unwrap();
508        let name2 = Name::from_str("example.com.").unwrap();
509
510        assert_eq!(name1, name2);
511    }
512
513    #[test]
514    fn test_name_equal_different_case() {
515        let name1 = Name::from_str("example.com.").unwrap();
516        let name2 = Name::from_str("EXAMPLE.cOm.").unwrap();
517
518        assert_eq!(name1, name2);
519    }
520
521    #[test]
522    fn test_name_equal_different_fqdn() {
523        let name1 = Name::from_str("example.com").unwrap();
524        let name2 = Name::from_str("example.com.").unwrap();
525
526        assert_ne!(name1, name2);
527    }
528
529    #[test]
530    fn test_name_write_network_bytes_root() {
531        let mut cur = Cursor::new(Vec::new());
532        let name = Name::root();
533        name.write_network_bytes(&mut cur).unwrap();
534        let buf = cur.into_inner();
535
536        assert_eq!(vec![0], buf);
537    }
538
539    #[rustfmt::skip]
540    #[test]
541    fn test_name_write_network_bytes_not_root() {
542        let mut cur = Cursor::new(Vec::new());
543        let name = Name::from_str("example.com.").unwrap();
544        name.write_network_bytes(&mut cur).unwrap();
545        let buf = cur.into_inner();
546
547        assert_eq!(
548            vec![
549                7,                                // length
550                101, 120, 97, 109, 112, 108, 101, // "example"
551                3,                                // length
552                99, 111, 109,                     // "com"
553                0,                                // root
554            ],
555            buf,
556        );
557    }
558
559    #[should_panic]
560    #[test]
561    fn test_name_write_network_bytes_not_fqdn() {
562        let mut cur = Cursor::new(Vec::new());
563        let name = Name::from_str("example.com").unwrap();
564        let _ = name.write_network_bytes(&mut cur);
565    }
566
567    #[rustfmt::skip]
568    #[test]
569    fn test_name_read_network_bytes_no_pointer() {
570        let cur = Cursor::new(vec![
571            7,                                // length
572            101, 120, 97, 109, 112, 108, 101, // "example"
573            3,                                // length
574            99, 111, 109,                     // "com"
575            0,                                // root
576        ]);
577
578        let name = Name::read_network_bytes(cur).unwrap();
579        assert_eq!("example.com.", name.to_string());
580        assert!(name.is_fqdn());
581    }
582
583    #[rustfmt::skip]
584    #[test]
585    fn test_name_read_network_bytes_bad_label_type() {
586        let cur = Cursor::new(vec![
587            64, // length, deprecated binary labels from RFC 2673
588            0,  // count of binary labels
589        ]);
590
591        let res = Name::read_network_bytes(cur);
592        assert!(res.is_err());
593    }
594
595    #[rustfmt::skip]
596    #[test]
597    fn test_name_read_network_bytes_bad_label_type_after_single_pointer() {
598        let mut cur = Cursor::new(vec![
599            7,                                // length
600            101, 120, 97, 109, 112, 108, 101, // "example"
601            64,                               // length, deprecated binary labels from RFC 2673
602            0,                                // count of binary labels
603            3,                                // length
604            119, 119, 119,                    // "www"
605            192, 0,                           // pointer to offset 0
606        ]);
607
608        cur.set_position(10);
609
610        let res = Name::read_network_bytes(&mut cur);
611        assert!(res.is_err());
612    }
613
614    #[rustfmt::skip]
615    #[test]
616    fn test_name_read_network_bytes_single_pointer() {
617        let mut cur = Cursor::new(vec![
618            7,                                // length
619            101, 120, 97, 109, 112, 108, 101, // "example"
620            3,                                // length
621            99, 111, 109,                     // "com"
622            0,                                // root
623            3,                                // length
624            119, 119, 119,                    // "www"
625            192, 0,                           // pointer to offset 0
626        ]);
627
628        cur.set_position(13);
629
630        let name = Name::read_network_bytes(&mut cur).unwrap();
631        assert_eq!("www.example.com.", name.to_string());
632        assert!(name.is_fqdn());
633        assert_eq!(19, cur.position());
634    }
635
636    #[rustfmt::skip]
637    #[test]
638    fn test_name_read_network_bytes_multiple_pointer() {
639        let mut cur = Cursor::new(vec![
640            7,                                // length
641            101, 120, 97, 109, 112, 108, 101, // "example"
642            3,                                // length
643            99, 111, 109,                     // "com"
644            0,                                // root
645            3,                                // length
646            119, 119, 119,                    // "www"
647            192, 0,                           // pointer to offset 0
648            3,                                // length
649            100, 101, 118,                    // "dev"
650            192, 13,                          // pointer to offset 13, "www"
651        ]);
652
653        cur.set_position(19);
654
655        let name = Name::read_network_bytes(&mut cur).unwrap();
656        assert_eq!("dev.www.example.com.", name.to_string());
657        assert!(name.is_fqdn());
658        assert_eq!(25, cur.position());
659    }
660
661    #[rustfmt::skip]
662    #[test]
663    fn test_name_read_network_bytes_multiple_pointer_multiple_name() {
664        let mut cur = Cursor::new(vec![
665            7,                                // length
666            101, 120, 97, 109, 112, 108, 101, // "example"
667            3,                                // length
668            99, 111, 109,                     // "com"
669            0,                                // root
670            3,                                // length
671            119, 119, 119,                    // "www"
672            192, 0,                           // pointer to offset 0
673            3,                                // length
674            100, 101, 118,                    // "dev"
675            192, 13,                          // pointer to offset 13, "www"
676        ]);
677
678        let name1 = Name::read_network_bytes(&mut cur).unwrap();
679        assert_eq!("example.com.", name1.to_string());
680        assert!(name1.is_fqdn());
681
682        let name2 = Name::read_network_bytes(&mut cur).unwrap();
683        assert_eq!("www.example.com.", name2.to_string());
684        assert!(name2.is_fqdn());
685
686        let name3 = Name::read_network_bytes(&mut cur).unwrap();
687        assert_eq!("dev.www.example.com.", name3.to_string());
688        assert!(name3.is_fqdn());
689
690        assert_eq!(25, cur.position());
691    }
692
693    #[test]
694    fn test_name_read_network_bytes_pointer_loop() {
695        let mut cur = Cursor::new(vec![
696            192, 2, // pointer to offset 2
697            192, 0, // pointer to offset 0
698        ]);
699
700        let res = Name::read_network_bytes(&mut cur);
701        assert!(res.is_err());
702    }
703}