use encoding_rs_io::DecodeReaderBytesBuilder;
use saphyr_parser::{BorrowedInput, BufferedInput, Input};
use std::cell::{Cell, RefCell};
use std::io::{self, BufReader, Error, Read};
use std::rc::Rc;
type DynReader<'a> = Box<dyn Read + 'a>;
type DynBufReader<'a> = BufReader<DynReader<'a>>;
pub struct ReaderInput<'a>(BufferedInput<ChunkedChars<DynBufReader<'a>>>);
impl<'a> ReaderInput<'a> {
#[inline]
pub fn new(inner: BufferedInput<ChunkedChars<DynBufReader<'a>>>) -> Self {
Self(inner)
}
}
impl<'a> Input for ReaderInput<'a> {
#[inline]
fn lookahead(&mut self, count: usize) {
self.0.lookahead(count);
}
#[inline]
fn buflen(&self) -> usize {
self.0.buflen()
}
#[inline]
fn bufmaxlen(&self) -> usize {
self.0.bufmaxlen()
}
#[inline]
fn raw_read_ch(&mut self) -> char {
self.0.raw_read_ch()
}
#[inline]
fn raw_read_non_breakz_ch(&mut self) -> Option<char> {
self.0.raw_read_non_breakz_ch()
}
#[inline]
fn skip(&mut self) {
self.0.skip();
}
#[inline]
fn skip_n(&mut self, count: usize) {
self.0.skip_n(count);
}
#[inline]
fn peek(&self) -> char {
self.0.peek()
}
#[inline]
fn peek_nth(&self, n: usize) -> char {
self.0.peek_nth(n)
}
}
impl<'a> BorrowedInput<'a> for ReaderInput<'a> {
#[inline]
fn slice_borrowed(&self, _start: usize, _end: usize) -> Option<&'a str> {
None
}
}
pub(crate) type ReaderInputError = Rc<RefCell<Option<Error>>>;
pub(crate) type ReaderInputBytesRead = Rc<Cell<usize>>;
pub struct ChunkedChars<R: Read> {
max_bytes: Option<usize>,
bytes_read: ReaderInputBytesRead,
reader: R,
pub(crate) err: Rc<RefCell<Option<Error>>>,
}
impl<R: Read> ChunkedChars<R> {
pub fn new(
reader: R,
max_bytes: Option<usize>,
err: Rc<RefCell<Option<Error>>>,
bytes_read: ReaderInputBytesRead,
) -> Self {
Self {
max_bytes,
bytes_read,
reader,
err,
}
}
}
impl<R: Read> Iterator for ChunkedChars<R> {
type Item = char;
fn next(&mut self) -> Option<char> {
let mut buf = [0u8; 4];
if let Err(e) = self.reader.read_exact(&mut buf[..1]) {
match e.kind() {
io::ErrorKind::UnexpectedEof => return None, _ => {
self.err.replace(Some(e));
return None;
}
}
}
let first = buf[0];
let needed = if first < 0x80 {
1
} else if first & 0b1110_0000 == 0b1100_0000 {
2
} else if first & 0b1111_0000 == 0b1110_0000 {
3
} else if first & 0b1111_1000 == 0b1111_0000 {
4
} else {
self.err.replace(Some(io::Error::new(
io::ErrorKind::InvalidData,
"invalid UTF-8 leading byte",
)));
return None;
};
if needed > 1 {
let mut read = 0;
while read < needed - 1 {
match self.reader.read(&mut buf[1 + read..needed]) {
Ok(0) => {
self.err.replace(Some(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF in middle of UTF-8 codepoint",
)));
return None;
}
Ok(n) => read += n,
Err(e) => {
self.err.replace(Some(e));
return None;
}
}
}
}
let add = needed;
let total_bytes = self.bytes_read.get();
if let Some(limit) = self.max_bytes {
let new_total = total_bytes.saturating_add(add);
if new_total > limit {
self.err.replace(Some(io::Error::new(
io::ErrorKind::FileTooLarge,
format!("input size limit of {limit} bytes exceeded"),
)));
return None;
}
self.bytes_read.set(new_total);
} else {
self.bytes_read.set(total_bytes.saturating_add(add));
}
match std::str::from_utf8(&buf[..needed]) {
Ok(s) => s.chars().next(),
Err(e) => {
self.err
.replace(Some(io::Error::new(io::ErrorKind::InvalidData, e)));
None
}
}
}
}
pub fn buffered_input_from_reader_with_limit<'a, R: Read + 'a>(
reader: R,
max_bytes: Option<usize>,
) -> (ReaderInput<'a>, ReaderInputError, ReaderInputBytesRead) {
let error: ReaderInputError = Rc::new(RefCell::new(None));
let bytes_read: ReaderInputBytesRead = Rc::new(Cell::new(0));
let input = buffered_input_from_reader_with_limit_shared(
reader,
max_bytes,
error.clone(),
bytes_read.clone(),
);
(input, error, bytes_read)
}
pub fn buffered_input_from_reader_with_limit_shared<'a, R: Read + 'a>(
reader: R,
max_bytes: Option<usize>,
error: ReaderInputError,
bytes_read: ReaderInputBytesRead,
) -> ReaderInput<'a> {
let decoder = DecodeReaderBytesBuilder::new()
.encoding(None) .bom_override(true)
.build(reader);
let br = BufReader::new(Box::new(decoder) as DynReader<'a>);
let char_iter = ChunkedChars::new(br, max_bytes, error, bytes_read);
ReaderInput::new(BufferedInput::new(char_iter))
}
#[cfg(test)]
mod tests {
use crate::buffered_input::ReaderInput;
use crate::buffered_input::buffered_input_from_reader_with_limit;
use saphyr_parser::{Event, Parser};
use std::io::{Cursor, Read};
pub fn buffered_input_from_reader<'a, R: Read + 'a>(reader: R) -> ReaderInput<'a> {
buffered_input_from_reader_with_limit(reader, None).0
}
fn gather_core_events<'a>(mut p: Parser<'a, super::ReaderInput<'a>>) -> Vec<Event<'a>> {
let mut events = Vec::new();
for item in &mut p {
match item {
Ok((ev, _)) => {
match &ev {
Event::SequenceStart(_, _)
| Event::SequenceEnd
| Event::Scalar(..)
| Event::StreamStart
| Event::StreamEnd
| Event::DocumentStart(_)
| Event::DocumentEnd => {
events.push(ev.clone());
}
_ => {}
}
}
Err(_) => break,
}
}
events
}
#[test]
fn buffered_input_handles_utf16le_bom() {
let code_units: [u16; 9] = [
0xFEFF, '-' as u16,
'-' as u16,
'-' as u16,
'\n' as u16,
'[' as u16,
'1' as u16,
',' as u16,
' ' as u16,
];
let mut bytes = Vec::new();
for &cu in &code_units {
bytes.extend_from_slice(&cu.to_le_bytes());
}
for ch in ['2', ']', '\n'] {
bytes.extend_from_slice(&(ch as u16).to_le_bytes());
}
let cursor = Cursor::new(bytes);
let input = buffered_input_from_reader(cursor);
let parser = Parser::new(input);
let events = gather_core_events(parser);
assert!(
events
.iter()
.any(|e| matches!(e, Event::SequenceStart(_, _))),
"no SequenceStart in events: {:?}",
events
);
assert!(
events.iter().any(|e| matches!(e, Event::SequenceEnd)),
"no SequenceEnd in events: {:?}",
events
);
let scalars: Vec<String> = events
.iter()
.filter_map(|e| match e {
Event::Scalar(v, _, _, _) => Some(v.to_string()),
_ => None,
})
.collect();
assert!(
scalars.contains(&"1".to_string()) && scalars.contains(&"2".to_string()),
"expected scalar elements '1' and '2', got {:?}",
scalars
);
}
#[test]
fn buffered_input_handles_utf8_basic() {
let yaml = "---\n[foo, bar]\n";
let cursor = Cursor::new(yaml.as_bytes());
let input = buffered_input_from_reader(cursor);
let parser = Parser::new(input);
let events = gather_core_events(parser);
assert!(
events
.iter()
.any(|e| matches!(e, Event::SequenceStart(_, _))),
"no SequenceStart in events: {:?}",
events
);
assert!(
events.iter().any(|e| matches!(e, Event::SequenceEnd)),
"no SequenceEnd in events: {:?}",
events
);
let scalars: Vec<String> = events
.iter()
.filter_map(|e| match e {
Event::Scalar(v, _, _, _) => Some(v.to_string()),
_ => None,
})
.collect();
assert!(
scalars.contains(&"foo".to_string()) && scalars.contains(&"bar".to_string()),
"expected scalar elements 'foo' and 'bar', got {:?}",
scalars
);
}
}