use core::fmt::{self, Write};
use either::Either;
use super::error::{ProtoError, not_enough_read_data};
#[derive(Debug, Default, Clone, Copy)]
pub struct Txt<'container, 'innards> {
repr: Repr<'container, 'innards>,
}
impl<'container, 'innards> From<&'container [&'innards str]> for Txt<'container, 'innards> {
fn from(strings: &'container [&'innards str]) -> Self {
Self::from_strings(strings)
}
}
#[derive(Debug, Clone, Copy)]
enum Repr<'container, 'innards> {
BytesStrings {
original: &'innards [u8],
start: usize,
end: usize,
},
Strings(&'container [&'innards str]),
}
impl Default for Repr<'_, '_> {
fn default() -> Self {
Self::Strings(&[])
}
}
impl<'container, 'innards> Txt<'container, 'innards> {
#[inline]
pub const fn from_strings(strings: &'container [&'innards str]) -> Self {
Self {
repr: Repr::Strings(strings),
}
}
#[inline]
pub const fn from_bytes(src: &'innards [u8]) -> Self {
Self::from_bytes_in(src, 0, src.len())
}
#[inline]
pub(super) const fn from_bytes_in(original: &'innards [u8], start: usize, end: usize) -> Self {
Self {
repr: Repr::BytesStrings {
original,
start,
end,
},
}
}
#[inline]
pub const fn strings(&self) -> Strings<'container, 'innards> {
let repr = match &self.repr {
Repr::BytesStrings {
original,
start,
end,
} => StringsRepr::Bytes {
original,
position: *start,
end: *end,
},
Repr::Strings(strings) => StringsRepr::Strings {
strings,
position: 0,
},
};
Strings { repr }
}
#[inline]
pub fn repr(&self) -> Either<&'container [&'innards str], &'innards [u8]> {
match &self.repr {
Repr::BytesStrings {
original,
start,
end,
} => Either::Right(&original[*start..*end]),
Repr::Strings(strings) => Either::Left(strings),
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct Str<'a> {
repr: StrRepr<'a>,
}
impl fmt::Display for Str<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.repr {
StrRepr::Bytes {
original,
start,
length,
} => {
let bytes = &original[*start..*start + *length];
for &byte in bytes {
match byte {
b'"' | b'\\' => {
f.write_str("\\")?;
f.write_char(byte as char)?;
}
b if (b' '..=b'~').contains(&b) => {
f.write_char(b as char)?;
}
b => {
f.write_str(
simdutf8::basic::from_utf8(escape_bytes(b).as_slice())
.expect("escape bytes must be valid utf8"),
)?;
}
}
}
Ok(())
}
StrRepr::String(s) => write!(f, "{}", s),
}
}
}
#[derive(Clone, Copy, Debug)]
enum StrRepr<'a> {
Bytes {
original: &'a [u8],
start: usize,
length: usize,
},
String(&'a str),
}
impl<'a> Str<'a> {
fn from_bytes(original: &'a [u8], start: usize, length: usize) -> Self {
Self {
repr: StrRepr::Bytes {
original,
start,
length,
},
}
}
pub fn as_bytes(&self) -> &'a [u8] {
match self.repr {
StrRepr::Bytes {
original,
start,
length,
} => &original[start..start + length],
StrRepr::String(s) => s.as_bytes(),
}
}
#[inline]
pub const fn new(s: &'a str) -> Self {
Self {
repr: StrRepr::String(s),
}
}
}
enum StringsRepr<'container, 'innards> {
Bytes {
original: &'innards [u8],
position: usize,
end: usize,
},
Strings {
strings: &'container [&'innards str],
position: usize,
},
}
pub struct Strings<'container, 'innards> {
repr: StringsRepr<'container, 'innards>,
}
impl<'innards> Iterator for Strings<'_, 'innards> {
type Item = Result<Str<'innards>, ProtoError>;
fn next(&mut self) -> Option<Self::Item> {
match &mut self.repr {
StringsRepr::Bytes {
original,
position,
end,
} => {
if *position >= *end {
return None;
}
let result = decode_txt_segment(original, *position, *end);
match result {
Ok((segment, new_position)) => {
*position = new_position;
Some(Ok(segment))
}
Err(e) => {
*position = *end;
Some(Err(e))
}
}
}
StringsRepr::Strings { strings, position } => {
if *position >= strings.len() {
return None;
}
let string = strings[*position];
*position += 1;
Some(Ok(Str::new(string)))
}
}
}
}
fn decode_txt_segment(
msg: &[u8],
mut offset: usize,
end: usize,
) -> Result<(Str<'_>, usize), ProtoError> {
if offset + 1 > msg.len() || offset >= end {
return Err(not_enough_read_data(1, 0));
}
let length = msg[offset] as usize;
offset += 1;
let content_start = offset;
let content_end = content_start + length;
if content_end > msg.len() {
return Err(not_enough_read_data(length, content_end - msg.len()));
}
if content_end > end {
return Err(not_enough_read_data(length, content_end - end));
}
let mut consumed = 0;
for (i, &b) in msg[offset..offset + length].iter().enumerate() {
match () {
() if (b == b'"' || b == b'\\') || !(b' '..=b'~').contains(&b) => {
consumed = i + 1;
}
_ => {}
}
}
if consumed == 0 {
return simdutf8::compat::from_utf8(&msg[offset..offset + length])
.map(|s| (Str::new(s), offset + length))
.map_err(Into::into);
}
let segment = Str::from_bytes(msg, content_start, length);
Ok((segment, content_end))
}
#[inline]
const fn escape_bytes(b: u8) -> [u8; 4] {
let mut buf = [0; 4];
buf[0] = b'\\';
buf[1] = b'0' + (b / 100);
buf[2] = b'0' + ((b / 10) % 10);
buf[3] = b'0' + (b % 10);
buf
}