rustdns 0.4.0

A DNS parsing library
Documentation
//! Various traits to help parsing of DNS messages.

use crate::bail;
use crate::types::{Class, Type};
use byteorder::{ReadBytesExt, BE};
use num_traits::FromPrimitive;
use std::convert::TryInto;
use std::io;
use std::io::Cursor;
use std::io::SeekFrom;

pub trait SeekExt: io::Seek {
    /// Returns the number of bytes remaining to be consumed.
    /// This is used as a way to check for malformed input.
    fn remaining(&mut self) -> io::Result<u64> {
        let pos = self.stream_position()?;
        let len = self.seek(SeekFrom::End(0))?;

        // reset position
        self.seek(SeekFrom::Start(pos))?;

        Ok(len - pos)
    }
}

impl<'a> SeekExt for Cursor<&'a [u8]> {
    fn remaining(self: &mut std::io::Cursor<&'a [u8]>) -> io::Result<u64> {
        let pos = self.position() as usize;
        let len = self.get_ref().len() as usize;

        Ok((len - pos).try_into().unwrap())
    }
}

pub trait CursorExt<T> {
    /// Return a cursor that is bounded over the original cursor by start-end.
    ///
    /// The returned cursor contains all values with start <= x < end. It is empty if start >= end.
    ///
    /// Similar to `Take` but allows the start-end range to be specified, instead of just the next
    /// N values.
    fn sub_cursor(&mut self, start: usize, end: usize) -> io::Result<std::io::Cursor<T>>;
}

impl<'a> CursorExt<&'a [u8]> for Cursor<&'a [u8]> {
    fn sub_cursor(&mut self, start: usize, end: usize) -> io::Result<std::io::Cursor<&'a [u8]>> {
        let buf = self.get_ref();

        let start = start.clamp(0, buf.len());
        let end = end.clamp(start, buf.len());

        let record = Cursor::new(&buf[start..end]);
        Ok(record)
    }
}

/// All types that implement `Read` and `Seek` get methods defined
/// in `DNSReadExt` for free.
impl<R: io::Read + ?Sized + io::Seek> DNSReadExt for R {}

/// Extensions to io::Read to add some DNS specific types.
pub trait DNSReadExt: io::Read + io::Seek {
    /// Reads a puny encoded domain name from a byte array.
    ///
    /// Used for extracting a encoding ASCII domain name from a DNS message. Will
    /// returns the Unicode domain name, as well as the length of this name (ignoring
    /// any compressed pointers) in bytes.
    ///
    /// # Errors
    ///
    /// Will return a io::Error(InvalidData) if the read domain name is invalid, or
    /// a more general io::Error on any other read failure.
    fn read_qname(&mut self) -> io::Result<String> {
        let mut qname = String::new();
        let start = self.stream_position()?;

        // Read each label one at a time, to build up the full domain name.
        loop {
            // Length of the first label
            let len = self.read_u8()?;
            if len == 0 {
                if qname.is_empty() {
                    qname.push('.') // Root domain
                }
                break;
            }

            match len & 0xC0 {
                // No compression
                0x00 => {
                    let mut label = vec![0; len.into()];
                    self.read_exact(&mut label)?;

                    // Really this is meant to be ASCII, but we read as utf8
                    // (as that what Rust provides).
                    let label = match std::str::from_utf8(&label) {
                        Err(e) => bail!(InvalidData, "invalid label: {}", e),
                        Ok(s) => s,
                    };

                    if !label.is_ascii() {
                        bail!(InvalidData, "invalid label '{:}': not valid ascii", label);
                    }

                    // Now puny decode this label returning its original unicode.
                    let label = match idna::domain_to_unicode(label) {
                        (label, Err(e)) => bail!(InvalidData, "invalid label '{:}': {}", label, e),
                        (label, Ok(_)) => label,
                    };

                    qname.push_str(&label);
                    qname.push('.');
                }

                // Compression
                0xC0 => {
                    // Read the 14 bit pointer.
                    let b2 = self.read_u8()? as u16;
                    let ptr = ((len as u16 & !0xC0) << 8 | b2) as u64;

                    // Make sure we don't get into a loop.
                    if ptr >= start {
                        bail!(
                            InvalidData,
                            "invalid compressed pointer pointing to future bytes"
                        );
                    }

                    // We are going to jump backwards, so record where we
                    // currently are. So we can reset it later.
                    let current = self.stream_position()?;

                    // Jump and start reading the qname again.
                    self.seek(SeekFrom::Start(ptr))?;
                    qname.push_str(&self.read_qname()?);

                    // Reset ourselves.
                    self.seek(SeekFrom::Start(current))?;

                    break;
                }

                // Unknown
                _ => bail!(
                    InvalidData,
                    "unsupported compression type {0:b}",
                    len & 0xC0
                ),
            }
        }

        Ok(qname)
    }

    /// Reads a DNS Type.
    fn read_type(&mut self) -> io::Result<Type> {
        let r#type = self.read_u16::<BE>()?;
        let r#type = match FromPrimitive::from_u16(r#type) {
            Some(t) => t,
            None => bail!(InvalidData, "invalid Type({})", r#type),
        };

        Ok(r#type)
    }

    /// Reads a DNS Class.
    fn read_class(&mut self) -> io::Result<Class> {
        let class = self.read_u16::<BE>()?;
        let class = match FromPrimitive::from_u16(class) {
            Some(t) => t,
            None => bail!(InvalidData, "invalid Class({})", class),
        };

        Ok(class)
    }
}