#[cfg(feature = "std")]
use alloc::vec::Vec;
use zstd_safe::{
CCtx, CParameter, CompressionLevel, InBuffer, OutBuffer, ResetDirective,
zstd_sys::ZSTD_EndDirective,
};
#[cfg(feature = "std")]
use crate::seek_table::Format;
use crate::{SEEKABLE_MAX_FRAME_SIZE, SeekTable, error::Result};
const MAX_FRAME_SIZE: u32 = SEEKABLE_MAX_FRAME_SIZE as u32;
#[derive(Debug, Clone)]
pub enum FrameSizePolicy {
Compressed(u32),
Uncompressed(u32),
}
impl Default for FrameSizePolicy {
fn default() -> Self {
Self::Uncompressed(0x200_000)
}
}
#[derive(Debug)]
pub struct CompressionProgress {
in_progress: usize,
out_progress: usize,
}
impl CompressionProgress {
fn new(in_progress: usize, out_progress: usize) -> Self {
Self {
in_progress,
out_progress,
}
}
pub fn in_progress(&self) -> usize {
self.in_progress
}
pub fn out_progress(&self) -> usize {
self.out_progress
}
}
#[derive(Debug)]
pub struct EpilogueProgress {
out_progress: usize,
data_left: usize,
}
impl EpilogueProgress {
fn new(out_progress: usize, data_left: usize) -> Self {
Self {
out_progress,
data_left,
}
}
pub fn out_progress(&self) -> usize {
self.out_progress
}
pub fn data_left(&self) -> usize {
self.data_left
}
}
pub struct EncodeOptions<'a> {
cctx: CCtx<'a>,
frame_policy: FrameSizePolicy,
checksum_flag: bool,
compression_level: CompressionLevel,
}
impl Default for EncodeOptions<'_> {
fn default() -> Self {
Self::new()
}
}
impl<'a> EncodeOptions<'a> {
pub fn new() -> Self {
Self::with_cctx(CCtx::create())
}
pub fn try_new() -> Option<Self> {
let cctx = CCtx::try_create()?;
Some(Self::with_cctx(cctx))
}
pub fn with_cctx(cctx: CCtx<'a>) -> Self {
Self {
cctx,
frame_policy: FrameSizePolicy::default(),
checksum_flag: false,
compression_level: CompressionLevel::default(),
}
}
pub fn cctx(mut self, cctx: CCtx<'a>) -> Self {
self.cctx = cctx;
self
}
pub fn frame_size_policy(mut self, policy: FrameSizePolicy) -> Self {
self.frame_policy = policy;
self
}
pub fn checksum_flag(mut self, flag: bool) -> Self {
self.checksum_flag = flag;
self
}
pub fn compression_level(mut self, level: CompressionLevel) -> Self {
self.compression_level = level;
self
}
pub fn into_raw_encoder(self) -> Result<RawEncoder<'a>> {
RawEncoder::with_opts(self)
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn into_encoder<W>(self, writer: W) -> Result<Encoder<'a, W>> {
Encoder::with_opts(writer, self)
}
}
pub struct RawEncoder<'a> {
cctx: CCtx<'a>,
frame_policy: FrameSizePolicy,
frame_c_size: u32,
frame_d_size: u32,
seek_table: SeekTable,
}
impl<'a> RawEncoder<'a> {
pub fn with_opts(mut opts: EncodeOptions<'a>) -> Result<Self> {
opts.cctx
.set_parameter(CParameter::CompressionLevel(opts.compression_level))?;
opts.cctx
.set_parameter(CParameter::ChecksumFlag(opts.checksum_flag))?;
Ok(Self {
cctx: opts.cctx,
frame_policy: opts.frame_policy,
frame_c_size: 0,
frame_d_size: 0,
seek_table: SeekTable::new(),
})
}
pub fn compress_with_prefix<'b: 'a>(
&mut self,
input: &[u8],
output: &mut [u8],
prefix: Option<&'b [u8]>,
) -> Result<CompressionProgress> {
if self.is_frame_complete() {
let mut out_progress = 0;
while out_progress < output.len() {
let progress = self.end_frame(&mut output[out_progress..])?;
out_progress += progress.out_progress;
if progress.data_left == 0 {
break;
}
}
Ok(CompressionProgress::new(0, out_progress))
} else {
let limit = input.len().min(self.remaining_frame_size());
let mut in_buf = InBuffer::around(&input[..limit]);
let mut out_buf = OutBuffer::around(output);
if let Some(pref) = prefix {
if self.frame_d_size == 0 {
self.cctx.ref_prefix(pref)?;
}
}
while in_buf.pos() < limit && out_buf.pos() < out_buf.capacity() {
self.cctx.compress_stream2(
&mut out_buf,
&mut in_buf,
ZSTD_EndDirective::ZSTD_e_continue,
)?;
}
self.frame_c_size += out_buf.pos() as u32;
self.frame_d_size += in_buf.pos() as u32;
Ok(CompressionProgress::new(in_buf.pos(), out_buf.pos()))
}
}
}
impl RawEncoder<'_> {
pub fn new() -> Result<Self> {
Self::with_opts(EncodeOptions::new())
}
pub fn compress(&mut self, input: &[u8], output: &mut [u8]) -> Result<CompressionProgress> {
self.compress_with_prefix(input, output, None)
}
pub fn end_frame(&mut self, output: &mut [u8]) -> Result<EpilogueProgress> {
let mut empty_buf = InBuffer::around(&[]);
let mut out_buf = OutBuffer::around(output);
loop {
let prev_out_pos = out_buf.pos();
let n = self.cctx.compress_stream2(
&mut out_buf,
&mut empty_buf,
ZSTD_EndDirective::ZSTD_e_end,
)?;
self.frame_c_size += (out_buf.pos() - prev_out_pos) as u32;
if n == 0 {
break;
}
if out_buf.pos() == out_buf.capacity() {
return Ok(EpilogueProgress::new(out_buf.pos(), n));
}
}
self.seek_table
.log_frame(self.frame_c_size, self.frame_d_size)?;
self.reset_frame();
Ok(EpilogueProgress::new(out_buf.pos(), 0))
}
pub fn seek_table(&self) -> &SeekTable {
&self.seek_table
}
pub fn into_seek_table(self) -> SeekTable {
self.seek_table
}
#[allow(clippy::missing_panics_doc)]
pub fn reset_frame(&mut self) {
self.frame_c_size = 0;
self.frame_d_size = 0;
self.cctx
.reset(ResetDirective::SessionOnly)
.expect("Resetting session never fails");
}
pub fn reset_seek_table(&mut self) {
self.seek_table = SeekTable::new();
}
fn remaining_frame_size(&self) -> usize {
let n = match self.frame_policy {
FrameSizePolicy::Compressed(_) => MAX_FRAME_SIZE - self.frame_d_size,
FrameSizePolicy::Uncompressed(limit) => MAX_FRAME_SIZE.min(limit) - self.frame_d_size,
};
n.try_into().expect("Remaining frame size fits in usize")
}
fn is_frame_complete(&self) -> bool {
match self.frame_policy {
FrameSizePolicy::Compressed(size) => {
size <= self.frame_c_size || MAX_FRAME_SIZE <= self.frame_d_size
}
FrameSizePolicy::Uncompressed(limit) => MAX_FRAME_SIZE.min(limit) <= self.frame_d_size,
}
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub struct Encoder<'a, W> {
raw: RawEncoder<'a>,
out_buf: Vec<u8>,
out_buf_pos: usize,
writer: W,
written_compressed: u64,
}
#[cfg(feature = "std")]
impl<'a, W> Encoder<'a, W> {
pub fn new(writer: W) -> Result<Self> {
Self::with_opts(writer, EncodeOptions::new())
}
pub fn with_opts(writer: W, opts: EncodeOptions<'a>) -> Result<Self> {
Ok(Self {
raw: opts.into_raw_encoder()?,
out_buf: alloc::vec![0; CCtx::out_size()],
out_buf_pos: 0,
writer,
written_compressed: 0,
})
}
}
#[cfg(feature = "std")]
impl<W> Encoder<'_, W> {
pub fn seek_table(&self) -> &SeekTable {
&self.raw.seek_table
}
pub fn written_compressed(&self) -> u64 {
self.written_compressed
}
pub fn into_seek_table(self) -> SeekTable {
self.raw.into_seek_table()
}
}
#[cfg(feature = "std")]
impl<'a, W: std::io::Write> Encoder<'a, W> {
pub fn compress_with_prefix<'b: 'a>(
&mut self,
buf: &[u8],
prefix: Option<&'b [u8]>,
) -> Result<usize> {
let mut input_progress = 0;
while input_progress < buf.len() {
let progress = self.raw.compress_with_prefix(
&buf[input_progress..],
&mut self.out_buf[self.out_buf_pos..],
prefix,
)?;
if progress.in_progress == 0 && progress.out_progress == 0 {
break;
}
self.out_buf_pos += progress.out_progress;
self.flush_out_buf(false)?;
input_progress += progress.in_progress;
}
Ok(input_progress)
}
}
#[cfg(feature = "std")]
impl<W: std::io::Write> Encoder<'_, W> {
pub fn compress(&mut self, buf: &[u8]) -> Result<usize> {
self.compress_with_prefix(buf, None)
}
pub fn end_frame(&mut self) -> Result<usize> {
let mut progress = 0;
loop {
let prog = self.raw.end_frame(&mut self.out_buf[self.out_buf_pos..])?;
self.out_buf_pos += prog.out_progress;
self.flush_out_buf(false)?;
progress += prog.out_progress;
if prog.data_left == 0 {
return Ok(progress);
}
}
}
pub fn finish(self) -> Result<u64> {
self.finish_format(Format::Foot)
}
pub fn finish_format(mut self, format: Format) -> Result<u64> {
self.end_frame()?;
let mut ser = self.raw.into_seek_table().into_format_serializer(format);
loop {
let n = ser.write_into(&mut self.out_buf[self.out_buf_pos..]);
if n == 0 {
self.writer.write_all(&self.out_buf[..self.out_buf_pos])?;
self.written_compressed += self.out_buf_pos as u64;
self.writer.flush()?;
return Ok(self.written_compressed);
}
self.out_buf_pos += n;
if self.out_buf_pos == self.out_buf.len() {
self.writer.write_all(&self.out_buf[..self.out_buf_pos])?;
self.written_compressed += self.out_buf_pos as u64;
self.out_buf_pos = 0;
}
}
}
#[inline]
fn flush_out_buf(&mut self, force: bool) -> Result<()> {
if self.out_buf_pos == self.out_buf.len() || force {
self.writer.write_all(&self.out_buf[..self.out_buf_pos])?;
self.written_compressed += self.out_buf_pos as u64;
self.out_buf_pos = 0;
}
Ok(())
}
}
#[cfg(feature = "std")]
impl<W: std::io::Write> std::io::Write for Encoder<'_, W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.compress(buf).map_err(std::io::Error::other)
}
fn flush(&mut self) -> std::io::Result<()> {
self.flush_out_buf(true).map_err(std::io::Error::other)?;
self.writer.flush()
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use crate::tests::INPUT;
use super::*;
#[test]
fn raw_encoder_reset() {
let mut encoder = RawEncoder::new().unwrap();
let mut buf = vec![0; 1024];
encoder.compress(b"Hello", &mut buf).unwrap();
encoder.end_frame(&mut buf).unwrap();
assert_eq!(encoder.seek_table().num_frames(), 1);
let first_st = encoder.seek_table().clone();
encoder.compress(b"Bye", &mut [0; 128]).unwrap();
encoder.reset_frame();
encoder.reset_seek_table();
assert_eq!(encoder.seek_table().num_frames(), 0);
encoder.compress(b"Hello", &mut buf).unwrap();
encoder.end_frame(&mut buf).unwrap();
assert_eq!(encoder.seek_table().num_frames(), 1);
debug_assert_eq!(&first_st, encoder.seek_table());
}
#[test]
fn checksum() {
let mut seekable = vec![];
let mut encoder = EncodeOptions::new()
.checksum_flag(true)
.frame_size_policy(FrameSizePolicy::Uncompressed(INPUT.len() as u32 / 3))
.into_raw_encoder()
.unwrap();
let mut buf = vec![0; INPUT.len()];
let mut in_progress = 0;
while in_progress < INPUT.len() {
let progress = encoder
.compress(&INPUT.as_bytes()[in_progress..], &mut buf)
.unwrap();
seekable.extend(&buf[..progress.out_progress]);
in_progress += progress.in_progress;
}
loop {
let prog = encoder.end_frame(&mut buf).unwrap();
seekable.extend(&buf[..prog.out_progress]);
if prog.data_left == 0 {
break;
}
}
let num_frames = encoder.seek_table().num_frames();
let st = encoder.into_seek_table();
for i in 0..num_frames {
let start_pos = st.frame_start_comp(i).unwrap();
let descriptor: u8 = seekable[start_pos as usize + 4];
assert!(descriptor & 0x4 > 0);
}
}
}