hickory_server/authority/
message_response.rs

1// Copyright 2015-2021 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::{
9    authority::{
10        message_request::{MessageRequest, QueriesEmitAndCount},
11        Queries,
12    },
13    proto::{
14        error::*,
15        op::{
16            message::{self, EmitAndCount},
17            Edns, Header, ResponseCode,
18        },
19        rr::Record,
20        serialize::binary::BinEncoder,
21    },
22    server::ResponseInfo,
23};
24
25use super::message_request::WireQuery;
26
27/// A EncodableMessage with borrowed data for Responses in the Server
28#[derive(Debug)]
29pub struct MessageResponse<'q, 'a, Answers, NameServers, Soa, Additionals>
30where
31    Answers: Iterator<Item = &'a Record> + Send + 'a,
32    NameServers: Iterator<Item = &'a Record> + Send + 'a,
33    Soa: Iterator<Item = &'a Record> + Send + 'a,
34    Additionals: Iterator<Item = &'a Record> + Send + 'a,
35{
36    header: Header,
37    query: Option<&'q WireQuery>,
38    answers: Answers,
39    name_servers: NameServers,
40    soa: Soa,
41    additionals: Additionals,
42    sig0: Vec<Record>,
43    edns: Option<Edns>,
44}
45
46enum EmptyOrQueries<'q> {
47    Empty,
48    Queries(QueriesEmitAndCount<'q>),
49}
50
51impl<'q> From<Option<&'q Queries>> for EmptyOrQueries<'q> {
52    fn from(option: Option<&'q Queries>) -> Self {
53        option.map_or(EmptyOrQueries::Empty, |q| {
54            EmptyOrQueries::Queries(q.as_emit_and_count())
55        })
56    }
57}
58
59impl<'q> From<Option<&'q WireQuery>> for EmptyOrQueries<'q> {
60    fn from(option: Option<&'q WireQuery>) -> Self {
61        option.map_or(EmptyOrQueries::Empty, |q| {
62            EmptyOrQueries::Queries(q.as_emit_and_count())
63        })
64    }
65}
66
67impl EmitAndCount for EmptyOrQueries<'_> {
68    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
69        match self {
70            EmptyOrQueries::Empty => Ok(0),
71            EmptyOrQueries::Queries(q) => q.emit(encoder),
72        }
73    }
74}
75
76impl<'a, A, N, S, D> MessageResponse<'_, 'a, A, N, S, D>
77where
78    A: Iterator<Item = &'a Record> + Send + 'a,
79    N: Iterator<Item = &'a Record> + Send + 'a,
80    S: Iterator<Item = &'a Record> + Send + 'a,
81    D: Iterator<Item = &'a Record> + Send + 'a,
82{
83    /// Returns the header of the message
84    pub fn header(&self) -> &Header {
85        &self.header
86    }
87
88    /// Get a mutable reference to the header
89    pub fn header_mut(&mut self) -> &mut Header {
90        &mut self.header
91    }
92
93    /// Set the EDNS options for the Response
94    pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
95        self.edns = Some(edns);
96        self
97    }
98
99    /// Gets a reference to the EDNS options for the Response.
100    pub fn get_edns(&self) -> &Option<Edns> {
101        &self.edns
102    }
103
104    /// Consumes self, and emits to the encoder.
105    pub fn destructive_emit(mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<ResponseInfo> {
106        // soa records are part of the nameserver section
107        let mut name_servers = self.name_servers.chain(self.soa);
108
109        message::emit_message_parts(
110            &self.header,
111            &mut EmptyOrQueries::from(self.query),
112            &mut self.answers,
113            &mut name_servers,
114            &mut self.additionals,
115            self.edns.as_ref(),
116            &self.sig0,
117            encoder,
118        )
119        .map(Into::into)
120    }
121}
122
123/// A builder for MessageResponses
124pub struct MessageResponseBuilder<'q> {
125    query: Option<&'q WireQuery>,
126    sig0: Option<Vec<Record>>,
127    edns: Option<Edns>,
128}
129
130impl<'q> MessageResponseBuilder<'q> {
131    /// Constructs a new response builder
132    ///
133    /// # Arguments
134    ///
135    /// * `query` - any optional query (from the Request) to associate with the Response
136    pub(crate) fn new(query: Option<&'q WireQuery>) -> Self {
137        MessageResponseBuilder {
138            query,
139            sig0: None,
140            edns: None,
141        }
142    }
143
144    /// Constructs a new response builder
145    ///
146    /// # Arguments
147    ///
148    /// * `message` - original request message to associate with the response
149    pub fn from_message_request(message: &'q MessageRequest) -> Self {
150        Self::new(Some(message.raw_query()))
151    }
152
153    /// Associate EDNS with the Response
154    pub fn edns(&mut self, edns: Edns) -> &mut Self {
155        self.edns = Some(edns);
156        self
157    }
158
159    /// Constructs the new MessageResponse with associated Header
160    ///
161    /// # Arguments
162    ///
163    /// * `header` - set of [Header]s for the Message
164    pub fn build<'a, A, N, S, D>(
165        self,
166        header: Header,
167        answers: A,
168        name_servers: N,
169        soa: S,
170        additionals: D,
171    ) -> MessageResponse<'q, 'a, A::IntoIter, N::IntoIter, S::IntoIter, D::IntoIter>
172    where
173        A: IntoIterator<Item = &'a Record> + Send + 'a,
174        A::IntoIter: Send,
175        N: IntoIterator<Item = &'a Record> + Send + 'a,
176        N::IntoIter: Send,
177        S: IntoIterator<Item = &'a Record> + Send + 'a,
178        S::IntoIter: Send,
179        D: IntoIterator<Item = &'a Record> + Send + 'a,
180        D::IntoIter: Send,
181    {
182        MessageResponse {
183            header,
184            query: self.query,
185            answers: answers.into_iter(),
186            name_servers: name_servers.into_iter(),
187            soa: soa.into_iter(),
188            additionals: additionals.into_iter(),
189            sig0: self.sig0.unwrap_or_default(),
190            edns: self.edns,
191        }
192    }
193
194    /// Construct a Response with no associated records
195    pub fn build_no_records<'a>(
196        self,
197        header: Header,
198    ) -> MessageResponse<
199        'q,
200        'a,
201        impl Iterator<Item = &'a Record> + Send + 'a,
202        impl Iterator<Item = &'a Record> + Send + 'a,
203        impl Iterator<Item = &'a Record> + Send + 'a,
204        impl Iterator<Item = &'a Record> + Send + 'a,
205    > {
206        MessageResponse {
207            header,
208            query: self.query,
209            answers: Box::new(None.into_iter()),
210            name_servers: Box::new(None.into_iter()),
211            soa: Box::new(None.into_iter()),
212            additionals: Box::new(None.into_iter()),
213            sig0: self.sig0.unwrap_or_default(),
214            edns: self.edns,
215        }
216    }
217
218    /// Constructs a new error MessageResponse with associated settings
219    ///
220    /// # Arguments
221    ///
222    /// * `id` - request id to which this is a response
223    /// * `op_code` - operation for which this is a response
224    /// * `response_code` - the type of error
225    pub fn error_msg<'a>(
226        self,
227        request_header: &Header,
228        response_code: ResponseCode,
229    ) -> MessageResponse<
230        'q,
231        'a,
232        impl Iterator<Item = &'a Record> + Send + 'a,
233        impl Iterator<Item = &'a Record> + Send + 'a,
234        impl Iterator<Item = &'a Record> + Send + 'a,
235        impl Iterator<Item = &'a Record> + Send + 'a,
236    > {
237        let mut header = Header::response_from_request(request_header);
238        header.set_response_code(response_code);
239
240        MessageResponse {
241            header,
242            query: self.query,
243            answers: Box::new(None.into_iter()),
244            name_servers: Box::new(None.into_iter()),
245            soa: Box::new(None.into_iter()),
246            additionals: Box::new(None.into_iter()),
247            sig0: self.sig0.unwrap_or_default(),
248            edns: self.edns,
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use std::iter;
256    use std::net::Ipv4Addr;
257    use std::str::FromStr;
258
259    use crate::proto::op::{Header, Message};
260    use crate::proto::rr::{DNSClass, Name, RData, Record, RecordType};
261    use crate::proto::serialize::binary::BinEncoder;
262
263    use super::*;
264
265    #[test]
266    fn test_truncation_ridiculous_number_answers() {
267        let mut buf = Vec::with_capacity(512);
268        {
269            let mut encoder = BinEncoder::new(&mut buf);
270            encoder.set_max_size(512);
271
272            let answer = Record::new()
273                .set_record_type(RecordType::A)
274                .set_name(Name::from_str("www.example.com.").unwrap())
275                .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 215, 14).into())))
276                .set_dns_class(DNSClass::NONE)
277                .clone();
278
279            let message = MessageResponse {
280                header: Header::new(),
281                query: None,
282                answers: iter::repeat(&answer),
283                name_servers: iter::once(&answer),
284                soa: iter::once(&answer),
285                additionals: iter::once(&answer),
286                sig0: vec![],
287                edns: None,
288            };
289
290            message
291                .destructive_emit(&mut encoder)
292                .expect("failed to encode");
293        }
294
295        let response = Message::from_vec(&buf).expect("failed to decode");
296        assert!(response.header().truncated());
297        assert!(response.answer_count() > 1);
298        // should never have written the name server field...
299        assert_eq!(response.name_server_count(), 0);
300    }
301
302    #[test]
303    fn test_truncation_ridiculous_number_nameservers() {
304        let mut buf = Vec::with_capacity(512);
305        {
306            let mut encoder = BinEncoder::new(&mut buf);
307            encoder.set_max_size(512);
308
309            let answer = Record::new()
310                .set_record_type(RecordType::A)
311                .set_name(Name::from_str("www.example.com.").unwrap())
312                .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 215, 14).into())))
313                .set_dns_class(DNSClass::NONE)
314                .clone();
315
316            let message = MessageResponse {
317                header: Header::new(),
318                query: None,
319                answers: iter::empty(),
320                name_servers: iter::repeat(&answer),
321                soa: iter::repeat(&answer),
322                additionals: iter::repeat(&answer),
323                sig0: vec![],
324                edns: None,
325            };
326
327            message
328                .destructive_emit(&mut encoder)
329                .expect("failed to encode");
330        }
331
332        let response = Message::from_vec(&buf).expect("failed to decode");
333        assert!(response.header().truncated());
334        assert_eq!(response.answer_count(), 0);
335        assert!(response.name_server_count() > 1);
336    }
337}