use core::any::Any;
use core::ops::Div;
use crate::endian::{read_u16_be, read_u32_be};
use crate::error::{CrafterError, Result};
use crate::field::Field;
use crate::packet::{IntoPacket, Layer, LayerContext, Packet};
mod constants;
mod dnssec;
mod edns;
mod name;
mod rdata;
mod record;
mod svcb;
use rdata::decode_record_data;
pub use dnssec::DnsTypeBitmaps;
pub use edns::{edns_option_code_name, EdnsOption};
pub use name::{decode_dns_name, decode_dns_name_typed, DnsName};
pub use rdata::DnsRecordData;
pub use record::{DnsQuestion, DnsRecord};
pub use svcb::{svcb_param_key_name, SvcParam, SvcParams};
pub use constants::{
DNS_CLASS_ANY, DNS_CLASS_CH, DNS_CLASS_HS, DNS_CLASS_IN, DNS_CLASS_NONE,
DNS_EDNS_DEFAULT_UDP_PAYLOAD_SIZE, DNS_EDNS_FLAG_DO, DNS_EDNS_OPTION_CLIENT_SUBNET,
DNS_EDNS_OPTION_COOKIE, DNS_EDNS_OPTION_DAU, DNS_EDNS_OPTION_DHU, DNS_EDNS_OPTION_EXPIRE,
DNS_EDNS_OPTION_EXTENDED_ERROR, DNS_EDNS_OPTION_N3U, DNS_EDNS_OPTION_NSID,
DNS_EDNS_OPTION_PADDING, DNS_EDNS_OPTION_TCP_KEEPALIVE, DNS_FLAG_AUTHENTIC_DATA,
DNS_FLAG_AUTHORITATIVE, DNS_FLAG_CHECKING_DISABLED, DNS_FLAG_QR_RESPONSE,
DNS_FLAG_RECURSION_AVAILABLE, DNS_FLAG_RECURSION_DESIRED, DNS_FLAG_TRUNCATED, DNS_HEADER_LEN,
DNS_OPCODE_DSO, DNS_OPCODE_IQUERY, DNS_OPCODE_NOTIFY, DNS_OPCODE_QUERY, DNS_OPCODE_STATUS,
DNS_OPCODE_UPDATE, DNS_PORT, DNS_RCODE_DSOTYPENI, DNS_RCODE_FORMERR, DNS_RCODE_NOERROR,
DNS_RCODE_NOTAUTH, DNS_RCODE_NOTIMP, DNS_RCODE_NOTZONE, DNS_RCODE_NXDOMAIN, DNS_RCODE_NXRRSET,
DNS_RCODE_REFUSED, DNS_RCODE_SERVFAIL, DNS_RCODE_YXDOMAIN, DNS_RCODE_YXRRSET,
DNS_SVCB_KEY_ALPN, DNS_SVCB_KEY_DOHPATH, DNS_SVCB_KEY_ECH, DNS_SVCB_KEY_IPV4HINT,
DNS_SVCB_KEY_IPV6HINT, DNS_SVCB_KEY_MANDATORY, DNS_SVCB_KEY_NO_DEFAULT_ALPN, DNS_SVCB_KEY_PORT,
DNS_TYPE_A, DNS_TYPE_AAAA, DNS_TYPE_CNAME, DNS_TYPE_DNSKEY, DNS_TYPE_DS, DNS_TYPE_HTTPS,
DNS_TYPE_MX, DNS_TYPE_NS, DNS_TYPE_NSEC, DNS_TYPE_NSEC3, DNS_TYPE_NSEC3PARAM, DNS_TYPE_OPT,
DNS_TYPE_PTR, DNS_TYPE_RRSIG, DNS_TYPE_SOA, DNS_TYPE_SRV, DNS_TYPE_SVCB, DNS_TYPE_TLSA,
DNS_TYPE_TXT,
};
const DNS_NAME_POINTER_MASK: u8 = 0xc0;
const DNS_NAME_POINTER_TAG: u8 = 0xc0;
const DNS_MAX_LABEL_LEN: usize = 63;
const DNS_MAX_NAME_WIRE_LEN: usize = 255;
const DNS_OPCODE_MASK: u16 = 0x7800;
const DNS_OPCODE_SHIFT: u16 = 11;
const DNS_RCODE_MASK: u16 = 0x000f;
const DNS_EDNS_EXTENDED_RCODE_SHIFT: u32 = 24;
const DNS_EDNS_VERSION_SHIFT: u32 = 16;
const DNS_EDNS_FLAGS_MASK: u32 = 0x0000_ffff;
macro_rules! impl_layer_object {
($type:ty) => {
fn clone_layer(&self) -> Box<dyn Layer> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
};
}
macro_rules! impl_layer_div {
($type:ty) => {
impl<R> Div<R> for $type
where
R: IntoPacket,
{
type Output = Packet;
fn div(self, rhs: R) -> Self::Output {
Packet::from_layer(self).concat(rhs)
}
}
};
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Dns {
id: Field<u16>,
flags: Field<u16>,
questions: Vec<DnsQuestion>,
answers: Vec<DnsRecord>,
authorities: Vec<DnsRecord>,
additionals: Vec<DnsRecord>,
}
impl Dns {
pub fn new() -> Self {
Self {
id: Field::defaulted(0),
flags: Field::defaulted(DNS_FLAG_RECURSION_DESIRED),
questions: Vec::new(),
answers: Vec::new(),
authorities: Vec::new(),
additionals: Vec::new(),
}
}
pub fn query(name: impl Into<DnsName>, question_type: u16) -> Self {
Self::new().question(DnsQuestion::new(name, question_type))
}
pub fn a_query(name: impl Into<DnsName>) -> Self {
Self::query(name, DNS_TYPE_A)
}
pub fn aaaa_query(name: impl Into<DnsName>) -> Self {
Self::query(name, DNS_TYPE_AAAA)
}
pub fn id(mut self, id: u16) -> Self {
self.id.set_user(id);
self
}
pub fn flags(mut self, flags: u16) -> Self {
self.flags.set_user(flags);
self
}
pub fn response(self, enabled: bool) -> Self {
self.set_flag(DNS_FLAG_QR_RESPONSE, enabled)
}
pub fn authoritative(self, enabled: bool) -> Self {
self.set_flag(DNS_FLAG_AUTHORITATIVE, enabled)
}
pub fn recursion_desired(self, enabled: bool) -> Self {
self.set_flag(DNS_FLAG_RECURSION_DESIRED, enabled)
}
pub fn rd(self, enabled: bool) -> Self {
self.recursion_desired(enabled)
}
pub fn recursion_available(self, enabled: bool) -> Self {
self.set_flag(DNS_FLAG_RECURSION_AVAILABLE, enabled)
}
pub fn opcode(mut self, opcode: u8) -> Self {
let field = ((opcode as u16) << DNS_OPCODE_SHIFT) & DNS_OPCODE_MASK;
let flags = (self.flags_value() & !DNS_OPCODE_MASK) | field;
self.flags.set_user(flags);
self
}
pub fn rcode(mut self, rcode: u8) -> Self {
let flags = (self.flags_value() & !DNS_RCODE_MASK) | ((rcode as u16) & DNS_RCODE_MASK);
self.flags.set_user(flags);
self
}
pub fn question(mut self, question: DnsQuestion) -> Self {
self.questions.push(question);
self
}
pub fn answer(mut self, answer: DnsRecord) -> Self {
self.answers.push(answer);
self
}
pub fn authority(mut self, authority: DnsRecord) -> Self {
self.authorities.push(authority);
self
}
pub fn additional(mut self, additional: DnsRecord) -> Self {
self.additionals.push(additional);
self
}
pub fn id_value(&self) -> u16 {
value_or_copy(&self.id, 0)
}
pub fn flags_value(&self) -> u16 {
value_or_copy(&self.flags, DNS_FLAG_RECURSION_DESIRED)
}
pub fn is_response(&self) -> bool {
self.flags_value() & DNS_FLAG_QR_RESPONSE != 0
}
pub fn opcode_value(&self) -> u8 {
((self.flags_value() & DNS_OPCODE_MASK) >> DNS_OPCODE_SHIFT) as u8
}
pub fn rcode_value(&self) -> u8 {
(self.flags_value() & DNS_RCODE_MASK) as u8
}
pub fn questions(&self) -> &[DnsQuestion] {
&self.questions
}
pub fn answers(&self) -> &[DnsRecord] {
&self.answers
}
pub fn authorities(&self) -> &[DnsRecord] {
&self.authorities
}
pub fn additionals(&self) -> &[DnsRecord] {
&self.additionals
}
fn set_flag(mut self, bit: u16, enabled: bool) -> Self {
let mut flags = self.flags_value();
if enabled {
flags |= bit;
} else {
flags &= !bit;
}
self.flags.set_user(flags);
self
}
fn encoded_message_len(&self) -> usize {
DNS_HEADER_LEN
+ self
.questions
.iter()
.map(DnsQuestion::encoded_len)
.sum::<usize>()
+ self
.answers
.iter()
.map(DnsRecord::encoded_len)
.sum::<usize>()
+ self
.authorities
.iter()
.map(DnsRecord::encoded_len)
.sum::<usize>()
+ self
.additionals
.iter()
.map(DnsRecord::encoded_len)
.sum::<usize>()
}
fn encode(&self, out: &mut Vec<u8>) -> Result<()> {
validate_count("dns.qdcount", self.questions.len())?;
validate_count("dns.ancount", self.answers.len())?;
validate_count("dns.nscount", self.authorities.len())?;
validate_count("dns.arcount", self.additionals.len())?;
out.extend_from_slice(&self.id_value().to_be_bytes());
out.extend_from_slice(&self.flags_value().to_be_bytes());
out.extend_from_slice(&(self.questions.len() as u16).to_be_bytes());
out.extend_from_slice(&(self.answers.len() as u16).to_be_bytes());
out.extend_from_slice(&(self.authorities.len() as u16).to_be_bytes());
out.extend_from_slice(&(self.additionals.len() as u16).to_be_bytes());
for question in &self.questions {
question.encode(out)?;
}
for answer in &self.answers {
answer.encode(out)?;
}
for authority in &self.authorities {
authority.encode(out)?;
}
for additional in &self.additionals {
additional.encode(out)?;
}
Ok(())
}
}
impl Default for Dns {
fn default() -> Self {
Self::new()
}
}
impl Layer for Dns {
fn name(&self) -> &'static str {
"Dns"
}
fn summary(&self) -> String {
let direction = if self.is_response() {
"response"
} else {
"query"
};
let question = self
.questions
.first()
.map(|question| {
format!(
" {} {}",
question.name(),
record_type_summary(question.question_type())
)
})
.unwrap_or_default();
format!(
"Dns(id=0x{:04x}, {direction}{question}, answers={})",
self.id_value(),
self.answers.len()
)
}
fn inspection_fields(&self) -> Vec<(&'static str, String)> {
vec![
("id", format!("0x{:04x}", self.id_value())),
("flags", format!("0x{:04x}", self.flags_value())),
("qdcount", self.questions.len().to_string()),
("ancount", self.answers.len().to_string()),
("nscount", self.authorities.len().to_string()),
("arcount", self.additionals.len().to_string()),
]
}
fn encoded_len(&self) -> usize {
self.encoded_message_len()
}
fn compile(&self, _ctx: &LayerContext<'_>, out: &mut Vec<u8>) -> Result<()> {
self.encode(out)
}
impl_layer_object!(Dns);
}
impl_layer_div!(Dns);
pub(crate) fn append_dns_packet(packet: Packet, bytes: &[u8]) -> Result<Packet> {
Ok(packet.push(decode_dns(bytes)?))
}
fn decode_dns(bytes: &[u8]) -> Result<Dns> {
if bytes.len() < DNS_HEADER_LEN {
return Err(CrafterError::buffer_too_short(
"dns header",
DNS_HEADER_LEN,
bytes.len(),
));
}
let qdcount = read_u16_be(&bytes[4..6])? as usize;
let ancount = read_u16_be(&bytes[6..8])? as usize;
let nscount = read_u16_be(&bytes[8..10])? as usize;
let arcount = read_u16_be(&bytes[10..12])? as usize;
let mut offset = DNS_HEADER_LEN;
let mut questions = Vec::with_capacity(qdcount);
let mut answers = Vec::with_capacity(ancount);
let mut authorities = Vec::with_capacity(nscount);
let mut additionals = Vec::with_capacity(arcount);
for _ in 0..qdcount {
let (question, next_offset) = decode_question(bytes, offset)?;
questions.push(question);
offset = next_offset;
}
for _ in 0..ancount {
let (record, next_offset) = decode_record(bytes, offset)?;
answers.push(record);
offset = next_offset;
}
for _ in 0..nscount {
let (record, next_offset) = decode_record(bytes, offset)?;
authorities.push(record);
offset = next_offset;
}
for _ in 0..arcount {
let (record, next_offset) = decode_record(bytes, offset)?;
additionals.push(record);
offset = next_offset;
}
if offset != bytes.len() {
return Err(CrafterError::invalid_field_value(
"dns.length",
"DNS message has trailing bytes after declared records",
));
}
Ok(Dns {
id: Field::user(read_u16_be(&bytes[0..2])?),
flags: Field::user(read_u16_be(&bytes[2..4])?),
questions,
answers,
authorities,
additionals,
})
}
fn decode_question(bytes: &[u8], offset: usize) -> Result<(DnsQuestion, usize)> {
let (name, consumed) = decode_dns_name_typed(bytes, offset)?;
let fields_offset = offset + consumed;
if fields_offset + 4 > bytes.len() {
return Err(CrafterError::buffer_too_short(
"dns question",
fields_offset + 4,
bytes.len(),
));
}
Ok((
DnsQuestion {
name,
question_type: read_u16_be(&bytes[fields_offset..fields_offset + 2])?,
question_class: read_u16_be(&bytes[fields_offset + 2..fields_offset + 4])?,
},
fields_offset + 4,
))
}
fn decode_record(bytes: &[u8], offset: usize) -> Result<(DnsRecord, usize)> {
let (name, consumed) = decode_dns_name_typed(bytes, offset)?;
let fields_offset = offset + consumed;
if fields_offset + 10 > bytes.len() {
return Err(CrafterError::buffer_too_short(
"dns record",
fields_offset + 10,
bytes.len(),
));
}
let record_type = read_u16_be(&bytes[fields_offset..fields_offset + 2])?;
let class = read_u16_be(&bytes[fields_offset + 2..fields_offset + 4])?;
let ttl = read_u32_be(&bytes[fields_offset + 4..fields_offset + 8])?;
let rdlength = read_u16_be(&bytes[fields_offset + 8..fields_offset + 10])? as usize;
let rdata_start = fields_offset + 10;
let rdata_end = rdata_start + rdlength;
if rdata_end > bytes.len() {
return Err(CrafterError::buffer_too_short(
"dns rdata",
rdata_end,
bytes.len(),
));
}
let data = decode_record_data(record_type, bytes, rdata_start, rdata_end)?;
Ok((
DnsRecord {
name,
record_type,
class,
ttl,
data,
},
rdata_end,
))
}
fn validate_count(field: &'static str, count: usize) -> Result<()> {
if count > u16::MAX as usize {
return Err(CrafterError::invalid_field_value(
field,
"DNS section count exceeds 65535",
));
}
Ok(())
}
fn value_or_copy<T: Copy>(field: &Field<T>, default: T) -> T {
field.value().copied().unwrap_or(default)
}
fn record_type_summary(record_type: u16) -> String {
match dns_type_name(record_type) {
Some(name) => name.to_string(),
None => format!("TYPE{record_type}"),
}
}
pub fn dns_type_name(record_type: u16) -> Option<&'static str> {
Some(match record_type {
DNS_TYPE_A => "A",
DNS_TYPE_NS => "NS",
DNS_TYPE_CNAME => "CNAME",
DNS_TYPE_SOA => "SOA",
DNS_TYPE_PTR => "PTR",
DNS_TYPE_MX => "MX",
DNS_TYPE_TXT => "TXT",
DNS_TYPE_AAAA => "AAAA",
DNS_TYPE_SRV => "SRV",
DNS_TYPE_OPT => "OPT",
DNS_TYPE_DS => "DS",
DNS_TYPE_RRSIG => "RRSIG",
DNS_TYPE_NSEC => "NSEC",
DNS_TYPE_DNSKEY => "DNSKEY",
DNS_TYPE_NSEC3 => "NSEC3",
DNS_TYPE_NSEC3PARAM => "NSEC3PARAM",
DNS_TYPE_TLSA => "TLSA",
DNS_TYPE_SVCB => "SVCB",
DNS_TYPE_HTTPS => "HTTPS",
_ => return None,
})
}
#[cfg(test)]
mod dns_tests {
use super::{
decode_dns_name, Dns, DnsName, DnsQuestion, DnsRecord, DnsRecordData, DNS_CLASS_IN,
DNS_FLAG_AUTHORITATIVE, DNS_FLAG_QR_RESPONSE, DNS_FLAG_RECURSION_DESIRED, DNS_TYPE_A,
DNS_TYPE_AAAA, DNS_TYPE_CNAME, DNS_TYPE_MX, DNS_TYPE_NS, DNS_TYPE_PTR, DNS_TYPE_SRV,
DNS_TYPE_TXT,
};
use crate::{Ipv4, NetworkLayer, Packet, Raw, Udp};
use core::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn dns_a_query_encodes_header_question_and_udp_payload() {
let dns = Dns::a_query("example.com").id(0xbeef);
let packet = Udp::new().sport(53001).dport(53) / dns;
let compiled = packet.compile().unwrap();
assert_eq!(&compiled.as_bytes()[8..10], &0xbeefu16.to_be_bytes());
assert_eq!(
&compiled.as_bytes()[10..12],
&DNS_FLAG_RECURSION_DESIRED.to_be_bytes()
);
assert_eq!(&compiled.as_bytes()[12..14], &1u16.to_be_bytes());
assert!(compiled.as_bytes().ends_with(&[
7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0, 0, 1, 0, 1,
]));
}
#[test]
fn dns_response_records_roundtrip() {
let original = Dns::new()
.id(0x1234)
.response(true)
.authoritative(true)
.question(DnsQuestion::a("example.com."))
.answer(DnsRecord::a(
"example.com.",
Ipv4Addr::new(203, 0, 113, 10),
60,
))
.answer(DnsRecord::aaaa(
"example.com.",
Ipv6Addr::from([0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
60,
))
.answer(DnsRecord::cname("www.example.com.", "example.com.", 60));
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ original.clone())
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.id_value(), 0x1234);
assert_eq!(
dns.flags_value() & (DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE),
DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE
);
assert_eq!(dns.questions()[0].name(), "example.com.");
assert_eq!(dns.answers().len(), 3);
assert_eq!(
dns.answers()[0].data(),
&DnsRecordData::A(Ipv4Addr::new(203, 0, 113, 10))
);
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn dns_decode_uses_udp_port_context() {
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(192, 0, 2, 10))
.dst(Ipv4Addr::new(198, 51, 100, 53))
/ Udp::new().sport(53001).dport(53)
/ Dns::aaaa_query("example.com").id(0x5678))
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.questions()[0].question_type(), DNS_TYPE_AAAA);
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn dns_builders_keep_common_values_visible() {
let query = Dns::new()
.id(7)
.rd(false)
.question(DnsQuestion::new("example.org", DNS_TYPE_A).qclass(DNS_CLASS_IN));
assert_eq!(query.id_value(), 7);
assert_eq!(query.flags_value() & DNS_FLAG_RECURSION_DESIRED, 0);
assert_eq!(query.questions()[0].name(), "example.org.");
assert_eq!(query.questions()[0].question_class(), DNS_CLASS_IN);
}
#[test]
fn dns_compressed_name_decode_is_exposed() {
let message = [
3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm',
0, 4, b'm', b'a', b'i', b'l', 0xc0, 4,
];
assert_eq!(
decode_dns_name(&message, 0).unwrap(),
("www.example.com.".to_string(), 17)
);
assert_eq!(
decode_dns_name(&message, 17).unwrap(),
("mail.example.com.".to_string(), 7)
);
}
#[test]
fn non_text_owner_name_round_trips_through_a_compiled_packet() {
let owner = DnsName::from_labels([vec![0x00u8, 0xff], b"example".to_vec()]).unwrap();
assert!(!owner.is_text());
let record = DnsRecord::new(
owner.clone(),
DNS_TYPE_TXT,
DNS_CLASS_IN,
300,
DnsRecordData::txt(b"v"),
);
let original = Dns::new().id(0x4242).response(true).answer(record);
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ original)
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
let answer = &dns.answers()[0];
assert_eq!(answer.name_labels(), owner.labels());
assert_eq!(answer.dns_name(), &owner);
assert_eq!(answer.name(), "\\000\\255.example.");
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn dns_compression_decodes_then_recompiles_uncompressed() {
let message: &[u8] = &[
0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00,
7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c, 0x00, 0x08, 5, b'a', b'l', b'i', b'a', b's', 0xc0, 0x0c, 0xc0, 0x0c, 0x00, 0x0f, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c, 0x00, 0x09, 0x00, 0x0a, 4, b'm', b'a', b'i', b'l', 0xc0, 0x0c, 4, b'_', b's', b'i', b'p', 4, b'_', b't', b'c', b'p', 0xc0, 0x0c, 0x00, 0x21, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c, 0x00, 0x0c, 0x00, 0x0a, 0x00, 0x05, 0x16, 0x2c, 3, b's', b'i', b'p', 0xc0, 0x0c,
];
let wire = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ Raw::from_bytes(message))
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, wire.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.questions()[0].name(), "example.com.");
assert_eq!(dns.answers().len(), 3);
assert_eq!(dns.answers()[0].name(), "example.com.");
assert_eq!(dns.answers()[0].record_type(), DNS_TYPE_CNAME);
match dns.answers()[0].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "alias.example.com."),
other => panic!("expected CNAME name, got {other:?}"),
}
assert_eq!(dns.answers()[1].record_type(), DNS_TYPE_MX);
match dns.answers()[1].data() {
DnsRecordData::Mx {
preference,
exchange,
} => {
assert_eq!(*preference, 10);
assert_eq!(exchange.presentation(), "mail.example.com.");
}
other => panic!("expected MX, got {other:?}"),
}
assert_eq!(dns.answers()[2].name(), "_sip._tcp.example.com.");
assert_eq!(dns.answers()[2].record_type(), DNS_TYPE_SRV);
match dns.answers()[2].data() {
DnsRecordData::Srv {
priority,
weight,
port,
target,
} => {
assert_eq!((*priority, *weight, *port), (10, 5, 5676));
assert_eq!(target.presentation(), "sip.example.com.");
}
other => panic!("expected SRV, got {other:?}"),
}
let recompiled = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ dns.clone())
.compile()
.unwrap();
let dns_payload = &recompiled.as_bytes()[28..];
assert!(dns_payload.len() > message.len());
assert!(
!dns_payload.iter().any(|&b| b & 0xc0 == 0xc0),
"recompiled DNS payload must not contain a compression pointer",
);
let redecoded = Packet::decode_from_l3(NetworkLayer::Ipv4, recompiled.as_bytes()).unwrap();
let redns = redecoded.layer::<Dns>().unwrap();
assert_eq!(redns.answers().len(), 3);
assert_eq!(redns.questions()[0].name(), "example.com.");
assert_eq!(redns.answers()[0].data(), dns.answers()[0].data());
assert_eq!(redns.answers()[1].data(), dns.answers()[1].data());
assert_eq!(redns.answers()[2].data(), dns.answers()[2].data());
assert_eq!(redns.answers()[2].name(), "_sip._tcp.example.com.");
}
#[test]
fn name_records_round_trip_through_a_compiled_packet() {
let original = Dns::new()
.id(0x4e43)
.response(true)
.authoritative(true)
.question(DnsQuestion::new("example.com.", DNS_TYPE_NS).qclass(DNS_CLASS_IN))
.answer(DnsRecord::new(
"example.com.",
DNS_TYPE_NS,
DNS_CLASS_IN,
3600,
DnsRecordData::name("ns1.example.com."),
))
.answer(DnsRecord::cname("www.example.com.", "host.example.", 300))
.answer(DnsRecord::new(
"20.113.0.203.in-addr.arpa.",
DNS_TYPE_PTR,
DNS_CLASS_IN,
300,
DnsRecordData::name("host.example.com."),
));
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ original.clone())
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.questions()[0].name(), "example.com.");
assert_eq!(dns.questions()[0].question_type(), DNS_TYPE_NS);
assert_eq!(dns.answers().len(), 3);
assert_eq!(dns.answers()[0].record_type(), DNS_TYPE_NS);
match dns.answers()[0].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "ns1.example.com."),
other => panic!("expected NS name, got {other:?}"),
}
assert_eq!(dns.answers()[1].record_type(), DNS_TYPE_CNAME);
assert_eq!(dns.answers()[1].name(), "www.example.com.");
match dns.answers()[1].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "host.example."),
other => panic!("expected CNAME name, got {other:?}"),
}
assert_eq!(dns.answers()[2].record_type(), DNS_TYPE_PTR);
assert_eq!(dns.answers()[2].name(), "20.113.0.203.in-addr.arpa.");
match dns.answers()[2].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "host.example.com."),
other => panic!("expected PTR name, got {other:?}"),
}
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn compressed_name_records_normalize_to_uncompressed_model() {
let message: &[u8] = &[
0x4e, 0x43, 0x84, 0x00, 0x00, 0x01, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00,
7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0, 0x00, 0x02, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x02, 0x00, 0x01, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x06, 3, b'n', b's', b'1', 0xc0, 0x0c, 3, b'w', b'w', b'w', 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c, 0x00, 0x07, 4, b'h', b'o', b's', b't', 0xc0, 0x0c, 3, b'p', b't', b'r', 0xc0, 0x0c, 0x00, 0x0c, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c, 0x00, 0x07, 4, b'h', b'o', b's', b't', 0xc0, 0x0c,
];
let wire = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ Raw::from_bytes(message))
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, wire.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.questions()[0].name(), "example.com.");
assert_eq!(dns.answers().len(), 3);
assert_eq!(dns.answers()[0].name(), "example.com.");
assert_eq!(dns.answers()[0].record_type(), DNS_TYPE_NS);
match dns.answers()[0].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "ns1.example.com."),
other => panic!("expected NS name, got {other:?}"),
}
assert_eq!(dns.answers()[1].name(), "www.example.com.");
assert_eq!(dns.answers()[1].record_type(), DNS_TYPE_CNAME);
match dns.answers()[1].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "host.example.com."),
other => panic!("expected CNAME name, got {other:?}"),
}
assert_eq!(dns.answers()[2].name(), "ptr.example.com.");
assert_eq!(dns.answers()[2].record_type(), DNS_TYPE_PTR);
match dns.answers()[2].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "host.example.com."),
other => panic!("expected PTR name, got {other:?}"),
}
let recompiled = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ dns.clone())
.compile()
.unwrap();
let dns_payload = &recompiled.as_bytes()[28..];
assert!(dns_payload.len() > message.len());
assert!(
!dns_payload.iter().any(|&b| b & 0xc0 == 0xc0),
"recompiled DNS payload must not contain a compression pointer",
);
let redecoded = Packet::decode_from_l3(NetworkLayer::Ipv4, recompiled.as_bytes()).unwrap();
let redns = redecoded.layer::<Dns>().unwrap();
assert_eq!(redns.answers().len(), 3);
assert_eq!(redns.answers()[0].data(), dns.answers()[0].data());
assert_eq!(redns.answers()[1].data(), dns.answers()[1].data());
assert_eq!(redns.answers()[2].data(), dns.answers()[2].data());
}
#[test]
fn dns_type_mismatch_is_rejected() {
let record = DnsRecord::new(
"example.com.",
DNS_TYPE_CNAME,
DNS_CLASS_IN,
60,
DnsRecordData::A(Ipv4Addr::new(203, 0, 113, 1)),
);
assert!(Packet::from_layer(Dns::new().answer(record))
.compile()
.is_err());
}
}
#[cfg(test)]
mod dns_header_codepoints {
use super::{
dns_type_name, Dns, DnsQuestion, DnsRecord, DnsRecordData, DNS_CLASS_ANY, DNS_CLASS_CH,
DNS_CLASS_HS, DNS_CLASS_IN, DNS_CLASS_NONE, DNS_FLAG_AUTHENTIC_DATA,
DNS_FLAG_AUTHORITATIVE, DNS_FLAG_CHECKING_DISABLED, DNS_FLAG_QR_RESPONSE,
DNS_FLAG_RECURSION_AVAILABLE, DNS_FLAG_RECURSION_DESIRED, DNS_FLAG_TRUNCATED,
DNS_OPCODE_QUERY, DNS_OPCODE_STATUS, DNS_OPCODE_UPDATE, DNS_RCODE_NOERROR,
DNS_RCODE_NXDOMAIN, DNS_RCODE_REFUSED, DNS_TYPE_A, DNS_TYPE_AAAA, DNS_TYPE_HTTPS,
DNS_TYPE_MX, DNS_TYPE_NS, DNS_TYPE_OPT, DNS_TYPE_SOA, DNS_TYPE_SRV, DNS_TYPE_TXT,
};
use crate::{Ipv4, NetworkLayer, Packet, Udp};
use std::net::Ipv4Addr;
#[test]
fn existing_flag_helpers_compile_identically() {
let dns = Dns::new()
.id(0xbeef)
.response(true)
.authoritative(true)
.question(DnsQuestion::a("example.com."));
let compiled = (Udp::new().sport(53001).dport(53) / dns).compile().unwrap();
let expected_flags =
DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE | DNS_FLAG_RECURSION_DESIRED;
assert_eq!(&compiled.as_bytes()[10..12], &expected_flags.to_be_bytes());
}
#[test]
fn opcode_setter_preserves_unrelated_bits() {
let dns = Dns::new()
.response(true)
.authoritative(true)
.rcode(DNS_RCODE_NXDOMAIN)
.opcode(DNS_OPCODE_UPDATE);
assert_eq!(dns.opcode_value(), DNS_OPCODE_UPDATE);
assert!(dns.is_response());
assert_ne!(dns.flags_value() & DNS_FLAG_AUTHORITATIVE, 0);
assert_ne!(dns.flags_value() & DNS_FLAG_RECURSION_DESIRED, 0);
assert_eq!(dns.rcode_value(), DNS_RCODE_NXDOMAIN);
}
#[test]
fn rcode_setter_preserves_unrelated_bits() {
let dns = Dns::new()
.opcode(DNS_OPCODE_STATUS)
.response(true)
.rcode(DNS_RCODE_REFUSED);
assert_eq!(dns.rcode_value(), DNS_RCODE_REFUSED);
assert_eq!(dns.opcode_value(), DNS_OPCODE_STATUS);
assert!(dns.is_response());
}
#[test]
fn opcode_and_rcode_defaults_are_query_noerror() {
let dns = Dns::new();
assert_eq!(dns.opcode_value(), DNS_OPCODE_QUERY);
assert_eq!(dns.rcode_value(), DNS_RCODE_NOERROR);
}
#[test]
fn unknown_opcode_and_rcode_values_round_trip() {
let dns = Dns::new().opcode(0xf).rcode(0xf);
assert_eq!(dns.opcode_value(), 0xf);
assert_eq!(dns.rcode_value(), 0xf);
let truncated = Dns::new().opcode(0xff).rcode(0xff);
assert_eq!(truncated.opcode_value(), 0xf);
assert_eq!(truncated.rcode_value(), 0xf);
}
#[test]
fn raw_flags_remain_the_escape_hatch() {
let dns = Dns::new().flags(0xabcd);
assert_eq!(dns.flags_value(), 0xabcd);
assert_eq!(dns.opcode_value(), ((0xabcd & 0x7800) >> 11) as u8);
assert_eq!(dns.rcode_value(), (0xabcd & 0x000f) as u8);
}
#[test]
fn all_named_header_flag_bits_survive_compile_and_decode() {
let all_named = DNS_FLAG_AUTHORITATIVE
| DNS_FLAG_TRUNCATED
| DNS_FLAG_RECURSION_DESIRED
| DNS_FLAG_RECURSION_AVAILABLE
| DNS_FLAG_AUTHENTIC_DATA
| DNS_FLAG_CHECKING_DISABLED;
let dns = Dns::new()
.flags(all_named)
.question(DnsQuestion::a("example.com."));
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(203, 0, 113, 1))
.dst(Ipv4Addr::new(198, 51, 100, 1))
/ Udp::new().sport(53001).dport(53)
/ dns)
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let header = decoded.layer::<Dns>().unwrap();
assert_eq!(header.flags_value(), all_named);
assert_ne!(header.flags_value() & DNS_FLAG_AUTHENTIC_DATA, 0);
assert_ne!(header.flags_value() & DNS_FLAG_CHECKING_DISABLED, 0);
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn recursion_available_setter_sets_only_its_bit() {
let dns = Dns::new().recursion_available(true);
assert_ne!(dns.flags_value() & DNS_FLAG_RECURSION_AVAILABLE, 0);
assert_ne!(dns.flags_value() & DNS_FLAG_RECURSION_DESIRED, 0);
assert_eq!(
dns.flags_value() & !(DNS_FLAG_RECURSION_AVAILABLE | DNS_FLAG_RECURSION_DESIRED),
0
);
}
#[test]
fn section_counts_auto_fill_from_typed_vectors() {
let dns = Dns::new()
.response(true)
.question(DnsQuestion::a("example.com."))
.answer(DnsRecord::a(
"example.com.",
Ipv4Addr::new(192, 0, 2, 10),
60,
))
.authority(DnsRecord::cname("example.com.", "ns1.example.com.", 300))
.additional(DnsRecord::a(
"ns1.example.com.",
Ipv4Addr::new(192, 0, 2, 53),
300,
));
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(203, 0, 113, 1))
.dst(Ipv4Addr::new(198, 51, 100, 1))
/ Udp::new().sport(53).dport(53001)
/ dns)
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let header = decoded.layer::<Dns>().unwrap();
assert_eq!(header.questions().len(), 1);
assert_eq!(header.answers().len(), 1);
assert_eq!(header.authorities().len(), 1);
assert_eq!(header.additionals().len(), 1);
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn section_placement_survives_decode_and_recompile() {
let dns = Dns::new()
.id(0x5023)
.response(true)
.authoritative(true)
.question(DnsQuestion::a("example.com."))
.answer(DnsRecord::a(
"example.com.",
Ipv4Addr::new(192, 0, 2, 10),
3600,
))
.authority(DnsRecord::new(
"example.com.",
DNS_TYPE_NS,
DNS_CLASS_IN,
3600,
DnsRecordData::name("ns1.example.com."),
))
.additional(DnsRecord::a(
"ns1.example.com.",
Ipv4Addr::new(192, 0, 2, 53),
3600,
))
.additional(DnsRecord::opt(1232, 0, 0, false, Vec::new()));
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(203, 0, 113, 1))
.dst(Ipv4Addr::new(198, 51, 100, 1))
/ Udp::new().sport(53).dport(53001)
/ dns)
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let header = decoded.layer::<Dns>().unwrap();
assert_eq!(header.questions().len(), 1);
assert_eq!(header.answers().len(), 1);
assert_eq!(header.authorities().len(), 1);
assert_eq!(header.additionals().len(), 2);
assert_eq!(header.answers()[0].record_type(), DNS_TYPE_A);
assert_eq!(header.answers()[0].name(), "example.com.");
assert_eq!(header.authorities()[0].record_type(), DNS_TYPE_NS);
assert_eq!(
header.authorities()[0].data(),
&DnsRecordData::name("ns1.example.com.")
);
assert_eq!(header.additionals()[0].record_type(), DNS_TYPE_A);
assert!(!header.additionals()[0].is_opt());
assert_eq!(header.additionals()[0].name(), "ns1.example.com.");
assert_eq!(header.additionals()[1].record_type(), DNS_TYPE_OPT);
assert!(header.additionals()[1].is_opt());
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn multi_question_class_and_type_values_round_trip() {
const QTYPE_ANY: u16 = 255;
const PRIVATE_QTYPE: u16 = 65280;
const PRIVATE_QCLASS: u16 = 65280;
let questions = [
(DNS_TYPE_A, DNS_CLASS_IN),
(DNS_TYPE_AAAA, DNS_CLASS_CH),
(DNS_TYPE_MX, DNS_CLASS_HS),
(DNS_TYPE_TXT, DNS_CLASS_NONE),
(QTYPE_ANY, DNS_CLASS_ANY),
(PRIVATE_QTYPE, PRIVATE_QCLASS),
];
let mut dns = Dns::new().rd(true);
for (index, (qtype, qclass)) in questions.iter().enumerate() {
let name = format!("q{index}.example.com.");
dns = dns.question(DnsQuestion::new(name, *qtype).qclass(*qclass));
}
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(192, 0, 2, 10))
.dst(Ipv4Addr::new(198, 51, 100, 53))
/ Udp::new().sport(53001).dport(53)
/ dns)
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let header = decoded.layer::<Dns>().unwrap();
assert_eq!(header.questions().len(), questions.len());
for (index, (qtype, qclass)) in questions.iter().enumerate() {
let question = &header.questions()[index];
assert_eq!(question.name(), format!("q{index}.example.com."));
assert_eq!(question.question_type(), *qtype);
assert_eq!(question.question_class(), *qclass);
}
assert_eq!(header.questions()[5].question_type(), PRIVATE_QTYPE);
assert_eq!(header.questions()[5].question_class(), PRIVATE_QCLASS);
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn type_names_cover_source_backed_codepoints() {
assert_eq!(dns_type_name(DNS_TYPE_A), Some("A"));
assert_eq!(dns_type_name(DNS_TYPE_SOA), Some("SOA"));
assert_eq!(dns_type_name(DNS_TYPE_SRV), Some("SRV"));
assert_eq!(dns_type_name(DNS_TYPE_HTTPS), Some("HTTPS"));
assert_eq!(dns_type_name(60000), None);
}
}
#[cfg(test)]
mod dns_malformed {
use super::{
decode_dns_name_typed, decode_record_data, DnsRecordData, DNS_MAX_LABEL_LEN, DNS_TYPE_A,
DNS_TYPE_AAAA, DNS_TYPE_DNSKEY, DNS_TYPE_DS, DNS_TYPE_NSEC, DNS_TYPE_NSEC3, DNS_TYPE_RRSIG,
DNS_TYPE_SOA, DNS_TYPE_SRV, DNS_TYPE_SVCB,
};
use crate::error::CrafterError;
fn assert_too_short(record_type: u16, rdata: &[u8], field: &str) {
match decode_record_data(record_type, rdata, 0, rdata.len()) {
Err(CrafterError::BufferTooShort { context, .. }) => assert_eq!(
context, field,
"type {record_type:#x} expected buffer-too-short context {field}"
),
other => {
panic!("type {record_type:#x} expected buffer-too-short {field}, got {other:?}")
}
}
}
fn assert_invalid(record_type: u16, rdata: &[u8], field: &str) {
match decode_record_data(record_type, rdata, 0, rdata.len()) {
Err(CrafterError::InvalidFieldValue { field: got, .. }) => assert_eq!(
got, field,
"type {record_type:#x} expected invalid-field-value {field}"
),
other => {
panic!("type {record_type:#x} expected invalid-field-value {field}, got {other:?}")
}
}
}
#[test]
fn fixed_length_records_reject_wrong_rdlength() {
assert_invalid(DNS_TYPE_A, &[192, 0, 2], "dns.a.rdlength");
assert_invalid(DNS_TYPE_A, &[192, 0, 2, 1, 9], "dns.a.rdlength");
assert_invalid(DNS_TYPE_AAAA, &[0u8; 15], "dns.aaaa.rdlength");
assert_invalid(DNS_TYPE_AAAA, &[0u8; 17], "dns.aaaa.rdlength");
}
#[test]
fn dnssec_fixed_headers_reject_truncation() {
assert_too_short(DNS_TYPE_DS, &[0u8; 3], "dns.ds");
assert_too_short(DNS_TYPE_DNSKEY, &[0u8; 3], "dns.dnskey");
assert_too_short(DNS_TYPE_RRSIG, &[0u8; 17], "dns.rrsig");
}
#[test]
fn soa_rejects_wrong_fixed_tail_length() {
let mut short = vec![1u8, b'a', 0, 0];
short.extend_from_slice(&[0u8; 19]);
assert_invalid(DNS_TYPE_SOA, &short, "dns.soa.rdlength");
let mut long = vec![1u8, b'a', 0, 0];
long.extend_from_slice(&[0u8; 21]);
assert_invalid(DNS_TYPE_SOA, &long, "dns.soa.rdlength");
}
#[test]
fn srv_rejects_short_header_and_trailing_bytes() {
assert_too_short(DNS_TYPE_SRV, &[0u8; 5], "dns.srv");
assert_invalid(
DNS_TYPE_SRV,
&[0, 1, 0, 2, 0x13, 0x88, 0, 0xff],
"dns.srv.target",
);
}
#[test]
fn nsec3_rejects_hash_length_overrun() {
assert_too_short(
DNS_TYPE_NSEC3,
&[1, 0, 0, 10, 0, 20, 0x11, 0x22],
"dns.nsec3.hash",
);
}
#[test]
fn nsec_rejects_non_minimal_trailing_zero_bitmap() {
let rdata = [0u8, 0x00, 2, 0x40, 0x00];
assert_invalid(DNS_TYPE_NSEC, &rdata, "dns.nsec.bitmap");
}
#[test]
fn svcb_rejects_out_of_order_and_overrun_params() {
let out_of_order = [0u8, 1, 0, 0, 3, 0, 0, 0, 1, 0, 0];
assert_invalid(DNS_TYPE_SVCB, &out_of_order, "dns.svcb.params");
let overrun = [0u8, 1, 0, 0, 3, 0, 8, 0, 0];
assert_too_short(DNS_TYPE_SVCB, &overrun, "dns.svcb.params");
}
#[test]
fn unknown_record_type_decodes_as_raw_not_rejected() {
let rdata = [0xde, 0xad, 0xbe, 0xef];
let data = decode_record_data(0xfff0, &rdata, 0, rdata.len()).unwrap();
assert_eq!(data, DnsRecordData::Raw(rdata.to_vec()));
let empty = decode_record_data(0xfff0, &[], 0, 0).unwrap();
assert_eq!(empty, DnsRecordData::Raw(Vec::new()));
}
#[test]
fn name_decoder_rejects_reserved_marker_and_length_overrun() {
match decode_dns_name_typed(&[0x40], 0) {
Err(CrafterError::InvalidFieldValue { field, .. }) => assert_eq!(field, "dns.name"),
other => panic!("expected reserved-marker rejection, got {other:?}"),
}
let mut overrun = Vec::new();
for _ in 0..4 {
overrun.push(DNS_MAX_LABEL_LEN as u8);
overrun.extend_from_slice(&[b'a'; DNS_MAX_LABEL_LEN]);
}
overrun.push(0);
match decode_dns_name_typed(&overrun, 0) {
Err(CrafterError::InvalidFieldValue { field, .. }) => assert_eq!(field, "dns.name"),
other => panic!("expected full-name overrun rejection, got {other:?}"),
}
}
#[test]
fn label_at_63_octet_boundary_is_accepted() {
let mut wire = Vec::new();
wire.push(DNS_MAX_LABEL_LEN as u8);
wire.extend_from_slice(&[b'a'; DNS_MAX_LABEL_LEN]);
wire.push(0);
let (name, used) = decode_dns_name_typed(&wire, 0).unwrap();
assert_eq!(used, wire.len());
assert_eq!(name.labels(), &[vec![b'a'; DNS_MAX_LABEL_LEN]]);
}
}
#[cfg(test)]
mod dns_golden_bytes {
use super::{Dns, DnsQuestion, DNS_TYPE_A};
use crate::{Ipv4, LinkType, Packet, Udp};
use core::net::Ipv4Addr;
const DNS_QUERY_FIXTURE: &[u8] = fixture_bytes!("bytes/ipv4-udp-dns-query-example-com.bin");
#[test]
fn dns_query_matches_golden_bytes() {
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(192, 0, 2, 10))
.dst(Ipv4Addr::new(198, 51, 100, 53))
.id(0x1237)
.ttl(61)
/ Udp::new().sport(53001).dport(53)
/ Dns::new()
.id(0xbeef)
.question(DnsQuestion::new("example.com.", DNS_TYPE_A)))
.compile()
.unwrap();
assert_eq!(bytes.as_bytes(), DNS_QUERY_FIXTURE);
}
#[test]
fn dns_query_fixture_decodes_to_typed_layer() {
let decoded = Packet::decode_from_l3(crate::NetworkLayer::Ipv4, DNS_QUERY_FIXTURE).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.id_value(), 0xbeef);
assert_eq!(dns.questions()[0].name(), "example.com.");
assert_eq!(dns.questions()[0].question_type(), DNS_TYPE_A);
assert_eq!(decoded.compile().unwrap().as_bytes(), DNS_QUERY_FIXTURE);
}
#[test]
fn non_dns_udp_payload_stays_raw_even_when_decoding_from_link() {
let raw_fixture = fixture_bytes!("bytes/ethernet-vlan-ipv4-udp-raw.bin");
let decoded = Packet::decode_from_link(LinkType::Ethernet, raw_fixture).unwrap();
assert!(decoded.layer::<Dns>().is_none());
}
}