use std::{fmt, io, str};
use encoding_rs::Encoder;
use super::{buffer::DefaultBuffer, util, MalformedError, UnmappableError};
const MIN_BUF_SIZE: usize = 32;
pub struct EncodingWriter<B> {
buffer: B,
encoder: Encoder,
deferred_error: Option<DefErr>,
}
impl<W: io::Write> EncodingWriter<DefaultBuffer<W>> {
pub fn new(writer: W, encoder: Encoder) -> Self {
const DEFAULT_BUF_SIZE: usize = if cfg!(target_os = "espidf") {
512
} else {
8 * 1024
};
Self::with_capacity(DEFAULT_BUF_SIZE, writer, encoder)
}
pub fn with_capacity(capacity: usize, writer: W, encoder: Encoder) -> Self {
let c = capacity.max(MIN_BUF_SIZE);
Self::with_buffer(DefaultBuffer::with_capacity(c, writer), encoder)
}
pub fn writer_ref(&self) -> &W {
self.buffer_ref().get_ref()
}
pub fn unwrap_writer(
self,
) -> Result<W, (W, impl Iterator<Item = io::Result<Vec<u8>>> + fmt::Debug)> {
let (mut buffer, iter) = match self.finish() {
Ok(buffer) => (buffer, None),
Err((buffer, iter)) => (buffer, Some(iter)),
};
let mut seq = Vec::new();
if let Err(e) = buffer.flush_buffer() {
seq.push(Err(e));
}
let (writer, unwritten) = buffer.into_parts();
match unwritten {
Ok(v) if v.is_empty() => {}
Ok(v) => seq.push(Ok(v)),
Err(e) => seq.push(Err(io::Error::new(io::ErrorKind::Other, e))),
}
if let Some(e) = iter {
seq.extend(e);
}
if seq.is_empty() {
Ok(writer)
} else {
Err((writer, seq.into_iter()))
}
}
}
impl<B: BufferedWrite> EncodingWriter<B> {
pub fn with_buffer(buffer: B, encoder: Encoder) -> Self {
Self {
buffer,
encoder,
deferred_error: None,
}
}
pub fn buffer_ref(&self) -> &B {
&self.buffer
}
pub fn encoder_ref(&self) -> &Encoder {
&self.encoder
}
pub fn finish(
mut self,
) -> Result<B, (B, impl Iterator<Item = io::Result<Vec<u8>>> + fmt::Debug)> {
let mut seq = Vec::new();
let max_remainder_len = self
.encoder
.max_buffer_length_from_utf8_without_replacement(0)
.unwrap();
if let Err(e) = self.realize_any_deferred_error() {
seq.push(Err(e));
} else if let Err(e) = self.buffer.try_reserve(max_remainder_len, None) {
seq.push(Err(e));
}
let result = if seq.is_empty() {
let unfilled = self.buffer.unfilled();
assert!(
unfilled.len() >= max_remainder_len,
"illegal `BufferedWrite` implementation"
);
let (result, _consumed, written) = self
.encoder
.encode_from_utf8_without_replacement("", unfilled, true);
self.buffer.advance(written);
result
} else {
let mut dst = vec![0; max_remainder_len];
let (result, _consumed, written) = self
.encoder
.encode_from_utf8_without_replacement("", &mut dst, true);
dst.truncate(written);
if !dst.is_empty() {
seq.push(Ok(dst));
}
result
};
if let encoding_rs::EncoderResult::Unmappable(c) = result {
seq.push(Err(UnmappableError::new(c).wrap()));
} else {
debug_assert!(matches!(result, encoding_rs::EncoderResult::InputEmpty));
}
if seq.is_empty() {
Ok(self.buffer)
} else {
Err((self.buffer, seq.into_iter()))
}
}
pub fn write_str(&mut self, buf: &str) -> io::Result<usize> {
if buf.is_empty() {
self.realize_deferred_error_except_incomplete_utf8()?;
return Ok(0);
} else {
self.realize_any_deferred_error()?;
}
self.buffer.try_reserve(MIN_BUF_SIZE, None)?;
Ok(self.write_str_inner(buf))
}
pub fn passthrough(&mut self) -> PassthroughWriter<'_, B> {
PassthroughWriter(self)
}
pub fn with_unmappable_handler<'a>(
&'a mut self,
handler: impl FnMut(UnmappableError, &mut PassthroughWriter<B>) -> io::Result<()> + 'a,
) -> impl io::Write + 'a {
struct WithUnmappableHandlerWriter<'a, B, H>(&'a mut EncodingWriter<B>, H);
impl<B: BufferedWrite, H> WithUnmappableHandlerWriter<'_, B, H>
where
H: FnMut(UnmappableError, &mut PassthroughWriter<B>) -> io::Result<()>,
{
fn handle_deferred_unmappable_error(&mut self) -> io::Result<()> {
match self.0.deferred_error {
Some(DefErr::Unmappable(..)) => match self.0.deferred_error.take() {
Some(DefErr::Unmappable(e)) => (self.1)(e, &mut self.0.passthrough()),
_ => unreachable!(),
},
_ => Ok(()),
}
}
}
impl<B: BufferedWrite, H> WriteFmtAdapter for WithUnmappableHandlerWriter<'_, B, H>
where
H: FnMut(UnmappableError, &mut PassthroughWriter<B>) -> io::Result<()>,
{
fn write_str_io(&mut self, buf: &str) -> io::Result<usize> {
self.handle_deferred_unmappable_error()?;
self.0.write_str_io(buf)
}
}
impl<B: BufferedWrite, H> io::Write for WithUnmappableHandlerWriter<'_, B, H>
where
H: FnMut(UnmappableError, &mut PassthroughWriter<B>) -> io::Result<()>,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.handle_deferred_unmappable_error()?;
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.handle_deferred_unmappable_error()?;
self.0.flush()
}
fn write_fmt(&mut self, f: fmt::Arguments<'_>) -> io::Result<()> {
self.handle_deferred_unmappable_error()?;
write_fmt_impl(self, f)
}
}
WithUnmappableHandlerWriter(self, handler)
}
fn write_str_inner(&mut self, buf: &str) -> usize {
debug_assert!(!buf.is_empty());
debug_assert!(self.deferred_error.is_none());
let unfilled = self.buffer.unfilled();
assert!(
unfilled.len() >= MIN_BUF_SIZE,
"illegal `BufferedWrite` implementation"
);
let (result, consumed, written) = self
.encoder
.encode_from_utf8_without_replacement(buf, unfilled, false);
self.buffer.advance(written);
debug_assert_ne!(consumed, 0);
debug_assert!(buf.is_char_boundary(consumed), "encoder broke contract");
if let encoding_rs::EncoderResult::Unmappable(c) = result {
self.deferred_error = Some(DefErr::Unmappable(UnmappableError::new(c)));
}
consumed
}
fn realize_any_deferred_error(&mut self) -> io::Result<()> {
match self.deferred_error.take() {
None => Ok(()),
Some(DefErr::Unmappable(e)) => Err(e.wrap()),
Some(DefErr::MalformedUtf8(e)) => Err(e.wrap()),
Some(DefErr::IncompleteUtf8(..)) => Err(MalformedError::new().wrap()),
}
}
fn realize_deferred_error_except_incomplete_utf8(&mut self) -> io::Result<()> {
match self.deferred_error {
None | Some(DefErr::IncompleteUtf8(..)) => Ok(()),
_ => self.realize_any_deferred_error(),
}
}
}
impl<B: BufferedWrite> WriteFmtAdapter for EncodingWriter<B> {
fn write_str_io(&mut self, buf: &str) -> io::Result<usize> {
self.write_str(buf)
}
}
impl<B: BufferedWrite> io::Write for EncodingWriter<B> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.realize_deferred_error_except_incomplete_utf8()?;
if buf.is_empty() {
return Ok(0);
}
self.buffer.try_reserve(MIN_BUF_SIZE, None)?;
Ok(match self.deferred_error.take() {
None => match str_from_utf8_up_to_error(buf) {
Ok(s) => self.write_str_inner(s),
Err(Some(error_len)) => {
self.deferred_error = Some(DefErr::MalformedUtf8(MalformedError::new()));
error_len
}
Err(None) => {
let mut bs = util::MiniBuffer::default();
let len = bs.fill_from_slice(buf);
debug_assert!(len == bs.len());
assert!(len < 4 && len == buf.len());
self.deferred_error = Some(DefErr::IncompleteUtf8(bs));
len
}
},
Some(DefErr::Unmappable(..)) => unreachable!(),
Some(DefErr::MalformedUtf8(..)) => unreachable!(),
Some(DefErr::IncompleteUtf8(mut bs)) => {
let old_len = bs.len();
let new_len = old_len + bs.fill_from_slice(buf);
debug_assert!(old_len < new_len && new_len == bs.len());
let consumed = match str_from_utf8_up_to_error(bs.as_ref()) {
Ok(s) => self.write_str_inner(s),
Err(Some(error_len)) => {
self.deferred_error = Some(DefErr::MalformedUtf8(MalformedError::new()));
error_len
}
Err(None) => {
assert!(new_len < 4 && new_len == old_len + buf.len());
self.deferred_error = Some(DefErr::IncompleteUtf8(bs));
new_len
}
};
assert!(old_len < consumed);
consumed - old_len
}
})
}
fn flush(&mut self) -> io::Result<()> {
self.realize_deferred_error_except_incomplete_utf8()?;
self.buffer.flush()
}
fn write_fmt(&mut self, f: fmt::Arguments<'_>) -> io::Result<()> {
self.realize_deferred_error_except_incomplete_utf8()?;
write_fmt_impl(self, f)
}
}
impl<B: fmt::Debug> fmt::Debug for EncodingWriter<B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
struct Wrapper<'a>(&'a Encoder);
impl fmt::Debug for Wrapper<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("Encoder")
.field("encoding()", self.0.encoding())
.finish_non_exhaustive()
}
}
f.debug_struct("EncodingWriter")
.field("buffer", &self.buffer)
.field("encoder", &Wrapper(&self.encoder))
.field("deferred_error", &self.deferred_error)
.finish()
}
}
#[derive(Debug)]
enum DefErr {
Unmappable(UnmappableError),
MalformedUtf8(MalformedError),
IncompleteUtf8(util::MiniBuffer),
}
#[derive(Debug)]
pub struct PassthroughWriter<'a, B>(&'a mut EncodingWriter<B>);
impl<B> PassthroughWriter<'_, B> {
pub fn encoding_writer_ref(&self) -> &EncodingWriter<B> {
self.0
}
}
impl<B: BufferedWrite> io::Write for PassthroughWriter<'_, B> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.is_empty() {
self.0.realize_deferred_error_except_incomplete_utf8()?;
return Ok(0);
} else {
self.0.realize_any_deferred_error()?;
}
self.0.buffer.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
fn str_from_utf8_up_to_error(v: &[u8]) -> Result<&str, Option<usize>> {
match str::from_utf8(v) {
Ok(s) => Ok(s),
Err(e) if e.valid_up_to() > 0 => {
Ok(unsafe { str::from_utf8_unchecked(&v[..e.valid_up_to()]) })
}
Err(e) => Err(e.error_len()),
}
}
trait WriteFmtAdapter {
fn write_str_io(&mut self, buf: &str) -> io::Result<usize>;
}
fn write_fmt_impl(writer: &mut impl WriteFmtAdapter, f: fmt::Arguments<'_>) -> io::Result<()> {
struct FmtWriter<'a, T> {
inner: &'a mut T,
io_error: io::Result<()>,
}
impl<T: WriteFmtAdapter> fmt::Write for FmtWriter<'_, T> {
fn write_str(&mut self, mut s: &str) -> fmt::Result {
while !s.is_empty() {
match self.inner.write_str_io(s) {
Ok(0) => {
self.io_error = Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write whole buffer",
));
return Err(fmt::Error);
}
Ok(n) => {
s = match s.get(n..) {
Some(t) => t,
None => {
debug_assert!(false, "unreachable");
self.io_error = Err(io::Error::new(
io::ErrorKind::Other,
"encoder returned invalid string index",
));
return Err(fmt::Error);
}
}
}
Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => {
self.io_error = Err(e);
return Err(fmt::Error);
}
}
}
Ok(())
}
}
let mut output = FmtWriter {
inner: writer,
io_error: Ok(()),
};
match fmt::write(&mut output, f) {
Ok(_) => Ok(()),
Err(_) => {
if output.io_error.is_err() {
output.io_error
} else {
Err(io::Error::new(io::ErrorKind::Other, "formatter error"))
}
}
}
}
pub trait BufferedWrite: io::Write {
fn unfilled(&mut self) -> &mut [u8];
fn advance(&mut self, n: usize);
fn try_reserve(&mut self, minimum: usize, size_hint: Option<usize>) -> io::Result<()>;
}
#[cfg(test)]
mod tests {
use std::io::Write;
use super::{EncodingWriter, MalformedError, UnmappableError};
#[test]
fn trailing_unmappable_char() {
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::SHIFT_JIS.new_encoder());
assert!(matches!(writer.write_all("Oh🥺".as_bytes()), Ok(())));
match writer.finish() {
Err((buffer, mut iter)) => {
assert_eq!(buffer.get_ref(), b"");
assert_eq!(buffer.buffer(), b"Oh");
let e = iter.next().unwrap().unwrap_err();
assert_eq!(UnmappableError::wrapped_in(&e).unwrap().value(), '🥺');
assert!(iter.next().is_none());
}
ret => panic!("assertion failed: {:?}", ret),
};
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::EUC_JP.new_encoder());
assert!(matches!(writer.write_all("Oh🥺".as_bytes()), Ok(())));
match writer.flush() {
Err(e) => {
assert_eq!(UnmappableError::wrapped_in(&e).unwrap().value(), '🥺');
}
ret => panic!("assertion failed: {:?}", ret),
};
match writer.finish() {
Ok(buffer) => {
assert_eq!(buffer.get_ref(), b"");
assert_eq!(buffer.buffer(), b"Oh");
}
ret => panic!("assertion failed: {:?}", ret),
};
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::ISO_2022_JP.new_encoder());
assert!(matches!(writer.write_all("Oh🥺".as_bytes()), Ok(())));
match writer.write(b"") {
Err(e) => {
assert_eq!(UnmappableError::wrapped_in(&e).unwrap().value(), '🥺');
}
ret => panic!("assertion failed: {:?}", ret),
};
match writer.finish() {
Ok(buffer) => {
assert_eq!(buffer.get_ref(), b"");
assert_eq!(buffer.buffer(), b"Oh");
}
ret => panic!("assertion failed: {:?}", ret),
};
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::EUC_KR.new_encoder());
assert!(matches!(write!(writer, "Oh🥺"), Ok(())));
match writer.finish() {
Err((buffer, mut iter)) => {
assert_eq!(buffer.get_ref(), b"");
assert_eq!(buffer.buffer(), b"Oh");
let e = iter.next().unwrap().unwrap_err();
assert_eq!(UnmappableError::wrapped_in(&e).unwrap().value(), '🥺');
assert!(iter.next().is_none());
}
ret => panic!("assertion failed: {:?}", ret),
};
}
#[test]
fn malformed_utf8() {
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::ISO_8859_2.new_encoder());
let src = [b'A', b'B', 0x80, 0x9f, b'C', b'D'];
assert!(matches!(writer.write(&src[0..]), Ok(2)));
assert!(matches!(writer.write(&src[2..]), Ok(1)));
assert!(match writer.write(&src[3..]) {
Err(e) => MalformedError::wrapped_in(&e).is_some(),
_ => false,
});
assert!(matches!(writer.write(&src[3..]), Ok(1)));
assert!(match writer.write(&src[4..]) {
Err(e) => MalformedError::wrapped_in(&e).is_some(),
_ => false,
});
assert!(matches!(writer.write(&src[4..]), Ok(2)));
assert!(matches!(writer.write(&[0xc3]), Ok(1)));
assert!(matches!(writer.write(&[0x97]), Ok(1)));
assert!(matches!(writer.write(&[0xc3]), Ok(1)));
assert!(matches!(writer.write(&[]), Ok(0)));
assert!(matches!(writer.write_str(""), Ok(0)));
assert!(matches!(writer.passthrough().write(&[]), Ok(0)));
match writer.finish() {
Err((buffer, mut iter)) => {
assert_eq!(buffer.get_ref(), &[]);
assert_eq!(buffer.buffer(), &[b'A', b'B', b'C', b'D', 0xd7]);
let e = iter.next().unwrap().unwrap_err();
assert!(MalformedError::wrapped_in(&e).is_some());
assert!(iter.next().is_none());
}
ret => panic!("assertion failed: {:?}", ret),
}
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::ISO_8859_3.new_encoder());
assert!(matches!(writer.write(&[0xc3]), Ok(1)));
assert!(match writer.write_str("A") {
Err(e) => MalformedError::wrapped_in(&e).is_some(),
_ => false,
});
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::ISO_8859_4.new_encoder());
assert!(matches!(writer.write(&[0xc3]), Ok(1)));
assert!(match writer.passthrough().write(&[0xd7]) {
Err(e) => MalformedError::wrapped_in(&e).is_some(),
_ => false,
});
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::ISO_8859_7.new_encoder());
assert!(matches!(writer.write(&[0xc3]), Ok(1)));
assert!(writer.flush().is_ok());
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::ISO_8859_15.new_encoder());
assert!(matches!(
writer
.with_unmappable_handler(|_, _| unreachable!())
.write(&[0xc3]),
Ok(1),
));
assert!(writer.flush().is_ok());
}
#[test]
fn propagate_error_from_handler() {
use std::{error, fmt, io};
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::BIG5.new_encoder());
{
let mut writer = writer.with_unmappable_handler(|e, _| Err(e.wrap()));
let ret = write!(writer, "Boo!👻 Boo!👻");
writer.flush().unwrap();
assert!(ret
.unwrap_err()
.get_ref()
.unwrap()
.downcast_ref::<UnmappableError>()
.is_some());
}
assert_eq!(writer.writer_ref(), b"Boo!");
#[derive(Debug)]
struct AdHocError;
impl fmt::Display for AdHocError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl error::Error for AdHocError {}
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::BIG5.new_encoder());
{
let mut writer = writer.with_unmappable_handler(|_, _| {
Err(io::Error::new(io::ErrorKind::Other, AdHocError))
});
let ret = write!(writer, "Boo!👻 Boo!👻");
writer.flush().unwrap();
assert!(ret
.unwrap_err()
.get_ref()
.unwrap()
.downcast_ref::<AdHocError>()
.is_some());
}
assert_eq!(writer.writer_ref(), b"Boo!");
}
#[test]
fn finish_iso_2022_jp() {
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::ISO_2022_JP.new_encoder());
write!(writer, "あ").unwrap();
assert_eq!(writer.write(&[0xff]).unwrap(), 1);
match writer.finish() {
Err((buffer, mut iter)) => {
assert_eq!(buffer.get_ref(), &[]);
assert_eq!(buffer.buffer(), &[27, 36, 66, 36, 34]);
let e = iter.next().unwrap().unwrap_err();
assert!(MalformedError::wrapped_in(&e).is_some());
assert_eq!(iter.next().unwrap().unwrap(), &[27, 40, 66]);
assert!(iter.next().is_none());
}
ret => panic!("assertion failed: {:?}", ret),
}
}
#[test]
fn iso_2022_jp_at_unmappable() {
let text = "🐀Lorem ip🐁sum恥の多い🐂生涯をdolor sit 🐃amet送って🐄来ました🐅";
let expected = {
let mut dst = Vec::with_capacity(text.len() * 2);
let mut encoder = encoding_rs::ISO_2022_JP.new_encoder();
let (result, ..) = encoder.encode_from_utf8_to_vec(text, &mut dst, false);
assert!(matches!(result, encoding_rs::CoderResult::InputEmpty));
dst
};
let mut writer = EncodingWriter::new(Vec::new(), encoding_rs::ISO_2022_JP.new_encoder());
{
let mut writer = writer.with_unmappable_handler(|e, w| {
assert!(!w.encoding_writer_ref().encoder_ref().has_pending_state());
write!(w, "&#{};", u32::from(e.value()))
});
write!(writer, "{}", text).unwrap();
writer.flush().unwrap();
}
assert_eq!(writer.writer_ref(), &expected);
}
}