use std::{io, net::SocketAddr};
use tracing::{debug, trace};
use trust_dns_proto::rr::Record;
use crate::server::Protocol;
use crate::{
authority::MessageResponse,
proto::{
serialize::binary::BinEncoder, xfer::SerialMessage, BufDnsStreamHandle, DnsStreamHandle,
},
server::ResponseInfo,
};
#[async_trait::async_trait]
pub trait ResponseHandler: Clone + Send + Sync + Unpin + 'static {
async fn send_response<'a>(
&mut self,
response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> io::Result<ResponseInfo>;
}
#[derive(Clone)]
pub struct ResponseHandle {
dst: SocketAddr,
stream_handle: BufDnsStreamHandle,
protocol: Protocol,
}
impl ResponseHandle {
pub fn new(dst: SocketAddr, stream_handle: BufDnsStreamHandle, protocol: Protocol) -> Self {
Self {
dst,
stream_handle,
protocol,
}
}
fn max_size_for_response<'a>(
&self,
response: &MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> u16 {
match self.protocol {
Protocol::Udp => {
if let Some(edns) = response.get_edns() {
edns.max_payload()
} else {
trust_dns_proto::udp::MAX_RECEIVE_BUFFER_SIZE as u16
}
}
_ => u16::MAX,
}
}
}
#[async_trait::async_trait]
impl ResponseHandler for ResponseHandle {
async fn send_response<'a>(
&mut self,
response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> io::Result<ResponseInfo> {
debug!(
"response: {} response_code: {}",
response.header().id(),
response.header().response_code(),
);
let mut buffer = Vec::with_capacity(512);
let encode_result = {
let mut encoder = BinEncoder::new(&mut buffer);
let max_size = self.max_size_for_response(&response);
trace!(
"setting response max size: {max_size} for protocol: {:?}",
self.protocol
);
encoder.set_max_size(max_size);
response.destructive_emit(&mut encoder)
};
let info = encode_result.map_err(|e| {
io::Error::new(io::ErrorKind::Other, format!("error encoding message: {e}"))
})?;
self.stream_handle
.send(SerialMessage::new(buffer, self.dst))
.map_err(|_| io::Error::new(io::ErrorKind::Other, "unknown"))?;
Ok(info)
}
}