use crate::dns::Message;
use crate::server::{RequestHandler, Server, ServerConfig};
use crate::{Error, Result};
use std::sync::Arc;
use tokio::net::UdpSocket;
use tracing::{debug, error, info, trace, warn};
pub struct UdpServer {
socket: Arc<UdpSocket>,
handler: Arc<dyn RequestHandler>,
config: ServerConfig,
}
impl UdpServer {
pub async fn new(config: ServerConfig, handler: Arc<dyn RequestHandler>) -> Result<Self> {
let addr = config
.udp_addr
.ok_or_else(|| Error::Config("UDP address not configured".to_string()))?;
let socket = UdpSocket::bind(addr).await.map_err(Error::Io)?;
info!("UDP server listening on {}", addr);
Ok(Self {
socket: Arc::new(socket),
handler,
config,
})
}
pub fn local_addr(&self) -> Result<std::net::SocketAddr> {
self.socket.local_addr().map_err(Error::Io)
}
pub async fn run(&self) -> Result<()> {
let mut buf = vec![0u8; self.config.max_udp_size];
info!("UDP server started");
loop {
match self.socket.recv_from(&mut buf).await {
Ok((len, peer_addr)) => {
trace!("Received {} bytes from {}", len, peer_addr);
let request_data = buf[..len].to_vec();
let handler = Arc::clone(&self.handler);
let socket = self.socket.clone();
tokio::spawn(async move {
if let Err(e) =
Self::handle_request(&request_data, peer_addr, handler, socket).await
{
warn!("Error handling request from {}: {}", peer_addr, e);
}
});
}
Err(e) => {
error!("Error receiving UDP packet: {}", e);
}
}
}
}
async fn handle_request(
request_data: &[u8],
peer_addr: std::net::SocketAddr,
handler: Arc<dyn RequestHandler>,
socket: Arc<UdpSocket>,
) -> Result<()> {
let request = Self::parse_request(request_data)?;
debug!(
peer = %peer_addr,
question = ?request.questions(),
"Processing query ID {} with {} questions from {}",
request.id(),
request.question_count(),
peer_addr
);
let req_id = request.id();
let ctx = crate::server::RequestContext::with_client(
request,
Some(peer_addr),
crate::server::Protocol::Udp,
);
let mut response = handler.handle(ctx).await?;
response.set_id(req_id);
trace!(
"Sending response ID {} with {} answers to {}",
response.id(),
response.answer_count(),
peer_addr
);
let response_data = Self::serialize_response(&response)?;
socket
.send_to(&response_data, peer_addr)
.await
.map_err(Error::Io)?;
Ok(())
}
fn parse_request(data: &[u8]) -> Result<Message> {
crate::dns::wire::parse_message(data)
}
fn serialize_response(message: &Message) -> Result<Vec<u8>> {
crate::dns::wire::serialize_message(message)
}
}
#[async_trait::async_trait]
impl Server for UdpServer {
async fn from_config(config: ServerConfig) -> Result<Self> {
let handler = config
.handler
.clone()
.ok_or_else(|| Error::Config("Handler not configured".to_string()))?;
Self::new(config, handler).await
}
async fn run(self) -> Result<()> {
UdpServer::run(&self).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::wire;
use crate::dns::{Question, RecordClass, RecordType};
use crate::server::DefaultHandler;
#[tokio::test]
async fn test_udp_server_creation() {
let config = ServerConfig::default().with_udp_addr("127.0.0.1:0".parse().unwrap());
let handler = Arc::new(DefaultHandler);
let server = UdpServer::new(config, handler).await;
assert!(server.is_ok());
}
#[tokio::test]
async fn test_udp_server_local_addr() {
let config = ServerConfig::default().with_udp_addr("127.0.0.1:0".parse().unwrap());
let handler = Arc::new(DefaultHandler);
let server = UdpServer::new(config, handler).await.unwrap();
let addr = server.local_addr();
assert!(addr.is_ok());
assert_eq!(addr.unwrap().ip(), std::net::Ipv4Addr::LOCALHOST);
}
#[tokio::test]
async fn test_udp_server_creation_without_udp_addr() {
let config = ServerConfig::new(None, None); let handler = Arc::new(DefaultHandler);
let server = UdpServer::new(config, handler).await;
assert!(server.is_err());
if let Err(Error::Config(_)) = server {
} else {
panic!("Expected Config error");
}
}
#[tokio::test]
async fn test_parse_request_with_real_dns_message() {
let mut req = Message::new();
req.set_id(0x42);
req.set_query(true);
req.add_question(Question::new(
"example.test".to_string(),
RecordType::A,
RecordClass::IN,
));
let data = wire::serialize_message(&req).expect("serialize request");
let parsed = UdpServer::parse_request(&data).expect("parse request");
assert_eq!(parsed.id(), 0x42);
assert_eq!(parsed.question_count(), 1);
assert!(!parsed.is_response()); }
#[tokio::test]
async fn test_serialize_response_with_real_dns_message() {
let mut resp = Message::new();
resp.set_id(0x99);
resp.set_response(true);
resp.add_question(Question::new(
"example.test".to_string(),
RecordType::A,
RecordClass::IN,
));
let data = UdpServer::serialize_response(&resp).expect("serialize response");
assert!(data.len() >= 12);
let parsed = wire::parse_message(&data).expect("parse serialized response");
assert_eq!(parsed.id(), 0x99);
assert!(parsed.is_response());
assert_eq!(parsed.question_count(), 1);
}
#[tokio::test]
async fn test_parse_request_placeholder() {
let data = vec![0u8; 12];
let message = UdpServer::parse_request(&data);
assert!(message.is_ok());
}
#[tokio::test]
async fn test_serialize_response_placeholder() {
let message = Message::new();
let data = UdpServer::serialize_response(&message);
assert!(data.is_ok());
assert_eq!(data.unwrap().len(), 12); }
#[tokio::test]
async fn test_parse_request_with_invalid_data() {
let data = vec![0u8; 5]; let message = UdpServer::parse_request(&data);
assert!(message.is_err());
}
#[tokio::test]
async fn test_serialize_response_with_complex_message() {
let mut resp = Message::new();
resp.set_id(0x1234);
resp.set_response(true);
resp.set_recursion_available(true);
resp.add_question(Question::new(
"complex.example.test".to_string(),
RecordType::AAAA,
RecordClass::IN,
));
let data = UdpServer::serialize_response(&resp).expect("serialize complex response");
assert!(data.len() > 12);
let parsed = wire::parse_message(&data).expect("parse complex response");
assert_eq!(parsed.id(), 0x1234);
assert!(parsed.is_response());
assert!(parsed.recursion_available());
}
}