compression_codecs/zstd/
decoder.rs

1use crate::zstd::params::DParameter;
2use crate::{DecodeV2, DecodedSize};
3use compression_core::{
4    unshared::Unshared,
5    util::{PartialBuffer, WriteBuffer},
6};
7use libzstd::stream::raw::{Decoder, Operation};
8use std::convert::TryInto;
9use std::io;
10use std::io::Result;
11use zstd_safe::get_error_name;
12
13#[derive(Debug)]
14pub struct ZstdDecoder {
15    decoder: Unshared<Decoder<'static>>,
16    stream_ended: bool,
17}
18
19impl Default for ZstdDecoder {
20    fn default() -> Self {
21        Self {
22            decoder: Unshared::new(Decoder::new().unwrap()),
23            stream_ended: false,
24        }
25    }
26}
27
28impl ZstdDecoder {
29    pub fn new() -> Self {
30        Self::default()
31    }
32
33    pub fn new_with_params(params: &[DParameter]) -> Self {
34        let mut decoder = Decoder::new().unwrap();
35        for param in params {
36            decoder.set_parameter(param.as_zstd()).unwrap();
37        }
38        Self {
39            decoder: Unshared::new(decoder),
40            stream_ended: false,
41        }
42    }
43
44    pub fn new_with_dict(dictionary: &[u8]) -> io::Result<Self> {
45        let decoder = Decoder::with_dictionary(dictionary)?;
46        Ok(Self {
47            decoder: Unshared::new(decoder),
48            stream_ended: false,
49        })
50    }
51
52    fn call_fn_on_out_buffer(
53        &mut self,
54        output: &mut WriteBuffer<'_>,
55        f: fn(&mut Decoder<'static>, &mut zstd_safe::OutBuffer<'_, [u8]>) -> io::Result<usize>,
56    ) -> io::Result<bool> {
57        let mut out_buf = zstd_safe::OutBuffer::around(output.initialize_unwritten());
58        let res = f(self.decoder.get_mut(), &mut out_buf);
59        let len = out_buf.as_slice().len();
60        output.advance(len);
61
62        res.map(|bytes_left| bytes_left == 0)
63    }
64}
65
66impl DecodeV2 for ZstdDecoder {
67    fn reinit(&mut self) -> Result<()> {
68        self.decoder.get_mut().reinit()?;
69        self.stream_ended = false;
70        Ok(())
71    }
72
73    fn decode(
74        &mut self,
75        input: &mut PartialBuffer<&[u8]>,
76        output: &mut WriteBuffer<'_>,
77    ) -> Result<bool> {
78        let status = self
79            .decoder
80            .get_mut()
81            .run_on_buffers(input.unwritten(), output.initialize_unwritten())?;
82        input.advance(status.bytes_read);
83        output.advance(status.bytes_written);
84
85        let finished = status.remaining == 0;
86        if finished {
87            self.stream_ended = true;
88        }
89        Ok(finished)
90    }
91
92    fn flush(&mut self, output: &mut WriteBuffer<'_>) -> Result<bool> {
93        // Note: stream_ended is not updated here because zstd's flush only flushes
94        // buffered output and doesn't indicate stream completion. Stream completion
95        // is detected in decode() when status.remaining == 0.
96        self.call_fn_on_out_buffer(output, |decoder, out_buf| decoder.flush(out_buf))
97    }
98
99    fn finish(&mut self, output: &mut WriteBuffer<'_>) -> Result<bool> {
100        self.call_fn_on_out_buffer(output, |decoder, out_buf| decoder.finish(out_buf, true))?;
101
102        if self.stream_ended {
103            Ok(true)
104        } else {
105            Err(io::Error::new(
106                io::ErrorKind::UnexpectedEof,
107                "zstd stream did not finish",
108            ))
109        }
110    }
111}
112
113impl DecodedSize for ZstdDecoder {
114    fn decoded_size(input: &[u8]) -> Result<u64> {
115        zstd_safe::find_frame_compressed_size(input)
116            .map_err(|error_code| io::Error::other(get_error_name(error_code)))
117            .and_then(|size| {
118                size.try_into()
119                    .map_err(|_| io::Error::from(io::ErrorKind::FileTooLarge))
120            })
121    }
122}