use bytes::Bytes;
use crate::codec::{
Error,
header::{Header, Opcode},
name::Name,
reader::Reader,
writer::Writer,
};
pub const MAX_MESSAGE_LEN: usize = 65535;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Qtype {
A,
Aaaa,
Ptr,
Other(u16),
}
impl From<u16> for Qtype {
fn from(v: u16) -> Self {
match v {
1 => Self::A,
28 => Self::Aaaa,
12 => Self::Ptr,
other => Self::Other(other),
}
}
}
impl From<Qtype> for u16 {
fn from(qt: Qtype) -> u16 {
match qt {
Qtype::A => 1,
Qtype::Aaaa => 28,
Qtype::Ptr => 12,
Qtype::Other(v) => v,
}
}
}
impl Qtype {
fn well_known_name(value: u16) -> Option<&'static str> {
Some(match value {
2 => "NS",
5 => "CNAME",
6 => "SOA",
15 => "MX",
16 => "TXT",
33 => "SRV",
35 => "NAPTR",
39 => "DNAME",
43 => "DS",
46 => "RRSIG",
47 => "NSEC",
48 => "DNSKEY",
50 => "NSEC3",
52 => "TLSA",
64 => "SVCB",
65 => "HTTPS",
255 => "ANY",
257 => "CAA",
_ => return None,
})
}
}
impl std::fmt::Display for Qtype {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::A => f.write_str("A"),
Self::Aaaa => f.write_str("AAAA"),
Self::Ptr => f.write_str("PTR"),
Self::Other(v) => match Self::well_known_name(*v) {
Some(name) => f.write_str(name),
None => write!(f, "TYPE{v}"),
},
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Qclass {
In,
Other(u16),
}
impl From<u16> for Qclass {
fn from(v: u16) -> Self {
match v {
1 => Self::In,
other => Self::Other(other),
}
}
}
impl From<Qclass> for u16 {
fn from(qc: Qclass) -> u16 {
match qc {
Qclass::In => 1,
Qclass::Other(v) => v,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Question {
pub name: Name,
pub qtype: Qtype,
pub qclass: Qclass,
}
impl Question {
pub fn read(reader: &mut Reader) -> Result<Self, Error> {
let name = Name::read_question(reader)?;
let qtype = Qtype::from(reader.read_u16()?);
let qclass = Qclass::from(reader.read_u16()?);
Ok(Self {
name,
qtype,
qclass,
})
}
pub fn write(&self, writer: &mut Writer) {
self.name.write(writer);
writer.write_u16(u16::from(self.qtype));
writer.write_u16(u16::from(self.qclass));
}
}
#[derive(Debug)]
pub struct ParseError {
pub id: Option<u16>,
pub kind: Error,
}
impl ParseError {
fn without_id(kind: Error) -> Self {
Self { id: None, kind }
}
fn with_id(id: u16, kind: Error) -> Self {
Self { id: Some(id), kind }
}
#[must_use]
pub fn reject_action(&self) -> RejectAction {
match (self.id, &self.kind) {
(None, _) => RejectAction::Drop,
(Some(_), Error::NotARequest) => RejectAction::Drop,
(Some(id), Error::UnsupportedOpcode(_)) => RejectAction::NotImp(id),
(Some(id), _) => RejectAction::FormErr(id),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RejectAction {
Drop,
FormErr(u16),
NotImp(u16),
}
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.id {
Some(id) => write!(f, "DNS parse error (id={id:#06x}): {}", self.kind),
None => write!(f, "DNS parse error (id=unknown): {}", self.kind),
}
}
}
impl std::error::Error for ParseError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.kind)
}
}
#[derive(Debug, Clone)]
pub struct Query {
raw: Bytes,
header: Header,
question: Question,
question_end: usize,
}
impl Query {
pub fn parse(raw: Bytes) -> Result<Self, ParseError> {
if raw.len() > MAX_MESSAGE_LEN {
return Err(ParseError::without_id(Error::MessageTooLong(raw.len())));
}
let mut reader = Reader::new(raw.clone());
let header = Header::read(&mut reader).map_err(ParseError::without_id)?;
if header.qr() {
return Err(ParseError::with_id(header.id, Error::NotARequest));
}
if header.opcode() != Opcode::Query {
return Err(ParseError::with_id(
header.id,
Error::UnsupportedOpcode(u8::from(header.opcode())),
));
}
if header.qdcount != 1 {
return Err(ParseError::with_id(
header.id,
Error::InvalidQuestionCount(header.qdcount),
));
}
let question =
Question::read(&mut reader).map_err(|e| ParseError::with_id(header.id, e))?;
let question_end = reader.position();
Ok(Self {
raw,
header,
question,
question_end,
})
}
#[must_use]
pub fn raw(&self) -> &Bytes {
&self.raw
}
#[must_use]
pub fn header(&self) -> &Header {
&self.header
}
#[must_use]
pub fn question(&self) -> &Question {
&self.question
}
#[must_use]
pub fn question_end(&self) -> usize {
self.question_end
}
#[must_use]
pub fn question_wire(&self) -> Bytes {
self.raw.slice(12..self.question_end)
}
}
impl TryFrom<Bytes> for Query {
type Error = ParseError;
fn try_from(raw: Bytes) -> Result<Self, Self::Error> {
Query::parse(raw)
}
}
impl TryFrom<&[u8]> for Query {
type Error = ParseError;
fn try_from(raw: &[u8]) -> Result<Self, Self::Error> {
Query::parse(Bytes::copy_from_slice(raw))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::{header::Header, name::Name, writer::Writer};
fn build_query(
id: u16,
rd: bool,
name: &str,
qtype: u16,
qclass: u16,
qdcount_override: Option<u16>,
) -> Bytes {
let mut w = Writer::with_capacity(64);
let qdcount = qdcount_override.unwrap_or(1);
let hdr = Header::new(id).with_rd(rd).with_qdcount(qdcount);
hdr.write(&mut w);
if qdcount_override.is_none() || qdcount_override == Some(1) {
let n: Name = name.parse().expect("valid name in test helper");
n.write(&mut w);
w.write_u16(qtype);
w.write_u16(qclass);
}
w.finish()
}
fn build_query_with_bad_qdcount(id: u16, name: &str, qdcount: u16) -> Bytes {
let mut w = Writer::with_capacity(64);
let hdr = Header::new(id).with_qdcount(qdcount);
hdr.write(&mut w);
let n: Name = name.parse().unwrap();
n.write(&mut w);
w.write_u16(1u16); w.write_u16(1u16); w.finish()
}
#[test]
fn qtype_a_round_trips() {
assert_eq!(Qtype::from(1u16), Qtype::A);
assert_eq!(u16::from(Qtype::A), 1u16);
}
#[test]
fn qtype_aaaa_round_trips() {
assert_eq!(Qtype::from(28u16), Qtype::Aaaa);
assert_eq!(u16::from(Qtype::Aaaa), 28u16);
}
#[test]
fn qtype_ptr_round_trips_and_displays() {
assert_eq!(Qtype::from(12u16), Qtype::Ptr);
assert_eq!(u16::from(Qtype::Ptr), 12u16);
assert_eq!(Qtype::Ptr.to_string(), "PTR");
assert_ne!(Qtype::from(12u16), Qtype::Other(12));
}
#[test]
fn qtype_other_preserved() {
assert_eq!(Qtype::from(255u16), Qtype::Other(255));
assert_eq!(u16::from(Qtype::Other(255)), 255u16);
assert_eq!(Qtype::from(15u16), Qtype::Other(15));
assert_eq!(u16::from(Qtype::Other(15)), 15u16);
}
#[test]
fn qtype_display_uses_mnemonic_and_rfc3597_generic() {
assert_eq!(Qtype::A.to_string(), "A");
assert_eq!(Qtype::Aaaa.to_string(), "AAAA");
assert_eq!(Qtype::Other(65).to_string(), "HTTPS");
assert_eq!(Qtype::Other(64).to_string(), "SVCB");
assert_eq!(Qtype::Other(5).to_string(), "CNAME");
assert_eq!(Qtype::Other(15).to_string(), "MX");
assert_eq!(Qtype::Other(1000).to_string(), "TYPE1000");
}
#[test]
fn qtype_all_u16_round_trip() {
for v in 0u16..=65535 {
let qt = Qtype::from(v);
let back = u16::from(qt);
assert_eq!(back, v, "Qtype u16 round-trip failed for {v}");
}
}
#[test]
fn qclass_in_round_trips() {
assert_eq!(Qclass::from(1u16), Qclass::In);
assert_eq!(u16::from(Qclass::In), 1u16);
}
#[test]
fn qclass_other_preserved() {
assert_eq!(Qclass::from(3u16), Qclass::Other(3));
assert_eq!(u16::from(Qclass::Other(3)), 3u16);
}
#[test]
fn qclass_all_u16_round_trip() {
for v in 0u16..=65535 {
let qc = Qclass::from(v);
let back = u16::from(qc);
assert_eq!(back, v, "Qclass u16 round-trip failed for {v}");
}
}
#[test]
fn parse_valid_a_query() {
let raw = build_query(0x1234, true, "example.com", 1, 1, None);
let q = Query::try_from(raw).expect("valid A query should parse");
assert_eq!(q.header().id, 0x1234, "id mismatch");
assert!(!q.header().qr(), "QR should be 0 (query)");
assert!(q.header().rd(), "RD should be set");
assert_eq!(q.header().qdcount, 1);
let question = q.question();
assert_eq!(question.name.to_string(), "example.com.", "QNAME mismatch");
assert_eq!(question.qtype, Qtype::A, "QTYPE should be A");
assert_eq!(question.qclass, Qclass::In, "QCLASS should be IN");
}
#[test]
fn parse_valid_aaaa_query() {
let raw = build_query(0xABCD, true, "example.com", 28, 1, None);
let q = Query::try_from(raw).expect("valid AAAA query should parse");
assert_eq!(q.header().id, 0xABCD);
assert!(!q.header().qr());
assert!(q.header().rd());
assert_eq!(q.question().qtype, Qtype::Aaaa, "QTYPE should be AAAA");
assert_eq!(q.question().qclass, Qclass::In);
assert_eq!(q.question().name.to_string(), "example.com.");
}
#[test]
fn parse_query_with_trailing_bytes_accepted() {
let mut raw = build_query(0x0001, false, "example.com", 1, 1, None).to_vec();
raw.extend_from_slice(&[
0x00, 0x00, 0x29, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
let bytes = Bytes::from(raw);
let q = Query::try_from(bytes.clone()).expect("trailing bytes must not cause error");
assert_eq!(
q.raw().len(),
bytes.len(),
"raw must hold the full datagram"
);
assert_eq!(q.question().name.to_string(), "example.com.");
assert_eq!(q.question().qtype, Qtype::A);
}
#[test]
fn parse_raw_field_is_full_datagram() {
let raw = build_query(0x5678, true, "test.example", 1, 1, None);
let expected_len = raw.len();
let q = Query::try_from(raw).unwrap();
assert_eq!(q.raw().len(), expected_len);
}
#[test]
fn qdcount_zero_rejected_with_id() {
let raw = build_query_with_bad_qdcount(0x1111, "example.com", 0);
let err = Query::try_from(raw).expect_err("QDCOUNT=0 must fail");
assert!(
matches!(err.kind, Error::InvalidQuestionCount(0)),
"unexpected error kind: {:?}",
err.kind
);
assert_eq!(err.id, Some(0x1111), "id must be Some when header was read");
}
#[test]
fn qdcount_two_rejected_with_id() {
let raw = build_query_with_bad_qdcount(0x2222, "example.com", 2);
let err = Query::try_from(raw).expect_err("QDCOUNT=2 must fail");
assert!(
matches!(err.kind, Error::InvalidQuestionCount(2)),
"unexpected error kind: {:?}",
err.kind
);
assert_eq!(err.id, Some(0x2222));
}
fn build_kind(id: u16, qr: bool, opcode: Opcode, name: &str) -> Bytes {
let mut w = Writer::with_capacity(64);
Header::new(id)
.with_qr(qr)
.with_opcode(opcode)
.with_qdcount(1)
.write(&mut w);
let n: Name = name.parse().unwrap();
n.write(&mut w);
w.write_u16(1); w.write_u16(1); w.finish()
}
#[test]
fn response_packet_rejected_as_not_a_request() {
let raw = build_kind(0x3333, true, Opcode::Query, "example.com");
let err = Query::try_from(raw).expect_err("QR=1 must be rejected");
assert!(
matches!(err.kind, Error::NotARequest),
"unexpected error kind: {:?}",
err.kind
);
assert_eq!(err.reject_action(), RejectAction::Drop);
}
#[test]
fn unsupported_opcode_rejected_with_notimp() {
let raw = build_kind(0x4444, false, Opcode::Other(5), "example.com");
let err = Query::try_from(raw).expect_err("non-QUERY opcode must be rejected");
assert!(
matches!(err.kind, Error::UnsupportedOpcode(5)),
"unexpected error kind: {:?}",
err.kind
);
assert_eq!(err.reject_action(), RejectAction::NotImp(0x4444));
}
#[test]
fn standard_query_opcode_still_parses() {
let raw = build_kind(0x5555, false, Opcode::Query, "example.com");
let q = Query::try_from(raw).expect("standard query must parse");
assert_eq!(q.header().id, 0x5555);
}
#[test]
fn malformed_query_still_maps_to_formerr() {
let raw = build_query_with_bad_qdcount(0x6666, "example.com", 0);
let err = Query::try_from(raw).expect_err("QDCOUNT=0 must fail");
assert_eq!(err.reject_action(), RejectAction::FormErr(0x6666));
}
#[test]
fn compression_pointer_in_question_rejected_with_id() {
let mut w = Writer::with_capacity(16);
Header::new(0x3333).with_qdcount(1).write(&mut w);
w.write_u8(0xC0);
w.write_u8(0x0C);
let raw = w.finish();
let err = Query::try_from(raw).expect_err("compression pointer must fail");
assert!(
matches!(err.kind, Error::CompressionPointerInQuestion),
"unexpected error kind: {:?}",
err.kind
);
assert_eq!(err.id, Some(0x3333), "id must be Some");
}
#[test]
fn truncated_question_name_rejected_with_id() {
let mut w = Writer::with_capacity(16);
Header::new(0x4444).with_qdcount(1).write(&mut w);
w.write_u8(7); let raw = w.finish();
let err = Query::try_from(raw).expect_err("truncated question must fail");
assert!(
matches!(err.kind, Error::UnexpectedEof { .. }),
"unexpected error kind: {:?}",
err.kind
);
assert_eq!(err.id, Some(0x4444), "id must be Some");
}
#[test]
fn truncated_question_qtype_rejected_with_id() {
let mut w = Writer::with_capacity(32);
Header::new(0x5555).with_qdcount(1).write(&mut w);
let name: Name = "example.com".parse().unwrap();
name.write(&mut w);
w.write_u8(0x00); let raw = w.finish();
let err = Query::try_from(raw).expect_err("truncated QTYPE must fail");
assert!(
matches!(err.kind, Error::UnexpectedEof { .. }),
"unexpected error kind: {:?}",
err.kind
);
assert_eq!(err.id, Some(0x5555));
}
#[test]
fn truncated_question_qclass_rejected_with_id() {
let mut w = Writer::with_capacity(32);
Header::new(0x6666).with_qdcount(1).write(&mut w);
let name: Name = "example.com".parse().unwrap();
name.write(&mut w);
w.write_u16(1u16); w.write_u8(0x00); let raw = w.finish();
let err = Query::try_from(raw).expect_err("truncated QCLASS must fail");
assert!(
matches!(err.kind, Error::UnexpectedEof { .. }),
"unexpected error kind: {:?}",
err.kind
);
assert_eq!(err.id, Some(0x6666));
}
#[test]
fn label_too_long_in_question_rejected_with_id() {
let mut w = Writer::with_capacity(32);
Header::new(0x7777).with_qdcount(1).write(&mut w);
w.write_u8(64); let label_bytes = vec![b'a'; 64];
w.write_slice(&label_bytes);
w.write_u8(0); let raw = w.finish();
let err = Query::try_from(raw).expect_err("label too long must fail");
assert!(
matches!(err.kind, Error::LabelTooLong(64)),
"unexpected error kind: {:?}",
err.kind
);
assert_eq!(err.id, Some(0x7777));
}
#[test]
fn buffer_shorter_than_12_bytes_id_is_none() {
for n in 0..12usize {
let raw = Bytes::from(vec![0xAAu8; n]);
let err = Query::try_from(raw).expect_err("short buffer must fail");
assert!(
matches!(err.kind, Error::MessageTooShort(_)),
"n={n}: unexpected error kind: {:?}",
err.kind
);
assert_eq!(
err.id, None,
"n={n}: id must be None when header is unreadable"
);
}
}
#[test]
fn oversized_buffer_rejected_with_id_none() {
let base = build_query(0x9999, false, "example.com", 1, 1, None);
let mut raw = base.to_vec();
raw.resize(65536, 0u8);
let err = Query::try_from(raw.as_slice()).expect_err("oversized buffer must fail");
assert!(
matches!(err.kind, Error::MessageTooLong(65536)),
"unexpected error kind: {:?}",
err.kind
);
assert_eq!(err.id, None, "id must be None for oversized message");
}
#[test]
fn message_at_max_len_accepted() {
let base = build_query(0x0001, false, "example.com", 1, 1, None);
let mut raw = base.to_vec();
raw.resize(65535, 0u8);
let result = Query::try_from(raw.as_slice());
if let Err(e) = &result {
assert!(
!matches!(e.kind, Error::MessageTooLong(_)),
"65535-byte message should not be rejected as too long"
);
}
}
#[test]
fn no_panic_empty_input() {
let _ = Query::try_from(Bytes::new());
}
#[test]
fn no_panic_all_zeros_12_bytes() {
let raw = Bytes::from(vec![0u8; 12]);
let result = Query::try_from(raw);
assert!(result.is_err());
}
#[test]
fn no_panic_all_zeros_100_bytes() {
let raw = Bytes::from(vec![0u8; 100]);
let _ = Query::try_from(raw);
}
#[test]
fn no_panic_all_ones_100_bytes() {
let data = vec![0xFFu8; 100];
let _ = Query::try_from(data.as_slice());
}
#[test]
fn no_panic_random_ish_bytes() {
let data: Vec<u8> = (0u8..=255).cycle().take(512).collect();
let _ = Query::try_from(data.as_slice());
}
#[test]
fn round_trip_a_query() {
let id = 0xBEEF;
let name_str = "www.example.com";
let qtype_val = 1u16; let qclass_val = 1u16;
let mut w = Writer::with_capacity(64);
let hdr = Header::new(id).with_qdcount(1).with_rd(true);
hdr.write(&mut w);
let name: Name = name_str.parse().unwrap();
name.write(&mut w);
w.write_u16(qtype_val);
w.write_u16(qclass_val);
let raw = w.finish();
let q = Query::try_from(raw).expect("round-trip must succeed");
assert_eq!(q.header().id, id);
assert_eq!(q.header().qdcount, 1);
assert!(q.header().rd());
assert!(!q.header().qr());
assert_eq!(q.question().name.to_string(), "www.example.com.");
assert_eq!(q.question().qtype, Qtype::A);
assert_eq!(q.question().qclass, Qclass::In);
}
#[test]
fn round_trip_aaaa_query() {
let mut w = Writer::with_capacity(64);
let hdr = Header::new(0x1111).with_qdcount(1).with_rd(true);
hdr.write(&mut w);
let name: Name = "ipv6.example.com".parse().unwrap();
name.write(&mut w);
w.write_u16(28u16); w.write_u16(1u16); let raw = w.finish();
let q = Query::try_from(raw).unwrap();
assert_eq!(q.header().id, 0x1111);
assert_eq!(q.question().qtype, Qtype::Aaaa);
assert_eq!(q.question().name.to_string(), "ipv6.example.com.");
}
#[test]
fn question_write_read_round_trip() {
let original = Question {
name: "sub.domain.test".parse().unwrap(),
qtype: Qtype::Aaaa,
qclass: Qclass::In,
};
let mut w = Writer::new();
original.write(&mut w);
let bytes = w.finish();
let mut reader = Reader::new(bytes);
let decoded = Question::read(&mut reader).unwrap();
assert_eq!(decoded.name, original.name);
assert_eq!(decoded.qtype, original.qtype);
assert_eq!(decoded.qclass, original.qclass);
}
#[test]
fn parse_error_display_with_id() {
let e = ParseError::with_id(0xABCD, Error::InvalidQuestionCount(0));
let s = e.to_string();
assert!(
s.contains("0xabcd") || s.contains("0xABCD") || s.contains("abcd"),
"display should include id: {s}"
);
}
#[test]
fn parse_error_display_without_id() {
let e = ParseError::without_id(Error::MessageTooShort(5));
let s = e.to_string();
assert!(
s.contains("unknown"),
"display should indicate unknown id: {s}"
);
}
#[test]
fn try_from_slice_copies_and_parses() {
let raw = build_query(0x1234, true, "slice.test", 1, 1, None);
let slice: &[u8] = &raw[..];
let q = Query::try_from(slice).expect("TryFrom<&[u8]> must work");
assert_eq!(q.header().id, 0x1234);
assert_eq!(q.question().name.to_string(), "slice.test.");
}
}