use super::header::{Header, HeaderCounts, HeaderSection};
use super::iana::{Class, Rcode, Rtype};
use super::message_builder::{AdditionalBuilder, AnswerBuilder};
use super::name::ParsedDname;
use super::octets::{
OctetsBuilder, OctetsFrom, OctetsRef, Parse, ParseError, Parser, ShortBuf,
};
use super::opt::{Opt, OptRecord};
use super::question::Question;
use super::rdata::ParseRecordData;
use super::record::{AsRecord, ParsedRecord, Record};
use crate::rdata::rfc1035::Cname;
use core::marker::PhantomData;
use core::{fmt, mem};
#[derive(Clone, Copy)]
pub struct Message<Octets> {
octets: Octets,
}
impl<Octets> Message<Octets> {
pub fn from_octets(octets: Octets) -> Result<Self, ShortBuf>
where
Octets: AsRef<[u8]>,
{
if octets.as_ref().len() < mem::size_of::<HeaderSection>() {
Err(ShortBuf)
} else {
Ok(unsafe { Self::from_octets_unchecked(octets) })
}
}
pub(super) unsafe fn from_octets_unchecked(octets: Octets) -> Self {
Message { octets }
}
pub fn as_octets(&self) -> &Octets {
&self.octets
}
pub fn into_octets(self) -> Octets {
self.octets
}
pub fn as_slice(&self) -> &[u8]
where
Octets: AsRef<[u8]>,
{
self.octets.as_ref()
}
fn as_slice_mut(&mut self) -> &mut [u8]
where
Octets: AsMut<[u8]>,
{
self.octets.as_mut()
}
pub fn for_slice(&self) -> Message<&[u8]>
where
Octets: AsRef<[u8]>,
{
unsafe { Message::from_octets_unchecked(self.octets.as_ref()) }
}
}
impl<Octets: AsRef<[u8]>> Message<Octets> {
pub fn header(&self) -> Header {
*Header::for_message_slice(self.as_slice())
}
pub fn header_mut(&mut self) -> &mut Header
where
Octets: AsMut<[u8]>,
{
Header::for_message_slice_mut(self.as_slice_mut())
}
pub fn header_counts(&self) -> HeaderCounts {
*HeaderCounts::for_message_slice(self.as_slice())
}
pub fn header_section(&self) -> HeaderSection {
*HeaderSection::for_message_slice(self.as_slice())
}
pub fn no_error(&self) -> bool {
self.header().rcode() == Rcode::NoError
}
pub fn is_error(&self) -> bool {
self.header().rcode() != Rcode::NoError
}
}
impl<Octets> Message<Octets>
where
for<'a> &'a Octets: OctetsRef,
{
pub fn question(&self) -> QuestionSection<&Octets> {
QuestionSection::new(&self.octets)
}
pub fn zone(&self) -> QuestionSection<&Octets> {
self.question()
}
pub fn answer(&self) -> Result<RecordSection<&Octets>, ParseError> {
self.question().next_section()
}
pub fn prerequisite(&self) -> Result<RecordSection<&Octets>, ParseError> {
self.answer()
}
pub fn authority(&self) -> Result<RecordSection<&Octets>, ParseError> {
Ok(self.answer()?.next_section()?.unwrap())
}
pub fn update(&self) -> Result<RecordSection<&Octets>, ParseError> {
self.authority()
}
pub fn additional(&self) -> Result<RecordSection<&Octets>, ParseError> {
Ok(self.authority()?.next_section()?.unwrap())
}
#[allow(clippy::type_complexity)]
pub fn sections(
&self,
) -> Result<
(
QuestionSection<&Octets>,
RecordSection<&Octets>,
RecordSection<&Octets>,
RecordSection<&Octets>,
),
ParseError,
> {
let question = self.question();
let answer = question.next_section()?;
let authority = answer.next_section()?.unwrap();
let additional = authority.next_section()?.unwrap();
Ok((question, answer, authority, additional))
}
pub fn iter(&self) -> MessageIter<&Octets> {
self.into_iter()
}
}
impl<Octets> Message<Octets>
where
Octets: AsRef<[u8]>,
for<'a> &'a Octets: OctetsRef,
{
pub fn is_answer<Other>(&self, query: &Message<Other>) -> bool
where
Other: AsRef<[u8]>,
for<'o> &'o Other: OctetsRef,
{
if !self.header().qr()
|| self.header().id() != query.header().id()
|| self.header_counts().qdcount()
!= query.header_counts().qdcount()
{
false
} else {
self.question() == query.question()
}
}
pub fn first_question(&self) -> Option<Question<ParsedDname<&Octets>>> {
match self.question().next() {
None | Some(Err(..)) => None,
Some(Ok(question)) => Some(question),
}
}
pub fn sole_question(
&self,
) -> Result<Question<ParsedDname<&Octets>>, ParseError> {
match self.header_counts().qdcount() {
0 => return Err(ParseError::form_error("no question")),
1 => {}
_ => return Err(ParseError::form_error("multiple questions")),
}
self.question().next().unwrap()
}
pub fn qtype(&self) -> Option<Rtype> {
self.first_question().map(|x| x.qtype())
}
pub fn contains_answer<'s, Data>(&'s self) -> bool
where
Data: ParseRecordData<&'s Octets>,
{
let answer = match self.answer() {
Ok(answer) => answer,
Err(..) => return false,
};
answer.limit_to::<Data>().next().is_some()
}
pub fn canonical_name(&self) -> Option<ParsedDname<&Octets>> {
let question = match self.first_question() {
None => return None,
Some(question) => question,
};
let mut name = question.into_qname();
let answer = match self.answer() {
Ok(answer) => answer.limit_to::<Cname<_>>(),
Err(_) => return None,
};
for _ in 0..self.header_counts().ancount() + 1 {
let mut found = false;
for record in answer.clone() {
let record = match record {
Ok(record) => record,
Err(_) => continue,
};
if *record.owner() == name {
name = *record.data().cname();
found = true;
break;
}
}
if !found {
return Some(name);
}
}
None
}
pub fn opt(&self) -> Option<OptRecord<<&Octets as OctetsRef>::Range>> {
match self.additional() {
Ok(section) => match section.limit_to::<Opt<_>>().next() {
Some(Ok(rr)) => Some(OptRecord::from(rr)),
_ => None,
},
Err(_) => None,
}
}
pub fn get_last_additional<'s, Data: ParseRecordData<&'s Octets>>(
&'s self,
) -> Option<Record<ParsedDname<&'s Octets>, Data>> {
let mut section = match self.additional() {
Ok(section) => section,
Err(_) => return None,
};
loop {
match section.count {
Err(_) => return None,
Ok(0) => return None,
Ok(1) => break,
_ => {}
}
let _ = section.next();
}
let record = match ParsedRecord::parse(&mut section.parser) {
Ok(record) => record,
Err(_) => return None,
};
let record = match record.into_record() {
Ok(Some(record)) => record,
_ => return None,
};
Some(record)
}
pub fn remove_last_additional(&mut self)
where
Octets: AsMut<[u8]>,
{
HeaderCounts::for_message_slice_mut(self.octets.as_mut())
.dec_arcount();
}
pub fn copy_records<'s, R, F, T, O>(
&'s self,
target: T,
mut op: F,
) -> Result<AdditionalBuilder<O>, CopyRecordsError>
where
for<'a> &'a Octets: OctetsRef,
R: AsRecord + 's,
F: FnMut(ParsedRecord<&'s Octets>) -> Option<R>,
T: Into<AnswerBuilder<O>>,
O: OctetsBuilder + AsMut<[u8]>,
{
let mut source = self.answer()?;
let mut target = target.into();
for rr in &mut source {
let rr = rr?;
if let Some(rr) = op(rr) {
target.push(rr)?;
}
}
let mut source = source.next_section()?.unwrap();
let mut target = target.authority();
for rr in &mut source {
let rr = rr?;
if let Some(rr) = op(rr) {
target.push(rr)?;
}
}
let source = source.next_section()?.unwrap();
let mut target = target.additional();
for rr in source {
let rr = rr?;
if let Some(rr) = op(rr) {
target.push(rr)?;
}
}
Ok(target)
}
}
impl<Octets> AsRef<Octets> for Message<Octets> {
fn as_ref(&self) -> &Octets {
&self.octets
}
}
impl<Octets: AsRef<[u8]>> AsRef<[u8]> for Message<Octets> {
fn as_ref(&self) -> &[u8] {
self.octets.as_ref()
}
}
impl<Octets, SrcOctets> OctetsFrom<Message<SrcOctets>> for Message<Octets>
where
Octets: OctetsFrom<SrcOctets>,
{
fn octets_from(source: Message<SrcOctets>) -> Result<Self, ShortBuf> {
Octets::octets_from(source.octets)
.map(|octets| unsafe { Self::from_octets_unchecked(octets) })
}
}
impl<'a, Octets> IntoIterator for &'a Message<Octets>
where
for<'s> &'s Octets: OctetsRef,
{
type Item = Result<(ParsedRecord<&'a Octets>, Section), ParseError>;
type IntoIter = MessageIter<&'a Octets>;
fn into_iter(self) -> Self::IntoIter {
MessageIter {
inner: self.answer().ok(),
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct QuestionSection<Ref> {
parser: Parser<Ref>,
count: Result<u16, ParseError>,
}
impl<Ref: OctetsRef> QuestionSection<Ref> {
fn new(octets: Ref) -> Self {
let mut parser = Parser::from_ref(octets);
parser.advance(mem::size_of::<HeaderSection>()).unwrap();
QuestionSection {
count: Ok(
HeaderCounts::for_message_slice(parser.as_slice()).qdcount()
),
parser,
}
}
pub fn pos(&self) -> usize {
self.parser.pos()
}
pub fn answer(mut self) -> Result<RecordSection<Ref>, ParseError> {
while self.next().is_some() {}
let _ = self.count?;
Ok(RecordSection::new(self.parser, Section::first()))
}
pub fn next_section(self) -> Result<RecordSection<Ref>, ParseError> {
self.answer()
}
}
impl<Ref: OctetsRef> Iterator for QuestionSection<Ref> {
type Item = Result<Question<ParsedDname<Ref>>, ParseError>;
fn next(&mut self) -> Option<Self::Item> {
match self.count {
Ok(count) if count > 0 => match Question::parse(&mut self.parser)
{
Ok(question) => {
self.count = Ok(count - 1);
Some(Ok(question))
}
Err(err) => {
self.count = Err(err);
Some(Err(err))
}
},
_ => None,
}
}
}
impl<Ref, Other> PartialEq<QuestionSection<Other>> for QuestionSection<Ref>
where
Ref: OctetsRef,
Other: OctetsRef,
{
fn eq(&self, other: &QuestionSection<Other>) -> bool {
let mut me = *self;
let mut other = *other;
loop {
match (me.next(), other.next()) {
(Some(Ok(left)), Some(Ok(right))) => {
if left != right {
return false;
}
}
(None, None) => return true,
_ => return false,
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd)]
pub enum Section {
Answer,
Authority,
Additional,
}
impl Section {
pub fn first() -> Self {
Section::Answer
}
fn count(self, counts: HeaderCounts) -> u16 {
match self {
Section::Answer => counts.ancount(),
Section::Authority => counts.nscount(),
Section::Additional => counts.arcount(),
}
}
pub(crate) fn next_section(self) -> Option<Self> {
match self {
Section::Answer => Some(Section::Authority),
Section::Authority => Some(Section::Additional),
Section::Additional => None,
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct RecordSection<Ref> {
parser: Parser<Ref>,
section: Section,
count: Result<u16, ParseError>,
}
impl<Ref: OctetsRef> RecordSection<Ref> {
fn new(parser: Parser<Ref>, section: Section) -> Self {
RecordSection {
count: Ok(section
.count(*HeaderCounts::for_message_slice(parser.as_slice()))),
section,
parser,
}
}
pub fn pos(&self) -> usize {
self.parser.pos()
}
pub fn limit_to<Data: ParseRecordData<Ref>>(
self,
) -> RecordIter<Ref, Data> {
RecordIter::new(self, false)
}
pub fn limit_to_in<Data: ParseRecordData<Ref>>(
self,
) -> RecordIter<Ref, Data> {
RecordIter::new(self, true)
}
pub fn next_section(mut self) -> Result<Option<Self>, ParseError> {
let section = match self.section.next_section() {
Some(section) => section,
None => return Ok(None),
};
while self.skip_next().is_some() {}
let _ = self.count?;
Ok(Some(RecordSection::new(self.parser, section)))
}
fn skip_next(&mut self) -> Option<Result<(), ParseError>> {
match self.count {
Ok(count) if count > 0 => {
match ParsedRecord::skip(&mut self.parser) {
Ok(_) => {
self.count = Ok(count - 1);
Some(Ok(()))
}
Err(err) => {
self.count = Err(err);
Some(Err(err))
}
}
}
_ => None,
}
}
}
impl<Ref: OctetsRef> Iterator for RecordSection<Ref> {
type Item = Result<ParsedRecord<Ref>, ParseError>;
fn next(&mut self) -> Option<Self::Item> {
match self.count {
Ok(count) if count > 0 => {
match ParsedRecord::parse(&mut self.parser) {
Ok(record) => {
self.count = Ok(count - 1);
Some(Ok(record))
}
Err(err) => {
self.count = Err(err);
Some(Err(err))
}
}
}
_ => None,
}
}
}
pub struct MessageIter<Ref> {
inner: Option<RecordSection<Ref>>,
}
impl<Ref: OctetsRef> Iterator for MessageIter<Ref> {
type Item = Result<(ParsedRecord<Ref>, Section), ParseError>;
fn next(&mut self) -> Option<Self::Item> {
match self.inner {
Some(ref mut inner) => {
let item = inner.next();
if let Some(item) = item {
return Some(item.map(|item| (item, inner.section)));
}
}
None => return None,
}
let inner = self.inner.take()?;
match inner.next_section() {
Ok(section) => {
self.inner = section;
self.next()
}
Err(err) => Some(Err(err)),
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct RecordIter<Ref, Data> {
section: RecordSection<Ref>,
in_only: bool,
marker: PhantomData<Data>,
}
impl<Ref: OctetsRef, Data: ParseRecordData<Ref>> RecordIter<Ref, Data> {
fn new(section: RecordSection<Ref>, in_only: bool) -> Self {
RecordIter {
section,
in_only,
marker: PhantomData,
}
}
pub fn unwrap(self) -> RecordSection<Ref> {
self.section
}
pub fn next_section(
self,
) -> Result<Option<RecordSection<Ref>>, ParseError> {
self.section.next_section()
}
}
impl<Ref, Data> Iterator for RecordIter<Ref, Data>
where
Ref: OctetsRef,
Data: ParseRecordData<Ref>,
{
type Item = Result<Record<ParsedDname<Ref>, Data>, ParseError>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let record = match self.section.next() {
Some(Ok(record)) => record,
Some(Err(err)) => return Some(Err(err)),
None => return None,
};
if self.in_only && record.class() != Class::In {
continue;
}
match record.into_record() {
Ok(Some(record)) => return Some(Ok(record)),
Err(err) => return Some(Err(err)),
Ok(None) => {}
}
}
}
}
#[derive(Clone, Copy, Debug)]
pub enum CopyRecordsError {
Parse(ParseError),
ShortBuf,
}
impl From<ParseError> for CopyRecordsError {
fn from(err: ParseError) -> Self {
CopyRecordsError::Parse(err)
}
}
impl From<ShortBuf> for CopyRecordsError {
fn from(_: ShortBuf) -> Self {
CopyRecordsError::ShortBuf
}
}
impl fmt::Display for CopyRecordsError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
CopyRecordsError::Parse(ref err) => err.fmt(f),
CopyRecordsError::ShortBuf => ShortBuf.fmt(f),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for CopyRecordsError {}
#[cfg(test)]
mod test {
use super::*;
#[cfg(feature = "std")]
use crate::base::message_builder::MessageBuilder;
#[cfg(feature = "std")]
use crate::base::name::Dname;
#[cfg(feature = "std")]
use crate::rdata::{AllRecordData, Ns};
#[cfg(feature = "std")]
use std::vec::Vec;
#[cfg(feature = "std")]
fn get_test_message() -> Message<Vec<u8>> {
let msg = MessageBuilder::new_vec();
let mut msg = msg.answer();
msg.push((
Dname::vec_from_str("foo.example.com.").unwrap(),
86000,
Cname::new(Dname::vec_from_str("baz.example.com.").unwrap()),
))
.unwrap();
let mut msg = msg.authority();
msg.push((
Dname::vec_from_str("bar.example.com.").unwrap(),
86000,
Ns::new(Dname::vec_from_str("baz.example.com.").unwrap()),
))
.unwrap();
msg.into_message()
}
#[test]
fn short_message() {
assert!(Message::from_octets(&[0u8; 11]).is_err());
assert!(Message::from_octets(&[0u8; 12]).is_ok());
}
#[test]
#[cfg(feature = "std")]
fn canonical_name() {
use crate::rdata::A;
let mut msg = MessageBuilder::new_vec().question();
msg.push((Dname::vec_from_str("example.com.").unwrap(), Rtype::A))
.unwrap();
let msg_ref = msg.as_message();
assert_eq!(
Dname::vec_from_str("example.com.").unwrap(),
msg_ref.canonical_name().unwrap()
);
let mut msg = msg.answer();
msg.push((
Dname::vec_from_str("bar.example.com.").unwrap(),
86000,
Cname::new(Dname::vec_from_str("baz.example.com.").unwrap()),
))
.unwrap();
msg.push((
Dname::vec_from_str("example.com.").unwrap(),
86000,
Cname::new(Dname::vec_from_str("foo.example.com.").unwrap()),
))
.unwrap();
msg.push((
Dname::vec_from_str("foo.example.com.").unwrap(),
86000,
Cname::new(Dname::vec_from_str("bar.example.com.").unwrap()),
))
.unwrap();
let msg_ref = msg.as_message();
assert_eq!(
Dname::vec_from_str("baz.example.com.").unwrap(),
msg_ref.canonical_name().unwrap()
);
msg.push((
Dname::vec_from_str("baz.example.com").unwrap(),
86000,
Cname::new(Dname::vec_from_str("foo.example.com").unwrap()),
))
.unwrap();
assert!(msg.as_message().canonical_name().is_none());
msg.push((
Dname::vec_from_str("baz.example.com").unwrap(),
86000,
A::from_octets(127, 0, 0, 1),
))
.unwrap();
assert!(msg.as_message().canonical_name().is_none());
}
#[test]
#[cfg(feature = "std")]
fn message_iterator() {
let msg = get_test_message();
let mut iter = msg.iter();
let (_rr, section) = iter.next().unwrap().unwrap();
assert_eq!(Section::Answer, section);
let (_rr, section) = iter.next().unwrap().unwrap();
assert_eq!(Section::Authority, section);
}
#[test]
#[cfg(feature = "std")]
fn copy_records() {
let msg = get_test_message();
let msg = msg.for_slice();
let target = MessageBuilder::new_vec().question();
let res = msg.copy_records(target.answer(), |rr| {
if let Ok(Some(rr)) =
rr.into_record::<AllRecordData<_, ParsedDname<_>>>()
{
if rr.rtype() == Rtype::Cname {
return Some(rr);
}
}
None
});
assert!(res.is_ok());
if let Ok(target) = res {
let msg = target.into_message();
assert_eq!(1, msg.header_counts().ancount());
assert_eq!(0, msg.header_counts().arcount());
}
}
}