use alloc::vec;
use alloc::vec::Vec;
use zstd_safe::{DCtx, InBuffer, OutBuffer, ResetDirective};
use crate::{
Error,
error::Result,
seek_table::SeekTable,
seekable::{OffsetFrom, Seekable},
};
pub struct DecodeOptions<'a, S> {
dctx: DCtx<'a>,
src: S,
seek_table: Option<SeekTable>,
lower_frame: Option<u32>,
offset: Option<u64>,
upper_frame: Option<u32>,
offset_limit: Option<u64>,
}
impl<'a, S> DecodeOptions<'a, S> {
pub fn new(src: S) -> Self {
Self::with_dctx(src, DCtx::create())
}
pub fn try_new(src: S) -> Option<Self> {
let dctx = DCtx::try_create()?;
Some(Self::with_dctx(src, dctx))
}
pub fn with_dctx(src: S, dctx: DCtx<'a>) -> Self {
Self {
dctx,
src,
seek_table: None,
lower_frame: None,
offset: None,
upper_frame: None,
offset_limit: None,
}
}
pub fn dctx(mut self, dctx: DCtx<'a>) -> Self {
self.dctx = dctx;
self
}
pub fn seek_table(mut self, seek_table: SeekTable) -> Self {
self.seek_table = Some(seek_table);
self
}
pub fn lower_frame(mut self, index: u32) -> Self {
self.lower_frame = Some(index);
self
}
pub fn upper_frame(mut self, index: u32) -> Self {
self.upper_frame = Some(index);
self
}
pub fn offset(mut self, offset: u64) -> Self {
self.offset = Some(offset);
self
}
pub fn offset_limit(mut self, limit: u64) -> Self {
self.offset_limit = Some(limit);
self
}
}
impl<'a, S: Seekable> DecodeOptions<'a, S> {
pub fn into_decoder(self) -> Result<Decoder<'a, S>> {
Decoder::with_opts(self)
}
}
pub struct Decoder<'a, S> {
dctx: DCtx<'a>,
seek_table: SeekTable,
src: S,
decomp_pos: u64,
offset: u64,
offset_limit: u64,
in_buf: Vec<u8>,
in_buf_pos: usize,
in_buf_limit: usize,
out_buf: Vec<u8>,
read_compressed: u64,
}
impl<'a, S: Seekable> Decoder<'a, S> {
pub fn new(src: S) -> Result<Self> {
Self::with_opts(DecodeOptions::new(src))
}
pub fn with_opts(mut opts: DecodeOptions<'a, S>) -> Result<Self> {
let seek_table = opts
.seek_table
.map_or_else(|| SeekTable::from_seekable(&mut opts.src), Ok)?;
let offset = if let Some(index) = opts.lower_frame {
seek_table.frame_start_decomp(index)?
} else {
opts.offset.unwrap_or(0)
};
Self::check_offset(offset, &seek_table)?;
let offset_limit = if let Some(index) = opts.upper_frame {
seek_table.frame_end_decomp(index)?
} else {
opts.offset_limit
.unwrap_or_else(|| seek_table.size_decomp())
};
Self::check_offset(offset_limit, &seek_table)?;
Ok(Self {
dctx: opts.dctx,
seek_table,
src: opts.src,
decomp_pos: 0,
offset,
offset_limit,
in_buf: vec![0; DCtx::in_size()],
in_buf_pos: 0,
in_buf_limit: 0,
out_buf: vec![0; DCtx::out_size()],
read_compressed: 0,
})
}
#[allow(clippy::missing_panics_doc)]
pub fn decompress_with_prefix<'b: 'a>(
&mut self,
buf: &mut [u8],
prefix: Option<&'b [u8]>,
) -> Result<usize> {
if self.read_compressed == 0 {
let frame_idx = self.seek_table.frame_index_decomp(self.offset);
let start_pos = self.seek_table.frame_start_comp(frame_idx)?;
self.src.set_offset(OffsetFrom::Start(start_pos))?;
self.decomp_pos = self.seek_table.frame_start_decomp(frame_idx)?;
if let Some(pref) = prefix {
self.dctx.ref_prefix(pref)?;
}
self.in_buf_pos = 0;
self.in_buf_limit = 0;
}
let mut output_progress = 0;
while self.offset < self.offset_limit && output_progress < buf.len() {
if self.in_buf_pos == self.in_buf_limit {
self.in_buf_limit = self.src.read(&mut self.in_buf)?;
self.in_buf_pos = 0;
}
let mut in_buffer = InBuffer::around(&self.in_buf[self.in_buf_pos..self.in_buf_limit]);
let mut out_buffer = if self.decomp_pos < self.offset {
let limit = (self.offset - self.decomp_pos).min(self.out_buf.len() as u64) as usize;
OutBuffer::around(&mut self.out_buf[..limit])
} else {
let remaining: usize = (self.offset_limit - self.decomp_pos)
.try_into()
.unwrap_or(usize::MAX);
let limit = buf.len().min(output_progress + remaining);
OutBuffer::around(&mut buf[output_progress..limit])
};
let in_len = self.in_buf_limit - self.in_buf_pos;
while in_buffer.pos() < in_len && out_buffer.pos() < out_buffer.capacity() {
let n = self
.dctx
.decompress_stream(&mut out_buffer, &mut in_buffer)?;
if n == 0 {
if let Some(pref) = prefix {
self.dctx
.reset(ResetDirective::SessionOnly)
.expect("Resetting session never fails");
self.dctx.ref_prefix(pref)?;
}
}
}
self.decomp_pos += out_buffer.pos() as u64;
self.in_buf_pos += in_buffer.pos();
self.read_compressed += in_buffer.pos() as u64;
if self.decomp_pos > self.offset {
self.offset += out_buffer.pos() as u64;
output_progress += out_buffer.pos();
}
}
Ok(output_progress)
}
}
impl<S: Seekable> Decoder<'_, S> {
pub fn decompress(&mut self, buf: &mut [u8]) -> Result<usize> {
self.decompress_with_prefix(buf, None)
}
pub fn reset(&mut self) {
self.reset_dctx();
self.offset = 0;
self.offset_limit = self.seek_table().size_decomp();
}
fn reset_dctx(&mut self) {
self.read_compressed = 0;
self.dctx
.reset(ResetDirective::SessionOnly)
.expect("Resetting session never fails");
}
pub fn set_lower_frame(&mut self, index: u32) -> Result<u64> {
let offset = self.seek_table.frame_start_decomp(index)?;
self.set_offset(offset)?;
Ok(offset)
}
pub fn set_upper_frame(&mut self, index: u32) -> Result<u64> {
let offset = self.seek_table.frame_end_decomp(index)?;
self.set_offset_limit(offset)?;
Ok(offset)
}
pub fn set_offset(&mut self, offset: u64) -> Result<()> {
Self::check_offset(offset, self.seek_table())?;
let current_frame = self.seek_table().frame_index_decomp(self.offset);
let target_frame = self.seek_table().frame_index_decomp(offset);
if current_frame != target_frame || offset < self.offset {
self.reset_dctx();
}
self.offset = offset;
Ok(())
}
pub fn set_offset_limit(&mut self, limit: u64) -> Result<()> {
Self::check_offset(limit, self.seek_table())?;
self.offset_limit = limit;
Ok(())
}
fn check_offset(offset: u64, seek_table: &SeekTable) -> Result<()> {
if offset > seek_table.size_decomp() {
Err(Error::offset_out_of_range())
} else {
Ok(())
}
}
pub fn read_compressed(&self) -> u64 {
self.read_compressed
}
pub fn seek_table(&self) -> &SeekTable {
&self.seek_table
}
pub fn offset(&self) -> u64 {
self.offset
}
pub fn offset_limit(&self) -> u64 {
self.offset_limit
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<S: Seekable> std::io::Read for Decoder<'_, S> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.decompress(buf).map_err(std::io::Error::other)
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<S: Seekable> std::io::Seek for Decoder<'_, S> {
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
use std::io::{self, SeekFrom};
match pos {
SeekFrom::Start(offset) => {
self.set_offset(offset).map_err(io::Error::other)?;
Ok(offset)
}
SeekFrom::End(n) => {
if n > 0 {
return Err(io::Error::other(Error::offset_out_of_range()));
}
let offset = self
.seek_table()
.size_decomp()
.checked_add_signed(n)
.ok_or(io::Error::other(Error::offset_out_of_range()))?;
self.set_offset(offset).map_err(io::Error::other)?;
Ok(offset)
}
SeekFrom::Current(n) => {
let offset = self
.offset
.checked_add_signed(n)
.ok_or(io::Error::other(Error::offset_out_of_range()))?;
self.set_offset(offset).map_err(io::Error::other)?;
Ok(offset)
}
}
}
}
#[cfg(test)]
mod tests {
use crate::{BytesWrapper, EncodeOptions, FrameSizePolicy, tests::INPUT};
use super::*;
fn new_seekable(frame_size_policy: Option<FrameSizePolicy>) -> Vec<u8> {
let mut seekable = vec![];
let mut encoder = EncodeOptions::new()
.frame_size_policy(frame_size_policy.unwrap_or_default())
.into_raw_encoder()
.unwrap();
let mut buf = vec![0; INPUT.len()];
let mut in_progress = 0;
let mut out_progress = 0;
while in_progress < INPUT.len() {
let progress = encoder
.compress(&INPUT.as_bytes()[in_progress..], &mut buf)
.unwrap();
seekable.extend(&buf[..progress.out_progress()]);
in_progress += progress.in_progress();
out_progress += progress.out_progress();
}
assert_eq!(in_progress, INPUT.len());
loop {
let prog = encoder.end_frame(&mut buf).unwrap();
seekable.extend(&buf[..prog.out_progress()]);
out_progress += prog.out_progress();
if prog.data_left() == 0 {
break;
}
}
assert_eq!(out_progress, seekable.len());
let mut ser = encoder.into_seek_table().into_serializer();
loop {
let n = ser.write_into(&mut buf);
if n == 0 {
break;
}
seekable.extend(&buf[..n]);
}
assert_eq!(out_progress + ser.encoded_len(), seekable.len());
seekable
}
#[test]
fn options() {
let seekable = new_seekable(None);
let mut seekable = BytesWrapper::new(&seekable);
let st = SeekTable::from_seekable(&mut seekable).unwrap();
let oks = [
DecodeOptions::new(seekable.clone()),
DecodeOptions::new(seekable.clone()).lower_frame(st.num_frames() - 1),
DecodeOptions::new(seekable.clone()).upper_frame(st.num_frames() - 1),
DecodeOptions::new(seekable.clone()).offset(st.size_decomp()),
DecodeOptions::new(seekable.clone()).offset_limit(st.size_decomp()),
DecodeOptions::new(BytesWrapper::new(&[0, 128])).seek_table(st.clone()),
];
let errs = [
DecodeOptions::new(BytesWrapper::new(&[0, 128])),
DecodeOptions::new(seekable.clone()).lower_frame(st.num_frames()),
DecodeOptions::new(seekable.clone()).upper_frame(st.num_frames()),
DecodeOptions::new(seekable.clone()).offset(st.size_decomp() + 1),
DecodeOptions::new(seekable.clone()).offset_limit(st.size_decomp() + 1),
];
for opts in oks {
assert!(opts.into_decoder().is_ok());
}
for opts in errs {
assert!(opts.into_decoder().is_err());
}
}
#[test]
fn decompress_and_reset() {
let seekable = new_seekable(None);
let mut decoder = Decoder::new(BytesWrapper::new(&seekable)).unwrap();
let mut output = vec![0; INPUT.len()];
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, output.len());
assert_eq!(INPUT.as_bytes(), output);
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, 0);
decoder.reset();
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, output.len());
assert_eq!(INPUT.as_bytes(), output);
}
#[test]
fn decompress_until_upper_frame() {
let frame_size = INPUT.len() / 7;
let seekable = new_seekable(Some(FrameSizePolicy::Uncompressed(frame_size as u32)));
let mut decoder = Decoder::new(BytesWrapper::new(&seekable)).unwrap();
decoder.set_lower_frame(0).unwrap();
decoder.set_upper_frame(5).unwrap();
let len = frame_size * 6;
let mut output = vec![0; len];
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, len);
assert_eq!(&INPUT.as_bytes()[..n], &output);
}
#[test]
fn decompress_last_frames() {
let frame_size = INPUT.len() / 9;
let seekable = new_seekable(Some(FrameSizePolicy::Uncompressed(frame_size as u32)));
let mut decoder = Decoder::new(BytesWrapper::new(&seekable)).unwrap();
decoder.set_lower_frame(5).unwrap();
decoder.set_upper_frame(9).unwrap();
let len = INPUT.len() - frame_size * 5;
let mut output = vec![0; len];
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, len);
assert_eq!(&INPUT.as_bytes()[INPUT.len() - n..], &output);
}
#[test]
fn upper_frame_greater_than_lower_frame() {
let frame_size = INPUT.len() / 13;
let seekable = new_seekable(Some(FrameSizePolicy::Uncompressed(frame_size as u32)));
let mut decoder = Decoder::new(BytesWrapper::new(&seekable)).unwrap();
decoder.set_lower_frame(9).unwrap();
decoder.set_upper_frame(8).unwrap();
let mut output = vec![0; INPUT.len()];
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(0, n);
}
#[test]
fn reset_decompression() {
let seekable = new_seekable(None);
let mut decoder = Decoder::new(BytesWrapper::new(&seekable)).unwrap();
decoder.decompress(&mut [0; 128]).unwrap();
decoder.reset();
let mut output = vec![0; INPUT.len()];
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, INPUT.len());
assert_eq!(INPUT.as_bytes(), output);
}
#[test]
fn decompress_everything_after_partly_decompression() {
let frame_size = INPUT.len() / 32;
let seekable = new_seekable(Some(FrameSizePolicy::Uncompressed(frame_size as u32)));
let mut decoder = Decoder::new(BytesWrapper::new(&seekable)).unwrap();
decoder.set_lower_frame(23).unwrap();
decoder.set_upper_frame(29).unwrap();
let mut output = vec![0; INPUT.len()];
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, frame_size * 30 - frame_size * 23);
assert_eq!(
INPUT.as_bytes()[frame_size * 23..frame_size * 30],
output[..n]
);
decoder.set_lower_frame(0).unwrap();
decoder
.set_upper_frame(decoder.seek_table().num_frames() - 1)
.unwrap();
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, INPUT.len());
assert_eq!(INPUT.as_bytes(), output);
}
#[test]
fn set_frame_boundaries() {
let seekable = new_seekable(None);
let mut decoder = Decoder::new(BytesWrapper::new(&seekable)).unwrap();
let num_frames = decoder.seek_table().num_frames();
assert!(decoder.set_lower_frame(num_frames - 1).is_ok());
assert!(decoder.set_upper_frame(num_frames - 1).is_ok());
assert!(
decoder
.set_lower_frame(num_frames)
.unwrap_err()
.is_frame_index_too_large()
);
assert!(
decoder
.set_upper_frame(num_frames)
.unwrap_err()
.is_frame_index_too_large()
);
}
#[test]
fn set_offset_boundaries() {
let seekable = new_seekable(None);
let mut decoder = Decoder::new(BytesWrapper::new(&seekable)).unwrap();
let mut offset = decoder.seek_table().size_decomp();
assert!(decoder.set_offset(offset).is_ok());
assert!(decoder.set_offset_limit(offset).is_ok());
offset += 1;
assert!(
decoder
.set_offset(offset)
.unwrap_err()
.is_offset_out_of_range()
);
assert!(
decoder
.set_offset_limit(offset)
.unwrap_err()
.is_offset_out_of_range()
);
}
#[test]
fn decompress_within_offset_boundaries() {
let frame_size = INPUT.len() / 34;
let seekable = new_seekable(Some(FrameSizePolicy::Uncompressed(frame_size as u32)));
let mut decoder = Decoder::new(BytesWrapper::new(&seekable)).unwrap();
let offset = INPUT.len() / 3;
let offset_limit = 2 * offset;
decoder.set_offset(offset as u64).unwrap();
decoder.set_offset_limit(offset_limit as u64).unwrap();
let mut output = vec![0; INPUT.len()];
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, offset_limit - offset);
assert_eq!(INPUT.as_bytes()[offset..offset_limit], output[..n]);
decoder.set_offset(3).unwrap();
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, offset_limit - 3);
assert_eq!(INPUT.as_bytes()[3..offset_limit], output[..n]);
decoder.reset();
assert_eq!(decoder.offset(), 0);
assert_eq!(decoder.offset_limit(), decoder.seek_table().size_decomp());
assert_eq!(decoder.read_compressed(), 0);
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, INPUT.len());
assert_eq!(INPUT.as_bytes(), output);
}
#[cfg(feature = "std")]
#[test]
#[allow(clippy::cast_sign_loss)]
fn seek_decoder() {
use std::io::{Seek, SeekFrom};
let frame_size = INPUT.len() / 52;
let seekable = new_seekable(Some(FrameSizePolicy::Uncompressed(frame_size as u32)));
let mut decoder = Decoder::new(BytesWrapper::new(&seekable)).unwrap();
let seek_pos = frame_size * 13;
let end = frame_size * 51;
decoder.set_offset_limit(end as u64).unwrap();
decoder.seek(SeekFrom::Start(seek_pos as u64)).unwrap();
assert_eq!(decoder.offset(), seek_pos as u64);
let mut output = vec![0; INPUT.len()];
let n = decoder.decompress(&mut output).unwrap();
assert_ne!(decoder.read_compressed(), 0);
assert_eq!(n, end - seek_pos);
assert_eq!(INPUT.as_bytes()[seek_pos..end], output[..n]);
assert_eq!(decoder.offset(), end as u64);
let seek_pos = -((2 * frame_size) as i64);
let start = (INPUT.len() as i64 + seek_pos) as usize;
assert_ne!(decoder.read_compressed(), 0);
decoder.seek(SeekFrom::End(seek_pos)).unwrap();
assert_eq!(
decoder.offset(),
(INPUT.len() as u64).wrapping_add_signed(seek_pos)
);
assert_eq!(decoder.read_compressed(), 0);
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, end - start);
assert_eq!(INPUT.as_bytes()[start..end], output[..n]);
decoder.seek(SeekFrom::Start(69)).unwrap();
decoder.seek(SeekFrom::Current(10)).unwrap();
assert_eq!(decoder.offset(), 79);
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, end - 79);
assert_eq!(INPUT.as_bytes()[79..end], output[..n]);
decoder.seek(SeekFrom::Start(69)).unwrap();
decoder.seek(SeekFrom::Current(-10)).unwrap();
assert_eq!(decoder.offset(), 59);
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, end - 59);
assert_eq!(INPUT.as_bytes()[59..end], output[..n]);
}
#[cfg(feature = "std")]
#[test]
fn set_offset_within_frame_continues_decompression() {
use std::io::Read;
let seekable = new_seekable(Some(FrameSizePolicy::Uncompressed(100)));
let mut decoder = Decoder::new(BytesWrapper::new(&seekable)).unwrap();
assert_eq!(decoder.read_compressed(), 0);
decoder.set_offset(10).unwrap();
decoder.read_exact(&mut [0; 10]).unwrap();
assert_ne!(decoder.read_compressed(), 0);
decoder.set_offset(30).unwrap();
assert_eq!(decoder.offset(), 30);
assert_ne!(decoder.read_compressed(), 0);
let mut output = vec![0; INPUT.len()];
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, INPUT.len() - 30);
assert_eq!(INPUT.as_bytes()[30..], output[..n]);
decoder.set_offset(101).unwrap();
assert_eq!(decoder.offset(), 101);
assert_eq!(decoder.read_compressed(), 0);
let n = decoder.decompress(&mut output).unwrap();
assert_eq!(n, INPUT.len() - 101);
assert_eq!(INPUT.as_bytes()[101..], output[..n]);
}
}