#![allow(dead_code)]
use std::io::{self, Write};
#[derive(Debug, Clone)]
pub struct ChunkInfo {
pub index: u64,
pub payload_len: usize,
pub cumulative_bytes: u64,
pub is_final: bool,
}
pub struct ChunkedWriter<W, F> {
inner: W,
callback: F,
buffer: Vec<u8>,
chunk_size: usize,
chunk_count: u64,
total_bytes: u64,
}
impl<W: Write, F: FnMut(&ChunkInfo)> ChunkedWriter<W, F> {
pub fn new(inner: W, chunk_size: usize, callback: F) -> Self {
assert!(chunk_size > 0, "chunk_size must be > 0");
Self {
inner,
callback,
buffer: Vec::with_capacity(chunk_size),
chunk_size,
chunk_count: 0,
total_bytes: 0,
}
}
#[must_use]
pub fn chunk_count(&self) -> u64 {
self.chunk_count
}
#[must_use]
pub fn total_bytes(&self) -> u64 {
self.total_bytes
}
#[must_use]
pub fn chunk_size(&self) -> usize {
self.chunk_size
}
#[must_use]
pub fn buffered_len(&self) -> usize {
self.buffer.len()
}
pub fn finish(&mut self) -> io::Result<()> {
if !self.buffer.is_empty() {
self.flush_chunk(true)?;
}
self.inner.flush()
}
pub fn into_inner(mut self) -> io::Result<W> {
self.finish()?;
Ok(self.inner)
}
fn flush_chunk(&mut self, is_final: bool) -> io::Result<()> {
let payload_len = self.buffer.len();
self.inner.write_all(&self.buffer)?;
self.total_bytes += payload_len as u64;
let info = ChunkInfo {
index: self.chunk_count,
payload_len,
cumulative_bytes: self.total_bytes,
is_final,
};
(self.callback)(&info);
self.chunk_count += 1;
self.buffer.clear();
Ok(())
}
}
impl<W: Write, F: FnMut(&ChunkInfo)> Write for ChunkedWriter<W, F> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut offset = 0;
while offset < buf.len() {
let remaining_cap = self.chunk_size - self.buffer.len();
let to_copy = remaining_cap.min(buf.len() - offset);
self.buffer
.extend_from_slice(&buf[offset..offset + to_copy]);
offset += to_copy;
if self.buffer.len() >= self.chunk_size {
self.flush_chunk(false)?;
}
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.finish()
}
}
pub struct AlignedChunkWriter<W> {
inner: W,
alignment: usize,
pad_byte: u8,
buffer: Vec<u8>,
chunks_written: u64,
}
impl<W: Write> AlignedChunkWriter<W> {
pub fn new(inner: W, alignment: usize) -> Self {
assert!(
alignment > 0 && alignment.is_power_of_two(),
"alignment must be a power of two"
);
Self {
inner,
alignment,
pad_byte: 0,
buffer: Vec::with_capacity(alignment),
chunks_written: 0,
}
}
#[must_use]
pub fn with_pad_byte(mut self, byte: u8) -> Self {
self.pad_byte = byte;
self
}
#[must_use]
pub fn chunks_written(&self) -> u64 {
self.chunks_written
}
pub fn finish(&mut self) -> io::Result<()> {
if !self.buffer.is_empty() {
let pad_len = self.alignment - (self.buffer.len() % self.alignment);
if pad_len < self.alignment {
self.buffer
.resize(self.buffer.len() + pad_len, self.pad_byte);
}
self.inner.write_all(&self.buffer)?;
self.chunks_written += 1;
self.buffer.clear();
}
self.inner.flush()
}
}
impl<W: Write> Write for AlignedChunkWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut offset = 0;
while offset < buf.len() {
let remaining = self.alignment - self.buffer.len();
let to_copy = remaining.min(buf.len() - offset);
self.buffer
.extend_from_slice(&buf[offset..offset + to_copy]);
offset += to_copy;
if self.buffer.len() >= self.alignment {
self.inner.write_all(&self.buffer)?;
self.chunks_written += 1;
self.buffer.clear();
}
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chunked_basic() {
let mut out = Vec::new();
let mut infos = Vec::new();
{
let mut w = ChunkedWriter::new(&mut out, 4, |info| {
infos.push(info.clone());
});
w.write_all(b"abcdefghij").unwrap();
w.finish().unwrap();
}
assert_eq!(out, b"abcdefghij");
assert_eq!(infos.len(), 3);
assert_eq!(infos[0].payload_len, 4);
assert_eq!(infos[1].payload_len, 4);
assert_eq!(infos[2].payload_len, 2);
assert!(infos[2].is_final);
}
#[test]
fn test_chunked_exact_multiple() {
let mut out = Vec::new();
let mut count = 0u64;
{
let mut w = ChunkedWriter::new(&mut out, 5, |_| count += 1);
w.write_all(b"12345").unwrap();
w.finish().unwrap();
}
assert_eq!(out, b"12345");
assert_eq!(count, 1); }
#[test]
fn test_chunked_empty() {
let mut out = Vec::new();
let mut count = 0u64;
{
let mut w = ChunkedWriter::new(&mut out, 10, |_| count += 1);
w.finish().unwrap();
}
assert!(out.is_empty());
assert_eq!(count, 0);
}
#[test]
fn test_chunk_info_fields() {
let mut out = Vec::new();
let mut infos = Vec::new();
{
let mut w = ChunkedWriter::new(&mut out, 3, |info| infos.push(info.clone()));
w.write_all(b"abcdef").unwrap();
w.finish().unwrap();
}
assert_eq!(infos[0].index, 0);
assert_eq!(infos[0].cumulative_bytes, 3);
assert_eq!(infos[1].index, 1);
assert_eq!(infos[1].cumulative_bytes, 6);
}
#[test]
fn test_chunked_total_bytes() {
let mut out = Vec::new();
let mut w = ChunkedWriter::new(&mut out, 8, |_| {});
w.write_all(b"hello world").unwrap();
w.finish().unwrap();
assert_eq!(w.total_bytes(), 11);
}
#[test]
fn test_chunked_chunk_count() {
let mut out = Vec::new();
let mut w = ChunkedWriter::new(&mut out, 4, |_| {});
w.write_all(b"0123456789ab").unwrap();
w.finish().unwrap();
assert_eq!(w.chunk_count(), 3);
}
#[test]
fn test_chunked_buffered_len() {
let mut out = Vec::new();
let mut w = ChunkedWriter::new(&mut out, 10, |_| {});
w.write_all(b"abc").unwrap();
assert_eq!(w.buffered_len(), 3);
}
#[test]
#[should_panic(expected = "chunk_size must be > 0")]
fn test_chunked_zero_size_panics() {
let mut out = Vec::new();
let _w = ChunkedWriter::new(&mut out, 0, |_| {});
}
#[test]
fn test_aligned_basic() {
let mut out = Vec::new();
{
let mut w = AlignedChunkWriter::new(&mut out, 8);
w.write_all(b"hello").unwrap();
w.finish().unwrap();
}
assert_eq!(out.len(), 8); assert_eq!(&out[..5], b"hello");
assert!(out[5..].iter().all(|&b| b == 0));
}
#[test]
fn test_aligned_exact() {
let mut out = Vec::new();
{
let mut w = AlignedChunkWriter::new(&mut out, 4);
w.write_all(b"abcd").unwrap();
w.finish().unwrap();
}
assert_eq!(out, b"abcd");
}
#[test]
fn test_aligned_pad_byte() {
let mut out = Vec::new();
{
let mut w = AlignedChunkWriter::new(&mut out, 8).with_pad_byte(0xFF);
w.write_all(b"hi").unwrap();
w.finish().unwrap();
}
assert_eq!(out.len(), 8);
assert_eq!(&out[..2], b"hi");
assert!(out[2..].iter().all(|&b| b == 0xFF));
}
#[test]
fn test_aligned_chunks_written() {
let mut out = Vec::new();
let mut w = AlignedChunkWriter::new(&mut out, 4);
w.write_all(b"12345678").unwrap(); assert_eq!(w.chunks_written(), 2);
}
#[test]
#[should_panic(expected = "alignment must be a power of two")]
fn test_aligned_non_power_of_two_panics() {
let mut out = Vec::new();
let _w = AlignedChunkWriter::new(&mut out, 3);
}
}