use std::cmp;
use std::io::{Error, ErrorKind, Read, Result};
use std::str;
use thiserror::Error;
const MAX_UTF8_SEQUENCE_SIZE: usize = 4;
#[derive(Debug, Error)]
#[error("Invalid UTF-8 sequence")]
pub(crate) struct UTF8ReaderError;
pub struct Utf8Reader<R> {
inner: R,
buffer: [u8; MAX_UTF8_SEQUENCE_SIZE],
buffer_start: usize,
buffer_end: usize,
}
impl<R> Utf8Reader<R> {
pub fn new(inner: R) -> Self {
Self {
inner,
buffer: [0; MAX_UTF8_SEQUENCE_SIZE],
buffer_start: 0,
buffer_end: 0,
}
}
fn buffer_to_read(&self) -> &[u8] {
&self.buffer[self.buffer_start..self.buffer_end]
}
fn advance(&mut self, amt: usize) {
self.buffer_start += amt;
}
fn read_from_buffer(&mut self, buf: &mut [u8]) -> Result<usize> {
let bytes_copied = slice_copy(buf, self.buffer_to_read());
self.advance(bytes_copied);
Ok(bytes_copied)
}
}
impl<R> Utf8Reader<R>
where
R: Read,
{
fn read_from_inner(&mut self, buf: &mut [u8]) -> Result<usize> {
let read_from_inner = self.inner.read(buf)?;
let read_portion = &buf[..read_from_inner];
let invalid_portion = ending_incomplete_utf8_sequence(read_portion)?;
slice_copy(&mut self.buffer, invalid_portion);
self.read_into_buffer_until_utf8(invalid_portion.len())?;
Ok(read_from_inner)
}
fn read_into_buffer_until_utf8(&mut self, start_index: usize) -> Result<()> {
let bytes_until_utf8 = read_until_utf8(&mut self.inner, &mut self.buffer, start_index)?;
self.buffer_start = start_index;
self.buffer_end = bytes_until_utf8;
Ok(())
}
}
impl<R> Read for Utf8Reader<R>
where
R: Read,
{
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let read_from_buffer = self.read_from_buffer(buf)?;
let read_from_inner = match self.buffer_to_read() {
[] => self.read_from_inner(&mut buf[read_from_buffer..])?,
_ => 0,
};
Ok(read_from_buffer + read_from_inner)
}
}
fn read_until_utf8(
reader: &mut impl Read,
buffer: &mut [u8],
mut current_index: usize,
) -> Result<usize> {
while str::from_utf8(&buffer[..current_index]).is_err() {
if current_index >= MAX_UTF8_SEQUENCE_SIZE
|| reader.read(&mut buffer[current_index..current_index + 1])? == 0
{
return Err(Error::new(ErrorKind::InvalidData, UTF8ReaderError));
}
current_index += 1;
}
Ok(current_index)
}
fn utf8_up_to(bytes: &[u8]) -> usize {
match str::from_utf8(bytes) {
Ok(_) => bytes.len(),
Err(e) => e.valid_up_to(),
}
}
fn ending_incomplete_utf8_sequence(bytes: &[u8]) -> Result<&[u8]> {
let valid_up_to = utf8_up_to(bytes);
let invalid_portion = &bytes[valid_up_to..];
if invalid_portion.len() >= MAX_UTF8_SEQUENCE_SIZE {
Err(Error::new(ErrorKind::InvalidData, UTF8ReaderError))
} else {
Ok(invalid_portion)
}
}
fn slice_copy<T>(dst: &mut [T], src: &[T]) -> usize
where
T: Copy,
{
let elements_to_copy = cmp::min(dst.len(), src.len());
dst[..elements_to_copy].copy_from_slice(&src[..elements_to_copy]);
elements_to_copy
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
use core::str;
use std::io::Cursor;
#[test]
fn test_read_empty() {
let mut empty_reader = Utf8Reader::new(Cursor::new(b""));
let mut buf = vec![];
empty_reader
.read_to_end(&mut buf)
.expect("read_to_end errored");
assert_eq!(buf, b"");
}
#[test]
fn test_read_ascii_simple() {
let mut reader = Utf8Reader::new(Cursor::new(b"Hello, world!"));
let mut buf = vec![];
reader.read_to_end(&mut buf).expect("read_to_end errored");
assert_eq!(buf, b"Hello, world!");
}
#[test]
fn test_read_utf8_simple() {
const HELLO_WORLD: &str = "你好世界!";
let mut reader = Utf8Reader::new(Cursor::new(HELLO_WORLD.as_bytes()));
let mut buf = vec![];
reader.read_to_end(&mut buf).expect("read_to_end errored");
assert_eq!(buf, HELLO_WORLD.as_bytes());
}
#[test]
fn small_reads_splitting_sequence() {
let mut reader = Utf8Reader::new(Cursor::new("🙂".as_bytes()));
let mut buf = [0; MAX_UTF8_SEQUENCE_SIZE];
for i in 0..MAX_UTF8_SEQUENCE_SIZE {
let bytes_read = reader.read(&mut buf[i..i + 1]).expect("read errored");
assert_eq!(bytes_read, 1, "bytes read");
}
assert_eq!(&buf[..], "🙂".as_bytes());
}
#[test]
fn invalid_utf8_sequence() {
let mut reader = Utf8Reader::new(Cursor::new([0b11111111]));
let mut buf = [0; 1];
reader.read(&mut buf).expect_err("read should have errored");
}
#[test]
fn invalid_utf8_sequence_at_end_of_reader() {
let mut read_buffer = Vec::from(b"Hello, world!");
let invalid_sequence = &"🙂".as_bytes()[..'🙂'.len_utf8() - 1];
read_buffer.extend(invalid_sequence);
let mut reader = Utf8Reader::new(Cursor::new(&read_buffer));
reader
.read_to_end(&mut vec![])
.expect_err("read should have errored");
}
proptest! {
#[test]
fn read_arbitrary_string(s in any::<String>()) {
let mut buf = vec![];
let mut reader = Utf8Reader::new(Cursor::new(&s));
reader.read_to_end(&mut buf).unwrap();
assert_eq!(str::from_utf8(&buf).unwrap(), s);
}
#[test]
fn dont_read_arbitrary_nonstring(s in any::<Vec<u8>>()) {
prop_assume!(str::from_utf8(&s).is_err());
let mut buf = vec![];
let mut reader = Utf8Reader::new(Cursor::new(&s));
reader.read_to_end(&mut buf).unwrap_err();
}
}
}