use std::net::SocketAddr;
use tracing::{debug, error, trace};
use crate::{
net::{
BufDnsStreamHandle, DnsStreamHandle, NetError, udp::MAX_RECEIVE_BUFFER_SIZE, xfer::Protocol,
},
proto::{
ProtoError,
op::{Header, HeaderCounts, MessageType, Metadata, OpCode, ResponseCode, SerialMessage},
rr::Record,
serialize::binary::BinEncodable,
serialize::binary::BinEncoder,
},
server::ResponseInfo,
zone_handler::MessageResponse,
};
#[async_trait::async_trait]
pub trait ResponseHandler: Clone + Send + Sync + Unpin + 'static {
async fn send_response<'a>(
&mut self,
response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> Result<ResponseInfo, NetError>;
}
#[derive(Clone)]
pub struct ResponseHandle {
dst: SocketAddr,
stream_handle: BufDnsStreamHandle,
protocol: Protocol,
}
impl ResponseHandle {
pub fn new(dst: SocketAddr, stream_handle: BufDnsStreamHandle, protocol: Protocol) -> Self {
Self {
dst,
stream_handle,
protocol,
}
}
fn max_size_for_response<'a>(
&self,
response: &MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> u16 {
match self.protocol {
Protocol::Udp => {
if let Some(edns) = response.edns() {
edns.max_payload()
} else {
MAX_RECEIVE_BUFFER_SIZE as u16
}
}
_ => u16::MAX,
}
}
}
#[async_trait::async_trait]
impl ResponseHandler for ResponseHandle {
async fn send_response<'a>(
&mut self,
response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> Result<ResponseInfo, NetError> {
let id = response.metadata().id;
debug!(
id,
response_code = %response.metadata().response_code,
"sending response",
);
let mut buffer = Vec::with_capacity(512);
let encode_result = {
let mut encoder = BinEncoder::new(&mut buffer);
let max_size = self.max_size_for_response(&response);
trace!(
"setting response max size: {max_size} for protocol: {:?}",
self.protocol
);
encoder.set_max_size(max_size);
response.destructive_emit(&mut encoder)
};
let info = encode_result.or_else(|error| {
error!(%error, "error encoding message");
encode_fallback_servfail_response(id, &mut buffer)
})?;
self.stream_handle
.send(SerialMessage::new(buffer, self.dst))?;
Ok(info)
}
}
pub(crate) fn encode_fallback_servfail_response(
id: u16,
buffer: &mut Vec<u8>,
) -> Result<ResponseInfo, ProtoError> {
buffer.clear();
let mut encoder = BinEncoder::new(buffer);
encoder.set_max_size(512);
let mut metadata = Metadata::new(id, MessageType::Response, OpCode::Query);
metadata.response_code = ResponseCode::ServFail;
let header = Header {
metadata,
counts: HeaderCounts::default(),
};
header.emit(&mut encoder)?;
Ok(ResponseInfo::from(header))
}