rsdns 0.23.0

DNS Client Library
Documentation
use crate::{bytes::WCursor, constants::DOMAIN_NAME_MAX_LENGTH, Error, Result};

impl WCursor<'_> {
    #[inline]
    fn write_label(&mut self, label: &[u8]) -> Result<()> {
        super::check_label_bytes(label)?;
        if self.len() > label.len() {
            unsafe {
                self.u8_unchecked(label.len() as u8);
                self.bytes_unchecked(label);
            }
            Ok(())
        } else {
            Err(Error::BufferTooShort(self.pos() + label.len() + 1))
        }
    }

    pub fn write_domain_name(&mut self, name: &str) -> Result<usize> {
        self.write_domain_name_bytes(name.as_bytes())
    }

    pub fn write_domain_name_bytes(&mut self, name: &[u8]) -> Result<usize> {
        if name.is_empty() {
            return Err(Error::DomainNameLabelIsEmpty);
        }

        if name == b"." {
            self.u8(0)?;
            return Ok(1);
        }

        let start = self.pos();
        let len = name.len();

        let mut i = 0;
        let mut domain_start = None;

        for j in 0..len {
            let byte = unsafe { *name.get_unchecked(j) };
            if byte == b'.' {
                let label = unsafe { name.get_unchecked(i..j) };
                self.write_label(label)?;
                i = j + 1;
                domain_start = Some(i);
            }
        }

        match domain_start {
            Some(ds) if len - ds > 0 => {
                let label = unsafe { name.get_unchecked(ds..len) };
                self.write_label(label)?;
            }
            None => self.write_label(name)?,
            _ => {}
        };

        self.u8(0)?;

        let length = self.pos() - start;
        if length > DOMAIN_NAME_MAX_LENGTH {
            return Err(Error::DomainNameTooLong(length));
        }

        Ok(length)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        bytes::{Cursor, Reader},
        names::InlineName,
    };
    use std::str::FromStr;

    #[test]
    fn test_write_good_flow() {
        let expectations: Vec<(&str, &[u8])> = vec![
            ("sub.example.com", b"\x03sub\x07example\x03com\x00"),
            ("example.com.", b"\x07example\x03com\x00"),
            ("_example.com.", b"\x08_example\x03com\x00"),
            ("com", b"\x03com\x00"),
            ("com.", b"\x03com\x00"),
            (".", b"\x00"),
        ];

        for ex in expectations {
            let mut arr = [0xFFu8; 64];
            let mut wcursor = WCursor::new(&mut arr[..]);

            let len = wcursor.write_domain_name(ex.0).unwrap();

            assert_eq!(len, ex.1.len());
            assert_eq!(&arr[..len], ex.1);

            let mut cursor = Cursor::new(&arr[..len]);
            let dn: InlineName = cursor.read().unwrap();

            assert_eq!(dn, InlineName::from_str(ex.0).unwrap());
        }
    }

    #[test]
    fn test_write_domain_too_long() {
        let mut arr = [0xFFu8; 1024];

        let l_63 = "a".repeat(63);

        let long_label = [l_63.as_str(), l_63.as_str(), l_63.as_str(), l_63.as_str()].join(".");
        assert_eq!(long_label.len(), 255);

        {
            let mut wcursor = WCursor::new(&mut arr[..]);
            assert!(matches!(
                wcursor.write_domain_name(&long_label),
                Err(Error::DomainNameTooLong(s)) if s == long_label.len() + 2
            ));
        }

        {
            let mut wcursor = WCursor::new(&mut arr[..]);
            assert!(matches!(
                wcursor.write_domain_name(&long_label[..long_label.len() - 1]),
                Err(Error::DomainNameTooLong(s)) if s == long_label.len() + 1
            ));
        }

        {
            let mut wcursor = WCursor::new(&mut arr[..]);
            let len = wcursor
                .write_domain_name(&long_label[..long_label.len() - 2])
                .unwrap();
            assert_eq!(len, 255);

            let mut cursor = Cursor::new(&arr[..len]);
            let dn: InlineName = cursor.read().unwrap();

            assert_eq!(
                dn,
                InlineName::from_str(&long_label[..long_label.len() - 2]).unwrap()
            );
            assert_eq!(dn.len(), 254);
        }
    }

    #[test]
    fn test_write_malformed_label() {
        let empty: Vec<&str> = vec![
            "",
            "..",
            "example.com..",
            "example..com",
            "sub..example.com",
        ];

        for e in empty {
            let mut arr = [0xFFu8; 32];
            let mut wcursor = WCursor::new(&mut arr[..]);
            assert!(matches!(
                wcursor.write_domain_name(e),
                Err(Error::DomainNameLabelIsEmpty)
            ));
        }

        let samples: Vec<(&str, u8)> = vec![("-xample.com", b'-')];

        for (s, c) in samples {
            let mut arr = [0xFFu8; 32];
            let mut wcursor = WCursor::new(&mut arr[..]);
            assert!(matches!(
                wcursor.write_domain_name(s),
                Err(Error::DomainNameLabelInvalidChar(
                    "domain name label first character is '-'",
                    v
                )) if v == c
            ));
        }

        let samples: Vec<(&str, u8)> = vec![("co-", b'-'), ("example-.com", b'-')];

        for (s, c) in samples {
            let mut arr = [0xFFu8; 32];
            let mut wcursor = WCursor::new(&mut arr[..]);
            assert!(matches!(
                wcursor.write_domain_name(s),
                Err(Error::DomainNameLabelInvalidChar(
                    "domain name label last character is '-'",
                    v
                )) if v == c
            ));
        }
    }
}