use std::{error::Error, fmt};
use errors::DnsQueryParseError;
use crate::utils::{dns_class::DnsClass, dns_types::DnsType};
mod errors;
#[derive(Debug, PartialEq)]
pub struct DnsQuery {
pub name: String,
pub qtype: DnsType,
pub qclass: DnsClass,
}
impl DnsQuery {
pub fn from_bytes(bytes: &[u8], offset: &mut usize) -> Result<Self, Box<dyn Error>> {
let (name, new_offset) = parse_name(bytes, *offset)?;
*offset = new_offset;
check_dns_query_size(bytes, *offset, 4)?;
let qtype = DnsType::new(u16::from_be_bytes([bytes[*offset], bytes[*offset + 1]]));
let qclass = DnsClass::new(u16::from_be_bytes([bytes[*offset + 2], bytes[*offset + 3]]));
*offset += 4;
Ok(DnsQuery {
name,
qtype,
qclass,
})
}
}
impl fmt::Display for DnsQuery {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"DnsQuery {{ name: {}, qtype: {}, qclass: {} }}",
self.name, self.qtype, self.qclass
)
}
}
fn check_dns_query_size(
bytes: &[u8],
offset: usize,
required_size: usize,
) -> Result<(), Box<dyn Error>> {
if offset + required_size > bytes.len() {
return Err(format!(
"Insufficient data: required {} more bytes at offset {}, but only {} bytes available",
required_size,
offset,
bytes.len() - offset
)
.into());
}
Ok(())
}
#[derive(Debug, PartialEq)]
pub struct DnsQueries {
pub queries: Vec<DnsQuery>,
}
impl DnsQueries {
pub fn from_bytes(bytes: &[u8], count: u16) -> Result<Self, Box<dyn Error>> {
let mut queries = Vec::with_capacity(count as usize);
let mut offset = 0;
for _ in 0..count {
check_dns_query_size(bytes, offset, 1)?;
queries.push(DnsQuery::from_bytes(bytes, &mut offset)?);
}
Ok(DnsQueries { queries })
}
}
impl fmt::Display for DnsQueries {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "DnsQueries {{ queries: [")?;
for query in &self.queries {
write!(f, " {},", query)?;
}
write!(f, "] }}")
}
}
fn parse_name(bytes: &[u8], mut offset: usize) -> Result<(String, usize), DnsQueryParseError> {
let mut labels = Vec::new();
loop {
if offset >= bytes.len() {
return Err(DnsQueryParseError::OutOfBoundParse);
}
let len = bytes[offset] as usize;
if len == 0 {
offset += 1;
break;
}
offset += 1;
if offset + len > bytes.len() {
return Err(DnsQueryParseError::OutOfBoundParse);
}
let label = String::from_utf8(bytes[offset..offset + len].to_vec())?;
labels.push(label);
offset += len;
}
let name = labels.join(".");
Ok((name, offset)) }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_name() {
let data = vec![
0x03, 0x77, 0x77, 0x77, 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, ];
let (name, offset) = parse_name(&data, 0).unwrap();
assert_eq!(name, "www.google.com");
assert_eq!(offset, 16);
}
#[test]
fn test_parse_name_invalid_utf8() {
let data = vec![
0x02, 0xFF, 0xFF, 0x00, ];
let result = parse_name(&data, 0);
assert!(result.is_err());
if let Err(DnsQueryParseError::Utf8Error(_)) = result {
} else {
panic!("Expected Utf8Error, but got {:?}", result);
}
}
#[test]
fn test_dns_query_from_bytes() {
let data = vec![
3, b'w', b'w', b'w', 6, b'g', b'o', b'o', b'g', b'l', b'e', 3, b'c', b'o', b'm', 0, 0,
1, 0, 1,
];
let mut offset = 0;
let query = DnsQuery::from_bytes(&data, &mut offset).unwrap();
assert_eq!(query.name, "www.google.com");
assert_eq!(query.qtype, DnsType(1));
assert_eq!(query.qclass, DnsClass(1));
assert_eq!(offset, 20);
}
#[test]
fn test_dns_queries_from_bytes() {
let data = vec![
3, b'w', b'w', b'w', 6, b'g', b'o', b'o', b'g', b'l', b'e', 3, b'c', b'o', b'm', 0, 0,
1, 0, 1, 3, b'f', b'o', b'o', 3, b'b', b'a', b'r', 3, b'c', b'o', b'm', 0, 0, 2, 0, 1,
];
let queries = DnsQueries::from_bytes(&data, 2).unwrap();
assert_eq!(queries.queries.len(), 2);
assert_eq!(queries.queries[0].name, "www.google.com");
assert_eq!(queries.queries[0].qtype, DnsType(1));
assert_eq!(queries.queries[0].qclass, DnsClass(1));
assert_eq!(queries.queries[1].name, "foo.bar.com");
assert_eq!(queries.queries[1].qtype, DnsType(2));
assert_eq!(queries.queries[1].qclass, DnsClass(1));
}
}