use std::io::{self, Write};
pub struct CountingWriter<W> {
inner: W,
bytes: u64,
}
impl<W: Write> CountingWriter<W> {
pub fn new(inner: W) -> Self {
Self { inner, bytes: 0 }
}
pub fn bytes(&self) -> u64 {
self.bytes
}
}
impl<W: Write> Write for CountingWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let n = self.inner.write(buf)?;
self.bytes += n as u64;
Ok(n)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
let mut remaining = buf;
while !remaining.is_empty() {
match self.write(remaining) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write whole buffer",
));
}
Ok(n) => remaining = &remaining[n..],
Err(e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn counts_bytes_across_writes() {
let mut w = CountingWriter::new(Vec::new());
w.write_all(b"hello").unwrap();
assert_eq!(w.bytes(), 5);
w.write_all(b", world").unwrap();
assert_eq!(w.bytes(), 12);
assert_eq!(w.inner, b"hello, world");
}
#[test]
fn partial_writes_increment_correctly() {
struct Partial(Vec<u8>);
impl Write for Partial {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let n = buf.len().min(3);
self.0.extend_from_slice(&buf[..n]);
Ok(n)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
let mut w = CountingWriter::new(Partial(Vec::new()));
w.write_all(b"hello, world").unwrap();
assert_eq!(w.bytes(), 12);
}
}