hickory_server/server/
response_handler.rs

1// Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::{io, net::SocketAddr};
9
10use hickory_proto::{
11    ProtoError,
12    op::{Header, ResponseCode},
13    rr::Record,
14    serialize::binary::BinEncodable,
15};
16use tracing::{debug, error, trace};
17
18use crate::{
19    authority::MessageResponse,
20    proto::{
21        BufDnsStreamHandle, DnsStreamHandle,
22        serialize::binary::BinEncoder,
23        xfer::{Protocol, SerialMessage},
24    },
25    server::ResponseInfo,
26};
27
28/// A handler for send a response to a client
29#[async_trait::async_trait]
30pub trait ResponseHandler: Clone + Send + Sync + Unpin + 'static {
31    // TODO: add associated error type
32    //type Error;
33
34    /// Serializes and sends a message to the wrapped handle
35    ///
36    /// self is consumed as only one message should ever be sent in response to a Request
37    async fn send_response<'a>(
38        &mut self,
39        response: MessageResponse<
40            '_,
41            'a,
42            impl Iterator<Item = &'a Record> + Send + 'a,
43            impl Iterator<Item = &'a Record> + Send + 'a,
44            impl Iterator<Item = &'a Record> + Send + 'a,
45            impl Iterator<Item = &'a Record> + Send + 'a,
46        >,
47    ) -> io::Result<ResponseInfo>;
48}
49
50/// A handler for wrapping a BufStreamHandle, which will properly serialize the message and add the
51///  associated destination.
52#[derive(Clone)]
53pub struct ResponseHandle {
54    dst: SocketAddr,
55    stream_handle: BufDnsStreamHandle,
56    protocol: Protocol,
57}
58
59impl ResponseHandle {
60    /// Returns a new `ResponseHandle` for sending a response message
61    pub fn new(dst: SocketAddr, stream_handle: BufDnsStreamHandle, protocol: Protocol) -> Self {
62        Self {
63            dst,
64            stream_handle,
65            protocol,
66        }
67    }
68
69    /// Selects an appropriate maximum serialized size for the given response.
70    fn max_size_for_response<'a>(
71        &self,
72        response: &MessageResponse<
73            '_,
74            'a,
75            impl Iterator<Item = &'a Record> + Send + 'a,
76            impl Iterator<Item = &'a Record> + Send + 'a,
77            impl Iterator<Item = &'a Record> + Send + 'a,
78            impl Iterator<Item = &'a Record> + Send + 'a,
79        >,
80    ) -> u16 {
81        match self.protocol {
82            Protocol::Udp => {
83                // Use EDNS, if available.
84                if let Some(edns) = response.get_edns() {
85                    edns.max_payload()
86                } else {
87                    // No EDNS, use the recommended max from RFC6891.
88                    hickory_proto::udp::MAX_RECEIVE_BUFFER_SIZE as u16
89                }
90            }
91            _ => u16::MAX,
92        }
93    }
94}
95
96#[async_trait::async_trait]
97impl ResponseHandler for ResponseHandle {
98    /// Serializes and sends a message to to the wrapped handle
99    ///
100    /// self is consumed as only one message should ever be sent in response to a Request
101    async fn send_response<'a>(
102        &mut self,
103        response: MessageResponse<
104            '_,
105            'a,
106            impl Iterator<Item = &'a Record> + Send + 'a,
107            impl Iterator<Item = &'a Record> + Send + 'a,
108            impl Iterator<Item = &'a Record> + Send + 'a,
109            impl Iterator<Item = &'a Record> + Send + 'a,
110        >,
111    ) -> io::Result<ResponseInfo> {
112        let id = response.header().id();
113        debug!(
114            id,
115            response_code = %response.header().response_code(),
116            "sending response",
117        );
118        let mut buffer = Vec::with_capacity(512);
119        let encode_result = {
120            let mut encoder = BinEncoder::new(&mut buffer);
121
122            // Set an appropriate maximum on the encoder.
123            let max_size = self.max_size_for_response(&response);
124            trace!(
125                "setting response max size: {max_size} for protocol: {:?}",
126                self.protocol
127            );
128            encoder.set_max_size(max_size);
129
130            response.destructive_emit(&mut encoder)
131        };
132
133        let info = encode_result.or_else(|error| {
134            error!(%error, "error encoding message");
135            encode_fallback_servfail_response(id, &mut buffer)
136        })?;
137
138        self.stream_handle
139            .send(SerialMessage::new(buffer, self.dst))
140            .map_err(|_| io::Error::new(io::ErrorKind::Other, "unknown"))?;
141
142        Ok(info)
143    }
144}
145
146/// Clears the buffer, encodes a SERVFAIL response in it, and returns a matching
147/// ResponseInfo.
148pub(crate) fn encode_fallback_servfail_response(
149    id: u16,
150    buffer: &mut Vec<u8>,
151) -> Result<ResponseInfo, ProtoError> {
152    buffer.clear();
153    let mut encoder = BinEncoder::new(buffer);
154    encoder.set_max_size(512);
155    let mut header = Header::new();
156    header.set_id(id);
157    header.set_response_code(ResponseCode::ServFail);
158    header.emit(&mut encoder)?;
159
160    Ok(ResponseInfo::serve_failed())
161}