use crate::{
Error, Result,
constants::{DOMAIN_NAME_LABEL_MAX_LENGTH, DOMAIN_NAME_MAX_LENGTH},
};
pub fn check_label_bytes(label: &[u8]) -> Result<()> {
if label.is_empty() {
return Err(Error::DomainNameLabelIsEmpty);
}
let len = label.len();
if len > DOMAIN_NAME_LABEL_MAX_LENGTH {
return Err(Error::DomainNameLabelTooLong(len));
}
for b in label.iter() {
if !(b.is_ascii_alphanumeric() || *b == b'-' || *b == b'_') {
return Err(Error::DomainNameLabelInvalidChar(
"domain name label invalid character",
*b,
));
}
}
unsafe {
let fc = label.get_unchecked(0);
if *fc == b'-' {
return Err(Error::DomainNameLabelInvalidChar(
"domain name label first character is '-'",
*fc,
));
}
let lc = label.get_unchecked(len - 1);
if *lc == b'-' {
return Err(Error::DomainNameLabelInvalidChar(
"domain name label last character is '-'",
*lc,
));
}
}
Ok(())
}
#[inline(always)]
pub fn check_label(label: &str) -> Result<()> {
check_label_bytes(label.as_bytes())
}
pub fn check_name_bytes(name: &[u8]) -> Result<()> {
if name.is_empty() {
return Err(Error::DomainNameLabelIsEmpty);
}
if name == b"." {
return Ok(());
}
let len = name.len();
let mut domain_start: Option<usize> = None;
let mut i = 0;
for j in 0..len {
let byte = unsafe { *name.get_unchecked(j) };
if byte == b'.' {
let label = unsafe { name.get_unchecked(i..j) };
check_label_bytes(label)?;
i = j + 1;
domain_start = Some(i);
}
}
match domain_start {
Some(ds) if len - ds > 0 => {
let label = unsafe { name.get_unchecked(ds..len) };
check_label_bytes(label)?;
}
None => check_label_bytes(name)?,
_ => (),
}
let last_byte = unsafe { *name.get_unchecked(len - 1) };
let full_length = if last_byte == b'.' { len + 1 } else { len + 2 };
if full_length > DOMAIN_NAME_MAX_LENGTH {
return Err(Error::DomainNameTooLong(full_length));
}
Ok(())
}
#[inline(always)]
pub fn check_name(name: &str) -> Result<()> {
check_name_bytes(name.as_bytes())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_check_label() {
let res = check_label_bytes(b"");
assert!(matches!(res, Err(Error::DomainNameLabelIsEmpty)));
let malformed: &[(&[u8], u8)] = &[(b"-abel", b'-')];
for (m, c) in malformed {
let res = check_label_bytes(m);
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar(
"domain name label first character is '-'",
v
)) if v == *c
));
let res = check_label(std::str::from_utf8(m).unwrap());
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar(
"domain name label first character is '-'",
v
)) if v == *c
));
}
let malformed: &[(&[u8], u8)] = &[(b"label-", b'-')];
for (m, c) in malformed {
let res = check_label_bytes(m);
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar(
"domain name label last character is '-'",
v
)) if v == *c
));
let res = check_label(std::str::from_utf8(m).unwrap());
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar(
"domain name label last character is '-'",
v
)) if v == *c
));
}
let invalid_char: &[(&[u8], u8)] = &[(b"la.el", b'.'), (b"\tabel", b'\t')];
for (ic, c) in invalid_char {
let res = check_label_bytes(ic);
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar("domain name label invalid character", v)) if v == *c
));
let res = check_label(std::str::from_utf8(ic).unwrap());
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar("domain name label invalid character", v)) if v == *c
));
}
let l_64 = "a".repeat(64);
let too_large = &[l_64.as_bytes()];
for tl in too_large {
let res = check_label_bytes(tl);
assert!(matches!(res, Err(Error::DomainNameLabelTooLong(l)) if l == tl.len()));
let res = check_label(std::str::from_utf8(tl).unwrap());
assert!(matches!(res, Err(Error::DomainNameLabelTooLong(l)) if l == tl.len()));
}
let l_63 = "a".repeat(63);
let good: &[&[u8]] = &[b"label", b"labe1", b"1abel", l_63.as_bytes()];
for g in good {
assert!(check_label_bytes(g).is_ok());
assert!(check_label(std::str::from_utf8(g).unwrap()).is_ok());
}
}
#[test]
fn test_check_name() {
let good: &[&[u8]] = &[
b".",
b"com",
b"3om",
b"example.com",
b"exampl0.com.",
b"3xample2.com",
b"exam-3le.com",
b"su--b.exAmp1e.com",
b"_example.com",
b"example_.com",
];
for g in good {
assert!(check_name_bytes(g).is_ok());
assert!(check_name(std::str::from_utf8(g).unwrap()).is_ok());
}
let empty: &[&[u8]] = &[
b"",
b"..",
b"example.com..",
b"example..com",
b"sub..example.com",
];
for e in empty {
let res = check_name_bytes(e);
assert!(matches!(res, Err(Error::DomainNameLabelIsEmpty)));
let res = check_name(std::str::from_utf8(e).unwrap());
assert!(matches!(res, Err(Error::DomainNameLabelIsEmpty)));
}
let malformed: &[(&[u8], u8)] = &[(b"-xample.com", b'-')];
for (m, c) in malformed {
let res = check_name_bytes(m);
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar(
"domain name label first character is '-'",
v
)) if v == *c
));
let res = check_name(std::str::from_utf8(m).unwrap());
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar(
"domain name label first character is '-'",
v
)) if v == *c
));
}
let malformed: &[(&[u8], u8)] = &[(b"co-", b'-'), (b"example-.com", b'-')];
for (m, c) in malformed {
let res = check_name_bytes(m);
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar(
"domain name label last character is '-'",
v
)) if v == *c
));
let res = check_name(std::str::from_utf8(m).unwrap());
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar(
"domain name label last character is '-'",
v
)) if v == *c
));
}
let invalid_char: &[(&[u8], u8)] = &[(b"examp|e.com.", b'|')];
for (ic, c) in invalid_char {
let res = check_name_bytes(ic);
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar("domain name label invalid character", v)) if v == *c
));
let res = check_name(std::str::from_utf8(ic).unwrap());
assert!(matches!(
res,
Err(Error::DomainNameLabelInvalidChar("domain name label invalid character", v)) if v == *c
));
}
let l_63 = "a".repeat(63);
let l_61 = "b".repeat(61);
let dn_253 = [l_63.clone(), l_63.clone(), l_63].join(".") + "." + l_61.as_str();
let dn_254 = dn_253.clone() + "b";
assert!(check_name_bytes(dn_253.as_bytes()).is_ok());
assert!(check_name(dn_253.as_str()).is_ok());
assert!(check_name_bytes((dn_253.clone() + ".").as_bytes()).is_ok());
assert!(check_name((dn_253 + ".").as_str()).is_ok());
let too_long = &[dn_254.as_str()];
for tl in too_long {
let res = check_name(tl);
assert!(matches!(res, Err(Error::DomainNameTooLong(s)) if s == tl.len() + 2));
let res = check_name_bytes(tl.as_bytes());
assert!(matches!(res, Err(Error::DomainNameTooLong(s)) if s == tl.len() + 2));
}
}
}