use std::{fmt, io, slice, str};
use encoding_rs::Decoder;
use super::{util, MalformedError};
pub struct DecodingReader<R> {
reader: BufReadWithFallbackBuffer<R>,
decoder: Option<Decoder>,
fallback_buf: util::MiniBuffer,
deferred_error: Option<MalformedError>,
}
impl<R: io::BufRead> DecodingReader<R> {
pub fn new(reader: R, decoder: Decoder) -> Self {
Self {
reader: reader.into(),
decoder: Some(decoder),
fallback_buf: Default::default(),
deferred_error: None,
}
}
pub fn reader_ref(&self) -> &R {
self.reader.get_ref()
}
pub fn decoder_ref(&self) -> Option<&Decoder> {
self.decoder.as_ref()
}
pub fn take_reader(self) -> (R, DecodingReader<impl io::BufRead>) {
let (reader, remainder) = self.reader.take_inner();
(
reader,
DecodingReader {
reader: remainder,
decoder: self.decoder,
fallback_buf: self.fallback_buf,
deferred_error: self.deferred_error,
},
)
}
pub fn unfused(&mut self) -> impl io::Read + '_ {
VariantReader::<'_, _, false, false> { inner: self }
}
pub fn lossy(&mut self) -> impl io::Read + '_ {
VariantReader::<'_, _, true, true> { inner: self }
}
pub fn lossy_unfused(&mut self) -> impl io::Read + '_ {
VariantReader::<'_, _, true, false> { inner: self }
}
fn read_impl<const LOSSY: bool, const FUSED: bool>(
&mut self,
buf: &mut [u8],
) -> io::Result<usize> {
if !self.fallback_buf.is_empty() {
return Ok(self.fallback_buf.read_to_slice(buf));
} else if let Some(e) = self.deferred_error.take() {
return if !LOSSY {
Err(e.wrap())
} else {
const REPL: &[u8] = "\u{FFFD}".as_bytes();
if buf.len() >= REPL.len() {
buf[..REPL.len()].copy_from_slice(REPL);
Ok(REPL.len())
} else {
self.fallback_buf.fill_from_slice(REPL);
Ok(self.fallback_buf.read_to_slice(buf))
}
};
} else if self.decoder.is_none() || buf.is_empty() {
return Ok(0);
}
debug_assert!(self.fallback_buf.is_empty());
debug_assert!(self.deferred_error.is_none());
debug_assert!(self.decoder.is_some() && !buf.is_empty());
let src = self.reader.fill_buf()?;
if src.is_empty() {
return if FUSED {
self.close_decoder::<LOSSY>(buf)
} else {
Ok(0)
};
}
let decoder = self.decoder.as_mut().unwrap();
let written = if !LOSSY {
let (result, consumed, written) =
decode_with_fallback_buf(buf, &mut self.fallback_buf, |dst| {
decoder.decode_to_utf8_without_replacement(src, dst, false)
});
self.reader.consume(consumed);
if let encoding_rs::DecoderResult::Malformed(..) = result {
if written == 0 {
return Err(MalformedError::new().wrap());
}
self.deferred_error = Some(MalformedError::new());
}
written
} else {
let (_, consumed, written) =
decode_with_fallback_buf(buf, &mut self.fallback_buf, |dst| {
let ret = decoder.decode_to_utf8(src, dst, false);
(ret.0, ret.1, ret.2)
});
self.reader.consume(consumed);
written
};
debug_assert!(self.check_utf8_guarantee(&buf[..written]).is_ok());
if FUSED && written == 0 {
self.close_decoder::<LOSSY>(buf)
} else {
Ok(written)
}
}
fn close_decoder<const LOSSY: bool>(&mut self, buf: &mut [u8]) -> io::Result<usize> {
debug_assert!(self.fallback_buf.is_empty());
debug_assert!(self.deferred_error.is_none());
debug_assert!(self.decoder.is_some() && !buf.is_empty());
let mut decoder = self.decoder.take().unwrap();
let written = if !LOSSY {
let (result, _, written) =
decode_with_fallback_buf(buf, &mut self.fallback_buf, |dst| {
decoder.decode_to_utf8_without_replacement(&[], dst, true)
});
if let encoding_rs::DecoderResult::Malformed(..) = result {
if written == 0 {
return Err(MalformedError::new().wrap());
}
self.deferred_error = Some(MalformedError::new());
}
written
} else {
let (_, _, written) = decode_with_fallback_buf(buf, &mut self.fallback_buf, |dst| {
let ret = decoder.decode_to_utf8(&[], dst, true);
(ret.0, ret.1, ret.2)
});
written
};
debug_assert!(self.check_utf8_guarantee(&buf[..written]).is_ok());
Ok(written)
}
fn check_utf8_guarantee(&self, buf_written: &[u8]) -> Result<(), str::Utf8Error> {
if self.fallback_buf.is_empty() {
str::from_utf8(buf_written).and(Ok(()))
} else {
let mut v = Vec::with_capacity(buf_written.len() + self.fallback_buf.len());
v.extend(buf_written);
v.extend(self.fallback_buf.as_ref());
str::from_utf8(&v).and(Ok(()))
}
}
}
impl<R: io::BufRead> ReadToStringAdapter for DecodingReader<R> {
fn has_read_valid_utf8(&self) -> bool {
self.fallback_buf.is_empty() || str::from_utf8(self.fallback_buf.as_ref()).is_ok()
}
}
impl<R: io::BufRead> io::Read for DecodingReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.read_impl::<false, true>(buf)
}
fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
read_to_string_impl(self, buf)
}
}
impl<R: fmt::Debug> fmt::Debug for DecodingReader<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
struct Wrapper<'a>(&'a Decoder);
impl fmt::Debug for Wrapper<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("Decoder")
.field("encoding()", self.0.encoding())
.finish_non_exhaustive()
}
}
f.debug_struct("DecodingReader")
.field("reader", &self.reader)
.field("decoder", &self.decoder.as_ref().map(Wrapper))
.field("fallback_buf", &self.fallback_buf)
.field("deferred_error", &self.deferred_error)
.finish()
}
}
struct VariantReader<'a, R, const LOSSY: bool, const FUSED: bool> {
inner: &'a mut DecodingReader<R>,
}
impl<R: io::BufRead, const LOSSY: bool, const FUSED: bool> ReadToStringAdapter
for VariantReader<'_, R, LOSSY, FUSED>
{
fn has_read_valid_utf8(&self) -> bool {
self.inner.has_read_valid_utf8()
}
}
impl<R: io::BufRead, const LOSSY: bool, const FUSED: bool> io::Read
for VariantReader<'_, R, LOSSY, FUSED>
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read_impl::<LOSSY, FUSED>(buf)
}
fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
read_to_string_impl(self, buf)
}
}
fn decode_with_fallback_buf<T>(
dst_buf: &mut [u8],
fallback_buf: &mut util::MiniBuffer,
mut decode: impl FnMut(&mut [u8]) -> (T, usize, usize),
) -> (T, usize, usize) {
debug_assert!(fallback_buf.is_empty());
if dst_buf.len() > fallback_buf.unfilled().len() {
decode(dst_buf)
} else {
let (result, consumed, mut written) = decode(fallback_buf.unfilled());
if written > 0 {
fallback_buf.advance(written);
written = fallback_buf.read_to_slice(dst_buf);
}
(result, consumed, written)
}
}
trait ReadToStringAdapter: io::Read {
fn has_read_valid_utf8(&self) -> bool;
}
fn read_to_string_impl(
reader: &mut impl ReadToStringAdapter,
buf: &mut String,
) -> io::Result<usize> {
struct PanicGuard<'a> {
len: usize,
inner: &'a mut Vec<u8>,
}
impl Drop for PanicGuard<'_> {
fn drop(&mut self) {
self.inner.truncate(self.len);
}
}
let mut g = PanicGuard {
len: buf.len(),
inner: unsafe { buf.as_mut_vec() },
};
let ret = reader.read_to_end(g.inner);
if reader.has_read_valid_utf8() {
g.len = g.inner.len();
ret
} else {
ret?;
debug_assert!(false, "unreachable");
Err(io::Error::new(
io::ErrorKind::Other,
"failed to read to string unexpectedly",
))
}
}
#[derive(Debug, Default)]
struct BufReadWithFallbackBuffer<R> {
inner: R,
fallback_buf: util::MiniBuffer,
}
impl<R: io::BufRead> From<R> for BufReadWithFallbackBuffer<R> {
fn from(value: R) -> Self {
Self {
inner: value,
fallback_buf: Default::default(),
}
}
}
impl<R: io::BufRead> BufReadWithFallbackBuffer<R> {
fn get_ref(&self) -> &R {
&self.inner
}
fn take_inner(self) -> (R, BufReadWithFallbackBuffer<io::Empty>) {
(
self.inner,
BufReadWithFallbackBuffer {
inner: io::empty(),
fallback_buf: self.fallback_buf,
},
)
}
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if !self.fallback_buf.is_empty() {
self.fallback_buf.fill_from_reader(&mut self.inner)?;
return Ok(self.fallback_buf.as_ref());
}
{
let buf = self.inner.fill_buf()?;
if buf.is_empty() || buf.len() > self.fallback_buf.unfilled().len() {
return Ok(unsafe { slice::from_raw_parts(buf.as_ptr(), buf.len()) });
}
}
self.fallback_buf.fill_from_reader(&mut self.inner)?;
Ok(self.fallback_buf.as_ref())
}
fn consume(&mut self, amt: usize) {
let amt_fallback = amt.min(self.fallback_buf.len());
if amt_fallback > 0 {
self.fallback_buf.remove_front(amt_fallback);
}
self.inner.consume(amt - amt_fallback);
}
}
#[cfg(test)]
mod tests {
use std::io::Read;
use super::DecodingReader;
#[test]
fn trailing_malformed_bytes() {
use encoding_rs::SHIFT_JIS as Enc;
let src: &[u8] = &[b'h', b'e', b'l', b'l', b'o', 0xe0];
{
let mut reader = DecodingReader::new(src, Enc.new_decoder());
let mut dst = String::new();
assert!(reader.read_to_string(&mut dst).is_err());
assert_eq!(dst, "hello");
assert!(matches!(reader.read_to_string(&mut dst), Ok(0)));
assert!(matches!(reader.read_to_string(&mut dst), Ok(0)));
assert!(matches!(reader.read(&mut [0; 64]), Ok(0)));
assert!(matches!(reader.read(&mut [0; 64]), Ok(0)));
assert_eq!(dst, "hello");
assert!(matches!(
reader.take_reader().1.lossy().read_to_string(&mut dst),
Ok(0)
));
assert_eq!(dst, "hello");
}
{
let mut reader = DecodingReader::new(src, Enc.new_decoder());
let mut dst = [0; 64];
assert!(matches!(reader.read(&mut dst), Ok(5)));
assert_eq!(&dst[..5], b"hello");
assert!(reader.read(&mut dst[5..]).is_err());
assert!(matches!(reader.read(&mut dst[5..]), Ok(0)));
assert!(matches!(reader.read(&mut dst[5..]), Ok(0)));
assert_eq!(&dst[..5], b"hello");
assert!(matches!(
reader.take_reader().1.lossy().read(&mut dst[5..]),
Ok(0)
));
assert_eq!(&dst[..5], b"hello");
}
{
let mut reader = DecodingReader::new(src, Enc.new_decoder());
let mut dst = String::new();
assert!(matches!(reader.lossy().read_to_string(&mut dst), Ok(8)));
assert_eq!(dst, "hello\u{FFFD}");
assert!(matches!(
reader.take_reader().1.lossy().read_to_string(&mut dst),
Ok(0)
));
assert_eq!(dst, "hello\u{FFFD}");
}
{
let mut reader = DecodingReader::new(src, Enc.new_decoder());
let mut dst = String::new();
assert!(matches!(reader.unfused().read_to_string(&mut dst), Ok(5)));
assert_eq!(dst, "hello");
assert!(matches!(reader.unfused().read_to_string(&mut dst), Ok(0)));
assert!(matches!(reader.unfused().read_to_string(&mut dst), Ok(0)));
assert!(matches!(reader.unfused().read(&mut [0; 64]), Ok(0)));
assert!(matches!(reader.unfused().read(&mut [0; 64]), Ok(0)));
assert_eq!(dst, "hello");
assert!(matches!(
reader.take_reader().1.lossy().read_to_string(&mut dst),
Ok(3)
));
assert_eq!(dst, "hello\u{FFFD}");
}
{
let mut reader = DecodingReader::new(src, Enc.new_decoder());
let mut dst = String::new();
assert!(matches!(
reader.lossy_unfused().read_to_string(&mut dst),
Ok(5)
));
assert_eq!(dst, "hello");
assert!(matches!(
reader.lossy_unfused().read_to_string(&mut dst),
Ok(0)
));
assert!(matches!(
reader.lossy_unfused().read_to_string(&mut dst),
Ok(0)
));
assert!(matches!(reader.lossy_unfused().read(&mut [0; 64]), Ok(0)));
assert!(matches!(reader.lossy_unfused().read(&mut [0; 64]), Ok(0)));
assert_eq!(dst, "hello");
assert!(matches!(
reader.take_reader().1.lossy().read_to_string(&mut dst),
Ok(3)
));
assert_eq!(dst, "hello\u{FFFD}");
}
}
}