compression_codecs/zstd/
decoder.rs

1use crate::{
2    zstd::{params::DParameter, OperationExt},
3    {DecodeV2, DecodedSize},
4};
5use compression_core::{
6    unshared::Unshared,
7    util::{PartialBuffer, WriteBuffer},
8};
9use libzstd::stream::raw::Decoder;
10use std::{
11    convert::TryInto,
12    io::{self, Result},
13};
14use zstd_safe::get_error_name;
15
16#[derive(Debug)]
17pub struct ZstdDecoder {
18    decoder: Unshared<Decoder<'static>>,
19    stream_ended: bool,
20}
21
22impl Default for ZstdDecoder {
23    fn default() -> Self {
24        Self {
25            decoder: Unshared::new(Decoder::new().unwrap()),
26            stream_ended: false,
27        }
28    }
29}
30
31impl ZstdDecoder {
32    pub fn new() -> Self {
33        Self::default()
34    }
35
36    pub fn new_with_params(params: &[DParameter]) -> Self {
37        let mut decoder = Decoder::new().unwrap();
38        for param in params {
39            decoder.set_parameter(param.as_zstd()).unwrap();
40        }
41        Self {
42            decoder: Unshared::new(decoder),
43            stream_ended: false,
44        }
45    }
46
47    pub fn new_with_dict(dictionary: &[u8]) -> io::Result<Self> {
48        let decoder = Decoder::with_dictionary(dictionary)?;
49        Ok(Self {
50            decoder: Unshared::new(decoder),
51            stream_ended: false,
52        })
53    }
54}
55
56impl DecodeV2 for ZstdDecoder {
57    fn reinit(&mut self) -> Result<()> {
58        self.decoder.reinit()?;
59        self.stream_ended = false;
60        Ok(())
61    }
62
63    fn decode(
64        &mut self,
65        input: &mut PartialBuffer<&[u8]>,
66        output: &mut WriteBuffer<'_>,
67    ) -> Result<bool> {
68        let finished = self.decoder.run(input, output)?;
69        if finished {
70            self.stream_ended = true;
71        }
72        Ok(finished)
73    }
74
75    fn flush(&mut self, output: &mut WriteBuffer<'_>) -> Result<bool> {
76        // Note: stream_ended is not updated here because zstd's flush only flushes
77        // buffered output and doesn't indicate stream completion. Stream completion
78        // is detected in decode() when status.remaining == 0.
79        self.decoder.flush(output)
80    }
81
82    fn finish(&mut self, output: &mut WriteBuffer<'_>) -> Result<bool> {
83        self.decoder.finish(output)?;
84
85        if self.stream_ended {
86            Ok(true)
87        } else {
88            Err(io::Error::new(
89                io::ErrorKind::UnexpectedEof,
90                "zstd stream did not finish",
91            ))
92        }
93    }
94}
95
96impl DecodedSize for ZstdDecoder {
97    fn decoded_size(input: &[u8]) -> Result<u64> {
98        zstd_safe::find_frame_compressed_size(input)
99            .map_err(|error_code| io::Error::other(get_error_name(error_code)))
100            .and_then(|size| {
101                size.try_into()
102                    .map_err(|_| io::Error::from(io::ErrorKind::FileTooLarge))
103            })
104    }
105}