use crate::error::{Error, Result};
use crate::server::{RequestHandler, Server, ServerConfig, TlsConfig};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, trace, warn};
pub struct DotServer {
addr: String,
tls_config: TlsConfig,
handler: Arc<dyn RequestHandler>,
}
impl DotServer {
pub fn new(
addr: impl Into<String>,
tls_config: TlsConfig,
handler: Arc<dyn RequestHandler>,
) -> Self {
Self {
addr: addr.into(),
tls_config,
handler,
}
}
pub async fn run(self) -> Result<()> {
let listener = TcpListener::bind(&self.addr).await.map_err(Error::Io)?;
info!("DoT server listening on {}", self.addr);
let tls_config = self.tls_config.build_server_config()?;
let acceptor = TlsAcceptor::from(tls_config);
loop {
let (stream, peer_addr) = match listener.accept().await {
Ok(conn) => conn,
Err(e) => {
error!("Failed to accept connection: {}", e);
continue;
}
};
debug!("DoT connection from {}", peer_addr);
let acceptor = acceptor.clone();
let handler = Arc::clone(&self.handler);
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(stream, acceptor, handler).await {
error!("Error handling DoT connection from {}: {}", peer_addr, e);
}
});
}
}
async fn handle_connection(
stream: TcpStream,
acceptor: TlsAcceptor,
handler: Arc<dyn RequestHandler>,
) -> Result<()> {
let peer_addr = stream.peer_addr().ok();
let mut tls_stream = acceptor
.accept(stream)
.await
.map_err(|e| Error::Other(format!("TLS handshake failed: {}", e)))?;
debug!(peer = ?peer_addr, "TLS handshake succeeded for DoT connection");
loop {
let mut len_buf = [0u8; 2];
match tls_stream.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
debug!("DoT client closed connection");
break;
}
Err(e) => {
return Err(Error::Io(e));
}
}
let msg_len = u16::from_be_bytes(len_buf) as usize;
if msg_len == 0 || msg_len > 65535 {
warn!("Invalid DoT message length: {}", msg_len);
break;
}
let mut buf = vec![0u8; msg_len];
trace!(peer = ?peer_addr, len = msg_len, "Reading DoT message");
tls_stream.read_exact(&mut buf).await.map_err(Error::Io)?;
let request = Self::parse_request(&buf)?;
debug!(
peer = ?peer_addr,
question = ?request.questions(),
"Processing DoT query ID {} with {} questions",
request.id(),
request.question_count()
);
let ctx = crate::server::RequestContext::new(request, crate::server::Protocol::DoT);
let response = handler.handle(ctx).await?;
let response_data = Self::serialize_response(&response)?;
trace!(peer = ?peer_addr, id = response.id(), answers = response.answer_count(), "Sending DoT response");
let response_len = response_data.len() as u16;
tls_stream
.write_all(&response_len.to_be_bytes())
.await
.map_err(Error::Io)?;
tls_stream
.write_all(&response_data)
.await
.map_err(Error::Io)?;
tls_stream.flush().await.map_err(Error::Io)?;
}
Ok(())
}
fn parse_request(data: &[u8]) -> Result<crate::dns::Message> {
crate::dns::wire::parse_message(data)
}
fn serialize_response(message: &crate::dns::Message) -> Result<Vec<u8>> {
crate::dns::wire::serialize_message(message)
}
}
#[async_trait::async_trait]
impl Server for DotServer {
async fn from_config(config: ServerConfig) -> Result<Self> {
let addr = config
.tcp_addr
.ok_or_else(|| Error::Config("TCP address not configured for DoT".to_string()))?
.to_string();
let tls_config = config
.tls_config
.ok_or_else(|| Error::Config("TLS config not configured for DoT".to_string()))?;
let handler = config
.handler
.ok_or_else(|| Error::Config("Handler not configured".to_string()))?;
Ok(Self::new(addr, tls_config, handler))
}
async fn run(self) -> Result<()> {
DotServer::run(self).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_request() {
let data = vec![0u8; 12]; let result = DotServer::parse_request(&data);
assert!(result.is_ok());
}
#[test]
fn test_serialize_response() {
let message = crate::dns::Message::new();
let result = DotServer::serialize_response(&message);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 12);
}
#[tokio::test]
async fn test_parse_request_invalid() {
let data: Vec<u8> = vec![];
let result = DotServer::parse_request(&data);
assert!(result.is_err());
}
#[tokio::test]
async fn test_run_invalid_bind_address() {
use rcgen::generate_simple_self_signed;
use std::io::Write;
use tempfile::NamedTempFile;
let cert = generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let cert_pem = cert.cert.pem();
let key_pem = cert.signing_key.serialize_pem();
let mut cert_file = NamedTempFile::new().unwrap();
cert_file.write_all(cert_pem.as_bytes()).unwrap();
let cert_path = cert_file.path().to_path_buf();
let mut key_file = NamedTempFile::new().unwrap();
key_file.write_all(key_pem.as_bytes()).unwrap();
let key_path = key_file.path().to_path_buf();
let tls = crate::server::TlsConfig::from_files(cert_path, key_path).unwrap();
struct DummyHandler;
#[async_trait::async_trait]
impl crate::server::RequestHandler for DummyHandler {
async fn handle(
&self,
ctx: crate::server::RequestContext,
) -> crate::Result<crate::dns::Message> {
let req = ctx.into_message();
Ok(req)
}
}
let server = DotServer::new("not-a-valid-addr", tls, Arc::new(DummyHandler));
let res = server.run().await;
assert!(res.is_err());
match res.unwrap_err() {
Error::Io(_) => {}
other => panic!("expected Io error, got: {:?}", other),
}
}
}