Skip to main content

heatshrink/
decoder.rs

1use super::CodecError;
2use super::Finish;
3use super::Poll;
4use super::PollError;
5use super::SinkError;
6
7#[derive(Debug, Copy, Clone, PartialEq)]
8enum HSDstate {
9    TagBit,
10    YieldLiteral,
11    BackrefIndexMsb,
12    BackrefIndexLsb,
13    BackrefCountLsb,
14    YieldBackref,
15}
16
17/// Heatshrink decoder.
18///
19/// # Type parameters
20///
21/// - `W`   : base-2 log of the LZSS sliding window size (must match the encoder).
22/// - `L`   : number of bits for back-reference lengths (must match the encoder).
23/// - `I`   : streaming input buffer size in bytes (>= 1, tunable for RAM/throughput).
24/// - `WIN` : output (window) buffer size in bytes; **must equal `1 << W`**.
25///   (Redundant parameter required because Rust stable does not yet
26///   support arithmetic on const generics in array sizes.)
27///
28/// Use [`DefaultDecoder`](super::DefaultDecoder) or the dispatch helpers in
29/// `heatshrink-bin` rather than setting `WIN` manually.
30///
31/// # Panics
32///
33/// [`new()`](HeatshrinkDecoder::new) panics if: `W < 4`, `L < 3`, `L >= W`,
34/// `W > 14`, `I < 1`, or `WIN != 1 << W`.
35#[derive(Debug)]
36pub struct HeatshrinkDecoder<const W: usize, const L: usize, const I: usize, const WIN: usize> {
37    input_size: usize,
38    input_index: usize,
39    output_index: usize,
40    head_index: usize,
41    output_count: u16,
42    current_byte: u8,
43    bit_index: u8,
44    state: HSDstate,
45    input_buffer: [u8; I],
46    output_buffer: [u8; WIN],
47}
48
49/// Decompress `src` into `dst` using the default parameters (W=8, L=4, I=32).
50pub fn decode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a [u8], CodecError> {
51    let mut dec = super::DefaultDecoder::new();
52    run_decode(&mut dec, src, dst)
53}
54
55/// Internal decode loop, generic over all decoder configurations.
56pub(crate) fn run_decode<'a, const W: usize, const L: usize, const I: usize, const WIN: usize>(
57    dec: &mut HeatshrinkDecoder<W, L, I, WIN>,
58    src: &[u8],
59    dst: &'a mut [u8],
60) -> Result<&'a [u8], CodecError> {
61    let mut total_input_size = 0;
62    let mut total_output_size = 0;
63
64    while total_input_size < src.len() {
65        match dec.sink(&src[total_input_size..]) {
66            Ok(n) => total_input_size += n,
67            Err(SinkError::Full) => {}
68            Err(SinkError::Misuse) => return Err(CodecError::Internal),
69        }
70
71        if total_output_size == dst.len() {
72            return Err(CodecError::OutputFull);
73        }
74
75        match dec.poll(&mut dst[total_output_size..]) {
76            Ok(Poll::More(_)) => return Err(CodecError::OutputFull),
77            Ok(Poll::Empty(n)) => total_output_size += n,
78            Err(_) => return Err(CodecError::Internal),
79        }
80
81        if total_input_size == src.len() {
82            match dec.finish() {
83                Finish::Done => {}
84                Finish::More => return Err(CodecError::OutputFull),
85            }
86        }
87    }
88
89    Ok(&dst[..total_output_size])
90}
91
92impl<const W: usize, const L: usize, const I: usize, const WIN: usize> Default
93    for HeatshrinkDecoder<W, L, I, WIN>
94{
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl<const W: usize, const L: usize, const I: usize, const WIN: usize>
101    HeatshrinkDecoder<W, L, I, WIN>
102{
103    /// Create a new decoder instance.
104    ///
105    /// # Panics
106    ///
107    /// Panics if `W < 4`, `L < 3`, `L >= W`, `W > 15`, `I < 1`,
108    /// or `WIN != 1 << W`.
109    pub fn new() -> Self {
110        assert!(W >= 4, "W must be >= 4");
111        assert!(L >= 3, "L must be >= 3");
112        assert!(L < W, "L must be < W");
113        assert!(W <= 15, "W must be <= 15 (search_index uses Option<u16>)");
114        assert!(I >= 1, "I must be >= 1");
115        assert!(WIN == 1 << W, "WIN must equal 1 << W");
116
117        HeatshrinkDecoder {
118            input_size: 0,
119            input_index: 0,
120            output_count: 0,
121            output_index: 0,
122            head_index: 0,
123            current_byte: 0,
124            bit_index: 0,
125            state: HSDstate::TagBit,
126            input_buffer: [0; I],
127            output_buffer: [0; WIN],
128        }
129    }
130
131    /// Reset the decoder to its initial state so it can be reused.
132    pub fn reset(&mut self) {
133        *self = Self::new();
134    }
135
136    /// Feed compressed data into the decoder.
137    pub fn sink(&mut self, input_buffer: &[u8]) -> Result<usize, SinkError> {
138        // Compact: slide unconsumed bytes to the front so the whole buffer
139        // is available for new data.  input_index bytes at the start have
140        // already been consumed by get_bits and can be overwritten.
141        // We must preserve current_byte / bit_index which hold up to 8 bits
142        // that were pre-loaded from the last consumed byte — those bits are
143        // NOT in input_buffer[input_index..] and must not be disturbed.
144        let unconsumed = self.input_size - self.input_index;
145        if self.input_index > 0 && unconsumed > 0 {
146            self.input_buffer
147                .copy_within(self.input_index..self.input_size, 0);
148        }
149        self.input_size = unconsumed;
150        self.input_index = 0;
151
152        let remaining_size = self.input_buffer.len() - self.input_size;
153        if remaining_size == 0 {
154            return Err(SinkError::Full);
155        }
156
157        let copy_size = remaining_size.min(input_buffer.len());
158        self.input_buffer[self.input_size..self.input_size + copy_size]
159            .copy_from_slice(&input_buffer[..copy_size]);
160        self.input_size += copy_size;
161
162        if self.bit_index == 0 {
163            self.current_byte = self.input_buffer[self.input_index];
164            self.input_index += 1;
165            self.bit_index = 8;
166        }
167
168        Ok(copy_size)
169    }
170
171    /// Pull decompressed output out of the decoder into `output_buffer`.
172    pub fn poll(&mut self, output_buffer: &mut [u8]) -> Result<Poll, PollError> {
173        if output_buffer.is_empty() {
174            return Err(PollError::Misuse);
175        }
176
177        let mut out_pos: usize = 0;
178
179        loop {
180            let previous_state = self.state;
181
182            match previous_state {
183                HSDstate::TagBit => {
184                    self.state = self.st_tag_bit();
185                }
186                HSDstate::YieldLiteral => {
187                    self.state = self.st_yield_literal(output_buffer, &mut out_pos);
188                }
189                HSDstate::BackrefIndexMsb => {
190                    self.state = self.st_backref_index_msb();
191                }
192                HSDstate::BackrefIndexLsb => {
193                    self.state = self.st_backref_index_lsb();
194                }
195                HSDstate::BackrefCountLsb => {
196                    self.state = self.st_backref_count_lsb();
197                }
198                HSDstate::YieldBackref => {
199                    self.state = self.st_yield_backref(output_buffer, &mut out_pos);
200                }
201            }
202
203            if self.state == previous_state {
204                return if out_pos < output_buffer.len() {
205                    Ok(Poll::Empty(out_pos))
206                } else {
207                    Ok(Poll::More(out_pos))
208                };
209            }
210        }
211    }
212
213    /// Signal end of input.
214    pub fn finish(&self) -> Finish {
215        if self.input_size == 0 {
216            Finish::Done
217        } else {
218            Finish::More
219        }
220    }
221
222    // ---- State machine helpers ----
223
224    #[inline]
225    fn st_tag_bit(&mut self) -> HSDstate {
226        match self.get_bits(1) {
227            None => HSDstate::TagBit,
228            Some(0) => {
229                self.output_index = 0;
230                if W > 8 {
231                    HSDstate::BackrefIndexMsb
232                } else {
233                    HSDstate::BackrefIndexLsb
234                }
235            }
236            Some(_) => HSDstate::YieldLiteral,
237        }
238    }
239
240    #[inline]
241    fn st_yield_literal(&mut self, out: &mut [u8], pos: &mut usize) -> HSDstate {
242        if *pos < out.len() {
243            match self.get_bits(8) {
244                None => HSDstate::YieldLiteral,
245                Some(c) => {
246                    let c = c as u8;
247                    self.output_buffer[self.head_index % WIN] = c;
248                    self.head_index += 1;
249                    out[*pos] = c;
250                    *pos += 1;
251                    HSDstate::TagBit
252                }
253            }
254        } else {
255            HSDstate::YieldLiteral
256        }
257    }
258
259    /// Only reached when W > 8: reads the (W - 8) most-significant index bits.
260    #[inline]
261    fn st_backref_index_msb(&mut self) -> HSDstate {
262        match self.get_bits((W - 8) as u8) {
263            None => HSDstate::BackrefIndexMsb,
264            Some(x) => {
265                self.output_index = (x as usize) << 8;
266                HSDstate::BackrefIndexLsb
267            }
268        }
269    }
270
271    #[inline]
272    fn st_backref_index_lsb(&mut self) -> HSDstate {
273        // When W <= 8 we arrive here directly (no MSB state) and the encoder
274        // wrote exactly W bits for the index.  When W > 8 we arrive after
275        // st_backref_index_msb which already consumed (W-8) bits, so we need
276        // the remaining 8 bits.  In both cases the right count is min(W, 8).
277        let lsb_bits = W.min(8) as u8;
278        match self.get_bits(lsb_bits) {
279            None => HSDstate::BackrefIndexLsb,
280            Some(x) => {
281                self.output_index |= x as usize;
282                self.output_index += 1;
283                self.output_count = 0;
284                HSDstate::BackrefCountLsb
285            }
286        }
287    }
288
289    #[inline]
290    fn st_backref_count_lsb(&mut self) -> HSDstate {
291        match self.get_bits(L as u8) {
292            None => HSDstate::BackrefCountLsb,
293            Some(x) => {
294                self.output_count = x + 1;
295                HSDstate::YieldBackref
296            }
297        }
298    }
299
300    #[inline]
301    fn st_yield_backref(&mut self, out: &mut [u8], pos: &mut usize) -> HSDstate {
302        if *pos == out.len() {
303            return HSDstate::YieldBackref;
304        }
305
306        let output_index = self.output_index;
307        let count = (out.len() - *pos).min(self.output_count as usize);
308
309        // Prologue: back-reference points before the start of the stream
310        // (output_index > head_index).  Emit zeroes — rare, only at the very
311        // beginning of decoding.
312        if output_index > self.head_index {
313            let zero_count = count.min(output_index - self.head_index);
314            let limit = self.head_index + zero_count;
315            while self.head_index < limit {
316                out[*pos] = 0;
317                *pos += 1;
318                self.output_buffer[self.head_index & (WIN - 1)] = 0;
319                self.head_index += 1;
320            }
321            self.output_count -= zero_count as u16;
322            if self.output_count == 0 {
323                return HSDstate::TagBit;
324            }
325            if *pos == out.len() {
326                return HSDstate::YieldBackref;
327            }
328        }
329
330        // How many bytes remain to emit in this call.
331        let count = (out.len() - *pos).min(self.output_count as usize);
332
333        if output_index >= count {
334            // ── Fast path: no self-referential overlap ────────────────────────
335            // The source window [src_start .. src_start+count] does not overlap
336            // with the destination window [dst_start .. dst_start+count] in the
337            // circular buffer (they are at least `count` bytes apart).
338            // We copy in at most two contiguous chunks to handle wrap-around,
339            // then bulk-copy the same data into the caller's output buffer.
340            let src_start = (self.head_index - output_index) & (WIN - 1);
341            let dst_start = self.head_index & (WIN - 1);
342
343            // --- copy into the circular output_buffer (1 or 2 chunks) ---
344            let src_end = src_start + count;
345            let dst_end = dst_start + count;
346
347            if src_end <= WIN && dst_end <= WIN {
348                // Neither source nor destination wraps: single copy.
349                self.output_buffer
350                    .copy_within(src_start..src_start + count, dst_start);
351            } else {
352                // At least one side wraps: copy byte by byte into the ring but
353                // still avoid the byte-by-byte loop below by handling
354                // the ring copy first, then using a slice copy for the caller.
355                let limit = self.head_index + count;
356                let mut h = self.head_index;
357                while h < limit {
358                    let s = (h - output_index) & (WIN - 1);
359                    let d = h & (WIN - 1);
360                    self.output_buffer[d] = self.output_buffer[s];
361                    h += 1;
362                }
363            }
364
365            // --- bulk copy from the (now updated) ring to the caller's buffer --
366            // dst_start is where we just wrote `count` bytes (possibly wrapping).
367            // Re-read from the ring to the output slice in 1 or 2 chunks.
368            if dst_end <= WIN {
369                out[*pos..*pos + count]
370                    .copy_from_slice(&self.output_buffer[dst_start..dst_start + count]);
371            } else {
372                let first = WIN - dst_start;
373                let second = count - first;
374                out[*pos..*pos + first].copy_from_slice(&self.output_buffer[dst_start..WIN]);
375                out[*pos + first..*pos + count].copy_from_slice(&self.output_buffer[..second]);
376            }
377            *pos += count;
378            self.head_index += count;
379        } else {
380            // ── Slow path: self-referential match (output_index < count) ──────
381            // Each byte written may immediately become the source of the next
382            // read (e.g. run-length: output_index=1, count=50 → "aaaa…").
383            // Must stay byte-by-byte.
384            let limit = self.head_index + count;
385            while self.head_index < limit {
386                let c = self.output_buffer[(self.head_index - output_index) & (WIN - 1)];
387                out[*pos] = c;
388                *pos += 1;
389                self.output_buffer[self.head_index & (WIN - 1)] = c;
390                self.head_index += 1;
391            }
392        }
393
394        self.output_count -= count as u16;
395        if self.output_count == 0 {
396            HSDstate::TagBit
397        } else {
398            HSDstate::YieldBackref
399        }
400    }
401
402    /// Read the next `count` bits from the input buffer.
403    ///
404    /// Supports up to 15 bits (the maximum needed is `L` for back-reference
405    /// lengths, and `L < W <= 15`).  Returns `None` if not enough bits are
406    /// available.
407    fn get_bits(&mut self, count: u8) -> Option<u16> {
408        debug_assert!(count > 0 && count <= 15);
409
410        let available = (self.input_size - self.input_index) * 8 + self.bit_index as usize;
411        if available < count as usize {
412            return None;
413        }
414
415        // u32 accumulator covers the worst case: bit_index=1, count=15,
416        // which may require loading 2 extra bytes.
417        let mut acc = (self.current_byte as u32) & ((1 << self.bit_index) - 1);
418        let mut bits = self.bit_index;
419
420        while bits < count {
421            self.current_byte = self.input_buffer[self.input_index];
422            self.input_index += 1;
423            acc = (acc << 8) | self.current_byte as u32;
424            bits += 8;
425        }
426
427        let remaining = bits - count;
428        let result = (acc >> remaining) & ((1u32 << count) - 1);
429
430        // Maintain the invariant: bit_index == 0 only when the buffer is fully
431        // consumed.  If remaining == 0 and bytes remain, pre-load the next one
432        // (same as sink() does on first fill), so sink() never wrongly
433        // overwrites an unconsumed byte.
434        if remaining == 0 {
435            if self.input_index < self.input_size {
436                self.current_byte = self.input_buffer[self.input_index];
437                self.input_index += 1;
438                self.bit_index = 8;
439                // If we just consumed the last byte, reset so sink() can refill.
440                if self.input_index == self.input_size {
441                    self.input_index = 0;
442                    self.input_size = 0;
443                }
444            } else {
445                self.input_index = 0;
446                self.input_size = 0;
447                self.bit_index = 0;
448                self.current_byte = 0;
449            }
450        } else {
451            self.bit_index = remaining;
452            self.current_byte = (acc & ((1 << remaining) - 1)) as u8;
453            if self.input_index == self.input_size {
454                self.input_index = 0;
455                self.input_size = 0;
456            }
457        }
458
459        Some(result as u16)
460    }
461}