use std::hint::unreachable_unchecked;
use crate::proto::ProtocolError;
#[cfg(feature = "simd")]
const FROM_UTF8_BASIC: for<'a> fn(&'a [u8]) -> Result<&'a str, simdutf8::basic::Utf8Error> =
simdutf8::basic::from_utf8;
#[cfg(feature = "simd")]
const FROM_UTF8_COMPAT: for<'a> fn(&'a [u8]) -> Result<&'a str, simdutf8::compat::Utf8Error> =
simdutf8::compat::from_utf8;
#[cfg(not(feature = "simd"))]
const FROM_UTF8_BASIC: for<'a> fn(&'a [u8]) -> Result<&'a str, std::str::Utf8Error> =
std::str::from_utf8;
#[cfg(not(feature = "simd"))]
const FROM_UTF8_COMPAT: for<'a> fn(&'a [u8]) -> Result<&'a str, std::str::Utf8Error> =
std::str::from_utf8;
#[inline]
pub fn parse_str(input: &[u8]) -> Result<&str, ProtocolError> {
FROM_UTF8_BASIC(input).map_err(|_| ProtocolError::InvalidUtf8)
}
#[derive(Debug)]
pub(crate) struct Validator {
partial_codepoint: [u8; 4],
partial_codepoint_len: usize,
}
impl Validator {
#[cfg(any(feature = "client", feature = "server"))]
pub fn new() -> Self {
Self {
partial_codepoint: [0; 4],
partial_codepoint_len: 0,
}
}
fn complete_codepoint_len(&self) -> usize {
match unsafe { self.partial_codepoint.get_unchecked(0) } {
0b0000_0000..=0b0111_1111 => 1,
0b1100_0000..=0b1101_1111 => 2,
0b1110_0000..=0b1110_1111 => 3,
0b1111_0000..=0b1111_0111 => 4,
_ => unsafe { unreachable_unchecked() },
}
}
#[inline]
pub fn reset(&mut self) {
self.partial_codepoint_len = 0;
}
pub fn feed(&mut self, input: &[u8], is_complete: bool) -> Result<(), ProtocolError> {
let remaining_bytes = if self.partial_codepoint_len == 0 {
input
} else {
let available_bytes = input.len();
if available_bytes == 0 && !is_complete {
return Ok(());
}
let complete_codepoint_len = self.complete_codepoint_len();
let missing_bytes = complete_codepoint_len - self.partial_codepoint_len;
let bytes_to_copy = available_bytes.min(missing_bytes);
let codepoint_len_after_copy = self.partial_codepoint_len + bytes_to_copy;
unsafe {
self.partial_codepoint
.get_unchecked_mut(self.partial_codepoint_len..codepoint_len_after_copy)
.copy_from_slice(input.get_unchecked(..bytes_to_copy));
}
if available_bytes == missing_bytes {
if FROM_UTF8_BASIC(unsafe {
self.partial_codepoint
.get_unchecked(..codepoint_len_after_copy)
})
.is_err()
{
return Err(ProtocolError::InvalidUtf8);
}
} else {
match FROM_UTF8_COMPAT(unsafe {
self.partial_codepoint
.get_unchecked(..codepoint_len_after_copy)
}) {
Ok(_) => {}
Err(utf8_error) if utf8_error.error_len().is_some() => {
return Err(ProtocolError::InvalidUtf8);
}
Err(_) => {
self.partial_codepoint_len = codepoint_len_after_copy;
if is_complete {
return Err(ProtocolError::InvalidUtf8);
}
return Ok(());
}
}
}
self.reset();
unsafe { input.get_unchecked(bytes_to_copy..) }
};
if is_complete {
self.reset();
match FROM_UTF8_BASIC(remaining_bytes) {
Ok(_) => Ok(()),
Err(_) => Err(ProtocolError::InvalidUtf8),
}
} else {
match FROM_UTF8_COMPAT(remaining_bytes) {
Ok(_) => Ok(()),
Err(utf8_error) if utf8_error.error_len().is_some() => {
Err(ProtocolError::InvalidUtf8)
}
Err(utf8_error) => {
self.partial_codepoint_len = input.len() - utf8_error.valid_up_to();
unsafe {
self.partial_codepoint
.get_unchecked_mut(..self.partial_codepoint_len)
.copy_from_slice(input.get_unchecked(utf8_error.valid_up_to()..));
}
Ok(())
}
}
}
}
}