hickory_server/server/
response_handler.rs1use std::{io, net::SocketAddr};
9
10use hickory_proto::{
11 ProtoError,
12 op::{Header, ResponseCode},
13 rr::Record,
14 serialize::binary::BinEncodable,
15};
16use tracing::{debug, error, trace};
17
18use crate::{
19 authority::MessageResponse,
20 proto::{
21 BufDnsStreamHandle, DnsStreamHandle,
22 serialize::binary::BinEncoder,
23 xfer::{Protocol, SerialMessage},
24 },
25 server::ResponseInfo,
26};
27
28#[async_trait::async_trait]
30pub trait ResponseHandler: Clone + Send + Sync + Unpin + 'static {
31 async fn send_response<'a>(
38 &mut self,
39 response: MessageResponse<
40 '_,
41 'a,
42 impl Iterator<Item = &'a Record> + Send + 'a,
43 impl Iterator<Item = &'a Record> + Send + 'a,
44 impl Iterator<Item = &'a Record> + Send + 'a,
45 impl Iterator<Item = &'a Record> + Send + 'a,
46 >,
47 ) -> io::Result<ResponseInfo>;
48}
49
50#[derive(Clone)]
53pub struct ResponseHandle {
54 dst: SocketAddr,
55 stream_handle: BufDnsStreamHandle,
56 protocol: Protocol,
57}
58
59impl ResponseHandle {
60 pub fn new(dst: SocketAddr, stream_handle: BufDnsStreamHandle, protocol: Protocol) -> Self {
62 Self {
63 dst,
64 stream_handle,
65 protocol,
66 }
67 }
68
69 fn max_size_for_response<'a>(
71 &self,
72 response: &MessageResponse<
73 '_,
74 'a,
75 impl Iterator<Item = &'a Record> + Send + 'a,
76 impl Iterator<Item = &'a Record> + Send + 'a,
77 impl Iterator<Item = &'a Record> + Send + 'a,
78 impl Iterator<Item = &'a Record> + Send + 'a,
79 >,
80 ) -> u16 {
81 match self.protocol {
82 Protocol::Udp => {
83 if let Some(edns) = response.get_edns() {
85 edns.max_payload()
86 } else {
87 hickory_proto::udp::MAX_RECEIVE_BUFFER_SIZE as u16
89 }
90 }
91 _ => u16::MAX,
92 }
93 }
94}
95
96#[async_trait::async_trait]
97impl ResponseHandler for ResponseHandle {
98 async fn send_response<'a>(
102 &mut self,
103 response: MessageResponse<
104 '_,
105 'a,
106 impl Iterator<Item = &'a Record> + Send + 'a,
107 impl Iterator<Item = &'a Record> + Send + 'a,
108 impl Iterator<Item = &'a Record> + Send + 'a,
109 impl Iterator<Item = &'a Record> + Send + 'a,
110 >,
111 ) -> io::Result<ResponseInfo> {
112 let id = response.header().id();
113 debug!(
114 id,
115 response_code = %response.header().response_code(),
116 "sending response",
117 );
118 let mut buffer = Vec::with_capacity(512);
119 let encode_result = {
120 let mut encoder = BinEncoder::new(&mut buffer);
121
122 let max_size = self.max_size_for_response(&response);
124 trace!(
125 "setting response max size: {max_size} for protocol: {:?}",
126 self.protocol
127 );
128 encoder.set_max_size(max_size);
129
130 response.destructive_emit(&mut encoder)
131 };
132
133 let info = encode_result.or_else(|error| {
134 error!(%error, "error encoding message");
135 encode_fallback_servfail_response(id, &mut buffer)
136 })?;
137
138 self.stream_handle
139 .send(SerialMessage::new(buffer, self.dst))
140 .map_err(|_| io::Error::new(io::ErrorKind::Other, "unknown"))?;
141
142 Ok(info)
143 }
144}
145
146pub(crate) fn encode_fallback_servfail_response(
149 id: u16,
150 buffer: &mut Vec<u8>,
151) -> Result<ResponseInfo, ProtoError> {
152 buffer.clear();
153 let mut encoder = BinEncoder::new(buffer);
154 encoder.set_max_size(512);
155 let mut header = Header::new();
156 header.set_id(id);
157 header.set_response_code(ResponseCode::ServFail);
158 header.emit(&mut encoder)?;
159
160 Ok(ResponseInfo::serve_failed())
161}