use std::collections::HashMap;
use std::fmt;
use shared::error::*;
const NAME_LEN: usize = 255;
#[derive(Default, PartialEq, Eq, Debug, Clone)]
pub(crate) struct Name {
pub(crate) data: String,
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.data)
}
}
impl Name {
pub(crate) fn new(data: &str) -> Result<Self> {
if data.len() > NAME_LEN {
Err(Error::ErrCalcLen)
} else {
Ok(Name {
data: data.to_owned(),
})
}
}
pub(crate) fn pack(
&self,
mut msg: Vec<u8>,
compression: &mut Option<HashMap<String, usize>>,
compression_off: usize,
) -> Result<Vec<u8>> {
let data = self.data.as_bytes();
if data.is_empty() || data[data.len() - 1] != b'.' {
return Err(Error::ErrNonCanonicalName);
}
if data.len() == 1 && data[0] == b'.' {
msg.push(0);
return Ok(msg);
}
let mut begin = 0;
for i in 0..data.len() {
if data[i] == b'.' {
if i - begin >= (1 << 6) {
return Err(Error::ErrSegTooLong);
}
if i - begin == 0 {
return Err(Error::ErrZeroSegLen);
}
msg.push((i - begin) as u8);
msg.extend_from_slice(&data[begin..i]);
begin = i + 1;
continue;
}
if (i == 0 || data[i - 1] == b'.')
&& let Some(compression) = compression
{
let key: String = self.data.chars().skip(i).collect();
if let Some(ptr) = compression.get(&key) {
msg.push(((ptr >> 8) | 0xC0) as u8);
msg.push((ptr & 0xFF) as u8);
return Ok(msg);
}
if msg.len() <= 0x3FFF {
compression.insert(key, msg.len() - compression_off);
}
}
}
msg.push(0);
Ok(msg)
}
pub(crate) fn unpack(&mut self, msg: &[u8], off: usize) -> Result<usize> {
self.unpack_compressed(msg, off, true )
}
pub(crate) fn unpack_compressed(
&mut self,
msg: &[u8],
off: usize,
allow_compression: bool,
) -> Result<usize> {
let mut curr_off = off;
let mut new_off = off;
let mut ptr = 0;
let mut name = String::new();
loop {
if curr_off >= msg.len() {
return Err(Error::ErrBaseLen);
}
let c = msg[curr_off];
curr_off += 1;
match c & 0xC0 {
0x00 => {
if c == 0x00 {
break;
}
let end_off = curr_off + c as usize;
if end_off > msg.len() {
return Err(Error::ErrCalcLen);
}
name.push_str(String::from_utf8(msg[curr_off..end_off].to_vec())?.as_str());
name.push('.');
curr_off = end_off;
}
0xC0 => {
if !allow_compression {
return Err(Error::ErrCompressedSrv);
}
if curr_off >= msg.len() {
return Err(Error::ErrInvalidPtr);
}
let c1 = msg[curr_off];
curr_off += 1;
if ptr == 0 {
new_off = curr_off;
}
ptr += 1;
if ptr > 10 {
return Err(Error::ErrTooManyPtr);
}
curr_off = ((c ^ 0xC0) as usize) << 8 | (c1 as usize);
}
_ => {
return Err(Error::ErrReserved);
}
}
}
if name.is_empty() {
name.push('.');
}
if name.len() > NAME_LEN {
return Err(Error::ErrCalcLen);
}
self.data = name;
if ptr == 0 {
new_off = curr_off;
}
Ok(new_off)
}
pub(crate) fn skip(msg: &[u8], off: usize) -> Result<usize> {
let mut new_off = off;
loop {
if new_off >= msg.len() {
return Err(Error::ErrBaseLen);
}
let c = msg[new_off];
new_off += 1;
match c & 0xC0 {
0x00 => {
if c == 0x00 {
break;
}
new_off += c as usize;
if new_off > msg.len() {
return Err(Error::ErrCalcLen);
}
}
0xC0 => {
new_off += 1;
break;
}
_ => {
return Err(Error::ErrReserved);
}
}
}
Ok(new_off)
}
}