hickory_server/authority/
message_response.rs1use crate::{
9 authority::{
10 message_request::{MessageRequest, QueriesEmitAndCount},
11 Queries,
12 },
13 proto::{
14 error::*,
15 op::{
16 message::{self, EmitAndCount},
17 Edns, Header, ResponseCode,
18 },
19 rr::Record,
20 serialize::binary::BinEncoder,
21 },
22 server::ResponseInfo,
23};
24
25use super::message_request::WireQuery;
26
27#[derive(Debug)]
29pub struct MessageResponse<'q, 'a, Answers, NameServers, Soa, Additionals>
30where
31 Answers: Iterator<Item = &'a Record> + Send + 'a,
32 NameServers: Iterator<Item = &'a Record> + Send + 'a,
33 Soa: Iterator<Item = &'a Record> + Send + 'a,
34 Additionals: Iterator<Item = &'a Record> + Send + 'a,
35{
36 header: Header,
37 query: Option<&'q WireQuery>,
38 answers: Answers,
39 name_servers: NameServers,
40 soa: Soa,
41 additionals: Additionals,
42 sig0: Vec<Record>,
43 edns: Option<Edns>,
44}
45
46enum EmptyOrQueries<'q> {
47 Empty,
48 Queries(QueriesEmitAndCount<'q>),
49}
50
51impl<'q> From<Option<&'q Queries>> for EmptyOrQueries<'q> {
52 fn from(option: Option<&'q Queries>) -> Self {
53 option.map_or(EmptyOrQueries::Empty, |q| {
54 EmptyOrQueries::Queries(q.as_emit_and_count())
55 })
56 }
57}
58
59impl<'q> From<Option<&'q WireQuery>> for EmptyOrQueries<'q> {
60 fn from(option: Option<&'q WireQuery>) -> Self {
61 option.map_or(EmptyOrQueries::Empty, |q| {
62 EmptyOrQueries::Queries(q.as_emit_and_count())
63 })
64 }
65}
66
67impl EmitAndCount for EmptyOrQueries<'_> {
68 fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
69 match self {
70 EmptyOrQueries::Empty => Ok(0),
71 EmptyOrQueries::Queries(q) => q.emit(encoder),
72 }
73 }
74}
75
76impl<'a, A, N, S, D> MessageResponse<'_, 'a, A, N, S, D>
77where
78 A: Iterator<Item = &'a Record> + Send + 'a,
79 N: Iterator<Item = &'a Record> + Send + 'a,
80 S: Iterator<Item = &'a Record> + Send + 'a,
81 D: Iterator<Item = &'a Record> + Send + 'a,
82{
83 pub fn header(&self) -> &Header {
85 &self.header
86 }
87
88 pub fn header_mut(&mut self) -> &mut Header {
90 &mut self.header
91 }
92
93 pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
95 self.edns = Some(edns);
96 self
97 }
98
99 pub fn get_edns(&self) -> &Option<Edns> {
101 &self.edns
102 }
103
104 pub fn destructive_emit(mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<ResponseInfo> {
106 let mut name_servers = self.name_servers.chain(self.soa);
108
109 message::emit_message_parts(
110 &self.header,
111 &mut EmptyOrQueries::from(self.query),
112 &mut self.answers,
113 &mut name_servers,
114 &mut self.additionals,
115 self.edns.as_ref(),
116 &self.sig0,
117 encoder,
118 )
119 .map(Into::into)
120 }
121}
122
123pub struct MessageResponseBuilder<'q> {
125 query: Option<&'q WireQuery>,
126 sig0: Option<Vec<Record>>,
127 edns: Option<Edns>,
128}
129
130impl<'q> MessageResponseBuilder<'q> {
131 pub(crate) fn new(query: Option<&'q WireQuery>) -> Self {
137 MessageResponseBuilder {
138 query,
139 sig0: None,
140 edns: None,
141 }
142 }
143
144 pub fn from_message_request(message: &'q MessageRequest) -> Self {
150 Self::new(Some(message.raw_query()))
151 }
152
153 pub fn edns(&mut self, edns: Edns) -> &mut Self {
155 self.edns = Some(edns);
156 self
157 }
158
159 pub fn build<'a, A, N, S, D>(
165 self,
166 header: Header,
167 answers: A,
168 name_servers: N,
169 soa: S,
170 additionals: D,
171 ) -> MessageResponse<'q, 'a, A::IntoIter, N::IntoIter, S::IntoIter, D::IntoIter>
172 where
173 A: IntoIterator<Item = &'a Record> + Send + 'a,
174 A::IntoIter: Send,
175 N: IntoIterator<Item = &'a Record> + Send + 'a,
176 N::IntoIter: Send,
177 S: IntoIterator<Item = &'a Record> + Send + 'a,
178 S::IntoIter: Send,
179 D: IntoIterator<Item = &'a Record> + Send + 'a,
180 D::IntoIter: Send,
181 {
182 MessageResponse {
183 header,
184 query: self.query,
185 answers: answers.into_iter(),
186 name_servers: name_servers.into_iter(),
187 soa: soa.into_iter(),
188 additionals: additionals.into_iter(),
189 sig0: self.sig0.unwrap_or_default(),
190 edns: self.edns,
191 }
192 }
193
194 pub fn build_no_records<'a>(
196 self,
197 header: Header,
198 ) -> MessageResponse<
199 'q,
200 'a,
201 impl Iterator<Item = &'a Record> + Send + 'a,
202 impl Iterator<Item = &'a Record> + Send + 'a,
203 impl Iterator<Item = &'a Record> + Send + 'a,
204 impl Iterator<Item = &'a Record> + Send + 'a,
205 > {
206 MessageResponse {
207 header,
208 query: self.query,
209 answers: Box::new(None.into_iter()),
210 name_servers: Box::new(None.into_iter()),
211 soa: Box::new(None.into_iter()),
212 additionals: Box::new(None.into_iter()),
213 sig0: self.sig0.unwrap_or_default(),
214 edns: self.edns,
215 }
216 }
217
218 pub fn error_msg<'a>(
226 self,
227 request_header: &Header,
228 response_code: ResponseCode,
229 ) -> MessageResponse<
230 'q,
231 'a,
232 impl Iterator<Item = &'a Record> + Send + 'a,
233 impl Iterator<Item = &'a Record> + Send + 'a,
234 impl Iterator<Item = &'a Record> + Send + 'a,
235 impl Iterator<Item = &'a Record> + Send + 'a,
236 > {
237 let mut header = Header::response_from_request(request_header);
238 header.set_response_code(response_code);
239
240 MessageResponse {
241 header,
242 query: self.query,
243 answers: Box::new(None.into_iter()),
244 name_servers: Box::new(None.into_iter()),
245 soa: Box::new(None.into_iter()),
246 additionals: Box::new(None.into_iter()),
247 sig0: self.sig0.unwrap_or_default(),
248 edns: self.edns,
249 }
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use std::iter;
256 use std::net::Ipv4Addr;
257 use std::str::FromStr;
258
259 use crate::proto::op::{Header, Message};
260 use crate::proto::rr::{DNSClass, Name, RData, Record, RecordType};
261 use crate::proto::serialize::binary::BinEncoder;
262
263 use super::*;
264
265 #[test]
266 fn test_truncation_ridiculous_number_answers() {
267 let mut buf = Vec::with_capacity(512);
268 {
269 let mut encoder = BinEncoder::new(&mut buf);
270 encoder.set_max_size(512);
271
272 let answer = Record::new()
273 .set_record_type(RecordType::A)
274 .set_name(Name::from_str("www.example.com.").unwrap())
275 .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 215, 14).into())))
276 .set_dns_class(DNSClass::NONE)
277 .clone();
278
279 let message = MessageResponse {
280 header: Header::new(),
281 query: None,
282 answers: iter::repeat(&answer),
283 name_servers: iter::once(&answer),
284 soa: iter::once(&answer),
285 additionals: iter::once(&answer),
286 sig0: vec![],
287 edns: None,
288 };
289
290 message
291 .destructive_emit(&mut encoder)
292 .expect("failed to encode");
293 }
294
295 let response = Message::from_vec(&buf).expect("failed to decode");
296 assert!(response.header().truncated());
297 assert!(response.answer_count() > 1);
298 assert_eq!(response.name_server_count(), 0);
300 }
301
302 #[test]
303 fn test_truncation_ridiculous_number_nameservers() {
304 let mut buf = Vec::with_capacity(512);
305 {
306 let mut encoder = BinEncoder::new(&mut buf);
307 encoder.set_max_size(512);
308
309 let answer = Record::new()
310 .set_record_type(RecordType::A)
311 .set_name(Name::from_str("www.example.com.").unwrap())
312 .set_data(Some(RData::A(Ipv4Addr::new(93, 184, 215, 14).into())))
313 .set_dns_class(DNSClass::NONE)
314 .clone();
315
316 let message = MessageResponse {
317 header: Header::new(),
318 query: None,
319 answers: iter::empty(),
320 name_servers: iter::repeat(&answer),
321 soa: iter::repeat(&answer),
322 additionals: iter::repeat(&answer),
323 sig0: vec![],
324 edns: None,
325 };
326
327 message
328 .destructive_emit(&mut encoder)
329 .expect("failed to encode");
330 }
331
332 let response = Message::from_vec(&buf).expect("failed to decode");
333 assert!(response.header().truncated());
334 assert_eq!(response.answer_count(), 0);
335 assert!(response.name_server_count() > 1);
336 }
337}