use core::fmt;
use crate::io::{
self, Error, ErrorKind, IntoInnerError, IoSlice, Seek, SeekFrom, Write, DEFAULT_BUF_SIZE,
};
use crate::io::prelude::*;
pub struct BufWriter<W: Write> {
inner: Option<W>,
buf: Vec<u8>,
panicked: bool,
}
impl<W: Write> BufWriter<W> {
pub fn new(inner: W) -> BufWriter<W> {
BufWriter::with_capacity(DEFAULT_BUF_SIZE, inner)
}
pub fn with_capacity(capacity: usize, inner: W) -> BufWriter<W> {
BufWriter { inner: Some(inner), buf: Vec::with_capacity(capacity), panicked: false }
}
pub(super) fn flush_buf(&mut self) -> io::Result<()> {
struct BufGuard<'a> {
buffer: &'a mut Vec<u8>,
written: usize,
}
impl<'a> BufGuard<'a> {
fn new(buffer: &'a mut Vec<u8>) -> Self {
Self { buffer, written: 0 }
}
fn remaining(&self) -> &[u8] {
&self.buffer[self.written..]
}
fn consume(&mut self, amt: usize) {
self.written += amt;
}
fn done(&self) -> bool {
self.written >= self.buffer.len()
}
}
impl Drop for BufGuard<'_> {
fn drop(&mut self) {
if self.written > 0 {
self.buffer.drain(..self.written);
}
}
}
let mut guard = BufGuard::new(&mut self.buf);
let inner = self.inner.as_mut().unwrap();
while !guard.done() {
self.panicked = true;
let r = inner.write(guard.remaining());
self.panicked = false;
match r {
Ok(0) => {
return Err(Error::new(
ErrorKind::WriteZero,
"failed to write the buffered data",
));
}
Ok(n) => guard.consume(n),
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(())
}
pub(super) fn write_to_buf(&mut self, buf: &[u8]) -> usize {
let available = self.buf.capacity() - self.buf.len();
let amt_to_buffer = available.min(buf.len());
self.buf.extend_from_slice(&buf[..amt_to_buffer]);
amt_to_buffer
}
pub fn get_ref(&self) -> &W {
self.inner.as_ref().unwrap()
}
pub fn get_mut(&mut self) -> &mut W {
self.inner.as_mut().unwrap()
}
pub fn buffer(&self) -> &[u8] {
&self.buf
}
pub fn capacity(&self) -> usize {
self.buf.capacity()
}
pub fn into_inner(mut self) -> Result<W, IntoInnerError<BufWriter<W>>> {
match self.flush_buf() {
Err(e) => Err(IntoInnerError::new(self, e)),
Ok(()) => Ok(self.inner.take().unwrap()),
}
}
}
impl<W: Write> Write for BufWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.buf.len() + buf.len() > self.buf.capacity() {
self.flush_buf()?;
}
if buf.len() >= self.buf.capacity() {
self.panicked = true;
let r = self.get_mut().write(buf);
self.panicked = false;
r
} else {
self.buf.extend_from_slice(buf);
Ok(buf.len())
}
}
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
if self.buf.len() + buf.len() > self.buf.capacity() {
self.flush_buf()?;
}
if buf.len() >= self.buf.capacity() {
self.panicked = true;
let r = self.get_mut().write_all(buf);
self.panicked = false;
r
} else {
self.buf.extend_from_slice(buf);
Ok(())
}
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
let total_len = bufs.iter().map(|b| b.len()).sum::<usize>();
if self.buf.len() + total_len > self.buf.capacity() {
self.flush_buf()?;
}
if total_len >= self.buf.capacity() {
self.panicked = true;
let r = self.get_mut().write_vectored(bufs);
self.panicked = false;
r
} else {
bufs.iter().for_each(|b| self.buf.extend_from_slice(b));
Ok(total_len)
}
}
fn is_write_vectored(&self) -> bool {
self.get_ref().is_write_vectored()
}
fn flush(&mut self) -> io::Result<()> {
self.flush_buf().and_then(|()| self.get_mut().flush())
}
}
impl<W: Write> fmt::Debug for BufWriter<W>
where
W: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("BufWriter")
.field("writer", &self.inner.as_ref().unwrap())
.field("buffer", &format_args!("{}/{}", self.buf.len(), self.buf.capacity()))
.finish()
}
}
impl<W: Write + Seek> Seek for BufWriter<W> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
self.flush_buf()?;
self.get_mut().seek(pos)
}
}
impl<W: Write> Drop for BufWriter<W> {
fn drop(&mut self) {
if self.inner.is_some() && !self.panicked {
let _r = self.flush_buf();
}
}
}