use std::{
fmt, io,
io::{Error, ErrorKind, IoSlice, Seek, SeekFrom, Write},
mem,
mem::MaybeUninit,
ptr,
};
pub struct StackBufWriter<W: Write, const N: usize = 8192> {
inner: W,
buf: [MaybeUninit<u8>; N],
start: usize,
end: usize,
panicked: bool,
}
impl<W: Write, const N: usize> StackBufWriter<W, N> {
pub fn new(inner: W) -> StackBufWriter<W, N> {
StackBufWriter {
inner,
buf: unsafe { MaybeUninit::uninit().assume_init() },
start: 0,
end: 0,
panicked: false,
}
}
fn flush_buf(&mut self) -> io::Result<()> {
struct BufGuard<'a> {
buffer: &'a [u8],
start: &'a mut usize,
end: &'a mut usize,
written: usize,
}
impl<'a> BufGuard<'a> {
fn new(buffer: &'a [u8], start: &'a mut usize, end: &'a mut usize) -> Self {
Self {
buffer,
start,
end,
written: 0,
}
}
fn remaining(&self) -> &[u8] {
&self.buffer[self.written..]
}
fn consume(&mut self, amt: usize) {
self.written += amt;
}
fn done(&self) -> bool {
self.written >= self.buffer.len()
}
}
impl Drop for BufGuard<'_> {
fn drop(&mut self) {
*self.start += self.written;
if self.start >= self.end {
debug_assert_eq!(self.start, self.end);
*self.start = 0;
*self.end = 0;
}
}
}
let mut guard = BufGuard::new(
unsafe { MaybeUninit::slice_assume_init_ref(&self.buf[self.start..self.end]) },
&mut self.start,
&mut self.end,
);
while !guard.done() {
self.panicked = true;
let r = self.inner.write(guard.remaining());
self.panicked = false;
match r {
Ok(0) => {
return Err(Error::new(
ErrorKind::WriteZero,
"failed to write the buffered data",
));
}
Ok(n) => guard.consume(n),
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(())
}
pub fn get_ref(&self) -> &W {
&self.inner
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.inner
}
pub fn buffer(&self) -> &[u8] {
unsafe { MaybeUninit::slice_assume_init_ref(&self.buf[self.start..self.end]) }
}
pub fn capacity(&self) -> usize {
self.buf.len()
}
pub fn into_inner(mut self) -> Result<W, (io::Error, StackBufWriter<W, N>)> {
match self.flush_buf() {
Err(e) => Err((e, self)),
Ok(()) => {
let inner = unsafe { ptr::read(&self.inner) };
mem::forget(self);
Ok(inner)
}
}
}
#[cold]
#[inline(never)]
fn write_cold(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.len() > self.spare_capacity() {
self.flush_buf()?;
}
if buf.len() >= self.buf.len() {
self.panicked = true;
let r = self.get_mut().write(buf);
self.panicked = false;
r
} else {
unsafe {
self.write_to_buffer_unchecked(buf);
}
Ok(buf.len())
}
}
#[cold]
#[inline(never)]
fn write_all_cold(&mut self, buf: &[u8]) -> io::Result<()> {
if buf.len() > self.spare_capacity() {
self.flush_buf()?;
}
if buf.len() >= self.buf.len() {
self.panicked = true;
let r = self.get_mut().write_all(buf);
self.panicked = false;
r
} else {
unsafe {
self.write_to_buffer_unchecked(buf);
}
Ok(())
}
}
#[inline]
unsafe fn write_to_buffer_unchecked(&mut self, buf: &[u8]) {
debug_assert!(buf.len() <= self.spare_capacity());
let buf_len = buf.len();
let src = buf.as_ptr();
let dst = MaybeUninit::slice_assume_init_mut(&mut self.buf)
.as_mut_ptr()
.add(self.end);
ptr::copy_nonoverlapping(src, dst, buf_len);
self.end += buf_len;
}
#[inline]
fn spare_capacity(&self) -> usize {
self.buf.len() - self.end
}
}
impl<W: Write, const N: usize> Write for StackBufWriter<W, N> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.len() < self.spare_capacity() {
unsafe {
self.write_to_buffer_unchecked(buf);
}
Ok(buf.len())
} else {
self.write_cold(buf)
}
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
if self.get_ref().is_write_vectored() {
let saturated_total_len = bufs
.iter()
.fold(0usize, |acc, b| acc.saturating_add(b.len()));
if saturated_total_len > self.spare_capacity() {
self.flush_buf()?;
}
if saturated_total_len >= self.buf.len() {
self.panicked = true;
let r = self.get_mut().write_vectored(bufs);
self.panicked = false;
r
} else {
unsafe {
bufs.iter().for_each(|b| self.write_to_buffer_unchecked(b));
};
Ok(saturated_total_len)
}
} else {
let mut iter = bufs.iter();
let mut total_written = if let Some(buf) = iter.by_ref().find(|&buf| !buf.is_empty()) {
if buf.len() > self.spare_capacity() {
self.flush_buf()?;
}
if buf.len() >= self.buf.len() {
self.panicked = true;
let r = self.get_mut().write(buf);
self.panicked = false;
return r;
} else {
unsafe {
self.write_to_buffer_unchecked(buf);
}
buf.len()
}
} else {
return Ok(0);
};
debug_assert!(total_written != 0);
for buf in iter {
if buf.len() <= self.spare_capacity() {
unsafe {
self.write_to_buffer_unchecked(buf);
}
total_written += buf.len();
} else {
break;
}
}
Ok(total_written)
}
}
fn is_write_vectored(&self) -> bool {
true
}
fn flush(&mut self) -> io::Result<()> {
self.flush_buf().and_then(|()| self.get_mut().flush())
}
#[inline]
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
if buf.len() < self.spare_capacity() {
unsafe {
self.write_to_buffer_unchecked(buf);
}
Ok(())
} else {
self.write_all_cold(buf)
}
}
}
impl<W: Write, const N: usize> fmt::Debug for StackBufWriter<W, N>
where
W: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("BufWriter")
.field("writer", &self.inner)
.field(
"buffer",
&format_args!("{}/{}", self.end - self.start, self.buf.len()),
)
.finish()
}
}
impl<W: Write + Seek, const N: usize> Seek for StackBufWriter<W, N> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
self.flush_buf()?;
self.get_mut().seek(pos)
}
}
impl<W: Write, const N: usize> Drop for StackBufWriter<W, N> {
fn drop(&mut self) {
if !self.panicked {
let _r = self.flush_buf();
}
}
}