use encoding_rs::{CoderResult, Encoder, Encoding, UTF_8};
use thiserror::Error;
#[derive(Error, Debug, Eq, PartialEq, Copy, Clone)]
#[error("Invalid UTF-8")]
pub struct Utf8Error;
enum Buffer {
Stack([u8; 63]), Heap(Vec<u8>),
}
impl Buffer {
const CONTENT_WRITE_LENGTH_LONG_ENOUGH_TO_USE_LARGER_BUFFER: usize = 1 << 20;
const DEFAULT_HEAP_BUFFER_SIZE: usize = 4096;
fn buffer_for_length(&mut self, content_len: usize) -> &mut [u8] {
match self {
Self::Heap(buf) => buf.as_mut_slice(),
buf if content_len >= Self::CONTENT_WRITE_LENGTH_LONG_ENOUGH_TO_USE_LARGER_BUFFER => {
*buf = Self::Heap(vec![0; Self::DEFAULT_HEAP_BUFFER_SIZE]);
match buf {
Self::Heap(buf) => buf.as_mut(),
_ => unreachable!(),
}
}
Self::Stack(buf) => buf.as_mut_slice(),
}
}
}
pub(crate) struct TextEncoder {
encoder: Encoder,
buffer: Buffer,
}
impl TextEncoder {
#[inline]
pub fn new(encoding: &'static Encoding) -> Self {
debug_assert!(encoding != UTF_8);
debug_assert!(encoding.is_ascii_compatible());
Self {
encoder: encoding.new_encoder(),
buffer: Buffer::Stack([0; 63]),
}
}
#[inline(never)]
pub fn encode(&mut self, mut content: &str, output_handler: &mut dyn FnMut(&[u8])) {
loop {
debug_assert!(!self.encoder.has_pending_state()); let ascii_len = Encoding::ascii_valid_up_to(content.as_bytes());
if let Some((ascii, remainder)) = content.split_at_checked(ascii_len) {
if !ascii.is_empty() {
(output_handler)(ascii.as_bytes());
}
if remainder.is_empty() {
return;
}
content = remainder;
}
let buffer = self.buffer.buffer_for_length(content.len());
let (result, read, written, _) = self.encoder.encode_from_utf8(content, buffer, false);
if written > 0 && written <= buffer.len() {
(output_handler)(&buffer[..written]);
}
content = match content.get(read..) {
Some(rest) if !rest.is_empty() => rest,
_ => return,
};
match result {
CoderResult::InputEmpty => {
debug_assert!(content.is_empty());
return;
}
CoderResult::OutputFull if written > 0 => {}
CoderResult::OutputFull => {
if buffer.len() >= Buffer::DEFAULT_HEAP_BUFFER_SIZE {
debug_assert!(false, "encoding_rs stalled");
return;
}
self.buffer = Buffer::Heap(vec![0; Buffer::DEFAULT_HEAP_BUFFER_SIZE]);
}
}
}
}
}
const fn is_continuation_byte(b: u8) -> bool {
(b >> 6) == 0b10
}
const fn utf8_width(b: u8) -> u8 {
b.leading_ones() as _
}
pub(crate) struct IncompleteUtf8Resync {
char_bytes: [u8; 4],
char_len: u8,
}
impl IncompleteUtf8Resync {
pub const fn new() -> Self {
Self {
char_bytes: [0; 4],
char_len: 0,
}
}
pub fn utf8_bytes_to_slice<'buf, 'src: 'buf>(
&'buf mut self,
mut content: &'src [u8],
) -> Result<(&'buf str, &'src [u8]), Utf8Error> {
if self.char_len > 0 {
let mut must_emit_now = false;
while let Some((&next_byte, rest)) = content.split_first() {
if is_continuation_byte(next_byte) {
if let Some(buf) = self.char_bytes.get_mut(self.char_len as usize) {
*buf = next_byte;
self.char_len += 1;
content = rest;
continue;
}
}
must_emit_now = true;
break;
}
if self.char_len >= utf8_width(self.char_bytes[0]) {
must_emit_now = true;
}
if must_emit_now {
let char_buf = self
.char_bytes
.get(..self.char_len as usize)
.ok_or(Utf8Error)?;
self.char_len = 0;
let ch = std::str::from_utf8(char_buf).map_err(|_| Utf8Error)?;
Ok((ch, content))
} else {
debug_assert!(content.is_empty());
Ok(("", b""))
}
} else {
match std::str::from_utf8(content) {
Ok(src) => Ok((src, b"")),
Err(err) if err.error_len().is_some() => Err(Utf8Error),
Err(err) => {
let (valid, invalid) = content
.split_at_checked(err.valid_up_to())
.ok_or(Utf8Error)?;
self.char_bytes
.get_mut(..invalid.len())
.ok_or(Utf8Error)?
.copy_from_slice(invalid);
self.char_len = invalid.len() as _;
let valid = std::str::from_utf8(valid).map_err(|_| Utf8Error)?;
Ok((valid, b""))
}
}
}
}
pub fn discard_incomplete(&mut self) -> bool {
if self.char_len > 0 {
self.char_len = 0;
true
} else {
false
}
}
pub fn write_utf8_chunk(
&mut self,
mut content: &[u8],
mut flush: impl FnMut(&str),
) -> Result<(), Utf8Error> {
while !content.is_empty() {
let (valid_chunk, rest) = self.utf8_bytes_to_slice(content)?;
content = rest;
if !valid_chunk.is_empty() {
flush(valid_chunk);
}
}
Ok(())
}
}
#[test]
fn chars() {
let boundaries = "đ°æććăăăȘă"
.as_bytes()
.iter()
.map(|&ch| {
if is_continuation_byte(ch) {
'.'
} else {
(b'0' + utf8_width(ch)) as char
}
})
.collect::<String>();
assert_eq!("4...2.3..3..3..3..3..3..3..", boundaries);
}