use std::{mem, ops};
use std::marker::PhantomData;
use bytes::Bytes;
use ::iana::{Rcode, Rtype};
use ::rdata::Cname;
use super::message_builder::{MessageBuilder, AdditionalBuilder, RecordSectionBuilder};
use super::header::{Header, HeaderCounts, HeaderSection};
use super::name::{ParsedDname, ParsedDnameError, ToDname};
use super::parse::{Parse, Parser, ShortBuf};
use super::question::Question;
use super::rdata::{ParseRecordData, RecordData};
use super::record::{ParsedRecord, Record, RecordParseError};
#[derive(Clone, Debug)]
pub struct Message {
bytes:Bytes,
}
impl Message {
pub fn from_bytes(bytes: Bytes) -> Result<Self, ShortBuf> {
if bytes.len() < mem::size_of::<HeaderSection>() {
Err(ShortBuf)
}
else {
Ok(Message { bytes })
}
}
pub(super) unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self {
Message { bytes }
}
pub fn as_bytes(&self) -> &Bytes {
&self.bytes
}
pub fn as_slice(&self) -> &[u8] {
self.bytes.as_ref()
}
}
impl Message {
pub fn header(&self) -> &Header {
Header::for_message_slice(self.as_slice())
}
pub fn header_counts(&self) -> &HeaderCounts {
HeaderCounts::for_message_slice(self.as_slice())
}
pub fn no_error(&self) -> bool {
self.header().rcode() == Rcode::NoError
}
pub fn is_error(&self) -> bool {
self.header().rcode() != Rcode::NoError
}
}
impl Message {
pub fn question(&self) -> QuestionSection {
QuestionSection::new(self.bytes.clone())
}
pub fn zone(&self) -> QuestionSection { self.question() }
pub fn answer(&self) -> Result<RecordSection, ParsedDnameError> {
Ok(self.question().next_section()?)
}
pub fn prerequisite(&self) -> Result<RecordSection, ParsedDnameError> {
self.answer()
}
pub fn authority(&self) -> Result<RecordSection, ParsedDnameError> {
Ok(self.answer()?.next_section()?.unwrap())
}
pub fn update(&self) -> Result<RecordSection, ParsedDnameError> {
self.authority()
}
pub fn additional(&self) -> Result<RecordSection, ParsedDnameError> {
Ok(self.authority()?.next_section()?.unwrap())
}
pub fn sections(&self) -> Result<(QuestionSection, RecordSection,
RecordSection, RecordSection),
ParsedDnameError> {
let question = self.question();
let answer = question.clone().next_section()?;
let authority = answer.clone().next_section()?.unwrap();
let additional = authority.clone().next_section()?.unwrap();
Ok((question, answer, authority, additional))
}
pub fn iter(&self) -> MessageIterator {
match self.answer() {
Ok(section) => MessageIterator { inner: Some(section) },
Err(_) => MessageIterator { inner: None },
}
}
pub fn copy_records<N, D, R, F>(&self, target: MessageBuilder, op: F) -> Result<AdditionalBuilder, ParsedDnameError>
where N: ToDname, D: RecordData, R: Into<Record<N, D>>, F: FnMut(Result<ParsedRecord, ParsedDnameError>) -> Option<R> + Copy
{
let mut target = target.answer();
self.answer()?.filter_map(op).for_each(|rr| target.push(rr).unwrap());
let mut target = target.authority();
self.authority()?.filter_map(op).for_each(|rr| target.push(rr).unwrap());
let mut target = target.additional();
self.additional()?.filter_map(op).for_each(|rr| target.push(rr).unwrap());
Ok(target)
}
}
impl Message {
pub fn is_answer(&self, query: &Message) -> bool {
if !self.header().qr()
|| self.header().id() != query.header().id()
|| self.header_counts().qdcount()
!= query.header_counts().qdcount() {
false
}
else { self.question().eq(query.question()) }
}
pub fn first_question(&self) -> Option<Question<ParsedDname>> {
match self.question().next() {
None | Some(Err(..)) => None,
Some(Ok(question)) => Some(question)
}
}
pub fn qtype(&self) -> Option<Rtype> {
self.first_question().map(|x| x.qtype())
}
pub fn contains_answer<D: ParseRecordData>(&self) -> bool {
let answer = match self.answer() {
Ok(answer) => answer,
Err(..) => return false
};
answer.limit_to::<D>().next().is_some()
}
pub fn canonical_name(&self) -> Option<ParsedDname> {
let question = match self.first_question() {
None => return None,
Some(question) => question
};
let mut name = question.qname().clone();
let answer = match self.answer() {
Ok(answer) => answer.limit_to::<Cname<ParsedDname>>(),
Err(_) => return None,
};
loop {
let mut found = false;
for record in answer.clone() {
let record = match record {
Ok(record) => record,
Err(_) => continue,
};
if *record.owner() == name {
name = record.data().cname().clone();
found = true;
break;
}
}
if !found {
break
}
}
Some(name)
}
}
impl ops::Deref for Message {
type Target = Bytes;
fn deref(&self) -> &Self::Target {
self.as_bytes()
}
}
impl AsRef<Message> for Message {
fn as_ref(&self) -> &Message {
self
}
}
impl AsRef<Bytes> for Message {
fn as_ref(&self) -> &Bytes {
self.as_bytes()
}
}
impl AsRef<[u8]> for Message {
fn as_ref(&self) -> &[u8] {
self.as_slice()
}
}
pub struct MessageIterator {
inner: Option<RecordSection>,
}
impl Iterator for MessageIterator {
type Item = (Result<ParsedRecord, ParsedDnameError>, Section);
fn next(&mut self) -> Option<Self::Item> {
match self.inner {
Some(ref mut inner) => {
let item = inner.next();
if let Some(item) = item {
return Some((item, inner.section));
}
},
None => return None,
}
self.inner = match self.inner.clone().unwrap().next_section() {
Ok(section) => section,
Err(_) => None,
};
self.next()
}
}
#[derive(Clone, Debug)]
pub struct QuestionSection {
parser: Parser,
count: Result<u16, ParsedDnameError>
}
impl QuestionSection {
fn new(bytes: Bytes) -> Self {
let mut parser = Parser::from_bytes(bytes);
parser.advance(mem::size_of::<HeaderSection>()).unwrap();
QuestionSection {
count: Ok(HeaderCounts::for_message_slice(
parser.as_slice()).qdcount()
),
parser,
}
}
pub fn answer(mut self) -> Result<RecordSection, ParsedDnameError> {
for question in &mut self {
let _ = question?;
}
match self.count {
Ok(..) => Ok(RecordSection::new(self.parser, Section::first())),
Err(err) => Err(err)
}
}
pub fn next_section(self) -> Result<RecordSection, ParsedDnameError> {
self.answer()
}
}
impl Iterator for QuestionSection {
type Item = Result<Question<ParsedDname>, ParsedDnameError>;
fn next(&mut self) -> Option<Self::Item> {
match self.count {
Ok(count) if count > 0 => {
match Question::parse(&mut self.parser) {
Ok(question) => {
self.count = Ok(count - 1);
Some(Ok(question))
}
Err(err) => {
self.count = Err(err);
Some(Err(err))
}
}
}
_ => None
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd)]
pub enum Section {
Answer,
Authority,
Additional
}
impl Section {
pub fn first() -> Self { Section::Answer }
fn count(self, counts: HeaderCounts) -> u16 {
match self {
Section::Answer => counts.ancount(),
Section::Authority => counts.nscount(),
Section::Additional => counts.arcount()
}
}
fn next_section(self) -> Option<Self> {
match self {
Section::Answer => Some(Section::Authority),
Section::Authority => Some(Section::Additional),
Section::Additional => None
}
}
}
#[derive(Clone, Debug)]
pub struct RecordSection {
parser: Parser,
section: Section,
count: Result<u16, ParsedDnameError>
}
impl RecordSection {
fn new(parser: Parser, section: Section) -> Self {
RecordSection {
count: Ok(section.count(
*HeaderCounts::for_message_slice(parser.as_slice())
)),
section,
parser,
}
}
pub fn limit_to<D: ParseRecordData>(self) -> RecordIter<D> {
RecordIter::new(self)
}
pub fn next_section(mut self)
-> Result<Option<Self>, ParsedDnameError> {
let section = match self.section.next_section() {
Some(section) => section,
None => return Ok(None)
};
for record in &mut self {
let _ = try!(record);
}
match self.count {
Ok(..) => Ok(Some(RecordSection::new(self.parser, section))),
Err(err) => Err(err)
}
}
}
impl Iterator for RecordSection {
type Item = Result<ParsedRecord, ParsedDnameError>;
fn next(&mut self) -> Option<Self::Item> {
match self.count {
Ok(count) if count > 0 => {
match ParsedRecord::parse(&mut self.parser) {
Ok(record) => {
self.count = Ok(count - 1);
Some(Ok(record))
}
Err(err) => {
self.count = Err(err);
Some(Err(err))
}
}
}
_ => None
}
}
}
#[derive(Clone, Debug)]
pub struct RecordIter<D: ParseRecordData> {
section: RecordSection,
marker: PhantomData<D>
}
impl<D: ParseRecordData> RecordIter<D> {
fn new(section: RecordSection) -> Self {
RecordIter { section, marker: PhantomData }
}
pub fn unwrap(self) -> RecordSection {
self.section
}
pub fn next_section(self)
-> Result<Option<RecordSection>, ParsedDnameError> {
self.section.next_section()
}
}
impl<D: ParseRecordData> Iterator for RecordIter<D> {
type Item = Result<Record<ParsedDname, D>,
RecordParseError<ParsedDnameError, D::Err>>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let record = match self.section.next() {
Some(Ok(record)) => record,
Some(Err(err)) => {
return Some(Err(RecordParseError::Name(err)))
}
None => return None,
};
match record.into_record() {
Ok(Some(record)) => return Some(Ok(record)),
Err(err) => return Some(Err(err)),
Ok(None) => { }
}
}
}
}
#[cfg(test)]
mod test {
use std::str::FromStr;
use super::*;
use bits::name::*;
use bits::message_builder::*;
use rdata::Ns;
use bits::rdata::UnknownRecordData;
fn get_test_message() -> Message {
let msg = MessageBuilder::with_capacity(512);
let mut msg = msg.answer();
msg.push((Dname::from_str("foo.example.com.").unwrap(), 86000,
Cname::new(Dname::from_str("baz.example.com.")
.unwrap()))).unwrap();
let mut msg = msg.authority();
msg.push((Dname::from_str("bar.example.com.").unwrap(), 86000,
Ns::new(Dname::from_str("baz.example.com.")
.unwrap()))).unwrap();
Message::from_bytes(msg.finish().into()).unwrap()
}
#[test]
fn short_message() {
assert!(Message::from_bytes(Bytes::from_static(&[0u8; 11])).is_err());
assert!(Message::from_bytes(Bytes::from_static(&[0u8; 12])).is_ok());
}
#[test]
fn message_iterator() {
let msg = get_test_message();
let mut iter = msg.iter();
let mut value = iter.next();
assert_eq!(true, value.is_some());
let (rr, section) = value.unwrap();
assert_eq!(Section::Answer, section);
assert!(rr.is_ok());
value = iter.next();
assert_eq!(true, value.is_some());
let (rr, section) = value.unwrap();
assert_eq!(Section::Authority, section);
assert!(rr.is_ok());
}
#[test]
fn copy_records() {
let msg = get_test_message();
let target = MessageBuilder::with_capacity(512);
let res = msg.copy_records(target, |rec| {
if let Ok(rr) = rec {
if let Ok(Some(rr)) = rr.into_record::<UnknownRecordData>() {
if rr.rtype() == Rtype::Cname {
return Some(rr);
}
}
}
return None;
});
assert!(res.is_ok());
if let Ok(target) = res {
let msg = target.freeze();
assert_eq!(1, msg.header_counts().ancount());
assert_eq!(0, msg.header_counts().arcount());
}
}
}