1use alloc::{boxed::Box, fmt, vec::Vec};
11use core::{iter, mem, ops::Deref};
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15#[cfg(feature = "__dnssec")]
16use tracing::debug;
17use tracing::warn;
18
19#[cfg(feature = "__dnssec")]
20use crate::dnssec::{DnssecIter, rdata::DNSSECRData};
21#[cfg(any(feature = "std", feature = "no-std-rand"))]
22use crate::random;
23#[cfg(feature = "__dnssec")]
24use crate::rr::{TSigVerifier, TSigner};
25use crate::{
26 error::{ProtoError, ProtoResult},
27 op::{Edns, Header, HeaderCounts, MessageType, Metadata, OpCode, Query, ResponseCode},
28 rr::{RData, Record, RecordData, RecordType, rdata::TSIG},
29 serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder, DecodeError},
30};
31
32#[non_exhaustive]
75#[derive(Clone, Debug, PartialEq, Eq)]
76#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
77pub struct Message {
78 pub metadata: Metadata,
80 pub queries: Vec<Query>,
82 pub answers: Vec<Record>,
84 pub authorities: Vec<Record>,
88 pub additionals: Vec<Record>,
90 pub signature: Option<Box<Record<TSIG>>>,
92 pub edns: Option<Edns>,
122}
123
124impl Message {
125 #[cfg(any(feature = "std", feature = "no-std-rand"))]
127 pub fn query() -> Self {
128 Self::new(random(), MessageType::Query, OpCode::Query)
129 }
130
131 pub fn error_msg(id: u16, op_code: OpCode, response_code: ResponseCode) -> Self {
139 let mut message = Self::response(id, op_code);
140 message.metadata.response_code = response_code;
141 message
142 }
143
144 pub fn response(id: u16, op_code: OpCode) -> Self {
146 Self::new(id, MessageType::Response, op_code)
147 }
148
149 pub fn new(id: u16, message_type: MessageType, op_code: OpCode) -> Self {
151 Self {
152 metadata: Metadata::new(id, message_type, op_code),
153 queries: Vec::new(),
154 answers: Vec::new(),
155 authorities: Vec::new(),
156 additionals: Vec::new(),
157 signature: None,
158 edns: None,
159 }
160 }
161
162 pub fn truncate(&self) -> Self {
164 let mut metadata = self.metadata;
166 metadata.truncation = true;
167
168 let mut msg = Self::new(0, MessageType::Query, OpCode::Query);
169 msg.metadata = metadata;
170
171 msg.add_queries(self.queries.iter().cloned());
174 if let Some(edns) = self.edns.clone() {
175 msg.set_edns(edns);
176 }
177
178 msg
180 }
181
182 pub fn maybe_strip_dnssec_records(mut self, query_has_dnssec_ok: bool) -> Self {
195 if query_has_dnssec_ok {
196 return self;
197 }
198
199 let Some(query_type) = self.queries.first().map(|q| q.query_type()) else {
200 return self; };
202
203 let predicate = |record: &Record| {
204 let record_type = record.record_type();
205 record_type == query_type || !record_type.is_dnssec()
206 };
207
208 self.answers.retain(predicate);
209 self.authorities.retain(predicate);
210 self.additionals.retain(predicate);
211
212 self
213 }
214
215 pub fn add_query(&mut self, query: Query) -> &mut Self {
217 self.queries.push(query);
218 self
219 }
220
221 pub fn add_queries<Q, I>(&mut self, queries: Q) -> &mut Self
223 where
224 Q: IntoIterator<Item = Query, IntoIter = I>,
225 I: Iterator<Item = Query>,
226 {
227 for query in queries {
228 self.add_query(query);
229 }
230
231 self
232 }
233
234 pub fn add_answer(&mut self, record: Record) -> &mut Self {
236 self.answers.push(record);
237 self
238 }
239
240 pub fn add_answers<R, I>(&mut self, records: R) -> &mut Self
242 where
243 R: IntoIterator<Item = Record, IntoIter = I>,
244 I: Iterator<Item = Record>,
245 {
246 for record in records {
247 self.add_answer(record);
248 }
249
250 self
251 }
252
253 pub fn insert_answers(&mut self, records: Vec<Record>) {
259 assert!(self.answers.is_empty());
260 self.answers = records;
261 }
262
263 pub fn add_authority(&mut self, record: Record) -> &mut Self {
265 self.authorities.push(record);
266 self
267 }
268
269 pub fn add_authorities<R, I>(&mut self, records: R) -> &mut Self
271 where
272 R: IntoIterator<Item = Record, IntoIter = I>,
273 I: Iterator<Item = Record>,
274 {
275 for record in records {
276 self.add_authority(record);
277 }
278
279 self
280 }
281
282 pub fn insert_authorities(&mut self, records: Vec<Record>) {
288 assert!(self.authorities.is_empty());
289 self.authorities = records;
290 }
291
292 pub fn add_additional(&mut self, record: Record) -> &mut Self {
294 self.additionals.push(record);
295 self
296 }
297
298 pub fn add_additionals<R, I>(&mut self, records: R) -> &mut Self
300 where
301 R: IntoIterator<Item = Record, IntoIter = I>,
302 I: Iterator<Item = Record>,
303 {
304 for record in records {
305 self.add_additional(record);
306 }
307
308 self
309 }
310
311 pub fn insert_additionals(&mut self, records: Vec<Record>) {
317 assert!(self.additionals.is_empty());
318 self.additionals = records;
319 }
320
321 pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
323 self.edns = Some(edns);
324 self
325 }
326
327 #[cfg(feature = "__dnssec")]
332 pub fn set_signature(&mut self, sig: Box<Record<TSIG>>) -> &mut Self {
333 self.signature = Some(sig);
334 self
335 }
336
337 pub fn into_response(mut self) -> Self {
339 self.metadata.message_type = MessageType::Response;
340 self
341 }
342
343 #[cfg(feature = "__dnssec")]
345 pub fn dnssec_answers(&self) -> DnssecIter<'_> {
346 DnssecIter::new(&self.answers)
347 }
348
349 pub fn take_all_sections(&mut self) -> impl Iterator<Item = Record> {
351 let (answers, authorities, additionals) = (
352 mem::take(&mut self.answers),
353 mem::take(&mut self.authorities),
354 mem::take(&mut self.additionals),
355 );
356 answers.into_iter().chain(authorities).chain(additionals)
357 }
358
359 pub fn all_sections(&self) -> impl Iterator<Item = &Record> {
361 self.answers
362 .iter()
363 .chain(self.authorities.iter())
364 .chain(self.additionals.iter())
365 }
366
367 pub fn max_payload(&self) -> u16 {
371 let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
372 if max_size < 512 { 512 } else { max_size }
373 }
374
375 pub fn version(&self) -> u8 {
379 self.edns.as_ref().map_or(0, Edns::version)
380 }
381
382 pub fn signature(&self) -> Option<&Record<TSIG>> {
386 self.signature.as_deref()
387 }
388
389 pub fn take_signature(&mut self) -> Option<Box<Record<TSIG>>> {
391 self.signature.take()
392 }
393
394 pub fn read_queries(decoder: &mut BinDecoder<'_>, count: usize) -> ProtoResult<Vec<Query>> {
396 let mut queries = Vec::with_capacity(count);
397 for _ in 0..count {
398 queries.push(Query::read(decoder)?);
399 }
400 Ok(queries)
401 }
402
403 #[cfg_attr(not(feature = "__dnssec"), allow(unused_mut))]
418 #[expect(clippy::type_complexity)]
419 pub fn read_records(
420 decoder: &mut BinDecoder<'_>,
421 count: usize,
422 is_additional: bool,
423 op: OpCode,
424 ) -> Result<(Vec<Record>, Option<Edns>, Option<Box<Record<TSIG>>>), DecodeError> {
425 let mut records: Vec<Record> = Vec::with_capacity(count);
426 let mut edns: Option<Edns> = None;
427 let mut sig = None;
428
429 for _ in 0..count {
430 let record = Record::read(decoder)?;
431 if op != OpCode::Update
432 && record.record_type() != RecordType::OPT
433 && record.data.is_update()
434 {
435 return Err(DecodeError::InvalidEmptyRecord);
436 }
437
438 if sig.is_some() {
440 return Err(DecodeError::RecordAfterSig);
441 }
442
443 if !is_additional
445 && matches!(
446 record.record_type(),
447 RecordType::OPT | RecordType::SIG | RecordType::TSIG
448 )
449 {
450 return Err(DecodeError::RecordNotInAdditionalSection(
451 record.record_type(),
452 ));
453 } else if !is_additional {
454 records.push(record);
455 continue;
456 }
457
458 match record.data {
459 #[cfg(feature = "__dnssec")]
460 RData::DNSSEC(DNSSECRData::SIG(_)) => {
461 warn!(
462 "message was SIG(0) signed, but support for SIG(0) message authentication was removed from hickory-dns"
463 );
464 records.push(record);
465 }
466 #[cfg(feature = "__dnssec")]
467 RData::TSIG(_) => {
468 sig = Some(Box::new(
469 record
470 .map(|data| match data {
471 RData::TSIG(tsig) => Some(tsig),
472 _ => None,
473 })
474 .unwrap(),
475 ))
476 }
477 RData::Update0(RecordType::OPT) | RData::OPT(_) => {
478 if edns.is_some() {
479 return Err(DecodeError::DuplicateEdns);
480 }
481 edns = Some((&record).into());
482 }
483 _ => {
484 records.push(record);
485 }
486 }
487 }
488
489 Ok((records, edns, sig))
490 }
491
492 pub fn from_vec(buffer: &[u8]) -> Result<Self, DecodeError> {
494 let mut decoder = BinDecoder::new(buffer);
495 Self::read(&mut decoder)
496 }
497
498 pub fn to_vec(&self) -> Result<Vec<u8>, ProtoError> {
500 let mut buffer = Vec::with_capacity(512);
504 {
505 let mut encoder = BinEncoder::new(&mut buffer);
506 self.emit(&mut encoder)?;
507 }
508
509 Ok(buffer)
510 }
511
512 #[cfg(feature = "__dnssec")]
516 pub fn finalize(
517 &mut self,
518 finalizer: &TSigner,
519 inception_time: u64,
520 ) -> ProtoResult<Option<TSigVerifier>> {
521 debug!("finalizing message: {:?}", self);
522
523 let (signature, verifier) = finalizer.sign_message(self, inception_time)?;
524 self.set_signature(signature);
525
526 Ok(verifier)
527 }
528}
529
530impl Deref for Message {
531 type Target = Metadata;
532
533 fn deref(&self) -> &Self::Target {
534 &self.metadata
535 }
536}
537
538fn count_was_truncated(result: ProtoResult<usize>) -> ProtoResult<(u16, bool)> {
540 let (count, truncated) = match result {
541 Ok(count) => (count, false),
542 Err(ProtoError::NotAllRecordsWritten { count }) => (count, true),
543 Err(e) => return Err(e),
544 };
545
546 match u16::try_from(count) {
547 Ok(count) => Ok((count, truncated)),
548 Err(_) => Err(ProtoError::Message(
549 "too many records to fit in header count",
550 )),
551 }
552}
553
554pub trait EmitAndCount {
556 fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize>;
558}
559
560impl<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable> EmitAndCount for I {
561 fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
562 encoder.emit_all(self)
563 }
564}
565
566#[allow(clippy::too_many_arguments)]
572pub fn emit_message_parts<Q, A, N, D>(
573 metadata: &Metadata,
574 queries: &mut Q,
575 answers: &mut A,
576 authorities: &mut N,
577 additionals: &mut D,
578 edns: Option<&Edns>,
579 signature: Option<&Record<TSIG>>,
580 encoder: &mut BinEncoder<'_>,
581) -> ProtoResult<Header>
582where
583 Q: EmitAndCount,
584 A: EmitAndCount,
585 N: EmitAndCount,
586 D: EmitAndCount,
587{
588 let place = encoder.place::<Header>()?;
589
590 let query_count = queries.emit(encoder)?;
591 let answer_count = count_was_truncated(answers.emit(encoder))?;
594 let authority_count = count_was_truncated(authorities.emit(encoder))?;
595 let mut additional_count = count_was_truncated(additionals.emit(encoder))?;
596
597 if let Some(mut edns) = edns.cloned() {
598 edns.set_rcode_high(metadata.response_code.high());
600
601 let count = count_was_truncated(encoder.emit_all(iter::once(&Record::from(&edns))))?;
602 additional_count.0 += count.0;
603 additional_count.1 |= count.1;
604 } else if metadata.response_code.high() > 0 {
605 warn!(
606 "response code: {} for request: {} requires EDNS but none available",
607 metadata.response_code, metadata.id
608 );
609 }
610
611 let count = match signature {
615 Some(rec) => count_was_truncated(encoder.emit_all(iter::once(rec)))?,
616 None => (0, false),
617 };
618 additional_count.0 += count.0;
619 additional_count.1 |= count.1;
620
621 let counts = HeaderCounts {
622 queries: match u16::try_from(query_count) {
623 Ok(count) => count,
624 Err(_) => {
625 return Err(ProtoError::Message(
626 "too many queries to fit in header count",
627 ));
628 }
629 },
630 answers: answer_count.0,
631 authorities: authority_count.0,
632 additionals: additional_count.0,
633 };
634
635 let mut final_metadata = *metadata;
636 final_metadata.truncation =
637 metadata.truncation || answer_count.1 || authority_count.1 || additional_count.1;
638
639 let header = Header {
640 metadata: final_metadata,
641 counts,
642 };
643
644 place.replace(encoder, header)?;
645 Ok(header)
646}
647
648impl BinEncodable for Message {
649 fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
650 emit_message_parts(
651 &self.metadata,
652 &mut self.queries.iter(),
653 &mut self.answers.iter(),
654 &mut self.authorities.iter(),
655 &mut self.additionals.iter(),
656 self.edns.as_ref(),
657 self.signature.as_deref(),
658 encoder,
659 )?;
660
661 Ok(())
662 }
663}
664
665impl<'r> BinDecodable<'r> for Message {
666 fn read(decoder: &mut BinDecoder<'r>) -> Result<Self, DecodeError> {
667 let Header {
668 mut metadata,
669 counts,
670 } = Header::read(decoder)?;
671
672 let count = counts.queries as usize;
677 let mut queries = Vec::with_capacity(count);
678 for _ in 0..count {
679 queries.push(Query::read(decoder)?);
680 }
681
682 let (answers, _, _) =
683 Self::read_records(decoder, counts.answers as usize, false, metadata.op_code)?;
684 let (authorities, _, _) = Self::read_records(
685 decoder,
686 counts.authorities as usize,
687 false,
688 metadata.op_code,
689 )?;
690 let (additionals, edns, signature) =
691 Self::read_records(decoder, counts.additionals as usize, true, metadata.op_code)?;
692
693 if let Some(edns) = &edns {
695 let high_response_code = edns.rcode_high();
696 metadata.merge_response_code(high_response_code);
697 }
698
699 Ok(Self {
700 metadata,
701 queries,
702 answers,
703 authorities,
704 additionals,
705 signature,
706 edns,
707 })
708 }
709}
710
711impl fmt::Display for Message {
712 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
713 let write_query = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
714 for d in slice {
715 writeln!(f, ";; {d}")?;
716 }
717
718 Ok(())
719 };
720
721 let write_slice = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
722 for d in slice {
723 writeln!(f, "{d}")?;
724 }
725
726 Ok(())
727 };
728
729 writeln!(f, "; header {header}", header = self.metadata)?;
730
731 if let Some(edns) = &self.edns {
732 writeln!(f, "; edns {edns}")?;
733 }
734
735 writeln!(f, "; query")?;
736 write_query(&self.queries, f)?;
737
738 if self.metadata.message_type == MessageType::Response
739 || self.metadata.op_code == OpCode::Update
740 {
741 writeln!(f, "; answers {}", self.answers.len())?;
742 write_slice(&self.answers, f)?;
743 writeln!(f, "; authorities {}", self.authorities.len())?;
744 write_slice(&self.authorities, f)?;
745 writeln!(f, "; additionals {}", self.additionals.len())?;
746 write_slice(&self.additionals, f)?;
747 }
748
749 Ok(())
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756
757 use crate::rr::rdata::A;
758 #[cfg(feature = "std")]
759 use crate::rr::rdata::OPT;
760 #[cfg(feature = "std")]
761 use crate::rr::rdata::opt::{ClientSubnet, EdnsCode, EdnsOption};
762 #[cfg(feature = "__dnssec")]
763 use crate::rr::rdata::{TSIG, tsig::TsigAlgorithm};
764 use crate::rr::{Name, RData};
765 #[cfg(feature = "std")]
766 use crate::std::net::IpAddr;
767 #[cfg(feature = "std")]
768 use crate::std::string::ToString;
769
770 #[test]
771 fn test_emit_and_read_header() {
772 let mut message = Message::response(10, OpCode::Update);
773 message.metadata.authoritative = true;
774 message.metadata.truncation = false;
775 message.metadata.recursion_desired = true;
776 message.metadata.recursion_available = true;
777 message.metadata.response_code = ResponseCode::ServFail;
778
779 test_emit_and_read(message);
780 }
781
782 #[test]
783 fn test_emit_and_read_query() {
784 let mut message = Message::response(10, OpCode::Update);
785 message.metadata.authoritative = true;
786 message.metadata.truncation = true;
787 message.metadata.recursion_desired = true;
788 message.metadata.recursion_available = true;
789 message.metadata.response_code = ResponseCode::ServFail;
790 message.add_query(Query::new());
791
792 test_emit_and_read(message);
793 }
794
795 #[test]
796 fn test_emit_and_read_records() {
797 let mut message = Message::response(10, OpCode::Update);
798 message.metadata.authoritative = true;
799 message.metadata.truncation = true;
800 message.metadata.recursion_desired = true;
801 message.metadata.recursion_available = true;
802 message.metadata.authentic_data = true;
803 message.metadata.checking_disabled = true;
804 message.metadata.response_code = ResponseCode::ServFail;
805
806 message.add_answer(Record::stub());
807 message.add_authority(Record::stub());
808 message.add_additional(Record::stub());
809
810 test_emit_and_read(message);
811 }
812
813 #[cfg(test)]
814 fn test_emit_and_read(message: Message) {
815 let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
816 {
817 let mut encoder = BinEncoder::new(&mut byte_vec);
818 message.emit(&mut encoder).unwrap();
819 }
820
821 let mut decoder = BinDecoder::new(&byte_vec);
822 let got = Message::read(&mut decoder).unwrap();
823
824 assert_eq!(got, message);
825 }
826
827 #[test]
828 fn test_header_counts_correction_after_emit_read() {
829 let mut message = Message::response(10, OpCode::Update);
830 message.metadata.authoritative = true;
831 message.metadata.truncation = true;
832 message.metadata.recursion_desired = true;
833 message.metadata.recursion_available = true;
834 message.metadata.authentic_data = true;
835 message.metadata.checking_disabled = true;
836 message.metadata.response_code = ResponseCode::ServFail;
837
838 message.add_answer(Record::stub());
839 message.add_authority(Record::stub());
840 message.add_additional(Record::stub());
841
842 let got = get_message_after_emitting_and_reading(message);
843 assert_eq!(got.queries.len(), 0);
844 assert_eq!(got.answers.len(), 1);
845 assert_eq!(got.authorities.len(), 1);
846 assert_eq!(got.additionals.len(), 1);
847 }
848
849 #[cfg(test)]
850 fn get_message_after_emitting_and_reading(message: Message) -> Message {
851 let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
852 {
853 let mut encoder = BinEncoder::new(&mut byte_vec);
854 message.emit(&mut encoder).unwrap();
855 }
856
857 let mut decoder = BinDecoder::new(&byte_vec);
858
859 Message::read(&mut decoder).unwrap()
860 }
861
862 #[test]
863 fn test_legit_message() {
864 #[rustfmt::skip]
865 let buf: Vec<u8> = vec![
866 0x10, 0x00, 0x81,
867 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x03, b'w', b'w', b'w', 0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 0x03, b'c', b'o', b'm', 0x00, 0x00, 0x01, 0x00, 0x01, 0xC0, 0x0C, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x04, 0x5D, 0xB8, 0xD7, 0x0E, ];
882
883 let mut decoder = BinDecoder::new(&buf);
884 let message = Message::read(&mut decoder).unwrap();
885
886 assert_eq!(message.id, 4_096);
887
888 let mut buf: Vec<u8> = Vec::with_capacity(512);
889 {
890 let mut encoder = BinEncoder::new(&mut buf);
891 message.emit(&mut encoder).unwrap();
892 }
893
894 let mut decoder = BinDecoder::new(&buf);
895 let message = Message::read(&mut decoder).unwrap();
896
897 assert_eq!(message.id, 4_096);
898 }
899
900 #[test]
901 fn rdata_zero_roundtrip() {
902 let buf = &[
903 160, 160, 0, 13, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
904 ];
905
906 assert!(Message::from_bytes(buf).is_err());
907 }
908
909 #[test]
910 fn nsec_deserialization() {
911 const CRASHING_MESSAGE: &[u8] = &[
912 0, 0, 132, 0, 0, 0, 0, 1, 0, 0, 0, 1, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100,
913 52, 50, 52, 45, 52, 102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55,
914 56, 48, 102, 50, 98, 5, 108, 111, 99, 97, 108, 0, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4,
915 192, 168, 1, 17, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100, 52, 50, 52, 45, 52,
916 102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55, 56, 48, 102, 50, 98,
917 5, 108, 111, 99, 97, 108, 0, 0, 47, 128, 1, 0, 0, 0, 120, 0, 5, 192, 70, 0, 1, 64,
918 ];
919
920 Message::from_vec(CRASHING_MESSAGE).expect("failed to parse message");
921 }
922
923 #[test]
924 fn test_read_records_unsigned() {
925 let records = vec![
926 Record::from_rdata(
927 Name::from_labels(vec!["example", "com"]).unwrap(),
928 300,
929 RData::A(A::new(127, 0, 0, 1)),
930 ),
931 Record::from_rdata(
932 Name::from_labels(vec!["www", "example", "com"]).unwrap(),
933 300,
934 RData::A(A::new(127, 0, 0, 1)),
935 ),
936 ];
937 let result = encode_and_read_records(records.clone(), false);
938 let (output_records, edns, signature) = result.unwrap();
939 assert_eq!(output_records.len(), records.len());
940 assert!(edns.is_none());
941 assert!(signature.is_none());
942 }
943
944 #[cfg(feature = "std")]
945 #[test]
946 fn test_read_records_edns() {
947 let records = vec![
948 Record::from_rdata(
949 Name::from_labels(vec!["example", "com"]).unwrap(),
950 300,
951 RData::A(A::new(127, 0, 0, 1)),
952 ),
953 Record::from_rdata(
954 Name::new(),
955 0,
956 RData::OPT(OPT::new(vec![(
957 EdnsCode::Subnet,
958 EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
959 )])),
960 ),
961 ];
962 let result = encode_and_read_records(records, true);
963 let (output_records, edns, signature) = result.unwrap();
964 assert_eq!(output_records.len(), 1); assert!(edns.is_some());
966 assert!(signature.is_none());
967 }
968
969 #[cfg(feature = "__dnssec")]
970 #[test]
971 fn test_read_records_tsig() {
972 let records = vec![
973 Record::from_rdata(
974 Name::from_labels(vec!["example", "com"]).unwrap(),
975 300,
976 RData::A(A::new(127, 0, 0, 1)),
977 ),
978 Record::from_rdata(
979 Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
980 0,
981 fake_tsig(),
982 ),
983 ];
984 let result = encode_and_read_records(records, true);
985 let (output_records, edns, signature) = result.unwrap();
986 assert_eq!(output_records.len(), 1); assert!(edns.is_none());
988 assert!(signature.is_some());
989 }
990
991 #[cfg(all(feature = "std", feature = "__dnssec"))]
992 #[test]
993 fn test_read_records_edns_tsig() {
994 let records = vec![
995 Record::from_rdata(
996 Name::from_labels(vec!["example", "com"]).unwrap(),
997 300,
998 RData::A(A::new(127, 0, 0, 1)),
999 ),
1000 Record::from_rdata(
1001 Name::new(),
1002 0,
1003 RData::OPT(OPT::new(vec![(
1004 EdnsCode::Subnet,
1005 EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
1006 )])),
1007 ),
1008 Record::from_rdata(
1009 Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1010 0,
1011 fake_tsig(),
1012 ),
1013 ];
1014
1015 let result = encode_and_read_records(records, true);
1016 assert!(result.is_ok());
1017 let (output_records, edns, signature) = result.unwrap();
1018 assert_eq!(output_records.len(), 1); assert!(edns.is_some());
1020 assert!(signature.is_some());
1021 }
1022
1023 #[cfg(feature = "std")]
1024 #[test]
1025 fn test_read_records_unsigned_multiple_edns() {
1026 let opt_record = Record::from_rdata(
1027 Name::new(),
1028 0,
1029 RData::OPT(OPT::new(vec![(
1030 EdnsCode::Subnet,
1031 EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
1032 )])),
1033 );
1034 let error = encode_and_read_records(
1035 vec![
1036 opt_record.clone(),
1037 Record::from_rdata(
1038 Name::from_labels(vec!["example", "com"]).unwrap(),
1039 300,
1040 RData::A(A::new(127, 0, 0, 1)),
1041 ),
1042 opt_record.clone(),
1043 ],
1044 true,
1045 )
1046 .unwrap_err();
1047 assert!(error.to_string().contains("more than one EDNS record"));
1048 }
1049
1050 #[cfg(feature = "std")]
1051 #[test]
1052 fn test_read_records_opt_not_additional() {
1053 let opt_record = Record::from_rdata(
1054 Name::new(),
1055 0,
1056 RData::OPT(OPT::new(vec![(
1057 EdnsCode::Subnet,
1058 EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
1059 )])),
1060 );
1061 let err = encode_and_read_records(
1062 vec![
1063 opt_record.clone(),
1064 Record::from_rdata(
1065 Name::from_labels(vec!["example", "com"]).unwrap(),
1066 300,
1067 RData::A(A::new(127, 0, 0, 1)),
1068 ),
1069 ],
1070 false,
1071 )
1072 .unwrap_err();
1073 assert!(
1074 err.to_string()
1075 .contains("record type OPT only allowed in additional")
1076 );
1077 }
1078
1079 #[cfg(all(feature = "std", feature = "__dnssec"))]
1080 #[test]
1081 fn test_read_records_signed_multiple_edns() {
1082 let opt_record = Record::from_rdata(
1083 Name::new(),
1084 0,
1085 RData::OPT(OPT::new(vec![(
1086 EdnsCode::Subnet,
1087 EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
1088 )])),
1089 );
1090 let error = encode_and_read_records(
1091 vec![
1092 opt_record.clone(),
1093 Record::from_rdata(
1094 Name::from_labels(vec!["example", "com"]).unwrap(),
1095 300,
1096 RData::A(A::new(127, 0, 0, 1)),
1097 ),
1098 opt_record.clone(),
1099 Record::from_rdata(
1100 Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1101 0,
1102 fake_tsig(),
1103 ),
1104 ],
1105 true,
1106 )
1107 .unwrap_err();
1108 assert!(error.to_string().contains("more than one EDNS record"));
1109 }
1110
1111 #[cfg(all(feature = "std", feature = "__dnssec"))]
1112 #[test]
1113 fn test_read_records_tsig_not_additional() {
1114 let err = encode_and_read_records(
1115 vec![
1116 Record::from_rdata(
1117 Name::from_labels(vec!["example", "com"]).unwrap(),
1118 300,
1119 RData::A(A::new(127, 0, 0, 1)),
1120 ),
1121 Record::from_rdata(
1122 Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1123 0,
1124 fake_tsig(),
1125 ),
1126 ],
1127 false,
1128 )
1129 .unwrap_err();
1130 assert!(
1131 err.to_string()
1132 .contains("record type TSIG only allowed in additional")
1133 );
1134 }
1135
1136 #[cfg(all(feature = "std", feature = "__dnssec"))]
1137 #[test]
1138 fn test_read_records_tsig_not_last() {
1139 let a_record = Record::from_rdata(
1140 Name::from_labels(vec!["example", "com"]).unwrap(),
1141 300,
1142 RData::A(A::new(127, 0, 0, 1)),
1143 );
1144 let error = encode_and_read_records(
1145 vec![
1146 a_record.clone(),
1147 Record::from_rdata(
1148 Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1149 0,
1150 fake_tsig(),
1151 ),
1152 a_record.clone(),
1153 ],
1154 true,
1155 )
1156 .unwrap_err()
1157 .to_string();
1158 assert!(error.contains("record after TSIG or SIG(0)"));
1159 }
1160
1161 #[cfg(all(feature = "std", feature = "__dnssec"))]
1162 #[test]
1163 fn test_read_records_sig0_not_last() {
1164 let a_record = Record::from_rdata(
1165 Name::from_labels(vec!["example", "com"]).unwrap(),
1166 300,
1167 RData::A(A::new(127, 0, 0, 1)),
1168 );
1169 let error = encode_and_read_records(
1170 vec![
1171 a_record.clone(),
1172 Record::from_rdata(
1173 Name::from_labels(vec!["sig0", "example", "com"]).unwrap(),
1174 0,
1175 fake_tsig(),
1176 ),
1177 a_record.clone(),
1178 ],
1179 true,
1180 )
1181 .unwrap_err()
1182 .to_string();
1183 assert!(error.contains("record after TSIG or SIG(0)"));
1184 }
1185
1186 #[cfg(all(feature = "std", feature = "__dnssec"))]
1187 #[test]
1188 fn test_read_records_multiple_tsig() {
1189 let tsig_record = Record::from_rdata(
1190 Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1191 0,
1192 fake_tsig(),
1193 );
1194 let error = encode_and_read_records(
1195 vec![
1196 Record::from_rdata(
1197 Name::from_labels(vec!["example", "com"]).unwrap(),
1198 300,
1199 RData::A(A::new(127, 0, 0, 1)),
1200 ),
1201 tsig_record.clone(),
1202 tsig_record.clone(),
1203 ],
1204 true,
1205 )
1206 .unwrap_err()
1207 .to_string();
1208 assert!(error.contains("record after TSIG or SIG(0)"));
1209 }
1210
1211 #[cfg(all(feature = "std", feature = "__dnssec"))]
1212 #[test]
1213 fn test_read_records_multiple_sig0() {
1214 let sig0_record = Record::from_rdata(
1215 Name::from_labels(vec!["sig0", "example", "com"]).unwrap(),
1216 0,
1217 fake_tsig(),
1218 );
1219 let error = encode_and_read_records(
1220 vec![
1221 Record::from_rdata(
1222 Name::from_labels(vec!["example", "com"]).unwrap(),
1223 300,
1224 RData::A(A::new(127, 0, 0, 1)),
1225 ),
1226 sig0_record.clone(),
1227 sig0_record.clone(),
1228 ],
1229 true,
1230 )
1231 .unwrap_err()
1232 .to_string();
1233 assert!(error.contains("record after TSIG or SIG(0)"));
1234 }
1235
1236 #[expect(clippy::type_complexity)]
1237 fn encode_and_read_records(
1238 records: Vec<Record>,
1239 is_additional: bool,
1240 ) -> ProtoResult<(Vec<Record>, Option<Edns>, Option<Box<Record<TSIG>>>)> {
1241 let mut bytes = Vec::new();
1242 let mut encoder = BinEncoder::new(&mut bytes);
1243 encoder.emit_all(records.iter())?;
1244 Ok(Message::read_records(
1245 &mut BinDecoder::new(&bytes),
1246 records.len(),
1247 is_additional,
1248 OpCode::Query,
1249 )?)
1250 }
1251
1252 #[cfg(feature = "__dnssec")]
1253 fn fake_tsig() -> RData {
1254 RData::TSIG(TSIG::new(
1255 TsigAlgorithm::HmacSha256,
1256 0,
1257 0,
1258 vec![],
1259 0,
1260 None,
1261 vec![],
1262 ))
1263 }
1264}