1use crate::{
9 authority::{Queries, message_request::MessageRequest},
10 proto::{
11 ProtoError,
12 op::{Edns, Header, ResponseCode, message},
13 rr::Record,
14 serialize::binary::BinEncoder,
15 },
16 server::ResponseInfo,
17};
18
19#[derive(Debug)]
21pub struct MessageResponse<'q, 'a, Answers, NameServers, Soa, Additionals>
22where
23 Answers: Iterator<Item = &'a Record> + Send + 'a,
24 NameServers: Iterator<Item = &'a Record> + Send + 'a,
25 Soa: Iterator<Item = &'a Record> + Send + 'a,
26 Additionals: Iterator<Item = &'a Record> + Send + 'a,
27{
28 header: Header,
29 queries: &'q Queries,
30 answers: Answers,
31 name_servers: NameServers,
32 soa: Soa,
33 additionals: Additionals,
34 sig0: Vec<Record>,
35 edns: Option<Edns>,
36}
37
38impl<'a, A, N, S, D> MessageResponse<'_, 'a, A, N, S, D>
39where
40 A: Iterator<Item = &'a Record> + Send + 'a,
41 N: Iterator<Item = &'a Record> + Send + 'a,
42 S: Iterator<Item = &'a Record> + Send + 'a,
43 D: Iterator<Item = &'a Record> + Send + 'a,
44{
45 pub fn header(&self) -> &Header {
47 &self.header
48 }
49
50 pub fn header_mut(&mut self) -> &mut Header {
52 &mut self.header
53 }
54
55 pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
57 self.edns = Some(edns);
58 self
59 }
60
61 pub fn get_edns(&self) -> &Option<Edns> {
63 &self.edns
64 }
65
66 pub fn destructive_emit(
68 mut self,
69 encoder: &mut BinEncoder<'_>,
70 ) -> Result<ResponseInfo, ProtoError> {
71 let mut name_servers = self.name_servers.chain(self.soa);
73
74 message::emit_message_parts(
75 &self.header,
76 &mut self.queries.as_emit_and_count(),
77 &mut self.answers,
78 &mut name_servers,
79 &mut self.additionals,
80 self.edns.as_ref(),
81 &self.sig0,
82 encoder,
83 )
84 .map(Into::into)
85 }
86}
87
88pub struct MessageResponseBuilder<'q> {
90 queries: &'q Queries,
91 sig0: Option<Vec<Record>>,
92 edns: Option<Edns>,
93}
94
95impl<'q> MessageResponseBuilder<'q> {
96 pub(crate) fn new(queries: &'q Queries) -> Self {
102 MessageResponseBuilder {
103 queries,
104 sig0: None,
105 edns: None,
106 }
107 }
108
109 pub fn from_message_request(message: &'q MessageRequest) -> Self {
115 Self::new(message.raw_queries())
116 }
117
118 pub fn edns(&mut self, edns: Edns) -> &mut Self {
120 self.edns = Some(edns);
121 self
122 }
123
124 pub fn build<'a, A, N, S, D>(
130 self,
131 header: Header,
132 answers: A,
133 name_servers: N,
134 soa: S,
135 additionals: D,
136 ) -> MessageResponse<'q, 'a, A::IntoIter, N::IntoIter, S::IntoIter, D::IntoIter>
137 where
138 A: IntoIterator<Item = &'a Record> + Send + 'a,
139 A::IntoIter: Send,
140 N: IntoIterator<Item = &'a Record> + Send + 'a,
141 N::IntoIter: Send,
142 S: IntoIterator<Item = &'a Record> + Send + 'a,
143 S::IntoIter: Send,
144 D: IntoIterator<Item = &'a Record> + Send + 'a,
145 D::IntoIter: Send,
146 {
147 MessageResponse {
148 header,
149 queries: self.queries,
150 answers: answers.into_iter(),
151 name_servers: name_servers.into_iter(),
152 soa: soa.into_iter(),
153 additionals: additionals.into_iter(),
154 sig0: self.sig0.unwrap_or_default(),
155 edns: self.edns,
156 }
157 }
158
159 pub fn build_no_records<'a>(
161 self,
162 header: Header,
163 ) -> MessageResponse<
164 'q,
165 'a,
166 impl Iterator<Item = &'a Record> + Send + 'a,
167 impl Iterator<Item = &'a Record> + Send + 'a,
168 impl Iterator<Item = &'a Record> + Send + 'a,
169 impl Iterator<Item = &'a Record> + Send + 'a,
170 > {
171 MessageResponse {
172 header,
173 queries: self.queries,
174 answers: Box::new(None.into_iter()),
175 name_servers: Box::new(None.into_iter()),
176 soa: Box::new(None.into_iter()),
177 additionals: Box::new(None.into_iter()),
178 sig0: self.sig0.unwrap_or_default(),
179 edns: self.edns,
180 }
181 }
182
183 pub fn error_msg<'a>(
191 self,
192 request_header: &Header,
193 response_code: ResponseCode,
194 ) -> MessageResponse<
195 'q,
196 'a,
197 impl Iterator<Item = &'a Record> + Send + 'a,
198 impl Iterator<Item = &'a Record> + Send + 'a,
199 impl Iterator<Item = &'a Record> + Send + 'a,
200 impl Iterator<Item = &'a Record> + Send + 'a,
201 > {
202 let mut header = Header::response_from_request(request_header);
203 header.set_response_code(response_code);
204
205 MessageResponse {
206 header,
207 queries: self.queries,
208 answers: Box::new(None.into_iter()),
209 name_servers: Box::new(None.into_iter()),
210 soa: Box::new(None.into_iter()),
211 additionals: Box::new(None.into_iter()),
212 sig0: self.sig0.unwrap_or_default(),
213 edns: self.edns,
214 }
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use std::iter;
221 use std::net::Ipv4Addr;
222 use std::str::FromStr;
223
224 use crate::proto::op::{Header, Message};
225 use crate::proto::rr::{DNSClass, Name, RData, Record};
226 use crate::proto::serialize::binary::BinEncoder;
227
228 use super::*;
229
230 #[test]
231 fn test_truncation_ridiculous_number_answers() {
232 let mut buf = Vec::with_capacity(512);
233 {
234 let mut encoder = BinEncoder::new(&mut buf);
235 encoder.set_max_size(512);
236
237 let answer = Record::from_rdata(
238 Name::from_str("www.example.com.").unwrap(),
239 0,
240 RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
241 )
242 .set_dns_class(DNSClass::NONE)
243 .clone();
244
245 let message = MessageResponse {
246 header: Header::new(),
247 queries: &Queries::empty(),
248 answers: iter::repeat(&answer),
249 name_servers: iter::once(&answer),
250 soa: iter::once(&answer),
251 additionals: iter::once(&answer),
252 sig0: vec![],
253 edns: None,
254 };
255
256 message
257 .destructive_emit(&mut encoder)
258 .expect("failed to encode");
259 }
260
261 let response = Message::from_vec(&buf).expect("failed to decode");
262 assert!(response.header().truncated());
263 assert!(response.answer_count() > 1);
264 assert_eq!(response.name_server_count(), 0);
266 }
267
268 #[test]
269 fn test_truncation_ridiculous_number_nameservers() {
270 let mut buf = Vec::with_capacity(512);
271 {
272 let mut encoder = BinEncoder::new(&mut buf);
273 encoder.set_max_size(512);
274
275 let answer = Record::from_rdata(
276 Name::from_str("www.example.com.").unwrap(),
277 0,
278 RData::A(Ipv4Addr::new(93, 184, 215, 14).into()),
279 )
280 .set_dns_class(DNSClass::NONE)
281 .clone();
282
283 let message = MessageResponse {
284 header: Header::new(),
285 queries: &Queries::empty(),
286 answers: iter::empty(),
287 name_servers: iter::repeat(&answer),
288 soa: iter::repeat(&answer),
289 additionals: iter::repeat(&answer),
290 sig0: vec![],
291 edns: None,
292 };
293
294 message
295 .destructive_emit(&mut encoder)
296 .expect("failed to encode");
297 }
298
299 let response = Message::from_vec(&buf).expect("failed to decode");
300 assert!(response.header().truncated());
301 assert_eq!(response.answer_count(), 0);
302 assert!(response.name_server_count() > 1);
303 }
304
305 #[test]
337 fn bad_length_of_named_pointers() {
338 use hickory_proto::serialize::binary::BinDecodable;
339
340 let mut buf = Vec::with_capacity(512);
341 let mut encoder = BinEncoder::new(&mut buf);
342
343 let data: &[u8] = &[
344 0x08u8, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc0, 0x00,
345 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
346 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
347 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
348 ];
349
350 let msg = MessageRequest::from_bytes(data).unwrap();
351
352 eprintln!("queries: {:?}", msg.queries());
353
354 MessageResponseBuilder::new(msg.raw_queries())
355 .build_no_records(Header::response_from_request(msg.header()))
356 .destructive_emit(&mut encoder)
357 .unwrap();
358 }
359}