Skip to main content

zrip_decode/
streaming.rs

1#![forbid(unsafe_code)]
2
3use std::io::{self, Read};
4
5use crate::BlockDecodeWorkspace;
6use crate::literals::decode_literals_ws;
7use crate::sequences::{SequenceDecodeTables, parse_sequence_count, parse_sequence_tables_ws};
8use zrip_core::block::{BlockType, parse_block_header};
9use zrip_core::dict::Dictionary;
10use zrip_core::error::DecompressError;
11use zrip_core::frame::MAX_BLOCK_SIZE;
12use zrip_core::frame::header::parse_frame_header;
13use zrip_core::fse::{promote_ll_table, promote_ml_table, promote_of_table};
14use zrip_core::xxhash::Xxh64State;
15
16#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
17use zrip_core::simd::CpuTier;
18
19enum State {
20    FrameHeader,
21    BlockHeader,
22    BlockData {
23        block_type: BlockType,
24        block_size: usize,
25        last: bool,
26    },
27    Checksum,
28    Done,
29}
30
31/// Streaming zstd decompressor implementing [`Read`].
32///
33/// Wraps a reader of compressed data and yields decompressed bytes.
34/// Supports multi-frame streams and skippable frames.
35///
36/// ```
37/// use std::io::Read;
38///
39/// let data = b"hello, streaming world!".repeat(100);
40/// let compressed = zrip::compress(&data, 1).unwrap();
41///
42/// let mut decoder = zrip::FrameDecoder::new(&compressed[..]);
43/// let mut output = Vec::new();
44/// decoder.read_to_end(&mut output).unwrap();
45/// assert_eq!(output, data);
46/// ```
47pub struct FrameDecoder<R: Read> {
48    inner: R,
49    state: State,
50    read_buf: Vec<u8>,
51    output_buf: Vec<u8>,
52    output_pos: usize,
53    ws: Box<BlockDecodeWorkspace>,
54    seq_tables: SequenceDecodeTables,
55    rep_offsets: [u32; 3],
56    hasher: Option<Xxh64State>,
57    content_checksum: bool,
58    max_output: usize,
59    bytes_output: usize,
60    dict: Option<Dictionary>,
61}
62
63impl<R: Read> FrameDecoder<R> {
64    /// Creates a decoder with [`DEFAULT_DECOMPRESS_LIMIT`](zrip_core::DEFAULT_DECOMPRESS_LIMIT).
65    pub fn new(reader: R) -> Self {
66        Self::with_limit(reader, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
67    }
68
69    /// Creates a decoder with an explicit output size limit.
70    pub fn with_limit(reader: R, max_output: usize) -> Self {
71        Self {
72            inner: reader,
73            state: State::FrameHeader,
74            read_buf: Vec::new(),
75            output_buf: Vec::new(),
76            output_pos: 0,
77            ws: Box::new(BlockDecodeWorkspace::new()),
78            seq_tables: SequenceDecodeTables::new_default(),
79            rep_offsets: [1, 4, 8],
80            hasher: None,
81            content_checksum: false,
82            max_output,
83            bytes_output: 0,
84            dict: None,
85        }
86    }
87
88    /// Creates a decoder with a dictionary and default output limit.
89    pub fn with_dict(reader: R, dict: Dictionary) -> Self {
90        Self::with_dict_and_limit(reader, dict, zrip_core::DEFAULT_DECOMPRESS_LIMIT)
91    }
92
93    /// Creates a decoder with a dictionary and explicit output limit.
94    pub fn with_dict_and_limit(reader: R, dict: Dictionary, max_output: usize) -> Self {
95        Self {
96            inner: reader,
97            state: State::FrameHeader,
98            read_buf: Vec::new(),
99            output_buf: Vec::new(),
100            output_pos: 0,
101            ws: Box::new(BlockDecodeWorkspace::new()),
102            seq_tables: SequenceDecodeTables::new_default(),
103            rep_offsets: [1, 4, 8],
104            hasher: None,
105            content_checksum: false,
106            max_output,
107            bytes_output: 0,
108            dict: Some(dict),
109        }
110    }
111
112    /// Consumes the decoder and returns the underlying reader.
113    pub fn into_inner(self) -> R {
114        self.inner
115    }
116
117    /// Installs a new reader for the next frame, keeping all internal
118    /// buffers allocated. Returns the previous reader.
119    pub fn reset(&mut self, new_reader: R) -> R {
120        let old = core::mem::replace(&mut self.inner, new_reader);
121        self.state = State::FrameHeader;
122        self.output_buf.clear();
123        self.output_pos = 0;
124        self.rep_offsets = [1, 4, 8];
125        self.seq_tables = SequenceDecodeTables::new_default();
126        self.ws.huf_valid = false;
127        self.hasher = None;
128        self.content_checksum = false;
129        self.bytes_output = 0;
130        old
131    }
132
133    fn fill_output(&mut self) -> io::Result<()> {
134        loop {
135            match self.state {
136                State::Done => return Ok(()),
137                State::FrameHeader => self.read_frame_header()?,
138                State::BlockHeader => self.read_block_header()?,
139                State::BlockData {
140                    block_type,
141                    block_size,
142                    last,
143                } => {
144                    self.read_block_data(block_type, block_size, last)?;
145                    if self.output_pos < self.output_buf.len() {
146                        return Ok(());
147                    }
148                }
149                State::Checksum => self.read_checksum()?,
150            }
151        }
152    }
153
154    fn read_frame_header(&mut self) -> io::Result<()> {
155        self.read_buf.resize(18, 0);
156        self.inner.read_exact(&mut self.read_buf[..5])?;
157
158        let magic = u32::from_le_bytes([
159            self.read_buf[0],
160            self.read_buf[1],
161            self.read_buf[2],
162            self.read_buf[3],
163        ]);
164
165        if (magic & 0xFFFF_FFF0) == 0x184D_2A50 {
166            self.inner.read_exact(&mut self.read_buf[5..9])?;
167            let skip_size = u32::from_le_bytes([
168                self.read_buf[5],
169                self.read_buf[6],
170                self.read_buf[7],
171                self.read_buf[8],
172            ]) as usize;
173            io::copy(
174                &mut self.inner.by_ref().take(skip_size as u64),
175                &mut io::sink(),
176            )?;
177            return Ok(());
178        }
179
180        let descriptor = self.read_buf[4];
181        let single_segment = (descriptor & 0x20) != 0;
182        let dict_id_flag = descriptor & 0x03;
183        let fcs_flag = (descriptor >> 6) & 0x03;
184
185        let mut hdr_len = 5usize;
186        if !single_segment {
187            hdr_len += 1;
188        }
189        hdr_len += match dict_id_flag {
190            0 => 0,
191            1 => 1,
192            2 => 2,
193            3 => 4,
194            _ => unreachable!(),
195        };
196        hdr_len += match fcs_flag {
197            0 if single_segment => 1,
198            0 => 0,
199            1 => 2,
200            2 => 4,
201            3 => 8,
202            _ => unreachable!(),
203        };
204
205        if hdr_len > 5 {
206            self.inner.read_exact(&mut self.read_buf[5..hdr_len])?;
207        }
208
209        let header = parse_frame_header(&self.read_buf[..hdr_len])
210            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
211
212        if let Some(frame_dict_id) = header.dict_id {
213            match &self.dict {
214                Some(d) if d.id() == frame_dict_id => {}
215                Some(d) => {
216                    return Err(io::Error::new(
217                        io::ErrorKind::InvalidData,
218                        DecompressError::DictMismatch {
219                            expected: frame_dict_id,
220                            got: d.id(),
221                        },
222                    ));
223                }
224                None => {
225                    return Err(io::Error::new(
226                        io::ErrorKind::InvalidData,
227                        DecompressError::DictRequired,
228                    ));
229                }
230            }
231        }
232
233        if let Some(fcs) = header.frame_content_size {
234            if fcs as usize > self.max_output {
235                return Err(io::Error::new(
236                    io::ErrorKind::InvalidData,
237                    DecompressError::OutputTooSmall,
238                ));
239            }
240        }
241
242        self.content_checksum = header.content_checksum;
243        self.hasher = if header.content_checksum {
244            Some(Xxh64State::new(0))
245        } else {
246            None
247        };
248
249        if let Some(ref d) = self.dict {
250            self.rep_offsets = *d.rep_offsets();
251            let mut st = SequenceDecodeTables::new_default();
252            if let Some((t, l)) = d.of_table() {
253                st.of_table = promote_of_table(t);
254                st.of_accuracy = l;
255            }
256            if let Some((t, l)) = d.ml_table() {
257                st.ml_table = promote_ml_table(t);
258                st.ml_accuracy = l;
259            }
260            if let Some((t, l)) = d.ll_table() {
261                st.ll_table = promote_ll_table(t);
262                st.ll_accuracy = l;
263            }
264            self.seq_tables = st;
265            self.ws.huf_valid = false;
266            if let Some((t, l)) = d.huf_table() {
267                self.ws.huf_table.clear();
268                self.ws.huf_table.extend_from_slice(t);
269                self.ws.huf_table_log = l;
270                self.ws.huf_valid = true;
271            }
272        } else {
273            self.rep_offsets = [1, 4, 8];
274            self.seq_tables = SequenceDecodeTables::new_default();
275            self.ws.huf_valid = false;
276        }
277
278        self.state = State::BlockHeader;
279        Ok(())
280    }
281
282    fn read_block_header(&mut self) -> io::Result<()> {
283        let mut hdr = [0u8; 3];
284        self.inner.read_exact(&mut hdr)?;
285        let block_header =
286            parse_block_header(&hdr).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
287
288        let block_size = block_header.block_size as usize;
289
290        match block_header.block_type {
291            BlockType::Raw | BlockType::Rle if block_size > MAX_BLOCK_SIZE => {
292                return Err(io::Error::new(
293                    io::ErrorKind::InvalidData,
294                    DecompressError::BlockTooLarge,
295                ));
296            }
297            _ => {}
298        }
299
300        self.state = State::BlockData {
301            block_type: block_header.block_type,
302            block_size,
303            last: block_header.last_block,
304        };
305        Ok(())
306    }
307
308    fn read_block_data(
309        &mut self,
310        block_type: BlockType,
311        block_size: usize,
312        last: bool,
313    ) -> io::Result<()> {
314        self.output_buf.clear();
315        self.output_pos = 0;
316
317        match block_type {
318            BlockType::Raw => {
319                self.output_buf.resize(block_size, 0);
320                self.inner.read_exact(&mut self.output_buf)?;
321            }
322            BlockType::Rle => {
323                let mut byte = [0u8; 1];
324                self.inner.read_exact(&mut byte)?;
325                self.output_buf.resize(block_size, byte[0]);
326            }
327            BlockType::Compressed => {
328                self.read_buf.resize(block_size, 0);
329                self.inner.read_exact(&mut self.read_buf[..block_size])?;
330                self.decode_compressed_block(block_size)?;
331            }
332        }
333
334        if let Some(ref mut hasher) = self.hasher {
335            hasher.update(&self.output_buf);
336        }
337        self.bytes_output += self.output_buf.len();
338        if self.bytes_output > self.max_output {
339            return Err(io::Error::new(
340                io::ErrorKind::InvalidData,
341                DecompressError::OutputTooSmall,
342            ));
343        }
344
345        self.state = if last {
346            if self.content_checksum {
347                State::Checksum
348            } else {
349                State::FrameHeader
350            }
351        } else {
352            State::BlockHeader
353        };
354
355        Ok(())
356    }
357
358    fn decode_compressed_block(&mut self, block_size: usize) -> io::Result<()> {
359        let dict_history: &[u8] = match &self.dict {
360            Some(d) => d.content(),
361            None => &[],
362        };
363        let block_data = &self.read_buf[..block_size];
364
365        let lit_consumed = decode_literals_ws(block_data, &mut self.ws)
366            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
367
368        let remaining = &block_data[lit_consumed..];
369
370        if remaining.is_empty() {
371            self.output_buf.extend_from_slice(&self.ws.literal_buf);
372            return Ok(());
373        }
374
375        let (num_sequences, seq_count_size) = parse_sequence_count(remaining)
376            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
377
378        if num_sequences == 0 {
379            self.output_buf.extend_from_slice(&self.ws.literal_buf);
380            return Ok(());
381        }
382
383        let table_data = &remaining[seq_count_size..];
384        let tables_consumed =
385            parse_sequence_tables_ws(table_data, &mut self.seq_tables, &mut self.ws)
386                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
387
388        let seq_data = &table_data[tables_consumed..];
389
390        #[cfg(target_arch = "x86_64")]
391        {
392            if zrip_core::simd::cpu_tier() >= CpuTier::Avx2 {
393                let before = self.output_buf.len();
394                crate::simd_decode::x86_64::decode::decode_execute_avx2_safe(
395                    seq_data,
396                    num_sequences,
397                    &self.seq_tables,
398                    &mut self.rep_offsets,
399                    &self.ws.literal_buf,
400                    &mut self.output_buf,
401                    dict_history,
402                )
403                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
404                if self.output_buf.len() - before > MAX_BLOCK_SIZE {
405                    return Err(io::Error::new(
406                        io::ErrorKind::InvalidData,
407                        DecompressError::BlockTooLarge,
408                    ));
409                }
410                return Ok(());
411            }
412        }
413
414        #[cfg(target_arch = "aarch64")]
415        {
416            if zrip_core::simd::cpu_tier() >= CpuTier::Neon {
417                let before = self.output_buf.len();
418                crate::simd_decode::aarch64::decode::decode_execute_neon_safe(
419                    seq_data,
420                    num_sequences,
421                    &self.seq_tables,
422                    &mut self.rep_offsets,
423                    &self.ws.literal_buf,
424                    &mut self.output_buf,
425                    dict_history,
426                )
427                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
428                if self.output_buf.len() - before > MAX_BLOCK_SIZE {
429                    return Err(io::Error::new(
430                        io::ErrorKind::InvalidData,
431                        DecompressError::BlockTooLarge,
432                    ));
433                }
434                return Ok(());
435            }
436        }
437
438        let before = self.output_buf.len();
439        crate::exec::decode_execute_sequences(
440            seq_data,
441            num_sequences,
442            &self.seq_tables,
443            &mut self.rep_offsets,
444            &self.ws.literal_buf,
445            &mut self.output_buf,
446            dict_history,
447        )
448        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
449        if self.output_buf.len() - before > MAX_BLOCK_SIZE {
450            return Err(io::Error::new(
451                io::ErrorKind::InvalidData,
452                DecompressError::BlockTooLarge,
453            ));
454        }
455        Ok(())
456    }
457
458    fn read_checksum(&mut self) -> io::Result<()> {
459        let mut buf = [0u8; 4];
460        self.inner.read_exact(&mut buf)?;
461        let stored = u32::from_le_bytes(buf);
462
463        if let Some(ref hasher) = self.hasher {
464            let hash = hasher.finish();
465            let expected = (hash & 0xFFFF_FFFF) as u32;
466            if expected != stored {
467                return Err(io::Error::new(
468                    io::ErrorKind::InvalidData,
469                    DecompressError::ChecksumMismatch {
470                        expected: stored,
471                        got: expected,
472                    },
473                ));
474            }
475        }
476
477        self.state = State::FrameHeader;
478        Ok(())
479    }
480}
481
482impl<R: Read> Read for FrameDecoder<R> {
483    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
484        if self.output_pos >= self.output_buf.len() {
485            if let State::Done = &self.state {
486                return Ok(0);
487            }
488
489            self.output_buf.clear();
490            self.output_pos = 0;
491
492            match self.fill_output() {
493                Ok(()) => {}
494                Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => match &self.state {
495                    State::FrameHeader => {
496                        self.state = State::Done;
497                        return Ok(0);
498                    }
499                    _ => return Err(e),
500                },
501                Err(e) => return Err(e),
502            }
503        }
504
505        let available = &self.output_buf[self.output_pos..];
506        let n = buf.len().min(available.len());
507        buf[..n].copy_from_slice(&available[..n]);
508        self.output_pos += n;
509        Ok(n)
510    }
511}