rsdns 0.23.0

DNS Client Library
Documentation
use crate::{
    clients::config::{ClientConfig, EDns, ProtocolStrategy, Recursion},
    constants::DNS_MESSAGE_BUFFER_MIN_LENGTH,
    errors::{Error, Result},
    message::{Flags, QueryWriter, reader::MessageReader},
    records::{Class, Opt, RecordSet, Type, data::RData},
};
use std::{
    io::{ErrorKind, Read, Write},
    net::{TcpStream, UdpSocket},
    time::{Duration, Instant},
};

const QUERY_BUFFER_SIZE: usize = 288;
type MsgBuf = arrayvec::ArrayVec<u8, QUERY_BUFFER_SIZE>;

struct ClientCtx<'a, 'b, 'c, 'd> {
    qname: &'a str,
    qtype: Type,
    qclass: Class,
    sock: &'b UdpSocket,
    config: &'c ClientConfig,
    msg_id: u16,
    msg: MsgBuf,
    buf: &'d mut [u8],
    start: Instant,
    query_start: Instant,
}

pub(crate) struct ClientImpl {
    config: ClientConfig,
    socket: UdpSocket,
    buf: Vec<u8>,
}

impl ClientImpl {
    pub fn new(config: ClientConfig) -> Result<Self> {
        let socket = UdpSocket::bind(config.bind_addr_)?;
        socket.connect(config.nameserver_)?;

        let buf = match config.buffer_size() {
            0 => Vec::new(),
            bs => Vec::with_capacity(bs),
        };

        Ok(Self {
            config,
            socket,
            buf,
        })
    }

    pub fn config(&self) -> &ClientConfig {
        &self.config
    }

    pub fn query_raw(
        &self,
        qname: &str,
        qtype: Type,
        qclass: Class,
        buf: &mut [u8],
    ) -> Result<usize> {
        if buf.len() < DNS_MESSAGE_BUFFER_MIN_LENGTH {
            return Err(Error::BufferTooShort(DNS_MESSAGE_BUFFER_MIN_LENGTH));
        }
        let now = Instant::now();
        let mut ctx = ClientCtx {
            qname,
            qtype,
            qclass,
            sock: &self.socket,
            config: &self.config,
            msg_id: 0,
            msg: MsgBuf::default(),
            buf,
            start: now,
            query_start: now,
        };
        ctx.prepare_message()?;
        ctx.query_raw()
    }

    pub fn query_rrset<D: RData>(&mut self, qname: &str, qclass: Class) -> Result<RecordSet<D>> {
        if self.config.buffer_size() == 0 {
            return Err(Error::BadParam("non-zero buffer_size is required"));
        }
        if !qclass.is_data_class() {
            return Err(Error::UnsupportedClass(qclass));
        }
        let mut buf = unsafe { self.take_buf() };
        let response_len = match self.query_raw(qname, D::RTYPE, qclass, &mut buf) {
            Ok(v) => v,
            Err(e) => {
                std::mem::swap(&mut self.buf, &mut buf);
                return Err(e);
            }
        };
        unsafe { buf.set_len(response_len) };
        let result = RecordSet::from_msg(&buf);
        std::mem::swap(&mut self.buf, &mut buf);
        result
    }

    unsafe fn take_buf(&mut self) -> Vec<u8> {
        let mut buf = std::mem::take(&mut self.buf);
        if buf.capacity() < self.config.buffer_size() {
            buf.reserve(self.config.buffer_size() - buf.capacity());
        }
        unsafe {
            buf.set_len(self.config.buffer_size());
        }
        buf
    }
}

impl ClientCtx<'_, '_, '_, '_> {
    #[inline]
    fn query_raw(&mut self) -> Result<usize> {
        match self.query_raw_impl() {
            Err(Error::IoError(v)) if v.kind() == ErrorKind::TimedOut => Err(Error::Timeout),
            Err(Error::IoError(v)) if v.kind() == ErrorKind::WouldBlock => Err(Error::Timeout),
            Ok(s) => Ok(s),
            Err(e) => Err(e),
        }
    }

    fn query_raw_impl(&mut self) -> Result<usize> {
        if self.udp_first() {
            let (size, flags) = self.udp_exchange()?;

            if flags.truncated() && self.tcp_allowed() {
                self.tcp_exchange()
            } else {
                Ok(size)
            }
        } else {
            self.tcp_exchange()
        }
    }

