use std::net::{Ipv4Addr, Ipv6Addr};
use bytes::Bytes;
use crate::codec::{
header::{Header, Rcode},
message::{Qtype, Query},
name::Name,
reader::Reader,
ttl::OPT_TYPE,
writer::Writer,
};
pub const SERVER_UDP_PAYLOAD_SIZE: u16 = 1232;
const EDNS_OPTION_COOKIE: u16 = 10;
const OWNER_PTR: [u8; 2] = [0xC0, 0x0C];
const CLASS_IN: u16 = 1;
#[derive(Debug, Clone)]
pub struct EdnsInfo {
pub udp_payload_size: u16,
pub cookie: Option<Bytes>,
}
impl EdnsInfo {
#[must_use]
pub fn scan(query: &Query) -> Option<Self> {
Self::scan_inner(query)
}
fn scan_inner(query: &Query) -> Option<Self> {
let raw = query.raw();
let mut reader = Reader::new(raw.clone());
let header = Header::read(&mut reader).ok()?;
for _ in 0..header.qdcount {
Name::skip_rr(&mut reader).ok()?;
reader.read_u16().ok()?; reader.read_u16().ok()?; }
let an_ns_count = (header.ancount as usize).saturating_add(header.nscount as usize);
for _ in 0..an_ns_count {
Name::skip_rr(&mut reader).ok()?;
reader.read_u16().ok()?; reader.read_u16().ok()?; reader.read_u32().ok()?; let rdlength = reader.read_u16().ok()? as usize;
reader.read_slice(rdlength).ok()?;
}
for _ in 0..header.arcount {
Name::skip_rr(&mut reader).ok()?;
let rr_type = reader.read_u16().ok()?;
let rr_class = reader.read_u16().ok()?; reader.read_u32().ok()?; let rdlength = reader.read_u16().ok()? as usize;
let rdata = reader.read_slice(rdlength).ok()?;
if rr_type == OPT_TYPE {
let udp_payload_size = rr_class;
let cookie = Self::extract_cookie(&rdata);
return Some(Self {
udp_payload_size,
cookie,
});
}
}
None
}
fn extract_cookie(rdata: &Bytes) -> Option<Bytes> {
let mut pos = 0usize;
let data = rdata;
while pos + 4 <= data.len() {
let code = u16::from_be_bytes([data[pos], data[pos + 1]]);
let length = u16::from_be_bytes([data[pos + 2], data[pos + 3]]) as usize;
pos += 4;
let end = pos.checked_add(length)?;
if end > data.len() {
return None; }
if code == EDNS_OPTION_COOKIE {
return Some(data.slice(pos..end));
}
pos = end;
}
None
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BlockMode {
NxDomain,
Address {
v4: Ipv4Addr,
v6: Ipv6Addr,
},
}
impl BlockMode {
#[must_use]
pub fn null_ip() -> Self {
Self::Address {
v4: Ipv4Addr::UNSPECIFIED,
v6: Ipv6Addr::UNSPECIFIED,
}
}
}
pub struct Response;
impl Response {
#[must_use]
pub fn block(query: &Query, mode: &BlockMode, ttl: u32, edns: Option<&EdnsInfo>) -> Bytes {
match mode {
BlockMode::NxDomain => Self::build(query, Rcode::NxDomain, false, &[], edns),
BlockMode::Address { v4, v6 } => match query.question().qtype {
Qtype::A => {
let rdata = v4.octets();
Self::build(
query,
Rcode::NoError,
false,
&[AnswerRr {
rtype: 1,
ttl,
rdata: &rdata,
}],
edns,
)
}
Qtype::Aaaa => {
let rdata = v6.octets();
Self::build(
query,
Rcode::NoError,
false,
&[AnswerRr {
rtype: 28,
ttl,
rdata: &rdata,
}],
edns,
)
}
Qtype::Other(_) => {
Self::build(query, Rcode::NoError, false, &[], edns)
}
},
}
}
#[must_use]
pub fn local(
query: &Query,
records: &[LocalRecord<'_>],
ttl: u32,
edns: Option<&EdnsInfo>,
) -> Bytes {
let answers: Vec<AnswerRr<'_>> = records
.iter()
.map(|r| AnswerRr {
rtype: r.rtype,
ttl,
rdata: r.rdata,
})
.collect();
Self::build_authoritative(query, Rcode::NoError, &answers, edns)
}
#[must_use]
pub fn local_nodata(query: &Query, edns: Option<&EdnsInfo>) -> Bytes {
Self::build_authoritative(query, Rcode::NoError, &[], edns)
}
#[must_use]
pub fn error_response(query: &Query, rcode: Rcode, edns: Option<&EdnsInfo>) -> Bytes {
Self::build(query, rcode, false, &[], edns)
}
#[must_use]
pub fn formerr(id: u16) -> Bytes {
let mut w = Writer::with_capacity(12);
Header::new(id)
.with_qr(true)
.with_rcode(Rcode::FormErr)
.write(&mut w);
w.finish()
}
#[must_use]
pub fn truncated(query: &Query, edns: Option<&EdnsInfo>) -> Bytes {
Self::build_tc(query, edns)
}
fn build_tc(query: &Query, edns: Option<&EdnsInfo>) -> Bytes {
let arcount = if edns.is_some() { 1u16 } else { 0u16 };
let mut w = Writer::with_capacity(512);
Header::new(query.header().id)
.with_qr(true)
.with_tc(true)
.with_rd(query.header().rd())
.with_ra(true)
.with_rcode(Rcode::NoError)
.with_qdcount(1)
.with_arcount(arcount)
.write(&mut w);
w.write_slice(&query.question_wire());
if let Some(edns) = edns {
Self::write_opt(&mut w, edns);
}
w.finish()
}
fn build(
query: &Query,
rcode: Rcode,
aa: bool,
answers: &[AnswerRr<'_>],
edns: Option<&EdnsInfo>,
) -> Bytes {
let ancount = answers.len() as u16;
let arcount = if edns.is_some() { 1u16 } else { 0u16 };
let mut w = Writer::with_capacity(512);
let mut hdr = Header::new(query.header().id)
.with_qr(true)
.with_rd(query.header().rd())
.with_ra(true)
.with_rcode(rcode)
.with_qdcount(1)
.with_ancount(ancount)
.with_arcount(arcount);
if aa {
hdr.set_aa(true);
}
hdr.write(&mut w);
w.write_slice(&query.question_wire());
for rr in answers {
Self::write_answer_rr(&mut w, rr);
}
if let Some(edns) = edns {
Self::write_opt(&mut w, edns);
}
w.finish()
}
fn build_authoritative(
query: &Query,
rcode: Rcode,
answers: &[AnswerRr<'_>],
edns: Option<&EdnsInfo>,
) -> Bytes {
Self::build(query, rcode, true, answers, edns)
}
fn write_answer_rr(w: &mut Writer, rr: &AnswerRr<'_>) {
w.write_slice(&OWNER_PTR);
w.write_u16(rr.rtype);
w.write_u16(CLASS_IN);
w.write_u32(rr.ttl);
w.write_u16(rr.rdata.len() as u16);
w.write_slice(rr.rdata);
}
fn write_opt(w: &mut Writer, edns: &EdnsInfo) {
w.write_u8(0x00);
w.write_u16(OPT_TYPE);
w.write_u16(SERVER_UDP_PAYLOAD_SIZE);
w.write_u32(0);
if let Some(cookie) = &edns.cookie {
let opt_len: u16 = 4 + cookie.len() as u16;
w.write_u16(opt_len); w.write_u16(EDNS_OPTION_COOKIE); w.write_u16(cookie.len() as u16); w.write_slice(cookie); } else {
w.write_u16(0); }
}
}
struct AnswerRr<'a> {
rtype: u16,
ttl: u32,
rdata: &'a [u8],
}
pub struct LocalRecord<'a> {
pub rtype: u16,
pub rdata: &'a [u8],
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::{
header::Header, message::Query, name::Name, reader::Reader, writer::Writer,
};
use bytes::Bytes;
use std::net::{Ipv4Addr, Ipv6Addr};
fn build_query(id: u16, rd: bool, name: &str, qtype: u16) -> Bytes {
build_query_raw(id, rd, name, qtype, &[])
}
fn build_query_raw(id: u16, rd: bool, name: &str, qtype: u16, extra: &[u8]) -> Bytes {
let mut w = Writer::with_capacity(128);
Header::new(id).with_rd(rd).with_qdcount(1).write(&mut w);
let n: Name = name.parse().expect("valid name in test helper");
n.write(&mut w);
w.write_u16(qtype);
w.write_u16(1u16); w.write_slice(extra);
w.finish()
}
fn build_query_with_opt(
id: u16,
rd: bool,
name: &str,
qtype: u16,
udp_payload_size: u16,
cookie: Option<&[u8]>,
) -> Bytes {
let mut opt_bytes = Vec::new();
opt_bytes.push(0x00); opt_bytes.extend_from_slice(&41u16.to_be_bytes()); opt_bytes.extend_from_slice(&udp_payload_size.to_be_bytes()); opt_bytes.extend_from_slice(&0u32.to_be_bytes());
if let Some(c) = cookie {
let rdlength: u16 = 4 + c.len() as u16;
opt_bytes.extend_from_slice(&rdlength.to_be_bytes());
opt_bytes.extend_from_slice(&EDNS_OPTION_COOKIE.to_be_bytes());
opt_bytes.extend_from_slice(&(c.len() as u16).to_be_bytes());
opt_bytes.extend_from_slice(c);
} else {
opt_bytes.extend_from_slice(&0u16.to_be_bytes()); }
let mut w = Writer::with_capacity(128);
Header::new(id)
.with_rd(rd)
.with_qdcount(1)
.with_arcount(1)
.write(&mut w);
let n: Name = name.parse().expect("valid name in test helper");
n.write(&mut w);
w.write_u16(qtype);
w.write_u16(1u16); w.write_slice(&opt_bytes);
w.finish()
}
fn parse_response_header(resp: &Bytes) -> Header {
let mut r = Reader::new(resp.clone());
Header::read(&mut r).expect("valid response header")
}
fn read_first_answer(resp: &Bytes) -> (u16, u16, u32, Bytes) {
let mut r = Reader::new(resp.clone());
r.read_slice(12).unwrap();
Name::skip_rr(&mut r).unwrap();
r.read_u16().unwrap(); r.read_u16().unwrap(); let _ptr = r.read_u16().unwrap();
let rtype = r.read_u16().unwrap();
let class = r.read_u16().unwrap();
let ttl = r.read_u32().unwrap();
let rdlength = r.read_u16().unwrap() as usize;
let rdata = r.read_slice(rdlength).unwrap();
(rtype, class, ttl, rdata)
}
fn read_opt_rr(resp: &Bytes) -> Option<(u16, u16, u32, Bytes)> {
let hdr = parse_response_header(resp);
let mut r = Reader::new(resp.clone());
r.read_slice(12).unwrap(); Name::skip_rr(&mut r).unwrap();
r.read_u16().unwrap();
r.read_u16().unwrap();
for _ in 0..hdr.ancount {
Name::skip_rr(&mut r).unwrap();
r.read_u16().unwrap(); r.read_u16().unwrap(); r.read_u32().unwrap(); let rdlen = r.read_u16().unwrap() as usize;
r.read_slice(rdlen).unwrap();
}
for _ in 0..hdr.arcount {
Name::skip_rr(&mut r).unwrap();
let rtype = r.read_u16().unwrap();
let class = r.read_u16().unwrap();
let ttl = r.read_u32().unwrap();
let rdlen = r.read_u16().unwrap() as usize;
let rdata = r.read_slice(rdlen).unwrap();
if rtype == OPT_TYPE {
return Some((rtype, class, ttl, rdata));
}
}
None
}
#[test]
fn edns_scan_no_opt_returns_none() {
let raw = build_query(0x1234, true, "example.com", 1);
let query = Query::try_from(raw).unwrap();
assert!(EdnsInfo::scan(&query).is_none());
}
#[test]
fn edns_scan_opt_without_cookie() {
let raw = build_query_with_opt(0x1234, true, "example.com", 1, 4096, None);
let query = Query::try_from(raw).unwrap();
let info = EdnsInfo::scan(&query).expect("should find OPT");
assert_eq!(info.udp_payload_size, 4096);
assert!(info.cookie.is_none());
}
#[test]
fn edns_scan_opt_with_cookie() {
let cookie_bytes: &[u8] = &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
let raw =
build_query_with_opt(0x5678, true, "ads.example.com", 1, 1232, Some(cookie_bytes));
let query = Query::try_from(raw).unwrap();
let info = EdnsInfo::scan(&query).expect("should find OPT with cookie");
assert_eq!(info.udp_payload_size, 1232);
let got_cookie = info.cookie.expect("cookie must be present");
assert_eq!(&got_cookie[..], cookie_bytes);
}
#[test]
fn edns_scan_no_panic_on_malformed_additional() {
let mut w = Writer::with_capacity(64);
Header::new(0x1111)
.with_rd(true)
.with_qdcount(1)
.with_arcount(1)
.write(&mut w);
let n: Name = "example.com".parse().unwrap();
n.write(&mut w);
w.write_u16(1u16);
w.write_u16(1u16);
w.write_u8(0x00); let raw = w.finish();
let query = Query::try_from(raw).unwrap();
let result = EdnsInfo::scan(&query);
assert!(result.is_none());
}
#[test]
fn edns_scan_no_panic_on_all_zeros_additional() {
let mut w = Writer::with_capacity(64);
Header::new(0x2222)
.with_rd(true)
.with_qdcount(1)
.with_arcount(5)
.write(&mut w);
let n: Name = "example.com".parse().unwrap();
n.write(&mut w);
w.write_u16(1u16);
w.write_u16(1u16);
w.write_slice(&[0u8; 20]);
let raw = w.finish();
let query = Query::try_from(raw).unwrap();
let _ = EdnsInfo::scan(&query); }
#[test]
fn block_nxdomain_a_query() {
let raw = build_query(0xABCD, true, "ads.example.com", 1);
let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &BlockMode::NxDomain, 60, None);
let hdr = parse_response_header(&resp);
assert!(hdr.qr(), "QR must be set");
assert_eq!(hdr.id, 0xABCD, "ID must match");
assert!(hdr.rd(), "RD must be copied");
assert!(hdr.ra(), "RA must be set");
assert_eq!(hdr.rcode(), Rcode::NxDomain);
assert_eq!(hdr.qdcount, 1, "QDCOUNT must be 1");
assert_eq!(hdr.ancount, 0, "ANCOUNT must be 0 for NXDOMAIN");
assert_eq!(hdr.arcount, 0, "ARCOUNT must be 0 (no EDNS)");
}
#[test]
fn block_nxdomain_any_qtype() {
let raw = build_query(0x1111, false, "blocked.example", 15);
let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &BlockMode::NxDomain, 60, None);
let hdr = parse_response_header(&resp);
assert_eq!(hdr.rcode(), Rcode::NxDomain);
assert_eq!(hdr.ancount, 0);
}
#[test]
fn block_address_a_query_returns_configured_ip() {
let v4 = Ipv4Addr::new(127, 0, 0, 1);
let v6 = Ipv6Addr::UNSPECIFIED;
let mode = BlockMode::Address { v4, v6 };
let raw = build_query(0x1234, true, "blocked.example.com", 1);
let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &mode, 300, None);
let hdr = parse_response_header(&resp);
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 1);
let (rtype, class, ttl, rdata) = read_first_answer(&resp);
assert_eq!(rtype, 1, "TYPE A");
assert_eq!(class, 1, "CLASS IN");
assert_eq!(ttl, 300);
assert_eq!(&rdata[..], &v4.octets());
}
#[test]
fn block_null_ip_a_query() {
let mode = BlockMode::null_ip();
let raw = build_query(0x2345, false, "tracker.example", 1);
let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &mode, 60, None);
let hdr = parse_response_header(&resp);
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 1);
let (rtype, _, ttl, rdata) = read_first_answer(&resp);
assert_eq!(rtype, 1);
assert_eq!(ttl, 60);
assert_eq!(&rdata[..], &[0u8, 0, 0, 0]); }
#[test]
fn block_address_aaaa_query_returns_configured_ip() {
let v4 = Ipv4Addr::UNSPECIFIED;
let v6: Ipv6Addr = "2001:db8::1".parse().unwrap();
let mode = BlockMode::Address { v4, v6 };
let raw = build_query(0x5678, true, "blocked.example.com", 28);
let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &mode, 120, None);
let hdr = parse_response_header(&resp);
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 1);
let (rtype, class, ttl, rdata) = read_first_answer(&resp);
assert_eq!(rtype, 28, "TYPE AAAA");
assert_eq!(class, 1, "CLASS IN");
assert_eq!(ttl, 120);
assert_eq!(&rdata[..], &v6.octets());
}
#[test]
fn block_null_ip_aaaa_query() {
let mode = BlockMode::null_ip();
let raw = build_query(0x3456, false, "tracker.example", 28);
let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &mode, 60, None);
let hdr = parse_response_header(&resp);
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 1);
let (_, _, _, rdata) = read_first_answer(&resp);
assert_eq!(&rdata[..], &[0u8; 16]); }
#[test]
fn block_address_mx_qtype_is_nodata() {
let mode = BlockMode::null_ip();
let raw = build_query(0x4567, true, "blocked.example", 15); let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &mode, 60, None);
let hdr = parse_response_header(&resp);
assert_eq!(
hdr.rcode(),
Rcode::NoError,
"Address mode non-A/AAAA → NODATA"
);
assert_eq!(hdr.ancount, 0, "NODATA must have 0 answers");
}
#[test]
fn block_address_txt_qtype_is_nodata() {
let mode = BlockMode::null_ip();
let raw = build_query(0x5678, false, "blocked.example", 16); let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &mode, 60, None);
let hdr = parse_response_header(&resp);
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 0);
}
#[test]
fn block_nxdomain_mx_qtype() {
let raw = build_query(0x6789, false, "blocked.example", 15);
let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &BlockMode::NxDomain, 60, None);
let hdr = parse_response_header(&resp);
assert_eq!(hdr.rcode(), Rcode::NxDomain);
assert_eq!(hdr.ancount, 0);
}
#[test]
fn local_a_record() {
let raw = build_query(0xBEEF, true, "local.home", 1);
let query = Query::try_from(raw).unwrap();
let addr = Ipv4Addr::new(192, 168, 1, 1);
let rdata = addr.octets();
let records = [LocalRecord {
rtype: 1,
rdata: &rdata,
}];
let resp = Response::local(&query, &records, 3600, None);
let hdr = parse_response_header(&resp);
assert!(hdr.aa(), "local response must have AA=1");
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 1);
let (rtype, _, ttl, rdata_got) = read_first_answer(&resp);
assert_eq!(rtype, 1);
assert_eq!(ttl, 3600);
assert_eq!(&rdata_got[..], &addr.octets());
}
#[test]
fn local_aaaa_record() {
let raw = build_query(0xCAFE, true, "local.home", 28);
let query = Query::try_from(raw).unwrap();
let addr: Ipv6Addr = "fd00::1".parse().unwrap();
let rdata = addr.octets();
let records = [LocalRecord {
rtype: 28,
rdata: &rdata,
}];
let resp = Response::local(&query, &records, 3600, None);
let hdr = parse_response_header(&resp);
assert!(hdr.aa());
assert_eq!(hdr.ancount, 1);
let (rtype, _, _, rdata_got) = read_first_answer(&resp);
assert_eq!(rtype, 28);
assert_eq!(&rdata_got[..], &addr.octets());
}
#[test]
fn local_nodata_authoritative() {
let raw = build_query(0xDEAD, true, "local.home", 28);
let query = Query::try_from(raw).unwrap();
let resp = Response::local_nodata(&query, None);
let hdr = parse_response_header(&resp);
assert!(hdr.aa(), "NODATA must be authoritative");
assert_eq!(hdr.rcode(), Rcode::NoError);
assert_eq!(hdr.ancount, 0);
}
#[test]
fn error_response_servfail() {
let raw = build_query(0xF00D, true, "fail.example.com", 1);
let query = Query::try_from(raw).unwrap();
let resp = Response::error_response(&query, Rcode::ServFail, None);
let hdr = parse_response_header(&resp);
assert!(hdr.qr());
assert_eq!(hdr.id, 0xF00D);
assert!(hdr.rd(), "RD must be copied");
assert!(hdr.ra(), "RA must be set");
assert_eq!(hdr.rcode(), Rcode::ServFail);
assert_eq!(hdr.qdcount, 1);
assert_eq!(hdr.ancount, 0);
}
#[test]
fn error_response_refused() {
let raw = build_query(0x1111, false, "example.com", 1);
let query = Query::try_from(raw).unwrap();
let resp = Response::error_response(&query, Rcode::Refused, None);
let hdr = parse_response_header(&resp);
assert_eq!(hdr.rcode(), Rcode::Refused);
assert!(!hdr.rd(), "RD=0 copied from query");
assert_eq!(hdr.id, 0x1111);
}
#[test]
fn formerr_id_only() {
let resp = Response::formerr(0xDEAD);
let hdr = parse_response_header(&resp);
assert!(hdr.qr(), "QR must be set");
assert_eq!(hdr.id, 0xDEAD, "ID must match");
assert_eq!(hdr.rcode(), Rcode::FormErr);
assert_eq!(hdr.qdcount, 0, "FORMERR from id has no question");
assert_eq!(hdr.ancount, 0);
assert_eq!(hdr.arcount, 0);
assert!(!hdr.rd(), "RD must be 0 (no query to copy from)");
assert!(!hdr.ra(), "RA must be 0 for FORMERR-from-id");
}
#[test]
fn formerr_minimum_length() {
let resp = Response::formerr(0x0000);
assert_eq!(resp.len(), 12, "FORMERR from id must be exactly 12 bytes");
}
#[test]
fn question_wire_bytes_preserved_exactly() {
let mut w = Writer::with_capacity(64);
Header::new(0xABCD)
.with_rd(true)
.with_qdcount(1)
.write(&mut w);
w.write_u8(7);
w.write_slice(b"eXaMpLe");
w.write_u8(3);
w.write_slice(b"CoM");
w.write_u8(0);
w.write_u16(1u16); w.write_u16(1u16); let raw = w.finish();
let query = Query::try_from(raw.clone()).unwrap();
let question_start = 12usize;
let question_end = query.question_end();
let original_question_bytes = &raw[question_start..question_end];
let mode = BlockMode::null_ip();
let resp = Response::block(&query, &mode, 60, None);
let resp_question_bytes = &resp[question_start..question_end];
assert_eq!(
resp_question_bytes, original_question_bytes,
"question bytes in response must be byte-identical to query question bytes \
(DNS 0x20 case preservation)"
);
}
#[test]
fn edns_query_without_cookie_gets_opt_in_response() {
let raw = build_query_with_opt(0x9999, true, "blocked.example.com", 1, 4096, None);
let query = Query::try_from(raw).unwrap();
let edns = EdnsInfo::scan(&query).expect("OPT must be found");
let resp = Response::block(&query, &BlockMode::null_ip(), 60, Some(&edns));
let hdr = parse_response_header(&resp);
assert_eq!(hdr.arcount, 1, "ARCOUNT must be 1 with EDNS echo");
let (rtype, class, ttl, rdata) = read_opt_rr(&resp).expect("OPT must be in response");
assert_eq!(rtype, OPT_TYPE, "TYPE must be OPT (41)");
assert_eq!(
class, SERVER_UDP_PAYLOAD_SIZE,
"CLASS must be server UDP payload size"
);
assert_eq!(ttl, 0, "OPT TTL must be 0");
assert_eq!(
rdata.len(),
0,
"no COOKIE in response when query had no cookie"
);
}
#[test]
fn edns_query_with_cookie_reflected_in_response() {
let client_cookie: &[u8] = &[0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE, 0xBA, 0xBE];
let raw = build_query_with_opt(
0xAAAA,
true,
"blocked.example.com",
1,
1232,
Some(client_cookie),
);
let query = Query::try_from(raw).unwrap();
let edns = EdnsInfo::scan(&query).expect("OPT must be found");
let resp = Response::block(&query, &BlockMode::null_ip(), 60, Some(&edns));
let hdr = parse_response_header(&resp);
assert_eq!(hdr.arcount, 1);
let (_, _, _, opt_rdata) = read_opt_rr(&resp).expect("OPT must be in response");
assert!(opt_rdata.len() >= 4 + client_cookie.len());
let opt_code = u16::from_be_bytes([opt_rdata[0], opt_rdata[1]]);
let opt_len = u16::from_be_bytes([opt_rdata[2], opt_rdata[3]]) as usize;
let opt_data = &opt_rdata[4..4 + opt_len];
assert_eq!(
opt_code, EDNS_OPTION_COOKIE,
"OPTION-CODE must be 10 (COOKIE)"
);
assert_eq!(opt_data, client_cookie, "cookie must be reflected verbatim");
}
#[test]
fn non_edns_query_has_no_opt_in_response() {
let raw = build_query(0xBBBB, true, "blocked.example.com", 1);
let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &BlockMode::null_ip(), 60, None);
let hdr = parse_response_header(&resp);
assert_eq!(hdr.arcount, 0, "no EDNS in query → no OPT in response");
assert!(read_opt_rr(&resp).is_none());
}
#[test]
fn edns_block_nxdomain_includes_opt() {
let raw = build_query_with_opt(0xCCCC, true, "blocked.example", 1, 512, None);
let query = Query::try_from(raw).unwrap();
let edns = EdnsInfo::scan(&query).unwrap();
let resp = Response::block(&query, &BlockMode::NxDomain, 60, Some(&edns));
let hdr = parse_response_header(&resp);
assert_eq!(hdr.rcode(), Rcode::NxDomain);
assert_eq!(hdr.arcount, 1);
assert!(read_opt_rr(&resp).is_some());
}
#[test]
fn edns_servfail_includes_opt() {
let raw = build_query_with_opt(0xDDDD, true, "fail.example", 1, 1232, None);
let query = Query::try_from(raw).unwrap();
let edns = EdnsInfo::scan(&query).unwrap();
let resp = Response::error_response(&query, Rcode::ServFail, Some(&edns));
let hdr = parse_response_header(&resp);
assert_eq!(hdr.rcode(), Rcode::ServFail);
assert_eq!(hdr.arcount, 1);
assert!(read_opt_rr(&resp).is_some());
}
#[test]
fn block_response_copies_rd_flag() {
let raw_rd1 = build_query(0x1234, true, "example.com", 1);
let q1 = Query::try_from(raw_rd1).unwrap();
let resp1 = Response::block(&q1, &BlockMode::null_ip(), 60, None);
assert!(
parse_response_header(&resp1).rd(),
"RD must be copied (RD=1)"
);
let raw_rd0 = build_query(0x5678, false, "example.com", 1);
let q0 = Query::try_from(raw_rd0).unwrap();
let resp0 = Response::block(&q0, &BlockMode::null_ip(), 60, None);
assert!(
!parse_response_header(&resp0).rd(),
"RD must be copied (RD=0)"
);
}
#[test]
fn block_response_always_sets_qr_and_ra() {
let raw = build_query(0x9999, false, "example.com", 1);
let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &BlockMode::NxDomain, 60, None);
let hdr = parse_response_header(&resp);
assert!(hdr.qr(), "QR must be 1");
assert!(hdr.ra(), "RA must be 1");
}
#[test]
fn local_response_sets_aa_flag() {
let raw = build_query(0xAAAA, true, "local.home", 1);
let query = Query::try_from(raw).unwrap();
let rdata = Ipv4Addr::new(10, 0, 0, 1).octets();
let records = [LocalRecord {
rtype: 1,
rdata: &rdata,
}];
let resp = Response::local(&query, &records, 60, None);
assert!(parse_response_header(&resp).aa(), "local must set AA");
}
#[test]
fn answer_owner_is_compression_pointer() {
let mode = BlockMode::null_ip();
let raw = build_query(0x1234, true, "example.com", 1);
let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &mode, 60, None);
let qend = query.question_end();
assert_eq!(resp[qend], 0xC0, "first byte of owner must be 0xC0");
assert_eq!(resp[qend + 1], 0x0C, "second byte of owner must be 0x0C");
}
#[test]
fn round_trip_block_response_is_parseable() {
let raw = build_query(0x4321, true, "tracker.example.com", 1);
let query = Query::try_from(raw).unwrap();
let resp = Response::block(&query, &BlockMode::null_ip(), 300, None);
let hdr = parse_response_header(&resp);
assert!(hdr.qr());
assert_eq!(hdr.id, 0x4321);
assert_eq!(hdr.qdcount, 1);
assert_eq!(hdr.ancount, 1);
}
#[test]
fn truncated_sets_tc_qr_ra_and_no_answers() {
let raw = build_query(0xBEEF, true, "example.com", 1);
let query = Query::try_from(raw).unwrap();
let resp = Response::truncated(&query, None);
let hdr = parse_response_header(&resp);
assert!(hdr.qr(), "QR must be set");
assert!(hdr.tc(), "TC must be set");
assert!(hdr.ra(), "RA must be set");
assert!(hdr.rd(), "RD must be copied from query");
assert_eq!(hdr.id, 0xBEEF, "ID must match");
assert_eq!(hdr.rcode(), Rcode::NoError, "RCODE must be NOERROR");
assert_eq!(hdr.qdcount, 1, "QDCOUNT must be 1");
assert_eq!(hdr.ancount, 0, "ANCOUNT must be 0");
assert_eq!(hdr.arcount, 0, "ARCOUNT must be 0 (no EDNS)");
}
#[test]
fn truncated_echoes_question() {
let raw = build_query(0x1234, false, "truncate.test", 1);
let query = Query::try_from(raw.clone()).unwrap();
let resp = Response::truncated(&query, None);
let question_start = 12usize;
let question_end = query.question_end();
assert_eq!(
&resp[question_start..question_end],
&raw[question_start..question_end],
"question section must be echoed verbatim"
);
}
#[test]
fn truncated_with_edns_includes_opt() {
let raw = build_query_with_opt(0xCAFE, true, "large.example.com", 1, 4096, None);
let query = Query::try_from(raw).unwrap();
let edns = EdnsInfo::scan(&query).expect("OPT must be found");
let resp = Response::truncated(&query, Some(&edns));
let hdr = parse_response_header(&resp);
assert!(hdr.tc(), "TC must be set");
assert_eq!(hdr.arcount, 1, "ARCOUNT must be 1 with EDNS");
assert!(read_opt_rr(&resp).is_some(), "OPT must be present");
}
#[test]
fn null_ip_returns_unspecified_addresses() {
match BlockMode::null_ip() {
BlockMode::Address { v4, v6 } => {
assert_eq!(v4, Ipv4Addr::UNSPECIFIED);
assert_eq!(v6, Ipv6Addr::UNSPECIFIED);
}
_ => panic!("null_ip must return Address variant"),
}
}
}