use std::mem;
use bytes::{BigEndian, BufMut, ByteOrder};
use ::iana::{Opcode, Rcode};
use super::compose::Compose;
use super::parse::{Parse, Parser, ShortBuf};
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Header {
inner: [u8; 4]
}
impl Header {
pub fn new() -> Self {
Self::default()
}
pub fn for_message_slice(s: &[u8]) -> &Header {
assert!(s.len() >= mem::size_of::<Header>());
unsafe { &*(s.as_ptr() as *const Header) }
}
pub fn for_message_slice_mut(s: &mut [u8]) -> &mut Header {
assert!(s.len() >= mem::size_of::<Header>());
unsafe { &mut *(s.as_ptr() as *mut Header) }
}
pub fn as_slice(&self) -> &[u8] {
&self.inner
}
}
impl Header {
pub fn id(self) -> u16 {
BigEndian::read_u16(&self.inner)
}
pub fn set_id(&mut self, value: u16) {
BigEndian::write_u16(&mut self.inner, value)
}
pub fn set_random_id(&mut self) { self.set_id(::rand::random()) }
pub fn qr(self) -> bool { self.get_bit(2, 7) }
pub fn set_qr(&mut self, set: bool) { self.set_bit(2, 7, set) }
pub fn opcode(self) -> Opcode {
Opcode::from_int((self.inner[2] >> 3) & 0x0F)
}
pub fn set_opcode(&mut self, opcode: Opcode) {
self.inner[2] = self.inner[2] & 0x87 | (opcode.to_int() << 3);
}
pub fn aa(self) -> bool { self.get_bit(2, 2) }
pub fn set_aa(&mut self, set: bool) { self.set_bit(2, 2, set) }
pub fn tc(self) -> bool { self.get_bit(2, 1) }
pub fn set_tc(&mut self, set: bool) { self.set_bit(2, 1, set) }
pub fn rd(self) -> bool { self.get_bit(2, 0) }
pub fn set_rd(&mut self, set: bool) { self.set_bit(2, 0, set) }
pub fn ra(self) -> bool { self.get_bit(3, 7) }
pub fn set_ra(&mut self, set: bool) { self.set_bit(3, 7, set) }
pub fn z(self) -> bool { self.get_bit(3, 6) }
pub fn set_z(&mut self, set: bool) { self.set_bit(3, 6, set) }
pub fn ad(self) -> bool { self.get_bit(3, 5) }
pub fn set_ad(&mut self, set: bool) { self.set_bit(3, 5, set) }
pub fn cd(self) -> bool { self.get_bit(3, 4) }
pub fn set_cd(&mut self, set: bool) { self.set_bit(3, 4, set) }
pub fn rcode(self) -> Rcode {
Rcode::from_int(self.inner[3] & 0x0F)
}
pub fn set_rcode(&mut self, rcode: Rcode) {
self.inner[3] = self.inner[3] & 0xF0 | (rcode.to_int() & 0x0F);
}
fn get_bit(self, offset: usize, bit: usize) -> bool {
self.inner[offset] & (1 << bit) != 0
}
fn set_bit(&mut self, offset: usize, bit: usize, set: bool) {
if set { self.inner[offset] |= 1 << bit }
else { self.inner[offset] &= !(1 << bit) }
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct HeaderCounts {
inner: [u8; 8]
}
impl HeaderCounts {
pub fn new() -> Self {
Self::default()
}
pub fn for_message_slice(message: &[u8]) -> &Self {
assert!(message.len() >= mem::size_of::<HeaderSection>());
unsafe {
&*((message[mem::size_of::<Header>()..].as_ptr())
as *const HeaderCounts)
}
}
pub fn for_message_slice_mut(message: &mut [u8]) -> &mut Self {
assert!(message.len() >= mem::size_of::<HeaderSection>());
unsafe {
&mut *((message[mem::size_of::<Header>()..].as_ptr())
as *mut HeaderCounts)
}
}
pub fn as_slice(&self) -> &[u8] {
&self.inner
}
pub fn as_slice_mut(&mut self) -> &mut [u8] {
&mut self.inner
}
pub fn set(&mut self, counts: HeaderCounts) {
self.as_slice_mut().copy_from_slice(counts.as_slice())
}
}
impl HeaderCounts {
pub fn qdcount(self) -> u16 {
self.get_u16(0)
}
pub fn set_qdcount(&mut self, value: u16) {
self.set_u16(0, value)
}
pub fn inc_qdcount(&mut self) {
let count = self.qdcount();
assert!(count < ::std::u16::MAX);
self.set_qdcount(count + 1);
}
pub fn ancount(self) -> u16 {
self.get_u16(2)
}
pub fn set_ancount(&mut self, value: u16) {
self.set_u16(2, value)
}
pub fn inc_ancount(&mut self) {
let count = self.ancount();
assert!(count < ::std::u16::MAX);
self.set_ancount(count + 1);
}
pub fn nscount(self) -> u16 {
self.get_u16(4)
}
pub fn set_nscount(&mut self, value: u16) {
self.set_u16(4, value)
}
pub fn inc_nscount(&mut self) {
let count = self.nscount();
assert!(count < ::std::u16::MAX);
self.set_nscount(count + 1);
}
pub fn arcount(self) -> u16 {
self.get_u16(6)
}
pub fn set_arcount(&mut self, value: u16) {
self.set_u16(6, value)
}
pub fn inc_arcount(&mut self) {
let count = self.arcount();
assert!(count < ::std::u16::MAX);
self.set_arcount(count + 1);
}
pub fn zocount(self) -> u16 {
self.qdcount()
}
pub fn set_zocount(&mut self, value: u16) {
self.set_qdcount(value)
}
pub fn prcount(self) -> u16 {
self.ancount()
}
pub fn set_prcount(&mut self, value: u16) {
self.set_ancount(value)
}
pub fn upcount(self) -> u16 {
self.nscount()
}
pub fn set_upcount(&mut self, value: u16) {
self.set_nscount(value)
}
pub fn adcount(self) -> u16 {
self.arcount()
}
pub fn set_adcount(&mut self, value: u16) {
self.set_arcount(value)
}
fn get_u16(self, offset: usize) -> u16 {
BigEndian::read_u16(&self.inner[offset..])
}
fn set_u16(&mut self, offset: usize, value: u16) {
BigEndian::write_u16(&mut self.inner[offset..], value)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct HeaderSection {
inner: [u8; 12]
}
impl HeaderSection {
pub fn new() -> Self {
Self::default()
}
pub fn for_message_slice(s: &[u8]) -> &HeaderSection {
assert!(s.len() >= mem::size_of::<HeaderSection>());
unsafe { &*(s.as_ptr() as *const HeaderSection) }
}
pub fn for_message_slice_mut(s: &mut [u8]) -> &mut HeaderSection {
assert!(s.len() >= mem::size_of::<HeaderSection>());
unsafe { &mut *(s.as_ptr() as *mut HeaderSection) }
}
pub fn as_slice(&self) -> &[u8] {
&self.inner
}
}
impl HeaderSection {
pub fn header(&self) -> &Header {
Header::for_message_slice(&self.inner)
}
pub fn header_mut(&mut self) -> &mut Header {
Header::for_message_slice_mut(&mut self. inner)
}
pub fn counts(&self) -> &HeaderCounts {
HeaderCounts::for_message_slice(&self.inner)
}
pub fn counts_mut(&mut self) -> &mut HeaderCounts {
HeaderCounts::for_message_slice_mut(&mut self.inner)
}
}
impl Parse for HeaderSection {
type Err = ShortBuf;
fn parse(parser: &mut Parser) -> Result<Self, Self::Err> {
let mut res = Self::default();
parser.parse_buf(&mut res.inner)?;
Ok(res)
}
fn skip(parser: &mut Parser) -> Result<(), Self::Err> {
parser.advance(12)
}
}
impl Compose for HeaderSection {
fn compose_len(&self) -> usize {
12
}
fn compose<B: BufMut>(&self, buf: &mut B) {
buf.put_slice(&self.inner)
}
}
#[cfg(test)]
mod test {
use super::*;
use iana::{Opcode, Rcode};
#[test]
fn for_slice() {
let header = b"\x01\x02\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0";
let mut vec = Vec::from(&header[..]);
assert_eq!(Header::for_message_slice(header).as_slice(),
b"\x01\x02\x00\x00");
assert_eq!(Header::for_message_slice_mut(vec.as_mut()).as_slice(),
b"\x01\x02\x00\x00");
assert_eq!(HeaderCounts::for_message_slice(header).as_slice(),
b"\x12\x34\x56\x78\x9a\xbc\xde\xf0");
assert_eq!(HeaderCounts::for_message_slice_mut(vec.as_mut()).as_slice(),
b"\x12\x34\x56\x78\x9a\xbc\xde\xf0");
assert_eq!(HeaderSection::for_message_slice(header).as_slice(),
header);
assert_eq!(HeaderSection::for_message_slice_mut(vec.as_mut())
.as_slice(),
header);
}
#[test]
#[should_panic]
fn short_header() {
Header::for_message_slice(b"134");
}
#[test]
#[should_panic]
fn short_header_counts() {
HeaderCounts::for_message_slice(b"12345678");
}
#[test]
#[should_panic]
fn short_header_section() {
HeaderSection::for_message_slice(b"1234");
}
macro_rules! test_field {
($get:ident, $set:ident, $default:expr, $($value:expr),*) => {
$({
let mut h = Header::new();
assert_eq!(h.$get(), $default);
h.$set($value);
assert_eq!(h.$get(), $value);
})*
}
}
#[test]
fn header() {
test_field!(id, set_id, 0, 0x1234);
test_field!(qr, set_qr, false, true, false);
test_field!(opcode, set_opcode, Opcode::Query, Opcode::Notify);
test_field!(aa, set_aa, false, true, false);
test_field!(tc, set_tc, false, true, false);
test_field!(rd, set_rd, false, true, false);
test_field!(ra, set_ra, false, true, false);
test_field!(z, set_z, false, true, false);
test_field!(ad, set_ad, false, true, false);
test_field!(cd, set_cd, false, true, false);
test_field!(rcode, set_rcode, Rcode::NoError, Rcode::Refused);
}
#[test]
fn counts() {
let mut c = HeaderCounts { inner: [ 1, 2, 3, 4, 5, 6, 7, 8 ] };
assert_eq!(c.qdcount(), 0x0102);
assert_eq!(c.ancount(), 0x0304);
assert_eq!(c.nscount(), 0x0506);
assert_eq!(c.arcount(), 0x0708);
c.inc_qdcount();
c.inc_ancount();
c.inc_nscount();
c.inc_arcount();
assert_eq!(c.inner, [ 1, 3, 3, 5, 5, 7, 7, 9 ]);
c.set_qdcount(0x0807);
c.set_ancount(0x0605);
c.set_nscount(0x0403);
c.set_arcount(0x0201);
assert_eq!(c.inner, [ 8, 7, 6, 5, 4, 3, 2, 1 ]);
}
#[test]
fn update_counts() {
let mut c = HeaderCounts { inner: [ 1, 2, 3, 4, 5, 6, 7, 8 ] };
assert_eq!(c.zocount(), 0x0102);
assert_eq!(c.prcount(), 0x0304);
assert_eq!(c.upcount(), 0x0506);
assert_eq!(c.adcount(), 0x0708);
c.set_zocount(0x0807);
c.set_prcount(0x0605);
c.set_upcount(0x0403);
c.set_adcount(0x0201);
assert_eq!(c.inner, [ 8, 7, 6, 5, 4, 3, 2, 1 ]);
}
#[test]
#[should_panic]
fn bad_inc_qdcount() {
let mut c = HeaderCounts {
inner: [ 0xff, 0xff,0xff,0xff,0xff,0xff,0xff,0xff ]
};
c.inc_qdcount()
}
#[test]
#[should_panic]
fn bad_inc_ancount() {
let mut c = HeaderCounts {
inner: [ 0xff, 0xff,0xff,0xff,0xff,0xff,0xff,0xff ]
};
c.inc_ancount()
}
#[test]
#[should_panic]
fn bad_inc_nscount() {
let mut c = HeaderCounts {
inner: [ 0xff, 0xff,0xff,0xff,0xff,0xff,0xff,0xff ]
};
c.inc_nscount()
}
#[test]
#[should_panic]
fn bad_inc_arcount() {
let mut c = HeaderCounts {
inner: [ 0xff, 0xff,0xff,0xff,0xff,0xff,0xff,0xff ]
};
c.inc_arcount()
}
}