use crate::message::{MessageType, OpCode, RCode};
macro_rules! get_bit {
($e:expr, $l:literal) => {
($e & (1 << $l)) != 0
};
}
cfg_any_client! {
macro_rules! set_bit {
($e:expr, $l:literal, $v:ident) => {
let mask = 1 << $l;
if $v {
$e |= mask;
} else {
$e &= !mask;
}
};
}
}
#[derive(Copy, Clone, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct Flags {
bits: u16,
}
impl Flags {
pub fn new() -> Flags {
Flags { bits: 0 }
}
pub fn message_type(self) -> MessageType {
(get_bit!(self.bits, 15)).into()
}
#[cfg(test)]
pub(crate) fn set_message_type(&mut self, message_type: MessageType) -> &mut Self {
let value: bool = message_type.into();
set_bit!(self.bits, 15, value);
self
}
#[inline]
pub fn opcode(self) -> OpCode {
let bits = ((self.bits & 0b0111_1000_0000_0000) >> 11) as u8;
bits.into()
}
#[cfg(test)]
pub(crate) fn set_opcode(&mut self, opcode: OpCode) -> &mut Self {
let mask = 0b0111_1000_0000_0000;
self.bits = (self.bits & !mask) | ((opcode.value() as u16) << 11);
self
}
pub fn authoritative_answer(self) -> bool {
get_bit!(self.bits, 10)
}
#[cfg(test)]
pub(crate) fn set_authoritative_answer(&mut self, value: bool) -> &mut Self {
set_bit!(self.bits, 10, value);
self
}
pub fn truncated(self) -> bool {
get_bit!(self.bits, 9)
}
#[cfg(test)]
pub(crate) fn set_truncated(&mut self, value: bool) -> &mut Self {
set_bit!(self.bits, 9, value);
self
}
pub fn recursion_desired(self) -> bool {
get_bit!(self.bits, 8)
}
cfg_any_client! {
pub(crate) fn set_recursion_desired(&mut self, value: bool) -> &mut Self {
set_bit!(self.bits, 8, value);
self
}
}
pub fn recursion_available(self) -> bool {
get_bit!(self.bits, 7)
}
#[cfg(test)]
pub(crate) fn set_recursion_available(&mut self, value: bool) -> &mut Self {
set_bit!(self.bits, 7, value);
self
}
#[cfg(test)]
pub(crate) fn z(self) -> u8 {
(self.bits >> 4) as u8
}
#[cfg(test)]
pub(crate) fn set_z(&mut self, value: u8) -> &mut Self {
self.bits |= ((value & 0b0000_0111) << 4) as u16;
self
}
pub fn response_code(self) -> RCode {
let bits = self.bits & 0b0000_0000_0000_1111;
bits.into()
}
#[cfg(test)]
pub(crate) fn set_response_code(&mut self, rcode: RCode) -> &mut Self {
self.bits |= rcode.value();
self
}
}
impl std::fmt::Debug for Flags {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#b}", self.bits)
}
}
impl std::convert::From<u16> for Flags {
#[inline]
fn from(flags: u16) -> Flags {
Flags { bits: flags }
}
}
impl std::convert::From<Flags> for u16 {
#[inline]
fn from(f: Flags) -> u16 {
f.bits
}
}
#[cfg(test)]
mod tests {
use super::*;
type FlagGet = fn(Flags) -> bool;
type FlagSet = fn(&mut Flags, bool) -> &mut Flags;
fn test_bool_flag(get: FlagGet, set: FlagSet, mask: u16) {
let mut f = Flags::default();
assert_eq!(u16::from(f), 0);
assert!(!get(f));
set(&mut f, true);
assert!(get(f));
assert_eq!(u16::from(f), mask);
set(&mut f, false);
assert!(!get(f));
assert_eq!(u16::from(f), 0);
}
#[test]
fn test_bool_flags() {
test_bool_flag(
Flags::authoritative_answer,
Flags::set_authoritative_answer,
0b0000_0100_0000_0000,
);
test_bool_flag(
Flags::truncated,
Flags::set_truncated,
0b0000_0010_0000_0000,
);
test_bool_flag(
Flags::recursion_desired,
Flags::set_recursion_desired,
0b0000_0001_0000_0000,
);
test_bool_flag(
Flags::recursion_available,
Flags::set_recursion_available,
0b0000_0000_1000_0000,
);
}
#[test]
fn test_message_flags() {
let mut f = Flags::default();
assert_eq!(f.message_type(), MessageType::Query);
f.set_message_type(MessageType::Response);
assert_eq!(f.message_type(), MessageType::Response);
f.set_message_type(MessageType::Query);
assert_eq!(f.message_type(), MessageType::Query);
}
#[test]
fn test_opcode() {
for opcode in OpCode::VALUES {
let f = Flags {
bits: (opcode.value() as u16) << 11,
};
assert_eq!(f.opcode(), opcode);
let mut f = Flags::default();
assert_eq!(u16::from(f), 0);
f.set_opcode(opcode);
assert_eq!(f.opcode(), opcode);
assert_eq!(
(u16::from(f) & 0b0111_1000_0000_0000) >> 11,
opcode.value() as u16
);
}
for i in 0..16 {
if !OpCode::VALUES.iter().any(|oc| oc.value() as u16 == i) {
let f = Flags { bits: i << 11 };
assert_eq!(f.opcode(), i as u8);
}
}
}
#[test]
fn test_response_code() {
for rcode in RCode::VALUES {
if rcode > 15 {
continue;
}
let f = Flags {
bits: rcode.value(),
};
assert_eq!(f.response_code(), rcode);
let mut f = Flags::default();
assert_eq!(u16::from(f), 0);
f.set_response_code(rcode);
assert_eq!(f.response_code(), rcode);
assert_eq!(u16::from(f) & 0b0000_0000_0000_1111, rcode.value());
}
for i in 0..16 {
if !RCode::VALUES.iter().any(|rc| *rc == i) {
let f = Flags { bits: i };
assert!(!f.response_code().is_defined());
}
}
}
#[test]
fn test_z() {
for i in 0..8 {
let f = Flags { bits: i << 4 };
assert_eq!(f.z(), i as u8);
let mut f = Flags::default();
assert_eq!(f.z(), 0);
f.set_z(i as u8);
assert_eq!(f.z(), i as u8);
}
for i in 8..256 {
let mut f = Flags::default();
assert_eq!(f.z(), 0);
f.set_z(i as u8);
assert_eq!(f.z(), (i % 8) as u8);
assert_eq!(u16::from(f), ((i % 8) << 4) as u16);
}
}
}