use crate::Result;
use crate::dns::Message;
use async_trait::async_trait;
use std::net::{IpAddr, SocketAddr};
#[derive(Debug, Clone)]
pub struct ClientInfo {
pub addr: SocketAddr,
pub ip: IpAddr,
pub port: u16,
}
impl From<SocketAddr> for ClientInfo {
fn from(addr: SocketAddr) -> Self {
Self {
addr,
ip: addr.ip(),
port: addr.port(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Protocol {
Udp,
Tcp,
DoH,
DoT,
DoQ,
}
#[derive(Debug)]
pub struct RequestContext {
pub message: Message,
pub client_info: Option<ClientInfo>,
pub protocol: Protocol,
}
impl RequestContext {
pub fn new(message: Message, protocol: Protocol) -> Self {
Self {
message,
client_info: None,
protocol,
}
}
pub fn with_client(
message: Message,
client_addr: Option<SocketAddr>,
protocol: Protocol,
) -> Self {
Self {
message,
client_info: client_addr.map(ClientInfo::from),
protocol,
}
}
pub fn client_ip(&self) -> Option<&IpAddr> {
self.client_info.as_ref().map(|info| &info.ip)
}
pub fn client_addr(&self) -> Option<&SocketAddr> {
self.client_info.as_ref().map(|info| &info.addr)
}
pub fn into_message(self) -> Message {
self.message
}
pub fn into_raw(self) -> (Message, Option<ClientInfo>, Protocol) {
(self.message, self.client_info, self.protocol)
}
}
#[async_trait]
pub trait RequestHandler: Send + Sync {
async fn handle(&self, ctx: RequestContext) -> Result<Message>;
}
#[derive(Debug, Clone)]
pub struct DefaultHandler;
#[async_trait]
impl RequestHandler for DefaultHandler {
async fn handle(&self, ctx: RequestContext) -> Result<Message> {
let mut request = ctx.message;
request.set_response(true);
request.set_recursion_available(false);
Ok(request)
}
}
impl Default for DefaultHandler {
fn default() -> Self {
Self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::{Question, RecordClass, RecordType};
#[tokio::test]
async fn test_default_handler() {
let handler = DefaultHandler;
let mut request = Message::new();
request.set_id(1234);
request.set_query(true);
request.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let ctx = RequestContext::new(request, Protocol::Udp);
let response = handler.handle(ctx).await.unwrap();
assert!(response.is_response());
assert_eq!(response.id(), 1234);
assert_eq!(response.question_count(), 1);
}
#[tokio::test]
async fn test_handler_preserves_questions() {
let handler = DefaultHandler;
let mut request = Message::new();
request.add_question(Question::new(
"test.com".to_string(),
RecordType::AAAA,
RecordClass::IN,
));
let ctx = RequestContext::new(request, Protocol::Udp);
let response = handler.handle(ctx).await.unwrap();
assert_eq!(response.questions()[0].qname(), "test.com");
assert_eq!(response.questions()[0].qtype(), RecordType::AAAA);
}
#[tokio::test]
async fn test_request_context_with_client() {
let addr: SocketAddr = "192.168.1.1:12345".parse().unwrap();
let message = Message::new();
let ctx = RequestContext::with_client(message, Some(addr), Protocol::Udp);
assert!(ctx.client_info.is_some());
assert_eq!(ctx.client_ip(), Some(&"192.168.1.1".parse().unwrap()));
assert_eq!(ctx.client_addr(), Some(&addr));
assert_eq!(ctx.protocol, Protocol::Udp);
}
#[tokio::test]
async fn test_request_context_without_client() {
let message = Message::new();
let ctx = RequestContext::new(message, Protocol::DoH);
assert!(ctx.client_info.is_none());
assert_eq!(ctx.client_ip(), None);
assert_eq!(ctx.protocol, Protocol::DoH);
}
}