use std::io::Write;
pub struct CountingWriter<W> {
inner: W,
bytes_written: u64,
}
impl<W> CountingWriter<W> {
#[must_use]
pub fn new(inner: W) -> Self {
Self {
inner,
bytes_written: 0,
}
}
#[must_use]
pub fn total_bytes(&self) -> u64 {
self.bytes_written
}
#[must_use]
pub fn into_inner(self) -> W {
self.inner
}
#[must_use]
pub fn get_ref(&self) -> &W {
&self.inner
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.inner
}
}
impl<W: Write> Write for CountingWriter<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let bytes = self.inner.write(buf)?;
self.bytes_written += bytes as u64;
Ok(bytes)
}
fn flush(&mut self) -> std::io::Result<()> {
self.inner.flush()
}
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
self.inner.write_all(buf)?;
self.bytes_written += buf.len() as u64;
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_counting_writer_basic() {
let mut buffer = Vec::new();
let mut writer = CountingWriter::new(&mut buffer);
writer.write_all(b"Hello").unwrap();
assert_eq!(writer.total_bytes(), 5);
writer.write_all(b", World!").unwrap();
assert_eq!(writer.total_bytes(), 13);
assert_eq!(buffer, b"Hello, World!");
}
#[test]
fn test_counting_writer_write() {
let mut buffer = Vec::new();
let mut writer = CountingWriter::new(&mut buffer);
let bytes_written = writer.write(b"test").unwrap();
assert_eq!(bytes_written, 4);
assert_eq!(writer.total_bytes(), 4);
}
#[test]
fn test_counting_writer_write_fmt() {
let mut buffer = Vec::new();
let mut writer = CountingWriter::new(&mut buffer);
write!(writer, "test {}", 42).unwrap();
assert_eq!(writer.total_bytes(), 7);
assert_eq!(buffer, b"test 42");
}
#[test]
fn test_counting_writer_flush() {
let mut buffer = Vec::new();
let mut writer = CountingWriter::new(&mut buffer);
writer.write_all(b"data").unwrap();
writer.flush().unwrap();
assert_eq!(writer.total_bytes(), 4);
}
#[test]
fn test_counting_writer_into_inner() {
let buffer = Vec::new();
let mut writer = CountingWriter::new(buffer);
writer.write_all(b"test").unwrap();
assert_eq!(writer.total_bytes(), 4);
let buffer = writer.into_inner();
assert_eq!(buffer, b"test");
}
#[test]
fn test_counting_writer_get_ref() {
let buffer = Vec::new();
let mut writer = CountingWriter::new(buffer);
writer.write_all(b"test").unwrap();
let inner_ref = writer.get_ref();
assert_eq!(inner_ref, &b"test"[..]);
}
#[test]
fn test_counting_writer_get_mut() {
let buffer = Vec::new();
let mut writer = CountingWriter::new(buffer);
writer.write_all(b"test").unwrap();
let inner_mut = writer.get_mut();
inner_mut.push(b'!');
assert_eq!(writer.total_bytes(), 4);
assert_eq!(writer.get_ref(), &b"test!"[..]);
}
#[test]
fn test_counting_writer_empty() {
let buffer: Vec<u8> = Vec::new();
let writer = CountingWriter::new(buffer);
assert_eq!(writer.total_bytes(), 0);
}
#[test]
fn test_counting_writer_multiple_writes() {
let mut buffer = Vec::new();
let mut writer = CountingWriter::new(&mut buffer);
for i in 0..10 {
write!(writer, "{i}").unwrap();
}
assert_eq!(writer.total_bytes(), 10);
assert_eq!(buffer, b"0123456789");
}
#[test]
fn test_counting_writer_with_cursor() {
let buffer: Vec<u8> = vec![0u8; 100];
let cursor = Cursor::new(buffer);
let mut writer = CountingWriter::new(cursor);
writer.write_all(b"test data").unwrap();
assert_eq!(writer.total_bytes(), 9);
}
#[test]
fn test_counting_writer_partial_write() {
struct LimitedWriter {
inner: Vec<u8>,
max_write: usize,
}
impl Write for LimitedWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let to_write = buf.len().min(self.max_write);
self.inner.extend_from_slice(&buf[..to_write]);
Ok(to_write)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
let limited = LimitedWriter {
inner: Vec::new(),
max_write: 3,
};
let mut writer = CountingWriter::new(limited);
let written = writer.write(b"hello").unwrap();
assert_eq!(written, 3);
assert_eq!(writer.total_bytes(), 3);
assert_eq!(writer.get_ref().inner, b"hel");
}
}