use core::{
borrow::{Borrow, BorrowMut},
cmp::Ordering,
fmt,
hash::{Hash, Hasher},
ops::{Deref, DerefMut},
str::FromStr,
};
use crate::{
new::base::{
build::BuildInMessage,
parse::{ParseMessageBytes, SplitMessageBytes},
wire::{
AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes,
TruncationError,
},
},
utils::dst::{UnsizedCopy, UnsizedCopyFrom},
};
use super::{
CanonicalName, Label, LabelBuf, LabelIter, LabelParseError,
NameCompressor,
};
#[derive(AsBytes, BuildBytes, UnsizedCopy)]
#[repr(transparent)]
pub struct Name([u8]);
impl Name {
pub const MAX_SIZE: usize = 255;
pub const ROOT: &'static Self = {
unsafe { Self::from_bytes_unchecked(&[0u8]) }
};
}
impl Name {
pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self {
core::mem::transmute(bytes)
}
pub unsafe fn from_bytes_unchecked_mut(bytes: &mut [u8]) -> &mut Self {
core::mem::transmute(bytes)
}
}
impl Name {
#[allow(clippy::len_without_is_empty)]
pub const fn len(&self) -> usize {
self.0.len()
}
pub const fn is_root(&self) -> bool {
self.0.len() == 1
}
pub const fn as_bytes(&self) -> &[u8] {
&self.0
}
pub const fn labels(&self) -> LabelIter<'_> {
unsafe { LabelIter::new_unchecked(self.as_bytes()) }
}
}
impl CanonicalName for Name {
fn cmp_composed(&self, other: &Self) -> Ordering {
self.as_bytes().cmp(other.as_bytes())
}
fn cmp_lowercase_composed(&self, other: &Self) -> Ordering {
self.as_bytes()
.iter()
.map(u8::to_ascii_lowercase)
.cmp(other.as_bytes().iter().map(u8::to_ascii_lowercase))
}
}
impl<'a> ParseBytes<'a> for &'a Name {
fn parse_bytes(bytes: &'a [u8]) -> Result<Self, ParseError> {
match Self::split_bytes(bytes) {
Ok((this, &[])) => Ok(this),
_ => Err(ParseError),
}
}
}
impl<'a> SplitBytes<'a> for &'a Name {
fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> {
let mut offset = 0usize;
while offset < 255 {
match *bytes.get(offset..).ok_or(ParseError)? {
[0, ..] => {
let (name, rest) = bytes.split_at(offset + 1);
let name = unsafe { Name::from_bytes_unchecked(name) };
return Ok((name, rest));
}
[l @ 1..=63, ref rest @ ..] if rest.len() >= l as usize => {
offset += 1 + l as usize;
}
_ => return Err(ParseError),
}
}
Err(ParseError)
}
}
impl BuildInMessage for Name {
fn build_in_message(
&self,
contents: &mut [u8],
start: usize,
compressor: &mut NameCompressor,
) -> Result<usize, TruncationError> {
if let Some((rest, addr)) =
compressor.compress_name(&contents[..start], self)
{
let end = start + rest.len() + 2;
let bytes =
contents.get_mut(start..end).ok_or(TruncationError)?;
bytes[..rest.len()].copy_from_slice(rest);
let addr = (addr + 0xC00C).to_be_bytes();
bytes[rest.len()..].copy_from_slice(&addr);
Ok(end)
} else {
let end = start + self.len();
contents
.get_mut(start..end)
.ok_or(TruncationError)?
.copy_from_slice(self.as_bytes());
Ok(end)
}
}
}
#[cfg(feature = "alloc")]
impl Clone for alloc::boxed::Box<Name> {
fn clone(&self) -> Self {
(*self).unsized_copy_into()
}
}
impl PartialEq for Name {
fn eq(&self, other: &Self) -> bool {
let lhs = self.as_bytes().iter().map(u8::to_ascii_lowercase);
let rhs = other.as_bytes().iter().map(u8::to_ascii_lowercase);
lhs.eq(rhs)
}
}
impl Eq for Name {}
impl PartialOrd for Name {
fn partial_cmp(&self, that: &Self) -> Option<Ordering> {
Some(self.cmp(that))
}
}
impl Ord for Name {
fn cmp(&self, that: &Self) -> Ordering {
let suffix_len = core::iter::zip(
self.as_bytes().iter().rev().map(u8::to_ascii_lowercase),
that.as_bytes().iter().rev().map(u8::to_ascii_lowercase),
)
.position(|(a, b)| a != b);
let Some(suffix_len) = suffix_len else {
return self.len().cmp(&that.len());
};
let (mut lhs, mut rhs) = (self.labels(), that.labels());
let mut prev = unsafe {
(lhs.next().unwrap_unchecked(), rhs.next().unwrap_unchecked())
};
loop {
let (llen, rlen) = (lhs.remaining().len(), rhs.remaining().len());
if llen == rlen && llen <= suffix_len {
break prev.0.cmp(prev.1);
} else if llen > rlen {
prev.0 = unsafe { lhs.next().unwrap_unchecked() };
} else {
prev.1 = unsafe { rhs.next().unwrap_unchecked() };
}
}
}
}
impl Hash for Name {
fn hash<H: Hasher>(&self, state: &mut H) {
for byte in self.as_bytes() {
state.write_u8(byte.to_ascii_lowercase())
}
}
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut first = true;
self.labels().try_for_each(|label| {
if !first {
f.write_str(".")?;
} else {
first = false;
}
label.fmt(f)
})
}
}
impl fmt::Debug for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Name({})", self)
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for Name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use std::string::ToString;
if serializer.is_human_readable() {
serializer.serialize_newtype_struct("Name", &self.to_string())
} else {
serializer.serialize_newtype_struct("Name", self.as_bytes())
}
}
}
#[derive(Clone)]
#[repr(C)] pub struct NameBuf {
size: u8,
buffer: [u8; 255],
}
impl NameBuf {
const fn empty() -> Self {
Self {
size: 0,
buffer: [0; 255],
}
}
pub fn copy_from(name: &Name) -> Self {
let mut buffer = [0u8; 255];
buffer[..name.len()].copy_from_slice(name.as_bytes());
Self {
size: name.len() as u8,
buffer,
}
}
}
impl UnsizedCopyFrom for NameBuf {
type Source = Name;
fn unsized_copy_from(value: &Self::Source) -> Self {
Self::copy_from(value)
}
}
impl<'a> SplitMessageBytes<'a> for NameBuf {
fn split_message_bytes(
contents: &'a [u8],
start: usize,
) -> Result<(Self, usize), ParseError> {
let mut buffer = Self::empty();
let bytes = contents.get(start..).ok_or(ParseError)?;
let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?;
let orig_end = contents.len() - rest.len();
let mut old_start = start;
while let Some(start) = pointer.map(usize::from) {
let start = start.checked_sub(12).ok_or(ParseError)?;
if start >= old_start {
return Err(ParseError);
}
let bytes = contents.get(start..).ok_or(ParseError)?;
(pointer, _) = parse_segment(bytes, &mut buffer)?;
old_start = start;
continue;
}
Ok((buffer, orig_end))
}
}
impl<'a> ParseMessageBytes<'a> for NameBuf {
fn parse_message_bytes(
contents: &'a [u8],
start: usize,
) -> Result<Self, ParseError> {
let mut buffer = Self::empty();
let bytes = contents.get(start..).ok_or(ParseError)?;
let (mut pointer, rest) = parse_segment(bytes, &mut buffer)?;
if !rest.is_empty() {
return Err(ParseError);
}
let mut old_start = start;
while let Some(start) = pointer.map(usize::from) {
let start = start.checked_sub(12).ok_or(ParseError)?;
if start >= old_start {
return Err(ParseError);
}
let bytes = contents.get(start..).ok_or(ParseError)?;
(pointer, _) = parse_segment(bytes, &mut buffer)?;
old_start = start;
continue;
}
Ok(buffer)
}
}
fn parse_segment<'a>(
mut bytes: &'a [u8],
buffer: &mut NameBuf,
) -> Result<(Option<u16>, &'a [u8]), ParseError> {
loop {
match *bytes {
[0, ref rest @ ..] => {
buffer.append_bytes(&[0u8]);
return Ok((None, rest));
}
[l, ..] if l < 64 => {
if bytes.len() < 1 + l as usize {
return Err(ParseError);
} else if 255 - buffer.size < 2 + l {
return Err(ParseError);
}
let (label, rest) = bytes.split_at(1 + l as usize);
buffer.append_bytes(label);
bytes = rest;
}
[hi, lo, ref rest @ ..] if hi >= 0xC0 => {
let pointer = u16::from_be_bytes([hi, lo]);
return Ok((Some(pointer & 0x3FFF), rest));
}
_ => return Err(ParseError),
}
}
}
impl<'a> SplitBytes<'a> for NameBuf {
fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> {
<&Name>::split_bytes(bytes)
.map(|(name, rest)| (NameBuf::copy_from(name), rest))
}
}
impl<'a> ParseBytes<'a> for NameBuf {
fn parse_bytes(bytes: &'a [u8]) -> Result<Self, ParseError> {
<&Name>::parse_bytes(bytes).map(NameBuf::copy_from)
}
}
impl BuildBytes for NameBuf {
fn build_bytes<'b>(
&self,
bytes: &'b mut [u8],
) -> Result<&'b mut [u8], TruncationError> {
(**self).build_bytes(bytes)
}
fn built_bytes_size(&self) -> usize {
(**self).built_bytes_size()
}
}
impl NameBuf {
fn append_bytes(&mut self, bytes: &[u8]) {
self.buffer[self.size as usize..][..bytes.len()]
.copy_from_slice(bytes);
self.size += bytes.len() as u8;
}
fn append_label(&mut self, label: &Label) {
self.append_bytes(label.as_bytes());
}
}
impl Deref for NameBuf {
type Target = Name;
fn deref(&self) -> &Self::Target {
let name = &self.buffer[..self.size as usize];
unsafe { Name::from_bytes_unchecked(name) }
}
}
impl DerefMut for NameBuf {
fn deref_mut(&mut self) -> &mut Self::Target {
let name = &mut self.buffer[..self.size as usize];
unsafe { Name::from_bytes_unchecked_mut(name) }
}
}
impl Borrow<Name> for NameBuf {
fn borrow(&self) -> &Name {
self
}
}
impl BorrowMut<Name> for NameBuf {
fn borrow_mut(&mut self) -> &mut Name {
self
}
}
impl AsRef<Name> for NameBuf {
fn as_ref(&self) -> &Name {
self
}
}
impl AsMut<Name> for NameBuf {
fn as_mut(&mut self) -> &mut Name {
self
}
}
impl PartialEq for NameBuf {
fn eq(&self, that: &Self) -> bool {
**self == **that
}
}
impl Eq for NameBuf {}
impl PartialOrd for NameBuf {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for NameBuf {
fn cmp(&self, other: &Self) -> Ordering {
(**self).cmp(&**other)
}
}
impl Hash for NameBuf {
fn hash<H: Hasher>(&self, state: &mut H) {
(**self).hash(state)
}
}
impl fmt::Display for NameBuf {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl fmt::Debug for NameBuf {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl NameBuf {
pub fn parse_str(mut s: &[u8]) -> Result<(Self, &[u8]), NameParseError> {
let mut this = Self::empty();
loop {
let (label, rest) = LabelBuf::parse_str(s)?;
if 255 - this.size < 1 + label.as_bytes().len() as u8 {
return Err(NameParseError::Overlong);
}
this.append_label(&label);
match *rest {
[b' ' | b'\n' | b'\r' | b'\t', ..] | [] => {
s = rest;
break;
}
[b'.', ref rest @ ..] => s = rest,
_ => return Err(NameParseError::InvalidChar),
}
}
this.append_label(Label::ROOT);
Ok((this, s))
}
}
impl FromStr for NameBuf {
type Err = NameParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match Self::parse_str(s.as_bytes()) {
Ok((this, &[])) => Ok(this),
Ok(_) => Err(NameParseError::InvalidChar),
Err(err) => Err(err),
}
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for NameBuf {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
(**self).serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'a> serde::Deserialize<'a> for NameBuf {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
if deserializer.is_human_readable() {
struct V;
impl serde::de::Visitor<'_> for V {
type Value = NameBuf;
fn expecting(
&self,
f: &mut fmt::Formatter<'_>,
) -> fmt::Result {
f.write_str("a domain name, in the DNS zonefile format")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
v.parse().map_err(|err| E::custom(err))
}
}
struct NV;
impl<'a> serde::de::Visitor<'a> for NV {
type Value = NameBuf;
fn expecting(
&self,
f: &mut fmt::Formatter<'_>,
) -> fmt::Result {
f.write_str("an absolute domain name")
}
fn visit_newtype_struct<D>(
self,
deserializer: D,
) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'a>,
{
deserializer.deserialize_str(V)
}
}
deserializer.deserialize_newtype_struct("Name", NV)
} else {
struct V;
impl serde::de::Visitor<'_> for V {
type Value = NameBuf;
fn expecting(
&self,
f: &mut fmt::Formatter<'_>,
) -> fmt::Result {
f.write_str("a domain name, in the DNS wire format")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
NameBuf::parse_bytes(v).map_err(|_| E::custom("misformatted domain name for the DNS wire format"))
}
}
struct NV;
impl<'a> serde::de::Visitor<'a> for NV {
type Value = NameBuf;
fn expecting(
&self,
f: &mut fmt::Formatter<'_>,
) -> fmt::Result {
f.write_str("an absolute domain name")
}
fn visit_newtype_struct<D>(
self,
deserializer: D,
) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'a>,
{
deserializer.deserialize_bytes(V)
}
}
deserializer.deserialize_newtype_struct("Name", NV)
}
}
}
#[cfg(feature = "serde")]
impl<'a> serde::Deserialize<'a> for std::boxed::Box<Name> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
NameBuf::deserialize(deserializer)
.map(|this| this.unsized_copy_into())
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum NameParseError {
Overlong,
InvalidChar,
Label(LabelParseError),
}
impl From<LabelParseError> for NameParseError {
fn from(value: LabelParseError) -> Self {
Self::Label(value)
}
}
#[cfg(feature = "std")]
impl std::error::Error for NameParseError {}
impl fmt::Display for NameParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::Overlong => "the domain name was too long",
Self::InvalidChar | Self::Label(LabelParseError::InvalidChar) => {
"the domain name contained an invalid character"
}
Self::Label(LabelParseError::Overlong) => "a label was too long",
Self::Label(LabelParseError::Empty) => "a label was empty",
Self::Label(LabelParseError::PartialEscape) => {
"a label contained an incomplete escape"
}
Self::Label(LabelParseError::InvalidEscape) => {
"a label contained an invalid escape"
}
})
}
}