use std::io::{self, Write};
pub struct DeferredWriter<'a> {
write: Box<dyn Write + 'a>,
buf: Vec<u8>,
io_error: Option<io::Error>,
panicked: bool,
}
impl<'a> DeferredWriter<'a> {
const DEFAULT_CHUNK_SIZE: usize = 16 << 10;
pub fn from_write(write: impl Write + 'a) -> Self {
Self::from_boxed_dyn_write(Box::new(write))
}
#[inline(never)]
pub fn from_boxed_dyn_write(write: Box<dyn Write + 'a>) -> Self {
DeferredWriter {
write,
buf: Vec::with_capacity(Self::DEFAULT_CHUNK_SIZE),
io_error: None,
panicked: false,
}
}
pub fn flush_defer_err(&mut self) {
if self.io_error.is_none() {
self.panicked = true;
if let Err(err) = self.write.write_all(&self.buf) {
self.io_error = Some(err);
}
self.panicked = false;
}
self.buf.clear();
}
#[inline]
pub fn write_all_defer_err(&mut self, buf: &[u8]) {
let old_len = self.buf.len();
let new_len = old_len + buf.len();
if new_len <= self.buf.capacity() {
unsafe {
self.buf
.as_mut_ptr()
.add(old_len)
.copy_from_nonoverlapping(buf.as_ptr(), buf.len());
self.buf.set_len(new_len)
}
} else {
self.write_all_defer_err_cold(buf);
}
}
#[inline(never)]
#[cold]
fn write_all_defer_err_cold(&mut self, mut buf: &[u8]) {
if buf.len() < self.buf.capacity() {
let (buf_first, buf_second) = buf.split_at(self.buf.capacity() - self.buf.len());
self.buf.extend_from_slice(buf_first);
buf = buf_second;
}
self.flush_defer_err();
if buf.len() < self.buf.capacity() {
self.buf.extend_from_slice(buf);
} else {
if self.io_error.is_none() {
self.panicked = true;
if let Err(err) = self.write.write_all(buf) {
self.io_error = Some(err);
}
self.panicked = false;
}
}
}
#[inline]
pub fn buf_write_ptr(&mut self, len: usize) -> *mut u8 {
let old_len = self.buf.len();
let new_len = old_len + len;
if new_len <= self.buf.capacity() {
unsafe { self.buf.as_mut_ptr().add(old_len) }
} else {
std::ptr::null_mut()
}
}
#[inline]
pub unsafe fn advance_unchecked(&mut self, len: usize) {
let old_len = self.buf.len();
let new_len = old_len + len;
debug_assert!(new_len <= self.buf.capacity());
self.buf.set_len(new_len)
}
#[inline]
pub fn check_io_error(&mut self) -> io::Result<()> {
if let Some(err) = self.io_error.take() {
Err(err)
} else {
Ok(())
}
}
}
impl<'a> Write for DeferredWriter<'a> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_all_defer_err(buf);
Ok(buf.len())
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
self.flush_defer_err();
self.check_io_error()
}
#[inline]
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
self.write_all_defer_err(buf);
Ok(())
}
}
impl<'a> Drop for DeferredWriter<'a> {
fn drop(&mut self) {
if !self.panicked {
self.flush_defer_err();
}
}
}