use crate::dns::Message;
use crate::error::{Error, Result};
use crate::server::{RequestHandler, Server, ServerConfig, TlsConfig};
use axum::{
Router,
body::Bytes,
extract::{Query as AxumQuery, State},
http::{HeaderMap, StatusCode, header},
response::{IntoResponse, Response},
routing::post,
};
#[cfg(feature = "doh")]
use axum_server::bind_rustls as axum_bind_rustls;
#[cfg(feature = "doh")]
use axum_server::tls_rustls::RustlsConfig as AxumRustlsConfig;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info, trace};
pub struct DohServer {
addr: String,
_tls_config: TlsConfig,
handler: Arc<dyn RequestHandler>,
path: String,
}
impl DohServer {
pub fn new(
addr: impl Into<String>,
tls_config: TlsConfig,
handler: Arc<dyn RequestHandler>,
) -> Self {
Self {
addr: addr.into(),
_tls_config: tls_config,
handler,
path: "/dns-query".to_string(),
}
}
pub fn with_path(mut self, path: String) -> Self {
self.path = path;
self
}
pub async fn run(self) -> Result<()> {
let handler = Arc::clone(&self.handler);
let app = Router::new()
.route(&self.path, post(handle_post_query).get(handle_get_query))
.with_state(handler);
info!(
"DoH server listening on {} (path: {})",
self.addr, self.path
);
#[cfg(feature = "doh")]
{
let tls_config = self._tls_config.build_server_config()?;
let axum_tls = AxumRustlsConfig::from_config(tls_config.clone());
info!(
"Starting DoH server with TLS on {} (path: {})",
self.addr, self.path
);
let bind_addr: std::net::SocketAddr = self
.addr
.parse()
.map_err(|e| Error::Config(format!("Invalid bind address: {}", e)))?;
axum_bind_rustls(bind_addr, axum_tls)
.serve(app.into_make_service())
.await
.map_err(|e| Error::Other(format!("Server error: {}", e)))?;
}
#[cfg(not(feature = "doh"))]
{
tracing::warn!(
"DoH server running without TLS; enable `tls` feature for production TLS support"
);
let listener = tokio::net::TcpListener::bind(&self.addr)
.await
.map_err(Error::Io)?;
axum::serve(listener, app)
.await
.map_err(|e| Error::Other(format!("Server error: {}", e)))?;
}
Ok(())
}
}
#[async_trait::async_trait]
impl Server for DohServer {
async fn from_config(config: ServerConfig) -> crate::Result<Self> {
let addr = config
.tcp_addr
.ok_or_else(|| Error::Config("TCP address not configured for DoH".to_string()))?
.to_string();
let tls_config = config
.tls_config
.ok_or_else(|| Error::Config("TLS config not configured for DoH".to_string()))?;
let handler = config
.handler
.ok_or_else(|| Error::Config("Handler not configured".to_string()))?;
let mut server = Self::new(addr, tls_config, handler);
if let Some(path) = config.doh_path {
server = server.with_path(path);
}
Ok(server)
}
async fn run(self) -> crate::Result<()> {
DohServer::run(self).await
}
}
async fn handle_get_query(
State(handler): State<Arc<dyn RequestHandler>>,
AxumQuery(params): AxumQuery<HashMap<String, String>>,
_headers: HeaderMap,
) -> Response {
debug!("Handling DoH GET request");
let dns_param = match params.get("dns") {
Some(param) => param,
None => {
return (
StatusCode::BAD_REQUEST,
"Missing 'dns' query parameter. Usage: /dns-query?dns=<base64url-encoded-query>",
)
.into_response();
}
};
trace!(dns_param, "DoH GET query parameters");
let dns_data = match URL_SAFE_NO_PAD.decode(dns_param.as_bytes()) {
Ok(data) => data,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Invalid base64url encoding: {}", e),
)
.into_response();
}
};
trace!("Decoded DNS query: {} bytes", dns_data.len());
let request = match parse_dns_message(&dns_data) {
Ok(msg) => msg,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Invalid DNS message: {}", e),
)
.into_response();
}
};
debug!(
question = ?request.questions(),
"Processing query ID {} with {} questions",
request.id(),
request.question_count()
);
let ctx = crate::server::RequestContext::new(request, crate::server::Protocol::DoH);
let response = match handler.handle(ctx).await {
Ok(resp) => resp,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Query processing failed: {}", e),
)
.into_response();
}
};
let response_data = match serialize_dns_message(&response) {
Ok(data) => data,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Response serialization failed: {}", e),
)
.into_response();
}
};
debug!("DoH GET handler processed query successfully");
trace!("Sending DoH response: {} bytes", response_data.len());
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/dns-message")],
response_data,
)
.into_response()
}
async fn handle_post_query(
State(handler): State<Arc<dyn RequestHandler>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
debug!("Handling DoH POST request");
if let Some(content_type) = headers.get(header::CONTENT_TYPE) {
if content_type != "application/dns-message" {
return (
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"Content-Type must be application/dns-message",
)
.into_response();
}
} else {
return (StatusCode::BAD_REQUEST, "Content-Type header required").into_response();
}
trace!(content_length = body.len(), "DoH POST body length");
let request = match parse_dns_message(&body) {
Ok(msg) => {
trace!(bytes = body.len(), "Parsed DNS POST query");
msg
}
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Invalid DNS message: {}", e),
)
.into_response();
}
};
debug!(
question = ?request.questions(),
"Processing query ID {} with {} questions",
request.id(),
request.question_count()
);
let ctx = crate::server::RequestContext::new(request, crate::server::Protocol::DoH);
let response = match handler.handle(ctx).await {
Ok(resp) => {
debug!("DoH POST handler processed query successfully");
resp
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Query processing failed: {}", e),
)
.into_response();
}
};
let response_data = match serialize_dns_message(&response) {
Ok(data) => {
trace!(bytes = data.len(), "Serialized DoH POST response");
data
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Response serialization failed: {}", e),
)
.into_response();
}
};
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/dns-message")],
response_data,
)
.into_response()
}
fn parse_dns_message(data: &[u8]) -> Result<Message> {
crate::dns::wire::parse_message(data)
}
fn serialize_dns_message(message: &Message) -> Result<Vec<u8>> {
crate::dns::wire::serialize_message(message)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::{RequestContext, RequestHandler};
use async_trait::async_trait;
use axum::body::Bytes as AxumBytes;
use axum::body::to_bytes;
use axum::http::header::CONTENT_TYPE;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use std::collections::HashMap;
struct TestHandler;
#[async_trait]
impl RequestHandler for TestHandler {
async fn handle(&self, ctx: RequestContext) -> crate::Result<Message> {
let mut request = ctx.into_message();
request.set_response(true);
Ok(request)
}
}
#[tokio::test]
async fn test_parse_dns_message_placeholder() {
let data = vec![0u8; 12];
let result = parse_dns_message(&data);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_serialize_dns_message_placeholder() {
let message = Message::new();
let result = serialize_dns_message(&message);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 12);
}
#[tokio::test]
async fn test_base64url_encoding_decoding() {
let original_data = vec![
0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ];
let encoded = URL_SAFE_NO_PAD.encode(&original_data);
let decoded = URL_SAFE_NO_PAD.decode(encoded.as_bytes()).unwrap();
assert_eq!(original_data, decoded);
}
#[tokio::test]
async fn test_handle_get_query_success() {
let mut req = Message::new();
req.set_id(0x1234);
req.set_query(true);
let data = crate::dns::wire::serialize_message(&req).unwrap();
let encoded = URL_SAFE_NO_PAD.encode(&data);
let mut params = HashMap::new();
params.insert("dns".to_string(), encoded);
let resp = handle_get_query(
State(Arc::new(TestHandler)),
AxumQuery(params),
HeaderMap::new(),
)
.await;
assert_eq!(resp.status(), StatusCode::OK);
let headers = resp.headers();
assert_eq!(
headers.get(CONTENT_TYPE).unwrap(),
"application/dns-message"
);
let body = to_bytes(resp.into_body(), 64 * 1024).await.unwrap();
let parsed = crate::dns::wire::parse_message(&body).unwrap();
assert!(parsed.is_response());
assert_eq!(parsed.id(), 0x1234);
}
#[tokio::test]
async fn test_handle_post_query_success() {
let mut req = Message::new();
req.set_id(0x9a);
req.set_query(true);
let data = crate::dns::wire::serialize_message(&req).unwrap();
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, "application/dns-message".parse().unwrap());
let resp = handle_post_query(
State(Arc::new(TestHandler)),
headers,
AxumBytes::from(data.clone()),
)
.await;
assert_eq!(resp.status(), StatusCode::OK);
let headers = resp.headers();
assert_eq!(
headers.get(CONTENT_TYPE).unwrap(),
"application/dns-message"
);
let body = to_bytes(resp.into_body(), 64 * 1024).await.unwrap();
let parsed = crate::dns::wire::parse_message(&body).unwrap();
assert!(parsed.is_response());
assert_eq!(parsed.id(), 0x9a);
}
#[tokio::test]
async fn test_handle_post_query_missing_content_type() {
let mut req = Message::new();
req.set_id(0x55);
req.set_query(true);
let data = crate::dns::wire::serialize_message(&req).unwrap();
let headers = HeaderMap::new();
let resp =
handle_post_query(State(Arc::new(TestHandler)), headers, AxumBytes::from(data)).await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_handle_post_query_unsupported_media_type() {
let mut req = Message::new();
req.set_id(0x66);
req.set_query(true);
let data = crate::dns::wire::serialize_message(&req).unwrap();
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, "text/plain".parse().unwrap());
let resp =
handle_post_query(State(Arc::new(TestHandler)), headers, AxumBytes::from(data)).await;
assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
}
struct TestHandlerErr;
#[async_trait]
impl RequestHandler for TestHandlerErr {
async fn handle(&self, _ctx: RequestContext) -> crate::Result<Message> {
Err(crate::Error::Plugin("handler failure".to_string()))
}
}
#[tokio::test]
async fn test_handle_get_query_missing_param() {
let params: HashMap<String, String> = HashMap::new();
let resp = handle_get_query(
State(Arc::new(TestHandler)),
AxumQuery(params),
HeaderMap::new(),
)
.await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_handle_get_query_invalid_base64() {
let mut params = HashMap::new();
params.insert("dns".to_string(), "!!not_base64!!".to_string());
let resp = handle_get_query(
State(Arc::new(TestHandler)),
AxumQuery(params),
HeaderMap::new(),
)
.await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_handle_get_query_invalid_dns_message() {
let bad = vec![1u8, 2, 3];
let encoded = URL_SAFE_NO_PAD.encode(&bad);
let mut params = HashMap::new();
params.insert("dns".to_string(), encoded);
let resp = handle_get_query(
State(Arc::new(TestHandler)),
AxumQuery(params),
HeaderMap::new(),
)
.await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_handler_error_get_and_post_return_internal() {
let mut req = Message::new();
req.set_id(0x77);
req.set_query(true);
let data = crate::dns::wire::serialize_message(&req).unwrap();
let encoded = URL_SAFE_NO_PAD.encode(&data);
let mut params = HashMap::new();
params.insert("dns".to_string(), encoded);
let resp_get = handle_get_query(
State(Arc::new(TestHandlerErr)),
AxumQuery(params.clone()),
HeaderMap::new(),
)
.await;
assert_eq!(resp_get.status(), StatusCode::INTERNAL_SERVER_ERROR);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, "application/dns-message".parse().unwrap());
let resp_post = handle_post_query(
State(Arc::new(TestHandlerErr)),
headers,
AxumBytes::from(data),
)
.await;
assert_eq!(resp_post.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
}