compression_codecs/zstd/
decoder.rs1use crate::zstd::params::DParameter;
2use crate::{DecodeV2, DecodedSize};
3use compression_core::{
4 unshared::Unshared,
5 util::{PartialBuffer, WriteBuffer},
6};
7use libzstd::stream::raw::{Decoder, Operation};
8use std::convert::TryInto;
9use std::io;
10use std::io::Result;
11use zstd_safe::get_error_name;
12
13#[derive(Debug)]
14pub struct ZstdDecoder {
15 decoder: Unshared<Decoder<'static>>,
16 stream_ended: bool,
17}
18
19impl Default for ZstdDecoder {
20 fn default() -> Self {
21 Self {
22 decoder: Unshared::new(Decoder::new().unwrap()),
23 stream_ended: false,
24 }
25 }
26}
27
28impl ZstdDecoder {
29 pub fn new() -> Self {
30 Self::default()
31 }
32
33 pub fn new_with_params(params: &[DParameter]) -> Self {
34 let mut decoder = Decoder::new().unwrap();
35 for param in params {
36 decoder.set_parameter(param.as_zstd()).unwrap();
37 }
38 Self {
39 decoder: Unshared::new(decoder),
40 stream_ended: false,
41 }
42 }
43
44 pub fn new_with_dict(dictionary: &[u8]) -> io::Result<Self> {
45 let decoder = Decoder::with_dictionary(dictionary)?;
46 Ok(Self {
47 decoder: Unshared::new(decoder),
48 stream_ended: false,
49 })
50 }
51
52 fn call_fn_on_out_buffer(
53 &mut self,
54 output: &mut WriteBuffer<'_>,
55 f: fn(&mut Decoder<'static>, &mut zstd_safe::OutBuffer<'_, [u8]>) -> io::Result<usize>,
56 ) -> io::Result<bool> {
57 let mut out_buf = zstd_safe::OutBuffer::around(output.initialize_unwritten());
58 let res = f(self.decoder.get_mut(), &mut out_buf);
59 let len = out_buf.as_slice().len();
60 output.advance(len);
61
62 res.map(|bytes_left| bytes_left == 0)
63 }
64}
65
66impl DecodeV2 for ZstdDecoder {
67 fn reinit(&mut self) -> Result<()> {
68 self.decoder.get_mut().reinit()?;
69 self.stream_ended = false;
70 Ok(())
71 }
72
73 fn decode(
74 &mut self,
75 input: &mut PartialBuffer<&[u8]>,
76 output: &mut WriteBuffer<'_>,
77 ) -> Result<bool> {
78 let status = self
79 .decoder
80 .get_mut()
81 .run_on_buffers(input.unwritten(), output.initialize_unwritten())?;
82 input.advance(status.bytes_read);
83 output.advance(status.bytes_written);
84
85 let finished = status.remaining == 0;
86 if finished {
87 self.stream_ended = true;
88 }
89 Ok(finished)
90 }
91
92 fn flush(&mut self, output: &mut WriteBuffer<'_>) -> Result<bool> {
93 self.call_fn_on_out_buffer(output, |decoder, out_buf| decoder.flush(out_buf))
97 }
98
99 fn finish(&mut self, output: &mut WriteBuffer<'_>) -> Result<bool> {
100 self.call_fn_on_out_buffer(output, |decoder, out_buf| decoder.finish(out_buf, true))?;
101
102 if self.stream_ended {
103 Ok(true)
104 } else {
105 Err(io::Error::new(
106 io::ErrorKind::UnexpectedEof,
107 "zstd stream did not finish",
108 ))
109 }
110 }
111}
112
113impl DecodedSize for ZstdDecoder {
114 fn decoded_size(input: &[u8]) -> Result<u64> {
115 zstd_safe::find_frame_compressed_size(input)
116 .map_err(|error_code| io::Error::other(get_error_name(error_code)))
117 .and_then(|size| {
118 size.try_into()
119 .map_err(|_| io::Error::from(io::ErrorKind::FileTooLarge))
120 })
121 }
122}