use core::{
borrow::{Borrow, BorrowMut},
cmp::Ordering,
fmt,
hash::{Hash, Hasher},
iter::FusedIterator,
ops::{Deref, DerefMut},
str::FromStr,
};
use crate::new::base::build::{BuildInMessage, NameCompressor};
use crate::new::base::parse::{ParseMessageBytes, SplitMessageBytes};
use crate::new::base::wire::{
AsBytes, BuildBytes, ParseBytes, ParseError, SplitBytes, TruncationError,
};
use crate::utils::dst::{UnsizedCopy, UnsizedCopyFrom};
#[derive(AsBytes, UnsizedCopy)]
#[repr(transparent)]
pub struct Label([u8]);
impl Label {
pub const ROOT: &'static Self = {
unsafe { Self::from_bytes_unchecked(&[0]) }
};
pub const WILDCARD: &'static Self = {
unsafe { Self::from_bytes_unchecked(&[1, b'*']) }
};
}
impl Label {
pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Self {
unsafe { core::mem::transmute(bytes) }
}
pub unsafe fn from_bytes_unchecked_mut(bytes: &mut [u8]) -> &mut Self {
unsafe { core::mem::transmute(bytes) }
}
}
impl<'a> ParseMessageBytes<'a> for &'a Label {
fn parse_message_bytes(
contents: &'a [u8],
start: usize,
) -> Result<Self, ParseError> {
Self::parse_bytes(&contents[start..])
}
}
impl<'a> SplitMessageBytes<'a> for &'a Label {
fn split_message_bytes(
contents: &'a [u8],
start: usize,
) -> Result<(Self, usize), ParseError> {
Self::split_bytes(&contents[start..])
.map(|(this, rest)| (this, contents.len() - start - rest.len()))
}
}
impl BuildInMessage for Label {
fn build_in_message(
&self,
contents: &mut [u8],
start: usize,
_compressor: &mut NameCompressor,
) -> Result<usize, TruncationError> {
let bytes = &self.0;
let end = start + bytes.len();
contents
.get_mut(start..end)
.ok_or(TruncationError)?
.copy_from_slice(bytes);
Ok(end)
}
}
impl<'a> SplitBytes<'a> for &'a Label {
fn split_bytes(bytes: &'a [u8]) -> Result<(Self, &'a [u8]), ParseError> {
let &size = bytes.first().ok_or(ParseError)?;
if size < 64 && bytes.len() > size as usize {
let (label, rest) = bytes.split_at(1 + size as usize);
Ok((unsafe { Label::from_bytes_unchecked(label) }, rest))
} else {
Err(ParseError)
}
}
}
impl<'a> ParseBytes<'a> for &'a Label {
fn parse_bytes(bytes: &'a [u8]) -> Result<Self, ParseError> {
match Self::split_bytes(bytes) {
Ok((this, &[])) => Ok(this),
_ => Err(ParseError),
}
}
}
impl BuildBytes for Label {
fn build_bytes<'b>(
&self,
bytes: &'b mut [u8],
) -> Result<&'b mut [u8], TruncationError> {
self.0.build_bytes(bytes)
}
fn built_bytes_size(&self) -> usize {
self.0.len()
}
}
impl Label {
pub const fn is_root(&self) -> bool {
self.0.len() == 1
}
pub const fn is_wildcard(&self) -> bool {
matches!(self.0, [1, b'*'])
}
pub const fn as_bytes(&self) -> &[u8] {
&self.0
}
pub fn contents(&self) -> &[u8] {
&self.0[1..]
}
}
impl AsRef<[u8]> for Label {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl<'a> From<&'a Label> for &'a [u8] {
fn from(value: &'a Label) -> Self {
&value.0
}
}
#[cfg(feature = "alloc")]
impl Clone for alloc::boxed::Box<Label> {
fn clone(&self) -> Self {
(*self).unsized_copy_into()
}
}
impl PartialEq for Label {
fn eq(&self, other: &Self) -> bool {
let this = self.as_bytes().iter().map(u8::to_ascii_lowercase);
let that = other.as_bytes().iter().map(u8::to_ascii_lowercase);
this.eq(that)
}
}
impl Eq for Label {}
impl PartialOrd for Label {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Label {
fn cmp(&self, other: &Self) -> Ordering {
let this = self.as_bytes().iter().map(u8::to_ascii_lowercase);
let that = other.as_bytes().iter().map(u8::to_ascii_lowercase);
this.cmp(that)
}
}
impl Hash for Label {
fn hash<H: Hasher>(&self, state: &mut H) {
for &byte in self.as_bytes() {
state.write_u8(byte.to_ascii_lowercase())
}
}
}
impl fmt::Display for Label {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_wildcard() {
return f.write_str("*");
}
self.contents().iter().try_for_each(|&byte| {
if byte.is_ascii_alphanumeric() || b"-_".contains(&byte) {
write!(f, "{}", byte as char)
} else if byte.is_ascii_graphic() {
write!(f, "\\{}", byte as char)
} else {
write!(f, "\\{:03}", byte)
}
})
}
}
impl fmt::Debug for Label {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Label({self})")
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for Label {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use std::string::ToString;
if serializer.is_human_readable() {
serializer.serialize_newtype_struct("Label", &self.to_string())
} else {
serializer.serialize_newtype_struct("Label", self.contents())
}
}
}
#[derive(Clone)]
#[repr(transparent)]
pub struct LabelBuf {
data: [u8; 64],
}
impl LabelBuf {
pub fn copy_from(label: &Label) -> Self {
let bytes = label.as_bytes();
let mut data = [0u8; 64];
data[..bytes.len()].copy_from_slice(bytes);
Self { data }
}
}
impl UnsizedCopyFrom for LabelBuf {
type Source = Label;
fn unsized_copy_from(value: &Self::Source) -> Self {
Self::copy_from(value)
}
}
impl ParseMessageBytes<'_> for LabelBuf {
fn parse_message_bytes(
contents: &'_ [u8],
start: usize,
) -> Result<Self, ParseError> {
Self::parse_bytes(&contents[start..])
}
}
impl SplitMessageBytes<'_> for LabelBuf {
fn split_message_bytes(
contents: &'_ [u8],
start: usize,
) -> Result<(Self, usize), ParseError> {
Self::split_bytes(&contents[start..])
.map(|(this, rest)| (this, contents.len() - start - rest.len()))
}
}
impl BuildInMessage for LabelBuf {
fn build_in_message(
&self,
contents: &mut [u8],
start: usize,
compressor: &mut NameCompressor,
) -> Result<usize, TruncationError> {
Label::build_in_message(self, contents, start, compressor)
}
}
impl ParseBytes<'_> for LabelBuf {
fn parse_bytes(bytes: &[u8]) -> Result<Self, ParseError> {
<&Label>::parse_bytes(bytes).map(Self::copy_from)
}
}
impl SplitBytes<'_> for LabelBuf {
fn split_bytes(bytes: &'_ [u8]) -> Result<(Self, &'_ [u8]), ParseError> {
<&Label>::split_bytes(bytes)
.map(|(label, rest)| (Self::copy_from(label), rest))
}
}
impl BuildBytes for LabelBuf {
fn build_bytes<'b>(
&self,
bytes: &'b mut [u8],
) -> Result<&'b mut [u8], TruncationError> {
(**self).build_bytes(bytes)
}
fn built_bytes_size(&self) -> usize {
(**self).built_bytes_size()
}
}
impl Deref for LabelBuf {
type Target = Label;
fn deref(&self) -> &Self::Target {
let size = self.data[0] as usize;
let label = &self.data[..1 + size];
unsafe { Label::from_bytes_unchecked(label) }
}
}
impl DerefMut for LabelBuf {
fn deref_mut(&mut self) -> &mut Self::Target {
let size = self.data[0] as usize;
let label = &mut self.data[..1 + size];
unsafe { Label::from_bytes_unchecked_mut(label) }
}
}
impl Borrow<Label> for LabelBuf {
fn borrow(&self) -> &Label {
self
}
}
impl BorrowMut<Label> for LabelBuf {
fn borrow_mut(&mut self) -> &mut Label {
self
}
}
impl AsRef<Label> for LabelBuf {
fn as_ref(&self) -> &Label {
self
}
}
impl AsMut<Label> for LabelBuf {
fn as_mut(&mut self) -> &mut Label {
self
}
}
impl PartialEq for LabelBuf {
fn eq(&self, that: &Self) -> bool {
**self == **that
}
}
impl Eq for LabelBuf {}
impl PartialOrd for LabelBuf {
fn partial_cmp(&self, that: &Self) -> Option<Ordering> {
Some(self.cmp(that))
}
}
impl Ord for LabelBuf {
fn cmp(&self, that: &Self) -> Ordering {
(**self).cmp(&**that)
}
}
impl Hash for LabelBuf {
fn hash<H: Hasher>(&self, state: &mut H) {
(**self).hash(state)
}
}
impl LabelBuf {
pub fn parse_str(mut s: &[u8]) -> Result<(Self, &[u8]), LabelParseError> {
if let &[b'*', ref rest @ ..] = s {
return Ok((Self::copy_from(Label::WILDCARD), rest));
}
let mut this = Self { data: [0u8; 64] };
loop {
let full = s;
let &[b, ref rest @ ..] = s else { break };
s = rest;
let value = if b.is_ascii_alphanumeric() || b"-_".contains(&b) {
b
} else if b == b'\\' {
let &[b, ref rest @ ..] = s else { break };
s = rest;
if b.is_ascii_digit() {
let digits = rest
.get(..3)
.ok_or(LabelParseError::PartialEscape)?;
let digits = core::str::from_utf8(digits)
.map_err(|_| LabelParseError::InvalidEscape)?;
digits
.parse()
.map_err(|_| LabelParseError::InvalidEscape)?
} else if b.is_ascii_graphic() {
b
} else {
return Err(LabelParseError::InvalidEscape);
}
} else if b". \n\r\t".contains(&b) {
s = full;
break;
} else {
return Err(LabelParseError::InvalidChar);
};
let off = this.data[0] as usize + 1;
this.data[0] += 1;
let ptr =
this.data.get_mut(off).ok_or(LabelParseError::Overlong)?;
*ptr = value;
}
if this.data[0] == 0 {
return Err(LabelParseError::Empty);
}
Ok((this, s))
}
}
impl FromStr for LabelBuf {
type Err = LabelParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match Self::parse_str(s.as_bytes()) {
Ok((this, &[])) => Ok(this),
Ok(_) => Err(LabelParseError::InvalidChar),
Err(err) => Err(err),
}
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for LabelBuf {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
(**self).serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'a> serde::Deserialize<'a> for LabelBuf {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
if deserializer.is_human_readable() {
struct V;
impl serde::de::Visitor<'_> for V {
type Value = LabelBuf;
fn expecting(
&self,
f: &mut fmt::Formatter<'_>,
) -> fmt::Result {
f.write_str("a label, in the DNS zonefile format")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
v.parse().map_err(|err| E::custom(err))
}
}
struct NV;
impl<'a> serde::de::Visitor<'a> for NV {
type Value = LabelBuf;
fn expecting(
&self,
f: &mut fmt::Formatter<'_>,
) -> fmt::Result {
f.write_str("a DNS label")
}
fn visit_newtype_struct<D>(
self,
deserializer: D,
) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'a>,
{
deserializer.deserialize_str(V)
}
}
deserializer.deserialize_newtype_struct("Label", NV)
} else {
struct V;
impl serde::de::Visitor<'_> for V {
type Value = LabelBuf;
fn expecting(
&self,
f: &mut fmt::Formatter<'_>,
) -> fmt::Result {
f.write_str("a label, in the DNS wire format")
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
LabelBuf::parse_bytes(v).map_err(|_| {
E::custom(
"misformatted label for the DNS wire format",
)
})
}
}
struct NV;
impl<'a> serde::de::Visitor<'a> for NV {
type Value = LabelBuf;
fn expecting(
&self,
f: &mut fmt::Formatter<'_>,
) -> fmt::Result {
f.write_str("a DNS label")
}
fn visit_newtype_struct<D>(
self,
deserializer: D,
) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'a>,
{
deserializer.deserialize_bytes(V)
}
}
deserializer.deserialize_newtype_struct("Label", NV)
}
}
}
#[cfg(feature = "serde")]
impl<'a> serde::Deserialize<'a> for std::boxed::Box<Label> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
LabelBuf::deserialize(deserializer)
.map(|this| this.unsized_copy_into())
}
}
#[derive(Clone)]
pub struct LabelIter<'a> {
bytes: &'a [u8],
}
impl<'a> LabelIter<'a> {
pub const unsafe fn new_unchecked(bytes: &'a [u8]) -> Self {
Self { bytes }
}
}
impl<'a> LabelIter<'a> {
pub const fn remaining(&self) -> &'a [u8] {
self.bytes
}
pub const fn is_empty(&self) -> bool {
self.bytes.is_empty()
}
}
impl<'a> Iterator for LabelIter<'a> {
type Item = &'a Label;
fn next(&mut self) -> Option<Self::Item> {
if self.bytes.is_empty() {
return None;
}
let (head, tail) =
unsafe { <&Label>::split_bytes(self.bytes).unwrap_unchecked() };
self.bytes = tail;
Some(head)
}
}
impl FusedIterator for LabelIter<'_> {}
impl fmt::Debug for LabelIter<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Labels<'a>(&'a LabelIter<'a>);
impl fmt::Debug for Labels<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.0.clone()).finish()
}
}
f.debug_tuple("LabelIter").field(&Labels(self)).finish()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum LabelParseError {
Overlong,
Empty,
InvalidChar,
PartialEscape,
InvalidEscape,
}
#[cfg(feature = "std")]
impl std::error::Error for LabelParseError {}
impl fmt::Display for LabelParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::Overlong => "the label was too large",
Self::Empty => "the label was empty",
Self::InvalidChar => "the label contained an invalid character",
Self::PartialEscape => "the label contained an incomplete escape",
Self::InvalidEscape => "the label contained an invalid escape",
})
}
}