use std::io::Write;
use crate::error::{Error, Result};
use super::seek_table::{FrameEntry, SeekTable};
pub const DEFAULT_LEVEL: i32 = 3;
pub const DEFAULT_FRAME_SIZE: usize = 2 * 1024 * 1024;
pub const MAX_FRAME_SIZE: usize = u32::MAX as usize;
pub struct WriterConfig {
pub level: i32,
pub frame_size: usize,
pub checksums: bool,
pub record_aligned: bool,
}
impl Default for WriterConfig {
fn default() -> Self {
Self {
level: DEFAULT_LEVEL,
frame_size: DEFAULT_FRAME_SIZE,
checksums: true,
record_aligned: true,
}
}
}
#[derive(Default, Debug, Clone)]
pub struct WriterStats {
pub frames: u64,
pub records: u64,
pub uncompressed_bytes: u64,
pub compressed_bytes: u64,
pub oversized_frames: u64,
pub seek_table_bytes: u64,
}
pub struct Writer<W: Write> {
out: W,
cfg: WriterConfig,
pending: Vec<u8>, table: SeekTable,
stats: WriterStats,
warned_oversized: bool,
}
impl<W: Write> Writer<W> {
pub fn new(out: W, cfg: WriterConfig) -> Self {
assert!(
cfg.frame_size > 0 && cfg.frame_size <= MAX_FRAME_SIZE,
"frame size {} outside [1, MAX_FRAME_SIZE={}]",
cfg.frame_size,
MAX_FRAME_SIZE
);
let with_checksums = cfg.checksums;
Self {
out,
cfg,
pending: Vec::with_capacity(4 * 1024 * 1024),
table: SeekTable::new(with_checksums),
stats: WriterStats::default(),
warned_oversized: false,
}
}
pub fn write_block(&mut self, block: &[u8]) -> Result<()> {
if block.is_empty() {
return Ok(());
}
self.stats.records += memchr::memchr_iter(b'\n', block).count() as u64;
self.pending.extend_from_slice(block);
while self.try_flush_frame()? {}
Ok(())
}
fn try_flush_frame(&mut self) -> Result<bool> {
if self.pending.len() < self.cfg.frame_size {
return Ok(false);
}
if !self.cfg.record_aligned {
self.flush_prefix(self.cfg.frame_size)?;
return Ok(true);
}
let search_start = self.cfg.frame_size - 1;
if let Some(rel) = memchr::memchr(b'\n', &self.pending[search_start..]) {
let boundary = search_start + rel + 1; self.flush_prefix(boundary)?;
return Ok(true);
}
if self.pending.len() >= MAX_FRAME_SIZE {
self.stats.oversized_frames += 1;
if !self.warned_oversized {
tracing::warn!(
frame_size = self.cfg.frame_size,
max_frame_size = MAX_FRAME_SIZE,
"record exceeds max frame size; splitting mid-record"
);
self.warned_oversized = true;
}
self.flush_prefix(self.pending.len())?;
return Ok(true);
}
Ok(false)
}
pub fn finish(mut self) -> Result<WriterStats> {
if !self.pending.is_empty() {
let ends_with_nl = self.pending.last().copied() == Some(b'\n');
if self.cfg.record_aligned && !ends_with_nl {
self.pending.push(b'\n');
self.stats.records += 1;
}
let size = self.pending.len();
self.flush_prefix(size)?;
}
let table_bytes = self.table.write_to(&mut self.out).map_err(Error::Io)?;
self.stats.seek_table_bytes = table_bytes as u64;
self.out.flush().map_err(Error::Io)?;
Ok(self.stats)
}
fn flush_prefix(&mut self, up_to: usize) -> Result<()> {
if up_to == 0 {
return Ok(());
}
if up_to > MAX_FRAME_SIZE {
return Err(Error::Input(format!(
"internal: attempted to flush {up_to}-byte frame exceeding \
MAX_FRAME_SIZE {MAX_FRAME_SIZE}"
)));
}
let chunk = &self.pending[..up_to];
let checksum = if self.cfg.checksums {
let xxh = xxhash_rust_like(chunk);
Some((xxh & 0xFFFF_FFFF) as u32)
} else {
None
};
let compressed = zstd::bulk::compress(chunk, self.cfg.level).map_err(Error::Io)?;
let compressed_size = compressed.len();
if compressed_size > MAX_FRAME_SIZE {
return Err(Error::Input(format!(
"compressed frame size {compressed_size} exceeds u32 cap"
)));
}
self.out.write_all(&compressed).map_err(Error::Io)?;
self.table.push(FrameEntry {
compressed_size: compressed_size as u32,
decompressed_size: up_to as u32,
checksum,
});
self.stats.frames += 1;
self.stats.uncompressed_bytes += up_to as u64;
self.stats.compressed_bytes += compressed_size as u64;
self.pending.drain(..up_to);
Ok(())
}
}
fn xxhash_rust_like(data: &[u8]) -> u64 {
let mut h: u64 = 0xcbf2_9ce4_8422_2325;
for b in data {
h ^= u64::from(*b);
h = h.wrapping_mul(0x100_0000_01b3);
}
h
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Cursor, Read};
fn decompress_all(bytes: &[u8]) -> Vec<u8> {
zstd::stream::decode_all(bytes).unwrap()
}
#[test]
fn simple_write_roundtrips_via_zstd_decoder() {
let mut buf = Vec::new();
{
let mut w = Writer::new(
&mut buf,
WriterConfig {
level: 3,
frame_size: 64,
checksums: true,
record_aligned: true,
},
);
for line in ["alpha\n", "bravo\n", "charlie\n", "delta\n", "echo\n"] {
w.write_block(line.as_bytes()).unwrap();
}
let stats = w.finish().unwrap();
assert!(stats.frames >= 1);
assert_eq!(stats.records, 5);
}
let decoded = decompress_all(&buf);
assert_eq!(decoded, b"alpha\nbravo\ncharlie\ndelta\necho\n");
}
#[test]
fn frames_are_record_aligned() {
let mut buf = Vec::new();
let mut w = Writer::new(
&mut buf,
WriterConfig {
level: 3,
frame_size: 32,
checksums: false,
record_aligned: true,
},
);
for i in 0..20 {
w.write_block(format!("record_{i:02}\n").as_bytes())
.unwrap();
}
let _ = w.finish().unwrap();
let mut cursor = Cursor::new(&buf);
let table = SeekTable::read_from_tail(&mut cursor).unwrap();
cursor.set_position(0);
let mut frame_reader = cursor;
for entry in &table.entries {
let mut frame_buf = vec![0u8; entry.compressed_size as usize];
frame_reader.read_exact(&mut frame_buf).unwrap();
let decoded = decompress_all(&frame_buf);
assert!(
decoded.last() == Some(&b'\n'),
"frame not record-aligned: {decoded:?}"
);
assert_eq!(decoded.len(), entry.decompressed_size as usize);
}
}
#[test]
fn writer_block_mode_splits_records_correctly() {
let mut buf = Vec::new();
let mut w = Writer::new(&mut buf, WriterConfig::default());
w.write_block(b"one\ntwo\npar").unwrap();
w.write_block(b"tial\nthree\n").unwrap();
let stats = w.finish().unwrap();
assert_eq!(stats.records, 4, "stats: {stats:?}");
let decoded = decompress_all(&buf);
assert_eq!(decoded, b"one\ntwo\npartial\nthree\n");
}
#[test]
fn finish_on_empty_writer_produces_parseable_file() {
let mut buf = Vec::new();
let w = Writer::new(&mut buf, WriterConfig::default());
let stats = w.finish().unwrap();
assert_eq!(stats.frames, 0);
let table = SeekTable::read_from_tail(std::io::Cursor::new(&buf)).unwrap();
assert_eq!(table.num_frames(), 0);
}
}