1use alloc::{boxed::Box, fmt, vec::Vec};
11use core::{iter, mem, ops::Deref};
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15use tracing::{debug, warn};
16
17use crate::{
18 error::*,
19 op::{Edns, Header, MessageType, OpCode, Query, ResponseCode},
20 rr::{Record, RecordType},
21 serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder, EncodeMode},
22 xfer::DnsResponse,
23};
24
25#[derive(Clone, Debug, PartialEq, Eq, Default)]
68#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
69pub struct Message {
70 header: Header,
71 queries: Vec<Query>,
72 answers: Vec<Record>,
73 name_servers: Vec<Record>,
74 additionals: Vec<Record>,
75 signature: Vec<Record>,
76 edns: Option<Edns>,
77}
78
79impl Message {
80 pub fn new() -> Self {
82 Self {
83 header: Header::new(),
84 queries: Vec::new(),
85 answers: Vec::new(),
86 name_servers: Vec::new(),
87 additionals: Vec::new(),
88 signature: Vec::new(),
89 edns: None,
90 }
91 }
92
93 pub fn error_msg(id: u16, op_code: OpCode, response_code: ResponseCode) -> Self {
101 let mut message = Self::new();
102 message
103 .set_message_type(MessageType::Response)
104 .set_id(id)
105 .set_response_code(response_code)
106 .set_op_code(op_code);
107
108 message
109 }
110
111 pub fn truncate(&self) -> Self {
113 let mut header = self.header;
115 header.set_truncated(true);
116 header
117 .set_additional_count(0)
118 .set_answer_count(0)
119 .set_name_server_count(0);
120
121 let mut msg = Self::new();
122 msg.add_queries(self.queries().iter().cloned());
125 if let Some(edns) = self.extensions().clone() {
126 msg.set_edns(edns);
127 }
128 msg.set_header(header);
130
131 msg
133 }
134
135 pub fn set_header(&mut self, header: Header) -> &mut Self {
137 self.header = header;
138 self
139 }
140
141 pub fn set_id(&mut self, id: u16) -> &mut Self {
143 self.header.set_id(id);
144 self
145 }
146
147 pub fn set_message_type(&mut self, message_type: MessageType) -> &mut Self {
149 self.header.set_message_type(message_type);
150 self
151 }
152
153 pub fn set_op_code(&mut self, op_code: OpCode) -> &mut Self {
155 self.header.set_op_code(op_code);
156 self
157 }
158
159 pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
161 self.header.set_authoritative(authoritative);
162 self
163 }
164
165 pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
167 self.header.set_truncated(truncated);
168 self
169 }
170
171 pub fn set_recursion_desired(&mut self, recursion_desired: bool) -> &mut Self {
173 self.header.set_recursion_desired(recursion_desired);
174 self
175 }
176
177 pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
179 self.header.set_recursion_available(recursion_available);
180 self
181 }
182
183 pub fn set_authentic_data(&mut self, authentic_data: bool) -> &mut Self {
185 self.header.set_authentic_data(authentic_data);
186 self
187 }
188
189 pub fn set_checking_disabled(&mut self, checking_disabled: bool) -> &mut Self {
191 self.header.set_checking_disabled(checking_disabled);
192 self
193 }
194
195 pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
197 self.header.set_response_code(response_code);
198 self
199 }
200
201 pub fn set_query_count(&mut self, query_count: u16) -> &mut Self {
206 self.header.set_query_count(query_count);
207 self
208 }
209
210 pub fn set_answer_count(&mut self, answer_count: u16) -> &mut Self {
215 self.header.set_answer_count(answer_count);
216 self
217 }
218
219 pub fn set_name_server_count(&mut self, name_server_count: u16) -> &mut Self {
224 self.header.set_name_server_count(name_server_count);
225 self
226 }
227
228 pub fn set_additional_count(&mut self, additional_count: u16) -> &mut Self {
233 self.header.set_additional_count(additional_count);
234 self
235 }
236
237 pub fn add_query(&mut self, query: Query) -> &mut Self {
239 self.queries.push(query);
240 self
241 }
242
243 pub fn add_queries<Q, I>(&mut self, queries: Q) -> &mut Self
245 where
246 Q: IntoIterator<Item = Query, IntoIter = I>,
247 I: Iterator<Item = Query>,
248 {
249 for query in queries {
250 self.add_query(query);
251 }
252
253 self
254 }
255
256 pub fn add_answer(&mut self, record: Record) -> &mut Self {
258 self.answers.push(record);
259 self
260 }
261
262 pub fn add_answers<R, I>(&mut self, records: R) -> &mut Self
264 where
265 R: IntoIterator<Item = Record, IntoIter = I>,
266 I: Iterator<Item = Record>,
267 {
268 for record in records {
269 self.add_answer(record);
270 }
271
272 self
273 }
274
275 pub fn insert_answers(&mut self, records: Vec<Record>) {
281 assert!(self.answers.is_empty());
282 self.answers = records;
283 }
284
285 pub fn add_name_server(&mut self, record: Record) -> &mut Self {
287 self.name_servers.push(record);
288 self
289 }
290
291 pub fn add_name_servers<R, I>(&mut self, records: R) -> &mut Self
293 where
294 R: IntoIterator<Item = Record, IntoIter = I>,
295 I: Iterator<Item = Record>,
296 {
297 for record in records {
298 self.add_name_server(record);
299 }
300
301 self
302 }
303
304 pub fn insert_name_servers(&mut self, records: Vec<Record>) {
310 assert!(self.name_servers.is_empty());
311 self.name_servers = records;
312 }
313
314 pub fn add_additional(&mut self, record: Record) -> &mut Self {
316 self.additionals.push(record);
317 self
318 }
319
320 pub fn add_additionals<R, I>(&mut self, records: R) -> &mut Self
322 where
323 R: IntoIterator<Item = Record, IntoIter = I>,
324 I: Iterator<Item = Record>,
325 {
326 for record in records {
327 self.add_additional(record);
328 }
329
330 self
331 }
332
333 pub fn insert_additionals(&mut self, records: Vec<Record>) {
339 assert!(self.additionals.is_empty());
340 self.additionals = records;
341 }
342
343 pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
345 self.edns = Some(edns);
346 self
347 }
348
349 #[cfg(feature = "__dnssec")]
353 pub fn add_sig0(&mut self, record: Record) -> &mut Self {
354 assert_eq!(RecordType::SIG, record.record_type());
355 self.signature.push(record);
356 self
357 }
358
359 #[cfg(feature = "__dnssec")]
363 pub fn add_tsig(&mut self, record: Record) -> &mut Self {
364 assert_eq!(RecordType::TSIG, record.record_type());
365 self.signature.push(record);
366 self
367 }
368
369 pub fn header(&self) -> &Header {
371 &self.header
372 }
373
374 pub fn id(&self) -> u16 {
376 self.header.id()
377 }
378
379 pub fn message_type(&self) -> MessageType {
381 self.header.message_type()
382 }
383
384 pub fn op_code(&self) -> OpCode {
386 self.header.op_code()
387 }
388
389 pub fn authoritative(&self) -> bool {
391 self.header.authoritative()
392 }
393
394 pub fn truncated(&self) -> bool {
396 self.header.truncated()
397 }
398
399 pub fn recursion_desired(&self) -> bool {
401 self.header.recursion_desired()
402 }
403
404 pub fn recursion_available(&self) -> bool {
406 self.header.recursion_available()
407 }
408
409 pub fn authentic_data(&self) -> bool {
411 self.header.authentic_data()
412 }
413
414 pub fn checking_disabled(&self) -> bool {
416 self.header.checking_disabled()
417 }
418
419 pub fn response_code(&self) -> ResponseCode {
424 self.header.response_code()
425 }
426
427 pub fn query(&self) -> Option<&Query> {
432 self.queries.first()
433 }
434
435 pub fn queries(&self) -> &[Query] {
439 &self.queries
440 }
441
442 pub fn queries_mut(&mut self) -> &mut Vec<Query> {
444 &mut self.queries
445 }
446
447 pub fn take_queries(&mut self) -> Vec<Query> {
449 mem::take(&mut self.queries)
450 }
451
452 pub fn answers(&self) -> &[Record] {
456 &self.answers
457 }
458
459 pub fn answers_mut(&mut self) -> &mut Vec<Record> {
461 &mut self.answers
462 }
463
464 pub fn take_answers(&mut self) -> Vec<Record> {
466 mem::take(&mut self.answers)
467 }
468
469 pub fn name_servers(&self) -> &[Record] {
475 &self.name_servers
476 }
477
478 pub fn name_servers_mut(&mut self) -> &mut Vec<Record> {
480 &mut self.name_servers
481 }
482
483 pub fn take_name_servers(&mut self) -> Vec<Record> {
485 mem::take(&mut self.name_servers)
486 }
487
488 pub fn additionals(&self) -> &[Record] {
493 &self.additionals
494 }
495
496 pub fn additionals_mut(&mut self) -> &mut Vec<Record> {
498 &mut self.additionals
499 }
500
501 pub fn take_additionals(&mut self) -> Vec<Record> {
503 mem::take(&mut self.additionals)
504 }
505
506 pub fn all_sections(&self) -> impl Iterator<Item = &Record> {
508 self.answers
509 .iter()
510 .chain(self.name_servers().iter())
511 .chain(self.additionals.iter())
512 }
513
514 #[deprecated(note = "Please use `extensions()`")]
544 pub fn edns(&self) -> Option<&Edns> {
545 self.edns.as_ref()
546 }
547
548 #[deprecated(
550 note = "Please use `extensions_mut()`. You can chain `.get_or_insert_with(Edns::new)` to recover original behavior of adding Edns if not present"
551 )]
552 pub fn edns_mut(&mut self) -> &mut Edns {
553 if self.edns.is_none() {
554 self.set_edns(Edns::new());
555 }
556 self.edns.as_mut().unwrap()
557 }
558
559 pub fn extensions(&self) -> &Option<Edns> {
561 &self.edns
562 }
563
564 pub fn extensions_mut(&mut self) -> &mut Option<Edns> {
566 &mut self.edns
567 }
568
569 pub fn max_payload(&self) -> u16 {
573 let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
574 if max_size < 512 { 512 } else { max_size }
575 }
576
577 pub fn version(&self) -> u8 {
581 self.edns.as_ref().map_or(0, Edns::version)
582 }
583
584 pub fn sig0(&self) -> &[Record] {
601 &self.signature
602 }
603
604 pub fn signature(&self) -> &[Record] {
619 &self.signature
620 }
621
622 pub fn take_signature(&mut self) -> Vec<Record> {
624 mem::take(&mut self.signature)
625 }
626
627 #[cfg(test)]
631 pub fn update_counts(&mut self) -> &mut Self {
632 self.header = update_header_counts(
633 &self.header,
634 self.truncated(),
635 HeaderCounts {
636 query_count: self.queries.len(),
637 answer_count: self.answers.len(),
638 nameserver_count: self.name_servers.len(),
639 additional_count: self.additionals.len(),
640 },
641 );
642 self
643 }
644
645 pub fn read_queries(decoder: &mut BinDecoder<'_>, count: usize) -> ProtoResult<Vec<Query>> {
647 let mut queries = Vec::with_capacity(count);
648 for _ in 0..count {
649 queries.push(Query::read(decoder)?);
650 }
651 Ok(queries)
652 }
653
654 #[cfg_attr(not(feature = "__dnssec"), allow(unused_mut))]
660 pub fn read_records(
661 decoder: &mut BinDecoder<'_>,
662 count: usize,
663 is_additional: bool,
664 ) -> ProtoResult<(Vec<Record>, Option<Edns>, Vec<Record>)> {
665 let mut records: Vec<Record> = Vec::with_capacity(count);
666 let mut edns: Option<Edns> = None;
667 let mut sigs: Vec<Record> = Vec::with_capacity(if is_additional { 1 } else { 0 });
668
669 let mut saw_sig0 = false;
671 let mut saw_tsig = false;
673 for _ in 0..count {
674 let record = Record::read(decoder)?;
675 if saw_tsig {
676 return Err("tsig must be final resource record".into());
677 } if !is_additional {
679 if saw_sig0 {
680 return Err("sig0 must be final resource record".into());
681 } records.push(record)
683 } else {
684 match record.record_type() {
685 #[cfg(feature = "__dnssec")]
686 RecordType::SIG => {
687 saw_sig0 = true;
688 sigs.push(record);
689 }
690 #[cfg(feature = "__dnssec")]
691 RecordType::TSIG => {
692 if saw_sig0 {
693 return Err("sig0 must be final resource record".into());
694 } saw_tsig = true;
696 sigs.push(record);
697 }
698 RecordType::OPT => {
699 if saw_sig0 {
700 return Err("sig0 must be final resource record".into());
701 } if edns.is_some() {
703 return Err("more than one edns record present".into());
704 }
705 edns = Some((&record).into());
706 }
707 _ => {
708 if saw_sig0 {
709 return Err("sig0 must be final resource record".into());
710 } records.push(record);
712 }
713 }
714 }
715 }
716
717 Ok((records, edns, sigs))
718 }
719
720 pub fn from_vec(buffer: &[u8]) -> ProtoResult<Self> {
722 let mut decoder = BinDecoder::new(buffer);
723 Self::read(&mut decoder)
724 }
725
726 pub fn to_vec(&self) -> Result<Vec<u8>, ProtoError> {
728 let mut buffer = Vec::with_capacity(512);
732 {
733 let mut encoder = BinEncoder::new(&mut buffer);
734 self.emit(&mut encoder)?;
735 }
736
737 Ok(buffer)
738 }
739
740 #[allow(clippy::match_single_binding)]
744 pub fn finalize(
745 &mut self,
746 finalizer: &dyn MessageFinalizer,
747 inception_time: u32,
748 ) -> ProtoResult<Option<MessageVerifier>> {
749 debug!("finalizing message: {:?}", self);
750 let (finals, verifier): (Vec<Record>, Option<MessageVerifier>) =
751 finalizer.finalize_message(self, inception_time)?;
752
753 for fin in finals {
755 match fin.record_type() {
756 #[cfg(feature = "__dnssec")]
758 RecordType::SIG => self.add_sig0(fin),
759 #[cfg(feature = "__dnssec")]
760 RecordType::TSIG => self.add_tsig(fin),
761 _ => self.add_additional(fin),
762 };
763 }
764
765 Ok(verifier)
766 }
767
768 pub fn into_parts(self) -> MessageParts {
770 self.into()
771 }
772}
773
774impl From<MessageParts> for Message {
775 fn from(msg: MessageParts) -> Self {
776 let MessageParts {
777 header,
778 queries,
779 answers,
780 name_servers,
781 additionals,
782 sig0,
783 edns,
784 } = msg;
785 Self {
786 header,
787 queries,
788 answers,
789 name_servers,
790 additionals,
791 signature: sig0,
792 edns,
793 }
794 }
795}
796
797impl Deref for Message {
798 type Target = Header;
799
800 fn deref(&self) -> &Self::Target {
801 &self.header
802 }
803}
804
805#[derive(Clone, Debug, PartialEq, Eq, Default)]
814pub struct MessageParts {
815 pub header: Header,
817 pub queries: Vec<Query>,
819 pub answers: Vec<Record>,
821 pub name_servers: Vec<Record>,
823 pub additionals: Vec<Record>,
825 pub sig0: Vec<Record>,
829 pub edns: Option<Edns>,
831}
832
833impl From<Message> for MessageParts {
834 fn from(msg: Message) -> Self {
835 let Message {
836 header,
837 queries,
838 answers,
839 name_servers,
840 additionals,
841 signature,
842 edns,
843 } = msg;
844 Self {
845 header,
846 queries,
847 answers,
848 name_servers,
849 additionals,
850 sig0: signature,
851 edns,
852 }
853 }
854}
855
856#[derive(Clone, Copy, Debug)]
860pub struct HeaderCounts {
861 pub query_count: usize,
863 pub answer_count: usize,
865 pub nameserver_count: usize,
867 pub additional_count: usize,
869}
870
871pub fn update_header_counts(
873 current_header: &Header,
874 is_truncated: bool,
875 counts: HeaderCounts,
876) -> Header {
877 assert!(counts.query_count <= u16::MAX as usize);
878 assert!(counts.answer_count <= u16::MAX as usize);
879 assert!(counts.nameserver_count <= u16::MAX as usize);
880 assert!(counts.additional_count <= u16::MAX as usize);
881
882 let mut header = *current_header;
884 header
885 .set_query_count(counts.query_count as u16)
886 .set_answer_count(counts.answer_count as u16)
887 .set_name_server_count(counts.nameserver_count as u16)
888 .set_additional_count(counts.additional_count as u16)
889 .set_truncated(is_truncated);
890
891 header
892}
893
894pub type MessageVerifier = Box<dyn FnMut(&[u8]) -> ProtoResult<DnsResponse> + Send>;
896
897pub trait MessageFinalizer: Send + Sync + 'static {
902 fn finalize_message(
914 &self,
915 message: &Message,
916 current_time: u32,
917 ) -> ProtoResult<(Vec<Record>, Option<MessageVerifier>)>;
918
919 fn should_finalize_message(&self, message: &Message) -> bool {
922 [OpCode::Update, OpCode::Notify].contains(&message.op_code())
923 || message
924 .queries()
925 .iter()
926 .any(|q| [RecordType::AXFR, RecordType::IXFR].contains(&q.query_type()))
927 }
928}
929
930pub fn count_was_truncated(result: ProtoResult<usize>) -> ProtoResult<(usize, bool)> {
932 match result {
933 Ok(count) => Ok((count, false)),
934 Err(e) => match e.kind() {
935 ProtoErrorKind::NotAllRecordsWritten { count } => Ok((*count, true)),
936 _ => Err(e),
937 },
938 }
939}
940
941pub trait EmitAndCount {
943 fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize>;
945}
946
947impl<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable> EmitAndCount for I {
948 fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
949 encoder.emit_all(self)
950 }
951}
952
953#[allow(clippy::too_many_arguments)]
959pub fn emit_message_parts<Q, A, N, D>(
960 header: &Header,
961 queries: &mut Q,
962 answers: &mut A,
963 name_servers: &mut N,
964 additionals: &mut D,
965 edns: Option<&Edns>,
966 signature: &[Record],
967 encoder: &mut BinEncoder<'_>,
968) -> ProtoResult<Header>
969where
970 Q: EmitAndCount,
971 A: EmitAndCount,
972 N: EmitAndCount,
973 D: EmitAndCount,
974{
975 let include_signature = encoder.mode() != EncodeMode::Signing;
976 let place = encoder.place::<Header>()?;
977
978 let query_count = queries.emit(encoder)?;
979 let answer_count = count_was_truncated(answers.emit(encoder))?;
982 let nameserver_count = count_was_truncated(name_servers.emit(encoder))?;
983 let mut additional_count = count_was_truncated(additionals.emit(encoder))?;
984
985 if let Some(mut edns) = edns.cloned() {
986 edns.set_rcode_high(header.response_code().high());
988
989 let count = count_was_truncated(encoder.emit_all(iter::once(&Record::from(&edns))))?;
990 additional_count.0 += count.0;
991 additional_count.1 |= count.1;
992 } else if header.response_code().high() > 0 {
993 warn!(
994 "response code: {} for request: {} requires EDNS but none available",
995 header.response_code(),
996 header.id()
997 );
998 }
999
1000 if include_signature {
1004 let count = count_was_truncated(encoder.emit_all(signature.iter()))?;
1005 additional_count.0 += count.0;
1006 additional_count.1 |= count.1;
1007 }
1008
1009 let counts = HeaderCounts {
1010 query_count,
1011 answer_count: answer_count.0,
1012 nameserver_count: nameserver_count.0,
1013 additional_count: additional_count.0,
1014 };
1015 let was_truncated =
1016 header.truncated() || answer_count.1 || nameserver_count.1 || additional_count.1;
1017
1018 let final_header = update_header_counts(header, was_truncated, counts);
1019 place.replace(encoder, final_header)?;
1020 Ok(final_header)
1021}
1022
1023impl BinEncodable for Message {
1024 fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
1025 emit_message_parts(
1026 &self.header,
1027 &mut self.queries.iter(),
1028 &mut self.answers.iter(),
1029 &mut self.name_servers.iter(),
1030 &mut self.additionals.iter(),
1031 self.edns.as_ref(),
1032 &self.signature,
1033 encoder,
1034 )?;
1035
1036 Ok(())
1037 }
1038}
1039
1040impl<'r> BinDecodable<'r> for Message {
1041 fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
1042 let mut header = Header::read(decoder)?;
1043
1044 let count = header.query_count() as usize;
1049 let mut queries = Vec::with_capacity(count);
1050 for _ in 0..count {
1051 queries.push(Query::read(decoder)?);
1052 }
1053
1054 let answer_count = header.answer_count() as usize;
1056 let name_server_count = header.name_server_count() as usize;
1057 let additional_count = header.additional_count() as usize;
1058
1059 let (answers, _, _) = Self::read_records(decoder, answer_count, false)?;
1060 let (name_servers, _, _) = Self::read_records(decoder, name_server_count, false)?;
1061 let (additionals, edns, signature) = Self::read_records(decoder, additional_count, true)?;
1062
1063 if let Some(edns) = &edns {
1065 let high_response_code = edns.rcode_high();
1066 header.merge_response_code(high_response_code);
1067 }
1068
1069 Ok(Self {
1070 header,
1071 queries,
1072 answers,
1073 name_servers,
1074 additionals,
1075 signature,
1076 edns,
1077 })
1078 }
1079}
1080
1081impl fmt::Display for Message {
1082 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1083 let write_query = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1084 for d in slice {
1085 writeln!(f, ";; {d}")?;
1086 }
1087
1088 Ok(())
1089 };
1090
1091 let write_slice = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1092 for d in slice {
1093 writeln!(f, "{d}")?;
1094 }
1095
1096 Ok(())
1097 };
1098
1099 writeln!(f, "; header {header}", header = self.header())?;
1100
1101 if let Some(edns) = self.extensions() {
1102 writeln!(f, "; edns {edns}")?;
1103 }
1104
1105 writeln!(f, "; query")?;
1106 write_query(self.queries(), f)?;
1107
1108 if self.header().message_type() == MessageType::Response
1109 || self.header().op_code() == OpCode::Update
1110 {
1111 writeln!(f, "; answers {}", self.answer_count())?;
1112 write_slice(self.answers(), f)?;
1113 writeln!(f, "; nameservers {}", self.name_server_count())?;
1114 write_slice(self.name_servers(), f)?;
1115 writeln!(f, "; additionals {}", self.additional_count())?;
1116 write_slice(self.additionals(), f)?;
1117 }
1118
1119 Ok(())
1120 }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125 use super::*;
1126
1127 #[test]
1128 fn test_emit_and_read_header() {
1129 let mut message = Message::new();
1130 message
1131 .set_id(10)
1132 .set_message_type(MessageType::Response)
1133 .set_op_code(OpCode::Update)
1134 .set_authoritative(true)
1135 .set_truncated(false)
1136 .set_recursion_desired(true)
1137 .set_recursion_available(true)
1138 .set_response_code(ResponseCode::ServFail);
1139
1140 test_emit_and_read(message);
1141 }
1142
1143 #[test]
1144 fn test_emit_and_read_query() {
1145 let mut message = Message::new();
1146 message
1147 .set_id(10)
1148 .set_message_type(MessageType::Response)
1149 .set_op_code(OpCode::Update)
1150 .set_authoritative(true)
1151 .set_truncated(true)
1152 .set_recursion_desired(true)
1153 .set_recursion_available(true)
1154 .set_response_code(ResponseCode::ServFail)
1155 .add_query(Query::new())
1156 .update_counts(); test_emit_and_read(message);
1159 }
1160
1161 #[test]
1162 fn test_emit_and_read_records() {
1163 let mut message = Message::new();
1164 message
1165 .set_id(10)
1166 .set_message_type(MessageType::Response)
1167 .set_op_code(OpCode::Update)
1168 .set_authoritative(true)
1169 .set_truncated(true)
1170 .set_recursion_desired(true)
1171 .set_recursion_available(true)
1172 .set_authentic_data(true)
1173 .set_checking_disabled(true)
1174 .set_response_code(ResponseCode::ServFail);
1175
1176 message.add_answer(Record::stub());
1177 message.add_name_server(Record::stub());
1178 message.add_additional(Record::stub());
1179 message.update_counts();
1180
1181 test_emit_and_read(message);
1182 }
1183
1184 #[cfg(test)]
1185 fn test_emit_and_read(message: Message) {
1186 let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1187 {
1188 let mut encoder = BinEncoder::new(&mut byte_vec);
1189 message.emit(&mut encoder).unwrap();
1190 }
1191
1192 let mut decoder = BinDecoder::new(&byte_vec);
1193 let got = Message::read(&mut decoder).unwrap();
1194
1195 assert_eq!(got, message);
1196 }
1197
1198 #[test]
1199 fn test_header_counts_correction_after_emit_read() {
1200 let mut message = Message::new();
1201
1202 message
1203 .set_id(10)
1204 .set_message_type(MessageType::Response)
1205 .set_op_code(OpCode::Update)
1206 .set_authoritative(true)
1207 .set_truncated(true)
1208 .set_recursion_desired(true)
1209 .set_recursion_available(true)
1210 .set_authentic_data(true)
1211 .set_checking_disabled(true)
1212 .set_response_code(ResponseCode::ServFail);
1213
1214 message.add_answer(Record::stub());
1215 message.add_name_server(Record::stub());
1216 message.add_additional(Record::stub());
1217
1218 message.set_query_count(1);
1222 message.set_answer_count(5);
1223 message.set_name_server_count(5);
1224 let got = get_message_after_emitting_and_reading(message);
1227
1228 assert_eq!(got.query_count(), 0);
1230 assert_eq!(got.answer_count(), 1);
1231 assert_eq!(got.name_server_count(), 1);
1232 assert_eq!(got.additional_count(), 1);
1233 }
1234
1235 #[cfg(test)]
1236 fn get_message_after_emitting_and_reading(message: Message) -> Message {
1237 let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1238 {
1239 let mut encoder = BinEncoder::new(&mut byte_vec);
1240 message.emit(&mut encoder).unwrap();
1241 }
1242
1243 let mut decoder = BinDecoder::new(&byte_vec);
1244
1245 Message::read(&mut decoder).unwrap()
1246 }
1247
1248 #[test]
1249 fn test_legit_message() {
1250 #[rustfmt::skip]
1251 let buf: Vec<u8> = vec![
1252 0x10, 0x00, 0x81,
1253 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x03, b'w', b'w', b'w', 0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 0x03, b'c', b'o', b'm', 0x00, 0x00, 0x01, 0x00, 0x01, 0xC0, 0x0C, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x04, 0x5D, 0xB8, 0xD7, 0x0E, ];
1268
1269 let mut decoder = BinDecoder::new(&buf);
1270 let message = Message::read(&mut decoder).unwrap();
1271
1272 assert_eq!(message.id(), 4_096);
1273
1274 let mut buf: Vec<u8> = Vec::with_capacity(512);
1275 {
1276 let mut encoder = BinEncoder::new(&mut buf);
1277 message.emit(&mut encoder).unwrap();
1278 }
1279
1280 let mut decoder = BinDecoder::new(&buf);
1281 let message = Message::read(&mut decoder).unwrap();
1282
1283 assert_eq!(message.id(), 4_096);
1284 }
1285
1286 #[test]
1287 fn rdata_zero_roundtrip() {
1288 let buf = &[
1289 160, 160, 0, 13, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
1290 ];
1291
1292 assert!(Message::from_bytes(buf).is_err());
1293 }
1294
1295 #[test]
1296 fn nsec_deserialization() {
1297 const CRASHING_MESSAGE: &[u8] = &[
1298 0, 0, 132, 0, 0, 0, 0, 1, 0, 0, 0, 1, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100,
1299 52, 50, 52, 45, 52, 102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55,
1300 56, 48, 102, 50, 98, 5, 108, 111, 99, 97, 108, 0, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4,
1301 192, 168, 1, 17, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100, 52, 50, 52, 45, 52,
1302 102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55, 56, 48, 102, 50, 98,
1303 5, 108, 111, 99, 97, 108, 0, 0, 47, 128, 1, 0, 0, 0, 120, 0, 5, 192, 70, 0, 1, 64,
1304 ];
1305
1306 Message::from_vec(CRASHING_MESSAGE).expect("failed to parse message");
1307 }
1308
1309 #[test]
1310 fn prior_to_pointer() {
1311 const MESSAGE: &[u8] = include_bytes!("../../tests/test-data/fuzz-prior-to-pointer.rdata");
1312 let message = Message::from_bytes(MESSAGE).expect("failed to parse message");
1313 let encoded = message.to_bytes().unwrap();
1314 Message::from_bytes(&encoded).expect("failed to parse encoded message");
1315 }
1316}