    fn tcp_exchange(&mut self) -> Result<usize> {
        let mut sock = TcpStream::connect_timeout(&self.config.nameserver_, self.lifetime_left()?)?;

        Self::set_timeout_tcp(&sock, self.lifetime_left()?)?;
        sock.write_all(&self.msg)?;

        Self::set_timeout_tcp(&sock, self.lifetime_left()?)?;
        let mut response_size_buf = [0u8; 2];
        sock.read_exact(&mut response_size_buf)?;

        let response_size = u16::from_be_bytes(response_size_buf) as usize;
        if response_size > self.buf.len() {
            return Err(Error::BufferTooShort(response_size));
        }

        Self::set_timeout_tcp(&sock, self.lifetime_left()?)?;
        sock.read_exact(&mut self.buf[..response_size])?;

        Ok(response_size)
    }

    fn udp_exchange(&mut self) -> Result<(usize, Flags)> {
        loop {
            self.query_start = Instant::now();
            Self::set_timeout_udp(self.sock, self.query_left()?)?;

            self.sock.send(&self.msg[2..])?;

            match self.udp_receive_loop() {
                Ok(v) => break Ok(v),
                Err(Error::IoError(v)) if v.kind() == ErrorKind::WouldBlock => continue,
                Err(Error::IoError(v)) if v.kind() == ErrorKind::TimedOut => continue,
                Err(v) => break Err(v),
            }
        }
    }

    fn udp_receive_loop(&mut self) -> Result<(usize, Flags)> {
        loop {
            Self::set_timeout_udp(self.sock, self.query_left()?)?;

            let size = self.sock.recv(self.buf)?;

            let response = &self.buf[..size];
            let mut mr = {
                match MessageReader::new(response) {
                    Ok(mr) => mr,
                    Err(_) => continue,
                }
            };

            let header = match mr.header() {
                Ok(h) => h,
                Err(_) => continue,
            };

            if header.id != self.msg_id {
                continue;
            }

            if let Ok(question) = mr.the_question()
                && question.qtype == self.qtype
                && question.qclass == self.qclass
                && question.qname == self.qname
            {
                return Ok((size, header.flags));
            }
        }
    }

    fn prepare_message(&mut self) -> Result<()> {
        let opt = match self.config.edns_ {
            EDns::On {
                version,
                udp_payload_size,
            } => {
                let ups = (udp_payload_size as usize).min(self.buf.len());
                Some(Opt::new(version, ups as u16))
            }
            EDns::Off => None,
        };

        unsafe {
            self.msg.set_len(self.msg.capacity());
        }

        let recursion = self.config.recursion_ == Recursion::On;
        let mut qw = QueryWriter::new(&mut self.msg);

        self.msg_id = qw.message_id();
        let msg_len = qw.write(self.qname, self.qtype, self.qclass, recursion, opt)?;

        unsafe {
            self.msg.set_len(msg_len);
        }

        Ok(())
    }

    #[inline]
    fn udp_first(&self) -> bool {
        match self.config.protocol_strategy_ {
            ProtocolStrategy::Udp | ProtocolStrategy::NoTcp => true,
            ProtocolStrategy::Tcp => false,
        }
    }

    #[inline]
    fn tcp_allowed(&self) -> bool {
        self.config.protocol_strategy_ != ProtocolStrategy::NoTcp
    }

    fn set_timeout_udp(sock: &UdpSocket, timeout: Duration) -> Result<()> {
        let timeout = Some(timeout);
        sock.set_read_timeout(timeout)?;
        sock.set_write_timeout(timeout)?;
        Ok(())
    }

    fn set_timeout_tcp(sock: &TcpStream, timeout: Duration) -> Result<()> {
        let timeout = Some(timeout);
        sock.set_read_timeout(timeout)?;
        sock.set_write_timeout(timeout)?;
        Ok(())
    }

    fn lifetime_left(&self) -> Result<Duration> {
        let elapsed = self.start.elapsed();
        if elapsed >= self.config.query_lifetime_ {
            return Err(Error::Timeout);
        }
        Ok(self.config.query_lifetime_ - elapsed)
    }

    fn query_left(&self) -> Result<Duration> {
        let lifetime_left = self.lifetime_left()?;

        let timeout = match self.config.query_timeout_ {
            Some(duration) => duration,
            _ => self.config.query_lifetime_,
        };

        let elapsed = self.query_start.elapsed();

        let time_left = if elapsed < timeout {
            timeout - elapsed
        } else {
            Duration::from_millis(0)
        };

        Ok(time_left.min(lifetime_left))
    }
}