use std::fmt;
#[derive(Debug)]
pub struct DnsPacket {
pub transaction_id: u16,
pub flags: u16,
pub questions: u16,
pub answers: u16,
pub authority_rrs: u16,
pub additional_rrs: u16,
pub queries: Vec<DnsQuery>,
}
impl fmt::Display for DnsPacket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"DNS Packet: transaction_id={:#X}, flags={:#X}, questions={}, answers={}, authority_rrs={}, additional_rrs={}, queries={:?}",
self.transaction_id, self.flags, self.questions, self.answers, self.authority_rrs, self.additional_rrs, self.queries
)
}
}
#[derive(Debug)]
pub struct DnsQuery {
pub name: String,
pub query_type: u16,
pub query_class: u16,
}
impl fmt::Display for DnsQuery {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"DNS Query: name={}, query_type={}, query_class={}",
self.name, self.query_type, self.query_class
)
}
}
pub fn parse_dns_packet(payload: &[u8]) -> Result<DnsPacket, bool> {
if is_payload_too_short(payload, 12) {
return Err(false);
}
let header = match parse_dns_header(payload) {
Ok(header) => header,
Err(_) => {
return Err(false);
}
};
let queries = match parse_dns_queries(payload, 12, header.questions) {
Ok(queries) => queries,
Err(_) => {
return Err(false);
}
};
Ok(DnsPacket {
transaction_id: header.transaction_id,
flags: header.flags,
questions: header.questions,
answers: header.answers,
authority_rrs: header.authority_rrs,
additional_rrs: header.additional_rrs,
queries,
})
}
fn parse_dns_header(payload: &[u8]) -> Result<DnsHeader, bool> {
if is_payload_too_short(payload, 12) {
return Err(false);
}
let transaction_id = u16::from_be_bytes([payload[0], payload[1]]);
let flags = u16::from_be_bytes([payload[2], payload[3]]);
let questions = u16::from_be_bytes([payload[4], payload[5]]);
let answers = u16::from_be_bytes([payload[6], payload[7]]);
let authority_rrs = u16::from_be_bytes([payload[8], payload[9]]);
let additional_rrs = u16::from_be_bytes([payload[10], payload[11]]);
if are_record_counts_unreasonable(questions, answers, authority_rrs, additional_rrs) {
return Err(false);
}
Ok(DnsHeader {
transaction_id,
flags,
questions,
answers,
authority_rrs,
additional_rrs,
})
}
fn parse_dns_queries(
payload: &[u8],
mut offset: usize,
questions: u16,
) -> Result<Vec<DnsQuery>, bool> {
let mut queries = Vec::new();
for _i in 0..questions {
let (name, new_offset) = match parse_dns_name(payload, offset) {
Ok(result) => result,
Err(_) => {
return Err(false);
}
};
offset = new_offset;
if is_payload_too_short(payload, offset + 4) {
return Err(false);
}
if name.is_empty() {
return Err(false);
}
let query_type = u16::from_be_bytes([payload[offset], payload[offset + 1]]);
let query_class = u16::from_be_bytes([payload[offset + 2], payload[offset + 3]]);
offset += 4;
queries.push(DnsQuery {
name,
query_type,
query_class,
});
if queries.is_empty() {
return Err(false);
}
}
Ok(queries)
}
fn parse_dns_name(payload: &[u8], mut offset: usize) -> Result<(String, usize), bool> {
let mut labels = Vec::new();
while offset < payload.len() {
let length = payload[offset] as usize;
if length == 0 {
offset += 1;
break;
}
if length & 0xC0 == 0xC0 {
return Err(false); }
offset += 1;
if is_payload_too_short(payload, offset + length) {
return Err(false);
}
labels.push(String::from_utf8_lossy(&payload[offset..offset + length]).to_string());
offset += length;
}
let name = labels.join(".");
Ok((name, offset))
}
fn is_payload_too_short(payload: &[u8], required_length: usize) -> bool {
payload.len() < required_length
}
fn are_record_counts_unreasonable(
questions: u16,
answers: u16,
authority_rrs: u16,
additional_rrs: u16,
) -> bool {
questions > 50 || answers > 50 || authority_rrs > 50 || additional_rrs > 50
}
#[derive(Debug)]
struct DnsHeader {
transaction_id: u16,
flags: u16,
questions: u16,
answers: u16,
authority_rrs: u16,
additional_rrs: u16,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_dns_packet() {
let dns_payload = vec![
0xdd, 0xc7, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, 0x77, 0x77, 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, ];
match parse_dns_packet(&dns_payload) {
Ok(packet) => {
assert_eq!(packet.transaction_id, 0xddc7);
assert_eq!(packet.flags, 0x0100);
assert_eq!(packet.questions, 1);
assert_eq!(packet.answers, 0);
assert_eq!(packet.authority_rrs, 0);
assert_eq!(packet.additional_rrs, 0);
assert_eq!(packet.queries.len(), 1);
let query = &packet.queries[0];
assert_eq!(query.name, "www.google.com");
assert_eq!(query.query_type, 1);
assert_eq!(query.query_class, 1);
}
Err(_) => panic!("Expected DNS packet"),
}
}
#[test]
fn test_parse_dns_packet_short_payload() {
let short_payload = vec![0xdd, 0xc7, 0x01, 0x00, 0x00, 0x01, 0x00];
match parse_dns_packet(&short_payload) {
Ok(_) => panic!("Expected invalid DNS packet due to short payload"),
Err(is_dns) => assert!(!is_dns),
}
}
#[test]
fn test_parse_dns_name() {
let dns_payload = vec![
0x03, 0x77, 0x77, 0x77, 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, ];
let (name, offset) = parse_dns_name(&dns_payload, 0).unwrap();
assert_eq!(name, "www.google.com");
assert_eq!(offset, dns_payload.len());
}
#[test]
fn test_parse_dns_name_invalid() {
let dns_payload = vec![0x10, 0x77, 0x77, 0x77];
assert!(parse_dns_name(&dns_payload, 0).is_err());
}
}