1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
use crate::{DnsError, DnsMessage, DnsName, DnsOpCode, DnsQuestion, DnsRecord, DnsType};
use fixed_buffer::FixedBuf;
use multimap::MultiMap;
use oorandom::Rand32;
use permit::Permit;
use prob_rate_limiter::ProbRateLimiter;
use std::cell::RefCell;
use std::convert::TryFrom;
use std::io::ErrorKind;
use std::net::{IpAddr, Ipv6Addr, SocketAddr, UdpSocket};
use std::time::{Duration, Instant};

thread_local!(static RAND32: RefCell<Rand32> = RefCell::new(Rand32::new(0)));

/// # Errors
/// Returns `Err` when the request is malformed or the server is not configured to answer the
/// request.
pub fn process_request(
    request: &DnsMessage,
    handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
) -> Result<DnsMessage, DnsError> {
    if request.header.is_response {
        return Err(DnsError::NotARequest);
    }
    if request.header.op_code != DnsOpCode::Query {
        return Err(DnsError::InvalidOpCode);
    }
    // NOTE: We only answer the first question.
    let question = request.questions.first().ok_or(DnsError::NoQuestion)?;
    // u16::try_from(self.questions.len()).map_err(|_| ProcessError::TooManyQuestions)?,
    let records = handler(question);
    request.answer_response(records)
}

/// # Errors
/// Returns `Err` when the request is malformed or the server is not configured to answer the
/// request.
#[allow(clippy::implicit_hasher)]
pub fn process_datagram(
    bytes: &mut FixedBuf<512>,
    handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
) -> Result<FixedBuf<512>, DnsError> {
    //println!("process_datagram: bytes = {:?}", bytes.readable());
    let request = DnsMessage::read(bytes)?;
    //println!("process_datagram: request = {:?}", request);
    let response = process_request(&request, &handler)?;
    //println!("process_datagram: response = {:?}", response);
    let mut out: FixedBuf<512> = FixedBuf::new();
    response.write(&mut out)?;
    //println!("process_datagram: out = {:?}", out.readable());
    Ok(out)
}

/// # Errors
/// Returns `Err` when socket operations fail.
#[allow(clippy::missing_panics_doc)]
pub fn serve_udp(
    permit: &Permit,
    sock: &UdpSocket,
    mut response_bytes_rate_limiter: ProbRateLimiter,
    handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
) -> Result<(), String> {
    sock.set_read_timeout(Some(Duration::from_millis(500)))
        .map_err(|e| format!("error setting socket read timeout: {e}"))?;
    let addr = sock
        .local_addr()
        .map_err(|e| format!("error getting socket local address: {e}"))?;
    while !permit.is_revoked() {
        // > DNS messages carried by UDP are restricted to 512 bytes (not counting the IP
        // > or UDP headers).  Longer messages are truncated and the TC bit is set in
        // > the header.
        // https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1
        let mut buf: FixedBuf<512> = FixedBuf::new();
        let addr = match sock.recv_from(buf.writable()) {
            Ok((len, _)) if len > buf.writable().len() => {
                println!("dropping over-long request");
                continue;
            }
            Ok((len, addr)) => {
                buf.wrote(len);
                addr
            }
            Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
                continue
            }
            Err(e) => return Err(format!("error reading DNS server UDP socket {addr:?}: {e}")),
        };
        let now = Instant::now();
        if !response_bytes_rate_limiter.attempt(now) {
            println!("dropping request");
            continue;
        }
        let out = match process_datagram(&mut buf, handler) {
            Ok(buf) => buf,
            Err(e) => {
                println!("dropping bad request: {e:?}");
                continue;
            }
        };
        if out.is_empty() {
            unreachable!();
        }
        response_bytes_rate_limiter.record(u32::try_from(out.len()).unwrap());
        let sent_len = sock
            .send_to(out.readable(), addr)
            .map_err(|e| format!("error sending response to {addr:?}: {e}"))?;
        if sent_len != out.len() {
            return Err(format!(
                "sent only {sent_len} bytes of {} byte response to {addr:?}",
                out.len()
            ));
        }
    }
    Ok(())
}

pub struct Builder {
    permit: Option<Permit>,
    sock: UdpSocket,
    max_response_bytes_per_second: Option<u32>,
}
impl Builder {
    #[must_use]
    pub fn new(sock: UdpSocket) -> Self {
        Self {
            permit: None,
            sock,
            max_response_bytes_per_second: None,
        }
    }

    /// # Errors
    /// Returns `Err` when it failed to allocate a socket or bind it to the specified port.
    pub fn new_port(port: u16) -> Result<Self, String> {
        let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port))
            .map_err(|e| format!("error binding to UDP port {port}: {e}"))?;
        Ok(Self::new(sock))
    }

    /// # Errors
    /// Returns `Err` when it failed to allocate a socket or bind it an available port.
    pub fn new_random_port() -> Result<(Self, SocketAddr), String> {
        let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))
            .map_err(|e| format!("error binding to random UDP port: {e}"))?;
        let addr = sock
            .local_addr()
            .map_err(|e| format!("error getting socket local address: {e}"))?;
        Ok((Self::new(sock), addr))
    }

    #[must_use]
    pub fn with_permit(mut self, permit: Permit) -> Self {
        self.permit = Some(permit);
        self
    }

    #[must_use]
    pub fn with_max_response_bytes_per_second(mut self, n: u32) -> Self {
        self.max_response_bytes_per_second = Some(n);
        self
    }

    /// # Errors
    /// Returns `Err` when socket operations fail.
    pub fn serve(self, handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>) -> Result<(), String> {
        let permit = self.permit.unwrap_or_default();
        let max_response_bytes_per_second = self.max_response_bytes_per_second.unwrap_or(1_000_000);
        let limiter = ProbRateLimiter::new(max_response_bytes_per_second);
        serve_udp(&permit, &self.sock, limiter, handler)
    }

    /// # Errors
    /// Returns `Err` when socket operations fail.
    pub fn serve_static(self, records: &[DnsRecord]) -> Result<(), String> {
        let name_to_records: MultiMap<&DnsName, &DnsRecord> =
            records.iter().map(|x| (x.name(), x)).collect();
        let handler = move |q: &DnsQuestion| {
            let Some(record_refs) = name_to_records.get_vec(&q.name) else {
                return Vec::new();
            };
            let mut records: Vec<DnsRecord> = record_refs.iter().map(|r| (*r).clone()).collect();
            if q.typ != DnsType::ANY {
                records.retain(|r| r.typ() == q.typ);
            }
            if !records.is_empty() {
                let range = 0..(u32::try_from(records.len()).unwrap_or(u32::MAX));
                let k = RAND32.with_borrow_mut(|r| r.rand_range(range)) as usize;
                records.rotate_right(k);
            }
            records
        };
        self.serve(&handler)
    }
}