use super::super::octets::{EmptyBuilder, OctetsBuilder, ShortBuf};
use super::dname::Dname;
use super::relative::{RelativeDname, RelativeDnameError};
use super::traits::{ToDname, ToRelativeDname};
#[cfg(feature = "bytes")]
use bytes::BytesMut;
use core::{fmt, ops};
#[cfg(feature = "std")]
use std::vec::Vec;
#[derive(Clone)]
pub struct DnameBuilder<Builder> {
builder: Builder,
head: Option<usize>,
}
impl<Builder> DnameBuilder<Builder> {
pub(super) unsafe fn from_builder_unchecked(builder: Builder) -> Self {
DnameBuilder {
builder,
head: None,
}
}
pub fn new() -> Self
where
Builder: EmptyBuilder,
{
unsafe { DnameBuilder::from_builder_unchecked(Builder::empty()) }
}
pub fn with_capacity(capacity: usize) -> Self
where
Builder: EmptyBuilder,
{
unsafe {
DnameBuilder::from_builder_unchecked(Builder::with_capacity(
capacity,
))
}
}
pub fn from_builder(builder: Builder) -> Result<Self, RelativeDnameError>
where
Builder: OctetsBuilder + AsRef<[u8]>,
{
RelativeDname::check_slice(builder.as_ref())?;
Ok(unsafe { DnameBuilder::from_builder_unchecked(builder) })
}
}
#[cfg(feature = "std")]
impl DnameBuilder<Vec<u8>> {
pub fn new_vec() -> Self {
Self::new()
}
pub fn vec_with_capacity(capacity: usize) -> Self {
Self::with_capacity(capacity)
}
}
#[cfg(feature = "bytes")]
impl DnameBuilder<BytesMut> {
pub fn new_bytes() -> Self {
Self::new()
}
pub fn bytes_with_capacity(capacity: usize) -> Self {
Self::with_capacity(capacity)
}
}
impl<Builder: OctetsBuilder> DnameBuilder<Builder> {
pub fn len(&self) -> usize {
self.builder.len()
}
pub fn is_empty(&self) -> bool {
self.builder.is_empty()
}
}
impl<Builder: OctetsBuilder + AsMut<[u8]>> DnameBuilder<Builder> {
pub fn in_label(&self) -> bool {
self.head.is_some()
}
pub fn push(&mut self, ch: u8) -> Result<(), PushError> {
let len = self.len();
if len >= 254 {
return Err(PushError::LongName);
}
if let Some(head) = self.head {
if len - head > 63 {
return Err(PushError::LongLabel);
}
self.builder.append_slice(&[ch])?;
} else {
self.head = Some(len);
self.builder.append_slice(&[0, ch])?;
}
Ok(())
}
pub fn append_slice(&mut self, slice: &[u8]) -> Result<(), PushError> {
if slice.is_empty() {
return Ok(());
}
if let Some(head) = self.head {
if slice.len() > 63 - (self.len() - head) {
return Err(PushError::LongLabel);
}
} else {
if slice.len() > 63 {
return Err(PushError::LongLabel);
}
if self.len() + slice.len() > 254 {
return Err(PushError::LongName);
}
self.head = Some(self.len());
self.builder.append_slice(&[0])?;
}
self.builder.append_slice(slice)?;
Ok(())
}
pub fn end_label(&mut self) {
if let Some(head) = self.head {
let len = self.len() - head - 1;
self.builder.as_mut()[head] = len as u8;
self.head = None;
}
}
pub fn append_label(&mut self, label: &[u8]) -> Result<(), PushError> {
let head = self.head;
self.end_label();
if let Err(err) = self.append_slice(label) {
self.head = head;
return Err(err);
}
self.end_label();
Ok(())
}
pub fn append_name<N: ToRelativeDname>(
&mut self,
name: &N,
) -> Result<(), PushNameError> {
let head = self.head.take();
self.end_label();
if self.len() + name.len() > 254 {
self.head = head;
return Err(PushNameError::LongName);
}
for label in name.iter_labels() {
label.build(&mut self.builder)?
}
Ok(())
}
pub fn append_chars<C: IntoIterator<Item = char>>(
&mut self,
chars: C,
) -> Result<(), FromStrError> {
let mut chars = chars.into_iter();
while let Some(ch) = chars.next() {
match ch {
'.' => {
if !self.in_label() {
return Err(FromStrError::EmptyLabel);
}
self.end_label();
}
'\\' => {
let in_label = self.in_label();
self.push(parse_escape(&mut chars, in_label)?)?;
}
' '..='-' | '/'..='[' | ']'..='~' => self.push(ch as u8)?,
_ => return Err(FromStrError::IllegalCharacter(ch)),
}
}
Ok(())
}
pub fn finish(mut self) -> RelativeDname<Builder::Octets> {
self.end_label();
unsafe { RelativeDname::from_octets_unchecked(self.builder.freeze()) }
}
pub fn into_dname(mut self) -> Result<Dname<Builder::Octets>, PushError> {
self.end_label();
self.builder.append_slice(&[0])?;
Ok(unsafe { Dname::from_octets_unchecked(self.builder.freeze()) })
}
pub fn append_origin<N: ToDname>(
mut self,
origin: &N,
) -> Result<Dname<Builder::Octets>, PushNameError> {
self.end_label();
if self.len() + origin.len() > 255 {
return Err(PushNameError::LongName);
}
for label in origin.iter_labels() {
label.build(&mut self.builder)?
}
Ok(unsafe { Dname::from_octets_unchecked(self.builder.freeze()) })
}
}
impl<Builder: EmptyBuilder> Default for DnameBuilder<Builder> {
fn default() -> Self {
Self::new()
}
}
impl<Builder: AsRef<[u8]>> ops::Deref for DnameBuilder<Builder> {
type Target = [u8];
fn deref(&self) -> &[u8] {
self.builder.as_ref()
}
}
impl<Builder: AsRef<[u8]>> AsRef<[u8]> for DnameBuilder<Builder> {
fn as_ref(&self) -> &[u8] {
self.builder.as_ref()
}
}
pub(super) fn parse_escape<C>(
chars: &mut C,
in_label: bool,
) -> Result<u8, LabelFromStrError>
where
C: Iterator<Item = char>,
{
let ch = chars.next().ok_or(LabelFromStrError::UnexpectedEnd)?;
if ('0'..='9').contains(&ch) {
let v = ch.to_digit(10).unwrap() * 100
+ chars
.next()
.ok_or(LabelFromStrError::UnexpectedEnd)
.and_then(|c| {
c.to_digit(10).ok_or(LabelFromStrError::IllegalEscape)
})?
* 10
+ chars
.next()
.ok_or(LabelFromStrError::UnexpectedEnd)
.and_then(|c| {
c.to_digit(10).ok_or(LabelFromStrError::IllegalEscape)
})?;
if v > 255 {
return Err(LabelFromStrError::IllegalEscape);
}
Ok(v as u8)
} else if ch == '[' {
if in_label {
Ok(b'[')
} else {
Err(LabelFromStrError::BinaryLabel)
}
} else {
Ok(ch as u8)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum PushError {
LongLabel,
LongName,
ShortBuf,
}
impl From<ShortBuf> for PushError {
fn from(_: ShortBuf) -> PushError {
PushError::ShortBuf
}
}
impl fmt::Display for PushError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
PushError::LongLabel => f.write_str("long label"),
PushError::LongName => f.write_str("long domain name"),
PushError::ShortBuf => ShortBuf.fmt(f),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for PushError {}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum PushNameError {
LongName,
ShortBuf,
}
impl From<ShortBuf> for PushNameError {
fn from(_: ShortBuf) -> Self {
PushNameError::ShortBuf
}
}
impl fmt::Display for PushNameError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
PushNameError::LongName => f.write_str("long domain name"),
PushNameError::ShortBuf => ShortBuf.fmt(f),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for PushNameError {}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum LabelFromStrError {
UnexpectedEnd,
BinaryLabel,
LongLabel,
IllegalEscape,
IllegalCharacter(char),
}
impl fmt::Display for LabelFromStrError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
LabelFromStrError::UnexpectedEnd => {
f.write_str("unexpected end of input")
}
LabelFromStrError::BinaryLabel => {
f.write_str("a binary label was encountered")
}
LabelFromStrError::LongLabel => {
f.write_str("label length limit exceeded")
}
LabelFromStrError::IllegalEscape => {
f.write_str("illegal escape sequence")
}
LabelFromStrError::IllegalCharacter(char) => {
write!(f, "illegal character '{}'", char)
}
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for LabelFromStrError {}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum FromStrError {
UnexpectedEnd,
EmptyLabel,
BinaryLabel,
LongLabel,
IllegalEscape,
IllegalCharacter(char),
LongName,
ShortBuf,
}
impl From<PushError> for FromStrError {
fn from(err: PushError) -> FromStrError {
match err {
PushError::LongLabel => FromStrError::LongLabel,
PushError::LongName => FromStrError::LongName,
PushError::ShortBuf => FromStrError::ShortBuf,
}
}
}
impl From<PushNameError> for FromStrError {
fn from(err: PushNameError) -> FromStrError {
match err {
PushNameError::LongName => FromStrError::LongName,
PushNameError::ShortBuf => FromStrError::ShortBuf,
}
}
}
impl From<LabelFromStrError> for FromStrError {
fn from(err: LabelFromStrError) -> FromStrError {
match err {
LabelFromStrError::UnexpectedEnd => FromStrError::UnexpectedEnd,
LabelFromStrError::BinaryLabel => FromStrError::BinaryLabel,
LabelFromStrError::LongLabel => FromStrError::LongLabel,
LabelFromStrError::IllegalEscape => FromStrError::IllegalEscape,
LabelFromStrError::IllegalCharacter(ch) => {
FromStrError::IllegalCharacter(ch)
}
}
}
}
impl fmt::Display for FromStrError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
FromStrError::UnexpectedEnd => {
f.write_str("unexpected end of input")
}
FromStrError::EmptyLabel => {
f.write_str("an empty label was encountered")
}
FromStrError::BinaryLabel => {
f.write_str("a binary label was encountered")
}
FromStrError::LongLabel => {
f.write_str("label length limit exceeded")
}
FromStrError::IllegalEscape => {
f.write_str("illegal escape sequence")
}
FromStrError::IllegalCharacter(char) => {
write!(f, "illegal character '{}'", char)
}
FromStrError::LongName => f.write_str("long domain name"),
FromStrError::ShortBuf => ShortBuf.fmt(f),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for FromStrError {}
#[cfg(test)]
#[cfg(feature = "std")]
mod test {
use super::*;
#[test]
fn build() {
let mut builder = DnameBuilder::new_vec();
builder.push(b'w').unwrap();
builder.append_slice(b"ww").unwrap();
builder.end_label();
builder.append_slice(b"exa").unwrap();
builder.push(b'm').unwrap();
builder.push(b'p').unwrap();
builder.append_slice(b"le").unwrap();
builder.end_label();
builder.append_slice(b"com").unwrap();
assert_eq!(builder.finish().as_slice(), b"\x03www\x07example\x03com");
}
#[test]
fn build_by_label() {
let mut builder = DnameBuilder::new_vec();
builder.append_label(b"www").unwrap();
builder.append_label(b"example").unwrap();
builder.append_label(b"com").unwrap();
assert_eq!(builder.finish().as_slice(), b"\x03www\x07example\x03com");
}
#[test]
fn build_mixed() {
let mut builder = DnameBuilder::new_vec();
builder.push(b'w').unwrap();
builder.append_slice(b"ww").unwrap();
builder.append_label(b"example").unwrap();
builder.append_slice(b"com").unwrap();
assert_eq!(builder.finish().as_slice(), b"\x03www\x07example\x03com");
}
#[test]
fn name_limit() {
let mut builder = DnameBuilder::new_vec();
for _ in 0..25 {
builder.append_label(b"123456789").unwrap();
}
assert_eq!(builder.append_label(b"12345"), Err(PushError::LongName));
assert_eq!(builder.clone().append_label(b"1234"), Ok(()));
assert_eq!(builder.append_slice(b"12345"), Err(PushError::LongName));
assert_eq!(builder.clone().append_slice(b"1234"), Ok(()));
assert_eq!(builder.append_slice(b"12"), Ok(()));
assert_eq!(builder.push(b'3'), Ok(()));
assert_eq!(builder.push(b'4'), Err(PushError::LongName))
}
#[test]
fn label_limit() {
let mut builder = DnameBuilder::new_vec();
builder.append_label(&[0u8; 63][..]).unwrap();
assert_eq!(
builder.append_label(&[0u8; 64][..]),
Err(PushError::LongLabel)
);
assert_eq!(
builder.append_label(&[0u8; 164][..]),
Err(PushError::LongLabel)
);
builder.append_slice(&[0u8; 60][..]).unwrap();
builder.clone().append_label(b"123").unwrap();
assert_eq!(builder.append_slice(b"1234"), Err(PushError::LongLabel));
builder.append_slice(b"12").unwrap();
builder.push(b'3').unwrap();
assert_eq!(builder.push(b'4'), Err(PushError::LongLabel));
}
#[test]
fn finish() {
let mut builder = DnameBuilder::new_vec();
builder.append_label(b"www").unwrap();
builder.append_label(b"example").unwrap();
builder.append_slice(b"com").unwrap();
assert_eq!(builder.finish().as_slice(), b"\x03www\x07example\x03com");
}
#[test]
fn into_dname() {
let mut builder = DnameBuilder::new_vec();
builder.append_label(b"www").unwrap();
builder.append_label(b"example").unwrap();
builder.append_slice(b"com").unwrap();
assert_eq!(
builder.into_dname().unwrap().as_slice(),
b"\x03www\x07example\x03com\x00"
);
}
}