#![allow(non_camel_case_types)]
use fray::{BitField, bitfield};
use custom_type::{OpCode, Rcode};
mod custom_type {
use fray::FieldType;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpCode {
Query = 0,
IQuery = 1,
Status = 2,
}
impl FieldType for OpCode {
const SIZE: usize = 4;
type BitsType = u8;
}
impl TryFrom<u8> for OpCode {
type Error = ();
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
0 => Self::Query,
1 => Self::IQuery,
2 => Self::Status,
_ => return Err(()),
})
}
}
impl From<OpCode> for u8 {
fn from(value: OpCode) -> Self {
value as u8
}
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Rcode {
NoError = 0,
FormatError = 1,
ServerFailure = 2,
NameError = 3,
NotImplemented = 4,
Refused = 5,
}
impl FieldType for Rcode {
const SIZE: usize = 4;
type BitsType = u8;
}
impl TryFrom<u8> for Rcode {
type Error = ();
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
0 => Self::NoError,
1 => Self::FormatError,
2 => Self::ServerFailure,
3 => Self::NameError,
4 => Self::NotImplemented,
5 => Self::Refused,
_ => return Err(()),
})
}
}
impl From<Rcode> for u8 {
fn from(value: Rcode) -> Self {
value as u8
}
}
}
#[test]
fn msb0() {
#[bitfield(repr(u16), bitorder(msb0))]
pub struct DNSFlags {
QR: bool,
OPCODE: OpCode,
AA: bool,
TC: bool,
RD: bool,
RA: bool,
_z0: bool,
_z1: bool,
_z2: bool,
RCODE: Rcode,
}
let mut dns_flags = DNSFlags::new();
dns_flags
.with::<QR>(true)
.with::<RD>(true)
.with::<RA>(true)
.with::<RCODE>(Rcode::NameError);
assert_eq!(dns_flags.into_inner(), 0x8183)
}
#[test]
fn msb0_override_field_size() {
#[bitfield(repr(u16), bitorder(msb0))]
pub struct DNSFlags {
QR: bool,
OPCODE: OpCode,
AA: bool,
TC: bool,
RD: bool,
RA: bool,
#[bits(3)]
_z: (),
RCODE: Rcode,
}
let mut dns_flags = DNSFlags::new();
dns_flags
.with::<QR>(true)
.with::<RD>(true)
.with::<RA>(true)
.with::<RCODE>(Rcode::NameError);
assert_eq!(dns_flags.into_inner(), 0x8183)
}
mod units {
use super::*;
use fray::iterable::BitIterableContainer;
#[bitfield(repr(u16), derives(Clone, Copy), impls(debug))]
pub struct DNSFlags {
QR: bool,
OPCODE: OpCode,
AA: bool,
TC: bool,
RD: bool,
RA: bool,
#[bits(3)]
_z: (),
RCODE: Rcode,
}
#[test]
fn set_bool_basic() {
let mut dns_flags = DNSFlags::new();
dns_flags.set::<QR>(true);
assert_eq!(dns_flags.into_inner(), 0b1u16);
dns_flags.set::<QR>(true);
assert_eq!(dns_flags.into_inner(), 0b1u16);
dns_flags.set::<QR>(false);
assert_eq!(dns_flags.into_inner(), 0b0u16);
dns_flags.set::<QR>(false);
assert_eq!(dns_flags.into_inner(), 0b0u16);
dns_flags.set::<RA>(true);
assert_eq!(dns_flags.into_inner(), 0b1_0000_0000u16);
dns_flags.set::<RA>(true);
assert_eq!(dns_flags.into_inner(), 0b1_0000_0000u16);
dns_flags.set::<RA>(false);
assert_eq!(dns_flags.into_inner(), 0b0u16);
dns_flags.set::<RA>(false);
assert_eq!(dns_flags.into_inner(), 0b0u16);
}
#[test]
fn set_bool_dont_overlap() {
let mut dns_flags = DNSFlags::new();
dns_flags.set::<TC>(true);
dns_flags.set::<RD>(true);
dns_flags.set::<RA>(true);
assert_eq!(dns_flags.into_inner(), 0b1_1100_0000u16);
dns_flags.set::<RD>(false);
assert_eq!(dns_flags.into_inner(), 0b1_0100_0000u16);
dns_flags.set::<RD>(true);
assert_eq!(dns_flags.into_inner(), 0b1_1100_0000u16);
}
#[test]
fn get_bool() {
let dns_flags = DNSFlags::from(BitIterableContainer::from(0b1u16));
assert!(dns_flags.get::<QR>());
let dns_flags = DNSFlags::new();
assert!(!dns_flags.get::<RA>());
assert!(!dns_flags.get::<QR>());
let dns_flags = DNSFlags::from(BitIterableContainer::from(0b1_0000_0000u16));
assert!(dns_flags.get::<RA>());
}
#[test]
fn set_custom_type() {
let mut dns_flags = DNSFlags::new();
dns_flags.set::<OPCODE>(OpCode::IQuery);
assert_eq!(dns_flags.into_inner(), 0b10u16);
dns_flags.set::<OPCODE>(OpCode::Status);
assert_eq!(dns_flags.into_inner(), 0b100u16);
dns_flags.set::<OPCODE>(OpCode::Query);
assert_eq!(dns_flags.into_inner(), 0b0u16);
}
#[test]
fn get_custom_type() {
let dns_flags = DNSFlags::from(BitIterableContainer::from(0b10u16));
assert_eq!(dns_flags.try_get::<OPCODE>(), Ok(OpCode::IQuery));
let dns_flags = DNSFlags::from(BitIterableContainer::from(0b100u16));
assert_eq!(dns_flags.try_get::<OPCODE>(), Ok(OpCode::Status));
let dns_flags = DNSFlags::new();
assert_eq!(dns_flags.try_get::<OPCODE>(), Ok(OpCode::Query));
}
}