use std::{
fmt,
hash::{Hash, Hasher},
str::FromStr,
};
use crate::codec::{Error, reader::Reader, writer::Writer};
const MAX_LABEL_LEN: usize = 63;
const MAX_NAME_WIRE_LEN: usize = 255;
const MAX_SKIP_HOPS: usize = 16;
const MAX_SKIP_BYTES: usize = 512;
#[derive(Clone, Debug)]
pub struct Name {
inner: Box<str>,
}
impl Name {
#[must_use]
pub fn as_str(&self) -> &str {
&self.inner
}
fn from_normalized(s: String) -> Self {
Self {
inner: s.into_boxed_str(),
}
}
pub fn read_question(reader: &mut Reader) -> Result<Self, Error> {
let mut normalized = String::with_capacity(64);
let mut wire_len: usize = 1;
loop {
let len_byte = reader.read_u8()?;
if len_byte & 0xC0 == 0xC0 {
return Err(Error::CompressionPointerInQuestion);
}
let label_len = len_byte as usize;
if label_len == 0 {
break;
}
if label_len > MAX_LABEL_LEN {
return Err(Error::LabelTooLong(label_len));
}
wire_len = wire_len
.checked_add(1 + label_len)
.ok_or(Error::NameTooLong(usize::MAX))?;
if wire_len > MAX_NAME_WIRE_LEN {
return Err(Error::NameTooLong(wire_len));
}
let label_bytes = reader.read_slice(label_len)?;
for &b in label_bytes.iter() {
normalized.push(b.to_ascii_lowercase() as char);
}
normalized.push('.');
}
if normalized.is_empty() {
normalized.push('.');
}
Ok(Self::from_normalized(normalized))
}
pub fn write(&self, writer: &mut Writer) {
for label in self.inner.split('.') {
if label.is_empty() {
continue;
}
writer.write_u8(label.len() as u8);
writer.write_slice(label.as_bytes());
}
writer.write_u8(0);
}
pub fn skip_rr(reader: &mut Reader) -> Result<(), Error> {
let msg = reader.as_bytes().clone();
let msg_len = msg.len();
let mut cur_pos = reader.position();
let mut fixed_reader = false;
let mut hops: usize = 0;
let mut total_label_bytes: usize = 0;
loop {
let len_byte = msg.get(cur_pos).copied().ok_or(Error::UnexpectedEof {
offset: cur_pos,
needed: 1,
available: msg_len.saturating_sub(cur_pos),
})?;
cur_pos += 1;
if len_byte & 0xC0 == 0xC0 {
let low_byte = msg.get(cur_pos).copied().ok_or(Error::UnexpectedEof {
offset: cur_pos,
needed: 1,
available: msg_len.saturating_sub(cur_pos),
})?;
cur_pos += 1;
if !fixed_reader {
reader.read_slice(cur_pos - reader.position())?;
fixed_reader = true;
}
let target = u16::from_be_bytes([len_byte & 0x3F, low_byte]) as usize;
if target >= msg_len {
return Err(Error::InvalidPointerTarget {
target: target as u16,
msg_len,
});
}
let pointer_start = cur_pos - 2;
if target >= pointer_start {
return Err(Error::InvalidPointerTarget {
target: target as u16,
msg_len,
});
}
hops += 1;
if hops > MAX_SKIP_HOPS {
return Err(Error::NameSkipLimitExceeded);
}
cur_pos = target;
continue;
}
if len_byte & 0xC0 != 0 {
return Err(Error::LabelTooLong(len_byte as usize));
}
let label_len = len_byte as usize;
if label_len == 0 {
if !fixed_reader {
reader.read_slice(cur_pos - reader.position())?;
}
return Ok(());
}
if label_len > MAX_LABEL_LEN {
return Err(Error::LabelTooLong(label_len));
}
total_label_bytes = total_label_bytes.saturating_add(label_len);
if total_label_bytes > MAX_SKIP_BYTES {
return Err(Error::NameSkipLimitExceeded);
}
if total_label_bytes > MAX_NAME_WIRE_LEN {
return Err(Error::NameTooLong(total_label_bytes));
}
cur_pos = cur_pos
.checked_add(label_len)
.ok_or(Error::NameSkipLimitExceeded)?;
if cur_pos > msg_len {
return Err(Error::UnexpectedEof {
offset: cur_pos - label_len,
needed: label_len,
available: msg_len.saturating_sub(cur_pos - label_len),
});
}
if !fixed_reader {
reader.read_slice(1 + label_len)?;
}
}
}
}
impl PartialEq for Name {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
}
}
impl Eq for Name {}
impl Hash for Name {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.hash(state);
}
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.inner)
}
}
impl FromStr for Name {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s == "." || s.is_empty() {
return Ok(Self::from_normalized(".".to_string()));
}
let s_stripped = s.strip_suffix('.').unwrap_or(s);
let mut normalized = String::with_capacity(s.len() + 1);
let mut wire_len: usize = 1;
for label in s_stripped.split('.') {
if label.is_empty() {
return Err(Error::EmptyLabel);
}
let label_len = label.len();
if label_len > MAX_LABEL_LEN {
return Err(Error::LabelTooLong(label_len));
}
wire_len = wire_len
.checked_add(1 + label_len)
.ok_or(Error::NameTooLong(usize::MAX))?;
if wire_len > MAX_NAME_WIRE_LEN {
return Err(Error::NameTooLong(wire_len));
}
for c in label.chars() {
normalized.push(c.to_ascii_lowercase());
}
normalized.push('.');
}
Ok(Self::from_normalized(normalized))
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use bytes::Bytes;
use super::*;
use crate::codec::{reader::Reader, writer::Writer};
fn wire_encode(name: &Name) -> Bytes {
let mut w = Writer::new();
name.write(&mut w);
w.finish()
}
fn reader_from(bytes: &'static [u8]) -> Reader {
Reader::new(Bytes::from_static(bytes))
}
#[test]
fn parse_simple() {
let n: Name = "example.com".parse().unwrap();
assert_eq!(n.to_string(), "example.com.");
}
#[test]
fn parse_with_trailing_dot() {
let n: Name = "example.com.".parse().unwrap();
assert_eq!(n.to_string(), "example.com.");
}
#[test]
fn parse_root_dot() {
let n: Name = ".".parse().unwrap();
assert_eq!(n.to_string(), ".");
}
#[test]
fn parse_root_empty_str() {
let n: Name = "".parse().unwrap();
assert_eq!(n.to_string(), ".");
}
#[test]
fn parse_single_label() {
let n: Name = "localhost".parse().unwrap();
assert_eq!(n.to_string(), "localhost.");
}
#[test]
fn normalization_mixed_case() {
let n: Name = "Example.COM".parse().unwrap();
assert_eq!(n.to_string(), "example.com.");
}
#[test]
fn normalization_uppercase_all() {
let n: Name = "UPPER.CASE.LABELS".parse().unwrap();
assert_eq!(n.to_string(), "upper.case.labels.");
}
#[test]
fn eq_case_insensitive() {
let a: Name = "Example.COM".parse().unwrap();
let b: Name = "example.com".parse().unwrap();
let c: Name = "example.com.".parse().unwrap();
assert_eq!(a, b);
assert_eq!(b, c);
assert_eq!(a, c);
}
#[test]
fn hash_consistent_with_eq() {
let a: Name = "Example.COM".parse().unwrap();
let b: Name = "example.com.".parse().unwrap();
let mut set = HashSet::new();
set.insert(a.clone());
assert!(!set.insert(b));
assert_eq!(set.len(), 1);
}
#[test]
fn hashset_lookup_case_insensitive() {
let mut set: HashSet<Name> = HashSet::new();
set.insert("blocked.example.com.".parse().unwrap());
let query: Name = "BLOCKED.EXAMPLE.COM".parse().unwrap();
assert!(set.contains(&query));
}
#[test]
fn label_too_long_from_str() {
let long_label = "a".repeat(64);
let err = Name::from_str(&long_label).unwrap_err();
assert!(
matches!(err, Error::LabelTooLong(64)),
"unexpected error: {err}"
);
}
#[test]
fn label_exactly_63_ok() {
let label = "a".repeat(63);
let n = Name::from_str(&label).unwrap();
assert!(n.to_string().starts_with(&label));
}
#[test]
fn name_too_long_from_str() {
let label = "a".repeat(63);
let long_name = format!("{label}.{label}.{label}.{label}");
let err = Name::from_str(&long_name).unwrap_err();
assert!(
matches!(err, Error::NameTooLong(_)),
"unexpected error: {err}"
);
}
#[test]
fn name_max_length_ok() {
let label = "a".repeat(63);
let name = format!("{label}.{label}.{label}");
assert!(Name::from_str(&name).is_ok());
}
#[test]
fn empty_label_in_middle_is_error() {
let err = Name::from_str("foo..bar").unwrap_err();
assert!(matches!(err, Error::EmptyLabel), "unexpected error: {err}");
}
#[test]
fn wire_round_trip_simple() {
let original: Name = "example.com".parse().unwrap();
let wire = wire_encode(&original);
let mut r = Reader::new(wire);
let decoded = Name::read_question(&mut r).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn wire_round_trip_root() {
let original: Name = ".".parse().unwrap();
let wire = wire_encode(&original);
assert_eq!(&wire[..], &[0x00]);
let mut r = Reader::new(wire);
let decoded = Name::read_question(&mut r).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn wire_round_trip_single_label() {
let original: Name = "localhost".parse().unwrap();
let wire = wire_encode(&original);
assert_eq!(wire[0], 9);
assert_eq!(&wire[1..10], b"localhost");
assert_eq!(wire[10], 0);
let mut r = Reader::new(wire);
let decoded = Name::read_question(&mut r).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn wire_round_trip_multi_label() {
let original: Name = "a.b.c.d".parse().unwrap();
let wire = wire_encode(&original);
let mut r = Reader::new(wire);
let decoded = Name::read_question(&mut r).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn wire_round_trip_mixed_case_normalizes() {
let original: Name = "UPPER.CASE".parse().unwrap();
let wire = wire_encode(&original);
let mut r = Reader::new(wire);
let decoded = Name::read_question(&mut r).unwrap();
assert_eq!(decoded.to_string(), "upper.case.");
}
#[test]
fn compression_pointer_in_question_rejected() {
let mut r = reader_from(&[0xC0, 0x0C]);
let err = Name::read_question(&mut r).unwrap_err();
assert!(
matches!(err, Error::CompressionPointerInQuestion),
"unexpected error: {err}"
);
}
#[test]
fn compression_pointer_mid_question_rejected() {
let mut r = reader_from(&[0x03, b'f', b'o', b'o', 0xC0, 0x0C]);
let err = Name::read_question(&mut r).unwrap_err();
assert!(
matches!(err, Error::CompressionPointerInQuestion),
"unexpected error: {err}"
);
}
#[test]
fn wire_label_too_long_rejected() {
let mut data = vec![64u8];
data.extend_from_slice(&[b'a'; 64]);
data.push(0);
let mut r = Reader::new(Bytes::from(data));
let err = Name::read_question(&mut r).unwrap_err();
assert!(
matches!(err, Error::LabelTooLong(64)),
"unexpected error: {err}"
);
}
#[test]
fn skip_rr_simple_name_no_pointer() {
let wire: &[u8] = &[
0x03, b'w', b'w', b'w', 0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 0x03, b'c',
b'o', b'm', 0x00, 0xFF,
];
let mut r = Reader::new(Bytes::from_static(wire));
Name::skip_rr(&mut r).unwrap();
assert_eq!(r.position(), 17);
}
#[test]
fn skip_rr_root_name() {
let wire: &[u8] = &[0x00, 0xFF];
let mut r = Reader::new(Bytes::from_static(wire));
Name::skip_rr(&mut r).unwrap();
assert_eq!(r.position(), 1);
}
#[test]
fn skip_rr_name_ending_in_pointer() {
let mut msg = vec![0u8; 12]; msg.extend_from_slice(&[0x03, b'c', b'o', b'm', 0x00]);
msg.extend_from_slice(&[0x00, 0x00, 0x00]);
msg.extend_from_slice(&[0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 0xC0, 0x0C]);
msg.push(0xAB);
let mut r = Reader::new(Bytes::from(msg));
r.read_slice(20).unwrap();
assert_eq!(r.position(), 20);
Name::skip_rr(&mut r).unwrap();
assert_eq!(r.position(), 30);
}
#[test]
fn skip_rr_pointer_loop_self_terminates() {
let mut msg = vec![0u8; 12]; msg.extend_from_slice(&[0xC0, 0x0C]);
let mut r = Reader::new(Bytes::from(msg));
r.read_slice(12).unwrap();
let err = Name::skip_rr(&mut r).unwrap_err();
assert!(
matches!(
err,
Error::InvalidPointerTarget { .. } | Error::NameSkipLimitExceeded
),
"expected pointer loop to return an error, got: {err}"
);
}
#[test]
fn skip_rr_pointer_two_cycle_terminates() {
let mut msg = vec![0u8; 12];
msg.extend_from_slice(&[0xC0, 0x0E]); msg.extend_from_slice(&[0xC0, 0x0C]);
let mut r = Reader::new(Bytes::from(msg));
r.read_slice(12).unwrap();
let err = Name::skip_rr(&mut r).unwrap_err();
assert!(
matches!(
err,
Error::InvalidPointerTarget { .. } | Error::NameSkipLimitExceeded
),
"expected two-cycle loop to error, got: {err}"
);
}
#[test]
fn skip_rr_forward_pointer_rejected() {
let mut msg = vec![0u8; 12];
msg.extend_from_slice(&[0xC0, 0x14]);
let mut r = Reader::new(Bytes::from(msg));
r.read_slice(12).unwrap();
let err = Name::skip_rr(&mut r).unwrap_err();
assert!(
matches!(err, Error::InvalidPointerTarget { target: 20, .. }),
"unexpected error: {err}"
);
}
#[test]
fn skip_rr_out_of_bounds_pointer_rejected() {
let mut msg = vec![0u8; 12];
msg.extend_from_slice(&[0xC0 | 0x27, 0x0F]); let mut r = Reader::new(Bytes::from(msg));
r.read_slice(12).unwrap();
let err = Name::skip_rr(&mut r).unwrap_err();
assert!(
matches!(err, Error::InvalidPointerTarget { .. }),
"unexpected error: {err}"
);
}
#[test]
fn skip_rr_truncated_label_content() {
let wire: &[u8] = &[0x05, b'a', b'b'];
let mut r = Reader::new(Bytes::from_static(wire));
let err = Name::skip_rr(&mut r).unwrap_err();
assert!(
matches!(err, Error::UnexpectedEof { .. }),
"unexpected error: {err}"
);
}
#[test]
fn skip_rr_truncated_pointer() {
let wire: &[u8] = &[0xC0];
let mut r = Reader::new(Bytes::from_static(wire));
let err = Name::skip_rr(&mut r).unwrap_err();
assert!(
matches!(err, Error::UnexpectedEof { .. }),
"unexpected error: {err}"
);
}
#[test]
fn no_panic_empty_buffer() {
let mut r = reader_from(&[]);
assert!(Name::read_question(&mut r).is_err());
}
#[test]
fn no_panic_skip_empty_buffer() {
let mut r = reader_from(&[]);
assert!(Name::skip_rr(&mut r).is_err());
}
#[test]
fn no_panic_all_ones() {
let data = vec![0xFFu8; 512];
let mut r = Reader::new(Bytes::from(data));
let _ = Name::read_question(&mut r);
}
#[test]
fn no_panic_skip_all_ones() {
let data = vec![0xFFu8; 512];
let mut r = Reader::new(Bytes::from(data));
let _ = Name::skip_rr(&mut r);
}
}