use alloc::{boxed::Box, fmt, vec::Vec};
use core::{iter, mem, ops::Deref};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "__dnssec")]
use tracing::debug;
use tracing::warn;
#[cfg(feature = "__dnssec")]
use crate::dnssec::{DnssecIter, rdata::DNSSECRData};
#[cfg(any(feature = "std", feature = "no-std-rand"))]
use crate::random;
#[cfg(feature = "__dnssec")]
use crate::rr::{TSigVerifier, TSigner};
use crate::{
error::{ProtoError, ProtoResult},
op::{Edns, Header, HeaderCounts, MessageType, Metadata, OpCode, Query, ResponseCode},
rr::{RData, Record, RecordData, RecordType, rdata::TSIG},
serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder, DecodeError},
};
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Message {
pub metadata: Metadata,
pub queries: Vec<Query>,
pub answers: Vec<Record>,
pub authorities: Vec<Record>,
pub additionals: Vec<Record>,
pub signature: Option<Box<Record<TSIG>>>,
pub edns: Option<Edns>,
}
impl Message {
#[cfg(any(feature = "std", feature = "no-std-rand"))]
pub fn query() -> Self {
Self::new(random(), MessageType::Query, OpCode::Query)
}
pub fn error_msg(id: u16, op_code: OpCode, response_code: ResponseCode) -> Self {
let mut message = Self::response(id, op_code);
message.metadata.response_code = response_code;
message
}
pub fn response(id: u16, op_code: OpCode) -> Self {
Self::new(id, MessageType::Response, op_code)
}
pub fn new(id: u16, message_type: MessageType, op_code: OpCode) -> Self {
Self {
metadata: Metadata::new(id, message_type, op_code),
queries: Vec::new(),
answers: Vec::new(),
authorities: Vec::new(),
additionals: Vec::new(),
signature: None,
edns: None,
}
}
pub fn truncate(&self) -> Self {
let mut metadata = self.metadata;
metadata.truncation = true;
let mut msg = Self::new(0, MessageType::Query, OpCode::Query);
msg.metadata = metadata;
msg.add_queries(self.queries.iter().cloned());
if let Some(edns) = self.edns.clone() {
msg.set_edns(edns);
}
msg
}
pub fn maybe_strip_dnssec_records(mut self, query_has_dnssec_ok: bool) -> Self {
if query_has_dnssec_ok {
return self;
}
let Some(query_type) = self.queries.first().map(|q| q.query_type()) else {
return self; };
let predicate = |record: &Record| {
let record_type = record.record_type();
record_type == query_type || !record_type.is_dnssec()
};
self.answers.retain(predicate);
self.authorities.retain(predicate);
self.additionals.retain(predicate);
self
}
pub fn add_query(&mut self, query: Query) -> &mut Self {
self.queries.push(query);
self
}
pub fn add_queries<Q, I>(&mut self, queries: Q) -> &mut Self
where
Q: IntoIterator<Item = Query, IntoIter = I>,
I: Iterator<Item = Query>,
{
for query in queries {
self.add_query(query);
}
self
}
pub fn add_answer(&mut self, record: Record) -> &mut Self {
self.answers.push(record);
self
}
pub fn add_answers<R, I>(&mut self, records: R) -> &mut Self
where
R: IntoIterator<Item = Record, IntoIter = I>,
I: Iterator<Item = Record>,
{
for record in records {
self.add_answer(record);
}
self
}
pub fn insert_answers(&mut self, records: Vec<Record>) {
assert!(self.answers.is_empty());
self.answers = records;
}
pub fn add_authority(&mut self, record: Record) -> &mut Self {
self.authorities.push(record);
self
}
pub fn add_authorities<R, I>(&mut self, records: R) -> &mut Self
where
R: IntoIterator<Item = Record, IntoIter = I>,
I: Iterator<Item = Record>,
{
for record in records {
self.add_authority(record);
}
self
}
pub fn insert_authorities(&mut self, records: Vec<Record>) {
assert!(self.authorities.is_empty());
self.authorities = records;
}
pub fn add_additional(&mut self, record: Record) -> &mut Self {
self.additionals.push(record);
self
}
pub fn add_additionals<R, I>(&mut self, records: R) -> &mut Self
where
R: IntoIterator<Item = Record, IntoIter = I>,
I: Iterator<Item = Record>,
{
for record in records {
self.add_additional(record);
}
self
}
pub fn insert_additionals(&mut self, records: Vec<Record>) {
assert!(self.additionals.is_empty());
self.additionals = records;
}
pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
self.edns = Some(edns);
self
}
#[cfg(feature = "__dnssec")]
pub fn set_signature(&mut self, sig: Box<Record<TSIG>>) -> &mut Self {
self.signature = Some(sig);
self
}
pub fn into_response(mut self) -> Self {
self.metadata.message_type = MessageType::Response;
self
}
#[cfg(feature = "__dnssec")]
pub fn dnssec_answers(&self) -> DnssecIter<'_> {
DnssecIter::new(&self.answers)
}
pub fn take_all_sections(&mut self) -> impl Iterator<Item = Record> {
let (answers, authorities, additionals) = (
mem::take(&mut self.answers),
mem::take(&mut self.authorities),
mem::take(&mut self.additionals),
);
answers.into_iter().chain(authorities).chain(additionals)
}
pub fn all_sections(&self) -> impl Iterator<Item = &Record> {
self.answers
.iter()
.chain(self.authorities.iter())
.chain(self.additionals.iter())
}
pub fn max_payload(&self) -> u16 {
let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
if max_size < 512 { 512 } else { max_size }
}
pub fn version(&self) -> u8 {
self.edns.as_ref().map_or(0, Edns::version)
}
pub fn signature(&self) -> Option<&Record<TSIG>> {
self.signature.as_deref()
}
pub fn take_signature(&mut self) -> Option<Box<Record<TSIG>>> {
self.signature.take()
}
pub fn read_queries(decoder: &mut BinDecoder<'_>, count: usize) -> ProtoResult<Vec<Query>> {
let mut queries = Vec::with_capacity(count);
for _ in 0..count {
queries.push(Query::read(decoder)?);
}
Ok(queries)
}
#[cfg_attr(not(feature = "__dnssec"), allow(unused_mut))]
#[expect(clippy::type_complexity)]
pub fn read_records(
decoder: &mut BinDecoder<'_>,
count: usize,
is_additional: bool,
op: OpCode,
) -> Result<(Vec<Record>, Option<Edns>, Option<Box<Record<TSIG>>>), DecodeError> {
let mut records: Vec<Record> = Vec::with_capacity(count);
let mut edns: Option<Edns> = None;
let mut sig = None;
for _ in 0..count {
let record = Record::read(decoder)?;
if op != OpCode::Update
&& record.record_type() != RecordType::OPT
&& record.data.is_update()
{
return Err(DecodeError::InvalidEmptyRecord);
}
if sig.is_some() {
return Err(DecodeError::RecordAfterSig);
}
if !is_additional
&& matches!(
record.record_type(),
RecordType::OPT | RecordType::SIG | RecordType::TSIG
)
{
return Err(DecodeError::RecordNotInAdditionalSection(
record.record_type(),
));
} else if !is_additional {
records.push(record);
continue;
}
match record.data {
#[cfg(feature = "__dnssec")]
RData::DNSSEC(DNSSECRData::SIG(_)) => {
warn!(
"message was SIG(0) signed, but support for SIG(0) message authentication was removed from hickory-dns"
);
records.push(record);
}
#[cfg(feature = "__dnssec")]
RData::TSIG(_) => {
sig = Some(Box::new(
record
.map(|data| match data {
RData::TSIG(tsig) => Some(tsig),
_ => None,
})
.unwrap(),
))
}
RData::Update0(RecordType::OPT) | RData::OPT(_) => {
if edns.is_some() {
return Err(DecodeError::DuplicateEdns);
}
edns = Some((&record).into());
}
_ => {
records.push(record);
}
}
}
Ok((records, edns, sig))
}
pub fn from_vec(buffer: &[u8]) -> Result<Self, DecodeError> {
let mut decoder = BinDecoder::new(buffer);
Self::read(&mut decoder)
}
pub fn to_vec(&self) -> Result<Vec<u8>, ProtoError> {
let mut buffer = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut buffer);
self.emit(&mut encoder)?;
}
Ok(buffer)
}
#[cfg(feature = "__dnssec")]
pub fn finalize(
&mut self,
finalizer: &TSigner,
inception_time: u64,
) -> ProtoResult<Option<TSigVerifier>> {
debug!("finalizing message: {:?}", self);
let (signature, verifier) = finalizer.sign_message(self, inception_time)?;
self.set_signature(signature);
Ok(verifier)
}
}
impl Deref for Message {
type Target = Metadata;
fn deref(&self) -> &Self::Target {
&self.metadata
}
}
fn count_was_truncated(result: ProtoResult<usize>) -> ProtoResult<(u16, bool)> {
let (count, truncated) = match result {
Ok(count) => (count, false),
Err(ProtoError::NotAllRecordsWritten { count }) => (count, true),
Err(e) => return Err(e),
};
match u16::try_from(count) {
Ok(count) => Ok((count, truncated)),
Err(_) => Err(ProtoError::Message(
"too many records to fit in header count",
)),
}
}
pub trait EmitAndCount {
fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize>;
}
impl<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable> EmitAndCount for I {
fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
encoder.emit_all(self)
}
}
#[allow(clippy::too_many_arguments)]
pub fn emit_message_parts<Q, A, N, D>(
metadata: &Metadata,
queries: &mut Q,
answers: &mut A,
authorities: &mut N,
additionals: &mut D,
edns: Option<&Edns>,
signature: Option<&Record<TSIG>>,
encoder: &mut BinEncoder<'_>,
) -> ProtoResult<Header>
where
Q: EmitAndCount,
A: EmitAndCount,
N: EmitAndCount,
D: EmitAndCount,
{
let place = encoder.place::<Header>()?;
let query_count = queries.emit(encoder)?;
let answer_count = count_was_truncated(answers.emit(encoder))?;
let authority_count = count_was_truncated(authorities.emit(encoder))?;
let mut additional_count = count_was_truncated(additionals.emit(encoder))?;
if let Some(mut edns) = edns.cloned() {
edns.set_rcode_high(metadata.response_code.high());
let count = count_was_truncated(encoder.emit_all(iter::once(&Record::from(&edns))))?;
additional_count.0 += count.0;
additional_count.1 |= count.1;
} else if metadata.response_code.high() > 0 {
warn!(
"response code: {} for request: {} requires EDNS but none available",
metadata.response_code, metadata.id
);
}
let count = match signature {
Some(rec) => count_was_truncated(encoder.emit_all(iter::once(rec)))?,
None => (0, false),
};
additional_count.0 += count.0;
additional_count.1 |= count.1;
let counts = HeaderCounts {
queries: match u16::try_from(query_count) {
Ok(count) => count,
Err(_) => {
return Err(ProtoError::Message(
"too many queries to fit in header count",
));
}
},
answers: answer_count.0,
authorities: authority_count.0,
additionals: additional_count.0,
};
let mut final_metadata = *metadata;
final_metadata.truncation =
metadata.truncation || answer_count.1 || authority_count.1 || additional_count.1;
let header = Header {
metadata: final_metadata,
counts,
};
place.replace(encoder, header)?;
Ok(header)
}
impl BinEncodable for Message {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
emit_message_parts(
&self.metadata,
&mut self.queries.iter(),
&mut self.answers.iter(),
&mut self.authorities.iter(),
&mut self.additionals.iter(),
self.edns.as_ref(),
self.signature.as_deref(),
encoder,
)?;
Ok(())
}
}
impl<'r> BinDecodable<'r> for Message {
fn read(decoder: &mut BinDecoder<'r>) -> Result<Self, DecodeError> {
let Header {
mut metadata,
counts,
} = Header::read(decoder)?;
let count = counts.queries as usize;
let mut queries = Vec::with_capacity(count);
for _ in 0..count {
queries.push(Query::read(decoder)?);
}
let (answers, _, _) =
Self::read_records(decoder, counts.answers as usize, false, metadata.op_code)?;
let (authorities, _, _) = Self::read_records(
decoder,
counts.authorities as usize,
false,
metadata.op_code,
)?;
let (additionals, edns, signature) =
Self::read_records(decoder, counts.additionals as usize, true, metadata.op_code)?;
if let Some(edns) = &edns {
let high_response_code = edns.rcode_high();
metadata.merge_response_code(high_response_code);
}
Ok(Self {
metadata,
queries,
answers,
authorities,
additionals,
signature,
edns,
})
}
}
impl fmt::Display for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let write_query = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
for d in slice {
writeln!(f, ";; {d}")?;
}
Ok(())
};
let write_slice = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
for d in slice {
writeln!(f, "{d}")?;
}
Ok(())
};
writeln!(f, "; header {header}", header = self.metadata)?;
if let Some(edns) = &self.edns {
writeln!(f, "; edns {edns}")?;
}
writeln!(f, "; query")?;
write_query(&self.queries, f)?;
if self.metadata.message_type == MessageType::Response
|| self.metadata.op_code == OpCode::Update
{
writeln!(f, "; answers {}", self.answers.len())?;
write_slice(&self.answers, f)?;
writeln!(f, "; authorities {}", self.authorities.len())?;
write_slice(&self.authorities, f)?;
writeln!(f, "; additionals {}", self.additionals.len())?;
write_slice(&self.additionals, f)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rr::rdata::A;
#[cfg(feature = "std")]
use crate::rr::rdata::OPT;
#[cfg(feature = "std")]
use crate::rr::rdata::opt::{ClientSubnet, EdnsCode, EdnsOption};
#[cfg(feature = "__dnssec")]
use crate::rr::rdata::{TSIG, tsig::TsigAlgorithm};
use crate::rr::{Name, RData};
#[cfg(feature = "std")]
use crate::std::net::IpAddr;
#[cfg(feature = "std")]
use crate::std::string::ToString;
#[test]
fn test_emit_and_read_header() {
let mut message = Message::response(10, OpCode::Update);
message.metadata.authoritative = true;
message.metadata.truncation = false;
message.metadata.recursion_desired = true;
message.metadata.recursion_available = true;
message.metadata.response_code = ResponseCode::ServFail;
test_emit_and_read(message);
}
#[test]
fn test_emit_and_read_query() {
let mut message = Message::response(10, OpCode::Update);
message.metadata.authoritative = true;
message.metadata.truncation = true;
message.metadata.recursion_desired = true;
message.metadata.recursion_available = true;
message.metadata.response_code = ResponseCode::ServFail;
message.add_query(Query::new());
test_emit_and_read(message);
}
#[test]
fn test_emit_and_read_records() {
let mut message = Message::response(10, OpCode::Update);
message.metadata.authoritative = true;
message.metadata.truncation = true;
message.metadata.recursion_desired = true;
message.metadata.recursion_available = true;
message.metadata.authentic_data = true;
message.metadata.checking_disabled = true;
message.metadata.response_code = ResponseCode::ServFail;
message.add_answer(Record::stub());
message.add_authority(Record::stub());
message.add_additional(Record::stub());
test_emit_and_read(message);
}
#[cfg(test)]
fn test_emit_and_read(message: Message) {
let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut byte_vec);
message.emit(&mut encoder).unwrap();
}
let mut decoder = BinDecoder::new(&byte_vec);
let got = Message::read(&mut decoder).unwrap();
assert_eq!(got, message);
}
#[test]
fn test_header_counts_correction_after_emit_read() {
let mut message = Message::response(10, OpCode::Update);
message.metadata.authoritative = true;
message.metadata.truncation = true;
message.metadata.recursion_desired = true;
message.metadata.recursion_available = true;
message.metadata.authentic_data = true;
message.metadata.checking_disabled = true;
message.metadata.response_code = ResponseCode::ServFail;
message.add_answer(Record::stub());
message.add_authority(Record::stub());
message.add_additional(Record::stub());
let got = get_message_after_emitting_and_reading(message);
assert_eq!(got.queries.len(), 0);
assert_eq!(got.answers.len(), 1);
assert_eq!(got.authorities.len(), 1);
assert_eq!(got.additionals.len(), 1);
}
#[cfg(test)]
fn get_message_after_emitting_and_reading(message: Message) -> Message {
let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut byte_vec);
message.emit(&mut encoder).unwrap();
}
let mut decoder = BinDecoder::new(&byte_vec);
Message::read(&mut decoder).unwrap()
}
#[test]
fn test_legit_message() {
#[rustfmt::skip]
let buf: Vec<u8> = vec![
0x10, 0x00, 0x81,
0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x03, b'w', b'w', b'w', 0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 0x03, b'c', b'o', b'm', 0x00, 0x00, 0x01, 0x00, 0x01, 0xC0, 0x0C, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x04, 0x5D, 0xB8, 0xD7, 0x0E, ];
let mut decoder = BinDecoder::new(&buf);
let message = Message::read(&mut decoder).unwrap();
assert_eq!(message.id, 4_096);
let mut buf: Vec<u8> = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut buf);
message.emit(&mut encoder).unwrap();
}
let mut decoder = BinDecoder::new(&buf);
let message = Message::read(&mut decoder).unwrap();
assert_eq!(message.id, 4_096);
}
#[test]
fn rdata_zero_roundtrip() {
let buf = &[
160, 160, 0, 13, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
];
assert!(Message::from_bytes(buf).is_err());
}
#[test]
fn nsec_deserialization() {
const CRASHING_MESSAGE: &[u8] = &[
0, 0, 132, 0, 0, 0, 0, 1, 0, 0, 0, 1, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100,
52, 50, 52, 45, 52, 102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55,
56, 48, 102, 50, 98, 5, 108, 111, 99, 97, 108, 0, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4,
192, 168, 1, 17, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100, 52, 50, 52, 45, 52,
102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55, 56, 48, 102, 50, 98,
5, 108, 111, 99, 97, 108, 0, 0, 47, 128, 1, 0, 0, 0, 120, 0, 5, 192, 70, 0, 1, 64,
];
Message::from_vec(CRASHING_MESSAGE).expect("failed to parse message");
}
#[test]
fn test_read_records_unsigned() {
let records = vec![
Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
),
Record::from_rdata(
Name::from_labels(vec!["www", "example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
),
];
let result = encode_and_read_records(records.clone(), false);
let (output_records, edns, signature) = result.unwrap();
assert_eq!(output_records.len(), records.len());
assert!(edns.is_none());
assert!(signature.is_none());
}
#[cfg(feature = "std")]
#[test]
fn test_read_records_edns() {
let records = vec![
Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
),
Record::from_rdata(
Name::new(),
0,
RData::OPT(OPT::new(vec![(
EdnsCode::Subnet,
EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
)])),
),
];
let result = encode_and_read_records(records, true);
let (output_records, edns, signature) = result.unwrap();
assert_eq!(output_records.len(), 1); assert!(edns.is_some());
assert!(signature.is_none());
}
#[cfg(feature = "__dnssec")]
#[test]
fn test_read_records_tsig() {
let records = vec![
Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
),
Record::from_rdata(
Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
0,
fake_tsig(),
),
];
let result = encode_and_read_records(records, true);
let (output_records, edns, signature) = result.unwrap();
assert_eq!(output_records.len(), 1); assert!(edns.is_none());
assert!(signature.is_some());
}
#[cfg(all(feature = "std", feature = "__dnssec"))]
#[test]
fn test_read_records_edns_tsig() {
let records = vec![
Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
),
Record::from_rdata(
Name::new(),
0,
RData::OPT(OPT::new(vec![(
EdnsCode::Subnet,
EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
)])),
),
Record::from_rdata(
Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
0,
fake_tsig(),
),
];
let result = encode_and_read_records(records, true);
assert!(result.is_ok());
let (output_records, edns, signature) = result.unwrap();
assert_eq!(output_records.len(), 1); assert!(edns.is_some());
assert!(signature.is_some());
}
#[cfg(feature = "std")]
#[test]
fn test_read_records_unsigned_multiple_edns() {
let opt_record = Record::from_rdata(
Name::new(),
0,
RData::OPT(OPT::new(vec![(
EdnsCode::Subnet,
EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
)])),
);
let error = encode_and_read_records(
vec![
opt_record.clone(),
Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
),
opt_record.clone(),
],
true,
)
.unwrap_err();
assert!(error.to_string().contains("more than one EDNS record"));
}
#[cfg(feature = "std")]
#[test]
fn test_read_records_opt_not_additional() {
let opt_record = Record::from_rdata(
Name::new(),
0,
RData::OPT(OPT::new(vec![(
EdnsCode::Subnet,
EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
)])),
);
let err = encode_and_read_records(
vec![
opt_record.clone(),
Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
),
],
false,
)
.unwrap_err();
assert!(
err.to_string()
.contains("record type OPT only allowed in additional")
);
}
#[cfg(all(feature = "std", feature = "__dnssec"))]
#[test]
fn test_read_records_signed_multiple_edns() {
let opt_record = Record::from_rdata(
Name::new(),
0,
RData::OPT(OPT::new(vec![(
EdnsCode::Subnet,
EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
)])),
);
let error = encode_and_read_records(
vec![
opt_record.clone(),
Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
),
opt_record.clone(),
Record::from_rdata(
Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
0,
fake_tsig(),
),
],
true,
)
.unwrap_err();
assert!(error.to_string().contains("more than one EDNS record"));
}
#[cfg(all(feature = "std", feature = "__dnssec"))]
#[test]
fn test_read_records_tsig_not_additional() {
let err = encode_and_read_records(
vec![
Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
),
Record::from_rdata(
Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
0,
fake_tsig(),
),
],
false,
)
.unwrap_err();
assert!(
err.to_string()
.contains("record type TSIG only allowed in additional")
);
}
#[cfg(all(feature = "std", feature = "__dnssec"))]
#[test]
fn test_read_records_tsig_not_last() {
let a_record = Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
);
let error = encode_and_read_records(
vec![
a_record.clone(),
Record::from_rdata(
Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
0,
fake_tsig(),
),
a_record.clone(),
],
true,
)
.unwrap_err()
.to_string();
assert!(error.contains("record after TSIG or SIG(0)"));
}
#[cfg(all(feature = "std", feature = "__dnssec"))]
#[test]
fn test_read_records_sig0_not_last() {
let a_record = Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
);
let error = encode_and_read_records(
vec![
a_record.clone(),
Record::from_rdata(
Name::from_labels(vec!["sig0", "example", "com"]).unwrap(),
0,
fake_tsig(),
),
a_record.clone(),
],
true,
)
.unwrap_err()
.to_string();
assert!(error.contains("record after TSIG or SIG(0)"));
}
#[cfg(all(feature = "std", feature = "__dnssec"))]
#[test]
fn test_read_records_multiple_tsig() {
let tsig_record = Record::from_rdata(
Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
0,
fake_tsig(),
);
let error = encode_and_read_records(
vec![
Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
),
tsig_record.clone(),
tsig_record.clone(),
],
true,
)
.unwrap_err()
.to_string();
assert!(error.contains("record after TSIG or SIG(0)"));
}
#[cfg(all(feature = "std", feature = "__dnssec"))]
#[test]
fn test_read_records_multiple_sig0() {
let sig0_record = Record::from_rdata(
Name::from_labels(vec!["sig0", "example", "com"]).unwrap(),
0,
fake_tsig(),
);
let error = encode_and_read_records(
vec![
Record::from_rdata(
Name::from_labels(vec!["example", "com"]).unwrap(),
300,
RData::A(A::new(127, 0, 0, 1)),
),
sig0_record.clone(),
sig0_record.clone(),
],
true,
)
.unwrap_err()
.to_string();
assert!(error.contains("record after TSIG or SIG(0)"));
}
#[expect(clippy::type_complexity)]
fn encode_and_read_records(
records: Vec<Record>,
is_additional: bool,
) -> ProtoResult<(Vec<Record>, Option<Edns>, Option<Box<Record<TSIG>>>)> {
let mut bytes = Vec::new();
let mut encoder = BinEncoder::new(&mut bytes);
encoder.emit_all(records.iter())?;
Ok(Message::read_records(
&mut BinDecoder::new(&bytes),
records.len(),
is_additional,
OpCode::Query,
)?)
}
#[cfg(feature = "__dnssec")]
fn fake_tsig() -> RData {
RData::TSIG(TSIG::new(
TsigAlgorithm::HmacSha256,
0,
0,
vec![],
0,
None,
vec![],
))
}
}