dns-server 0.2.6

A threaded DNS server.
Documentation
use crate::{DnsError, DnsMessage, DnsOpCode, DnsQuestion, DnsRecord, DnsType};
use fixed_buffer::FixedBuf;
use oorandom::Rand32;
use permit::Permit;
use prob_rate_limiter::ProbRateLimiter;
use std::cell::RefCell;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::io::ErrorKind;
use std::net::{IpAddr, Ipv6Addr, SocketAddr, UdpSocket};
use std::time::{Duration, Instant};

thread_local!(static RAND32: RefCell<Rand32> = RefCell::new(Rand32::new(0)));

/// # Errors
/// Returns `Err` when the request is malformed or the server is not configured to answer the
/// request.
pub fn process_request(
    request: &DnsMessage,
    handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
) -> Result<DnsMessage, DnsError> {
    if request.header.is_response {
        return Err(DnsError::NotARequest);
    }
    if request.header.op_code != DnsOpCode::Query {
        return Err(DnsError::InvalidOpCode);
    }
    // NOTE: We only answer the first question.
    let question = request.questions.first().ok_or(DnsError::NoQuestion)?;
    // u16::try_from(self.questions.len()).map_err(|_| ProcessError::TooManyQuestions)?,
    let records = handler(question);
    request.answer_response(records)
}

/// # Errors
/// Returns `Err` when the request is malformed or the server is not configured to answer the
/// request.
#[allow(clippy::implicit_hasher)]
pub fn process_datagram(
    bytes: &mut FixedBuf<512>,
    handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
) -> Result<FixedBuf<512>, DnsError> {
    //println!("process_datagram: bytes = {:?}", bytes.readable());
    let request = DnsMessage::read(bytes)?;
    //println!("process_datagram: request = {:?}", request);
    let response = process_request(&request, &handler)?;
    //println!("process_datagram: response = {:?}", response);
    let mut out: FixedBuf<512> = FixedBuf::new();
    response.write(&mut out)?;
    //println!("process_datagram: out = {:?}", out.readable());
    Ok(out)
}

/// # Errors
/// Returns `Err` when socket operations fail.
#[allow(clippy::missing_panics_doc)]
pub fn serve_udp(
    permit: &Permit,
    sock: &UdpSocket,
    mut response_bytes_rate_limiter: ProbRateLimiter,
    handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
) -> Result<(), String> {
    sock.set_read_timeout(Some(Duration::from_millis(500)))
        .map_err(|e| format!("error setting socket read timeout: {e}"))?;
    let addr = sock
        .local_addr()
        .map_err(|e| format!("error getting socket local address: {e}"))?;
    while !permit.is_revoked() {
        // > DNS messages carried by UDP are restricted to 512 bytes (not counting the IP
        // > or UDP headers).  Longer messages are truncated and the TC bit is set in
        // > the header.
        // https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1
        let mut buf: FixedBuf<512> = FixedBuf::new();
        let addr = match sock.recv_from(buf.writable()) {
            Ok((len, _)) if len > buf.writable().len() => {
                println!("dropping over-long request");
                continue;
            }
            Ok((len, addr)) => {
                buf.wrote(len);
                addr
            }
            Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
                continue
            }
            Err(e) => return Err(format!("error reading DNS server UDP socket {addr:?}: {e}")),
        };
        let now = Instant::now();
        if !response_bytes_rate_limiter.attempt(now) {
            println!("dropping request");
            continue;
        }
        let out = match process_datagram(&mut buf, handler) {
            Ok(buf) => buf,
            Err(e) => {
                println!("dropping bad request: {e:?}");
                continue;
            }
        };
        if out.is_empty() {
            unreachable!();
        }
        response_bytes_rate_limiter.record(u32::try_from(out.len()).unwrap());
        let sent_len = sock
            .send_to(out.readable(), addr)
            .map_err(|e| format!("error sending response to {addr:?}: {e}"))?;
        if sent_len != out.len() {
            return Err(format!(
                "sent only {sent_len} bytes of {} byte response to {addr:?}",
                out.len()
            ));
        }
    }
    Ok(())
}

pub struct Builder {
    permit: Option<Permit>,
    sock: UdpSocket,
    max_response_bytes_per_second: Option<u32>,
}
impl Builder {
    #[must_use]
    pub fn new(sock: UdpSocket) -> Self {
        Self {
            permit: None,
            sock,
            max_response_bytes_per_second: None,
        }
    }

    /// # Errors
    /// Returns `Err` when it failed to allocate a socket or bind it to the specified port.
    pub fn new_port(port: u16) -> Result<Self, String> {
        let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port))
            .map_err(|e| format!("error binding to UDP port {port}: {e}"))?;
        Ok(Self::new(sock))
    }

    /// # Errors
    /// Returns `Err` when it failed to allocate a socket or bind it an available port.
    pub fn new_random_port() -> Result<(Self, SocketAddr), String> {
        let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))
            .map_err(|e| format!("error binding to random UDP port: {e}"))?;
        let addr = sock
            .local_addr()
            .map_err(|e| format!("error getting socket local address: {e}"))?;
        Ok((Self::new(sock), addr))
    }

    #[must_use]
    pub fn with_permit(mut self, permit: Permit) -> Self {
        self.permit = Some(permit);
        self
    }

    #[must_use]
    pub fn with_max_response_bytes_per_second(mut self, n: u32) -> Self {
        self.max_response_bytes_per_second = Some(n);
        self
    }

    /// # Errors
    /// Returns `Err` when socket operations fail.
    pub fn serve(self, handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>) -> Result<(), String> {
        let permit = self.permit.unwrap_or_default();
        let max_response_bytes_per_second = self.max_response_bytes_per_second.unwrap_or(1_000_000);
        let limiter = ProbRateLimiter::new(max_response_bytes_per_second);
        serve_udp(&permit, &self.sock, limiter, handler)
    }

    /// # Errors
    /// Returns `Err` when socket operations fail.
    pub fn serve_static(self, records: &[DnsRecord]) -> Result<(), String> {
        let mut name_to_records: HashMap<String, Vec<&DnsRecord>> = HashMap::default();
        for record in records {
            let key = record.name().inner().to_ascii_lowercase();
            if let Some(v) = name_to_records.get_mut(&key) {
                v.push(record);
            } else {
                name_to_records.insert(key, vec![&record]);
            }
        }
        let handler = move |q: &DnsQuestion| {
            let key = q.name.inner().to_ascii_lowercase();
            let Some(record_refs) = name_to_records.get(&key) else {
                return Vec::new();
            };
            let mut records: Vec<DnsRecord> = record_refs.iter().map(|r| (*r).clone()).collect();
            if q.typ != DnsType::ANY {
                records.retain(|r| r.typ() == q.typ);
            }
            if !records.is_empty() {
                let range = 0..(u32::try_from(records.len()).unwrap_or(u32::MAX));
                let k = RAND32.with_borrow_mut(|r| r.rand_range(range)) as usize;
                records.rotate_right(k);
            }
            records
        };
        self.serve(&handler)
    }
}