compression_codecs/zstd/
decoder.rs1use crate::{
2 zstd::{params::DParameter, OperationExt},
3 {DecodeV2, DecodedSize},
4};
5use compression_core::{
6 unshared::Unshared,
7 util::{PartialBuffer, WriteBuffer},
8};
9use libzstd::stream::raw::Decoder;
10use std::{
11 convert::TryInto,
12 io::{self, Result},
13};
14use zstd_safe::get_error_name;
15
16#[derive(Debug)]
17pub struct ZstdDecoder {
18 decoder: Unshared<Decoder<'static>>,
19 stream_ended: bool,
20}
21
22impl Default for ZstdDecoder {
23 fn default() -> Self {
24 Self {
25 decoder: Unshared::new(Decoder::new().unwrap()),
26 stream_ended: false,
27 }
28 }
29}
30
31impl ZstdDecoder {
32 pub fn new() -> Self {
33 Self::default()
34 }
35
36 pub fn new_with_params(params: &[DParameter]) -> Self {
37 let mut decoder = Decoder::new().unwrap();
38 for param in params {
39 decoder.set_parameter(param.as_zstd()).unwrap();
40 }
41 Self {
42 decoder: Unshared::new(decoder),
43 stream_ended: false,
44 }
45 }
46
47 pub fn new_with_dict(dictionary: &[u8]) -> io::Result<Self> {
48 let decoder = Decoder::with_dictionary(dictionary)?;
49 Ok(Self {
50 decoder: Unshared::new(decoder),
51 stream_ended: false,
52 })
53 }
54}
55
56impl DecodeV2 for ZstdDecoder {
57 fn reinit(&mut self) -> Result<()> {
58 self.decoder.reinit()?;
59 self.stream_ended = false;
60 Ok(())
61 }
62
63 fn decode(
64 &mut self,
65 input: &mut PartialBuffer<&[u8]>,
66 output: &mut WriteBuffer<'_>,
67 ) -> Result<bool> {
68 let finished = self.decoder.run(input, output)?;
69 if finished {
70 self.stream_ended = true;
71 }
72 Ok(finished)
73 }
74
75 fn flush(&mut self, output: &mut WriteBuffer<'_>) -> Result<bool> {
76 self.decoder.flush(output)
80 }
81
82 fn finish(&mut self, output: &mut WriteBuffer<'_>) -> Result<bool> {
83 self.decoder.finish(output)?;
84
85 if self.stream_ended {
86 Ok(true)
87 } else {
88 Err(io::Error::new(
89 io::ErrorKind::UnexpectedEof,
90 "zstd stream did not finish",
91 ))
92 }
93 }
94}
95
96impl DecodedSize for ZstdDecoder {
97 fn decoded_size(input: &[u8]) -> Result<u64> {
98 zstd_safe::find_frame_compressed_size(input)
99 .map_err(|error_code| io::Error::other(get_error_name(error_code)))
100 .and_then(|size| {
101 size.try_into()
102 .map_err(|_| io::Error::from(io::ErrorKind::FileTooLarge))
103 })
104 }
105}