1use crate::{DnsError, DnsMessage, DnsOpCode, DnsQuestion, DnsRecord, DnsType};
2use fixed_buffer::FixedBuf;
3use oorandom::Rand32;
4use permit::Permit;
5use prob_rate_limiter::ProbRateLimiter;
6use std::cell::RefCell;
7use std::collections::HashMap;
8use std::convert::TryFrom;
9use std::io::ErrorKind;
10use std::net::{IpAddr, Ipv6Addr, SocketAddr, UdpSocket};
11use std::time::{Duration, Instant};
12
13thread_local!(static RAND32: RefCell<Rand32> = RefCell::new(Rand32::new(0)));
14
15pub fn process_request(
19 request: &DnsMessage,
20 handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
21) -> Result<DnsMessage, DnsError> {
22 if request.header.is_response {
23 return Err(DnsError::NotARequest);
24 }
25 if request.header.op_code != DnsOpCode::Query {
26 return Err(DnsError::InvalidOpCode);
27 }
28 let question = request.questions.first().ok_or(DnsError::NoQuestion)?;
30 let records = handler(question);
32 request.answer_response(records)
33}
34
35#[allow(clippy::implicit_hasher)]
39pub fn process_datagram(
40 bytes: &mut FixedBuf<512>,
41 handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
42) -> Result<FixedBuf<512>, DnsError> {
43 let request = DnsMessage::read(bytes)?;
45 let response = process_request(&request, &handler)?;
47 let mut out: FixedBuf<512> = FixedBuf::new();
49 response.write(&mut out)?;
50 Ok(out)
52}
53
54#[allow(clippy::missing_panics_doc)]
57pub fn serve_udp(
58 permit: &Permit,
59 sock: &UdpSocket,
60 mut response_bytes_rate_limiter: ProbRateLimiter,
61 handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
62) -> Result<(), String> {
63 sock.set_read_timeout(Some(Duration::from_millis(500)))
64 .map_err(|e| format!("error setting socket read timeout: {e}"))?;
65 let addr = sock
66 .local_addr()
67 .map_err(|e| format!("error getting socket local address: {e}"))?;
68 while !permit.is_revoked() {
69 let mut buf: FixedBuf<512> = FixedBuf::new();
74 let addr = match sock.recv_from(buf.writable()) {
75 Ok((len, _)) if len > buf.writable().len() => {
76 println!("dropping over-long request");
77 continue;
78 }
79 Ok((len, addr)) => {
80 buf.wrote(len);
81 addr
82 }
83 Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
84 continue
85 }
86 Err(e) => return Err(format!("error reading DNS server UDP socket {addr:?}: {e}")),
87 };
88 let now = Instant::now();
89 if !response_bytes_rate_limiter.attempt(now) {
90 println!("dropping request");
91 continue;
92 }
93 let out = match process_datagram(&mut buf, handler) {
94 Ok(buf) => buf,
95 Err(e) => {
96 println!("dropping bad request: {e:?}");
97 continue;
98 }
99 };
100 if out.is_empty() {
101 unreachable!();
102 }
103 response_bytes_rate_limiter.record(u32::try_from(out.len()).unwrap());
104 let sent_len = sock
105 .send_to(out.readable(), addr)
106 .map_err(|e| format!("error sending response to {addr:?}: {e}"))?;
107 if sent_len != out.len() {
108 return Err(format!(
109 "sent only {sent_len} bytes of {} byte response to {addr:?}",
110 out.len()
111 ));
112 }
113 }
114 Ok(())
115}
116
117pub struct Builder {
118 permit: Option<Permit>,
119 sock: UdpSocket,
120 max_response_bytes_per_second: Option<u32>,
121}
122impl Builder {
123 #[must_use]
124 pub fn new(sock: UdpSocket) -> Self {
125 Self {
126 permit: None,
127 sock,
128 max_response_bytes_per_second: None,
129 }
130 }
131
132 pub fn new_port(port: u16) -> Result<Self, String> {
135 let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port))
136 .map_err(|e| format!("error binding to UDP port {port}: {e}"))?;
137 Ok(Self::new(sock))
138 }
139
140 pub fn new_random_port() -> Result<(Self, SocketAddr), String> {
143 let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))
144 .map_err(|e| format!("error binding to random UDP port: {e}"))?;
145 let addr = sock
146 .local_addr()
147 .map_err(|e| format!("error getting socket local address: {e}"))?;
148 Ok((Self::new(sock), addr))
149 }
150
151 #[must_use]
152 pub fn with_permit(mut self, permit: Permit) -> Self {
153 self.permit = Some(permit);
154 self
155 }
156
157 #[must_use]
158 pub fn with_max_response_bytes_per_second(mut self, n: u32) -> Self {
159 self.max_response_bytes_per_second = Some(n);
160 self
161 }
162
163 pub fn serve(self, handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>) -> Result<(), String> {
166 let permit = self.permit.unwrap_or_default();
167 let max_response_bytes_per_second = self.max_response_bytes_per_second.unwrap_or(1_000_000);
168 let limiter = ProbRateLimiter::new(max_response_bytes_per_second);
169 serve_udp(&permit, &self.sock, limiter, handler)
170 }
171
172 pub fn serve_static(self, records: &[DnsRecord]) -> Result<(), String> {
175 let mut name_to_records: HashMap<String, Vec<&DnsRecord>> = HashMap::default();
176 for record in records {
177 let key = record.name().inner().to_ascii_lowercase();
178 if let Some(v) = name_to_records.get_mut(&key) {
179 v.push(record);
180 } else {
181 name_to_records.insert(key, vec![&record]);
182 }
183 }
184 let handler = move |q: &DnsQuestion| {
185 let key = q.name.inner().to_ascii_lowercase();
186 let Some(record_refs) = name_to_records.get(&key) else {
187 return Vec::new();
188 };
189 let mut records: Vec<DnsRecord> = record_refs.iter().map(|r| (*r).clone()).collect();
190 if q.typ != DnsType::ANY {
191 records.retain(|r| r.typ() == q.typ);
192 }
193 if !records.is_empty() {
194 let range = 0..(u32::try_from(records.len()).unwrap_or(u32::MAX));
195 let k = RAND32.with_borrow_mut(|r| r.rand_range(range)) as usize;
196 records.rotate_right(k);
197 }
198 records
199 };
200 self.serve(&handler)
201 }
202}