dns_server/
server.rs

1use crate::{DnsError, DnsMessage, DnsOpCode, DnsQuestion, DnsRecord, DnsType};
2use fixed_buffer::FixedBuf;
3use oorandom::Rand32;
4use permit::Permit;
5use prob_rate_limiter::ProbRateLimiter;
6use std::cell::RefCell;
7use std::collections::HashMap;
8use std::convert::TryFrom;
9use std::io::ErrorKind;
10use std::net::{IpAddr, Ipv6Addr, SocketAddr, UdpSocket};
11use std::time::{Duration, Instant};
12
13thread_local!(static RAND32: RefCell<Rand32> = RefCell::new(Rand32::new(0)));
14
15/// # Errors
16/// Returns `Err` when the request is malformed or the server is not configured to answer the
17/// request.
18pub fn process_request(
19    request: &DnsMessage,
20    handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
21) -> Result<DnsMessage, DnsError> {
22    if request.header.is_response {
23        return Err(DnsError::NotARequest);
24    }
25    if request.header.op_code != DnsOpCode::Query {
26        return Err(DnsError::InvalidOpCode);
27    }
28    // NOTE: We only answer the first question.
29    let question = request.questions.first().ok_or(DnsError::NoQuestion)?;
30    // u16::try_from(self.questions.len()).map_err(|_| ProcessError::TooManyQuestions)?,
31    let records = handler(question);
32    request.answer_response(records)
33}
34
35/// # Errors
36/// Returns `Err` when the request is malformed or the server is not configured to answer the
37/// request.
38#[allow(clippy::implicit_hasher)]
39pub fn process_datagram(
40    bytes: &mut FixedBuf<512>,
41    handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
42) -> Result<FixedBuf<512>, DnsError> {
43    //println!("process_datagram: bytes = {:?}", bytes.readable());
44    let request = DnsMessage::read(bytes)?;
45    //println!("process_datagram: request = {:?}", request);
46    let response = process_request(&request, &handler)?;
47    //println!("process_datagram: response = {:?}", response);
48    let mut out: FixedBuf<512> = FixedBuf::new();
49    response.write(&mut out)?;
50    //println!("process_datagram: out = {:?}", out.readable());
51    Ok(out)
52}
53
54/// # Errors
55/// Returns `Err` when socket operations fail.
56#[allow(clippy::missing_panics_doc)]
57pub fn serve_udp(
58    permit: &Permit,
59    sock: &UdpSocket,
60    mut response_bytes_rate_limiter: ProbRateLimiter,
61    handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
62) -> Result<(), String> {
63    sock.set_read_timeout(Some(Duration::from_millis(500)))
64        .map_err(|e| format!("error setting socket read timeout: {e}"))?;
65    let addr = sock
66        .local_addr()
67        .map_err(|e| format!("error getting socket local address: {e}"))?;
68    while !permit.is_revoked() {
69        // > DNS messages carried by UDP are restricted to 512 bytes (not counting the IP
70        // > or UDP headers).  Longer messages are truncated and the TC bit is set in
71        // > the header.
72        // https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1
73        let mut buf: FixedBuf<512> = FixedBuf::new();
74        let addr = match sock.recv_from(buf.writable()) {
75            Ok((len, _)) if len > buf.writable().len() => {
76                println!("dropping over-long request");
77                continue;
78            }
79            Ok((len, addr)) => {
80                buf.wrote(len);
81                addr
82            }
83            Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
84                continue
85            }
86            Err(e) => return Err(format!("error reading DNS server UDP socket {addr:?}: {e}")),
87        };
88        let now = Instant::now();
89        if !response_bytes_rate_limiter.attempt(now) {
90            println!("dropping request");
91            continue;
92        }
93        let out = match process_datagram(&mut buf, handler) {
94            Ok(buf) => buf,
95            Err(e) => {
96                println!("dropping bad request: {e:?}");
97                continue;
98            }
99        };
100        if out.is_empty() {
101            unreachable!();
102        }
103        response_bytes_rate_limiter.record(u32::try_from(out.len()).unwrap());
104        let sent_len = sock
105            .send_to(out.readable(), addr)
106            .map_err(|e| format!("error sending response to {addr:?}: {e}"))?;
107        if sent_len != out.len() {
108            return Err(format!(
109                "sent only {sent_len} bytes of {} byte response to {addr:?}",
110                out.len()
111            ));
112        }
113    }
114    Ok(())
115}
116
117pub struct Builder {
118    permit: Option<Permit>,
119    sock: UdpSocket,
120    max_response_bytes_per_second: Option<u32>,
121}
122impl Builder {
123    #[must_use]
124    pub fn new(sock: UdpSocket) -> Self {
125        Self {
126            permit: None,
127            sock,
128            max_response_bytes_per_second: None,
129        }
130    }
131
132    /// # Errors
133    /// Returns `Err` when it failed to allocate a socket or bind it to the specified port.
134    pub fn new_port(port: u16) -> Result<Self, String> {
135        let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port))
136            .map_err(|e| format!("error binding to UDP port {port}: {e}"))?;
137        Ok(Self::new(sock))
138    }
139
140    /// # Errors
141    /// Returns `Err` when it failed to allocate a socket or bind it an available port.
142    pub fn new_random_port() -> Result<(Self, SocketAddr), String> {
143        let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))
144            .map_err(|e| format!("error binding to random UDP port: {e}"))?;
145        let addr = sock
146            .local_addr()
147            .map_err(|e| format!("error getting socket local address: {e}"))?;
148        Ok((Self::new(sock), addr))
149    }
150
151    #[must_use]
152    pub fn with_permit(mut self, permit: Permit) -> Self {
153        self.permit = Some(permit);
154        self
155    }
156
157    #[must_use]
158    pub fn with_max_response_bytes_per_second(mut self, n: u32) -> Self {
159        self.max_response_bytes_per_second = Some(n);
160        self
161    }
162
163    /// # Errors
164    /// Returns `Err` when socket operations fail.
165    pub fn serve(self, handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>) -> Result<(), String> {
166        let permit = self.permit.unwrap_or_default();
167        let max_response_bytes_per_second = self.max_response_bytes_per_second.unwrap_or(1_000_000);
168        let limiter = ProbRateLimiter::new(max_response_bytes_per_second);
169        serve_udp(&permit, &self.sock, limiter, handler)
170    }
171
172    /// # Errors
173    /// Returns `Err` when socket operations fail.
174    pub fn serve_static(self, records: &[DnsRecord]) -> Result<(), String> {
175        let mut name_to_records: HashMap<String, Vec<&DnsRecord>> = HashMap::default();
176        for record in records {
177            let key = record.name().inner().to_ascii_lowercase();
178            if let Some(v) = name_to_records.get_mut(&key) {
179                v.push(record);
180            } else {
181                name_to_records.insert(key, vec![&record]);
182            }
183        }
184        let handler = move |q: &DnsQuestion| {
185            let key = q.name.inner().to_ascii_lowercase();
186            let Some(record_refs) = name_to_records.get(&key) else {
187                return Vec::new();
188            };
189            let mut records: Vec<DnsRecord> = record_refs.iter().map(|r| (*r).clone()).collect();
190            if q.typ != DnsType::ANY {
191                records.retain(|r| r.typ() == q.typ);
192            }
193            if !records.is_empty() {
194                let range = 0..(u32::try_from(records.len()).unwrap_or(u32::MAX));
195                let k = RAND32.with_borrow_mut(|r| r.rand_range(range)) as usize;
196                records.rotate_right(k);
197            }
198            records
199        };
200        self.serve(&handler)
201    }
202}