rsdns 0.23.0

DNS Client Library
Documentation
use zyn::TokenStream;

pub fn render(crate_name: &str) -> TokenStream {
    zyn::zyn! {
        use crate::{
            clients::config::{ProtocolStrategy, Recursion, ClientConfig, EDns},
            constants::DNS_MESSAGE_BUFFER_MIN_LENGTH,
            message::{reader::MessageReader, Flags, QueryWriter},
            records::{data::RData, Class, RecordSet, Opt, Type},
            Error, Result,
        };

        @if (crate_name == "tokio") {
            use tokio::{
                net::{TcpStream, UdpSocket},
                io::{AsyncReadExt, AsyncWriteExt},
                time::timeout
            };

            #[cfg(all(target_os = "linux", feature = "net-tokio", feature = "socket2"))]
            use std::os::unix::io::{IntoRawFd, FromRawFd};

            #[cfg(all(target_os = "linux", feature = "net-tokio", feature = "socket2"))]
            use tokio::net::TcpSocket;
        }
        @else if (crate_name == "async-std") {
            use async_std::{
                future::timeout,
                net::{TcpStream, UdpSocket},
                io::prelude::{ReadExt, WriteExt}
            };
        }
        @else if (crate_name == "smol") {
            use smol::{
                net::{TcpStream, UdpSocket},
                io::{AsyncReadExt, AsyncWriteExt},
            };
            use smol_timeout::TimeoutExt;
        }

        const QUERY_BUFFER_SIZE: usize = 288;
        type MsgBuf = arrayvec::ArrayVec<u8, QUERY_BUFFER_SIZE>;

        pub struct ClientImpl {
            config: ClientConfig,
            sock: UdpSocket,
            buf: Vec<u8>,
        }

        impl ClientImpl {
            pub async fn new(config: ClientConfig) -> Result<Self> {
                let sock = udp_socket(&config).await?;
                let buf = match config.buffer_size() {
                    0 => Vec::new(),
                    bs => Vec::with_capacity(bs),
                };
                Ok(Self { config, sock, buf })
            }

            pub fn config(&self) -> &ClientConfig {
                &self.config
            }

            pub async fn query_raw(&self, qname: &str, qtype: Type, qclass: Class, buf: &mut [u8]) -> Result<usize> {
                if buf.len() < DNS_MESSAGE_BUFFER_MIN_LENGTH {
                    return Err(Error::BufferTooShort(DNS_MESSAGE_BUFFER_MIN_LENGTH));
                }
                let mut ctx = ClientCtx {
                    qname,
                    qtype,
                    qclass,
                    sock: &self.sock,
                    config: &self.config,
                    msg_id: 0,
                    msg: MsgBuf::default(),
                    buf
                };
                ctx.prepare_message()?;
                ctx.query_raw().await
            }

            #[allow(clippy::await_holding_refcell_ref)]
            pub async fn query_rrset<D: RData>(&mut self, qname: &str, qclass: Class) -> Result<RecordSet<D>> {
                if self.config.buffer_size() == 0 {
                    return Err(Error::BadParam("non-zero buffer_size is required"));
                }
                if !qclass.is_data_class() {
                    return Err(Error::UnsupportedClass(qclass));
                }
                let mut buf = unsafe { self.take_buf() };
                let response_len = match self.query_raw(qname, D::RTYPE, qclass, &mut buf).await {
                    Ok(v) => v,
                    Err(e) => {
                        std::mem::swap(&mut self.buf, &mut buf);
                        return Err(e);
                    }
                };
                unsafe {
                    buf.set_len(response_len);
                }
                let result = RecordSet::from_msg(&buf);
                std::mem::swap(&mut self.buf, &mut buf);
                result
            }

            unsafe fn take_buf(&mut self) -> Vec<u8> {
                let mut buf = std::mem::take(&mut self.buf);
                if buf.capacity() < self.config.buffer_size() {
                    buf.reserve(self.config.buffer_size() - buf.capacity());
                }
                unsafe {
                    buf.set_len(self.config.buffer_size());
                }
                buf
            }
        }

        struct ClientCtx<'a, 'b, 'c, 'd> {
            qname: &'a str,
            qtype: Type,
            qclass: Class,
            sock: &'b UdpSocket,
            config: &'c ClientConfig,
            msg_id: u16,
            msg: MsgBuf,
            buf: &'d mut [u8],
        }

        impl ClientCtx<'_, '_, '_, '_> {
            async fn query_raw(&mut self) -> Result<usize> {
                let query_lifetime = self.config.query_lifetime();

                let future = self.query_raw_impl();

                @if (crate_name == "tokio" || crate_name == "async-std") {
                    match timeout(query_lifetime, future).await {
                        Ok(res) => res,
                        Err(_) => Err(Error::Timeout),
                    }
                }
                @else if (crate_name == "smol") {
                    match future.timeout(query_lifetime).await {
                        Some(res) => res,
                        None => Err(Error::Timeout),
                    }
                }
            }

            async fn query_raw_impl(&mut self) -> Result<usize> {
                if self.udp_first() {
                    let (size, flags) = self.udp_exchange_loop().await?;

                    if flags.truncated() && self.tcp_allowed() {
                        self.tcp_exchange().await
                    } else {
                        Ok(size)
                    }
                } else {
                    self.tcp_exchange().await
                }
            }

            async fn tcp_exchange(&mut self) -> Result<usize> {
                let mut sock = tcp_socket(self.config).await?;

                sock.write_all(&self.msg).await?;

                let mut response_size_buf = [0u8; 2];
                sock.read_exact(&mut response_size_buf).await?;

                let response_size = u16::from_be_bytes(response_size_buf) as usize;

                if response_size > self.buf.len() {
                    return Err(Error::BufferTooShort(response_size));
                }

                sock.read_exact(&mut self.buf[..response_size]).await?;

                Ok(response_size)
            }

            async fn udp_exchange_loop(&mut self) -> Result<(usize, Flags)> {
                loop {
                    self.sock.send(&self.msg[2..]).await?;

                    let query_timeout = self.config.query_timeout();

                    let future = self.udp_receive_loop();

                    if let Some(query_timeout) = query_timeout {
                        @if (crate_name == "tokio" || crate_name == "async-std") {
                            match timeout(query_timeout, future).await {
                                Ok(res) => return res,
                                Err(_) => continue,
                            };
                        }
                        @else if (crate_name == "smol") {
                            match future.timeout(query_timeout).await {
                                Some(res) => return res,
                                None => continue,
                            };
                        }
                    } else {
                        return future.await;
                    }
                }
            }

            async fn udp_receive_loop(&mut self) -> Result<(usize, Flags)> {
                loop {
                    let size = self.sock.recv(self.buf).await?;

                    let response = &self.buf[..size];
                    let mut mr = match MessageReader::new(response) {
                        Ok(mr) => mr,
                        Err(_) => continue,
                    };
                    let header = match mr.header() {
                        Ok(h) => h,
                        Err(_) => continue,
                    };

                    if header.id != self.msg_id {
                        continue;
                    }

                    if let Ok(question) = mr.the_question()
                        && question.qtype == self.qtype
                        && question.qclass == self.qclass
                        && question.qname == self.qname
                    {
                        return Ok((size, header.flags));
                    }
                }
            }

            fn prepare_message(&mut self) -> Result<()> {
                let opt = match self.config.edns_ {
                    EDns::On {
                        version,
                        udp_payload_size
                    } => {
                        let ups = (udp_payload_size as usize).min(self.buf.len());
                        Some(Opt::new(version, ups as u16))
                    },
                    EDns::Off => None,
                };
                unsafe { self.msg.set_len(self.msg.capacity()); }
                let mut qw = QueryWriter::new(&mut self.msg);
                self.msg_id = qw.message_id();
                let msg_len = qw.write(self.qname, self.qtype, self.qclass,
                                       self.config.recursion_ == Recursion::On, opt)?;
                unsafe { self.msg.set_len(msg_len); }
                Ok(())
            }

            #[inline]
            fn udp_first(&self) -> bool {
                match self.config.protocol_strategy_ {
                    ProtocolStrategy::Udp | ProtocolStrategy::NoTcp => true,
                    ProtocolStrategy::Tcp => false,
                }
            }

            #[inline]
            fn tcp_allowed(&self) -> bool {
                self.config.protocol_strategy_ != ProtocolStrategy::NoTcp
            }
        }

        @if (crate_name == "tokio") {
            #[cfg(all(target_os = "linux", feature = "net-tokio", feature = "socket2"))]
            async fn udp_socket2(config: &ClientConfig) -> Result<UdpSocket> {
                if config.interface_.is_empty() {
                    return udp_socket_simple(config).await;
                }

                let mut interface = config.interface_;
                interface.try_push(char::default()).ok();

                let sock = socket2::Socket::new(
                    socket2::Domain::for_address(config.nameserver_),
                    socket2::Type::DGRAM.nonblocking().cloexec(),
                    Some(socket2::Protocol::UDP)
                )?;

                sock.bind_device(Some(interface.as_bytes()))?;

                let sockaddr = socket2::SockAddr::from(config.bind_addr_);
                sock.bind(&sockaddr)?;

                let sockaddr = socket2::SockAddr::from(config.nameserver_);
                sock.connect(&sockaddr)?;

                let std_sock = unsafe { std::net::UdpSocket::from_raw_fd(sock.into_raw_fd()) };

                Ok(UdpSocket::from_std(std_sock)?)
            }

            #[cfg(all(target_os = "linux", feature = "net-tokio", feature = "socket2"))]
            async fn tcp_socket2(config: &ClientConfig) -> Result<TcpStream> {
                if config.interface_.is_empty() {
                    return tcp_socket_simple(config).await;
                }

                let mut interface = config.interface_;
                interface.try_push(char::default()).ok();

                let sock = socket2::Socket::new(
                    socket2::Domain::for_address(config.nameserver_),
                    socket2::Type::STREAM.nonblocking().cloexec(),
                    Some(socket2::Protocol::TCP)
                )?;

                sock.bind_device(Some(interface.as_bytes()))?;
                sock.set_tcp_nodelay(true)?;

                let tcp_socket = unsafe { TcpSocket::from_raw_fd(sock.into_raw_fd()) };

                Ok(tcp_socket.connect(config.nameserver_).await?)
            }
        }

        #[inline(always)]
        async fn udp_socket_simple(config: &ClientConfig) -> Result<UdpSocket> {
            let sock = UdpSocket::bind(config.bind_addr_).await?;
            sock.connect(config.nameserver_).await?;
            Ok(sock)
        }

        #[inline(always)]
        async fn tcp_socket_simple(config: &ClientConfig) -> Result<TcpStream> {
            let sock = TcpStream::connect(config.nameserver_).await?;
            sock.set_nodelay(true)?;
            Ok(sock)
        }

        #[inline(always)]
        async fn udp_socket(config: &ClientConfig) -> Result<UdpSocket> {
            @if (crate_name != "tokio") {
                udp_socket_simple(config).await
            }
            @else {
                cfg_if::cfg_if! {
                    if #[cfg(all(target_os = "linux", feature = "net-tokio", feature = "socket2"))] {
                        udp_socket2(config).await
                    }
                    else {
                        udp_socket_simple(config).await
                    }
                }
            }
        }

        #[inline(always)]
        async fn tcp_socket(config: &ClientConfig) -> Result<TcpStream> {
            @if (crate_name != "tokio") {
                tcp_socket_simple(config).await
            }
            @else {
                cfg_if::cfg_if! {
                    if #[cfg(all(target_os = "linux", feature = "net-tokio", feature = "socket2"))] {
                        tcp_socket2(config).await
                    }
                    else {
                        tcp_socket_simple(config).await
                    }
                }
            }
        }
    }
    .into()
}