use std::io;
pub use zstd_safe::{CParameter, DParameter, InBuffer, OutBuffer};
use crate::dict::{DecoderDictionary, EncoderDictionary};
use crate::map_error_code;
pub trait Operation {
fn run(
&mut self,
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_>,
) -> io::Result<usize>;
fn run_on_buffers(
&mut self,
input: &[u8],
output: &mut [u8],
) -> io::Result<Status> {
let mut input = InBuffer::around(input);
let mut output = OutBuffer::around(output);
let remaining = self.run(&mut input, &mut output)?;
Ok(Status {
remaining,
bytes_read: input.pos,
bytes_written: output.pos,
})
}
fn flush(&mut self, output: &mut OutBuffer<'_>) -> io::Result<usize> {
let _ = output;
Ok(0)
}
fn reinit(&mut self) -> io::Result<()> {
Ok(())
}
fn finish(
&mut self,
output: &mut OutBuffer<'_>,
finished_frame: bool,
) -> io::Result<usize> {
let _ = output;
let _ = finished_frame;
Ok(0)
}
}
pub struct NoOp;
impl Operation for NoOp {
fn run(
&mut self,
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_>,
) -> io::Result<usize> {
let src = &input.src[input.pos..];
let dst = &mut output.dst[output.pos..];
let len = usize::min(src.len(), dst.len());
let src = &src[..len];
let dst = &mut dst[..len];
dst.copy_from_slice(src);
input.pos += len;
output.pos += len;
Ok(0)
}
}
pub struct Status {
pub remaining: usize,
pub bytes_read: usize,
pub bytes_written: usize,
}
pub struct Decoder<'a> {
context: zstd_safe::DCtx<'a>,
}
impl Decoder<'static> {
pub fn new() -> io::Result<Self> {
Self::with_dictionary(&[])
}
pub fn with_dictionary(dictionary: &[u8]) -> io::Result<Self> {
let mut context = zstd_safe::DCtx::create();
context.init();
context
.load_dictionary(dictionary)
.map_err(map_error_code)?;
Ok(Decoder { context })
}
}
impl<'a> Decoder<'a> {
pub fn with_prepared_dictionary<'b>(
dictionary: &DecoderDictionary<'b>,
) -> io::Result<Self>
where
'b: 'a,
{
let mut context = zstd_safe::DCtx::create();
context
.ref_ddict(dictionary.as_ddict())
.map_err(map_error_code)?;
Ok(Decoder { context })
}
pub fn set_parameter(&mut self, parameter: DParameter) -> io::Result<()> {
self.context
.set_parameter(parameter)
.map_err(map_error_code)?;
Ok(())
}
}
impl Operation for Decoder<'_> {
fn run(
&mut self,
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_>,
) -> io::Result<usize> {
self.context
.decompress_stream(output, input)
.map_err(map_error_code)
}
fn reinit(&mut self) -> io::Result<()> {
self.context.reset().map_err(map_error_code)?;
Ok(())
}
fn finish(
&mut self,
_output: &mut OutBuffer<'_>,
finished_frame: bool,
) -> io::Result<usize> {
if finished_frame {
Ok(0)
} else {
Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"incomplete frame",
))
}
}
}
pub struct Encoder<'a> {
context: zstd_safe::CCtx<'a>,
}
impl Encoder<'static> {
pub fn new(level: i32) -> io::Result<Self> {
Self::with_dictionary(level, &[])
}
pub fn with_dictionary(level: i32, dictionary: &[u8]) -> io::Result<Self> {
let mut context = zstd_safe::CCtx::create();
context
.set_parameter(CParameter::CompressionLevel(level))
.map_err(map_error_code)?;
context
.load_dictionary(dictionary)
.map_err(map_error_code)?;
Ok(Encoder { context })
}
}
impl<'a> Encoder<'a> {
pub fn with_prepared_dictionary<'b>(
dictionary: &EncoderDictionary<'b>,
) -> io::Result<Self>
where
'b: 'a,
{
let mut context = zstd_safe::CCtx::create();
context
.ref_cdict(dictionary.as_cdict())
.map_err(map_error_code)?;
Ok(Encoder { context })
}
pub fn set_parameter(&mut self, parameter: CParameter) -> io::Result<()> {
self.context
.set_parameter(parameter)
.map_err(map_error_code)?;
Ok(())
}
}
impl<'a> Operation for Encoder<'a> {
fn run(
&mut self,
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_>,
) -> io::Result<usize> {
self.context
.compress_stream(output, input)
.map_err(map_error_code)
}
fn flush(&mut self, output: &mut OutBuffer<'_>) -> io::Result<usize> {
self.context.flush_stream(output).map_err(map_error_code)
}
fn finish(
&mut self,
output: &mut OutBuffer<'_>,
_finished_frame: bool,
) -> io::Result<usize> {
self.context.end_stream(output).map_err(map_error_code)
}
fn reinit(&mut self) -> io::Result<()> {
self.context
.reset(zstd_safe::ResetDirective::ZSTD_reset_session_only)
.map_err(map_error_code)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::{Decoder, Encoder, InBuffer, Operation, OutBuffer};
#[test]
fn test_cycle() {
let mut encoder = Encoder::new(1).unwrap();
let mut decoder = Decoder::new().unwrap();
let mut input = InBuffer::around(b"AbcdefAbcdefabcdef");
let mut output = [0u8; 128];
let mut output = OutBuffer::around(&mut output);
loop {
encoder.run(&mut input, &mut output).unwrap();
if input.pos == input.src.len() {
break;
}
}
encoder.finish(&mut output, true).unwrap();
let initial_data = input.src;
let mut input = InBuffer::around(output.as_slice());
let mut output = [0u8; 128];
let mut output = OutBuffer::around(&mut output);
loop {
decoder.run(&mut input, &mut output).unwrap();
if input.pos == input.src.len() {
break;
}
}
assert_eq!(initial_data, output.as_slice());
}
}