use crate::dns::Message;
use crate::server::{RequestHandler, Server, ServerConfig};
use crate::{Error, Result};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tracing::{debug, error, info, trace};
pub struct TcpServer {
listener: TcpListener,
handler: Arc<dyn RequestHandler>,
config: ServerConfig,
}
impl TcpServer {
pub async fn new(config: ServerConfig, handler: Arc<dyn RequestHandler>) -> Result<Self> {
let addr = config
.tcp_addr
.ok_or_else(|| Error::Config("TCP address not configured".to_string()))?;
let listener = TcpListener::bind(addr).await.map_err(Error::Io)?;
info!("TCP server listening on {}", addr);
Ok(Self {
listener,
handler,
config,
})
}
pub async fn run(&self) -> Result<()> {
info!("TCP server started");
loop {
match self.listener.accept().await {
Ok((stream, peer_addr)) => {
debug!("Accepted connection from {}", peer_addr);
let handler = Arc::clone(&self.handler);
let max_size = self.config.max_tcp_size;
tokio::spawn(async move {
if let Err(e) =
Self::handle_connection(stream, peer_addr, handler, max_size).await
{
error!("Error handling connection from {}: {}", peer_addr, e);
}
});
}
Err(e) => {
error!("Error accepting TCP connection: {}", e);
}
}
}
}
async fn handle_connection(
mut stream: TcpStream,
peer_addr: std::net::SocketAddr,
handler: Arc<dyn RequestHandler>,
max_size: usize,
) -> Result<()> {
let mut len_buf = [0u8; 2];
stream.read_exact(&mut len_buf).await.map_err(Error::Io)?;
let msg_len = u16::from_be_bytes(len_buf) as usize;
if msg_len > max_size {
return Err(Error::DnsProtocol(format!(
"Message too large: {} > {}",
msg_len, max_size
)));
}
trace!("Reading {} bytes", msg_len);
let mut buf = vec![0u8; msg_len];
stream.read_exact(&mut buf).await.map_err(Error::Io)?;
let request = Self::parse_request(&buf)?;
debug!(
peer = %peer_addr,
question = ?request.questions(),
"Processing query ID {} with {} questions",
request.id(),
request.question_count()
);
let ctx = crate::server::RequestContext::with_client(
request,
Some(peer_addr),
crate::server::Protocol::Tcp,
);
let response = handler.handle(ctx).await?;
trace!(
"Sending response ID {} with {} answers",
response.id(),
response.answer_count()
);
let response_data = Self::serialize_response(&response)?;
let len = response_data.len() as u16;
stream
.write_all(&len.to_be_bytes())
.await
.map_err(Error::Io)?;
stream.write_all(&response_data).await.map_err(Error::Io)?;
stream.flush().await.map_err(Error::Io)?;
Ok(())
}
fn parse_request(data: &[u8]) -> Result<Message> {
crate::dns::wire::parse_message(data)
}
fn serialize_response(message: &Message) -> Result<Vec<u8>> {
crate::dns::wire::serialize_message(message)
}
}
#[async_trait::async_trait]
impl Server for TcpServer {
async fn from_config(config: ServerConfig) -> Result<Self> {
let handler = config
.handler
.clone()
.ok_or_else(|| Error::Config("Handler not configured".to_string()))?;
Self::new(config, handler).await
}
async fn run(self) -> Result<()> {
TcpServer::run(&self).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::wire;
use crate::dns::{Question, RecordClass, RecordType};
use crate::server::DefaultHandler;
use tokio::net::{TcpListener, TcpStream};
#[tokio::test]
async fn test_tcp_server_creation() {
let config = ServerConfig::default().with_tcp_addr("127.0.0.1:0".parse().unwrap());
let handler = Arc::new(DefaultHandler);
let server = TcpServer::new(config, handler).await;
assert!(server.is_ok());
}
#[tokio::test]
async fn test_parse_request_and_serialize_response() {
let mut req = Message::new();
req.set_id(0x42);
req.set_query(true);
req.add_question(Question::new(
"example.test".to_string(),
RecordType::A,
RecordClass::IN,
));
let data = wire::serialize_message(&req).expect("serialize request");
let parsed = TcpServer::parse_request(&data).expect("parse request");
assert_eq!(parsed.id(), 0x42);
assert_eq!(parsed.question_count(), 1);
let mut resp = parsed.clone();
resp.set_response(true);
let resp_data = TcpServer::serialize_response(&resp).expect("serialize response");
assert!(resp_data.len() >= 12);
}
#[tokio::test]
async fn test_handle_connection_roundtrip() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
if let Ok((stream, _peer)) = listener.accept().await {
let handler = Arc::new(DefaultHandler);
let _ = TcpServer::handle_connection(
stream,
"127.0.0.1:12345".parse().unwrap(),
handler,
64 * 1024,
)
.await;
}
});
let mut client = TcpStream::connect(addr).await.unwrap();
let mut req = Message::new();
req.set_id(0x99);
req.set_query(true);
req.add_question(Question::new(
"roundtrip.test".to_string(),
RecordType::AAAA,
RecordClass::IN,
));
let req_data = wire::serialize_message(&req).expect("serialize client request");
let len = req_data.len() as u16;
client
.write_all(&len.to_be_bytes())
.await
.map_err(|e| eprintln!("write len: {}", e))
.ok();
client
.write_all(&req_data)
.await
.map_err(|e| eprintln!("write data: {}", e))
.ok();
let mut len_buf = [0u8; 2];
client.read_exact(&mut len_buf).await.unwrap();
let resp_len = u16::from_be_bytes(len_buf) as usize;
let mut resp_buf = vec![0u8; resp_len];
client.read_exact(&mut resp_buf).await.unwrap();
let response = wire::parse_message(&resp_buf).expect("parse response");
assert!(response.is_response());
assert_eq!(response.id(), 0x99);
let _ = server_task.await;
}
}