use crate::{DnsError, DnsMessage, DnsName, DnsOpCode, DnsQuestion, DnsRecord, DnsType};
use fixed_buffer::FixedBuf;
use multimap::MultiMap;
use oorandom::Rand32;
use permit::Permit;
use prob_rate_limiter::ProbRateLimiter;
use std::cell::RefCell;
use std::convert::TryFrom;
use std::io::ErrorKind;
use std::net::{IpAddr, Ipv6Addr, SocketAddr, UdpSocket};
use std::time::{Duration, Instant};
thread_local!(static RAND32: RefCell<Rand32> = RefCell::new(Rand32::new(0)));
pub fn process_request(
request: &DnsMessage,
handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
) -> Result<DnsMessage, DnsError> {
if request.header.is_response {
return Err(DnsError::NotARequest);
}
if request.header.op_code != DnsOpCode::Query {
return Err(DnsError::InvalidOpCode);
}
let question = request.questions.first().ok_or(DnsError::NoQuestion)?;
let records = handler(question);
request.answer_response(records)
}
#[allow(clippy::implicit_hasher)]
pub fn process_datagram(
bytes: &mut FixedBuf<512>,
handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
) -> Result<FixedBuf<512>, DnsError> {
let request = DnsMessage::read(bytes)?;
let response = process_request(&request, &handler)?;
let mut out: FixedBuf<512> = FixedBuf::new();
response.write(&mut out)?;
Ok(out)
}
#[allow(clippy::missing_panics_doc)]
pub fn serve_udp(
permit: &Permit,
sock: &UdpSocket,
mut response_bytes_rate_limiter: ProbRateLimiter,
handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>,
) -> Result<(), String> {
sock.set_read_timeout(Some(Duration::from_millis(500)))
.map_err(|e| format!("error setting socket read timeout: {e}"))?;
let addr = sock
.local_addr()
.map_err(|e| format!("error getting socket local address: {e}"))?;
while !permit.is_revoked() {
let mut buf: FixedBuf<512> = FixedBuf::new();
let addr = match sock.recv_from(buf.writable()) {
Ok((len, _)) if len > buf.writable().len() => {
println!("dropping over-long request");
continue;
}
Ok((len, addr)) => {
buf.wrote(len);
addr
}
Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
continue
}
Err(e) => return Err(format!("error reading DNS server UDP socket {addr:?}: {e}")),
};
let now = Instant::now();
if !response_bytes_rate_limiter.attempt(now) {
println!("dropping request");
continue;
}
let out = match process_datagram(&mut buf, handler) {
Ok(buf) => buf,
Err(e) => {
println!("dropping bad request: {e:?}");
continue;
}
};
if out.is_empty() {
unreachable!();
}
response_bytes_rate_limiter.record(u32::try_from(out.len()).unwrap());
let sent_len = sock
.send_to(out.readable(), addr)
.map_err(|e| format!("error sending response to {addr:?}: {e}"))?;
if sent_len != out.len() {
return Err(format!(
"sent only {sent_len} bytes of {} byte response to {addr:?}",
out.len()
));
}
}
Ok(())
}
pub struct Builder {
permit: Option<Permit>,
sock: UdpSocket,
max_response_bytes_per_second: Option<u32>,
}
impl Builder {
#[must_use]
pub fn new(sock: UdpSocket) -> Self {
Self {
permit: None,
sock,
max_response_bytes_per_second: None,
}
}
pub fn new_port(port: u16) -> Result<Self, String> {
let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port))
.map_err(|e| format!("error binding to UDP port {port}: {e}"))?;
Ok(Self::new(sock))
}
pub fn new_random_port() -> Result<(Self, SocketAddr), String> {
let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))
.map_err(|e| format!("error binding to random UDP port: {e}"))?;
let addr = sock
.local_addr()
.map_err(|e| format!("error getting socket local address: {e}"))?;
Ok((Self::new(sock), addr))
}
#[must_use]
pub fn with_permit(mut self, permit: Permit) -> Self {
self.permit = Some(permit);
self
}
#[must_use]
pub fn with_max_response_bytes_per_second(mut self, n: u32) -> Self {
self.max_response_bytes_per_second = Some(n);
self
}
pub fn serve(self, handler: &impl Fn(&DnsQuestion) -> Vec<DnsRecord>) -> Result<(), String> {
let permit = self.permit.unwrap_or_default();
let max_response_bytes_per_second = self.max_response_bytes_per_second.unwrap_or(1_000_000);
let limiter = ProbRateLimiter::new(max_response_bytes_per_second);
serve_udp(&permit, &self.sock, limiter, handler)
}
pub fn serve_static(self, records: &[DnsRecord]) -> Result<(), String> {
let name_to_records: MultiMap<&DnsName, &DnsRecord> =
records.iter().map(|x| (x.name(), x)).collect();
let handler = move |q: &DnsQuestion| {
let Some(record_refs) = name_to_records.get_vec(&q.name) else {
return Vec::new();
};
let mut records: Vec<DnsRecord> = record_refs.iter().map(|r| (*r).clone()).collect();
if q.typ != DnsType::ANY {
records.retain(|r| r.typ() == q.typ);
}
if !records.is_empty() {
let range = 0..(u32::try_from(records.len()).unwrap_or(u32::MAX));
let k = RAND32.with_borrow_mut(|r| r.rand_range(range)) as usize;
records.rotate_right(k);
}
records
};
self.serve(&handler)
}
}