use std::io::{self, Write};
use crate::constants::*;
use crate::crc::crc;
use crate::encode::encode;
pub struct Writer<W: Write> {
writer: W,
buf: Vec<u8>,
block_size: usize,
wrote_header: bool,
}
impl<W: Write> Writer<W> {
pub fn new(writer: W) -> Self {
Self::with_block_size(writer, DEFAULT_BLOCK_SIZE)
}
pub fn with_block_size(writer: W, block_size: usize) -> Self {
let block_size = block_size.clamp(MIN_BLOCK_SIZE, MAX_BLOCK_SIZE);
Writer {
writer,
buf: Vec::new(),
block_size,
wrote_header: false,
}
}
fn write_header(&mut self) -> io::Result<()> {
if !self.wrote_header {
self.writer.write_all(MAGIC_CHUNK)?;
self.wrote_header = true;
}
Ok(())
}
fn flush_block(&mut self) -> io::Result<()> {
if self.buf.is_empty() {
return Ok(());
}
self.write_header()?;
let compressed = encode(&self.buf);
let checksum = crc(&self.buf);
let chunk_len = compressed.len() + CHECKSUM_SIZE;
if chunk_len > MAX_CHUNK_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"compressed block too large",
));
}
self.writer.write_all(&[CHUNK_TYPE_COMPRESSED_DATA])?;
let len_bytes = [
(chunk_len & 0xff) as u8,
((chunk_len >> 8) & 0xff) as u8,
((chunk_len >> 16) & 0xff) as u8,
];
self.writer.write_all(&len_bytes)?;
self.writer.write_all(&checksum.to_le_bytes())?;
self.writer.write_all(&compressed)?;
self.buf.clear();
Ok(())
}
pub fn reset(&mut self, writer: W) -> W {
self.buf.clear();
self.wrote_header = false;
std::mem::replace(&mut self.writer, writer)
}
pub fn get_ref(&self) -> &W {
&self.writer
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.writer
}
}
impl<W: Write> Write for Writer<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut written = 0;
while written < buf.len() {
let remaining = buf.len() - written;
let space_in_buf = self.block_size - self.buf.len();
if space_in_buf == 0 {
self.flush_block()?;
continue;
}
let to_write = remaining.min(space_in_buf);
self.buf
.extend_from_slice(&buf[written..written + to_write]);
written += to_write;
}
Ok(written)
}
fn flush(&mut self) -> io::Result<()> {
self.flush_block()?;
self.writer.flush()
}
}
impl<W: Write> Drop for Writer<W> {
fn drop(&mut self) {
let _ = self.flush();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[test]
fn test_writer_basic() {
let mut compressed = Vec::new();
{
let mut writer = Writer::new(&mut compressed);
writer.write_all(b"Hello, World!").unwrap();
writer.flush().unwrap();
}
assert!(compressed.len() > MAGIC_CHUNK.len());
assert_eq!(&compressed[..MAGIC_CHUNK.len()], MAGIC_CHUNK);
}
#[test]
fn test_writer_empty() {
let mut compressed = Vec::new();
{
let _writer = Writer::new(&mut compressed);
}
assert_eq!(compressed.len(), 0);
}
#[test]
fn test_writer_multiple_writes() {
let mut compressed = Vec::new();
{
let mut writer = Writer::new(&mut compressed);
writer.write_all(b"Hello, ").unwrap();
writer.write_all(b"World!").unwrap();
writer.flush().unwrap();
}
assert!(compressed.len() > MAGIC_CHUNK.len());
assert_eq!(&compressed[..MAGIC_CHUNK.len()], MAGIC_CHUNK);
}
#[test]
fn test_writer_large_data() {
let data = vec![b'A'; 100000];
let mut compressed = Vec::new();
{
let mut writer = Writer::new(&mut compressed);
writer.write_all(&data).unwrap();
writer.flush().unwrap();
}
assert!(compressed.len() < data.len() / 2);
}
}