use crate::{
Error, Result,
bytes::{Cursor, Reader},
constants::DOMAIN_NAME_MAX_LENGTH,
names::InlineName,
};
use std::{
cmp::Ordering,
fmt::{self, Display, Formatter},
hash::{Hash, Hasher},
str::FromStr,
};
#[derive(Debug, Default, Clone)]
pub struct Name {
name: String,
}
impl Name {
#[inline(always)]
pub fn new() -> Self {
Self {
name: Default::default(),
}
}
pub fn root() -> Self {
Self {
name: String::from("."),
}
}
fn from(s: &str) -> Result<Self> {
super::check_name(s)?;
let mut dn = Self {
name: String::from(s),
};
let bytes = s.as_bytes();
let last_byte = unsafe { *bytes.get_unchecked(bytes.len() - 1) };
if last_byte != b'.' {
dn.name.push('.');
}
Ok(dn)
}
#[inline(always)]
pub fn as_str(&self) -> &str {
&self.name
}
#[inline(always)]
pub fn len(&self) -> usize {
self.name.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.name.is_empty()
}
#[inline(always)]
pub fn clear(&mut self) {
self.name.clear();
}
pub(crate) fn append_label_bytes(&mut self, label: &[u8]) -> Result<()> {
super::check_label_bytes(label)?;
let label_as_str = unsafe { std::str::from_utf8_unchecked(label) };
let new_len = self.name.len() + label_as_str.len() + 1;
if new_len > DOMAIN_NAME_MAX_LENGTH {
return Err(Error::DomainNameTooLong(new_len));
}
self.name.push_str(label_as_str);
self.name.push('.');
Ok(())
}
pub(crate) fn append_label(&mut self, label: &str) -> Result<()> {
super::check_label(label)?;
let new_len = self.name.len() + label.len() + 1;
if new_len > DOMAIN_NAME_MAX_LENGTH {
return Err(Error::DomainNameTooLong(new_len));
}
self.name.push_str(label);
self.name.push('.');
Ok(())
}
pub fn set_root(&mut self) {
self.name.clear();
self.name.push('.');
}
}
impl TryFrom<&str> for Name {
type Error = Error;
fn try_from(value: &str) -> Result<Self> {
Self::from(value)
}
}
impl FromStr for Name {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
Self::from(s)
}
}
impl AsRef<str> for Name {
fn as_ref(&self) -> &str {
&self.name
}
}
impl PartialEq for Name {
fn eq(&self, other: &Self) -> bool {
self.name
.as_bytes()
.eq_ignore_ascii_case(other.name.as_bytes())
}
}
impl PartialOrd for Name {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Name {
fn cmp(&self, other: &Self) -> Ordering {
for i in 0..self.len().min(other.len()) {
let left = unsafe { self.name.as_bytes().get_unchecked(i) };
let right = unsafe { other.name.as_bytes().get_unchecked(i) };
let ord = left.to_ascii_lowercase().cmp(&right.to_ascii_lowercase());
if Ordering::Equal != ord {
return ord;
}
}
self.len().cmp(&other.len())
}
}
impl PartialEq<&str> for Name {
fn eq(&self, other: &&str) -> bool {
let l_is_root = self.name.as_bytes() == b".";
let r_is_root = *other == ".";
match (l_is_root, r_is_root) {
(true, true) => return true,
(false, false) => {}
_ => return false,
}
let mut bytes = self.name.as_bytes();
if !bytes.is_empty() && !other.ends_with('.') {
bytes = &bytes[..bytes.len() - 1];
}
bytes.eq_ignore_ascii_case(other.as_bytes())
}
}
impl Eq for Name {}
impl Hash for Name {
fn hash<H: Hasher>(&self, state: &mut H) {
for b in self.name.as_bytes() {
state.write_u8(b.to_ascii_lowercase());
}
}
}
impl Display for Name {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.pad(self.as_str())
}
}
impl From<Name> for String {
fn from(name: Name) -> Self {
name.name
}
}
impl From<InlineName> for Name {
fn from(name: InlineName) -> Self {
Self {
name: name.as_str().to_string(),
}
}
}
impl From<&InlineName> for Name {
fn from(name: &InlineName) -> Self {
Self {
name: name.as_str().to_string(),
}
}
}
impl super::private::DNameBase for Name {
#[inline(always)]
fn as_str(&self) -> &str {
self.as_str()
}
#[inline(always)]
fn len(&self) -> usize {
self.len()
}
#[inline(always)]
fn is_empty(&self) -> bool {
self.is_empty()
}
#[inline(always)]
fn clear(&mut self) {
self.clear()
}
#[inline(always)]
fn append_label_bytes(&mut self, label: &[u8]) -> Result<()> {
self.append_label_bytes(label)
}
#[inline(always)]
fn append_label(&mut self, label: &str) -> Result<()> {
self.append_label(label)
}
#[inline(always)]
fn set_root(&mut self) {
self.set_root()
}
#[inline(always)]
fn from_cursor(c: &mut Cursor<'_>) -> Result<Self> {
c.read()
}
}
impl super::DName for Name {}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn test_new() {
let dn = Name::new();
assert!(dn.is_empty());
assert_eq!(dn.len(), 0);
}
#[test]
fn test_default() {
let dn: Name = Default::default();
assert!(dn.is_empty());
assert_eq!(dn.len(), 0);
}
#[test]
fn test_from() {
let label_63 = "a".repeat(63);
let label_61 = "b".repeat(60);
let dn_253 = [
label_63.as_str(),
label_63.as_str(),
label_63.as_str(),
label_61.as_str(),
]
.join(".");
let dn_254 = dn_253.clone() + ".";
let dn_255 = [
label_63.as_str(),
label_63.as_str(),
label_63.as_str(),
label_63.as_str(),
]
.join(".");
let success_cases = &[
"3om",
"com",
"example.com",
"sub.example.com",
"3ub.example.com",
".",
"example.com.",
"3xample.com.",
"EXAMPLE.com",
"EXAMPLE.COM",
"EXAMPLE.COM.",
dn_253.as_str(),
dn_254.as_str(),
];
for sc in success_cases {
let dn = Name::from(sc).unwrap();
let expected = if sc.ends_with('.') {
sc.to_string()
} else {
format!("{sc}.")
};
assert_eq!(dn.as_str(), &expected);
assert_eq!(dn.len(), expected.len());
}
let failure_cases = &[
"",
"..",
"3c-",
"co-",
"example..com",
"sub..example.com",
"example-.com",
"-xample.com",
"examp|e.com",
"exa\u{203C}ple.com",
dn_255.as_str(),
];
for fc in failure_cases {
assert!(Name::from(fc).is_err())
}
}
#[test]
fn test_len() {
let mut dn = Name::new();
assert_eq!(dn.len(), 0);
dn.append_label("example").unwrap();
assert_eq!(dn.len(), 8);
dn.append_label("com").unwrap();
assert_eq!(dn.len(), 12);
}
#[test]
fn test_append_label_too_long() {
let l_63 = "a".repeat(63);
let l_62 = "b".repeat(62);
let mut dn = Name::new();
dn.append_label(&l_63).unwrap();
assert_eq!(dn.len(), 64);
dn.append_label(&l_63).unwrap();
assert_eq!(dn.len(), 128);
dn.append_label(&l_63).unwrap();
assert_eq!(dn.len(), 192);
{
let mut dn = dn.clone();
dn.append_label("small").unwrap();
let res = dn.append_label(&l_63);
assert!(
matches!(res, Err(Error::DomainNameTooLong(s)) if s == dn.len() + l_63.len() + 1)
);
}
let res = dn.clone().append_label(&l_63);
assert!(matches!(res, Err(Error::DomainNameTooLong(s)) if s == dn.len() + l_63.len() + 1));
dn.append_label(&l_62).unwrap();
assert_eq!(dn.len(), 255);
}
#[test]
fn test_eq() {
let dn1 = Name::from("example.com").unwrap();
let dn2 = Name::from("EXAMPLE.COM").unwrap();
let dn3 = Name::from("eXaMpLe.cOm").unwrap();
assert_eq!(dn1, dn2);
assert_eq!(dn1, dn3);
assert_eq!(dn2, dn3);
}
#[test]
fn test_neq() {
let dn1 = Name::from("example.com").unwrap();
let dn2 = Name::from("sub.example.com").unwrap();
let dn3 = Name::from("Sub.examp1e.com").unwrap();
assert_ne!(dn1, dn2);
assert_ne!(dn1, dn3);
assert_ne!(dn2, dn3);
}
#[test]
fn test_eq_str() {
let dn1 = Name::from("example.com").unwrap();
let dn2 = Name::from("EXAMPLE.COM").unwrap();
assert_eq!(dn1, "EXAMPLE.COM.");
assert_eq!(dn1, "EXAMPLE.COM");
assert_eq!(dn1, "eXaMpLe.cOm.");
assert_eq!(dn2, "eXaMpLe.cOm");
assert_eq!(dn2, "eXaMpLe.cOm");
assert_eq!(dn2, "eXaMpLe.cOm.");
assert_eq!(Name::from("sub.example.com").unwrap(), "sub.example.com.");
assert_eq!(Name::from("sub.example.com.").unwrap(), "sub.example.com");
assert_eq!(Name::new(), "");
assert_eq!(Name::root(), ".");
}
#[test]
fn test_neq_str() {
let dn1 = Name::from("example.com").unwrap();
let dn2 = Name::from("sub.example.com").unwrap();
assert_ne!(dn1, "sub.example.com");
assert_ne!(dn1, "sub.example.com.");
assert_ne!(dn1, "Sub.examp1e.com");
assert_ne!(dn1, "Sub.examp1e.com.");
assert_ne!(dn2, "Sub.examp1e.com");
assert_ne!(dn2, "Sub.examp1e.com.");
assert_ne!(Name::new(), ".");
assert_ne!(Name::root(), "");
}
#[test]
fn test_hash() {
let dn = Name::from("example.com").unwrap();
let mut s = HashSet::new();
s.insert(dn);
assert!(s.contains(&Name::from("example.com.").unwrap()));
assert!(s.contains(&Name::from("eXaMpLe.COM").unwrap()));
assert!(s.contains(&Name::from("EXAMPLE.COM").unwrap()));
assert!(!s.contains(&Name::from("suB.Example.com.").unwrap()));
}
#[test]
fn test_ord() {
let dn1 = Name::from("example.com").unwrap();
let dn2 = Name::from("ExaMplE.com").unwrap();
let dn3 = Name::from("Sub.example.com").unwrap();
assert_eq!(Ordering::Equal, dn1.cmp(&dn2));
assert_eq!(Ordering::Less, dn1.cmp(&dn3));
assert_eq!(Ordering::Greater, dn3.cmp(&dn1));
assert_eq!(Ordering::Equal, Name::root().cmp(&Name::root()));
assert_eq!(Ordering::Equal, Name::new().cmp(&Name::new()));
}
#[test]
fn test_partial_ord() {
let dn1 = Name::from("example.com").unwrap();
let dn2 = Name::from("ExaMplE.com").unwrap();
let dn3 = Name::from("Sub.example.com").unwrap();
assert_eq!(Some(Ordering::Equal), dn1.partial_cmp(&dn2));
assert_eq!(Some(Ordering::Less), dn1.partial_cmp(&dn3));
assert_eq!(Some(Ordering::Greater), dn3.partial_cmp(&dn1));
assert_eq!(
Some(Ordering::Equal),
Name::root().partial_cmp(&Name::root())
);
assert_eq!(Some(Ordering::Equal), Name::new().partial_cmp(&Name::new()));
}
}