Skip to main content

zrip_decode/
streaming.rs

1#![forbid(unsafe_code)]
2
3use std::io::{self, Read};
4
5use crate::BlockDecodeWorkspace;
6use crate::literals::decode_literals_ws;
7use crate::sequences::{SequenceDecodeTables, parse_sequence_count, parse_sequence_tables_ws};
8use zrip_core::block::{BlockType, parse_block_header};
9use zrip_core::error::DecompressError;
10use zrip_core::frame::MAX_BLOCK_SIZE;
11use zrip_core::frame::header::parse_frame_header;
12use zrip_core::xxhash::Xxh64State;
13
14#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
15use zrip_core::simd::CpuTier;
16
17enum State {
18    FrameHeader,
19    BlockHeader,
20    BlockData {
21        block_type: BlockType,
22        block_size: usize,
23        last: bool,
24    },
25    Checksum,
26    Done,
27}
28
29/// Streaming zstd decompressor implementing [`Read`].
30///
31/// Wraps a reader of compressed data and yields decompressed bytes.
32/// Supports multi-frame streams and skippable frames.
33///
34/// ```
35/// use std::io::Read;
36///
37/// let data = b"hello, streaming world!".repeat(100);
38/// let compressed = zrip::compress(&data, 1).unwrap();
39///
40/// let mut decoder = zrip::FrameDecoder::new(&compressed[..]);
41/// let mut output = Vec::new();
42/// decoder.read_to_end(&mut output).unwrap();
43/// assert_eq!(output, data);
44/// ```
45pub struct FrameDecoder<R: Read> {
46    inner: R,
47    state: State,
48    read_buf: Vec<u8>,
49    output_buf: Vec<u8>,
50    output_pos: usize,
51    ws: Box<BlockDecodeWorkspace>,
52    seq_tables: SequenceDecodeTables,
53    rep_offsets: [u32; 3],
54    hasher: Option<Xxh64State>,
55    content_checksum: bool,
56    max_output: usize,
57    bytes_output: usize,
58}
59
60impl<R: Read> FrameDecoder<R> {
61    /// Creates a decoder with [`DEFAULT_DECOMPRESS_LIMIT`](zrip_core::DEFAULT_DECOMPRESS_LIMIT).
62    pub fn new(reader: R) -> Self {
63        Self::with_limit(reader, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
64    }
65
66    /// Creates a decoder with an explicit output size limit.
67    pub fn with_limit(reader: R, max_output: usize) -> Self {
68        Self {
69            inner: reader,
70            state: State::FrameHeader,
71            read_buf: Vec::new(),
72            output_buf: Vec::new(),
73            output_pos: 0,
74            ws: Box::new(BlockDecodeWorkspace::new()),
75            seq_tables: SequenceDecodeTables::new_default(),
76            rep_offsets: [1, 4, 8],
77            hasher: None,
78            content_checksum: false,
79            max_output,
80            bytes_output: 0,
81        }
82    }
83
84    /// Consumes the decoder and returns the underlying reader.
85    pub fn into_inner(self) -> R {
86        self.inner
87    }
88
89    fn fill_output(&mut self) -> io::Result<()> {
90        loop {
91            match self.state {
92                State::Done => return Ok(()),
93                State::FrameHeader => self.read_frame_header()?,
94                State::BlockHeader => self.read_block_header()?,
95                State::BlockData {
96                    block_type,
97                    block_size,
98                    last,
99                } => {
100                    self.read_block_data(block_type, block_size, last)?;
101                    if self.output_pos < self.output_buf.len() {
102                        return Ok(());
103                    }
104                }
105                State::Checksum => self.read_checksum()?,
106            }
107        }
108    }
109
110    fn read_frame_header(&mut self) -> io::Result<()> {
111        self.read_buf.resize(18, 0);
112        self.inner.read_exact(&mut self.read_buf[..5])?;
113
114        let magic = u32::from_le_bytes([
115            self.read_buf[0],
116            self.read_buf[1],
117            self.read_buf[2],
118            self.read_buf[3],
119        ]);
120
121        if (magic & 0xFFFFFFF0) == 0x184D2A50 {
122            self.inner.read_exact(&mut self.read_buf[5..9])?;
123            let skip_size = u32::from_le_bytes([
124                self.read_buf[5],
125                self.read_buf[6],
126                self.read_buf[7],
127                self.read_buf[8],
128            ]) as usize;
129            io::copy(
130                &mut self.inner.by_ref().take(skip_size as u64),
131                &mut io::sink(),
132            )?;
133            return Ok(());
134        }
135
136        let descriptor = self.read_buf[4];
137        let single_segment = (descriptor & 0x20) != 0;
138        let dict_id_flag = descriptor & 0x03;
139        let fcs_flag = (descriptor >> 6) & 0x03;
140
141        let mut hdr_len = 5usize;
142        if !single_segment {
143            hdr_len += 1;
144        }
145        hdr_len += match dict_id_flag {
146            0 => 0,
147            1 => 1,
148            2 => 2,
149            3 => 4,
150            _ => unreachable!(),
151        };
152        hdr_len += match fcs_flag {
153            0 if single_segment => 1,
154            0 => 0,
155            1 => 2,
156            2 => 4,
157            3 => 8,
158            _ => unreachable!(),
159        };
160
161        if hdr_len > 5 {
162            self.inner.read_exact(&mut self.read_buf[5..hdr_len])?;
163        }
164
165        let header = parse_frame_header(&self.read_buf[..hdr_len])
166            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
167
168        if let Some(fcs) = header.frame_content_size {
169            if fcs as usize > self.max_output {
170                return Err(io::Error::new(
171                    io::ErrorKind::InvalidData,
172                    DecompressError::OutputTooSmall,
173                ));
174            }
175        }
176
177        self.content_checksum = header.content_checksum;
178        self.hasher = if header.content_checksum {
179            Some(Xxh64State::new(0))
180        } else {
181            None
182        };
183        self.rep_offsets = [1, 4, 8];
184        self.ws.huf_valid = false;
185        self.state = State::BlockHeader;
186        Ok(())
187    }
188
189    fn read_block_header(&mut self) -> io::Result<()> {
190        let mut hdr = [0u8; 3];
191        self.inner.read_exact(&mut hdr)?;
192        let block_header =
193            parse_block_header(&hdr).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
194
195        let block_size = block_header.block_size as usize;
196
197        match block_header.block_type {
198            BlockType::Raw | BlockType::Rle if block_size > MAX_BLOCK_SIZE => {
199                return Err(io::Error::new(
200                    io::ErrorKind::InvalidData,
201                    DecompressError::CorruptSequences,
202                ));
203            }
204            _ => {}
205        }
206
207        self.state = State::BlockData {
208            block_type: block_header.block_type,
209            block_size,
210            last: block_header.last_block,
211        };
212        Ok(())
213    }
214
215    fn read_block_data(
216        &mut self,
217        block_type: BlockType,
218        block_size: usize,
219        last: bool,
220    ) -> io::Result<()> {
221        self.output_buf.clear();
222        self.output_pos = 0;
223
224        match block_type {
225            BlockType::Raw => {
226                self.output_buf.resize(block_size, 0);
227                self.inner.read_exact(&mut self.output_buf)?;
228            }
229            BlockType::Rle => {
230                let mut byte = [0u8; 1];
231                self.inner.read_exact(&mut byte)?;
232                self.output_buf.resize(block_size, byte[0]);
233            }
234            BlockType::Compressed => {
235                self.read_buf.resize(block_size, 0);
236                self.inner.read_exact(&mut self.read_buf[..block_size])?;
237                self.decode_compressed_block(block_size)?;
238            }
239        }
240
241        if let Some(ref mut hasher) = self.hasher {
242            hasher.update(&self.output_buf);
243        }
244        self.bytes_output += self.output_buf.len();
245        if self.bytes_output > self.max_output {
246            return Err(io::Error::new(
247                io::ErrorKind::InvalidData,
248                DecompressError::OutputTooSmall,
249            ));
250        }
251
252        self.state = if last {
253            if self.content_checksum {
254                State::Checksum
255            } else {
256                State::FrameHeader
257            }
258        } else {
259            State::BlockHeader
260        };
261
262        Ok(())
263    }
264
265    fn decode_compressed_block(&mut self, block_size: usize) -> io::Result<()> {
266        let block_data = &self.read_buf[..block_size];
267
268        let lit_consumed = decode_literals_ws(block_data, &mut self.ws)
269            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
270
271        let remaining = &block_data[lit_consumed..];
272
273        if remaining.is_empty() {
274            self.output_buf.extend_from_slice(&self.ws.literal_buf);
275            return Ok(());
276        }
277
278        let (num_sequences, seq_count_size) = parse_sequence_count(remaining)
279            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
280
281        if num_sequences == 0 {
282            self.output_buf.extend_from_slice(&self.ws.literal_buf);
283            return Ok(());
284        }
285
286        let table_data = &remaining[seq_count_size..];
287        let tables_consumed =
288            parse_sequence_tables_ws(table_data, &mut self.seq_tables, &mut self.ws)
289                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
290
291        let seq_data = &table_data[tables_consumed..];
292
293        #[cfg(target_arch = "x86_64")]
294        {
295            if zrip_core::simd::cpu_tier() >= CpuTier::Avx2 {
296                let before = self.output_buf.len();
297                crate::simd_decode::x86_64::decode::decode_execute_avx2_safe(
298                    seq_data,
299                    num_sequences,
300                    &self.seq_tables,
301                    &mut self.rep_offsets,
302                    &self.ws.literal_buf,
303                    &mut self.output_buf,
304                    &[],
305                )
306                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
307                if self.output_buf.len() - before > MAX_BLOCK_SIZE {
308                    return Err(io::Error::new(
309                        io::ErrorKind::InvalidData,
310                        DecompressError::CorruptSequences,
311                    ));
312                }
313                return Ok(());
314            }
315        }
316
317        #[cfg(target_arch = "aarch64")]
318        {
319            if zrip_core::simd::cpu_tier() >= CpuTier::Neon {
320                let before = self.output_buf.len();
321                crate::simd_decode::aarch64::decode::decode_execute_neon_safe(
322                    seq_data,
323                    num_sequences,
324                    &self.seq_tables,
325                    &mut self.rep_offsets,
326                    &self.ws.literal_buf,
327                    &mut self.output_buf,
328                    &[],
329                )
330                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
331                if self.output_buf.len() - before > MAX_BLOCK_SIZE {
332                    return Err(io::Error::new(
333                        io::ErrorKind::InvalidData,
334                        DecompressError::CorruptSequences,
335                    ));
336                }
337                return Ok(());
338            }
339        }
340
341        let before = self.output_buf.len();
342        crate::exec::decode_execute_sequences(
343            seq_data,
344            num_sequences,
345            &self.seq_tables,
346            &mut self.rep_offsets,
347            &self.ws.literal_buf,
348            &mut self.output_buf,
349            &[],
350        )
351        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
352        if self.output_buf.len() - before > MAX_BLOCK_SIZE {
353            return Err(io::Error::new(
354                io::ErrorKind::InvalidData,
355                DecompressError::CorruptSequences,
356            ));
357        }
358        Ok(())
359    }
360
361    fn read_checksum(&mut self) -> io::Result<()> {
362        let mut buf = [0u8; 4];
363        self.inner.read_exact(&mut buf)?;
364        let stored = u32::from_le_bytes(buf);
365
366        if let Some(ref hasher) = self.hasher {
367            let hash = hasher.finish();
368            let expected = (hash & 0xFFFFFFFF) as u32;
369            if expected != stored {
370                return Err(io::Error::new(
371                    io::ErrorKind::InvalidData,
372                    DecompressError::ChecksumMismatch {
373                        expected: stored,
374                        got: expected,
375                    },
376                ));
377            }
378        }
379
380        self.state = State::FrameHeader;
381        Ok(())
382    }
383}
384
385impl<R: Read> Read for FrameDecoder<R> {
386    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
387        if self.output_pos >= self.output_buf.len() {
388            match &self.state {
389                State::Done => return Ok(0),
390                _ => {}
391            }
392
393            self.output_buf.clear();
394            self.output_pos = 0;
395
396            match self.fill_output() {
397                Ok(()) => {}
398                Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => match &self.state {
399                    State::FrameHeader => {
400                        self.state = State::Done;
401                        return Ok(0);
402                    }
403                    _ => return Err(e),
404                },
405                Err(e) => return Err(e),
406            }
407        }
408
409        let available = &self.output_buf[self.output_pos..];
410        let n = buf.len().min(available.len());
411        buf[..n].copy_from_slice(&available[..n]);
412        self.output_pos += n;
413        Ok(n)
414    }
415}