numa 0.13.1

Portable DNS resolver in Rust — .numa local domains, ad blocking, developer overrides, DNS-over-HTTPS
Documentation
use std::net::SocketAddr;

use axum::body::Bytes;
use axum::extract::{Request, State};
use axum::response::{IntoResponse, Response};
use hyper::StatusCode;
use log::warn;

use crate::buffer::BytePacketBuffer;
use crate::ctx::{resolve_query, ServerCtx};
use crate::header::ResultCode;
use crate::packet::DnsPacket;
use crate::stats::Transport;

const MAX_DNS_MSG: usize = 4096;
const DOH_CONTENT_TYPE: &str = "application/dns-message";

pub async fn doh_post(State(state): State<super::proxy::DohState>, req: Request) -> Response {
    let host = super::proxy::extract_host(&req);
    if !is_doh_host(host.as_deref(), &state.ctx.proxy_tld) {
        return StatusCode::NOT_FOUND.into_response();
    }

    let content_type = req
        .headers()
        .get(hyper::header::CONTENT_TYPE)
        .and_then(|v| v.to_str().ok())
        .unwrap_or("");
    if !content_type.starts_with(DOH_CONTENT_TYPE) {
        return StatusCode::UNSUPPORTED_MEDIA_TYPE.into_response();
    }

    let body = match axum::body::to_bytes(req.into_body(), MAX_DNS_MSG).await {
        Ok(b) => b,
        Err(_) => {
            return (StatusCode::PAYLOAD_TOO_LARGE, "body exceeds 4096 bytes").into_response()
        }
    };

    if body.is_empty() {
        return (StatusCode::BAD_REQUEST, "empty body").into_response();
    }

    let src = state
        .remote_addr
        .unwrap_or_else(|| SocketAddr::from(([127, 0, 0, 1], 0)));

    resolve_doh(&body, src, &state.ctx).await
}

fn is_doh_host(host: Option<&str>, tld: &str) -> bool {
    let h = match host {
        Some(h) => h,
        None => return false,
    };
    let base = strip_port(h).unwrap_or(h);
    is_loopback_host(base) || is_tld_match(base, tld)
}

fn strip_port(h: &str) -> Option<&str> {
    if h.starts_with('[') {
        // [::1]:443 → [::1]
        let (base, port) = h.rsplit_once("]:")?;
        port.bytes()
            .all(|b| b.is_ascii_digit())
            .then(|| &h[..base.len() + 1])
    } else {
        let (base, port) = h.rsplit_once(':')?;
        // Bare IPv6 like "::1" has multiple colons — not a port suffix
        if base.contains(':') {
            return None;
        }
        port.bytes().all(|b| b.is_ascii_digit()).then_some(base)
    }
}

fn is_loopback_host(h: &str) -> bool {
    matches!(h, "127.0.0.1" | "::1" | "[::1]" | "localhost")
}

fn is_tld_match(h: &str, tld: &str) -> bool {
    h == tld
        || (h.len() == 2 * tld.len() + 1
            && h.starts_with(tld)
            && h.as_bytes().get(tld.len()) == Some(&b'.')
            && h.ends_with(tld))
}

async fn resolve_doh(
    dns_bytes: &[u8],
    src: SocketAddr,
    ctx: &std::sync::Arc<ServerCtx>,
) -> Response {
    let mut buffer = BytePacketBuffer::from_bytes(dns_bytes);
    let query = match DnsPacket::from_buffer(&mut buffer) {
        Ok(q) => q,
        Err(e) => {
            warn!("DoH: parse error from {}: {}", src, e);
            let query_id = u16::from_be_bytes([
                dns_bytes.first().copied().unwrap_or(0),
                dns_bytes.get(1).copied().unwrap_or(0),
            ]);
            let mut resp = DnsPacket::new();
            resp.header.id = query_id;
            resp.header.response = true;
            resp.header.rescode = ResultCode::FORMERR;
            return serialize_response(&resp);
        }
    };

    let query_id = query.header.id;
    let query_rd = query.header.recursion_desired;
    let questions = query.questions.clone();

    match resolve_query(query, dns_bytes, src, ctx, Transport::Doh).await {
        Ok((resp_buffer, _)) => {
            let min_ttl = extract_min_ttl(resp_buffer.filled());
            dns_response(resp_buffer.filled(), min_ttl)
        }
        Err(e) => {
            warn!("DoH: resolve error for {}: {}", src, e);
            let mut resp = DnsPacket::new();
            resp.header.id = query_id;
            resp.header.response = true;
            resp.header.recursion_desired = query_rd;
            resp.header.recursion_available = true;
            resp.header.rescode = ResultCode::SERVFAIL;
            resp.questions = questions;
            serialize_response(&resp)
        }
    }
}

