use core::any::Any;
use core::ops::Div;
use crate::endian::{read_u16_be, read_u32_be};
use crate::error::{CrafterError, Result};
use crate::field::Field;
use crate::packet::{IntoPacket, Layer, LayerContext, Packet};
use crate::protocols::transport::Udp;
mod constants;
mod dns_sd;
mod dnssec;
mod edns;
mod name;
mod rdata;
mod record;
mod svcb;
use rdata::decode_record_data;
pub use dns_sd::{
dns_sd_instance_name, dns_sd_instance_name_from_labels, dns_sd_service_enumeration_name,
dns_sd_service_name, dns_sd_service_name_from_labels, dns_sd_subtype_name,
dns_sd_subtype_name_from_labels, dns_sd_tcp_instance_name, dns_sd_tcp_service_name,
dns_sd_udp_instance_name, dns_sd_udp_service_name,
};
pub use dnssec::DnsTypeBitmaps;
pub use edns::{edns_option_code_name, EdnsOption};
pub use name::{decode_dns_name, decode_dns_name_typed, DnsName};
pub use rdata::DnsRecordData;
pub use record::{DnsQuestion, DnsRecord};
pub use svcb::{svcb_param_key_name, SvcParam, SvcParams};
pub use constants::{
DNS_CLASS_ANY, DNS_CLASS_CH, DNS_CLASS_HS, DNS_CLASS_IN, DNS_CLASS_NONE,
DNS_EDNS_DEFAULT_UDP_PAYLOAD_SIZE, DNS_EDNS_FLAG_DO, DNS_EDNS_OPTION_CLIENT_SUBNET,
DNS_EDNS_OPTION_COOKIE, DNS_EDNS_OPTION_DAU, DNS_EDNS_OPTION_DHU, DNS_EDNS_OPTION_EXPIRE,
DNS_EDNS_OPTION_EXTENDED_ERROR, DNS_EDNS_OPTION_N3U, DNS_EDNS_OPTION_NSID,
DNS_EDNS_OPTION_PADDING, DNS_EDNS_OPTION_TCP_KEEPALIVE, DNS_FLAG_AUTHENTIC_DATA,
DNS_FLAG_AUTHORITATIVE, DNS_FLAG_CHECKING_DISABLED, DNS_FLAG_QR_RESPONSE,
DNS_FLAG_RECURSION_AVAILABLE, DNS_FLAG_RECURSION_DESIRED, DNS_FLAG_TRUNCATED, DNS_HEADER_LEN,
DNS_OPCODE_DSO, DNS_OPCODE_IQUERY, DNS_OPCODE_NOTIFY, DNS_OPCODE_QUERY, DNS_OPCODE_STATUS,
DNS_OPCODE_UPDATE, DNS_PORT, DNS_RCODE_DSOTYPENI, DNS_RCODE_FORMERR, DNS_RCODE_NOERROR,
DNS_RCODE_NOTAUTH, DNS_RCODE_NOTIMP, DNS_RCODE_NOTZONE, DNS_RCODE_NXDOMAIN, DNS_RCODE_NXRRSET,
DNS_RCODE_REFUSED, DNS_RCODE_SERVFAIL, DNS_RCODE_YXDOMAIN, DNS_RCODE_YXRRSET,
DNS_SD_DEFAULT_DOMAIN, DNS_SD_SERVICE_ENUMERATION_NAME, DNS_SVCB_KEY_ALPN,
DNS_SVCB_KEY_DOHPATH, DNS_SVCB_KEY_ECH, DNS_SVCB_KEY_IPV4HINT, DNS_SVCB_KEY_IPV6HINT,
DNS_SVCB_KEY_MANDATORY, DNS_SVCB_KEY_NO_DEFAULT_ALPN, DNS_SVCB_KEY_PORT, DNS_TYPE_A,
DNS_TYPE_AAAA, DNS_TYPE_CNAME, DNS_TYPE_DNSKEY, DNS_TYPE_DS, DNS_TYPE_HTTPS, DNS_TYPE_MX,
DNS_TYPE_NS, DNS_TYPE_NSEC, DNS_TYPE_NSEC3, DNS_TYPE_NSEC3PARAM, DNS_TYPE_OPT, DNS_TYPE_PTR,
DNS_TYPE_RRSIG, DNS_TYPE_SOA, DNS_TYPE_SRV, DNS_TYPE_SVCB, DNS_TYPE_TLSA, DNS_TYPE_TXT,
MDNS_CLASS_BIT, MDNS_CLASS_MASK, MDNS_GOODBYE_TTL, MDNS_IPV4_ETHERNET_MULTICAST,
MDNS_IPV4_MULTICAST, MDNS_IPV6_ETHERNET_MULTICAST, MDNS_IPV6_LINK_LOCAL_MULTICAST,
MDNS_IPV6_MULTICAST, MDNS_PORT, MDNS_RESPONSE_HOP_LIMIT, MDNS_RESPONSE_TTL,
};
pub mod mdns {
use core::net::{Ipv4Addr, Ipv6Addr};
use crate::mac::MacAddr;
use crate::packet::Packet;
use crate::protocols::ip::v4::Ipv4;
use crate::protocols::ip::v6::Ipv6;
use crate::protocols::link::{Ethernet, ETHERTYPE_IPV4, ETHERTYPE_IPV6};
use crate::protocols::transport::Udp;
use super::{
Dns, DnsName, DnsQuestion, DnsRecord, DnsRecordData, DNS_CLASS_IN, DNS_TYPE_PTR,
DNS_TYPE_TXT,
};
pub use super::{
dns_sd_instance_name, dns_sd_instance_name_from_labels, dns_sd_service_enumeration_name,
dns_sd_service_name, dns_sd_service_name_from_labels, dns_sd_subtype_name,
dns_sd_subtype_name_from_labels, dns_sd_tcp_instance_name, dns_sd_tcp_service_name,
dns_sd_udp_instance_name, dns_sd_udp_service_name,
};
pub use super::{
DNS_SD_DEFAULT_DOMAIN, DNS_SD_SERVICE_ENUMERATION_NAME, MDNS_CLASS_BIT, MDNS_CLASS_MASK,
MDNS_GOODBYE_TTL, MDNS_IPV4_ETHERNET_MULTICAST, MDNS_IPV4_MULTICAST,
MDNS_IPV6_ETHERNET_MULTICAST, MDNS_IPV6_LINK_LOCAL_MULTICAST, MDNS_IPV6_MULTICAST,
MDNS_PORT, MDNS_RESPONSE_HOP_LIMIT, MDNS_RESPONSE_TTL,
};
pub fn query(question: DnsQuestion) -> Dns {
Dns::mdns_query(question)
}
pub fn query_for(name: impl Into<DnsName>, question_type: u16) -> Dns {
Dns::mdns_query_for(name, question_type)
}
pub fn response() -> Dns {
Dns::mdns_response()
}
pub fn response_with_answers(records: impl IntoIterator<Item = DnsRecord>) -> Dns {
Dns::mdns_response_with_answers(records)
}
pub fn known_answer_query(
question: DnsQuestion,
known_answers: impl IntoIterator<Item = DnsRecord>,
) -> Dns {
Dns::mdns_known_answer_query(question, known_answers)
}
pub fn continued_known_answer_query(
question: DnsQuestion,
known_answers: impl IntoIterator<Item = DnsRecord>,
) -> Dns {
Dns::mdns_continued_known_answer_query(question, known_answers)
}
pub fn response_with_additionals(
answers: impl IntoIterator<Item = DnsRecord>,
additionals: impl IntoIterator<Item = DnsRecord>,
) -> Dns {
Dns::mdns_response_with_additionals(answers, additionals)
}
pub fn announcement(records: impl IntoIterator<Item = DnsRecord>) -> Dns {
Dns::mdns_announcement(records)
}
pub fn announce_response(
records: impl IntoIterator<Item = DnsRecord>,
additionals: impl IntoIterator<Item = DnsRecord>,
) -> Dns {
Dns::mdns_announce_response(records, additionals)
}
pub fn goodbye_response(records: impl IntoIterator<Item = DnsRecord>) -> Dns {
Dns::mdns_goodbye_response(records)
}
pub fn probe(name: impl Into<DnsName>) -> Dns {
Dns::mdns_probe(name)
}
pub fn probe_for(name: impl Into<DnsName>, question_type: u16) -> Dns {
Dns::mdns_probe_for(name, question_type)
}
pub fn probe_with_authorities(
question: DnsQuestion,
proposed: impl IntoIterator<Item = DnsRecord>,
) -> Dns {
Dns::mdns_probe_with_authorities(question, proposed)
}
pub fn cache_flush(record: DnsRecord) -> DnsRecord {
record.mdns_cache_flush(true)
}
pub fn shared_record(record: DnsRecord) -> DnsRecord {
record.mdns_cache_flush(false)
}
pub fn goodbye(record: DnsRecord) -> DnsRecord {
record.mdns_goodbye()
}
pub fn service_ptr(
service_type: impl Into<DnsName>,
service_instance: impl Into<DnsName>,
ttl: u32,
) -> DnsRecord {
DnsRecord::new(
service_type,
DNS_TYPE_PTR,
DNS_CLASS_IN,
ttl,
DnsRecordData::name(service_instance),
)
}
pub fn subtype_ptr(
subtype_name: impl Into<DnsName>,
service_instance: impl Into<DnsName>,
ttl: u32,
) -> DnsRecord {
DnsRecord::new(
subtype_name,
DNS_TYPE_PTR,
DNS_CLASS_IN,
ttl,
DnsRecordData::name(service_instance),
)
}
pub fn srv(
service_instance: impl Into<DnsName>,
target: impl Into<DnsName>,
port: u16,
ttl: u32,
) -> DnsRecord {
srv_with_priority(service_instance, target, 0, 0, port, ttl)
}
pub fn srv_with_priority(
service_instance: impl Into<DnsName>,
target: impl Into<DnsName>,
priority: u16,
weight: u16,
port: u16,
ttl: u32,
) -> DnsRecord {
DnsRecord::srv(service_instance, ttl, priority, weight, port, target)
}
pub fn txt(
service_instance: impl Into<DnsName>,
strings: impl IntoIterator<Item = impl AsRef<[u8]>>,
ttl: u32,
) -> DnsRecord {
DnsRecord::new(
service_instance,
DNS_TYPE_TXT,
DNS_CLASS_IN,
ttl,
DnsRecordData::Txt(
strings
.into_iter()
.map(|string| string.as_ref().to_vec())
.collect(),
),
)
}
pub fn a(host: impl Into<DnsName>, address: Ipv4Addr, ttl: u32) -> DnsRecord {
DnsRecord::a(host, address, ttl)
}
pub fn aaaa(host: impl Into<DnsName>, address: Ipv6Addr, ttl: u32) -> DnsRecord {
DnsRecord::aaaa(host, address, ttl)
}
pub fn mdns_udp() -> Udp {
Udp::new()
.source_port(MDNS_PORT)
.destination_port(MDNS_PORT)
}
pub fn udp() -> Udp {
mdns_udp()
}
pub fn udp_unicast_reply(source_port: u16, destination_port: u16) -> Udp {
Udp::new()
.source_port(source_port)
.destination_port(destination_port)
}
pub fn mdns_ipv4(source: Ipv4Addr) -> Ipv4 {
Ipv4::new()
.src(source)
.dst(MDNS_IPV4_MULTICAST)
.ttl(MDNS_RESPONSE_TTL)
}
pub fn ipv4_multicast(source: Ipv4Addr) -> Ipv4 {
mdns_ipv4(source)
}
pub fn ipv4_response(source: Ipv4Addr) -> Ipv4 {
mdns_ipv4(source)
}
pub fn mdns_ipv6(source: Ipv6Addr) -> Ipv6 {
Ipv6::new()
.src(source)
.dst(MDNS_IPV6_LINK_LOCAL_MULTICAST)
.hop_limit(MDNS_RESPONSE_HOP_LIMIT)
}
pub fn ipv6_multicast(source: Ipv6Addr) -> Ipv6 {
mdns_ipv6(source)
}
pub fn ipv6_response(source: Ipv6Addr) -> Ipv6 {
mdns_ipv6(source)
}
pub fn mdns_ethernet_ipv4(source: MacAddr) -> Ethernet {
Ethernet::new()
.src(source)
.dst(MDNS_IPV4_ETHERNET_MULTICAST)
.ethertype(ETHERTYPE_IPV4)
}
pub fn ethernet_ipv4_multicast(source: MacAddr) -> Ethernet {
mdns_ethernet_ipv4(source)
}
pub fn mdns_ethernet_ipv6(source: MacAddr) -> Ethernet {
Ethernet::new()
.src(source)
.dst(MDNS_IPV6_ETHERNET_MULTICAST)
.ethertype(ETHERTYPE_IPV6)
}
pub fn ethernet_ipv6_multicast(source: MacAddr) -> Ethernet {
mdns_ethernet_ipv6(source)
}
pub fn mdns_ipv4_packet(source: Ipv4Addr, dns: Dns) -> Packet {
mdns_ipv4(source) / mdns_udp() / dns
}
pub fn ipv4_packet(source: Ipv4Addr, dns: Dns) -> Packet {
mdns_ipv4_packet(source, dns)
}
pub fn mdns_ipv6_packet(source: Ipv6Addr, dns: Dns) -> Packet {
mdns_ipv6(source) / mdns_udp() / dns
}
pub fn ipv6_packet(source: Ipv6Addr, dns: Dns) -> Packet {
mdns_ipv6_packet(source, dns)
}
pub fn mdns_ethernet_ipv4_packet(source_mac: MacAddr, source_ip: Ipv4Addr, dns: Dns) -> Packet {
mdns_ethernet_ipv4(source_mac) / mdns_ipv4_packet(source_ip, dns)
}
pub fn ethernet_ipv4_packet(source_mac: MacAddr, source_ip: Ipv4Addr, dns: Dns) -> Packet {
mdns_ethernet_ipv4_packet(source_mac, source_ip, dns)
}
pub fn mdns_ethernet_ipv6_packet(source_mac: MacAddr, source_ip: Ipv6Addr, dns: Dns) -> Packet {
mdns_ethernet_ipv6(source_mac) / mdns_ipv6_packet(source_ip, dns)
}
pub fn ethernet_ipv6_packet(source_mac: MacAddr, source_ip: Ipv6Addr, dns: Dns) -> Packet {
mdns_ethernet_ipv6_packet(source_mac, source_ip, dns)
}
}
const DNS_NAME_POINTER_MASK: u8 = 0xc0;
const DNS_NAME_POINTER_TAG: u8 = 0xc0;
const DNS_MAX_LABEL_LEN: usize = 63;
const DNS_MAX_NAME_WIRE_LEN: usize = 255;
const DNS_QTYPE_ANY: u16 = 255;
const DNS_OPCODE_MASK: u16 = 0x7800;
const DNS_OPCODE_SHIFT: u16 = 11;
const DNS_RCODE_MASK: u16 = 0x000f;
const DNS_EDNS_EXTENDED_RCODE_SHIFT: u32 = 24;
const DNS_EDNS_VERSION_SHIFT: u32 = 16;
const DNS_EDNS_FLAGS_MASK: u32 = 0x0000_ffff;
macro_rules! impl_layer_object {
($type:ty) => {
fn clone_layer(&self) -> Box<dyn Layer> {
Box::new(self.clone())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
};
}
macro_rules! impl_layer_div {
($type:ty) => {
impl<R> Div<R> for $type
where
R: IntoPacket,
{
type Output = Packet;
fn div(self, rhs: R) -> Self::Output {
Packet::from_layer(self).concat(rhs)
}
}
};
}
#[derive(Debug, Clone)]
pub struct Dns {
id: Field<u16>,
flags: Field<u16>,
questions: Vec<DnsQuestion>,
answers: Vec<DnsRecord>,
authorities: Vec<DnsRecord>,
additionals: Vec<DnsRecord>,
multicast_dns: bool,
}
impl PartialEq for Dns {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
&& self.flags == other.flags
&& self.questions == other.questions
&& self.answers == other.answers
&& self.authorities == other.authorities
&& self.additionals == other.additionals
}
}
impl Eq for Dns {}
impl Dns {
pub fn new() -> Self {
Self {
id: Field::defaulted(0),
flags: Field::defaulted(DNS_FLAG_RECURSION_DESIRED),
questions: Vec::new(),
answers: Vec::new(),
authorities: Vec::new(),
additionals: Vec::new(),
multicast_dns: false,
}
}
fn mdns_message(flags: u16) -> Self {
Self {
id: Field::defaulted(0),
flags: Field::defaulted(flags),
questions: Vec::new(),
answers: Vec::new(),
authorities: Vec::new(),
additionals: Vec::new(),
multicast_dns: true,
}
}
pub fn query(name: impl Into<DnsName>, question_type: u16) -> Self {
Self::new().question(DnsQuestion::new(name, question_type))
}
pub fn a_query(name: impl Into<DnsName>) -> Self {
Self::query(name, DNS_TYPE_A)
}
pub fn aaaa_query(name: impl Into<DnsName>) -> Self {
Self::query(name, DNS_TYPE_AAAA)
}
pub fn mdns_query(question: DnsQuestion) -> Self {
Self::mdns_message(0).question(question)
}
pub fn mdns_query_for(name: impl Into<DnsName>, question_type: u16) -> Self {
Self::mdns_query(DnsQuestion::new(name, question_type))
}
pub fn mdns_response() -> Self {
Self::mdns_message(DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE)
}
pub fn mdns_response_with_answers(records: impl IntoIterator<Item = DnsRecord>) -> Self {
records
.into_iter()
.fold(Self::mdns_response(), |dns, record| dns.answer(record))
}
pub fn mdns_known_answer_query(
question: DnsQuestion,
known_answers: impl IntoIterator<Item = DnsRecord>,
) -> Self {
Self::mdns_query(question).mdns_known_answers(known_answers)
}
pub fn mdns_continued_known_answer_query(
question: DnsQuestion,
known_answers: impl IntoIterator<Item = DnsRecord>,
) -> Self {
Self::mdns_known_answer_query(question, known_answers).flags(DNS_FLAG_TRUNCATED)
}
pub fn mdns_response_with_additionals(
answers: impl IntoIterator<Item = DnsRecord>,
additionals: impl IntoIterator<Item = DnsRecord>,
) -> Self {
Self::mdns_response_with_answers(answers).mdns_additional_records(additionals)
}
pub fn mdns_announcement(records: impl IntoIterator<Item = DnsRecord>) -> Self {
records
.into_iter()
.fold(Self::mdns_response(), |dns, record| {
dns.answer(record.mdns_cache_flush(true))
})
}
pub fn mdns_announce_response(
records: impl IntoIterator<Item = DnsRecord>,
additionals: impl IntoIterator<Item = DnsRecord>,
) -> Self {
additionals
.into_iter()
.fold(Self::mdns_announcement(records), |dns, record| {
dns.additional(record.mdns_cache_flush(true))
})
}
pub fn mdns_goodbye_response(records: impl IntoIterator<Item = DnsRecord>) -> Self {
records
.into_iter()
.fold(Self::mdns_response(), |dns, record| {
dns.answer(record.mdns_goodbye())
})
}
pub fn mdns_probe(name: impl Into<DnsName>) -> Self {
Self::mdns_probe_for(name, DNS_QTYPE_ANY)
}
pub fn mdns_probe_for(name: impl Into<DnsName>, question_type: u16) -> Self {
Self::mdns_probe_with_authorities(
DnsQuestion::new(name, question_type).mdns_qu(true),
core::iter::empty(),
)
}
pub fn mdns_probe_with_authorities(
question: DnsQuestion,
proposed: impl IntoIterator<Item = DnsRecord>,
) -> Self {
proposed.into_iter().fold(
Self::mdns_message(0).question(question.mdns_qu(true)),
|dns, record| dns.authority(record),
)
}
pub fn id(mut self, id: u16) -> Self {
self.id.set_user(id);
self
}
pub fn flags(mut self, flags: u16) -> Self {
self.flags.set_user(flags);
self
}
pub fn response(self, enabled: bool) -> Self {
self.set_flag(DNS_FLAG_QR_RESPONSE, enabled)
}
pub fn authoritative(self, enabled: bool) -> Self {
self.set_flag(DNS_FLAG_AUTHORITATIVE, enabled)
}
pub fn recursion_desired(self, enabled: bool) -> Self {
self.set_flag(DNS_FLAG_RECURSION_DESIRED, enabled)
}
pub fn rd(self, enabled: bool) -> Self {
self.recursion_desired(enabled)
}
pub fn recursion_available(self, enabled: bool) -> Self {
self.set_flag(DNS_FLAG_RECURSION_AVAILABLE, enabled)
}
pub fn opcode(mut self, opcode: u8) -> Self {
let field = ((opcode as u16) << DNS_OPCODE_SHIFT) & DNS_OPCODE_MASK;
let flags = (self.flags_value() & !DNS_OPCODE_MASK) | field;
self.flags.set_user(flags);
self
}
pub fn rcode(mut self, rcode: u8) -> Self {
let flags = (self.flags_value() & !DNS_RCODE_MASK) | ((rcode as u16) & DNS_RCODE_MASK);
self.flags.set_user(flags);
self
}
pub fn question(mut self, question: DnsQuestion) -> Self {
self.questions.push(question);
self
}
pub fn answer(mut self, answer: DnsRecord) -> Self {
self.answers.push(answer);
self
}
pub fn mdns_known_answer(self, known_answer: DnsRecord) -> Self {
self.answer(known_answer)
}
pub fn mdns_known_answers(self, known_answers: impl IntoIterator<Item = DnsRecord>) -> Self {
known_answers
.into_iter()
.fold(self, |dns, record| dns.mdns_known_answer(record))
}
pub fn authority(mut self, authority: DnsRecord) -> Self {
self.authorities.push(authority);
self
}
pub fn additional(mut self, additional: DnsRecord) -> Self {
self.additionals.push(additional);
self
}
pub fn mdns_additional_record(self, additional: DnsRecord) -> Self {
self.additional(additional)
}
pub fn mdns_additional_records(self, additionals: impl IntoIterator<Item = DnsRecord>) -> Self {
additionals
.into_iter()
.fold(self, |dns, record| dns.mdns_additional_record(record))
}
pub fn id_value(&self) -> u16 {
value_or_copy(&self.id, 0)
}
pub fn flags_value(&self) -> u16 {
value_or_copy(&self.flags, DNS_FLAG_RECURSION_DESIRED)
}
pub fn is_response(&self) -> bool {
self.flags_value() & DNS_FLAG_QR_RESPONSE != 0
}
pub fn opcode_value(&self) -> u8 {
((self.flags_value() & DNS_OPCODE_MASK) >> DNS_OPCODE_SHIFT) as u8
}
pub fn rcode_value(&self) -> u8 {
(self.flags_value() & DNS_RCODE_MASK) as u8
}
pub fn questions(&self) -> &[DnsQuestion] {
&self.questions
}
pub fn answers(&self) -> &[DnsRecord] {
&self.answers
}
pub fn authorities(&self) -> &[DnsRecord] {
&self.authorities
}
pub fn additionals(&self) -> &[DnsRecord] {
&self.additionals
}
fn set_flag(mut self, bit: u16, enabled: bool) -> Self {
let mut flags = self.flags_value();
if enabled {
flags |= bit;
} else {
flags &= !bit;
}
self.flags.set_user(flags);
self
}
fn with_mdns_context(mut self) -> Self {
self.multicast_dns = true;
self
}
fn has_mdns_class_bits(&self) -> bool {
self.questions
.iter()
.any(DnsQuestion::mdns_unicast_response_preferred_value)
|| dns_records(self).any(DnsRecord::mdns_cache_flush_value)
}
fn should_show_mdns(&self) -> bool {
self.multicast_dns || self.has_mdns_class_bits()
}
fn mdns_summary(&self) -> String {
let direction = if self.is_response() {
"response"
} else {
"query"
};
let question = self.questions.first();
let record = if self.is_response() {
first_dns_record(self)
} else {
None
};
let focus = record
.map(|record| format!(" {}", record_summary(record)))
.or_else(|| question.map(|question| format!(" {}", question_summary(question))))
.unwrap_or_default();
let mut details = format!(
"mDNS(id=0x{:04x}, flags=0x{:04x}, {direction}{focus}",
self.id_value(),
self.flags_value()
);
if let Some(question) = question {
details.push_str(&format!(
", unicast_response={}",
question.mdns_unicast_response_preferred_value()
));
}
if let Some(record) = record {
details.push_str(&format!(
", cache_flush={}",
record.mdns_cache_flush_value()
));
}
if !self.questions.is_empty() {
details.push_str(&format!(
", unicast_response_questions={}",
mdns_unicast_response_question_count(self)
));
}
if dns_records(self).next().is_some() {
details.push_str(&format!(
", cache_flush_records={}",
mdns_cache_flush_record_count(self)
));
}
details.push_str(&format!(
", questions={}, answers={}, authorities={}, additionals={})",
self.questions.len(),
self.answers.len(),
self.authorities.len(),
self.additionals.len()
));
details
}
fn encoded_message_len(&self) -> usize {
DNS_HEADER_LEN
+ self
.questions
.iter()
.map(DnsQuestion::encoded_len)
.sum::<usize>()
+ self
.answers
.iter()
.map(DnsRecord::encoded_len)
.sum::<usize>()
+ self
.authorities
.iter()
.map(DnsRecord::encoded_len)
.sum::<usize>()
+ self
.additionals
.iter()
.map(DnsRecord::encoded_len)
.sum::<usize>()
}
fn encode(&self, out: &mut Vec<u8>) -> Result<()> {
validate_count("dns.qdcount", self.questions.len())?;
validate_count("dns.ancount", self.answers.len())?;
validate_count("dns.nscount", self.authorities.len())?;
validate_count("dns.arcount", self.additionals.len())?;
out.extend_from_slice(&self.id_value().to_be_bytes());
out.extend_from_slice(&self.flags_value().to_be_bytes());
out.extend_from_slice(&(self.questions.len() as u16).to_be_bytes());
out.extend_from_slice(&(self.answers.len() as u16).to_be_bytes());
out.extend_from_slice(&(self.authorities.len() as u16).to_be_bytes());
out.extend_from_slice(&(self.additionals.len() as u16).to_be_bytes());
for question in &self.questions {
question.encode(out)?;
}
for answer in &self.answers {
answer.encode(out)?;
}
for authority in &self.authorities {
authority.encode(out)?;
}
for additional in &self.additionals {
additional.encode(out)?;
}
Ok(())
}
}
impl Default for Dns {
fn default() -> Self {
Self::new()
}
}
impl Layer for Dns {
fn name(&self) -> &'static str {
"Dns"
}
fn summary(&self) -> String {
if self.should_show_mdns() {
return self.mdns_summary();
}
let direction = if self.is_response() {
"response"
} else {
"query"
};
let question = self
.questions
.first()
.map(|question| {
format!(
" {} {}",
question.name(),
record_type_summary(question.question_type())
)
})
.unwrap_or_default();
format!(
"Dns(id=0x{:04x}, {direction}{question}, answers={})",
self.id_value(),
self.answers.len()
)
}
fn inspection_fields(&self) -> Vec<(&'static str, String)> {
let mut fields = vec![
("id", format!("0x{:04x}", self.id_value())),
("flags", format!("0x{:04x}", self.flags_value())),
("qdcount", self.questions.len().to_string()),
("ancount", self.answers.len().to_string()),
("nscount", self.authorities.len().to_string()),
("arcount", self.additionals.len().to_string()),
];
if self.should_show_mdns() {
fields.push(("mDNS", "true".to_string()));
fields.push((
"mDNS_shape",
if self.is_response() {
"response".to_string()
} else {
"query".to_string()
},
));
if let Some(question) = self.questions.first() {
fields.push(("first_question", question_inspection_summary(question)));
}
if let Some(answer) = self.answers.first() {
fields.push(("first_answer", record_inspection_summary(answer)));
} else if let Some(record) = first_dns_record(self) {
fields.push(("first_record", record_inspection_summary(record)));
}
fields.push((
"unicast_response_questions",
mdns_unicast_response_question_count(self).to_string(),
));
fields.push((
"cache_flush_records",
mdns_cache_flush_record_count(self).to_string(),
));
let dns_sd_names = dns_sd_names_summary(self);
if !dns_sd_names.is_empty() {
fields.push(("dns_sd_names", dns_sd_names));
}
}
fields
}
fn encoded_len(&self) -> usize {
self.encoded_message_len()
}
fn compile(&self, _ctx: &LayerContext<'_>, out: &mut Vec<u8>) -> Result<()> {
self.encode(out)
}
impl_layer_object!(Dns);
}
impl_layer_div!(Dns);
pub(crate) fn append_dns_packet(packet: Packet, bytes: &[u8]) -> Result<Packet> {
let dns = decode_dns(bytes)?;
let dns = if packet_has_mdns_transport(&packet) {
dns.with_mdns_context()
} else {
dns
};
Ok(packet.push(dns))
}
fn decode_dns(bytes: &[u8]) -> Result<Dns> {
if bytes.len() < DNS_HEADER_LEN {
return Err(CrafterError::buffer_too_short(
"dns header",
DNS_HEADER_LEN,
bytes.len(),
));
}
let qdcount = read_u16_be(&bytes[4..6])? as usize;
let ancount = read_u16_be(&bytes[6..8])? as usize;
let nscount = read_u16_be(&bytes[8..10])? as usize;
let arcount = read_u16_be(&bytes[10..12])? as usize;
let mut offset = DNS_HEADER_LEN;
let mut questions = Vec::with_capacity(qdcount);
let mut answers = Vec::with_capacity(ancount);
let mut authorities = Vec::with_capacity(nscount);
let mut additionals = Vec::with_capacity(arcount);
for _ in 0..qdcount {
let (question, next_offset) = decode_question(bytes, offset)?;
questions.push(question);
offset = next_offset;
}
for _ in 0..ancount {
let (record, next_offset) = decode_record(bytes, offset)?;
answers.push(record);
offset = next_offset;
}
for _ in 0..nscount {
let (record, next_offset) = decode_record(bytes, offset)?;
authorities.push(record);
offset = next_offset;
}
for _ in 0..arcount {
let (record, next_offset) = decode_record(bytes, offset)?;
additionals.push(record);
offset = next_offset;
}
if offset != bytes.len() {
return Err(CrafterError::invalid_field_value(
"dns.length",
"DNS message has trailing bytes after declared records",
));
}
Ok(Dns {
id: Field::user(read_u16_be(&bytes[0..2])?),
flags: Field::user(read_u16_be(&bytes[2..4])?),
questions,
answers,
authorities,
additionals,
multicast_dns: false,
})
}
fn decode_question(bytes: &[u8], offset: usize) -> Result<(DnsQuestion, usize)> {
let (name, consumed) = decode_dns_name_typed(bytes, offset)?;
let fields_offset = offset + consumed;
if fields_offset + 4 > bytes.len() {
return Err(CrafterError::buffer_too_short(
"dns question",
fields_offset + 4,
bytes.len(),
));
}
Ok((
DnsQuestion {
name,
question_type: read_u16_be(&bytes[fields_offset..fields_offset + 2])?,
question_class: read_u16_be(&bytes[fields_offset + 2..fields_offset + 4])?,
},
fields_offset + 4,
))
}
fn decode_record(bytes: &[u8], offset: usize) -> Result<(DnsRecord, usize)> {
let (name, consumed) = decode_dns_name_typed(bytes, offset)?;
let fields_offset = offset + consumed;
if fields_offset + 10 > bytes.len() {
return Err(CrafterError::buffer_too_short(
"dns record",
fields_offset + 10,
bytes.len(),
));
}
let record_type = read_u16_be(&bytes[fields_offset..fields_offset + 2])?;
let class = read_u16_be(&bytes[fields_offset + 2..fields_offset + 4])?;
let ttl = read_u32_be(&bytes[fields_offset + 4..fields_offset + 8])?;
let rdlength = read_u16_be(&bytes[fields_offset + 8..fields_offset + 10])? as usize;
let rdata_start = fields_offset + 10;
let rdata_end = rdata_start + rdlength;
if rdata_end > bytes.len() {
return Err(CrafterError::buffer_too_short(
"dns rdata",
rdata_end,
bytes.len(),
));
}
let data = decode_record_data(record_type, bytes, rdata_start, rdata_end)?;
Ok((
DnsRecord {
name,
record_type,
class,
ttl,
data,
},
rdata_end,
))
}
fn validate_count(field: &'static str, count: usize) -> Result<()> {
if count > u16::MAX as usize {
return Err(CrafterError::invalid_field_value(
field,
"DNS section count exceeds 65535",
));
}
Ok(())
}
fn value_or_copy<T: Copy>(field: &Field<T>, default: T) -> T {
field.value().copied().unwrap_or(default)
}
fn record_type_summary(record_type: u16) -> String {
match dns_type_name(record_type) {
Some(name) => name.to_string(),
None => format!("TYPE{record_type}"),
}
}
fn packet_has_mdns_transport(packet: &Packet) -> bool {
packet.layer::<Udp>().is_some_and(|udp| {
udp.source_port_value() == MDNS_PORT || udp.destination_port_value() == MDNS_PORT
})
}
fn dns_records(dns: &Dns) -> impl Iterator<Item = &DnsRecord> {
dns.answers
.iter()
.chain(dns.authorities.iter())
.chain(dns.additionals.iter())
}
fn first_dns_record(dns: &Dns) -> Option<&DnsRecord> {
dns_records(dns).next()
}
fn mdns_unicast_response_question_count(dns: &Dns) -> usize {
dns.questions
.iter()
.filter(|question| question.mdns_unicast_response_preferred_value())
.count()
}
fn mdns_cache_flush_record_count(dns: &Dns) -> usize {
dns_records(dns)
.filter(|record| record.mdns_cache_flush_value())
.count()
}
fn question_summary(question: &DnsQuestion) -> String {
format!(
"{} {}",
question.name(),
record_type_summary(question.question_type())
)
}
fn question_inspection_summary(question: &DnsQuestion) -> String {
format!(
"{} class=0x{:04x} base_class=0x{:04x} unicast_response={}",
question_summary(question),
question.question_class(),
question.mdns_base_question_class(),
question.mdns_unicast_response_preferred_value()
)
}
fn record_summary(record: &DnsRecord) -> String {
format!(
"{} {}{}",
record.name(),
record_type_summary(record.record_type()),
record_data_summary(record.data())
)
}
fn record_inspection_summary(record: &DnsRecord) -> String {
format!(
"{} class=0x{:04x} base_class=0x{:04x} ttl={} cache_flush={}",
record_summary(record),
record.class(),
record.mdns_base_class(),
record.ttl(),
record.mdns_cache_flush_value()
)
}
fn record_data_summary(data: &DnsRecordData) -> String {
match data {
DnsRecordData::A(address) => format!(" -> {address}"),
DnsRecordData::Aaaa(address) => format!(" -> {address}"),
DnsRecordData::Name(name) => format!(" -> {}", name.presentation()),
DnsRecordData::Srv { target, port, .. } => {
format!(" -> {}:{port}", target.presentation())
}
DnsRecordData::Txt(strings) => format!(" strings={}", strings.len()),
DnsRecordData::Raw(bytes) => format!(" raw_len={}", bytes.len()),
_ => String::new(),
}
}
fn dns_sd_names_summary(dns: &Dns) -> String {
let mut names = Vec::new();
for question in &dns.questions {
push_dns_sd_name(&mut names, question.dns_name());
}
for record in dns_records(dns) {
push_record_dns_sd_names(&mut names, record);
}
names.join(", ")
}
fn push_record_dns_sd_names(names: &mut Vec<String>, record: &DnsRecord) {
push_dns_sd_name(names, record.dns_name());
match record.data() {
DnsRecordData::Name(name) => push_dns_sd_name(names, name),
DnsRecordData::Srv { target, .. } => push_dns_sd_name(names, target),
_ => {}
}
}
fn push_dns_sd_name(names: &mut Vec<String>, name: &DnsName) {
let presentation = name.presentation();
if !is_dns_sd_service_name(presentation) {
return;
}
if !names.iter().any(|existing| existing == presentation) {
names.push(presentation.to_string());
}
}
fn is_dns_sd_service_name(name: &str) -> bool {
name.starts_with('_') || name.contains("._")
}
pub fn dns_type_name(record_type: u16) -> Option<&'static str> {
Some(match record_type {
DNS_TYPE_A => "A",
DNS_TYPE_NS => "NS",
DNS_TYPE_CNAME => "CNAME",
DNS_TYPE_SOA => "SOA",
DNS_TYPE_PTR => "PTR",
DNS_TYPE_MX => "MX",
DNS_TYPE_TXT => "TXT",
DNS_TYPE_AAAA => "AAAA",
DNS_TYPE_SRV => "SRV",
DNS_TYPE_OPT => "OPT",
DNS_TYPE_DS => "DS",
DNS_TYPE_RRSIG => "RRSIG",
DNS_TYPE_NSEC => "NSEC",
DNS_TYPE_DNSKEY => "DNSKEY",
DNS_TYPE_NSEC3 => "NSEC3",
DNS_TYPE_NSEC3PARAM => "NSEC3PARAM",
DNS_TYPE_TLSA => "TLSA",
DNS_TYPE_SVCB => "SVCB",
DNS_TYPE_HTTPS => "HTTPS",
_ => return None,
})
}
#[cfg(test)]
mod mdns_summary_show_tests {
use core::net::Ipv4Addr;
use crate::{NetworkLayer, Packet};
use super::{mdns, Dns, DnsQuestion, DNS_SD_DEFAULT_DOMAIN, DNS_TYPE_A, DNS_TYPE_PTR};
#[test]
fn mdns_summary_show_query_exposes_shape_and_unicast_response() -> crate::Result<()> {
let service_type = mdns::dns_sd_tcp_service_name("ipp", DNS_SD_DEFAULT_DOMAIN)?;
let dns = mdns::query(DnsQuestion::new(service_type.clone(), DNS_TYPE_PTR).mdns_qu(true));
let packet = mdns::mdns_ipv4_packet(Ipv4Addr::new(192, 0, 2, 10), dns);
let summary = packet.summary();
assert!(summary.contains("mDNS(id=0x0000, flags=0x0000, query _ipp._tcp.local. PTR"));
assert!(summary.contains("unicast_response=true"));
assert!(summary.contains("answers=0"));
let show = packet.show();
assert!(show.contains("id: 0x0000"));
assert!(show.contains("flags: 0x0000"));
assert!(show.contains("mDNS: true"));
assert!(show.contains("mDNS_shape: query"));
assert!(show.contains(
"first_question: _ipp._tcp.local. PTR class=0x8001 base_class=0x0001 unicast_response=true"
));
assert!(show.contains("dns_sd_names: _ipp._tcp.local."));
Ok(())
}
#[test]
fn mdns_summary_show_response_exposes_cache_flush_and_service_names() -> crate::Result<()> {
let service_type = mdns::dns_sd_tcp_service_name("ipp", DNS_SD_DEFAULT_DOMAIN)?;
let service_instance =
mdns::dns_sd_tcp_instance_name("Office\\032Printer", "ipp", DNS_SD_DEFAULT_DOMAIN)?;
let answer =
mdns::service_ptr(service_type, service_instance.clone(), 4500).mdns_cache_flush(true);
let dns = mdns::response_with_answers([answer]);
let packet = mdns::mdns_ipv4_packet(Ipv4Addr::new(192, 0, 2, 20), dns);
let summary = packet.summary();
assert!(summary.contains("mDNS(id=0x0000, flags=0x8400, response _ipp._tcp.local. PTR"));
assert!(summary.contains("Office\\032Printer._ipp._tcp.local."));
assert!(summary.contains("cache_flush=true"));
assert!(summary.contains("answers=1"));
let show = packet.show();
assert!(show.contains("mDNS_shape: response"));
assert!(show.contains("first_answer: _ipp._tcp.local. PTR"));
assert!(show.contains("cache_flush=true"));
assert!(show.contains("cache_flush_records: 1"));
assert!(
show.contains("dns_sd_names: _ipp._tcp.local., Office\\032Printer._ipp._tcp.local.")
);
assert!(show.contains("flags: 0x8400"));
assert_eq!(
service_instance.presentation(),
"Office\\032Printer._ipp._tcp.local."
);
Ok(())
}
#[test]
fn mdns_summary_show_decoded_udp_5353_keeps_context() -> crate::Result<()> {
let packet = mdns::mdns_ipv4_packet(
Ipv4Addr::new(192, 0, 2, 30),
mdns::query_for("printer.local.", DNS_TYPE_A),
);
let bytes = packet.compile()?;
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes())?;
assert_eq!(decoded.compile()?, bytes);
assert!(decoded.summary().contains("mDNS(id=0x0000"));
assert!(decoded.show().contains("mDNS: true"));
assert!(decoded.show().contains("mDNS_shape: query"));
Ok(())
}
#[test]
fn mdns_summary_show_standard_dns_summary_stays_dns() {
let packet = Dns::a_query("example.com.").id(0x1234);
assert_eq!(
Packet::from_layer(packet).summary(),
"Dns(id=0x1234, query example.com. A, answers=0)"
);
}
}
#[cfg(test)]
mod mdns_message_builders_tests {
use core::net::Ipv4Addr;
use crate::Udp;
use super::{
decode_dns, mdns, Dns, DnsQuestion, DnsRecord, DNS_CLASS_IN, DNS_FLAG_AUTHORITATIVE,
DNS_FLAG_QR_RESPONSE, DNS_FLAG_RECURSION_DESIRED, DNS_QTYPE_ANY, DNS_TYPE_A, DNS_TYPE_AAAA,
MDNS_CLASS_BIT, MDNS_PORT,
};
#[test]
fn mdns_message_builders_query_defaults_are_nonrecursive_and_compile() {
let question = DnsQuestion::new("printer.local.", DNS_TYPE_A).mdns_qu(true);
let dns = mdns::query(question.clone());
assert_eq!(dns.id_value(), 0);
assert_eq!(dns.flags_value(), 0);
assert!(!dns.is_response());
assert_eq!(dns.questions(), &[question]);
assert!(dns.answers().is_empty());
assert!(dns.authorities().is_empty());
assert!(dns.additionals().is_empty());
let compiled = (Udp::new().sport(MDNS_PORT).dport(MDNS_PORT) / dns)
.compile()
.unwrap();
let bytes = compiled.as_bytes();
assert_eq!(&bytes[8..10], &0u16.to_be_bytes());
assert_eq!(&bytes[10..12], &0u16.to_be_bytes());
assert_eq!(&bytes[12..14], &1u16.to_be_bytes());
assert_eq!(&bytes[14..16], &0u16.to_be_bytes());
}
#[test]
fn mdns_message_builders_response_defaults_are_authoritative() {
let answer = DnsRecord::a("printer.local.", Ipv4Addr::new(192, 0, 2, 10), 120)
.mdns_cache_flush(true);
let dns = mdns::response_with_answers(core::iter::once(answer.clone()));
assert_eq!(dns.id_value(), 0);
assert_eq!(
dns.flags_value(),
DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE
);
assert!(dns.is_response());
assert!(dns.questions().is_empty());
assert_eq!(dns.answers(), &[answer]);
assert!(dns.authorities().is_empty());
let compiled = (Udp::new().sport(MDNS_PORT).dport(MDNS_PORT) / dns)
.compile()
.unwrap();
let decoded = decode_dns(&compiled.as_bytes()[8..]).unwrap();
assert_eq!(
decoded.flags_value(),
DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE
);
assert_eq!(decoded.answers().len(), 1);
assert!(decoded.answers()[0].mdns_cache_flush_value());
}
#[test]
fn mdns_message_builders_probe_uses_qu_and_authority_records() {
let any_probe = Dns::mdns_probe("printer.local.");
assert_eq!(any_probe.flags_value(), 0);
assert_eq!(any_probe.questions()[0].question_type(), DNS_QTYPE_ANY);
assert_eq!(
any_probe.questions()[0].question_class(),
MDNS_CLASS_BIT | DNS_CLASS_IN
);
assert!(any_probe.questions()[0].mdns_unicast_response_preferred_value());
let a_probe = mdns::probe_for("printer.local.", DNS_TYPE_A);
assert_eq!(a_probe.questions()[0].question_type(), DNS_TYPE_A);
assert_eq!(
a_probe.questions()[0].question_class(),
MDNS_CLASS_BIT | DNS_CLASS_IN
);
assert!(a_probe.questions()[0].mdns_unicast_response_preferred_value());
let proposed = DnsRecord::a("printer.local.", Ipv4Addr::new(192, 0, 2, 11), 120)
.mdns_cache_flush(true);
let typed_probe = mdns::probe_with_authorities(
DnsQuestion::new("printer.local.", DNS_TYPE_AAAA).qclass(DNS_CLASS_IN),
core::iter::once(proposed.clone()),
);
assert_eq!(typed_probe.questions()[0].question_type(), DNS_TYPE_AAAA);
assert_eq!(
typed_probe.questions()[0].question_class(),
MDNS_CLASS_BIT | DNS_CLASS_IN
);
assert!(typed_probe.answers().is_empty());
assert_eq!(typed_probe.authorities(), &[proposed]);
assert!(typed_probe.additionals().is_empty());
}
#[test]
fn mdns_message_builders_raw_id_and_flags_overrides_are_preserved() {
let dns = Dns::mdns_query_for("raw.local.", DNS_TYPE_A)
.id(0x4567)
.flags(0xabcd);
assert_eq!(dns.id_value(), 0x4567);
assert_eq!(dns.flags_value(), 0xabcd);
let compiled = (Udp::new().sport(MDNS_PORT).dport(MDNS_PORT) / dns)
.compile()
.unwrap();
let decoded = decode_dns(&compiled.as_bytes()[8..]).unwrap();
assert_eq!(decoded.id_value(), 0x4567);
assert_eq!(decoded.flags_value(), 0xabcd);
let response_override = mdns::response().flags(DNS_FLAG_RECURSION_DESIRED);
assert_eq!(response_override.flags_value(), DNS_FLAG_RECURSION_DESIRED);
assert!(!response_override.is_response());
}
}
#[cfg(test)]
mod mdns_known_answer_helpers_tests {
use core::net::Ipv4Addr;
use crate::Packet;
use super::{
decode_dns, mdns, Dns, DnsQuestion, DnsRecord, DNS_CLASS_IN, DNS_FLAG_AUTHORITATIVE,
DNS_FLAG_QR_RESPONSE, DNS_FLAG_TRUNCATED, MDNS_CLASS_BIT,
};
fn round_trip_dns(dns: Dns) -> crate::Result<(Dns, Vec<u8>, Vec<u8>)> {
let compiled = Packet::from_layer(dns).compile()?;
let decoded = decode_dns(compiled.as_bytes())?;
let recompiled = Packet::from_layer(decoded.clone()).compile()?;
Ok((
decoded,
compiled.as_bytes().to_vec(),
recompiled.as_bytes().to_vec(),
))
}
#[test]
fn mdns_known_answer_query_helpers_keep_records_in_answer_section() -> crate::Result<()> {
let question = DnsQuestion::a("printer.local.").mdns_qu(true);
let known_answer = DnsRecord::a("printer.local.", Ipv4Addr::new(192, 0, 2, 10), 119)
.mdns_cache_flush(true);
let dns = mdns::known_answer_query(question.clone(), [known_answer.clone()])
.mdns_known_answer(
DnsRecord::a("printer.local.", Ipv4Addr::new(192, 0, 2, 11), 42)
.mdns_cache_flush(false),
);
assert_eq!(dns.id_value(), 0);
assert_eq!(dns.flags_value(), 0);
assert_eq!(dns.questions(), &[question]);
assert_eq!(dns.answers().len(), 2);
assert!(dns.authorities().is_empty());
assert!(dns.additionals().is_empty());
let (decoded, original, recompiled) = round_trip_dns(dns)?;
assert_eq!(original, recompiled);
assert_eq!(
decoded.questions()[0].question_class(),
MDNS_CLASS_BIT | DNS_CLASS_IN
);
assert_eq!(decoded.answers()[0], known_answer);
assert_eq!(decoded.answers()[0].ttl(), 119);
assert_eq!(decoded.answers()[0].class(), MDNS_CLASS_BIT | DNS_CLASS_IN);
assert!(decoded.answers()[0].mdns_cache_flush_value());
assert_eq!(decoded.answers()[1].ttl(), 42);
assert_eq!(decoded.answers()[1].class(), DNS_CLASS_IN);
assert!(!decoded.answers()[1].mdns_cache_flush_value());
assert!(decoded.additionals().is_empty());
Ok(())
}
#[test]
fn mdns_known_answer_continued_query_sets_tc_without_moving_answers() -> crate::Result<()> {
let known_answer = DnsRecord::a("printer.local.", Ipv4Addr::new(192, 0, 2, 12), 120)
.mdns_cache_flush(true);
let dns = mdns::continued_known_answer_query(
DnsQuestion::a("printer.local."),
[known_answer.clone()],
);
let (decoded, original, recompiled) = round_trip_dns(dns)?;
assert_eq!(original, recompiled);
assert_eq!(decoded.flags_value(), DNS_FLAG_TRUNCATED);
assert_eq!(decoded.answers(), &[known_answer]);
assert!(decoded.authorities().is_empty());
assert!(decoded.additionals().is_empty());
Ok(())
}
#[test]
fn mdns_known_answer_additional_helpers_keep_section() -> crate::Result<()> {
let answer = DnsRecord::a("printer.local.", Ipv4Addr::new(192, 0, 2, 20), 120)
.mdns_cache_flush(true);
let additional = DnsRecord::a("printer.local.", Ipv4Addr::new(192, 0, 2, 21), 121)
.mdns_cache_flush(true);
let chained_additional = DnsRecord::a("printer.local.", Ipv4Addr::new(192, 0, 2, 22), 122)
.mdns_cache_flush(false);
let dns = mdns::response_with_additionals([answer.clone()], [additional.clone()])
.mdns_additional_record(chained_additional.clone());
let (decoded, original, recompiled) = round_trip_dns(dns)?;
assert_eq!(original, recompiled);
assert_eq!(
decoded.flags_value(),
DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE
);
assert!(decoded.questions().is_empty());
assert_eq!(decoded.answers(), &[answer]);
assert!(decoded.authorities().is_empty());
assert_eq!(
decoded.additionals(),
&[additional.clone(), chained_additional.clone()]
);
assert_eq!(decoded.additionals()[0].ttl(), 121);
assert_eq!(
decoded.additionals()[0].class(),
MDNS_CLASS_BIT | DNS_CLASS_IN
);
assert!(decoded.additionals()[0].mdns_cache_flush_value());
assert_eq!(decoded.additionals()[1].ttl(), 122);
assert_eq!(decoded.additionals()[1].class(), DNS_CLASS_IN);
assert!(!decoded.additionals()[1].mdns_cache_flush_value());
Ok(())
}
}
#[cfg(test)]
mod mdns_goodbye_probe_announce_tests {
use core::net::{Ipv4Addr, Ipv6Addr};
use crate::Packet;
use super::{
decode_dns, mdns, Dns, DnsQuestion, DnsRecord, DnsRecordData, DNS_CLASS_IN,
DNS_FLAG_AUTHORITATIVE, DNS_FLAG_QR_RESPONSE, DNS_QTYPE_ANY, DNS_TYPE_A, DNS_TYPE_AAAA,
MDNS_CLASS_BIT, MDNS_GOODBYE_TTL,
};
const ODD_CLASS: u16 = 0x1234;
const HIGH_CLASS: u16 = 0xf234;
fn round_trip_dns(dns: Dns) -> crate::Result<(Dns, Vec<u8>, Vec<u8>)> {
let compiled = Packet::from_layer(dns).compile()?;
let decoded = decode_dns(compiled.as_bytes())?;
let recompiled = Packet::from_layer(decoded.clone()).compile()?;
Ok((
decoded,
compiled.as_bytes().to_vec(),
recompiled.as_bytes().to_vec(),
))
}
#[test]
fn mdns_goodbye_probe_announce_default_helpers_build_expected_shapes() -> crate::Result<()> {
let host = "printer.local.";
let flushed = mdns::cache_flush(DnsRecord::a(host, Ipv4Addr::new(192, 0, 2, 29), 119));
assert_eq!(flushed.ttl(), 119);
assert_eq!(flushed.class(), MDNS_CLASS_BIT | DNS_CLASS_IN);
assert!(flushed.mdns_cache_flush_value());
let goodbye = mdns::goodbye(
DnsRecord::a(host, Ipv4Addr::new(192, 0, 2, 30), 120).mdns_cache_flush(true),
);
assert_eq!(goodbye.ttl(), MDNS_GOODBYE_TTL);
assert_eq!(goodbye.class(), MDNS_CLASS_BIT | DNS_CLASS_IN);
assert!(goodbye.mdns_cache_flush_value());
let goodbye_dns = mdns::goodbye_response([DnsRecord::aaaa(
host,
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x30),
4500,
)
.mdns_cache_flush(true)]);
assert_eq!(
goodbye_dns.flags_value(),
DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE
);
assert!(goodbye_dns.questions().is_empty());
assert_eq!(goodbye_dns.answers().len(), 1);
assert_eq!(goodbye_dns.answers()[0].ttl(), MDNS_GOODBYE_TTL);
assert!(goodbye_dns.answers()[0].mdns_cache_flush_value());
let probe = mdns::probe(host);
assert_eq!(probe.flags_value(), 0);
assert_eq!(probe.questions().len(), 1);
assert_eq!(probe.questions()[0].question_type(), DNS_QTYPE_ANY);
assert_eq!(
probe.questions()[0].question_class(),
MDNS_CLASS_BIT | DNS_CLASS_IN
);
assert!(probe.questions()[0].mdns_unicast_response_preferred_value());
assert!(probe.answers().is_empty());
assert!(probe.authorities().is_empty());
assert!(probe.additionals().is_empty());
let announcement =
mdns::announcement([DnsRecord::a(host, Ipv4Addr::new(192, 0, 2, 31), 120)]);
assert_eq!(
announcement.flags_value(),
DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE
);
assert_eq!(announcement.answers()[0].ttl(), 120);
assert_eq!(
announcement.answers()[0].class(),
MDNS_CLASS_BIT | DNS_CLASS_IN
);
assert!(announcement.answers()[0].mdns_cache_flush_value());
let announce_response = mdns::announce_response(
[DnsRecord::a(host, Ipv4Addr::new(192, 0, 2, 32), 121)],
[DnsRecord::aaaa(
host,
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x32),
122,
)],
);
assert!(announce_response.questions().is_empty());
assert_eq!(announce_response.answers()[0].ttl(), 121);
assert!(announce_response.answers()[0].mdns_cache_flush_value());
assert_eq!(announce_response.additionals()[0].ttl(), 122);
assert!(announce_response.additionals()[0].mdns_cache_flush_value());
let (decoded, original, recompiled) = round_trip_dns(announce_response)?;
assert_eq!(original, recompiled);
assert!(decoded.answers()[0].mdns_cache_flush_value());
assert!(decoded.additionals()[0].mdns_cache_flush_value());
Ok(())
}
#[test]
fn mdns_goodbye_probe_announce_unusual_overrides_preserve_raw_values() -> crate::Result<()> {
let proposed = DnsRecord::new(
"odd.local.",
DNS_TYPE_A,
HIGH_CLASS,
0x0102_0304,
DnsRecordData::A(Ipv4Addr::new(192, 0, 2, 40)),
);
let probe = mdns::probe_with_authorities(
DnsQuestion::new("odd.local.", DNS_TYPE_AAAA).qclass(ODD_CLASS),
[proposed.clone()],
)
.id(0xbeef)
.flags(0x1234);
assert_eq!(probe.id_value(), 0xbeef);
assert_eq!(probe.flags_value(), 0x1234);
assert_eq!(
probe.questions()[0].question_class(),
MDNS_CLASS_BIT | ODD_CLASS
);
assert_eq!(probe.questions()[0].mdns_base_question_class(), ODD_CLASS);
assert_eq!(probe.authorities(), &[proposed.clone()]);
assert_eq!(probe.authorities()[0].class(), HIGH_CLASS);
assert_eq!(probe.authorities()[0].ttl(), 0x0102_0304);
let announcement_record = DnsRecord::new(
"odd.local.",
DNS_TYPE_A,
ODD_CLASS,
0xaabb_ccdd,
DnsRecordData::A(Ipv4Addr::new(192, 0, 2, 41)),
);
let announcement = mdns::announcement([announcement_record]);
assert_eq!(announcement.answers()[0].ttl(), 0xaabb_ccdd);
assert_eq!(
announcement.answers()[0].class(),
MDNS_CLASS_BIT | ODD_CLASS
);
assert_eq!(announcement.answers()[0].mdns_base_class(), ODD_CLASS);
let goodbye_source = DnsRecord::new(
"odd.local.",
DNS_TYPE_A,
HIGH_CLASS,
0xffff_ffff,
DnsRecordData::A(Ipv4Addr::new(192, 0, 2, 42)),
);
let goodbye = mdns::goodbye(goodbye_source);
assert_eq!(goodbye.ttl(), MDNS_GOODBYE_TTL);
assert_eq!(goodbye.class(), HIGH_CLASS);
assert!(goodbye.mdns_cache_flush_value());
let shared = mdns::shared_record(DnsRecord::new(
"odd.local.",
DNS_TYPE_A,
HIGH_CLASS,
77,
DnsRecordData::A(Ipv4Addr::new(192, 0, 2, 43)),
));
assert_eq!(shared.class(), HIGH_CLASS & !MDNS_CLASS_BIT);
assert_eq!(shared.ttl(), 77);
assert!(!shared.mdns_cache_flush_value());
let (decoded_probe, original_probe, recompiled_probe) = round_trip_dns(probe)?;
assert_eq!(original_probe, recompiled_probe);
assert_eq!(decoded_probe.id_value(), 0xbeef);
assert_eq!(decoded_probe.flags_value(), 0x1234);
assert_eq!(decoded_probe.authorities()[0].ttl(), 0x0102_0304);
assert_eq!(decoded_probe.authorities()[0].class(), HIGH_CLASS);
Ok(())
}
}
#[cfg(test)]
mod mdns_transport_builders_tests {
use core::net::{Ipv4Addr, Ipv6Addr};
use crate::{
Ethernet, Ipv4, Ipv6, LinkType, MacAddr, NetworkLayer, Packet, ProtocolRegistry, Udp,
};
use super::{
append_dns_packet, mdns, Dns, DNS_TYPE_A, DNS_TYPE_AAAA, MDNS_IPV4_ETHERNET_MULTICAST,
MDNS_IPV4_MULTICAST, MDNS_IPV6_ETHERNET_MULTICAST, MDNS_IPV6_LINK_LOCAL_MULTICAST,
MDNS_PORT, MDNS_RESPONSE_HOP_LIMIT, MDNS_RESPONSE_TTL,
};
fn mdns_registry() -> ProtocolRegistry {
let mut registry = ProtocolRegistry::new();
registry.bind_udp_port(MDNS_PORT, |packet, payload| {
append_dns_packet(packet, payload)
});
registry
}
#[test]
fn mdns_transport_builders_ipv4_packet_compiles_and_decodes() -> crate::Result<()> {
let source = Ipv4Addr::new(192, 0, 2, 10);
let dns = mdns::query_for("printer.local.", DNS_TYPE_A);
let packet = mdns::mdns_ipv4_packet(source, dns.clone());
let ipv4 = packet.layer::<Ipv4>().unwrap();
assert_eq!(ipv4.source(), source);
assert_eq!(ipv4.destination(), MDNS_IPV4_MULTICAST);
assert_eq!(ipv4.ttl_value(), MDNS_RESPONSE_TTL);
let udp = packet.layer::<Udp>().unwrap();
assert_eq!(udp.source_port_value(), MDNS_PORT);
assert_eq!(udp.destination_port_value(), MDNS_PORT);
let compiled = packet.compile()?;
let decoded = Packet::decode_from_l3_with_registry(
&mdns_registry(),
NetworkLayer::Ipv4,
compiled.as_bytes(),
)?;
let decoded_ipv4 = decoded.layer::<Ipv4>().unwrap();
assert_eq!(decoded_ipv4.source(), source);
assert_eq!(decoded_ipv4.destination(), MDNS_IPV4_MULTICAST);
assert_eq!(decoded_ipv4.ttl_value(), MDNS_RESPONSE_TTL);
let decoded_udp = decoded.layer::<Udp>().unwrap();
assert_eq!(decoded_udp.source_port_value(), MDNS_PORT);
assert_eq!(decoded_udp.destination_port_value(), MDNS_PORT);
let decoded_dns = decoded.layer::<Dns>().unwrap();
assert_eq!(decoded_dns.id_value(), dns.id_value());
assert_eq!(decoded_dns.flags_value(), dns.flags_value());
assert_eq!(decoded_dns.questions(), dns.questions());
Ok(())
}
#[test]
fn mdns_transport_builders_ipv6_ethernet_packet_compiles_and_decodes() -> crate::Result<()> {
let source_mac = MacAddr::new([0x02, 0x00, 0x5e, 0x10, 0x00, 0x08]);
let source_ip = Ipv6Addr::new(0xfe80, 0, 0, 0, 0x0200, 0x5eff, 0xfe10, 0x0008);
let dns = mdns::query_for("printer.local.", DNS_TYPE_AAAA);
let packet = mdns::mdns_ethernet_ipv6_packet(source_mac, source_ip, dns.clone());
let ethernet = packet.layer::<Ethernet>().unwrap();
assert_eq!(ethernet.source(), Some(source_mac));
assert_eq!(ethernet.destination(), Some(MDNS_IPV6_ETHERNET_MULTICAST));
let ipv6 = packet.layer::<Ipv6>().unwrap();
assert_eq!(ipv6.source(), source_ip);
assert_eq!(ipv6.destination(), MDNS_IPV6_LINK_LOCAL_MULTICAST);
assert_eq!(ipv6.hop_limit_value(), MDNS_RESPONSE_HOP_LIMIT);
let compiled = packet.compile()?;
let decoded = Packet::decode_from_link_with_registry(
&mdns_registry(),
LinkType::Ethernet,
compiled.as_bytes(),
)?;
let decoded_ethernet = decoded.layer::<Ethernet>().unwrap();
assert_eq!(decoded_ethernet.source(), Some(source_mac));
assert_eq!(
decoded_ethernet.destination(),
Some(MDNS_IPV6_ETHERNET_MULTICAST)
);
let decoded_ipv6 = decoded.layer::<Ipv6>().unwrap();
assert_eq!(decoded_ipv6.source(), source_ip);
assert_eq!(decoded_ipv6.destination(), MDNS_IPV6_LINK_LOCAL_MULTICAST);
assert_eq!(decoded_ipv6.hop_limit_value(), MDNS_RESPONSE_HOP_LIMIT);
let decoded_udp = decoded.layer::<Udp>().unwrap();
assert_eq!(decoded_udp.source_port_value(), MDNS_PORT);
assert_eq!(decoded_udp.destination_port_value(), MDNS_PORT);
let decoded_dns = decoded.layer::<Dns>().unwrap();
assert_eq!(decoded_dns.id_value(), dns.id_value());
assert_eq!(decoded_dns.flags_value(), dns.flags_value());
assert_eq!(decoded_dns.questions(), dns.questions());
Ok(())
}
#[test]
fn mdns_transport_builders_preserve_caller_overrides() {
let udp = mdns::mdns_udp().sport(53000).dport(53001);
assert_eq!(udp.source_port_value(), 53000);
assert_eq!(udp.destination_port_value(), 53001);
let ipv4 = mdns::mdns_ipv4(Ipv4Addr::new(192, 0, 2, 10))
.src(Ipv4Addr::new(192, 0, 2, 11))
.dst(Ipv4Addr::new(198, 51, 100, 20))
.ttl(42);
assert_eq!(ipv4.source(), Ipv4Addr::new(192, 0, 2, 11));
assert_eq!(ipv4.destination(), Ipv4Addr::new(198, 51, 100, 20));
assert_eq!(ipv4.ttl_value(), 42);
let ipv6 = mdns::mdns_ipv6(Ipv6Addr::LOCALHOST)
.src(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1))
.dst(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2))
.hop_limit(43);
assert_eq!(
ipv6.source(),
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)
);
assert_eq!(
ipv6.destination(),
Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 2)
);
assert_eq!(ipv6.hop_limit_value(), 43);
let source_mac = MacAddr::new([0x02, 0x00, 0x5e, 0x10, 0x00, 0x01]);
let override_source = MacAddr::new([0x02, 0x00, 0x5e, 0x10, 0x00, 0x02]);
let override_destination = MacAddr::new([0x02, 0x00, 0x5e, 0x10, 0x00, 0x03]);
let ethernet = mdns::mdns_ethernet_ipv4(source_mac)
.src(override_source)
.dst(override_destination);
assert_eq!(ethernet.source(), Some(override_source));
assert_eq!(ethernet.destination(), Some(override_destination));
let default_ethernet = mdns::mdns_ethernet_ipv4(source_mac);
assert_eq!(default_ethernet.source(), Some(source_mac));
assert_eq!(
default_ethernet.destination(),
Some(MDNS_IPV4_ETHERNET_MULTICAST)
);
}
#[test]
fn mdns_transport_alias_builders_match_primary_defaults() {
let source_ipv4 = Ipv4Addr::new(192, 0, 2, 60);
let source_ipv6 = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x60);
let source_mac = MacAddr::new([0x02, 0x00, 0x5e, 0x10, 0x00, 0x60]);
let udp = mdns::udp();
assert_eq!(udp.source_port_value(), MDNS_PORT);
assert_eq!(udp.destination_port_value(), MDNS_PORT);
let unicast_reply = mdns::udp_unicast_reply(MDNS_PORT, 62000);
assert_eq!(unicast_reply.source_port_value(), MDNS_PORT);
assert_eq!(unicast_reply.destination_port_value(), 62000);
let ipv4_multicast = mdns::ipv4_multicast(source_ipv4);
assert_eq!(ipv4_multicast.source(), source_ipv4);
assert_eq!(ipv4_multicast.destination(), MDNS_IPV4_MULTICAST);
assert_eq!(ipv4_multicast.ttl_value(), MDNS_RESPONSE_TTL);
let ipv4_response = mdns::ipv4_response(source_ipv4);
assert_eq!(ipv4_response.destination(), MDNS_IPV4_MULTICAST);
assert_eq!(ipv4_response.ttl_value(), MDNS_RESPONSE_TTL);
let ipv6_multicast = mdns::ipv6_multicast(source_ipv6);
assert_eq!(ipv6_multicast.source(), source_ipv6);
assert_eq!(ipv6_multicast.destination(), MDNS_IPV6_LINK_LOCAL_MULTICAST);
assert_eq!(ipv6_multicast.hop_limit_value(), MDNS_RESPONSE_HOP_LIMIT);
let ipv6_response = mdns::ipv6_response(source_ipv6);
assert_eq!(ipv6_response.destination(), MDNS_IPV6_LINK_LOCAL_MULTICAST);
assert_eq!(ipv6_response.hop_limit_value(), MDNS_RESPONSE_HOP_LIMIT);
let ethernet_ipv4 = mdns::ethernet_ipv4_multicast(source_mac);
assert_eq!(ethernet_ipv4.source(), Some(source_mac));
assert_eq!(
ethernet_ipv4.destination(),
Some(MDNS_IPV4_ETHERNET_MULTICAST)
);
let ethernet_ipv6 = mdns::ethernet_ipv6_multicast(source_mac);
assert_eq!(ethernet_ipv6.source(), Some(source_mac));
assert_eq!(
ethernet_ipv6.destination(),
Some(MDNS_IPV6_ETHERNET_MULTICAST)
);
let dns = mdns::query_for("printer.local.", DNS_TYPE_A);
assert!(mdns::ipv4_packet(source_ipv4, dns.clone())
.layer::<Ipv4>()
.is_some());
assert!(mdns::ipv6_packet(source_ipv6, dns.clone())
.layer::<Ipv6>()
.is_some());
assert!(
mdns::ethernet_ipv4_packet(source_mac, source_ipv4, dns.clone())
.layer::<Ethernet>()
.is_some()
);
assert!(mdns::ethernet_ipv6_packet(source_mac, source_ipv6, dns)
.layer::<Ethernet>()
.is_some());
}
}
#[cfg(test)]
mod bonjour_record_builders_tests {
use core::net::{Ipv4Addr, Ipv6Addr};
use crate::{NetworkLayer, Packet, ProtocolRegistry};
use super::{
append_dns_packet, mdns, Dns, DnsRecordData, DNS_SD_DEFAULT_DOMAIN, DNS_TYPE_A,
DNS_TYPE_AAAA, DNS_TYPE_PTR, DNS_TYPE_SRV, DNS_TYPE_TXT, MDNS_PORT,
};
fn mdns_registry() -> ProtocolRegistry {
let mut registry = ProtocolRegistry::new();
registry.bind_udp_port(MDNS_PORT, |packet, payload| {
append_dns_packet(packet, payload)
});
registry
}
#[test]
fn bonjour_record_builders_complete_service_answer_set_round_trips() -> crate::Result<()> {
let service_type = mdns::dns_sd_tcp_service_name("ipp", DNS_SD_DEFAULT_DOMAIN)?;
let service_instance =
mdns::dns_sd_tcp_instance_name("Office\\032Printer", "ipp", DNS_SD_DEFAULT_DOMAIN)?;
let subtype = mdns::dns_sd_subtype_name("printer", "ipp", "tcp", DNS_SD_DEFAULT_DOMAIN)?;
let host = "office-printer.local.";
let txt_strings = vec![
b"txtvers=1".to_vec(),
b"rp=printers/office".to_vec(),
vec![b'k', b'=', 0x00, 0xff],
];
let service_ptr = mdns::service_ptr(service_type.clone(), service_instance.clone(), 4500);
let subtype_ptr = mdns::subtype_ptr(subtype.clone(), service_instance.clone(), 4501);
let srv = mdns::srv_with_priority(service_instance.clone(), host, 0, 10, 631, 120)
.mdns_cache_flush(true);
let txt =
mdns::txt(service_instance.clone(), txt_strings.clone(), 121).mdns_cache_flush(true);
let ipv4 = mdns::a(host, Ipv4Addr::new(192, 0, 2, 55), 122).mdns_cache_flush(true);
let ipv6 = mdns::aaaa(host, Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x55), 123)
.mdns_cache_flush(true);
let dns = mdns::response_with_answers([
service_ptr.clone(),
subtype_ptr.clone(),
srv.clone(),
txt.clone(),
])
.additional(ipv4.clone())
.additional(ipv6.clone());
let bytes = mdns::mdns_ipv4_packet(Ipv4Addr::new(192, 0, 2, 10), dns).compile()?;
let decoded = Packet::decode_from_l3_with_registry(
&mdns_registry(),
NetworkLayer::Ipv4,
bytes.as_bytes(),
)?;
let decoded_dns = decoded.layer::<Dns>().unwrap();
assert_eq!(decoded_dns.answers(), &[service_ptr, subtype_ptr, srv, txt]);
assert_eq!(decoded_dns.additionals(), &[ipv4, ipv6]);
assert_eq!(decoded_dns.answers()[0].record_type(), DNS_TYPE_PTR);
assert_eq!(decoded_dns.answers()[0].ttl(), 4500);
assert!(!decoded_dns.answers()[0].mdns_cache_flush_value());
assert_eq!(decoded_dns.answers()[1].record_type(), DNS_TYPE_PTR);
assert_eq!(decoded_dns.answers()[1].ttl(), 4501);
assert!(!decoded_dns.answers()[1].mdns_cache_flush_value());
assert_eq!(decoded_dns.answers()[2].record_type(), DNS_TYPE_SRV);
assert_eq!(decoded_dns.answers()[2].ttl(), 120);
assert!(decoded_dns.answers()[2].mdns_cache_flush_value());
assert_eq!(decoded_dns.answers()[3].record_type(), DNS_TYPE_TXT);
assert_eq!(decoded_dns.answers()[3].ttl(), 121);
assert!(decoded_dns.answers()[3].mdns_cache_flush_value());
assert_eq!(
decoded_dns.answers()[3].data(),
&DnsRecordData::Txt(txt_strings)
);
assert_eq!(decoded_dns.additionals()[0].record_type(), DNS_TYPE_A);
assert_eq!(decoded_dns.additionals()[0].ttl(), 122);
assert!(decoded_dns.additionals()[0].mdns_cache_flush_value());
assert_eq!(decoded_dns.additionals()[1].record_type(), DNS_TYPE_AAAA);
assert_eq!(decoded_dns.additionals()[1].ttl(), 123);
assert!(decoded_dns.additionals()[1].mdns_cache_flush_value());
assert_eq!(decoded.compile()?, bytes);
Ok(())
}
}
#[cfg(test)]
mod dns_tests {
use super::{
decode_dns_name, Dns, DnsName, DnsQuestion, DnsRecord, DnsRecordData, DNS_CLASS_IN,
DNS_FLAG_AUTHORITATIVE, DNS_FLAG_QR_RESPONSE, DNS_FLAG_RECURSION_DESIRED, DNS_TYPE_A,
DNS_TYPE_AAAA, DNS_TYPE_CNAME, DNS_TYPE_MX, DNS_TYPE_NS, DNS_TYPE_PTR, DNS_TYPE_SRV,
DNS_TYPE_TXT,
};
use crate::{Ipv4, NetworkLayer, Packet, Raw, Udp};
use core::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn dns_a_query_encodes_header_question_and_udp_payload() {
let dns = Dns::a_query("example.com").id(0xbeef);
let packet = Udp::new().sport(53001).dport(53) / dns;
let compiled = packet.compile().unwrap();
assert_eq!(&compiled.as_bytes()[8..10], &0xbeefu16.to_be_bytes());
assert_eq!(
&compiled.as_bytes()[10..12],
&DNS_FLAG_RECURSION_DESIRED.to_be_bytes()
);
assert_eq!(&compiled.as_bytes()[12..14], &1u16.to_be_bytes());
assert!(compiled.as_bytes().ends_with(&[
7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0, 0, 1, 0, 1,
]));
}
#[test]
fn dns_response_records_roundtrip() {
let original = Dns::new()
.id(0x1234)
.response(true)
.authoritative(true)
.question(DnsQuestion::a("example.com."))
.answer(DnsRecord::a(
"example.com.",
Ipv4Addr::new(203, 0, 113, 10),
60,
))
.answer(DnsRecord::aaaa(
"example.com.",
Ipv6Addr::from([0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
60,
))
.answer(DnsRecord::cname("www.example.com.", "example.com.", 60));
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ original.clone())
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.id_value(), 0x1234);
assert_eq!(
dns.flags_value() & (DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE),
DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE
);
assert_eq!(dns.questions()[0].name(), "example.com.");
assert_eq!(dns.answers().len(), 3);
assert_eq!(
dns.answers()[0].data(),
&DnsRecordData::A(Ipv4Addr::new(203, 0, 113, 10))
);
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn dns_decode_uses_udp_port_context() {
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(192, 0, 2, 10))
.dst(Ipv4Addr::new(198, 51, 100, 53))
/ Udp::new().sport(53001).dport(53)
/ Dns::aaaa_query("example.com").id(0x5678))
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.questions()[0].question_type(), DNS_TYPE_AAAA);
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn dns_builders_keep_common_values_visible() {
let query = Dns::new()
.id(7)
.rd(false)
.question(DnsQuestion::new("example.org", DNS_TYPE_A).qclass(DNS_CLASS_IN));
assert_eq!(query.id_value(), 7);
assert_eq!(query.flags_value() & DNS_FLAG_RECURSION_DESIRED, 0);
assert_eq!(query.questions()[0].name(), "example.org.");
assert_eq!(query.questions()[0].question_class(), DNS_CLASS_IN);
}
#[test]
fn dns_compressed_name_decode_is_exposed() {
let message = [
3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm',
0, 4, b'm', b'a', b'i', b'l', 0xc0, 4,
];
assert_eq!(
decode_dns_name(&message, 0).unwrap(),
("www.example.com.".to_string(), 17)
);
assert_eq!(
decode_dns_name(&message, 17).unwrap(),
("mail.example.com.".to_string(), 7)
);
}
#[test]
fn non_text_owner_name_round_trips_through_a_compiled_packet() {
let owner = DnsName::from_labels([vec![0x00u8, 0xff], b"example".to_vec()]).unwrap();
assert!(!owner.is_text());
let record = DnsRecord::new(
owner.clone(),
DNS_TYPE_TXT,
DNS_CLASS_IN,
300,
DnsRecordData::txt(b"v"),
);
let original = Dns::new().id(0x4242).response(true).answer(record);
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ original)
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
let answer = &dns.answers()[0];
assert_eq!(answer.name_labels(), owner.labels());
assert_eq!(answer.dns_name(), &owner);
assert_eq!(answer.name(), "\\000\\255.example.");
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn dns_compression_decodes_then_recompiles_uncompressed() {
let message: &[u8] = &[
0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00,
7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c, 0x00, 0x08, 5, b'a', b'l', b'i', b'a', b's', 0xc0, 0x0c, 0xc0, 0x0c, 0x00, 0x0f, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c, 0x00, 0x09, 0x00, 0x0a, 4, b'm', b'a', b'i', b'l', 0xc0, 0x0c, 4, b'_', b's', b'i', b'p', 4, b'_', b't', b'c', b'p', 0xc0, 0x0c, 0x00, 0x21, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c, 0x00, 0x0c, 0x00, 0x0a, 0x00, 0x05, 0x16, 0x2c, 3, b's', b'i', b'p', 0xc0, 0x0c,
];
let wire = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ Raw::from_bytes(message))
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, wire.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.questions()[0].name(), "example.com.");
assert_eq!(dns.answers().len(), 3);
assert_eq!(dns.answers()[0].name(), "example.com.");
assert_eq!(dns.answers()[0].record_type(), DNS_TYPE_CNAME);
match dns.answers()[0].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "alias.example.com."),
other => panic!("expected CNAME name, got {other:?}"),
}
assert_eq!(dns.answers()[1].record_type(), DNS_TYPE_MX);
match dns.answers()[1].data() {
DnsRecordData::Mx {
preference,
exchange,
} => {
assert_eq!(*preference, 10);
assert_eq!(exchange.presentation(), "mail.example.com.");
}
other => panic!("expected MX, got {other:?}"),
}
assert_eq!(dns.answers()[2].name(), "_sip._tcp.example.com.");
assert_eq!(dns.answers()[2].record_type(), DNS_TYPE_SRV);
match dns.answers()[2].data() {
DnsRecordData::Srv {
priority,
weight,
port,
target,
} => {
assert_eq!((*priority, *weight, *port), (10, 5, 5676));
assert_eq!(target.presentation(), "sip.example.com.");
}
other => panic!("expected SRV, got {other:?}"),
}
let recompiled = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ dns.clone())
.compile()
.unwrap();
let dns_payload = &recompiled.as_bytes()[28..];
assert!(dns_payload.len() > message.len());
assert!(
!dns_payload.iter().any(|&b| b & 0xc0 == 0xc0),
"recompiled DNS payload must not contain a compression pointer",
);
let redecoded = Packet::decode_from_l3(NetworkLayer::Ipv4, recompiled.as_bytes()).unwrap();
let redns = redecoded.layer::<Dns>().unwrap();
assert_eq!(redns.answers().len(), 3);
assert_eq!(redns.questions()[0].name(), "example.com.");
assert_eq!(redns.answers()[0].data(), dns.answers()[0].data());
assert_eq!(redns.answers()[1].data(), dns.answers()[1].data());
assert_eq!(redns.answers()[2].data(), dns.answers()[2].data());
assert_eq!(redns.answers()[2].name(), "_sip._tcp.example.com.");
}
#[test]
fn name_records_round_trip_through_a_compiled_packet() {
let original = Dns::new()
.id(0x4e43)
.response(true)
.authoritative(true)
.question(DnsQuestion::new("example.com.", DNS_TYPE_NS).qclass(DNS_CLASS_IN))
.answer(DnsRecord::new(
"example.com.",
DNS_TYPE_NS,
DNS_CLASS_IN,
3600,
DnsRecordData::name("ns1.example.com."),
))
.answer(DnsRecord::cname("www.example.com.", "host.example.", 300))
.answer(DnsRecord::new(
"20.113.0.203.in-addr.arpa.",
DNS_TYPE_PTR,
DNS_CLASS_IN,
300,
DnsRecordData::name("host.example.com."),
));
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ original.clone())
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.questions()[0].name(), "example.com.");
assert_eq!(dns.questions()[0].question_type(), DNS_TYPE_NS);
assert_eq!(dns.answers().len(), 3);
assert_eq!(dns.answers()[0].record_type(), DNS_TYPE_NS);
match dns.answers()[0].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "ns1.example.com."),
other => panic!("expected NS name, got {other:?}"),
}
assert_eq!(dns.answers()[1].record_type(), DNS_TYPE_CNAME);
assert_eq!(dns.answers()[1].name(), "www.example.com.");
match dns.answers()[1].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "host.example."),
other => panic!("expected CNAME name, got {other:?}"),
}
assert_eq!(dns.answers()[2].record_type(), DNS_TYPE_PTR);
assert_eq!(dns.answers()[2].name(), "20.113.0.203.in-addr.arpa.");
match dns.answers()[2].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "host.example.com."),
other => panic!("expected PTR name, got {other:?}"),
}
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn compressed_name_records_normalize_to_uncompressed_model() {
let message: &[u8] = &[
0x4e, 0x43, 0x84, 0x00, 0x00, 0x01, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00,
7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0, 0x00, 0x02, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x02, 0x00, 0x01, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x06, 3, b'n', b's', b'1', 0xc0, 0x0c, 3, b'w', b'w', b'w', 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c, 0x00, 0x07, 4, b'h', b'o', b's', b't', 0xc0, 0x0c, 3, b'p', b't', b'r', 0xc0, 0x0c, 0x00, 0x0c, 0x00, 0x01, 0x00, 0x00, 0x01, 0x2c, 0x00, 0x07, 4, b'h', b'o', b's', b't', 0xc0, 0x0c,
];
let wire = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ Raw::from_bytes(message))
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, wire.as_bytes()).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.questions()[0].name(), "example.com.");
assert_eq!(dns.answers().len(), 3);
assert_eq!(dns.answers()[0].name(), "example.com.");
assert_eq!(dns.answers()[0].record_type(), DNS_TYPE_NS);
match dns.answers()[0].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "ns1.example.com."),
other => panic!("expected NS name, got {other:?}"),
}
assert_eq!(dns.answers()[1].name(), "www.example.com.");
assert_eq!(dns.answers()[1].record_type(), DNS_TYPE_CNAME);
match dns.answers()[1].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "host.example.com."),
other => panic!("expected CNAME name, got {other:?}"),
}
assert_eq!(dns.answers()[2].name(), "ptr.example.com.");
assert_eq!(dns.answers()[2].record_type(), DNS_TYPE_PTR);
match dns.answers()[2].data() {
DnsRecordData::Name(name) => assert_eq!(name.presentation(), "host.example.com."),
other => panic!("expected PTR name, got {other:?}"),
}
let recompiled = (Ipv4::new()
.src(Ipv4Addr::new(198, 51, 100, 53))
.dst(Ipv4Addr::new(192, 0, 2, 10))
/ Udp::new().sport(53).dport(53001)
/ dns.clone())
.compile()
.unwrap();
let dns_payload = &recompiled.as_bytes()[28..];
assert!(dns_payload.len() > message.len());
assert!(
!dns_payload.iter().any(|&b| b & 0xc0 == 0xc0),
"recompiled DNS payload must not contain a compression pointer",
);
let redecoded = Packet::decode_from_l3(NetworkLayer::Ipv4, recompiled.as_bytes()).unwrap();
let redns = redecoded.layer::<Dns>().unwrap();
assert_eq!(redns.answers().len(), 3);
assert_eq!(redns.answers()[0].data(), dns.answers()[0].data());
assert_eq!(redns.answers()[1].data(), dns.answers()[1].data());
assert_eq!(redns.answers()[2].data(), dns.answers()[2].data());
}
#[test]
fn dns_type_mismatch_is_rejected() {
let record = DnsRecord::new(
"example.com.",
DNS_TYPE_CNAME,
DNS_CLASS_IN,
60,
DnsRecordData::A(Ipv4Addr::new(203, 0, 113, 1)),
);
assert!(Packet::from_layer(Dns::new().answer(record))
.compile()
.is_err());
}
}
#[cfg(test)]
mod dns_header_codepoints {
use super::{
dns_type_name, Dns, DnsQuestion, DnsRecord, DnsRecordData, DNS_CLASS_ANY, DNS_CLASS_CH,
DNS_CLASS_HS, DNS_CLASS_IN, DNS_CLASS_NONE, DNS_FLAG_AUTHENTIC_DATA,
DNS_FLAG_AUTHORITATIVE, DNS_FLAG_CHECKING_DISABLED, DNS_FLAG_QR_RESPONSE,
DNS_FLAG_RECURSION_AVAILABLE, DNS_FLAG_RECURSION_DESIRED, DNS_FLAG_TRUNCATED,
DNS_OPCODE_QUERY, DNS_OPCODE_STATUS, DNS_OPCODE_UPDATE, DNS_RCODE_NOERROR,
DNS_RCODE_NXDOMAIN, DNS_RCODE_REFUSED, DNS_TYPE_A, DNS_TYPE_AAAA, DNS_TYPE_HTTPS,
DNS_TYPE_MX, DNS_TYPE_NS, DNS_TYPE_OPT, DNS_TYPE_SOA, DNS_TYPE_SRV, DNS_TYPE_TXT,
};
use crate::{Ipv4, NetworkLayer, Packet, Udp};
use std::net::Ipv4Addr;
#[test]
fn existing_flag_helpers_compile_identically() {
let dns = Dns::new()
.id(0xbeef)
.response(true)
.authoritative(true)
.question(DnsQuestion::a("example.com."));
let compiled = (Udp::new().sport(53001).dport(53) / dns).compile().unwrap();
let expected_flags =
DNS_FLAG_QR_RESPONSE | DNS_FLAG_AUTHORITATIVE | DNS_FLAG_RECURSION_DESIRED;
assert_eq!(&compiled.as_bytes()[10..12], &expected_flags.to_be_bytes());
}
#[test]
fn opcode_setter_preserves_unrelated_bits() {
let dns = Dns::new()
.response(true)
.authoritative(true)
.rcode(DNS_RCODE_NXDOMAIN)
.opcode(DNS_OPCODE_UPDATE);
assert_eq!(dns.opcode_value(), DNS_OPCODE_UPDATE);
assert!(dns.is_response());
assert_ne!(dns.flags_value() & DNS_FLAG_AUTHORITATIVE, 0);
assert_ne!(dns.flags_value() & DNS_FLAG_RECURSION_DESIRED, 0);
assert_eq!(dns.rcode_value(), DNS_RCODE_NXDOMAIN);
}
#[test]
fn rcode_setter_preserves_unrelated_bits() {
let dns = Dns::new()
.opcode(DNS_OPCODE_STATUS)
.response(true)
.rcode(DNS_RCODE_REFUSED);
assert_eq!(dns.rcode_value(), DNS_RCODE_REFUSED);
assert_eq!(dns.opcode_value(), DNS_OPCODE_STATUS);
assert!(dns.is_response());
}
#[test]
fn opcode_and_rcode_defaults_are_query_noerror() {
let dns = Dns::new();
assert_eq!(dns.opcode_value(), DNS_OPCODE_QUERY);
assert_eq!(dns.rcode_value(), DNS_RCODE_NOERROR);
}
#[test]
fn unknown_opcode_and_rcode_values_round_trip() {
let dns = Dns::new().opcode(0xf).rcode(0xf);
assert_eq!(dns.opcode_value(), 0xf);
assert_eq!(dns.rcode_value(), 0xf);
let truncated = Dns::new().opcode(0xff).rcode(0xff);
assert_eq!(truncated.opcode_value(), 0xf);
assert_eq!(truncated.rcode_value(), 0xf);
}
#[test]
fn raw_flags_remain_the_escape_hatch() {
let dns = Dns::new().flags(0xabcd);
assert_eq!(dns.flags_value(), 0xabcd);
assert_eq!(dns.opcode_value(), ((0xabcd & 0x7800) >> 11) as u8);
assert_eq!(dns.rcode_value(), (0xabcd & 0x000f) as u8);
}
#[test]
fn all_named_header_flag_bits_survive_compile_and_decode() {
let all_named = DNS_FLAG_AUTHORITATIVE
| DNS_FLAG_TRUNCATED
| DNS_FLAG_RECURSION_DESIRED
| DNS_FLAG_RECURSION_AVAILABLE
| DNS_FLAG_AUTHENTIC_DATA
| DNS_FLAG_CHECKING_DISABLED;
let dns = Dns::new()
.flags(all_named)
.question(DnsQuestion::a("example.com."));
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(203, 0, 113, 1))
.dst(Ipv4Addr::new(198, 51, 100, 1))
/ Udp::new().sport(53001).dport(53)
/ dns)
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let header = decoded.layer::<Dns>().unwrap();
assert_eq!(header.flags_value(), all_named);
assert_ne!(header.flags_value() & DNS_FLAG_AUTHENTIC_DATA, 0);
assert_ne!(header.flags_value() & DNS_FLAG_CHECKING_DISABLED, 0);
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn recursion_available_setter_sets_only_its_bit() {
let dns = Dns::new().recursion_available(true);
assert_ne!(dns.flags_value() & DNS_FLAG_RECURSION_AVAILABLE, 0);
assert_ne!(dns.flags_value() & DNS_FLAG_RECURSION_DESIRED, 0);
assert_eq!(
dns.flags_value() & !(DNS_FLAG_RECURSION_AVAILABLE | DNS_FLAG_RECURSION_DESIRED),
0
);
}
#[test]
fn section_counts_auto_fill_from_typed_vectors() {
let dns = Dns::new()
.response(true)
.question(DnsQuestion::a("example.com."))
.answer(DnsRecord::a(
"example.com.",
Ipv4Addr::new(192, 0, 2, 10),
60,
))
.authority(DnsRecord::cname("example.com.", "ns1.example.com.", 300))
.additional(DnsRecord::a(
"ns1.example.com.",
Ipv4Addr::new(192, 0, 2, 53),
300,
));
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(203, 0, 113, 1))
.dst(Ipv4Addr::new(198, 51, 100, 1))
/ Udp::new().sport(53).dport(53001)
/ dns)
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let header = decoded.layer::<Dns>().unwrap();
assert_eq!(header.questions().len(), 1);
assert_eq!(header.answers().len(), 1);
assert_eq!(header.authorities().len(), 1);
assert_eq!(header.additionals().len(), 1);
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn section_placement_survives_decode_and_recompile() {
let dns = Dns::new()
.id(0x5023)
.response(true)
.authoritative(true)
.question(DnsQuestion::a("example.com."))
.answer(DnsRecord::a(
"example.com.",
Ipv4Addr::new(192, 0, 2, 10),
3600,
))
.authority(DnsRecord::new(
"example.com.",
DNS_TYPE_NS,
DNS_CLASS_IN,
3600,
DnsRecordData::name("ns1.example.com."),
))
.additional(DnsRecord::a(
"ns1.example.com.",
Ipv4Addr::new(192, 0, 2, 53),
3600,
))
.additional(DnsRecord::opt(1232, 0, 0, false, Vec::new()));
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(203, 0, 113, 1))
.dst(Ipv4Addr::new(198, 51, 100, 1))
/ Udp::new().sport(53).dport(53001)
/ dns)
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let header = decoded.layer::<Dns>().unwrap();
assert_eq!(header.questions().len(), 1);
assert_eq!(header.answers().len(), 1);
assert_eq!(header.authorities().len(), 1);
assert_eq!(header.additionals().len(), 2);
assert_eq!(header.answers()[0].record_type(), DNS_TYPE_A);
assert_eq!(header.answers()[0].name(), "example.com.");
assert_eq!(header.authorities()[0].record_type(), DNS_TYPE_NS);
assert_eq!(
header.authorities()[0].data(),
&DnsRecordData::name("ns1.example.com.")
);
assert_eq!(header.additionals()[0].record_type(), DNS_TYPE_A);
assert!(!header.additionals()[0].is_opt());
assert_eq!(header.additionals()[0].name(), "ns1.example.com.");
assert_eq!(header.additionals()[1].record_type(), DNS_TYPE_OPT);
assert!(header.additionals()[1].is_opt());
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn multi_question_class_and_type_values_round_trip() {
const QTYPE_ANY: u16 = 255;
const PRIVATE_QTYPE: u16 = 65280;
const PRIVATE_QCLASS: u16 = 65280;
let questions = [
(DNS_TYPE_A, DNS_CLASS_IN),
(DNS_TYPE_AAAA, DNS_CLASS_CH),
(DNS_TYPE_MX, DNS_CLASS_HS),
(DNS_TYPE_TXT, DNS_CLASS_NONE),
(QTYPE_ANY, DNS_CLASS_ANY),
(PRIVATE_QTYPE, PRIVATE_QCLASS),
];
let mut dns = Dns::new().rd(true);
for (index, (qtype, qclass)) in questions.iter().enumerate() {
let name = format!("q{index}.example.com.");
dns = dns.question(DnsQuestion::new(name, *qtype).qclass(*qclass));
}
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(192, 0, 2, 10))
.dst(Ipv4Addr::new(198, 51, 100, 53))
/ Udp::new().sport(53001).dport(53)
/ dns)
.compile()
.unwrap();
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv4, bytes.as_bytes()).unwrap();
let header = decoded.layer::<Dns>().unwrap();
assert_eq!(header.questions().len(), questions.len());
for (index, (qtype, qclass)) in questions.iter().enumerate() {
let question = &header.questions()[index];
assert_eq!(question.name(), format!("q{index}.example.com."));
assert_eq!(question.question_type(), *qtype);
assert_eq!(question.question_class(), *qclass);
}
assert_eq!(header.questions()[5].question_type(), PRIVATE_QTYPE);
assert_eq!(header.questions()[5].question_class(), PRIVATE_QCLASS);
assert_eq!(decoded.compile().unwrap(), bytes);
}
#[test]
fn type_names_cover_source_backed_codepoints() {
assert_eq!(dns_type_name(DNS_TYPE_A), Some("A"));
assert_eq!(dns_type_name(DNS_TYPE_SOA), Some("SOA"));
assert_eq!(dns_type_name(DNS_TYPE_SRV), Some("SRV"));
assert_eq!(dns_type_name(DNS_TYPE_HTTPS), Some("HTTPS"));
assert_eq!(dns_type_name(60000), None);
}
}
#[cfg(test)]
mod dns_malformed {
use super::{
decode_dns_name_typed, decode_record_data, DnsRecordData, DNS_MAX_LABEL_LEN, DNS_TYPE_A,
DNS_TYPE_AAAA, DNS_TYPE_DNSKEY, DNS_TYPE_DS, DNS_TYPE_NSEC, DNS_TYPE_NSEC3, DNS_TYPE_RRSIG,
DNS_TYPE_SOA, DNS_TYPE_SRV, DNS_TYPE_SVCB,
};
use crate::error::CrafterError;
fn assert_too_short(record_type: u16, rdata: &[u8], field: &str) {
match decode_record_data(record_type, rdata, 0, rdata.len()) {
Err(CrafterError::BufferTooShort { context, .. }) => assert_eq!(
context, field,
"type {record_type:#x} expected buffer-too-short context {field}"
),
other => {
panic!("type {record_type:#x} expected buffer-too-short {field}, got {other:?}")
}
}
}
fn assert_invalid(record_type: u16, rdata: &[u8], field: &str) {
match decode_record_data(record_type, rdata, 0, rdata.len()) {
Err(CrafterError::InvalidFieldValue { field: got, .. }) => assert_eq!(
got, field,
"type {record_type:#x} expected invalid-field-value {field}"
),
other => {
panic!("type {record_type:#x} expected invalid-field-value {field}, got {other:?}")
}
}
}
#[test]
fn fixed_length_records_reject_wrong_rdlength() {
assert_invalid(DNS_TYPE_A, &[192, 0, 2], "dns.a.rdlength");
assert_invalid(DNS_TYPE_A, &[192, 0, 2, 1, 9], "dns.a.rdlength");
assert_invalid(DNS_TYPE_AAAA, &[0u8; 15], "dns.aaaa.rdlength");
assert_invalid(DNS_TYPE_AAAA, &[0u8; 17], "dns.aaaa.rdlength");
}
#[test]
fn dnssec_fixed_headers_reject_truncation() {
assert_too_short(DNS_TYPE_DS, &[0u8; 3], "dns.ds");
assert_too_short(DNS_TYPE_DNSKEY, &[0u8; 3], "dns.dnskey");
assert_too_short(DNS_TYPE_RRSIG, &[0u8; 17], "dns.rrsig");
}
#[test]
fn soa_rejects_wrong_fixed_tail_length() {
let mut short = vec![1u8, b'a', 0, 0];
short.extend_from_slice(&[0u8; 19]);
assert_invalid(DNS_TYPE_SOA, &short, "dns.soa.rdlength");
let mut long = vec![1u8, b'a', 0, 0];
long.extend_from_slice(&[0u8; 21]);
assert_invalid(DNS_TYPE_SOA, &long, "dns.soa.rdlength");
}
#[test]
fn srv_rejects_short_header_and_trailing_bytes() {
assert_too_short(DNS_TYPE_SRV, &[0u8; 5], "dns.srv");
assert_invalid(
DNS_TYPE_SRV,
&[0, 1, 0, 2, 0x13, 0x88, 0, 0xff],
"dns.srv.target",
);
}
#[test]
fn nsec3_rejects_hash_length_overrun() {
assert_too_short(
DNS_TYPE_NSEC3,
&[1, 0, 0, 10, 0, 20, 0x11, 0x22],
"dns.nsec3.hash",
);
}
#[test]
fn nsec_rejects_non_minimal_trailing_zero_bitmap() {
let rdata = [0u8, 0x00, 2, 0x40, 0x00];
assert_invalid(DNS_TYPE_NSEC, &rdata, "dns.nsec.bitmap");
}
#[test]
fn svcb_rejects_out_of_order_and_overrun_params() {
let out_of_order = [0u8, 1, 0, 0, 3, 0, 0, 0, 1, 0, 0];
assert_invalid(DNS_TYPE_SVCB, &out_of_order, "dns.svcb.params");
let overrun = [0u8, 1, 0, 0, 3, 0, 8, 0, 0];
assert_too_short(DNS_TYPE_SVCB, &overrun, "dns.svcb.params");
}
#[test]
fn unknown_record_type_decodes_as_raw_not_rejected() {
let rdata = [0xde, 0xad, 0xbe, 0xef];
let data = decode_record_data(0xfff0, &rdata, 0, rdata.len()).unwrap();
assert_eq!(data, DnsRecordData::Raw(rdata.to_vec()));
let empty = decode_record_data(0xfff0, &[], 0, 0).unwrap();
assert_eq!(empty, DnsRecordData::Raw(Vec::new()));
}
#[test]
fn name_decoder_rejects_reserved_marker_and_length_overrun() {
match decode_dns_name_typed(&[0x40], 0) {
Err(CrafterError::InvalidFieldValue { field, .. }) => assert_eq!(field, "dns.name"),
other => panic!("expected reserved-marker rejection, got {other:?}"),
}
let mut overrun = Vec::new();
for _ in 0..4 {
overrun.push(DNS_MAX_LABEL_LEN as u8);
overrun.extend_from_slice(&[b'a'; DNS_MAX_LABEL_LEN]);
}
overrun.push(0);
match decode_dns_name_typed(&overrun, 0) {
Err(CrafterError::InvalidFieldValue { field, .. }) => assert_eq!(field, "dns.name"),
other => panic!("expected full-name overrun rejection, got {other:?}"),
}
}
#[test]
fn label_at_63_octet_boundary_is_accepted() {
let mut wire = Vec::new();
wire.push(DNS_MAX_LABEL_LEN as u8);
wire.extend_from_slice(&[b'a'; DNS_MAX_LABEL_LEN]);
wire.push(0);
let (name, used) = decode_dns_name_typed(&wire, 0).unwrap();
assert_eq!(used, wire.len());
assert_eq!(name.labels(), &[vec![b'a'; DNS_MAX_LABEL_LEN]]);
}
}
#[cfg(test)]
mod dns_golden_bytes {
use super::{Dns, DnsQuestion, DNS_TYPE_A};
use crate::{Ipv4, LinkType, Packet, Udp};
use core::net::Ipv4Addr;
const DNS_QUERY_FIXTURE: &[u8] = fixture_bytes!("bytes/ipv4-udp-dns-query-example-com.bin");
#[test]
fn dns_query_matches_golden_bytes() {
let bytes = (Ipv4::new()
.src(Ipv4Addr::new(192, 0, 2, 10))
.dst(Ipv4Addr::new(198, 51, 100, 53))
.id(0x1237)
.ttl(61)
/ Udp::new().sport(53001).dport(53)
/ Dns::new()
.id(0xbeef)
.question(DnsQuestion::new("example.com.", DNS_TYPE_A)))
.compile()
.unwrap();
assert_eq!(bytes.as_bytes(), DNS_QUERY_FIXTURE);
}
#[test]
fn dns_query_fixture_decodes_to_typed_layer() {
let decoded = Packet::decode_from_l3(crate::NetworkLayer::Ipv4, DNS_QUERY_FIXTURE).unwrap();
let dns = decoded.layer::<Dns>().unwrap();
assert_eq!(dns.id_value(), 0xbeef);
assert_eq!(dns.questions()[0].name(), "example.com.");
assert_eq!(dns.questions()[0].question_type(), DNS_TYPE_A);
assert_eq!(decoded.compile().unwrap().as_bytes(), DNS_QUERY_FIXTURE);
}
#[test]
fn non_dns_udp_payload_stays_raw_even_when_decoding_from_link() {
let raw_fixture = fixture_bytes!("bytes/ethernet-vlan-ipv4-udp-raw.bin");
let decoded = Packet::decode_from_link(LinkType::Ethernet, raw_fixture).unwrap();
assert!(decoded.layer::<Dns>().is_none());
}
}