Skip to main content

compcol/lzfse/
decoder.rs

1//! Streaming LZFSE decoder.
2//!
3//! Buffers input until a whole block (magic + header + payload) can be
4//! decoded, then drains the decoded payload into the caller's output slice
5//! across as many `decode` calls as the caller needs.
6
7use alloc::vec::Vec;
8
9use crate::error::Error;
10use crate::lzfse::{lzfse_v2, lzvn};
11use crate::traits::{RawDecoder, RawProgress};
12
13/// 4-byte block magics.
14const MAGIC_UNCOMPRESSED: [u8; 4] = *b"bvx-";
15const MAGIC_LZVN: [u8; 4] = *b"bvxn";
16const MAGIC_V1: [u8; 4] = *b"bvx1";
17const MAGIC_V2: [u8; 4] = *b"bvx2";
18const MAGIC_EOS: [u8; 4] = *b"bvx$";
19
20/// Streaming decoder state machine.
21pub struct Decoder {
22    /// Bytes the caller has fed us that we haven't yet consumed.
23    input_buf: Vec<u8>,
24    /// Decoded bytes pending delivery to the caller.
25    output_buf: Vec<u8>,
26    /// Read cursor into `output_buf`. We keep the buffer around so we don't
27    /// have to shift bytes on every partial drain; once `output_pos ==
28    /// output_buf.len()`, we clear both.
29    output_pos: usize,
30    /// State.
31    state: State,
32    /// Once we hit the end-of-stream marker (or have signalled it once), we
33    /// short-circuit further calls.
34    eos: bool,
35    /// Set on any decode error so callers don't accidentally resume.
36    poisoned: bool,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40enum State {
41    /// Waiting for the next 4-byte block magic.
42    AwaitMagic,
43    /// Read magic; waiting for the block-specific header bytes.
44    AwaitHeader(BlockKind),
45    /// Header parsed; waiting for the rest of the payload, then decode + drain.
46    AwaitPayload {
47        kind: BlockKind,
48        /// For uncompressed blocks: bytes to copy. For LZVN: compressed bytes
49        /// to decode.
50        payload_len: usize,
51        /// For LZVN: expected decoded size from the header.
52        decoded_size: usize,
53    },
54    /// Stream is finished.
55    Done,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59enum BlockKind {
60    Uncompressed,
61    Lzvn,
62    /// `bvx2` (LZFSE v2): FSE + LZ77. Decoded by [`lzfse_v2::decode_block`]
63    /// once the whole block (variable-length header + both payload streams)
64    /// is buffered.
65    V2,
66    /// `bvx1` (LZFSE v1, uncompressed-freq variant): not emitted by modern
67    /// encoders; returns [`Error::Unsupported`].
68    V1,
69}
70
71impl Decoder {
72    pub fn new() -> Self {
73        Self {
74            input_buf: Vec::new(),
75            output_buf: Vec::new(),
76            output_pos: 0,
77            state: State::AwaitMagic,
78            eos: false,
79            poisoned: false,
80        }
81    }
82
83    fn raw_decode_inner(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
84        if self.poisoned {
85            return Err(Error::Corrupt);
86        }
87        let mut consumed = 0usize;
88        let mut written = 0usize;
89
90        loop {
91            // 1. Drain any pending decoded output first.
92            if self.output_pos < self.output_buf.len() {
93                let want = (self.output_buf.len() - self.output_pos).min(output.len() - written);
94                output[written..written + want]
95                    .copy_from_slice(&self.output_buf[self.output_pos..self.output_pos + want]);
96                self.output_pos += want;
97                written += want;
98                if self.output_pos == self.output_buf.len() {
99                    // Fully drained; reset.
100                    self.output_buf.clear();
101                    self.output_pos = 0;
102                }
103                if written == output.len() {
104                    return Ok(RawProgress {
105                        consumed,
106                        written,
107                        done: false,
108                    });
109                }
110                // If we just drained and output still has room, loop to
111                // try to make more progress.
112            }
113
114            // 2. If we've already hit end-of-stream, signal done.
115            if self.eos {
116                return Ok(RawProgress {
117                    consumed,
118                    written,
119                    done: true,
120                });
121            }
122
123            // 3. Pull from caller's `input` into `input_buf`. We pull lazily:
124            //    only as much as the current state needs.
125            if consumed < input.len() {
126                self.input_buf.extend_from_slice(&input[consumed..]);
127                consumed = input.len();
128            }
129
130            // 4. Advance the state machine.
131            match self.state {
132                State::AwaitMagic => {
133                    if self.input_buf.len() < 4 {
134                        return Ok(RawProgress {
135                            consumed,
136                            written,
137                            done: false,
138                        });
139                    }
140                    let mut magic = [0u8; 4];
141                    magic.copy_from_slice(&self.input_buf[..4]);
142                    // Drop the magic.
143                    self.input_buf.drain(..4);
144                    match magic {
145                        MAGIC_EOS => {
146                            self.state = State::Done;
147                            self.eos = true;
148                            // loop to emit done on next iteration
149                        }
150                        MAGIC_UNCOMPRESSED => {
151                            self.state = State::AwaitHeader(BlockKind::Uncompressed);
152                        }
153                        MAGIC_LZVN => {
154                            self.state = State::AwaitHeader(BlockKind::Lzvn);
155                        }
156                        MAGIC_V1 => {
157                            self.state = State::AwaitHeader(BlockKind::V1);
158                        }
159                        MAGIC_V2 => {
160                            self.state = State::AwaitHeader(BlockKind::V2);
161                        }
162                        _ => {
163                            self.poisoned = true;
164                            return Err(Error::BadHeader);
165                        }
166                    }
167                }
168
169                State::AwaitHeader(kind) => match kind {
170                    BlockKind::Uncompressed => {
171                        // 4-byte LE n_raw_bytes.
172                        if self.input_buf.len() < 4 {
173                            return Ok(RawProgress {
174                                consumed,
175                                written,
176                                done: false,
177                            });
178                        }
179                        let n_raw = u32::from_le_bytes([
180                            self.input_buf[0],
181                            self.input_buf[1],
182                            self.input_buf[2],
183                            self.input_buf[3],
184                        ]) as usize;
185                        self.input_buf.drain(..4);
186                        self.state = State::AwaitPayload {
187                            kind: BlockKind::Uncompressed,
188                            payload_len: n_raw,
189                            decoded_size: n_raw,
190                        };
191                    }
192                    BlockKind::Lzvn => {
193                        // 8-byte header: n_raw_bytes (u32 LE) + n_payload_bytes (u32 LE).
194                        if self.input_buf.len() < 8 {
195                            return Ok(RawProgress {
196                                consumed,
197                                written,
198                                done: false,
199                            });
200                        }
201                        let n_raw = u32::from_le_bytes([
202                            self.input_buf[0],
203                            self.input_buf[1],
204                            self.input_buf[2],
205                            self.input_buf[3],
206                        ]) as usize;
207                        let n_payload = u32::from_le_bytes([
208                            self.input_buf[4],
209                            self.input_buf[5],
210                            self.input_buf[6],
211                            self.input_buf[7],
212                        ]) as usize;
213                        self.input_buf.drain(..8);
214                        self.state = State::AwaitPayload {
215                            kind: BlockKind::Lzvn,
216                            payload_len: n_payload,
217                            decoded_size: n_raw,
218                        };
219                    }
220                    BlockKind::V2 => {
221                        // The v2 header is variable-length (FSE frequency
222                        // tables follow the fixed packed fields). Buffer the
223                        // fixed 28 bytes (post-magic: n_raw + three u64 words)
224                        // first so we can read `header_size` and the payload
225                        // sizes, then arrange to buffer the whole block (header
226                        // + payload) before decoding it in one shot.
227                        let fixed = lzfse_v2::V2_HEADER_FIXED_BYTES;
228                        if self.input_buf.len() < fixed {
229                            return Ok(RawProgress {
230                                consumed,
231                                written,
232                                done: false,
233                            });
234                        }
235                        let header_size = match lzfse_v2::parse_header_size(&self.input_buf) {
236                            Ok(h) => h as usize,
237                            Err(e) => {
238                                self.poisoned = true;
239                                return Err(e);
240                            }
241                        };
242                        let n_payload = match lzfse_v2::parse_payload_size(&self.input_buf) {
243                            Ok(n) => n as usize,
244                            Err(e) => {
245                                self.poisoned = true;
246                                return Err(e);
247                            }
248                        };
249                        // `header_size` includes the 4-byte magic we already
250                        // dropped; remaining block bytes after the magic are
251                        // `header_size - 4 + n_payload`.
252                        let header_len = match header_size.checked_sub(4) {
253                            Some(h) if h >= fixed => h,
254                            _ => {
255                                self.poisoned = true;
256                                return Err(Error::Corrupt);
257                            }
258                        };
259                        let block_len = match header_len.checked_add(n_payload) {
260                            Some(b) => b,
261                            None => {
262                                self.poisoned = true;
263                                return Err(Error::Corrupt);
264                            }
265                        };
266                        self.state = State::AwaitPayload {
267                            kind: BlockKind::V2,
268                            payload_len: block_len,
269                            decoded_size: 0,
270                        };
271                    }
272                    BlockKind::V1 => {
273                        self.poisoned = true;
274                        return Err(Error::Unsupported);
275                    }
276                },
277
278                State::AwaitPayload {
279                    kind,
280                    payload_len,
281                    decoded_size,
282                } => {
283                    if self.input_buf.len() < payload_len {
284                        return Ok(RawProgress {
285                            consumed,
286                            written,
287                            done: false,
288                        });
289                    }
290                    match kind {
291                        BlockKind::Uncompressed => {
292                            // Copy payload_len bytes into output_buf for drain.
293                            self.output_buf
294                                .extend_from_slice(&self.input_buf[..payload_len]);
295                            self.input_buf.drain(..payload_len);
296                            self.state = State::AwaitMagic;
297                        }
298                        BlockKind::Lzvn => {
299                            // Decode in one shot into output_buf.
300                            //
301                            // Bound the capacity hint by what the payload could
302                            // plausibly produce so an attacker-controlled
303                            // `decoded_size` (n_raw_bytes) cannot force a huge
304                            // up-front allocation (DoS / OOM): a single 1-byte
305                            // LZVN opcode expands to at most ~16 output bytes.
306                            // `decode_block` still enforces the real output size
307                            // against `decoded_size`, so under-hinting only makes
308                            // the Vec grow as actual bytes are produced.
309                            let capacity_hint =
310                                decoded_size.min(payload_len.saturating_mul(16).saturating_add(64));
311                            let mut block_out = Vec::with_capacity(capacity_hint);
312                            if let Err(e) = lzvn::decode_block(
313                                &self.input_buf[..payload_len],
314                                payload_len,
315                                decoded_size,
316                                &mut block_out,
317                            ) {
318                                self.poisoned = true;
319                                return Err(e);
320                            }
321                            self.output_buf.append(&mut block_out);
322                            self.input_buf.drain(..payload_len);
323                            self.state = State::AwaitMagic;
324                        }
325                        BlockKind::V2 => {
326                            // The whole block (header + both payload streams)
327                            // is now buffered in `payload_len` bytes. Decode in
328                            // one shot. Bound the up-front output reservation by
329                            // a payload-derived hint (an FSE block can expand
330                            // more than LZVN, but is still bounded; the decoder
331                            // enforces the exact `n_raw_bytes` internally).
332                            let cap_hint = payload_len.saturating_mul(32).saturating_add(1 << 16);
333                            let mut block_out = Vec::new();
334                            match lzfse_v2::decode_block(
335                                &self.input_buf[..payload_len],
336                                &mut block_out,
337                                cap_hint,
338                            ) {
339                                Ok(consumed_block) => {
340                                    debug_assert_eq!(consumed_block, payload_len);
341                                }
342                                Err(e) => {
343                                    self.poisoned = true;
344                                    return Err(e);
345                                }
346                            }
347                            self.output_buf.append(&mut block_out);
348                            self.input_buf.drain(..payload_len);
349                            self.state = State::AwaitMagic;
350                        }
351                        BlockKind::V1 => {
352                            // Unreachable — header step would have errored.
353                            self.poisoned = true;
354                            return Err(Error::Unsupported);
355                        }
356                    }
357                }
358
359                State::Done => {
360                    self.eos = true;
361                    return Ok(RawProgress {
362                        consumed,
363                        written,
364                        done: true,
365                    });
366                }
367            }
368        }
369    }
370}
371
372impl Default for Decoder {
373    fn default() -> Self {
374        Self::new()
375    }
376}
377
378impl RawDecoder for Decoder {
379    fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
380        self.raw_decode_inner(input, output)
381    }
382
383    fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
384        // finish drains any pending output and, if the stream has reached
385        // the end-of-stream marker, returns `done`. Otherwise we surface
386        // an UnexpectedEnd to signal truncation.
387        if self.poisoned {
388            return Err(Error::Corrupt);
389        }
390        let p = self.raw_decode_inner(&[], output)?;
391        if p.done {
392            return Ok(p);
393        }
394        // No more input is coming. If we haven't seen the EOS marker but
395        // we have nothing buffered and nothing pending, treat as
396        // unexpected-end. If we still have decoded bytes to drain, signal
397        // OutputFull-style (done=false, written>0).
398        if p.written > 0 || !self.output_buf.is_empty() {
399            return Ok(p);
400        }
401        if self.state == State::AwaitMagic && self.input_buf.is_empty() {
402            // No partial block in flight. Empty input followed by finish on
403            // a fresh decoder is fine — return StreamEnd.
404            self.eos = true;
405            return Ok(RawProgress {
406                consumed: 0,
407                written: 0,
408                done: true,
409            });
410        }
411        // Mid-block at EOI — truncated.
412        self.poisoned = true;
413        Err(Error::UnexpectedEnd)
414    }
415
416    fn raw_reset(&mut self) {
417        self.input_buf.clear();
418        self.output_buf.clear();
419        self.output_pos = 0;
420        self.state = State::AwaitMagic;
421        self.eos = false;
422        self.poisoned = false;
423    }
424}