use super::Error;
use core::convert::TryInto;
use core::fmt;
use core::hash;
use core::iter;
use core::mem;
use core::num::NonZeroUsize;
use memchr::Memchr;
pub trait Serialize<'a> {
fn serialized_len(&self) -> usize;
fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error>;
}
pub trait Deserialize<'a> {
fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error>;
}
#[derive(Debug, Copy, Clone)]
pub struct Cursor<'a> {
bytes: &'a [u8],
cursor: usize,
}
impl<'a> Cursor<'a> {
pub fn new(bytes: &'a [u8]) -> Self {
Self { bytes, cursor: 0 }
}
pub fn original(&self) -> &'a [u8] {
self.bytes
}
pub fn remaining(&self) -> &'a [u8] {
&self.bytes[self.cursor..]
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.bytes.len() - self.cursor
}
pub fn at(&self, pos: usize) -> Self {
Self {
bytes: self.bytes,
cursor: pos,
}
}
pub fn advance(mut self, n: usize) -> Result<Self, Error> {
if n == 0 {
return Ok(self);
}
if self.cursor + n > self.bytes.len() {
return Err(Error::NotEnoughReadBytes {
tried_to_read: NonZeroUsize::new(self.cursor.saturating_add(n)).unwrap(),
available: self.bytes.len(),
});
}
self.cursor += n;
Ok(self)
}
fn read_error(&self, n: usize) -> Error {
Error::NotEnoughReadBytes {
tried_to_read: NonZeroUsize::new(self.cursor.saturating_add(n)).unwrap(),
available: self.bytes.len(),
}
}
}
impl Serialize<'_> for () {
fn serialized_len(&self) -> usize {
0
}
fn serialize(&self, _bytes: &mut [u8]) -> Result<usize, Error> {
Ok(0)
}
}
impl<'a> Deserialize<'a> for () {
fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
Ok(bytes)
}
}
#[derive(Clone, Copy)]
pub struct Label<'a> {
repr: Repr<'a>,
}
#[derive(Clone, Copy)]
enum Repr<'a> {
Bytes {
original: &'a [u8],
start: usize,
end: usize,
},
String {
string: &'a str,
},
}
impl Default for Label<'_> {
fn default() -> Self {
Self {
repr: Repr::Bytes {
original: &[0],
start: 0,
end: 1,
},
}
}
}
impl<'a> PartialEq<Label<'a>> for Label<'_> {
fn eq(&self, other: &Label<'a>) -> bool {
self.segments().eq(other.segments())
}
}
impl Eq for Label<'_> {}
impl<'a> PartialOrd<Label<'a>> for Label<'_> {
fn partial_cmp(&self, other: &Label<'a>) -> Option<core::cmp::Ordering> {
self.segments().partial_cmp(other.segments())
}
}
impl Ord for Label<'_> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.segments().cmp(other.segments())
}
}
impl hash::Hash for Label<'_> {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
for segment in self.segments() {
segment.hash(state);
}
}
}
impl<'a> Label<'a> {
pub fn segments(&self) -> impl Iterator<Item = LabelSegment<'a>> {
match self.repr {
Repr::Bytes {
original, start, ..
} => Either::A(parse_bytes(original, start)),
Repr::String { string } => Either::B(parse_string(string)),
}
}
pub fn names(&self) -> impl Iterator<Item = Result<&'a str, &'a [u8]>> {
match self.repr {
Repr::String { string } => {
Either::A(parse_string(string).filter_map(|seg| seg.as_str().map(Ok)))
}
Repr::Bytes {
original, start, ..
} => {
let mut cursor = Cursor {
bytes: original,
cursor: start,
};
Either::B(iter::from_fn(move || {
loop {
let mut ls: LabelSegment<'_> = LabelSegment::Empty;
cursor = ls.deserialize(cursor).ok()?;
match ls {
LabelSegment::Empty => return None,
LabelSegment::Pointer(pos) => {
cursor = cursor.at(pos.into());
}
LabelSegment::String(label) => return Some(Ok(label)),
}
}
}))
}
}
}
}
impl<'a> From<&'a str> for Label<'a> {
fn from(string: &'a str) -> Self {
Self {
repr: Repr::String { string },
}
}
}
impl fmt::Debug for Label<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct LabelFmt<'a>(&'a Label<'a>);
impl fmt::Debug for LabelFmt<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self.0, f)
}
}
f.debug_tuple("Label").field(&LabelFmt(self)).finish()
}
}
impl fmt::Display for Label<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.names().enumerate().try_for_each(|(i, name)| {
if i > 0 {
f.write_str(".")?;
}
match name {
Ok(name) => f.write_str(name),
Err(_) => f.write_str("???"),
}
})
}
}
impl<'a> Serialize<'a> for Label<'a> {
fn serialized_len(&self) -> usize {
if let Repr::Bytes { start, end, .. } = self.repr {
return end - start;
}
self.segments()
.map(|item| item.serialized_len())
.fold(0, |a, b| a.saturating_add(b))
}
fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
if let Repr::Bytes {
original,
start,
end,
} = self.repr
{
bytes[..end - start].copy_from_slice(&original[start..end]);
return Ok(end - start);
}
self.segments().try_fold(0, |mut offset, item| {
let len = item.serialize(&mut bytes[offset..])?;
offset += len;
Ok(offset)
})
}
}
impl<'a> Deserialize<'a> for Label<'a> {
fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
let original = cursor.original();
let start = cursor.cursor;
let mut end = start;
loop {
let len_char = match original.get(end) {
Some(0) => {
end += 1;
break;
}
Some(ptr) if ptr & PTR_MASK != 0 => {
end += 2;
break;
}
Some(len_char) => *len_char,
None => {
return Err(Error::NotEnoughReadBytes {
tried_to_read: NonZeroUsize::new(cursor.cursor + end).unwrap(),
available: original.len(),
})
}
};
let len = len_char as usize;
end += len + 1;
}
self.repr = Repr::Bytes {
original,
start,
end,
};
cursor.advance(end - start)
}
}
fn parse_bytes(bytes: &[u8], position: usize) -> impl Iterator<Item = LabelSegment<'_>> + '_ {
let mut cursor = Cursor {
bytes,
cursor: position,
};
let mut keep_going = true;
iter::from_fn(move || {
if !keep_going {
return None;
}
let mut segment = LabelSegment::Empty;
cursor = segment.deserialize(cursor).ok()?;
match segment {
LabelSegment::String(_) => {}
_ => {
keep_going = false;
}
}
Some(segment)
})
}
fn parse_string(str: &str) -> impl Iterator<Item = LabelSegment<'_>> + '_ {
let dot = Memchr::new(b'.', str.as_bytes());
let mut last_index = 0;
let dot = dot.chain(Some(str.len()));
dot.filter_map(move |index| {
let item = &str[last_index..index];
last_index = index.saturating_add(1);
if item.is_empty() {
None
} else {
Some(LabelSegment::String(item))
}
})
.chain(Some(LabelSegment::Empty))
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum LabelSegment<'a> {
Empty,
String(&'a str),
Pointer(u16),
}
const MAX_STR_LEN: usize = !PTR_MASK as usize;
const PTR_MASK: u8 = 0b1100_0000;
impl<'a> LabelSegment<'a> {
fn as_str(&self) -> Option<&'a str> {
match self {
Self::String(s) => Some(s),
_ => None,
}
}
}
impl Default for LabelSegment<'_> {
fn default() -> Self {
Self::Empty
}
}
impl<'a> Serialize<'a> for LabelSegment<'a> {
fn serialized_len(&self) -> usize {
match self {
Self::Empty => 1,
Self::Pointer(_) => 2,
Self::String(s) => 1 + s.len(),
}
}
fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
match self {
Self::Empty => {
bytes[0] = 0;
Ok(1)
}
Self::Pointer(ptr) => {
let [mut b1, b2] = ptr.to_be_bytes();
b1 |= PTR_MASK;
bytes[0] = b1;
bytes[1] = b2;
Ok(2)
}
Self::String(s) => {
let len = s.len();
if len > MAX_STR_LEN {
return Err(Error::NameTooLong(len));
}
if len > bytes.len() {
panic!("not enough bytes to serialize string");
}
bytes[0] = len as u8;
bytes[1..=len].copy_from_slice(s.as_bytes());
Ok(len + 1)
}
}
}
}
impl<'a> Deserialize<'a> for LabelSegment<'a> {
fn deserialize(&mut self, cursor: Cursor<'a>) -> Result<Cursor<'a>, Error> {
let b1 = *cursor
.remaining()
.first()
.ok_or_else(|| cursor.read_error(1))?;
if b1 == 0 {
*self = Self::Empty;
cursor.advance(1)
} else if b1 & PTR_MASK == PTR_MASK {
let [b1, b2]: [u8; 2] = cursor.remaining()[..2]
.try_into()
.map_err(|_| cursor.read_error(2))?;
let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]);
*self = Self::Pointer(ptr);
cursor.advance(2)
} else {
let len = b1 as usize;
if len > MAX_STR_LEN {
return Err(Error::NameTooLong(len));
}
let bytes = cursor.remaining()[1..=len]
.try_into()
.map_err(|_| cursor.read_error(len + 1))?;
let s = simdutf8::compat::from_utf8(bytes)?;
*self = Self::String(s);
cursor.advance(len + 1)
}
}
}
macro_rules! serialize_num {
($($num_ty: ident),*) => {
$(
impl<'a> Serialize<'a> for $num_ty {
fn serialized_len(&self) -> usize {
mem::size_of::<$num_ty>()
}
fn serialize(&self, bytes: &mut [u8]) -> Result<usize, Error> {
if bytes.len() < mem::size_of::<$num_ty>() {
panic!("Not enough space to serialize a {}", stringify!($num_ty));
}
let value = (*self).to_be_bytes();
bytes[..mem::size_of::<$num_ty>()].copy_from_slice(&value);
Ok(mem::size_of::<$num_ty>())
}
}
impl<'a> Deserialize<'a> for $num_ty {
fn deserialize(&mut self, bytes: Cursor<'a>) -> Result<Cursor<'a>, Error> {
if bytes.len() < mem::size_of::<$num_ty>() {
return Err(bytes.read_error(mem::size_of::<$num_ty>()));
}
let mut value = [0; mem::size_of::<$num_ty>()];
value.copy_from_slice(&bytes.remaining()[..mem::size_of::<$num_ty>()]);
*self = $num_ty::from_be_bytes(value);
bytes.advance(mem::size_of::<$num_ty>())
}
}
)*
}
}
serialize_num! {
u8, u16, u32, u64,
i8, i16, i32, i64
}
enum Either<A, B> {
A(A),
B(B),
}
impl<A: Iterator, Other: Iterator<Item = A::Item>> Iterator for Either<A, Other> {
type Item = A::Item;
fn next(&mut self) -> Option<Self::Item> {
match self {
Either::A(a) => a.next(),
Either::B(b) => b.next(),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self {
Either::A(a) => a.size_hint(),
Either::B(b) => b.size_hint(),
}
}
fn fold<B, F>(self, init: B, f: F) -> B
where
Self: Sized,
F: FnMut(B, Self::Item) -> B,
{
match self {
Either::A(a) => a.fold(init, f),
Either::B(b) => b.fold(init, f),
}
}
}