use thiserror::Error;
pub(crate) const MAX_LABEL_LEN: usize = 63;
pub(crate) const MAX_QNAME_LEN: usize = 253;
pub(crate) const DNS_HEADER_LEN: usize = 12;
pub(crate) const QCLASS_IN: u16 = 1;
const POINTER_MASK: u8 = 0b1100_0000;
#[derive(Debug, Error, PartialEq, Eq)]
pub enum DnsParseError {
#[error("dns packet too short")]
TooShort,
#[error("dns QDCOUNT must be at least 1, got 0")]
QdcountZero,
#[error("dns QDCOUNT must be exactly 1 on the proxy path, got {0}")]
QdcountUnsupported(u16),
#[error("dns label exceeds 63 octets")]
LabelOverflow,
#[error("dns QNAME exceeds 253 octets")]
NameOverflow,
#[error("dns QNAME pointer-compression rejected")]
CompressionRejected,
#[error("dns qclass {0} not supported (only IN/1)")]
UnsupportedClass(u16),
#[error("dns label contains invalid byte 0x{0:02x}")]
InvalidLabelByte(u8),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DnsQueryView {
pub txn_id: u16,
pub flags: u16,
pub qname: String,
pub qtype: u16,
pub qclass: u16,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QuestionOutcome {
Parsed(DnsQueryView),
}
impl QuestionOutcome {
pub fn as_view(&self) -> Option<&DnsQueryView> {
match self {
QuestionOutcome::Parsed(v) => Some(v),
}
}
}
pub fn parse_query(packet: &[u8]) -> Result<DnsQueryView, DnsParseError> {
if packet.len() < DNS_HEADER_LEN {
return Err(DnsParseError::TooShort);
}
let txn_id = u16::from_be_bytes([packet[0], packet[1]]);
let flags = u16::from_be_bytes([packet[2], packet[3]]);
let qdcount = u16::from_be_bytes([packet[4], packet[5]]);
if qdcount == 0 {
return Err(DnsParseError::QdcountZero);
}
if qdcount != 1 {
return Err(DnsParseError::QdcountUnsupported(qdcount));
}
let (view, _next) = parse_one_question(packet, DNS_HEADER_LEN, txn_id, flags)?;
Ok(view)
}
pub fn parse_query_multi(packet: &[u8]) -> Result<Vec<QuestionOutcome>, DnsParseError> {
if packet.len() < DNS_HEADER_LEN {
return Err(DnsParseError::TooShort);
}
let txn_id = u16::from_be_bytes([packet[0], packet[1]]);
let flags = u16::from_be_bytes([packet[2], packet[3]]);
let qdcount = u16::from_be_bytes([packet[4], packet[5]]);
if qdcount == 0 {
return Err(DnsParseError::QdcountZero);
}
let mut outcomes = Vec::with_capacity(qdcount as usize);
let mut idx = DNS_HEADER_LEN;
for _ in 0..qdcount {
let (view, next) = parse_one_question(packet, idx, txn_id, flags)?;
outcomes.push(QuestionOutcome::Parsed(view));
idx = next;
}
Ok(outcomes)
}
fn parse_one_question(
packet: &[u8],
start: usize,
txn_id: u16,
flags: u16,
) -> Result<(DnsQueryView, usize), DnsParseError> {
let mut idx = start;
let mut qname = String::new();
loop {
if idx >= packet.len() {
return Err(DnsParseError::TooShort);
}
let len_byte = packet[idx];
if len_byte == 0 {
idx += 1;
break;
}
if (len_byte & POINTER_MASK) == POINTER_MASK {
return Err(DnsParseError::CompressionRejected);
}
let label_len = len_byte as usize;
if label_len > MAX_LABEL_LEN {
return Err(DnsParseError::LabelOverflow);
}
idx += 1;
if idx + label_len > packet.len() {
return Err(DnsParseError::TooShort);
}
let label = &packet[idx..idx + label_len];
for &b in label {
if b == 0 || !(0x20..=0x7e).contains(&b) {
return Err(DnsParseError::InvalidLabelByte(b));
}
}
if !qname.is_empty() {
qname.push('.');
}
for &b in label {
qname.push((b as char).to_ascii_lowercase());
}
if qname.len() > MAX_QNAME_LEN {
return Err(DnsParseError::NameOverflow);
}
idx += label_len;
}
if idx + 4 > packet.len() {
return Err(DnsParseError::TooShort);
}
let qtype = u16::from_be_bytes([packet[idx], packet[idx + 1]]);
let qclass = u16::from_be_bytes([packet[idx + 2], packet[idx + 3]]);
if qclass != QCLASS_IN {
return Err(DnsParseError::UnsupportedClass(qclass));
}
let view = DnsQueryView {
txn_id,
flags,
qname,
qtype,
qclass,
};
Ok((view, idx + 4))
}
#[cfg(test)]
mod tests {
use super::*;
fn build_query(qname: &str, qtype: u16, qclass: u16) -> Vec<u8> {
let mut p = Vec::new();
p.extend_from_slice(&[
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
]);
for label in qname.split('.') {
p.push(label.len() as u8);
p.extend_from_slice(label.as_bytes());
}
p.push(0); p.extend_from_slice(&qtype.to_be_bytes());
p.extend_from_slice(&qclass.to_be_bytes());
p
}
fn append_question(p: &mut Vec<u8>, qname: &str, qtype: u16, qclass: u16) {
for label in qname.split('.') {
p.push(label.len() as u8);
p.extend_from_slice(label.as_bytes());
}
p.push(0);
p.extend_from_slice(&qtype.to_be_bytes());
p.extend_from_slice(&qclass.to_be_bytes());
}
fn build_multi_query(questions: &[(&str, u16, u16)]) -> Vec<u8> {
let mut p = Vec::new();
let qd = questions.len() as u16;
p.extend_from_slice(&[0x12, 0x34, 0x01, 0x00]);
p.extend_from_slice(&qd.to_be_bytes());
p.extend_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
for (name, qtype, qclass) in questions {
append_question(&mut p, name, *qtype, *qclass);
}
p
}
#[test]
fn parses_well_formed_a_query() {
let pkt = build_query("api.example.com", 1, 1);
let v = parse_query(&pkt).expect("parse ok");
assert_eq!(v.txn_id, 0x1234);
assert_eq!(v.qname, "api.example.com");
assert_eq!(v.qtype, 1);
assert_eq!(v.qclass, 1);
}
#[test]
fn parses_aaaa_query() {
let pkt = build_query("ipv6.example.com", 28, 1);
let v = parse_query(&pkt).expect("parse ok");
assert_eq!(v.qtype, 28);
assert_eq!(v.qname, "ipv6.example.com");
}
#[test]
fn lowercases_uppercase_qname() {
let pkt = build_query("API.Example.COM", 1, 1);
let v = parse_query(&pkt).expect("parse ok");
assert_eq!(v.qname, "api.example.com");
}
#[test]
fn parses_https_query_type_65() {
let pkt = build_query("svc.example.com", 65, 1);
let v = parse_query(&pkt).expect("parse ok");
assert_eq!(v.qtype, 65);
}
#[test]
fn rejects_truncated_header() {
let pkt = vec![0x00; 6]; assert_eq!(parse_query(&pkt), Err(DnsParseError::TooShort));
}
#[test]
fn rejects_truncated_qname() {
let mut pkt = vec![
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
pkt.push(5); pkt.extend_from_slice(b"abc"); assert_eq!(parse_query(&pkt), Err(DnsParseError::TooShort));
}
#[test]
fn rejects_truncated_qtype_qclass() {
let mut pkt = vec![
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
pkt.push(0); assert_eq!(parse_query(&pkt), Err(DnsParseError::TooShort));
}
#[test]
fn parse_query_rejects_qdcount_two() {
let pkt = build_multi_query(&[
("allowed.example.com", 1, 1),
("attacker.tld", 1, 1), ]);
assert_eq!(parse_query(&pkt), Err(DnsParseError::QdcountUnsupported(2)));
}
#[test]
fn parse_query_rejects_qdcount_three() {
let pkt = build_multi_query(&[
("api.example.com", 1, 1),
("svc.example.com", 28, 1),
("alt.example.com", 65, 1),
]);
assert_eq!(parse_query(&pkt), Err(DnsParseError::QdcountUnsupported(3)));
}
#[test]
fn parse_query_rejects_qdcount_max() {
let mut pkt = vec![0x12, 0x34, 0x01, 0x00];
pkt.extend_from_slice(&u16::MAX.to_be_bytes()); pkt.extend_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
append_question(&mut pkt, "api.example.com", 1, 1);
assert_eq!(
parse_query(&pkt),
Err(DnsParseError::QdcountUnsupported(u16::MAX))
);
}
#[test]
fn rejects_qdcount_zero_single() {
let pkt = vec![
0x12, 0x34, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
assert_eq!(parse_query(&pkt), Err(DnsParseError::QdcountZero));
}
#[test]
fn rejects_qdcount_zero_multi() {
let pkt = vec![
0x12, 0x34, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
assert_eq!(parse_query_multi(&pkt), Err(DnsParseError::QdcountZero));
}
#[test]
fn rejects_pointer_compression() {
let mut pkt = vec![
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
pkt.push(0xc0);
pkt.push(0x0c); pkt.extend_from_slice(&[0, 1, 0, 1]);
assert_eq!(parse_query(&pkt), Err(DnsParseError::CompressionRejected));
}
#[test]
fn rejects_oversized_label() {
let mut pkt = vec![
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
pkt.push(64);
pkt.extend(std::iter::repeat_n(b'a', 64));
pkt.push(0);
pkt.extend_from_slice(&[0, 1, 0, 1]);
assert_eq!(parse_query(&pkt), Err(DnsParseError::LabelOverflow));
}
#[test]
fn rejects_oversized_qname() {
let mut pkt = vec![
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
for _ in 0..6 {
pkt.push(50);
pkt.extend(std::iter::repeat_n(b'a', 50));
}
pkt.push(0);
pkt.extend_from_slice(&[0, 1, 0, 1]);
assert_eq!(parse_query(&pkt), Err(DnsParseError::NameOverflow));
}
#[test]
fn rejects_non_in_class() {
let pkt = build_query("api.example.com", 1, 3);
assert_eq!(parse_query(&pkt), Err(DnsParseError::UnsupportedClass(3)));
}
#[test]
fn rejects_invalid_label_byte() {
let mut pkt = vec![
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
pkt.push(3);
pkt.extend_from_slice(&[b'a', 0x00, b'b']);
pkt.push(0);
pkt.extend_from_slice(&[0, 1, 0, 1]);
assert!(matches!(
parse_query(&pkt),
Err(DnsParseError::InvalidLabelByte(_))
));
}
#[test]
fn parses_root_only_query_as_empty_name() {
let mut pkt = vec![
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
pkt.push(0);
pkt.extend_from_slice(&[0, 1, 0, 1]);
let v = parse_query(&pkt).expect("parse ok");
assert_eq!(v.qname, "");
}
#[test]
fn parse_query_multi_returns_three_distinct_outcomes_for_qdcount_three() {
let pkt = build_multi_query(&[
("api.example.com", 1, 1),
("svc.example.com", 28, 1),
("alt.example.com", 65, 1),
]);
let outcomes = parse_query_multi(&pkt).expect("parse ok");
assert_eq!(outcomes.len(), 3);
let v0 = outcomes[0].as_view().expect("variant carries view");
let v1 = outcomes[1].as_view().expect("variant carries view");
let v2 = outcomes[2].as_view().expect("variant carries view");
assert_eq!(v0.qname, "api.example.com");
assert_eq!(v0.qtype, 1);
assert_eq!(v1.qname, "svc.example.com");
assert_eq!(v1.qtype, 28);
assert_eq!(v2.qname, "alt.example.com");
assert_eq!(v2.qtype, 65);
assert_ne!(v0, v1);
assert_ne!(v1, v2);
assert_ne!(v0, v2);
}
#[test]
fn parse_query_multi_truncation_in_second_question_returns_too_short() {
let mut pkt = vec![0x12, 0x34, 0x01, 0x00, 0x00, 0x02];
pkt.extend_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
append_question(&mut pkt, "api.example.com", 1, 1);
pkt.push(5);
pkt.extend_from_slice(b"ab");
assert_eq!(parse_query_multi(&pkt), Err(DnsParseError::TooShort));
}
#[test]
fn parse_query_multi_unsupported_class_in_second_question() {
let pkt = build_multi_query(&[("api.example.com", 1, 1), ("evil.example.com", 1, 3)]);
assert_eq!(
parse_query_multi(&pkt),
Err(DnsParseError::UnsupportedClass(3))
);
}
#[test]
fn parse_query_and_parse_query_multi_agree_on_single_question_packets() {
let pkt = build_query("api.example.com", 28, 1);
let single = parse_query(&pkt).expect("single ok");
let multi = parse_query_multi(&pkt).expect("multi ok");
assert_eq!(multi.len(), 1);
let v_multi = multi[0].as_view().expect("variant carries view");
assert_eq!(&single, v_multi);
}
}