use crate::core::{Class, Opcode, Rcode, Ttl, Type};
use crate::rdata::view::error::TypedRdataError;
use crate::rdata::view::Rdata;
use crate::rdata::ClassType;
use crate::view::{BorrowedView, ViewToOwned};
use crate::MAX_SUPPORTED_EDNS_VERSION;
use super::iter::{Labels, Questions, Records};
use super::ExtensionError;
use super::{CharacterString, Extension, Header, Label, Message, Name, Question, Record, View};
use byteorder::{ByteOrder, NetworkEndian};
use core::ops::Range;
impl Message<'_> {
pub fn header(&self) -> Header {
Header {
source: self.source,
offset: self.offset,
}
}
pub fn opt(&self) -> Option<&Extension> {
self.opt.as_ref()
}
pub fn rcode(&self) -> Rcode {
let basic_part = self.header().rcode();
let extended_part = self.opt().map_or(0, |x| x.extended());
Rcode::from_basic_extended(basic_part, extended_part).expect("guaranteed by Header impl")
}
pub fn qd(&self) -> Questions<'_> {
let start = self.offset + self.header().len();
let stop = start + self.qd_len;
Questions::new(self.source, start..stop)
}
pub fn an(&self) -> Records<'_> {
let start = self.offset + self.header().len() + self.qd_len;
let stop = start + self.an_len;
Records::new(self.source, start..stop)
}
pub fn ns(&self) -> Records<'_> {
let start = self.offset + self.header().len() + self.qd_len + self.an_len;
let stop = start + self.ns_len;
Records::new(self.source, start..stop)
}
pub fn ar(&self) -> Records<'_> {
let start = self.offset + self.header().len() + self.qd_len + self.an_len + self.ns_len;
let stop = start + self.ar_len;
Records::new(self.source, start..stop)
}
pub fn udp_limit(&self) -> u16 {
self.opt().map_or(512, |x| x.udp().max(512))
}
}
impl Header<'_> {
pub fn id(&self) -> u16 {
NetworkEndian::read_u16(&self.source[self.offset..][0..])
}
pub fn qr(&self) -> bool {
self.source[self.offset..][2] >> 7 & 1 == 1
}
pub fn opcode(&self) -> Opcode {
Opcode::new(self.source[self.offset..][2] >> 3 & 0xF).unwrap()
}
pub fn aa(&self) -> bool {
self.source[self.offset..][2] >> 2 & 1 == 1
}
pub fn tc(&self) -> bool {
self.source[self.offset..][2] >> 1 & 1 == 1
}
pub fn rd(&self) -> bool {
self.source[self.offset..][2] >> 0 & 1 == 1
}
pub fn ra(&self) -> bool {
self.source[self.offset..][3] >> 7 & 1 == 1
}
fn rcode(&self) -> u8 {
self.source[self.offset..][3] & 0xF
}
pub fn qdcount(&self) -> u16 {
NetworkEndian::read_u16(&self.source[self.offset..][4..])
}
pub fn ancount(&self) -> u16 {
NetworkEndian::read_u16(&self.source[self.offset..][6..])
}
pub fn nscount(&self) -> u16 {
NetworkEndian::read_u16(&self.source[self.offset..][8..])
}
pub fn arcount(&self) -> u16 {
NetworkEndian::read_u16(&self.source[self.offset..][10..])
}
}
impl Question<'_> {
pub fn qname(&self) -> Name {
Name {
source: self.source,
offset: self.offset,
len: self.len - 4,
}
}
pub fn qtype(&self) -> Type {
let offset = self.len - 4;
Type::new(NetworkEndian::read_u16(
&self.source[self.offset..][offset..],
))
}
pub fn qclass(&self) -> Class {
let offset = self.len - 2;
Class::new(NetworkEndian::read_u16(
&self.source[self.offset..][offset..],
))
}
}
impl Record<'_> {
pub fn name(&self) -> Name {
Name {
source: self.source,
offset: self.offset,
len: self.name_len,
}
}
pub fn r#type(&self) -> Type {
let offset = self.name_len;
Type::new(NetworkEndian::read_u16(
&self.source[self.offset..][offset..],
))
}
pub fn class(&self) -> Class {
let offset = self.name_len + 2;
Class::new(NetworkEndian::read_u16(
&self.source[self.offset..][offset..],
))
}
pub fn ttl(&self) -> Ttl {
let offset = self.name_len + 4;
Ttl::new(NetworkEndian::read_u32(
&self.source[self.offset..][offset..],
))
}
pub fn rdlength(&self) -> u16 {
let offset = self.offset + self.name_len + 8;
NetworkEndian::read_u16(&self.source[offset..])
}
pub fn rdata(&self) -> Rdata {
Rdata::from(self)
}
pub fn rdata_as_view<CT: ClassType>(&self) -> Result<CT::View<'_>, TypedRdataError> {
if self.class() == CT::CLASS && self.r#type() == CT::TYPE {
Ok(CT::View::view(self.source, self.rdata_range())
.map_err(|e| TypedRdataError::View(e.into()))?
.0)
} else {
Err(TypedRdataError::WrongClassOrType)
}
}
pub fn rdata_as_owned<'s, CT: ClassType<View<'s>: ViewToOwned<Owned>>, Owned>(
&'s self,
) -> Result<Owned, TypedRdataError> {
if self.class() == CT::CLASS && self.r#type() == CT::TYPE {
Ok(CT::View::view(self.source, self.rdata_range())
.map_err(|e| TypedRdataError::View(e.into()))?
.0
.to_owned())
} else {
Err(TypedRdataError::WrongClassOrType)
}
}
pub fn rdata_range(&self) -> Range<usize> {
let start = self.offset + self.name_len + 10;
let stop = start + self.rdata_len;
start..stop
}
}
impl<'s> Name<'s> {
pub(crate) fn new_unchecked(source: &'s [u8], range: Range<usize>) -> Self {
Self {
source,
offset: range.start,
len: range.end - range.start,
}
}
}
impl Name<'_> {
pub fn is_root(&self) -> bool {
self.labels_with_null()
.next()
.expect("Labels guarantees that this is infallible")
.is_null()
}
pub fn is_subdomain_of(&self, other: &Name<'_>) -> bool {
let self_count = self.labels_with_null().count();
let other_count = other.labels_with_null().count();
if self_count < other_count {
return false;
}
let self_labels = self.labels_with_null().skip(self_count - other_count);
let other_labels = other.labels_with_null();
self_labels.eq(other_labels)
}
pub fn subdomain_distance(&self, other: &Name<'_>) -> Option<usize> {
if self.is_subdomain_of(other) {
let self_count = self.labels_with_null().count();
let other_count = other.labels_with_null().count();
Some(self_count - other_count)
} else {
None
}
}
}
impl<'s> Name<'s> {
pub fn labels_with_null(&self) -> Labels<'s> {
let start = self.offset;
let stop = self.offset + self.len;
Labels::new(self.source, start..stop, true)
}
pub fn labels_not_null(&self) -> Labels<'s> {
let start = self.offset;
let stop = self.offset + self.len;
Labels::new(self.source, start..stop, false)
}
pub fn parent(&self) -> Option<Name<'s>> {
let second_label = self.labels_with_null().skip(1).next()?;
let (result, _) =
Name::view(self.source, second_label.offset..).expect("guaranteed by Self invariants");
Some(result)
}
pub fn truncate(&self, labels_not_null: usize) -> Name<'s> {
let total_labels = self.labels_not_null().count();
let skip_labels = total_labels.saturating_sub(labels_not_null);
let mut result = self.clone();
for _ in 1..=skip_labels {
result = result.parent().expect("guaranteed by Self invariants");
}
result
}
}
impl<'s> Label<'s> {
pub fn value(&self) -> Option<&'s [u8]> {
let result = self.resolve()?;
debug_assert!(result.source[result.offset] < 0x40);
Some(&result.source[result.offset..][1..result.len])
}
fn resolve(&self) -> Option<Self> {
let mut label = Label::view(self.source, self.offset..).ok()?.0;
while let Some(offset) = label.pointer() {
label = Label::view(self.source, usize::from(offset)..).ok()?.0;
}
Some(label)
}
pub fn is_null(&self) -> bool {
self.len == 1
}
pub fn pointer(&self) -> Option<u16> {
if self.source[self.offset] < 0xC0 {
return None;
}
Some(NetworkEndian::read_u16(&self.source[self.offset..]) ^ 0xC000)
}
}
impl<'s> Extension<'s> {
pub fn wrap(inner: Record<'s>) -> Result<Self, ExtensionError> {
if !inner.name().is_root() {
return Err(ExtensionError::BadName);
}
let result = Self { inner };
if result.version() > MAX_SUPPORTED_EDNS_VERSION {
return Err(ExtensionError::UnimplementedVersion);
}
Ok(result)
}
pub fn udp(&self) -> u16 {
self.inner.class().value()
}
pub fn extended(&self) -> u8 {
(self.inner.ttl().value() >> 24 & 0xFF) as u8
}
pub fn version(&self) -> u8 {
(self.inner.ttl().value() >> 16 & 0xFF) as u8
}
pub fn r#do(&self) -> bool {
self.inner.ttl().value() >> 15 & 1 == 1
}
}
impl<'s> CharacterString<'s> {
pub fn value(&self) -> &[u8] {
&self.source[self.offset..][1..self.len]
}
}
#[cfg(test)]
mod test {
use super::super::{Label, Name, View};
declare_any_error!(AnyError);
#[test]
#[rustfmt::skip]
fn name() -> Result<(), AnyError> {
let source = b"\0";
let (name, _) = Name::view(source, 0..1)?;
assert_eq!(name.is_root(), true);
let source = b"\0\xC0\x00";
let (name, _) = Name::view(source, 1..3)?;
assert_eq!(name.is_root(), true);
let daria_daz_cat = Name::view(b"\x03cat\0\x05daria\x03daz\xC0\x00", 5..)?.0;
let daz_cat = Name::view(b"\x03daz\x03cat\0", ..)?.0;
assert!(daria_daz_cat.is_subdomain_of(&daz_cat));
|| -> Option<()> {
assert_eq!(daria_daz_cat.parent()?, "daz.cat.");
assert_eq!(daria_daz_cat.parent()?.parent()?, "cat.");
assert_eq!(daria_daz_cat.parent()?.parent()?.parent()?, ".");
assert_eq!(daria_daz_cat.parent()?.parent()?.parent()?.parent(), None);
let mut name = daria_daz_cat.clone();
name = name.parent()?;
name = name.parent()?;
name = name.parent()?;
assert_eq!(name, ".");
Some(())
}();
|| -> Option<()> {
assert_eq!(daria_daz_cat.truncate(4), "daria.daz.cat.");
assert_eq!(daria_daz_cat.truncate(3), "daria.daz.cat.");
assert_eq!(daria_daz_cat.truncate(2), "daz.cat.");
assert_eq!(daria_daz_cat.truncate(1), "cat.");
assert_eq!(daria_daz_cat.truncate(0), ".");
let mut name = daria_daz_cat.truncate(4);
name = name.truncate(3);
name = name.truncate(2);
name = name.truncate(1);
name = name.truncate(0);
assert_eq!(name, ".");
Some(())
}();
let daz_cat = daria_daz_cat.truncate(2);
assert_eq!(Name::view(b"\x03em0\x05daria\x03daz\x03cat\0", ..)?.0.subdomain_distance(&daz_cat), Some(2));
assert_eq!(Name::view(b"\x05daria\x03daz\x03cat\0", ..)?.0.subdomain_distance(&daz_cat), Some(1));
assert_eq!(Name::view(b"\x03daz\x03cat\0", ..)?.0.subdomain_distance(&daz_cat), Some(0));
assert_eq!(Name::view(b"\x03cat\0", ..)?.0.subdomain_distance(&daz_cat), None);
assert_eq!(Name::view(b"\0", ..)?.0.subdomain_distance(&daz_cat), None);
assert_eq!(Name::view(b"\x03dog\0", ..)?.0.subdomain_distance(&daz_cat), None);
Ok(())
}
#[test]
#[rustfmt::skip]
fn label() -> Result<(), AnyError> {
let source = b"\0";
let (label, _) = Label::view(source, 0..1)?;
assert_eq!(label.is_null(), true);
let source = b"\0\xC0\x00";
let (label, _) = Label::view(source, 1..3)?;
assert_eq!(label.is_null(), false);
assert_eq!(Label::view(b"\0", ..)?.0.is_null(), true);
assert_eq!(Label::view(b"\0", ..)?.0.pointer(), None);
assert_eq!(Label::view(b"\0", ..)?.0.value(), Some(&b""[..]));
assert_eq!(Label::view(b"\x03cat", ..)?.0.is_null(), false);
assert_eq!(Label::view(b"\x03cat", ..)?.0.pointer(), None);
assert_eq!(Label::view(b"\x03cat", ..)?.0.value(), Some(&b"cat"[..]));
assert_eq!(Label::view(b"\xC0\x02\x03cat", ..)?.0.is_null(), false);
assert_eq!(Label::view(b"\xC0\x02\x03cat", ..)?.0.pointer(), Some(2));
assert_eq!(Label::view(b"\xC0\x02\x03cat", ..)?.0.value(), Some(&b"cat"[..]));
assert_eq!(Label::view(b"\xC0\x02\xC0\x04\x03cat", ..)?.0.is_null(), false);
assert_eq!(Label::view(b"\xC0\x02\xC0\x04\x03cat", ..)?.0.pointer(), Some(2));
assert_eq!(Label::view(b"\xC0\x02\xC0\x04\x03cat", ..)?.0.value(), Some(&b"cat"[..]));
Ok(())
}
}