#[cfg(test)]
use alloc::vec::Vec;
use core::{convert::From, fmt};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::{
error::*,
op::{op_code::OpCode, response_code::ResponseCode},
serialize::binary::*,
};
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Header {
id: u16,
message_type: MessageType,
op_code: OpCode,
authoritative: bool,
truncation: bool,
recursion_desired: bool,
recursion_available: bool,
authentic_data: bool,
checking_disabled: bool,
response_code: ResponseCode,
query_count: u16,
answer_count: u16,
name_server_count: u16,
additional_count: u16,
}
impl fmt::Display for Header {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(
f,
"{id}:{message_type}:{flags}:{code:?}:{op_code}:{answers}/{authorities}/{additionals}",
id = self.id,
message_type = self.message_type,
flags = self.flags(),
code = self.response_code,
op_code = self.op_code,
answers = self.answer_count,
authorities = self.name_server_count,
additionals = self.additional_count,
)
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Copy, Clone, Hash)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub enum MessageType {
Query,
Response,
}
impl fmt::Display for MessageType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let s = match self {
Self::Query => "QUERY",
Self::Response => "RESPONSE",
};
f.write_str(s)
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct Flags {
authoritative: bool,
truncation: bool,
recursion_desired: bool,
recursion_available: bool,
authentic_data: bool,
checking_disabled: bool,
}
impl fmt::Display for Flags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
const SEPARATOR: &str = ",";
let flags = [
(self.recursion_desired, "RD"),
(self.checking_disabled, "CD"),
(self.truncation, "TC"),
(self.authoritative, "AA"),
(self.recursion_available, "RA"),
(self.authentic_data, "AD"),
];
let mut iter = flags
.iter()
.cloned()
.filter_map(|(flag, s)| if flag { Some(s) } else { None });
if let Some(s) = iter.next() {
f.write_str(s)?
}
for s in iter {
f.write_str(SEPARATOR)?;
f.write_str(s)?;
}
Ok(())
}
}
impl Default for Header {
fn default() -> Self {
Self::new()
}
}
impl Header {
pub const fn new() -> Self {
Self {
id: 0,
message_type: MessageType::Query,
op_code: OpCode::Query,
authoritative: false,
truncation: false,
recursion_desired: false,
recursion_available: false,
authentic_data: false,
checking_disabled: false,
response_code: ResponseCode::NoError,
query_count: 0,
answer_count: 0,
name_server_count: 0,
additional_count: 0,
}
}
pub fn response_from_request(header: &Self) -> Self {
Self {
id: header.id,
message_type: MessageType::Response,
op_code: header.op_code,
authoritative: false,
truncation: false,
recursion_desired: header.recursion_desired,
recursion_available: false,
authentic_data: false,
checking_disabled: header.checking_disabled,
response_code: ResponseCode::default(),
query_count: 0,
answer_count: 0,
name_server_count: 0,
additional_count: 0,
}
}
#[inline(always)]
pub fn len() -> usize {
12
}
pub fn set_id(&mut self, id: u16) -> &mut Self {
self.id = id;
self
}
pub fn set_message_type(&mut self, message_type: MessageType) -> &mut Self {
self.message_type = message_type;
self
}
pub fn set_op_code(&mut self, op_code: OpCode) -> &mut Self {
self.op_code = op_code;
self
}
pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
self.authoritative = authoritative;
self
}
pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
self.truncation = truncated;
self
}
pub fn set_recursion_desired(&mut self, recursion_desired: bool) -> &mut Self {
self.recursion_desired = recursion_desired;
self
}
pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
self.recursion_available = recursion_available;
self
}
pub fn set_authentic_data(&mut self, authentic_data: bool) -> &mut Self {
self.authentic_data = authentic_data;
self
}
pub fn set_checking_disabled(&mut self, checking_disabled: bool) -> &mut Self {
self.checking_disabled = checking_disabled;
self
}
pub fn flags(&self) -> Flags {
Flags {
authoritative: self.authoritative,
authentic_data: self.authentic_data,
checking_disabled: self.checking_disabled,
recursion_available: self.recursion_available,
recursion_desired: self.recursion_desired,
truncation: self.truncation,
}
}
pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
self.response_code = response_code;
self
}
#[doc(hidden)]
pub fn merge_response_code(&mut self, high_response_code: u8) {
self.response_code = ResponseCode::from(high_response_code, self.response_code.low());
}
pub fn set_query_count(&mut self, query_count: u16) -> &mut Self {
self.query_count = query_count;
self
}
pub fn set_answer_count(&mut self, answer_count: u16) -> &mut Self {
self.answer_count = answer_count;
self
}
pub fn set_name_server_count(&mut self, name_server_count: u16) -> &mut Self {
self.name_server_count = name_server_count;
self
}
pub fn set_additional_count(&mut self, additional_count: u16) -> &mut Self {
self.additional_count = additional_count;
self
}
pub fn id(&self) -> u16 {
self.id
}
pub fn message_type(&self) -> MessageType {
self.message_type
}
pub fn op_code(&self) -> OpCode {
self.op_code
}
pub fn authoritative(&self) -> bool {
self.authoritative
}
pub fn truncated(&self) -> bool {
self.truncation
}
pub fn recursion_desired(&self) -> bool {
self.recursion_desired
}
pub fn recursion_available(&self) -> bool {
self.recursion_available
}
pub fn authentic_data(&self) -> bool {
self.authentic_data
}
pub fn checking_disabled(&self) -> bool {
self.checking_disabled
}
pub fn response_code(&self) -> ResponseCode {
self.response_code
}
pub fn query_count(&self) -> u16 {
self.query_count
}
pub fn answer_count(&self) -> u16 {
self.answer_count
}
pub fn name_server_count(&self) -> u16 {
self.name_server_count
}
pub fn additional_count(&self) -> u16 {
self.additional_count
}
}
impl BinEncodable for Header {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
encoder.reserve(12)?;
encoder.emit_u16(self.id)?;
let mut q_opcd_a_t_r: u8 = if let MessageType::Response = self.message_type {
0x80
} else {
0x00
};
q_opcd_a_t_r |= u8::from(self.op_code) << 3;
q_opcd_a_t_r |= if self.authoritative { 0x4 } else { 0x0 };
q_opcd_a_t_r |= if self.truncation { 0x2 } else { 0x0 };
q_opcd_a_t_r |= if self.recursion_desired { 0x1 } else { 0x0 };
encoder.emit(q_opcd_a_t_r)?;
let mut r_z_ad_cd_rcod: u8 = if self.recursion_available {
0b1000_0000
} else {
0b0000_0000
};
r_z_ad_cd_rcod |= if self.authentic_data {
0b0010_0000
} else {
0b0000_0000
};
r_z_ad_cd_rcod |= if self.checking_disabled {
0b0001_0000
} else {
0b0000_0000
};
r_z_ad_cd_rcod |= self.response_code.low();
encoder.emit(r_z_ad_cd_rcod)?;
encoder.emit_u16(self.query_count)?;
encoder.emit_u16(self.answer_count)?;
encoder.emit_u16(self.name_server_count)?;
encoder.emit_u16(self.additional_count)?;
Ok(())
}
}
impl<'r> BinDecodable<'r> for Header {
fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
let id = decoder.read_u16()?.unverified();
let q_opcd_a_t_r = decoder.pop()?.unverified();
let message_type = if (0b1000_0000 & q_opcd_a_t_r) == 0b1000_0000 {
MessageType::Response
} else {
MessageType::Query
};
let op_code = OpCode::from_u8((0b0111_1000 & q_opcd_a_t_r) >> 3);
let authoritative = (0b0000_0100 & q_opcd_a_t_r) == 0b0000_0100;
let truncation = (0b0000_0010 & q_opcd_a_t_r) == 0b0000_0010;
let recursion_desired = (0b0000_0001 & q_opcd_a_t_r) == 0b0000_0001;
let r_z_ad_cd_rcod = decoder.pop()?.unverified();
let recursion_available = (0b1000_0000 & r_z_ad_cd_rcod) == 0b1000_0000;
let authentic_data = (0b0010_0000 & r_z_ad_cd_rcod) == 0b0010_0000;
let checking_disabled = (0b0001_0000 & r_z_ad_cd_rcod) == 0b0001_0000;
let response_code: u8 = 0b0000_1111 & r_z_ad_cd_rcod;
let response_code = ResponseCode::from_low(response_code);
let query_count =
decoder.read_u16()?.unverified();
let answer_count =
decoder.read_u16()?.unverified();
let name_server_count =
decoder.read_u16()?.unverified();
let additional_count =
decoder.read_u16()?.unverified();
Ok(Self {
id,
message_type,
op_code,
authoritative,
truncation,
recursion_desired,
recursion_available,
authentic_data,
checking_disabled,
response_code,
query_count,
answer_count,
name_server_count,
additional_count,
})
}
}
#[test]
fn test_parse() {
let byte_vec = vec![
0x01, 0x10, 0xAA, 0x83, 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11,
];
let mut decoder = BinDecoder::new(&byte_vec);
let expect = Header {
id: 0x0110,
message_type: MessageType::Response,
op_code: OpCode::Update,
authoritative: false,
truncation: true,
recursion_desired: false,
recursion_available: true,
authentic_data: false,
checking_disabled: false,
response_code: ResponseCode::NXDomain,
query_count: 0x8877,
answer_count: 0x6655,
name_server_count: 0x4433,
additional_count: 0x2211,
};
let got = Header::read(&mut decoder).unwrap();
assert_eq!(got, expect);
}
#[test]
fn test_write() {
let header = Header {
id: 0x0110,
message_type: MessageType::Response,
op_code: OpCode::Update,
authoritative: false,
truncation: true,
recursion_desired: false,
recursion_available: true,
authentic_data: false,
checking_disabled: false,
response_code: ResponseCode::NXDomain,
query_count: 0x8877,
answer_count: 0x6655,
name_server_count: 0x4433,
additional_count: 0x2211,
};
let expect: Vec<u8> = vec![
0x01, 0x10, 0xAA, 0x83, 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11,
];
let mut bytes = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut bytes);
header.emit(&mut encoder).unwrap();
}
assert_eq!(bytes, expect);
}