mod cmp;
mod inherent;
mod iter;
pub use self::iter::{CharacterStrings, Labels, Questions, Records};
use core::{
fmt::{Debug, Display},
ops::{Bound, Range, RangeBounds},
};
use byteorder::{ByteOrder, NetworkEndian};
use crate::{core::Type, seen::Seen};
pub type Result<V, E> = core::result::Result<(V, Range<usize>), E>;
pub trait View<'s>: 's {
type Error: Debug + Display;
fn view(source: &'s [u8], range: impl RangeBounds<usize>) -> Result<Self, Self::Error>
where
Self: Sized,
{
let start = match range.start_bound() {
Bound::Included(&x) => x,
Bound::Excluded(&x) => x + 1,
Bound::Unbounded => 0,
};
let stop = match range.end_bound() {
Bound::Included(&x) => x + 1,
Bound::Excluded(&x) => x,
Bound::Unbounded => source.len(),
};
Self::view_range(source, start..stop)
}
fn view_range(source: &'s [u8], range: Range<usize>) -> Result<Self, Self::Error>
where
Self: Sized;
}
pub trait BorrowedView<'s> {
fn source(&self) -> &'s [u8];
fn offset(&self) -> usize;
fn len(&self) -> usize;
fn as_bytes(&self) -> &'s [u8] {
&self.source()[self.offset()..][..self.len()]
}
}
pub trait ViewToOwned<Owned> {
fn to_owned(&self) -> Owned;
}
#[derive(Debug, Clone)]
pub struct Message<'s> {
source: &'s [u8],
offset: usize,
qd_len: usize,
an_len: usize,
ns_len: usize,
ar_len: usize,
opt: Option<Extension<'s>>,
}
#[derive(Debug, Clone)]
pub struct Header<'s> {
source: &'s [u8],
offset: usize,
}
#[derive(Debug, Clone)]
pub struct Question<'s> {
source: &'s [u8],
offset: usize,
len: usize,
}
#[derive(Debug, Clone)]
pub struct Record<'s> {
source: &'s [u8],
offset: usize,
name_len: usize,
rdata_len: usize,
}
#[derive(Debug, Clone)]
pub struct Name<'s> {
source: &'s [u8],
offset: usize,
len: usize,
}
#[derive(Debug, Clone)]
pub struct Label<'s> {
source: &'s [u8],
offset: usize,
len: usize,
}
#[derive(Debug, Clone)]
pub struct Extension<'s> {
pub inner: Record<'s>,
}
pub struct CharacterString<'s> {
source: &'s [u8],
offset: usize,
len: usize,
}
error!(BoundsError);
#[derive(Debug, displaydoc::Display)]
#[prefix_enum_doc_attributes]
pub enum BoundsError {
BackwardsRange(usize, usize),
RangeOverflow(usize, usize),
ReadOverflow(usize, usize, usize),
}
error!(MessageError(_edns_version, [inner: inner]));
#[derive(Debug, displaydoc::Display)]
pub struct MessageError(pub Option<u8>, pub MessageErrorKind);
error!(MessageErrorKind, Header, Question, Record, Extension);
#[derive(Debug, displaydoc::Display)]
pub enum MessageErrorKind {
Header(BoundsError),
Question(QuestionError),
Record(RecordError),
Extension(ExtensionError),
MisplacedOptRecord,
MultipleOptRecords,
}
error!(QuestionError, Bounds, Name);
#[derive(Debug, displaydoc::Display)]
#[prefix_enum_doc_attributes]
pub enum QuestionError {
Bounds(BoundsError),
Name(NameError),
}
error!(RecordError, Bounds, Name);
#[derive(Debug, displaydoc::Display)]
#[prefix_enum_doc_attributes]
pub enum RecordError {
Bounds(BoundsError),
Name(NameError),
}
error!(NameError, Label);
#[derive(Debug, displaydoc::Display)]
#[prefix_enum_doc_attributes]
pub enum NameError {
Label(LabelError),
PointerCycle,
TooLong,
}
error!(LabelError, Bounds);
#[derive(Debug, displaydoc::Display)]
#[prefix_enum_doc_attributes]
pub enum LabelError {
Bounds(BoundsError),
ReservedLength,
}
error!(ExtensionError);
#[derive(Debug, displaydoc::Display)]
#[prefix_enum_doc_attributes]
pub enum ExtensionError {
BadName,
UnimplementedVersion,
}
impl<'s> View<'s> for Message<'s> {
type Error = MessageError;
fn view_range(source: &'s [u8], range: Range<usize>) -> Result<Self, Self::Error> {
use self::MessageErrorKind as Kind;
let (header, rest) =
Header::view(source, range.clone()).map_err(|e| MessageError(None, Kind::Header(e)))?;
let qdcount = header.qdcount();
let ancount = header.ancount();
let nscount = header.nscount();
let arcount = header.arcount();
let mut next = rest;
let section = next.clone();
for _ in 0..qdcount {
let (_, rest) =
Question::view(source, next).map_err(|e| MessageError(None, Kind::Question(e)))?;
next = rest;
}
let qd_len = next.start - section.start;
let section = next.clone();
for _ in 0..ancount {
let (record, rest) =
Record::view(source, next).map_err(|e| MessageError(None, Kind::Record(e)))?;
next = rest;
if record.r#type() == Type::OPT {
let version = record.ttl().edns_version();
return Err(MessageError(Some(version), Kind::MisplacedOptRecord));
}
}
let an_len = next.start - section.start;
let section = next.clone();
for _ in 0..nscount {
let (record, rest) =
Record::view(source, next).map_err(|e| MessageError(None, Kind::Record(e)))?;
next = rest;
if record.r#type() == Type::OPT {
let version = record.ttl().edns_version();
return Err(MessageError(Some(version), Kind::MisplacedOptRecord));
}
}
let ns_len = next.start - section.start;
let section = next.clone();
let mut opt = None;
let mut edns_version = None;
for _ in 0..arcount {
let (record, rest) = Record::view(source, next)
.map_err(|e| MessageError(edns_version, Kind::Record(e)))?;
next = rest;
if record.r#type() == Type::OPT {
if opt.is_some() {
return Err(MessageError(edns_version, Kind::MultipleOptRecords));
}
edns_version = Some(record.ttl().edns_version());
opt = Some(
Extension::wrap(record)
.map_err(|e| MessageError(edns_version, Kind::Extension(e)))?,
);
}
}
let ar_len = next.start - section.start;
let rest = next;
Ok((
Message {
source,
opt,
offset: range.start,
qd_len,
an_len,
ns_len,
ar_len,
},
rest,
))
}
}
impl<'s> BorrowedView<'s> for Message<'s> {
fn source(&self) -> &'s [u8] {
self.source
}
fn offset(&self) -> usize {
self.offset
}
fn len(&self) -> usize {
self.header().len() + self.qd_len + self.an_len + self.ns_len + self.ar_len
}
}
impl<'s> View<'s> for Header<'s> {
type Error = BoundsError;
fn view_range(source: &'s [u8], range: Range<usize>) -> Result<Self, Self::Error> {
let mut rest = range.clone();
assert(source, &rest, 12)?;
rest.start += 12;
Ok((
Header {
source,
offset: range.start,
},
rest,
))
}
}
impl<'s> BorrowedView<'s> for Header<'s> {
fn source(&self) -> &'s [u8] {
self.source
}
fn offset(&self) -> usize {
self.offset
}
fn len(&self) -> usize {
12
}
}
impl<'s> View<'s> for Question<'s> {
type Error = QuestionError;
fn view_range(source: &'s [u8], range: Range<usize>) -> Result<Self, Self::Error> {
let (_, mut rest) = Name::view(source, range.clone()).map_err(Self::Error::Name)?;
assert(source, &rest, 4).map_err(Self::Error::Bounds)?;
rest.start += 4;
Ok((
Question {
source,
offset: range.start,
len: rest.start - range.start,
},
rest,
))
}
}
impl<'s> BorrowedView<'s> for Question<'s> {
fn source(&self) -> &'s [u8] {
self.source
}
fn offset(&self) -> usize {
self.offset
}
fn len(&self) -> usize {
self.len
}
}
impl<'s> View<'s> for Record<'s> {
type Error = RecordError;
fn view_range(source: &'s [u8], range: Range<usize>) -> Result<Self, Self::Error> {
let (_, mut rest) = Name::view(source, range.clone()).map_err(Self::Error::Name)?;
let name_len = rest.start - range.start;
assert(source, &rest, 8).map_err(Self::Error::Bounds)?;
rest.start += 8;
let (rdlength, mut rest) = u16::view(source, rest).map_err(Self::Error::Bounds)?;
let rdata_len = usize::from(rdlength);
assert(source, &rest, rdata_len).map_err(Self::Error::Bounds)?;
rest.start += rdata_len;
Ok((
Record {
source,
offset: range.start,
name_len,
rdata_len,
},
rest,
))
}
}
impl<'s> BorrowedView<'s> for Record<'s> {
fn source(&self) -> &'s [u8] {
self.source
}
fn offset(&self) -> usize {
self.offset
}
fn len(&self) -> usize {
self.name_len + 10 + self.rdata_len
}
}
impl<'s> View<'s> for Name<'s> {
type Error = NameError;
fn view_range(source: &'s [u8], range: Range<usize>) -> Result<Self, Self::Error> {
let mut result = None;
let mut next = range.clone();
let mut seen = Seen::default();
let mut len = 0;
let rest = loop {
let (label, rest) = Label::view(source, next.clone()).map_err(Self::Error::Label)?;
if label.is_null() {
break result.unwrap_or(rest);
} else if let Some(offset) = label.pointer() {
if seen.see(offset) {
return Err(Self::Error::PointerCycle);
}
next = offset.into()..source.len();
result.get_or_insert_with(|| rest.clone());
} else {
len += label.len();
if len > 255 {
return Err(Self::Error::TooLong);
}
next = rest;
}
};
let offset = range.start;
let len = rest.start - offset;
Ok((
Name {
source,
offset,
len,
},
rest,
))
}
}
impl<'s> BorrowedView<'s> for Name<'s> {
fn source(&self) -> &'s [u8] {
self.source
}
fn offset(&self) -> usize {
self.offset
}
fn len(&self) -> usize {
self.len
}
}
impl<'s> View<'s> for Label<'s> {
type Error = LabelError;
fn view_range(source: &'s [u8], range: Range<usize>) -> Result<Self, Self::Error> {
let (length, mut rest) = u8::view(source, range.clone()).map_err(Self::Error::Bounds)?;
let len = if length < 64 {
usize::from(length)
} else if length < 192 {
return Err(Self::Error::ReservedLength);
} else {
1
};
assert(source, &rest, len).map_err(Self::Error::Bounds)?;
rest.start += len;
Ok((
Label {
source,
offset: range.start,
len: rest.start - range.start,
},
rest,
))
}
}
impl<'s> BorrowedView<'s> for Label<'s> {
fn source(&self) -> &'s [u8] {
self.source
}
fn offset(&self) -> usize {
self.offset
}
fn len(&self) -> usize {
self.len
}
}
impl<'s> BorrowedView<'s> for Extension<'s> {
fn source(&self) -> &'s [u8] {
self.inner.source()
}
fn offset(&self) -> usize {
self.inner.offset()
}
fn len(&self) -> usize {
self.inner.len()
}
}
impl<'s> View<'s> for CharacterString<'s> {
type Error = BoundsError;
fn view_range(source: &'s [u8], range: Range<usize>) -> Result<Self, Self::Error> {
let (length, mut rest) = u8::view(source, range.clone())?;
assert(source, &rest, length.into())?;
rest.start += usize::from(length);
Ok((
CharacterString {
source,
offset: range.start,
len: rest.start - range.start,
},
rest,
))
}
}
impl<'s> BorrowedView<'s> for CharacterString<'s> {
fn source(&self) -> &'s [u8] {
self.source
}
fn offset(&self) -> usize {
self.offset
}
fn len(&self) -> usize {
self.len
}
}
impl View<'_> for u32 {
type Error = BoundsError;
fn view_range(source: &[u8], mut range: Range<usize>) -> Result<Self, Self::Error> {
assert(source, &range, 4)?;
let result = NetworkEndian::read_u32(&source[range.start..]);
range.start += 4;
Ok((result, range))
}
}
impl View<'_> for u16 {
type Error = BoundsError;
fn view_range(source: &[u8], mut range: Range<usize>) -> Result<Self, Self::Error> {
assert(source, &range, 2)?;
let result = NetworkEndian::read_u16(&source[range.start..]);
range.start += 2;
Ok((result, range))
}
}
impl View<'_> for u8 {
type Error = BoundsError;
fn view_range(source: &[u8], mut range: Range<usize>) -> Result<Self, Self::Error> {
assert(source, &range, 1)?;
let result = source[range.start];
range.start += 1;
Ok((result, range))
}
}
pub(crate) fn assert(
source: &[u8],
range: &Range<usize>,
len: usize,
) -> core::result::Result<(), BoundsError> {
if range.end < range.start {
return Err(BoundsError::BackwardsRange(range.start, range.end));
}
if source.len() < range.end {
return Err(BoundsError::RangeOverflow(range.end, source.len()));
}
if range.end - range.start < len {
return Err(BoundsError::ReadOverflow(len, range.start, range.end));
}
Ok(())
}
#[cfg(test)]
mod test {
use core::iter::once;
use arrayvec::ArrayVec;
use assert_matches::assert_matches;
use super::{Extension, Label, Message, Name, Question, Record, View};
type A12 = ArrayVec<u8, 4096>;
#[test]
fn message() {
use super::MessageError;
use super::MessageErrorKind as Kind;
let source = include_bytes!("../samples/daria.daz.cat.a.dns");
let len = source.len();
assert_matches!(
Message::view(source, 0..len),
Ok((
Message {
source: _,
opt: Some(Extension {
inner: Record {
source: _,
offset: 141,
name_len: 1,
rdata_len: 0
}
}),
offset: 0,
qd_len: 19,
an_len: 16,
ns_len: 66,
ar_len: 39
},
ref x
)) if *x == (len..len)
);
let header = b"\x13\x13\x00\x00\x00\x01\x00\x01\x00\x00\x00\x00";
let question = b"\0\x00\x02\x00\x01";
let v0 = b"\0\x00\x29\x10\x00\x00\x00\x00\x00\x00\x00";
let v13 = b"\0\x00\x29\x10\x00\x00\x0D\x00\x00\x00\x00";
let source: A12 = header
.iter()
.copied()
.chain(question.iter().copied())
.chain(v13.iter().copied())
.collect();
assert_matches!(
Message::view(&source, ..),
Err(MessageError(Some(13), Kind::MisplacedOptRecord))
);
let header = b"\x13\x13\x00\x00\x00\x01\x00\x00\x00\x01\x00\x00";
let source: A12 = header
.iter()
.copied()
.chain(question.iter().copied())
.chain(v13.iter().copied())
.collect();
assert_matches!(
Message::view(&source, ..),
Err(MessageError(Some(13), Kind::MisplacedOptRecord))
);
let header = b"\x13\x13\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02";
let source: A12 = header
.iter()
.copied()
.chain(question.iter().copied())
.chain(v0.iter().copied())
.chain(v13.iter().copied())
.collect();
assert_matches!(
Message::view(&source, ..),
Err(MessageError(Some(0), Kind::MultipleOptRecords))
);
}
#[test]
fn question() {
assert_matches!(
Question::view(b"\0\xFF\xFE\xFF\xFE", 0..5),
Ok((Question { source: _, offset: 0, len: 5 }, ref x)) if *x == (5..5)
);
}
#[test]
fn record() {
assert_matches!(
Record::view(
b"\0\xFF\xFE\xFF\xFE\x00\x00\x0E\x10\x00\x00\x00",
0..12,
),
Ok((
Record {
source: _,
offset: 0,
name_len: 1,
rdata_len: 0
},
ref x
)) if *x == (11..12)
);
assert_matches!(
Record::view(
b"\0\xFF\xFE\xFF\xFE\x00\x00\x0E\x10\x00\x01\x00",
0..12,
),
Ok((
Record {
source: _,
offset: 0,
name_len: 1,
rdata_len: 1
},
ref x
)) if *x == (12..12)
);
}
#[test]
fn name() {
use super::{BoundsError, LabelError::Bounds, NameError::*};
assert_matches!(
Name::view(b"\x05daria\x03daz\x03cat\0", 0..15),
Ok((Name { source: _, offset: 0, len: 15 }, ref x)) if *x == (15..15)
);
assert_matches!(
Name::view(
b"\x05daria\x03daz\x03cat\0\x08charming\xC0\x06",
15..26,
),
Ok((
Name {
source: _,
offset: 15,
len: 11
},
ref x
)) if *x == (26..26)
);
assert_matches!(
Name::view(
b"\x05daria\x03daz\x03cat\0\x08charming\xC0\x06\xC0\x0F",
26..28,
),
Ok((
Name {
source: _,
offset: 26,
len: 2
},
ref x
)) if *x == (28..28)
);
assert_matches!(
Name::view(b"\xC0\x03", 0..2),
Err(Label(Bounds(BoundsError::BackwardsRange(3, 2))))
);
assert_matches!(
Name::view(b"\xC0\x02", 0..2),
Err(Label(Bounds(BoundsError::ReadOverflow(1, 2, 2))))
);
assert_matches!(
Name::view(b"\xC0\x01", 0..2),
Err(Label(Bounds(BoundsError::ReadOverflow(1, 2, 2))))
);
assert_matches!(
Name::view(b"\xC0\x02\0", 0..2),
Ok((Name { source: _, offset: 0, len: 2 }, ref x)) if *x == (2..2)
);
assert_matches!(Name::view(b"\xC0\x02\xC0\x00", 0..2), Err(PointerCycle));
assert_matches!(Name::view(b"\xC0\x00", 0..2), Err(PointerCycle));
let source = [2; 255].iter().copied().chain(once(0)).collect::<A12>();
assert_matches!(Name::view(&source, 0..256), Ok((Name { source: _, offset: 0, len: 256 }, ref x)) if *x == (256..256));
let source = [3; 256].iter().copied().chain(once(0)).collect::<A12>();
assert_matches!(Name::view(&source, 0..257), Err(TooLong));
let source = [3; 128]
.iter()
.copied()
.chain(once(0))
.chain([3; 128].iter().copied())
.chain([0xC0, 0x00].iter().copied())
.collect::<A12>();
assert_matches!(Name::view(&source, 129..259), Err(TooLong));
}
#[test]
#[rustfmt::skip]
fn label() {
use super::{BoundsError, LabelError::*};
assert_matches!(
Label::view(b"\0", 0..1),
Ok((Label { source: _, offset: 0, len: 1 }, ref x)) if *x == (1..1)
);
assert_matches!(
Label::view(b"\x01\0", 0..2),
Ok((Label { source: _, offset: 0, len: 2 }, ref x)) if *x == (2..2)
);
assert_matches!(
Label::view(b"\0\0", 0..2),
Ok((Label { source: _, offset: 0, len: 1 }, ref x)) if *x == (1..2)
);
assert_matches!(
Label::view(&[63; 64], 0..64),
Ok((Label { source: _, offset: 0, len: 64 }, ref x)) if *x == (64..64)
);
assert_matches!(Label::view(&[64; 65], 0..65), Err(ReservedLength));
assert_matches!(Label::view(&[191; 192], 191..192), Err(ReservedLength));
assert_matches!(
Label::view(&[192, 0], 0..2),
Ok((Label { source: _, offset: 0, len: 2 }, ref x)) if *x == (2..2)
);
assert_matches!(Label::view(b"\x01\0", 0..1), Err(Bounds(BoundsError::ReadOverflow(1, 1, 1))));
assert_matches!(Label::view(b"\x01", 0..1), Err(Bounds(BoundsError::ReadOverflow(1, 1, 1))));
assert_matches!(Label::view(b"", 0..0), Err(Bounds(BoundsError::ReadOverflow(1, 0, 0))));
}
#[test]
#[rustfmt::skip]
fn assert() {
use super::{assert, BoundsError};
assert_matches!(assert(b"", &(0..0), 0), Ok(()));
assert_matches!(assert(b"", &(0..0), 1), Err(BoundsError::ReadOverflow(1, 0, 0)));
assert_matches!(assert(b"", &(0..1), 0), Err(BoundsError::RangeOverflow(1, 0)));
assert_matches!(assert(b"", &(0..1), 1), Err(BoundsError::RangeOverflow(1, 0)));
assert_matches!(assert(b"~", &(0..0), 0), Ok(()));
assert_matches!(assert(b"~", &(0..0), 1), Err(BoundsError::ReadOverflow(1, 0, 0)));
assert_matches!(assert(b"~", &(0..1), 0), Ok(()));
assert_matches!(assert(b"~", &(0..1), 1), Ok(()));
assert_matches!(assert(b"~", &(0..1), 2), Err(BoundsError::ReadOverflow(2, 0, 1)));
assert_matches!(assert(b"~", &(1..1), 0), Ok(()));
assert_matches!(assert(b"~", &(1..1), 1), Err(BoundsError::ReadOverflow(1, 1, 1)));
assert_matches!(assert(b"~", &(1..0), 0), Err(BoundsError::BackwardsRange(1, 0)));
assert_matches!(assert(b"<>", &(0..2), 2), Ok(()));
assert_matches!(assert(b"<>", &(0..1), 2), Err(BoundsError::ReadOverflow(2, 0, 1)));
}
}
#[cfg(all(test, feature = "bench"))]
mod bench {
extern crate test;
use test::Bencher;
use super::{Message, Name, Question, Record, View};
#[bench]
fn message(bencher: &mut Bencher) {
let source = include_bytes!("../samples/daria.daz.cat.a.dns");
let len = source.len();
bencher.iter(|| Message::view(source, 0..len));
}
#[bench]
fn question(bencher: &mut Bencher) {
bencher.iter(|| Question::view(b"\0\xFF\xFE\xFF\xFE", 0..5));
}
#[bench]
fn record(bencher: &mut Bencher) {
bencher.iter(|| Record::view(b"\0\xFF\xFE\xFF\xFE\x00\x00\x0E\x10\x00\x00", 0..11));
}
#[bench]
fn name0(bencher: &mut Bencher) {
bencher.iter(|| Name::view(b"\0", 0..1));
}
#[bench]
fn name1(bencher: &mut Bencher) {
bencher.iter(|| Name::view(b"\x05daria\x03daz\x03cat\0", 0..15));
}
#[bench]
fn name2(bencher: &mut Bencher) {
bencher.iter(|| Name::view(b"\x05daria\x03daz\x03cat\0\x08charming\xC0\x00", 15..26));
}
#[bench]
fn name3(bencher: &mut Bencher) {
let mut source = alloc::vec![];
source.extend(b"\x08charming\xFF\xFF");
source.resize(16383, b'\0');
source.extend(b"\x03daz\x03cat\0");
bencher.iter(|| Name::view(&source, 0..11));
}
}