fn extract_min_ttl(wire: &[u8]) -> u32 {
    crate::wire::scan_ttl_offsets(wire)
        .ok()
        .and_then(|meta| crate::wire::min_ttl_from_wire(wire, &meta))
        .unwrap_or(0)
}

fn dns_response(wire: &[u8], min_ttl: u32) -> Response {
    (
        StatusCode::OK,
        [
            (hyper::header::CONTENT_TYPE, DOH_CONTENT_TYPE),
            (
                hyper::header::CACHE_CONTROL,
                &format!("max-age={}", min_ttl),
            ),
        ],
        Bytes::copy_from_slice(wire),
    )
        .into_response()
}

fn serialize_response(pkt: &DnsPacket) -> Response {
    let mut buf = BytePacketBuffer::new();
    match pkt.write(&mut buf) {
        Ok(_) => dns_response(buf.filled(), 0),
        Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::buffer::BytePacketBuffer;
    use crate::header::ResultCode;
    use crate::packet::DnsPacket;
    use crate::record::DnsRecord;

    #[test]
    fn is_doh_host_matches_tld() {
        assert!(is_doh_host(Some("numa"), "numa"));
        assert!(is_doh_host(Some("numa.numa"), "numa"));
        assert!(is_doh_host(Some("127.0.0.1"), "numa"));
        assert!(is_doh_host(Some("127.0.0.1:443"), "numa"));
        assert!(is_doh_host(Some("::1"), "numa"));
        assert!(is_doh_host(Some("[::1]"), "numa"));
        assert!(is_doh_host(Some("[::1]:443"), "numa"));
        assert!(is_doh_host(Some("localhost"), "numa"));
        assert!(is_doh_host(Some("localhost:443"), "numa"));
        assert!(!is_doh_host(Some("foo.numa"), "numa"));
        assert!(!is_doh_host(None, "numa"));
    }

    #[test]
    fn extract_min_ttl_from_response() {
        let mut pkt = DnsPacket::new();
        pkt.header.response = true;
        pkt.answers.push(DnsRecord::A {
            domain: "example.com".to_string(),
            addr: std::net::Ipv4Addr::new(1, 2, 3, 4),
            ttl: 300,
        });
        pkt.answers.push(DnsRecord::A {
            domain: "example.com".to_string(),
            addr: std::net::Ipv4Addr::new(5, 6, 7, 8),
            ttl: 60,
        });
        let mut buf = BytePacketBuffer::new();
        pkt.write(&mut buf).unwrap();
        assert_eq!(extract_min_ttl(buf.filled()), 60);
    }

    #[test]
    fn extract_min_ttl_no_answers() {
        let mut pkt = DnsPacket::new();
        pkt.header.response = true;
        let mut buf = BytePacketBuffer::new();
        pkt.write(&mut buf).unwrap();
        assert_eq!(extract_min_ttl(buf.filled()), 0);
    }

    #[test]
    fn serialize_formerr_response() {
        let mut pkt = DnsPacket::new();
        pkt.header.id = 0xABCD;
        pkt.header.response = true;
        pkt.header.rescode = ResultCode::FORMERR;
        let resp = serialize_response(&pkt);
        assert_eq!(resp.status(), StatusCode::OK);
    }
}