#[cfg(test)]
use alloc::vec::Vec;
use core::{convert::From, fmt, ops::Deref};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::{
error::*,
op::{op_code::OpCode, response_code::ResponseCode},
serialize::binary::*,
};
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Header {
#[cfg_attr(feature = "serde", serde(flatten))]
pub metadata: Metadata,
#[cfg_attr(feature = "serde", serde(flatten))]
pub counts: HeaderCounts,
}
impl BinEncodable for Header {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
encoder.emit_u16(self.id)?;
let mut q_opcd_a_t_r = if let MessageType::Response = self.message_type {
0x80
} else {
0x00
};
q_opcd_a_t_r |= u8::from(self.op_code) << 3;
q_opcd_a_t_r |= if self.authoritative { 0x4 } else { 0x0 };
q_opcd_a_t_r |= if self.truncation { 0x2 } else { 0x0 };
q_opcd_a_t_r |= if self.recursion_desired { 0x1 } else { 0x0 };
encoder.emit(q_opcd_a_t_r)?;
let mut r_z_ad_cd_rcod = if self.recursion_available {
0b1000_0000
} else {
0b0000_0000
};
r_z_ad_cd_rcod |= if self.authentic_data {
0b0010_0000
} else {
0b0000_0000
};
r_z_ad_cd_rcod |= if self.checking_disabled {
0b0001_0000
} else {
0b0000_0000
};
r_z_ad_cd_rcod |= self.response_code.low();
encoder.emit(r_z_ad_cd_rcod)?;
encoder.emit_u16(self.counts.queries)?;
encoder.emit_u16(self.counts.answers)?;
encoder.emit_u16(self.counts.authorities)?;
encoder.emit_u16(self.counts.additionals)?;
Ok(())
}
}
impl<'r> BinDecodable<'r> for Header {
fn read(decoder: &mut BinDecoder<'r>) -> Result<Self, DecodeError> {
let id = decoder.read_u16()?.unverified();
let q_opcd_a_t_r = decoder.pop()?.unverified();
let message_type = if (0b1000_0000 & q_opcd_a_t_r) == 0b1000_0000 {
MessageType::Response
} else {
MessageType::Query
};
let op_code = OpCode::from_u8((0b0111_1000 & q_opcd_a_t_r) >> 3);
let authoritative = (0b0000_0100 & q_opcd_a_t_r) == 0b0000_0100;
let truncation = (0b0000_0010 & q_opcd_a_t_r) == 0b0000_0010;
let recursion_desired = (0b0000_0001 & q_opcd_a_t_r) == 0b0000_0001;
let r_z_ad_cd_rcod = decoder.pop()?.unverified();
let recursion_available = (0b1000_0000 & r_z_ad_cd_rcod) == 0b1000_0000;
let authentic_data = (0b0010_0000 & r_z_ad_cd_rcod) == 0b0010_0000;
let checking_disabled = (0b0001_0000 & r_z_ad_cd_rcod) == 0b0001_0000;
let response_code: u8 = 0b0000_1111 & r_z_ad_cd_rcod;
let response_code = ResponseCode::from_low(response_code);
let metadata = Metadata {
id,
message_type,
op_code,
authoritative,
truncation,
recursion_desired,
recursion_available,
authentic_data,
checking_disabled,
response_code,
};
let counts = HeaderCounts {
queries: decoder.read_u16()?.unverified(),
answers: decoder.read_u16()?.unverified(),
authorities: decoder.read_u16()?.unverified(),
additionals: decoder.read_u16()?.unverified(),
};
Ok(Self { metadata, counts })
}
}
impl EncodedSize for Header {
const LEN: usize = 12;
}
impl Deref for Header {
type Target = Metadata;
fn deref(&self) -> &Self::Target {
&self.metadata
}
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct Metadata {
pub id: u16,
pub message_type: MessageType,
pub op_code: OpCode,
pub authoritative: bool,
pub truncation: bool,
pub recursion_desired: bool,
pub recursion_available: bool,
pub authentic_data: bool,
pub checking_disabled: bool,
pub response_code: ResponseCode,
}
impl Metadata {
pub const fn new(id: u16, message_type: MessageType, op_code: OpCode) -> Self {
Self {
id,
message_type,
op_code,
authoritative: false,
truncation: false,
recursion_desired: false,
recursion_available: false,
authentic_data: false,
checking_disabled: false,
response_code: ResponseCode::NoError,
}
}
pub fn response_from_request(req: &Self) -> Self {
Self {
id: req.id,
message_type: MessageType::Response,
op_code: req.op_code,
authoritative: false,
truncation: false,
recursion_desired: req.recursion_desired,
recursion_available: false,
authentic_data: false,
checking_disabled: req.checking_disabled,
response_code: ResponseCode::default(),
}
}
pub fn flags(&self) -> Flags {
Flags {
authoritative: self.authoritative,
authentic_data: self.authentic_data,
checking_disabled: self.checking_disabled,
recursion_available: self.recursion_available,
recursion_desired: self.recursion_desired,
truncation: self.truncation,
}
}
#[doc(hidden)]
pub fn merge_response_code(&mut self, high_response_code: u8) {
self.response_code = ResponseCode::from(high_response_code, self.response_code.low());
}
}
impl fmt::Display for Metadata {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(
f,
"{id}:{message_type}:{flags}:{code:?}:{op_code}",
id = self.id,
message_type = self.message_type,
flags = self.flags(),
code = self.response_code,
op_code = self.op_code,
)
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, PartialOrd, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub struct HeaderCounts {
pub queries: u16,
pub answers: u16,
pub authorities: u16,
pub additionals: u16,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Copy, Clone, Hash)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
pub enum MessageType {
Query,
Response,
}
impl fmt::Display for MessageType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let s = match self {
Self::Query => "QUERY",
Self::Response => "RESPONSE",
};
f.write_str(s)
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct Flags {
authoritative: bool,
truncation: bool,
recursion_desired: bool,
recursion_available: bool,
authentic_data: bool,
checking_disabled: bool,
}
impl fmt::Display for Flags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
const SEPARATOR: &str = ",";
let flags = [
(self.recursion_desired, "RD"),
(self.checking_disabled, "CD"),
(self.truncation, "TC"),
(self.authoritative, "AA"),
(self.recursion_available, "RA"),
(self.authentic_data, "AD"),
];
let mut iter = flags
.iter()
.cloned()
.filter_map(|(flag, s)| if flag { Some(s) } else { None });
if let Some(s) = iter.next() {
f.write_str(s)?
}
for s in iter {
f.write_str(SEPARATOR)?;
f.write_str(s)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse() {
let byte_vec = vec![
0x01, 0x10, 0xAA, 0x83, 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11,
];
let mut decoder = BinDecoder::new(&byte_vec);
let expect = Header {
metadata: Metadata {
id: 0x0110,
message_type: MessageType::Response,
op_code: OpCode::Update,
authoritative: false,
truncation: true,
recursion_desired: false,
recursion_available: true,
authentic_data: false,
checking_disabled: false,
response_code: ResponseCode::NXDomain,
},
counts: HeaderCounts {
queries: 0x8877,
answers: 0x6655,
authorities: 0x4433,
additionals: 0x2211,
},
};
let got = Header::read(&mut decoder).unwrap();
assert_eq!(got, expect);
}
#[test]
fn test_write() {
let header = Header {
metadata: Metadata {
id: 0x0110,
message_type: MessageType::Response,
op_code: OpCode::Update,
authoritative: false,
truncation: true,
recursion_desired: false,
recursion_available: true,
authentic_data: false,
checking_disabled: false,
response_code: ResponseCode::NXDomain,
},
counts: HeaderCounts {
queries: 0x8877,
answers: 0x6655,
authorities: 0x4433,
additionals: 0x2211,
},
};
let expect = vec![
0x01, 0x10, 0xAA, 0x83, 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11,
];
let mut bytes = Vec::with_capacity(512);
{
let mut encoder = BinEncoder::new(&mut bytes);
header.emit(&mut encoder).unwrap();
}
assert_eq!(bytes, expect);
}
}