use super::question::Question;
use super::record::ResourceRecord;
use super::types::{OpCode, ResponseCode};
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Message {
id: u16,
qr: bool,
opcode: OpCode,
aa: bool,
tc: bool,
rd: bool,
ra: bool,
rcode: ResponseCode,
questions: Vec<Question>,
answers: Vec<ResourceRecord>,
authority: Vec<ResourceRecord>,
additional: Vec<ResourceRecord>,
}
impl Message {
pub fn new() -> Self {
Self {
id: 0,
qr: false,
opcode: OpCode::Query,
aa: false,
tc: false,
rd: true,
ra: false,
rcode: ResponseCode::NoError,
questions: Vec::new(),
answers: Vec::new(),
authority: Vec::new(),
additional: Vec::new(),
}
}
pub fn id(&self) -> u16 {
self.id
}
pub fn set_id(&mut self, id: u16) {
self.id = id;
}
pub fn is_response(&self) -> bool {
self.qr
}
pub fn set_query(&mut self, is_query: bool) {
self.qr = !is_query;
}
pub fn set_response(&mut self, is_response: bool) {
self.qr = is_response;
}
pub fn opcode(&self) -> OpCode {
self.opcode
}
pub fn set_opcode(&mut self, opcode: OpCode) {
self.opcode = opcode;
}
pub fn is_authoritative(&self) -> bool {
self.aa
}
pub fn set_authoritative(&mut self, aa: bool) {
self.aa = aa;
}
pub fn is_truncated(&self) -> bool {
self.tc
}
pub fn set_truncated(&mut self, tc: bool) {
self.tc = tc;
}
pub fn recursion_desired(&self) -> bool {
self.rd
}
pub fn set_recursion_desired(&mut self, rd: bool) {
self.rd = rd;
}
pub fn recursion_available(&self) -> bool {
self.ra
}
pub fn set_recursion_available(&mut self, ra: bool) {
self.ra = ra;
}
pub fn response_code(&self) -> ResponseCode {
self.rcode
}
pub fn set_response_code(&mut self, rcode: ResponseCode) {
self.rcode = rcode;
}
pub fn questions(&self) -> &[Question] {
&self.questions
}
pub fn questions_mut(&mut self) -> &mut Vec<Question> {
&mut self.questions
}
pub fn add_question(&mut self, question: Question) {
self.questions.push(question);
}
pub fn answers(&self) -> &[ResourceRecord] {
&self.answers
}
pub fn answers_mut(&mut self) -> &mut Vec<ResourceRecord> {
&mut self.answers
}
pub fn add_answer(&mut self, answer: ResourceRecord) {
self.answers.push(answer);
}
pub fn authority(&self) -> &[ResourceRecord] {
&self.authority
}
pub fn authority_mut(&mut self) -> &mut Vec<ResourceRecord> {
&mut self.authority
}
pub fn add_authority(&mut self, authority: ResourceRecord) {
self.authority.push(authority);
}
pub fn additional(&self) -> &[ResourceRecord] {
&self.additional
}
pub fn additional_mut(&mut self) -> &mut Vec<ResourceRecord> {
&mut self.additional
}
pub fn add_additional(&mut self, additional: ResourceRecord) {
self.additional.push(additional);
}
pub fn question_count(&self) -> usize {
self.questions.len()
}
pub fn answer_count(&self) -> usize {
self.answers.len()
}
pub fn authority_count(&self) -> usize {
self.authority.len()
}
pub fn additional_count(&self) -> usize {
self.additional.len()
}
pub fn clear_questions(&mut self) {
self.questions.clear();
}
pub fn clear_answers(&mut self) {
self.answers.clear();
}
pub fn clear_authority(&mut self) {
self.authority.clear();
}
pub fn clear_additional(&mut self) {
self.additional.clear();
}
}
impl Default for Message {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, ";; DNS Message")?;
writeln!(f, ";; ID: {}", self.id)?;
writeln!(
f,
";; QR: {} ({})",
self.qr,
if self.qr { "Response" } else { "Query" }
)?;
writeln!(f, ";; OPCODE: {:?}", self.opcode)?;
writeln!(f, ";; RCODE: {}", self.rcode)?;
writeln!(
f,
";; Flags: AA={} TC={} RD={} RA={}",
self.aa, self.tc, self.rd, self.ra
)?;
writeln!(
f,
";; Counts: QUESTION={} ANSWER={} AUTHORITY={} ADDITIONAL={}",
self.question_count(),
self.answer_count(),
self.authority_count(),
self.additional_count()
)?;
if !self.questions.is_empty() {
writeln!(f, "\n;; QUESTION SECTION:")?;
for question in &self.questions {
writeln!(f, "{}", question)?;
}
}
if !self.answers.is_empty() {
writeln!(f, "\n;; ANSWER SECTION:")?;
for answer in &self.answers {
writeln!(f, "{}", answer)?;
}
}
if !self.authority.is_empty() {
writeln!(f, "\n;; AUTHORITY SECTION:")?;
for auth in &self.authority {
writeln!(f, "{}", auth)?;
}
}
if !self.additional.is_empty() {
writeln!(f, "\n;; ADDITIONAL SECTION:")?;
for add in &self.additional {
writeln!(f, "{}", add)?;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::{RData, RecordClass, RecordType};
use std::net::Ipv4Addr;
#[test]
fn test_message_creation() {
let message = Message::new();
assert_eq!(message.id(), 0);
assert!(!message.is_response());
assert_eq!(message.opcode(), OpCode::Query);
assert_eq!(message.response_code(), ResponseCode::NoError);
assert!(message.recursion_desired());
assert!(!message.recursion_available());
}
#[test]
fn test_message_id() {
let mut message = Message::new();
message.set_id(12345);
assert_eq!(message.id(), 12345);
}
#[test]
fn test_message_flags() {
let mut message = Message::new();
message.set_response(true);
assert!(message.is_response());
message.set_authoritative(true);
assert!(message.is_authoritative());
message.set_truncated(true);
assert!(message.is_truncated());
message.set_recursion_desired(false);
assert!(!message.recursion_desired());
message.set_recursion_available(true);
assert!(message.recursion_available());
}
#[test]
fn test_add_question() {
let mut message = Message::new();
let question = Question::new("example.com".to_string(), RecordType::A, RecordClass::IN);
message.add_question(question);
assert_eq!(message.question_count(), 1);
assert_eq!(message.questions()[0].qname(), "example.com");
}
#[test]
fn test_add_answer() {
let mut message = Message::new();
let answer = ResourceRecord::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
3600,
RData::A(Ipv4Addr::new(192, 0, 2, 1)),
);
message.add_answer(answer);
assert_eq!(message.answer_count(), 1);
assert_eq!(message.answers()[0].name(), "example.com");
}
#[test]
fn test_clear_sections() {
let mut message = Message::new();
message.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
message.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
3600,
RData::A(Ipv4Addr::new(192, 0, 2, 1)),
));
assert_eq!(message.question_count(), 1);
assert_eq!(message.answer_count(), 1);
message.clear_questions();
message.clear_answers();
assert_eq!(message.question_count(), 0);
assert_eq!(message.answer_count(), 0);
}
#[test]
fn test_response_code() {
let mut message = Message::new();
message.set_response_code(ResponseCode::NXDomain);
assert_eq!(message.response_code(), ResponseCode::NXDomain);
message.set_response_code(ResponseCode::ServFail);
assert_eq!(message.response_code(), ResponseCode::ServFail);
}
#[test]
fn test_opcode() {
let mut message = Message::new();
message.set_opcode(OpCode::Update);
assert_eq!(message.opcode(), OpCode::Update);
}
#[test]
fn test_complete_query_message() {
let mut message = Message::new();
message.set_id(1234);
message.set_query(true);
message.set_recursion_desired(true);
message.add_question(Question::new(
"www.example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
assert_eq!(message.id(), 1234);
assert!(!message.is_response());
assert!(message.recursion_desired());
assert_eq!(message.question_count(), 1);
}
#[test]
fn test_complete_response_message() {
let mut message = Message::new();
message.set_id(1234);
message.set_response(true);
message.set_authoritative(true);
message.set_recursion_available(true);
message.set_response_code(ResponseCode::NoError);
message.add_question(Question::new(
"www.example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
message.add_answer(ResourceRecord::new(
"www.example.com".to_string(),
RecordType::A,
RecordClass::IN,
300,
RData::A(Ipv4Addr::new(93, 184, 216, 34)),
));
assert_eq!(message.id(), 1234);
assert!(message.is_response());
assert!(message.is_authoritative());
assert!(message.recursion_available());
assert_eq!(message.response_code(), ResponseCode::NoError);
assert_eq!(message.question_count(), 1);
assert_eq!(message.answer_count(), 1);
}
}