hickory_server/authority/
message_request.rs1use std::iter::once;
9
10use crate::proto::{
11 error::*,
12 op::{
13 message::{self, EmitAndCount},
14 Edns, Header, LowerQuery, Message, MessageType, OpCode, ResponseCode,
15 },
16 rr::Record,
17 serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder},
18};
19
20#[derive(Debug, PartialEq)]
22pub struct MessageRequest {
23 header: Header,
24 query: WireQuery,
25 answers: Vec<Record>,
26 name_servers: Vec<Record>,
27 additionals: Vec<Record>,
28 sig0: Vec<Record>,
29 edns: Option<Edns>,
30}
31
32impl MessageRequest {
33 pub fn header(&self) -> &Header {
35 &self.header
36 }
37
38 pub fn id(&self) -> u16 {
40 self.header.id()
41 }
42
43 pub fn message_type(&self) -> MessageType {
45 self.header.message_type()
46 }
47
48 pub fn op_code(&self) -> OpCode {
50 self.header.op_code()
51 }
52
53 pub fn authoritative(&self) -> bool {
55 self.header.authoritative()
56 }
57
58 pub fn truncated(&self) -> bool {
60 self.header.truncated()
61 }
62
63 pub fn recursion_desired(&self) -> bool {
65 self.header.recursion_desired()
66 }
67
68 pub fn recursion_available(&self) -> bool {
70 self.header.recursion_available()
71 }
72
73 pub fn authentic_data(&self) -> bool {
75 self.header.authentic_data()
76 }
77
78 pub fn checking_disabled(&self) -> bool {
80 self.header.checking_disabled()
81 }
82
83 pub fn response_code(&self) -> ResponseCode {
88 self.header.response_code()
89 }
90
91 pub fn query(&self) -> &LowerQuery {
95 &self.query.query
96 }
97
98 pub fn answers(&self) -> &[Record] {
102 &self.answers
103 }
104
105 pub fn name_servers(&self) -> &[Record] {
111 &self.name_servers
112 }
113
114 pub fn additionals(&self) -> &[Record] {
119 &self.additionals
120 }
121
122 pub fn edns(&self) -> Option<&Edns> {
152 self.edns.as_ref()
153 }
154
155 pub fn sig0(&self) -> &[Record] {
157 &self.sig0
158 }
159
160 pub fn max_payload(&self) -> u16 {
164 let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
165 if max_size < 512 {
166 512
167 } else {
168 max_size
169 }
170 }
171
172 pub fn version(&self) -> u8 {
176 self.edns.as_ref().map_or(0, Edns::version)
177 }
178
179 pub(crate) fn raw_query(&self) -> &WireQuery {
181 &self.query
182 }
183}
184
185impl<'q> BinDecodable<'q> for MessageRequest {
186 fn read(decoder: &mut BinDecoder<'q>) -> ProtoResult<Self> {
189 let mut header = Header::read(decoder)?;
190
191 let mut try_parse_rest = move || {
192 let query_count = header.query_count() as usize;
194 let answer_count = header.answer_count() as usize;
195 let name_server_count = header.name_server_count() as usize;
196 let additional_count = header.additional_count() as usize;
197
198 let queries = Queries::read(decoder, query_count)?;
199 let query = queries.try_into_query()?;
200 let (answers, _, _) = Message::read_records(decoder, answer_count, false)?;
201 let (name_servers, _, _) = Message::read_records(decoder, name_server_count, false)?;
202 let (additionals, edns, sig0) = Message::read_records(decoder, additional_count, true)?;
203
204 if let Some(edns) = &edns {
206 let high_response_code = edns.rcode_high();
207 header.merge_response_code(high_response_code);
208 }
209
210 Ok(Self {
211 header,
212 query,
213 answers,
214 name_servers,
215 additionals,
216 sig0,
217 edns,
218 })
219 };
220
221 match try_parse_rest() {
222 Ok(message) => Ok(message),
223 Err(e) => Err(ProtoErrorKind::FormError {
224 header,
225 error: Box::new(e),
226 }
227 .into()),
228 }
229 }
230}
231
232#[derive(Debug, PartialEq, Eq)]
234pub struct Queries {
235 queries: Vec<LowerQuery>,
236 original: Box<[u8]>,
237}
238
239impl Queries {
240 fn read_queries(decoder: &mut BinDecoder<'_>, count: usize) -> ProtoResult<Vec<LowerQuery>> {
241 let mut queries = Vec::with_capacity(count);
242 for _ in 0..count {
243 queries.push(LowerQuery::read(decoder)?);
244 }
245 Ok(queries)
246 }
247
248 pub fn read(decoder: &mut BinDecoder<'_>, num_queries: usize) -> ProtoResult<Self> {
250 let queries_start = decoder.index();
251 let queries = Self::read_queries(decoder, num_queries)?;
252 let original = decoder
253 .slice_from(queries_start)?
254 .to_vec()
255 .into_boxed_slice();
256
257 Ok(Self { queries, original })
258 }
259
260 pub fn len(&self) -> usize {
262 self.queries.len()
263 }
264
265 pub fn is_empty(&self) -> bool {
267 self.queries.is_empty()
268 }
269
270 pub fn as_bytes(&self) -> &[u8] {
272 self.original.as_ref()
273 }
274
275 pub(crate) fn as_emit_and_count(&self) -> QueriesEmitAndCount<'_> {
276 QueriesEmitAndCount {
277 length: self.queries.len(),
278 first_query: self.queries.first(),
281 cached_serialized: self.original.as_ref(),
282 }
283 }
284
285 pub(crate) fn try_into_query(mut self) -> Result<WireQuery, ProtoError> {
287 let count = self.queries.len();
288 if count == 1 {
289 let query = self.queries.pop().expect("should have been at least one");
290
291 Ok(WireQuery {
292 query,
293 original: self.original,
294 })
295 } else {
296 Err(ProtoErrorKind::BadQueryCount(count).into())
297 }
298 }
299}
300
301#[derive(Debug, PartialEq)]
303pub(crate) struct WireQuery {
304 query: LowerQuery,
305 original: Box<[u8]>,
306}
307
308impl WireQuery {
309 pub(crate) fn as_emit_and_count(&self) -> QueriesEmitAndCount<'_> {
310 QueriesEmitAndCount {
311 length: 1,
312 first_query: Some(&self.query),
313 cached_serialized: self.original.as_ref(),
314 }
315 }
316}
317
318pub(crate) struct QueriesEmitAndCount<'q> {
319 length: usize,
321 first_query: Option<&'q LowerQuery>,
323 cached_serialized: &'q [u8],
325}
326
327impl EmitAndCount for QueriesEmitAndCount<'_> {
328 fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
329 let original_offset = encoder.offset();
330 encoder.emit_vec(self.cached_serialized)?;
331 if !encoder.is_canonical_names() {
332 if let Some(query) = self.first_query {
333 encoder.store_label_pointer(
334 original_offset,
335 original_offset + query.original().name().len(),
336 )
337 }
338 }
339 Ok(self.length)
340 }
341}
342
343impl BinEncodable for MessageRequest {
344 fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
345 message::emit_message_parts(
346 &self.header,
347 &mut once(&self.query.query),
350 &mut self.answers.iter(),
351 &mut self.name_servers.iter(),
352 &mut self.additionals.iter(),
353 self.edns.as_ref(),
354 &self.sig0,
355 encoder,
356 )?;
357
358 Ok(())
359 }
360}
361
362pub trait UpdateRequest {
364 fn id(&self) -> u16;
366
367 fn zone(&self) -> &LowerQuery;
369
370 fn prerequisites(&self) -> &[Record];
372
373 fn updates(&self) -> &[Record];
375
376 fn additionals(&self) -> &[Record];
378
379 fn sig0(&self) -> &[Record];
381}
382
383impl UpdateRequest for MessageRequest {
384 fn id(&self) -> u16 {
385 Self::id(self)
386 }
387
388 fn zone(&self) -> &LowerQuery {
389 self.query()
390 }
391
392 fn prerequisites(&self) -> &[Record] {
393 self.answers()
394 }
395
396 fn updates(&self) -> &[Record] {
397 self.name_servers()
398 }
399
400 fn additionals(&self) -> &[Record] {
401 self.additionals()
402 }
403
404 fn sig0(&self) -> &[Record] {
405 self.sig0()
406 }
407}