use super::iana::{Opcode, Rcode};
use super::octets::{
Compose, OctetsBuilder, Parse, ParseError, Parser, ShortBuf,
};
use core::convert::TryInto;
use core::{fmt, mem, str::FromStr};
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Header {
inner: [u8; 4],
}
impl Header {
pub fn new() -> Self {
Self::default()
}
pub fn for_message_slice(s: &[u8]) -> &Header {
assert!(s.len() >= mem::size_of::<Header>());
unsafe { &*(s.as_ptr() as *const Header) }
}
pub fn for_message_slice_mut(s: &mut [u8]) -> &mut Header {
assert!(s.len() >= mem::size_of::<Header>());
unsafe { &mut *(s.as_ptr() as *mut Header) }
}
pub fn as_slice(&self) -> &[u8] {
&self.inner
}
}
impl Header {
#[cfg_attr(
feature = "std",
doc = "[`set_random_id`][Self::set_random_id]"
)]
#[cfg_attr(not(feature = "std"), doc = "`set_random_id`")]
pub fn id(self) -> u16 {
u16::from_be_bytes(self.inner[..2].try_into().unwrap())
}
pub fn set_id(&mut self, value: u16) {
self.inner[..2].copy_from_slice(&value.to_be_bytes())
}
#[cfg(feature = "random")]
pub fn set_random_id(&mut self) {
self.set_id(::rand::random())
}
pub fn qr(self) -> bool {
self.get_bit(2, 7)
}
pub fn set_qr(&mut self, set: bool) {
self.set_bit(2, 7, set)
}
pub fn opcode(self) -> Opcode {
Opcode::from_int((self.inner[2] >> 3) & 0x0F)
}
pub fn set_opcode(&mut self, opcode: Opcode) {
self.inner[2] = self.inner[2] & 0x87 | (opcode.to_int() << 3);
}
pub fn flags(self) -> Flags {
Flags {
qr: self.qr(),
aa: self.aa(),
tc: self.tc(),
rd: self.rd(),
ra: self.ra(),
ad: self.ad(),
cd: self.cd(),
}
}
pub fn set_flags(&mut self, flags: Flags) {
self.set_qr(flags.qr);
self.set_aa(flags.aa);
self.set_tc(flags.tc);
self.set_rd(flags.rd);
self.set_ra(flags.ra);
self.set_ad(flags.ad);
self.set_cd(flags.cd);
}
pub fn aa(self) -> bool {
self.get_bit(2, 2)
}
pub fn set_aa(&mut self, set: bool) {
self.set_bit(2, 2, set)
}
pub fn tc(self) -> bool {
self.get_bit(2, 1)
}
pub fn set_tc(&mut self, set: bool) {
self.set_bit(2, 1, set)
}
pub fn rd(self) -> bool {
self.get_bit(2, 0)
}
pub fn set_rd(&mut self, set: bool) {
self.set_bit(2, 0, set)
}
pub fn ra(self) -> bool {
self.get_bit(3, 7)
}
pub fn set_ra(&mut self, set: bool) {
self.set_bit(3, 7, set)
}
pub fn z(self) -> bool {
self.get_bit(3, 6)
}
pub fn set_z(&mut self, set: bool) {
self.set_bit(3, 6, set)
}
pub fn ad(self) -> bool {
self.get_bit(3, 5)
}
pub fn set_ad(&mut self, set: bool) {
self.set_bit(3, 5, set)
}
pub fn cd(self) -> bool {
self.get_bit(3, 4)
}
pub fn set_cd(&mut self, set: bool) {
self.set_bit(3, 4, set)
}
pub fn rcode(self) -> Rcode {
Rcode::from_int(self.inner[3] & 0x0F)
}
pub fn set_rcode(&mut self, rcode: Rcode) {
self.inner[3] = self.inner[3] & 0xF0 | (rcode.to_int() & 0x0F);
}
fn get_bit(self, offset: usize, bit: usize) -> bool {
self.inner[offset] & (1 << bit) != 0
}
fn set_bit(&mut self, offset: usize, bit: usize, set: bool) {
if set {
self.inner[offset] |= 1 << bit
} else {
self.inner[offset] &= !(1 << bit)
}
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Hash)]
pub struct Flags {
pub qr: bool,
pub aa: bool,
pub tc: bool,
pub rd: bool,
pub ra: bool,
pub ad: bool,
pub cd: bool,
}
impl Flags {
pub fn new() -> Self {
Self::default()
}
}
impl fmt::Display for Flags {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut sep = "";
if self.qr {
write!(f, "QR")?;
sep = " ";
}
if self.aa {
write!(f, "{}AA", sep)?;
sep = " ";
}
if self.tc {
write!(f, "{}TC", sep)?;
sep = " ";
}
if self.rd {
write!(f, "{}RD", sep)?;
sep = " ";
}
if self.ra {
write!(f, "{}RA", sep)?;
sep = " ";
}
if self.ad {
write!(f, "{}AD", sep)?;
sep = " ";
}
if self.cd {
write!(f, "{}CD", sep)?;
}
Ok(())
}
}
impl FromStr for Flags {
type Err = FlagsFromStrError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut flags = Flags::new();
for token in s.split(' ') {
match token {
"QR" | "Qr" | "qR" | "qr" => flags.qr = true,
"AA" | "Aa" | "aA" | "aa" => flags.aa = true,
"TC" | "Tc" | "tC" | "tc" => flags.tc = true,
"RD" | "Rd" | "rD" | "rd" => flags.rd = true,
"RA" | "Ra" | "rA" | "ra" => flags.ra = true,
"AD" | "Ad" | "aD" | "ad" => flags.ad = true,
"CD" | "Cd" | "cD" | "cd" => flags.cd = true,
"" => {}
_ => return Err(FlagsFromStrError),
}
}
Ok(flags)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct HeaderCounts {
inner: [u8; 8],
}
impl HeaderCounts {
pub fn new() -> Self {
Self::default()
}
pub fn for_message_slice(message: &[u8]) -> &Self {
assert!(message.len() >= mem::size_of::<HeaderSection>());
unsafe {
&*((message[mem::size_of::<Header>()..].as_ptr())
as *const HeaderCounts)
}
}
pub fn for_message_slice_mut(message: &mut [u8]) -> &mut Self {
assert!(message.len() >= mem::size_of::<HeaderSection>());
unsafe {
&mut *((message[mem::size_of::<Header>()..].as_ptr())
as *mut HeaderCounts)
}
}
pub fn as_slice(&self) -> &[u8] {
&self.inner
}
pub fn as_slice_mut(&mut self) -> &mut [u8] {
&mut self.inner
}
pub fn set(&mut self, counts: HeaderCounts) {
self.as_slice_mut().copy_from_slice(counts.as_slice())
}
}
impl HeaderCounts {
pub fn qdcount(self) -> u16 {
self.get_u16(0)
}
pub fn set_qdcount(&mut self, value: u16) {
self.set_u16(0, value)
}
pub fn inc_qdcount(&mut self) -> Result<(), ShortBuf> {
match self.qdcount().checked_add(1) {
Some(count) => {
self.set_qdcount(count);
Ok(())
}
None => Err(ShortBuf),
}
}
pub fn dec_qdcount(&mut self) {
let count = self.qdcount();
assert!(count > 0);
self.set_qdcount(count - 1);
}
pub fn ancount(self) -> u16 {
self.get_u16(2)
}
pub fn set_ancount(&mut self, value: u16) {
self.set_u16(2, value)
}
pub fn inc_ancount(&mut self) -> Result<(), ShortBuf> {
match self.ancount().checked_add(1) {
Some(count) => {
self.set_ancount(count);
Ok(())
}
None => Err(ShortBuf),
}
}
pub fn dec_ancount(&mut self) {
let count = self.ancount();
assert!(count > 0);
self.set_ancount(count - 1);
}
pub fn nscount(self) -> u16 {
self.get_u16(4)
}
pub fn set_nscount(&mut self, value: u16) {
self.set_u16(4, value)
}
pub fn inc_nscount(&mut self) -> Result<(), ShortBuf> {
match self.nscount().checked_add(1) {
Some(count) => {
self.set_nscount(count);
Ok(())
}
None => Err(ShortBuf),
}
}
pub fn dec_nscount(&mut self) {
let count = self.nscount();
assert!(count > 0);
self.set_nscount(count - 1);
}
pub fn arcount(self) -> u16 {
self.get_u16(6)
}
pub fn set_arcount(&mut self, value: u16) {
self.set_u16(6, value)
}
pub fn inc_arcount(&mut self) -> Result<(), ShortBuf> {
match self.arcount().checked_add(1) {
Some(count) => {
self.set_arcount(count);
Ok(())
}
None => Err(ShortBuf),
}
}
pub fn dec_arcount(&mut self) {
let count = self.arcount();
assert!(count > 0);
self.set_arcount(count - 1);
}
pub fn zocount(self) -> u16 {
self.qdcount()
}
pub fn set_zocount(&mut self, value: u16) {
self.set_qdcount(value)
}
pub fn prcount(self) -> u16 {
self.ancount()
}
pub fn set_prcount(&mut self, value: u16) {
self.set_ancount(value)
}
pub fn upcount(self) -> u16 {
self.nscount()
}
pub fn set_upcount(&mut self, value: u16) {
self.set_nscount(value)
}
pub fn adcount(self) -> u16 {
self.arcount()
}
pub fn set_adcount(&mut self, value: u16) {
self.set_arcount(value)
}
fn get_u16(self, offset: usize) -> u16 {
u16::from_be_bytes(self.inner[offset..offset + 2].try_into().unwrap())
}
fn set_u16(&mut self, offset: usize, value: u16) {
self.inner[offset..offset + 2].copy_from_slice(&value.to_be_bytes())
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct HeaderSection {
inner: [u8; 12],
}
impl HeaderSection {
pub fn new() -> Self {
Self::default()
}
pub fn for_message_slice(s: &[u8]) -> &HeaderSection {
assert!(s.len() >= mem::size_of::<HeaderSection>());
unsafe { &*(s.as_ptr() as *const HeaderSection) }
}
pub fn for_message_slice_mut(s: &mut [u8]) -> &mut HeaderSection {
assert!(s.len() >= mem::size_of::<HeaderSection>());
unsafe { &mut *(s.as_ptr() as *mut HeaderSection) }
}
pub fn as_slice(&self) -> &[u8] {
&self.inner
}
}
impl HeaderSection {
pub fn header(&self) -> &Header {
Header::for_message_slice(&self.inner)
}
pub fn header_mut(&mut self) -> &mut Header {
Header::for_message_slice_mut(&mut self.inner)
}
pub fn counts(&self) -> &HeaderCounts {
HeaderCounts::for_message_slice(&self.inner)
}
pub fn counts_mut(&mut self) -> &mut HeaderCounts {
HeaderCounts::for_message_slice_mut(&mut self.inner)
}
}
impl AsRef<Header> for HeaderSection {
fn as_ref(&self) -> &Header {
self.header()
}
}
impl AsMut<Header> for HeaderSection {
fn as_mut(&mut self) -> &mut Header {
self.header_mut()
}
}
impl AsRef<HeaderCounts> for HeaderSection {
fn as_ref(&self) -> &HeaderCounts {
self.counts()
}
}
impl AsMut<HeaderCounts> for HeaderSection {
fn as_mut(&mut self) -> &mut HeaderCounts {
self.counts_mut()
}
}
impl<Ref: AsRef<[u8]>> Parse<Ref> for HeaderSection {
fn parse(parser: &mut Parser<Ref>) -> Result<Self, ParseError> {
let mut res = Self::default();
parser.parse_buf(&mut res.inner)?;
Ok(res)
}
fn skip(parser: &mut Parser<Ref>) -> Result<(), ParseError> {
parser.advance(12)
}
}
impl Compose for HeaderSection {
fn compose<T: OctetsBuilder>(
&self,
target: &mut T,
) -> Result<(), ShortBuf> {
target.append_slice(&self.inner)
}
}
#[derive(Debug)]
pub struct FlagsFromStrError;
impl fmt::Display for FlagsFromStrError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "illegal flags token")
}
}
#[cfg(feature = "std")]
impl std::error::Error for FlagsFromStrError {}
#[cfg(test)]
mod test {
use super::*;
use crate::base::iana::{Opcode, Rcode};
#[test]
#[cfg(feature = "std")]
fn for_slice() {
use std::vec::Vec;
let header = b"\x01\x02\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0";
let mut vec = Vec::from(&header[..]);
assert_eq!(
Header::for_message_slice(header).as_slice(),
b"\x01\x02\x00\x00"
);
assert_eq!(
Header::for_message_slice_mut(vec.as_mut()).as_slice(),
b"\x01\x02\x00\x00"
);
assert_eq!(
HeaderCounts::for_message_slice(header).as_slice(),
b"\x12\x34\x56\x78\x9a\xbc\xde\xf0"
);
assert_eq!(
HeaderCounts::for_message_slice_mut(vec.as_mut()).as_slice(),
b"\x12\x34\x56\x78\x9a\xbc\xde\xf0"
);
assert_eq!(
HeaderSection::for_message_slice(header).as_slice(),
header
);
assert_eq!(
HeaderSection::for_message_slice_mut(vec.as_mut()).as_slice(),
header
);
}
#[test]
#[should_panic]
fn short_header() {
Header::for_message_slice(b"134");
}
#[test]
#[should_panic]
fn short_header_counts() {
HeaderCounts::for_message_slice(b"12345678");
}
#[test]
#[should_panic]
fn short_header_section() {
HeaderSection::for_message_slice(b"1234");
}
macro_rules! test_field {
($get:ident, $set:ident, $default:expr, $($value:expr),*) => {
$({
let mut h = Header::new();
assert_eq!(h.$get(), $default);
h.$set($value);
assert_eq!(h.$get(), $value);
})*
}
}
#[test]
#[allow(clippy::bool_assert_comparison)]
fn header() {
test_field!(id, set_id, 0, 0x1234);
test_field!(qr, set_qr, false, true, false);
test_field!(opcode, set_opcode, Opcode::Query, Opcode::Notify);
test_field!(
flags,
set_flags,
Flags::new(),
Flags {
qr: true,
..Default::default()
}
);
test_field!(aa, set_aa, false, true, false);
test_field!(tc, set_tc, false, true, false);
test_field!(rd, set_rd, false, true, false);
test_field!(ra, set_ra, false, true, false);
test_field!(z, set_z, false, true, false);
test_field!(ad, set_ad, false, true, false);
test_field!(cd, set_cd, false, true, false);
test_field!(rcode, set_rcode, Rcode::NoError, Rcode::Refused);
}
#[test]
fn counts() {
let mut c = HeaderCounts {
inner: [1, 2, 3, 4, 5, 6, 7, 8],
};
assert_eq!(c.qdcount(), 0x0102);
assert_eq!(c.ancount(), 0x0304);
assert_eq!(c.nscount(), 0x0506);
assert_eq!(c.arcount(), 0x0708);
c.inc_qdcount().unwrap();
c.inc_ancount().unwrap();
c.inc_nscount().unwrap();
c.inc_arcount().unwrap();
assert_eq!(c.inner, [1, 3, 3, 5, 5, 7, 7, 9]);
c.set_qdcount(0x0807);
c.set_ancount(0x0605);
c.set_nscount(0x0403);
c.set_arcount(0x0201);
assert_eq!(c.inner, [8, 7, 6, 5, 4, 3, 2, 1]);
}
#[test]
fn update_counts() {
let mut c = HeaderCounts {
inner: [1, 2, 3, 4, 5, 6, 7, 8],
};
assert_eq!(c.zocount(), 0x0102);
assert_eq!(c.prcount(), 0x0304);
assert_eq!(c.upcount(), 0x0506);
assert_eq!(c.adcount(), 0x0708);
c.set_zocount(0x0807);
c.set_prcount(0x0605);
c.set_upcount(0x0403);
c.set_adcount(0x0201);
assert_eq!(c.inner, [8, 7, 6, 5, 4, 3, 2, 1]);
}
#[test]
fn inc_qdcount() {
let mut c = HeaderCounts {
inner: [0xff, 0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff],
};
assert!(c.inc_qdcount().is_ok());
assert!(c.inc_qdcount().is_err());
}
#[test]
fn inc_ancount() {
let mut c = HeaderCounts {
inner: [0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff],
};
assert!(c.inc_ancount().is_ok());
assert!(c.inc_ancount().is_err());
}
#[test]
fn inc_nscount() {
let mut c = HeaderCounts {
inner: [0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff],
};
assert!(c.inc_nscount().is_ok());
assert!(c.inc_nscount().is_err());
}
#[test]
fn inc_arcount() {
let mut c = HeaderCounts {
inner: [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe],
};
assert!(c.inc_arcount().is_ok());
assert!(c.inc_arcount().is_err());
}
#[cfg(feature = "std")]
#[test]
fn flags_display() {
let f = Flags::new();
assert_eq!(format!("{}", f), "");
let f = Flags {
qr: true,
aa: true,
tc: true,
rd: true,
ra: true,
ad: true,
cd: true,
};
assert_eq!(format!("{}", f), "QR AA TC RD RA AD CD");
let mut f = Flags::new();
f.rd = true;
f.cd = true;
assert_eq!(format!("{}", f), "RD CD");
}
#[cfg(feature = "std")]
#[test]
fn flags_from_str() {
let f1 = Flags::from_str("").unwrap();
let f2 = Flags::new();
assert_eq!(f1, f2);
let f1 = Flags::from_str("QR AA TC RD RA AD CD").unwrap();
let f2 = Flags {
qr: true,
aa: true,
tc: true,
rd: true,
ra: true,
ad: true,
cd: true,
};
assert_eq!(f1, f2);
let f1 = Flags::from_str("tC Aa CD rd").unwrap();
let f2 = Flags {
aa: true,
tc: true,
rd: true,
cd: true,
..Default::default()
};
assert_eq!(f1, f2);
let f1 = Flags::from_str("XXXX");
assert!(f1.is_err());
}
}