use std::collections::VecDeque;
use std::convert::Infallible;
use std::io::{self, BufReader, Read};
pub trait Reader {
type Error: std::error::Error;
fn read_char(&mut self) -> Result<Option<char>, Self::Error>;
fn try_read_string(&mut self, s: &str, case_sensitive: bool) -> Result<bool, Self::Error>;
fn len_of_char_in_current_encoding(&self, c: char) -> usize;
}
pub trait IntoReader<'a> {
type Reader: Reader + 'a;
fn into_reader(self) -> Self::Reader;
}
impl<'a, R: 'a + Reader> IntoReader<'a> for R {
type Reader = Self;
fn into_reader(self) -> Self::Reader {
self
}
}
impl<R: Reader + ?Sized> Reader for Box<R> {
type Error = R::Error;
fn read_char(&mut self) -> Result<Option<char>, Self::Error> {
self.as_mut().read_char()
}
fn try_read_string(&mut self, s: &str, case_sensitive: bool) -> Result<bool, Self::Error> {
self.as_mut().try_read_string(s, case_sensitive)
}
fn len_of_char_in_current_encoding(&self, c: char) -> usize {
self.as_ref().len_of_char_in_current_encoding(c)
}
}
pub struct StringReader<'a> {
input: &'a str,
cursor: std::str::Chars<'a>,
pos: usize,
}
impl<'a> StringReader<'a> {
fn new(input: &'a str) -> Self {
let cursor = input.chars();
StringReader {
input,
cursor,
pos: 0,
}
}
}
impl<'a> Reader for StringReader<'a> {
type Error = Infallible;
fn read_char(&mut self) -> Result<Option<char>, Self::Error> {
let c = match self.cursor.next() {
Some(c) => c,
None => return Ok(None),
};
self.pos += c.len_utf8();
Ok(Some(c))
}
fn try_read_string(&mut self, s1: &str, case_sensitive: bool) -> Result<bool, Self::Error> {
if let Some(s2) = self.input.get(self.pos..self.pos + s1.len()) {
if s1 == s2 || (!case_sensitive && s1.eq_ignore_ascii_case(s2)) {
self.pos += s1.len();
self.cursor = self.input[self.pos..].chars();
return Ok(true);
}
}
Ok(false)
}
fn len_of_char_in_current_encoding(&self, c: char) -> usize {
c.len_utf8()
}
}
impl<'a> IntoReader<'a> for &'a str {
type Reader = StringReader<'a>;
fn into_reader(self) -> Self::Reader {
StringReader::new(self)
}
}
impl<'a> IntoReader<'a> for &'a String {
type Reader = StringReader<'a>;
fn into_reader(self) -> Self::Reader {
StringReader::new(self.as_str())
}
}
const BUF_SIZE: usize = 8 * 1024;
pub struct BufReadReader<R: Read> {
reader: R,
buffer: [u8; BUF_SIZE],
read: usize,
pos: usize,
chars: VecDeque<char>,
error: Option<io::Error>,
eof: bool,
}
impl<R: Read> BufReadReader<R> {
pub fn new(reader: R) -> Self {
BufReadReader {
reader,
buffer: [0; BUF_SIZE],
read: 0,
pos: 0,
chars: VecDeque::new(),
error: None,
eof: false,
}
}
#[inline]
fn read(&mut self) -> Result<(), io::Error> {
debug_assert!(!self.eof);
debug_assert!(self.error.is_none());
if self.pos == self.read {
self.read = match self.reader.read(&mut self.buffer)? {
0 => {
self.eof = true;
return Ok(());
}
n => n,
};
self.pos = 0;
}
let unprocessed = &self.buffer[self.pos..self.read];
let (valid_str, err) = match std::str::from_utf8(unprocessed) {
Ok(s) => (s, None),
Err(err) => (
unsafe { std::str::from_utf8_unchecked(&unprocessed[..err.valid_up_to()]) },
Some(err),
),
};
for c in valid_str.chars() {
self.chars.push_back(c);
}
self.pos += valid_str.len();
if let Some(err) = err {
self.error = Some(io::Error::new(io::ErrorKind::InvalidData, err));
match err.error_len() {
None => self.eof = true,
Some(error_len) => self.pos += error_len,
}
}
Ok(())
}
}
impl<R: Read> Reader for BufReadReader<R> {
type Error = io::Error;
fn read_char(&mut self) -> Result<Option<char>, Self::Error> {
if let Some(char) = self.chars.pop_front() {
return Ok(Some(char));
}
if let Some(error) = self.error.take() {
return Err(error);
}
if self.eof {
return Ok(None);
}
self.read()?;
if let Some(char) = self.chars.pop_front() {
return Ok(Some(char));
}
if let Some(error) = self.error.take() {
return Err(error);
}
debug_assert!(self.eof);
Ok(None)
}
fn try_read_string(&mut self, s1: &str, case_sensitive: bool) -> Result<bool, Self::Error> {
debug_assert!(!s1.contains('\r'));
debug_assert!(!s1.contains('\n'));
debug_assert!(s1.len() <= self.buffer.len());
while self.chars.len() < s1.len() {
if self.error.is_some() {
return Ok(false);
}
if self.eof {
return Ok(false);
}
self.read()?;
}
for (c, expected) in std::iter::zip(self.chars.iter(), s1.chars()) {
if case_sensitive {
if *c != expected {
return Ok(false);
}
} else {
if !c.eq_ignore_ascii_case(&expected) {
return Ok(false);
}
}
}
self.chars.drain(..s1.len());
Ok(true)
}
fn len_of_char_in_current_encoding(&self, c: char) -> usize {
c.len_utf8()
}
}
impl<'a, R: Read + 'a> IntoReader<'a> for BufReader<R> {
type Reader = BufReadReader<BufReader<R>>;
fn into_reader(self) -> Self::Reader {
BufReadReader::new(self)
}
}
#[cfg(test)]
mod tests {
use std::io::{BufReader, ErrorKind};
use std::str::Utf8Error;
use super::{IntoReader, Reader};
#[test]
fn buf_read_reader_invalid_utf8() {
let mut reader = BufReader::new(b" \xc3\x28" as &[u8]).into_reader();
assert_eq!(reader.read_char().unwrap(), Some(' '));
let error = reader.read_char().unwrap_err();
assert!(matches!(error.kind(), ErrorKind::InvalidData));
error.into_inner().unwrap().downcast::<Utf8Error>().unwrap();
assert_eq!(reader.read_char().unwrap(), Some('('));
assert_eq!(reader.read_char().unwrap(), None);
}
}