Skip to main content

heatshrink/
encoder.rs

1use super::CodecError;
2use super::Finish;
3use super::Poll;
4use super::PollError;
5use super::SinkError;
6
7#[derive(Debug, Copy, Clone, PartialEq)]
8enum HSEstate {
9    NotFull,
10    Filled,
11    Search,
12    YieldTagBit,
13    YieldLiteral,
14    YieldBrIndex,
15    YieldBrLength,
16    SaveBacklog,
17    FlushBits,
18    Done,
19}
20
21/// Heatshrink encoder.
22///
23/// # Type parameters
24///
25/// - `W`   : base-2 log of the LZSS sliding window size.
26/// - `L`   : number of bits used to encode back-reference lengths (must be < W).
27/// - `BUF` : total input buffer size in bytes; **must equal `2 << W`**.
28///   (This redundant parameter is required because Rust stable does not
29///   yet support arithmetic on const generics in array sizes.)
30///
31/// Use the [`DefaultEncoder`](super::DefaultEncoder) type alias or construct
32/// via [`HeatshrinkEncoder::new`] when using the convenience dispatch in
33/// `heatshrink-bin`; do not set `BUF` manually — always use `2 << W`.
34///
35/// # Panics
36///
37/// [`new()`](HeatshrinkEncoder::new) panics if: `W < 4`, `L < 3`, `L >= W`,
38/// `W > 14`, or `BUF != 2 << W`.
39#[derive(Debug)]
40pub struct HeatshrinkEncoder<const W: usize, const L: usize, const BUF: usize> {
41    input_size: usize,
42    match_scan_index: usize,
43    match_length: usize,
44    match_position: usize,
45    outgoing_bits: u16,
46    outgoing_bits_count: u8,
47    is_finishing: bool,
48    current_byte: u8,
49    bit_index: u8,
50    state: HSEstate,
51    #[cfg(feature = "heatshrink-use-index")]
52    search_index: [u16; BUF],
53    input_buffer: [u8; BUF],
54}
55
56/// Compress `src` into `dst` using the default parameters (W=8, L=4).
57pub fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a [u8], CodecError> {
58    let mut enc = super::DefaultEncoder::new();
59    run_encode(&mut enc, src, dst)
60}
61
62/// Internal encode loop, generic over all encoder configurations.
63pub(crate) fn run_encode<'a, const W: usize, const L: usize, const BUF: usize>(
64    enc: &mut HeatshrinkEncoder<W, L, BUF>,
65    src: &[u8],
66    dst: &'a mut [u8],
67) -> Result<&'a [u8], CodecError> {
68    let mut total_input_size = 0;
69    let mut total_output_size = 0;
70
71    loop {
72        if total_input_size < src.len() {
73            match enc.sink(&src[total_input_size..]) {
74                Ok(n) => total_input_size += n,
75                Err(SinkError::Full) => {}
76                Err(SinkError::Misuse) => return Err(CodecError::Internal),
77            }
78        }
79
80        if total_input_size == src.len() {
81            enc.finish();
82        }
83
84        if total_output_size == dst.len() {
85            return Err(CodecError::OutputFull);
86        }
87
88        match enc.poll(&mut dst[total_output_size..]) {
89            Ok(Poll::More(n)) => {
90                total_output_size += n;
91                if total_output_size == dst.len() {
92                    return Err(CodecError::OutputFull);
93                }
94            }
95            Ok(Poll::Empty(n)) => {
96                total_output_size += n;
97                if total_input_size == src.len() {
98                    break;
99                }
100            }
101            Err(_) => return Err(CodecError::Internal),
102        }
103    }
104
105    Ok(&dst[..total_output_size])
106}
107
108impl<const W: usize, const L: usize, const BUF: usize> Default for HeatshrinkEncoder<W, L, BUF> {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl<const W: usize, const L: usize, const BUF: usize> HeatshrinkEncoder<W, L, BUF> {
115    /// Create a new encoder instance.
116    ///
117    /// # Panics
118    ///
119    /// Panics if `W < 4`, `L < 3`, `L >= W`, `W > 15`, or `BUF != 2 << W`.
120    pub fn new() -> Self {
121        assert!(W >= 4, "W must be >= 4");
122        assert!(L >= 3, "L must be >= 3");
123        assert!(L < W, "L must be < W");
124        assert!(
125            W <= 15,
126            "W must be <= 15 (BUF = 2<<W, max index u16::MAX-1 = 65534 >= 2<<15)"
127        );
128        assert!(BUF == 2 << W, "BUF must equal 2 << W");
129
130        HeatshrinkEncoder {
131            input_size: 0,
132            match_scan_index: 0,
133            match_length: 0,
134            match_position: 0,
135            outgoing_bits: 0,
136            outgoing_bits_count: 0,
137            is_finishing: false,
138            current_byte: 0,
139            bit_index: 8,
140            state: HSEstate::NotFull,
141            #[cfg(feature = "heatshrink-use-index")]
142            search_index: [u16::MAX; BUF],
143            input_buffer: [0; BUF],
144        }
145    }
146
147    /// Reset the encoder to its initial state so it can be reused.
148    pub fn reset(&mut self) {
149        *self = Self::new();
150    }
151
152    /// Feed input data into the encoder.
153    pub fn sink(&mut self, input_buffer: &[u8]) -> Result<usize, SinkError> {
154        if self.is_finishing {
155            return Err(SinkError::Misuse);
156        }
157        if self.state != HSEstate::NotFull {
158            return Err(SinkError::Full);
159        }
160
161        let remaining_size = self.get_input_buffer_size() - self.input_size;
162        if remaining_size == 0 {
163            return Err(SinkError::Full);
164        }
165
166        let copy_size = remaining_size.min(input_buffer.len());
167        let write_offset = self.get_input_offset() + self.input_size;
168
169        self.input_buffer[write_offset..write_offset + copy_size]
170            .copy_from_slice(&input_buffer[..copy_size]);
171        self.input_size += copy_size;
172
173        if self.input_size == self.get_input_buffer_size() {
174            self.state = HSEstate::Filled;
175        }
176
177        Ok(copy_size)
178    }
179
180    /// Pull compressed output out of the encoder into `output_buffer`.
181    pub fn poll(&mut self, output_buffer: &mut [u8]) -> Result<Poll, PollError> {
182        if output_buffer.is_empty() {
183            return Err(PollError::Misuse);
184        }
185
186        let mut out_pos: usize = 0;
187
188        loop {
189            let previous_state = self.state;
190
191            match previous_state {
192                HSEstate::NotFull => return Ok(Poll::Empty(out_pos)),
193                HSEstate::Filled => {
194                    self.do_indexing();
195                    self.state = HSEstate::Search;
196                }
197                HSEstate::Search => {
198                    self.state = self.st_step_search();
199                }
200                HSEstate::YieldTagBit => {
201                    self.state = self.st_yield_tag_bit(output_buffer, &mut out_pos);
202                }
203                HSEstate::YieldLiteral => {
204                    self.state = self.st_yield_literal(output_buffer, &mut out_pos);
205                }
206                HSEstate::YieldBrIndex => {
207                    self.state = self.st_yield_br_index(output_buffer, &mut out_pos);
208                }
209                HSEstate::YieldBrLength => {
210                    self.state = self.st_yield_br_length(output_buffer, &mut out_pos);
211                }
212                HSEstate::SaveBacklog => {
213                    self.state = self.st_save_backlog();
214                }
215                HSEstate::FlushBits => {
216                    self.state = self.st_flush_bit_buffer(output_buffer, &mut out_pos);
217                    return Ok(Poll::Empty(out_pos));
218                }
219                HSEstate::Done => return Ok(Poll::Empty(out_pos)),
220            }
221
222            if self.state == previous_state && out_pos == output_buffer.len() {
223                return Ok(Poll::More(out_pos));
224            }
225        }
226    }
227
228    /// Signal that all input has been provided.
229    pub fn finish(&mut self) -> Finish {
230        self.is_finishing = true;
231        if self.state == HSEstate::NotFull {
232            self.state = HSEstate::Filled;
233        }
234        if self.state == HSEstate::Done {
235            Finish::Done
236        } else {
237            Finish::More
238        }
239    }
240
241    // ---- State machine helpers ----
242
243    #[inline]
244    fn st_step_search(&mut self) -> HSEstate {
245        let lookahead = if self.is_finishing {
246            1
247        } else {
248            self.get_lookahead_size()
249        };
250        if self.match_scan_index + lookahead > self.input_size {
251            return if self.is_finishing {
252                HSEstate::FlushBits
253            } else {
254                HSEstate::SaveBacklog
255            };
256        }
257
258        let end = self.get_input_offset() + self.match_scan_index;
259        let start = end - self.get_input_buffer_size();
260        let max_possible = if self.input_size < self.get_lookahead_size() + self.match_scan_index {
261            self.input_size - self.match_scan_index
262        } else {
263            self.get_lookahead_size()
264        };
265
266        match self.find_longest_match(start, end, max_possible) {
267            None => {
268                self.match_scan_index += 1;
269                self.match_length = 0;
270            }
271            Some((position, length)) => {
272                self.match_position = position;
273                self.match_length = length;
274                assert!(self.match_position <= 1 << W);
275            }
276        }
277        HSEstate::YieldTagBit
278    }
279
280    #[inline]
281    fn st_yield_tag_bit(&mut self, out: &mut [u8], pos: &mut usize) -> HSEstate {
282        if *pos < out.len() {
283            if self.match_length == 0 {
284                self.add_tag_bit(out, pos, 0x1);
285                HSEstate::YieldLiteral
286            } else {
287                self.add_tag_bit(out, pos, 0);
288                self.outgoing_bits = self.match_position as u16 - 1;
289                self.outgoing_bits_count = W as u8;
290                HSEstate::YieldBrIndex
291            }
292        } else {
293            HSEstate::YieldTagBit
294        }
295    }
296
297    #[inline]
298    fn st_yield_literal(&mut self, out: &mut [u8], pos: &mut usize) -> HSEstate {
299        if *pos < out.len() {
300            self.push_literal_byte(out, pos);
301            HSEstate::Search
302        } else {
303            HSEstate::YieldLiteral
304        }
305    }
306
307    #[inline]
308    fn st_yield_br_index(&mut self, out: &mut [u8], pos: &mut usize) -> HSEstate {
309        if *pos < out.len() {
310            if self.push_outgoing_bits(out, pos) > 0 {
311                HSEstate::YieldBrIndex
312            } else {
313                self.outgoing_bits = self.match_length as u16 - 1;
314                self.outgoing_bits_count = L as u8;
315                HSEstate::YieldBrLength
316            }
317        } else {
318            HSEstate::YieldBrIndex
319        }
320    }
321
322    #[inline]
323    fn st_yield_br_length(&mut self, out: &mut [u8], pos: &mut usize) -> HSEstate {
324        if *pos < out.len() {
325            if self.push_outgoing_bits(out, pos) > 0 {
326                HSEstate::YieldBrLength
327            } else {
328                self.match_scan_index += self.match_length;
329                self.match_length = 0;
330                HSEstate::Search
331            }
332        } else {
333            HSEstate::YieldBrLength
334        }
335    }
336
337    #[inline]
338    fn st_save_backlog(&mut self) -> HSEstate {
339        self.save_backlog();
340        HSEstate::NotFull
341    }
342
343    #[inline]
344    fn st_flush_bit_buffer(&self, out: &mut [u8], pos: &mut usize) -> HSEstate {
345        if self.bit_index == 8 {
346            HSEstate::Done
347        } else if *pos < out.len() {
348            out[*pos] = self.current_byte;
349            *pos += 1;
350            HSEstate::Done
351        } else {
352            HSEstate::FlushBits
353        }
354    }
355
356    #[inline]
357    fn add_tag_bit(&mut self, out: &mut [u8], pos: &mut usize, tag: u8) {
358        self.push_bits(1, tag, out, pos)
359    }
360
361    #[inline]
362    fn get_input_offset(&self) -> usize {
363        self.get_input_buffer_size()
364    }
365
366    #[inline]
367    fn get_input_buffer_size(&self) -> usize {
368        BUF / 2
369    }
370
371    #[inline]
372    fn get_lookahead_size(&self) -> usize {
373        1 << L
374    }
375
376    #[inline]
377    fn do_indexing(&mut self) {
378        #[cfg(feature = "heatshrink-use-index")]
379        {
380            // last[v] = most recent position where byte v was seen, u16::MAX = none.
381            // enumerate() lets the compiler elide bounds-checks on both slices.
382            let mut last: [u16; 256] = [u16::MAX; 256];
383            let end = self.get_input_offset() + self.input_size - 1;
384            self.input_buffer[..end]
385                .iter()
386                .zip(self.search_index[..end].iter_mut())
387                .enumerate()
388                .for_each(|(i, (&v, slot))| {
389                    let v = v as usize;
390                    *slot = last[v];
391                    last[v] = i as u16;
392                });
393        }
394    }
395
396    #[inline]
397    fn find_longest_match(
398        &self,
399        start: usize,
400        end: usize,
401        maxlen: usize,
402    ) -> Option<(usize, usize)> {
403        let mut match_maxlen: usize = 0;
404        let mut match_index: usize = 0;
405
406        // Pre-slice the buffer once so every inner access is relative to a
407        // slice whose length the compiler knows.  This lets LLVM prove that
408        // needle[len] and candidate[len] are always in-bounds (len < maxlen
409        // is the loop invariant) and eliminate the bounds-checks entirely —
410        // without any `unsafe`.
411        //
412        // window covers [start .. end+maxlen]: the window bytes that can be
413        // referenced as back-reference sources, plus the lookahead we compare
414        // against.  All candidate positions satisfy start <= pos < end, so
415        // pos - start + maxlen <= end - start + maxlen = window.len().
416        let window = &self.input_buffer[start..end + maxlen];
417        // needle is the lookahead we are trying to match, relative to window.
418        let needle_off = end - start; // offset of `end` inside `window`
419        let needle = &window[needle_off..needle_off + maxlen];
420
421        #[cfg(not(feature = "heatshrink-use-index"))]
422        {
423            let mut position = end - 1;
424            loop {
425                let cand_off = position - start;
426                let candidate = &window[cand_off..cand_off + maxlen];
427                if candidate[0] == needle[0] && candidate[match_maxlen] == needle[match_maxlen] {
428                    let mut len = 1;
429                    while len < maxlen {
430                        if candidate[len] != needle[len] {
431                            break;
432                        }
433                        len += 1;
434                    }
435                    if len > match_maxlen {
436                        match_maxlen = len;
437                        match_index = position;
438                        if len == maxlen {
439                            break;
440                        }
441                    }
442                }
443                if position == start {
444                    break;
445                }
446                position -= 1;
447            }
448        }
449
450        #[cfg(feature = "heatshrink-use-index")]
451        {
452            let mut position = self.search_index[end];
453            while position != u16::MAX {
454                let pos = position as usize;
455                if pos < start {
456                    break;
457                }
458                let cand_off = pos - start;
459                let candidate = &window[cand_off..cand_off + maxlen];
460                // Quick-reject on the current best length before the inner loop.
461                if candidate[match_maxlen] != needle[match_maxlen] {
462                    position = self.search_index[pos];
463                    continue;
464                }
465                let mut len = 1;
466                while len < maxlen {
467                    if candidate[len] != needle[len] {
468                        break;
469                    }
470                    len += 1;
471                }
472                if len > match_maxlen {
473                    match_maxlen = len;
474                    match_index = pos;
475                    if len == maxlen {
476                        break;
477                    }
478                }
479                position = self.search_index[pos];
480            }
481        }
482
483        let break_even_point: usize = (1 + W + L) / 8;
484        if match_maxlen > break_even_point {
485            Some((end - match_index, match_maxlen))
486        } else {
487            None
488        }
489    }
490
491    #[inline]
492    fn push_outgoing_bits(&mut self, out: &mut [u8], pos: &mut usize) -> u8 {
493        let (count, bits) = if self.outgoing_bits_count > 8 {
494            (
495                8u8,
496                (self.outgoing_bits >> (self.outgoing_bits_count - 8)) as u8,
497            )
498        } else {
499            (self.outgoing_bits_count, self.outgoing_bits as u8)
500        };
501        if count > 0 {
502            self.push_bits(count, bits, out, pos);
503            self.outgoing_bits_count -= count;
504        }
505        count
506    }
507
508    #[inline]
509    fn push_bits(&mut self, count: u8, bits: u8, out: &mut [u8], pos: &mut usize) {
510        debug_assert!(count > 0 && count <= 8);
511        // Fast path: pushing a full aligned byte (the common case for literals).
512        if count == 8 && self.bit_index == 8 {
513            out[*pos] = bits;
514            *pos += 1;
515            return;
516        }
517        if count >= self.bit_index {
518            let shift = count - self.bit_index;
519            let tmp_byte = self.current_byte | (bits >> shift);
520            out[*pos] = tmp_byte;
521            *pos += 1;
522            self.bit_index = 8 - shift;
523            self.current_byte = if shift == 0 {
524                0
525            } else {
526                bits << self.bit_index
527            };
528        } else {
529            self.bit_index -= count;
530            self.current_byte |= bits << self.bit_index;
531        }
532    }
533
534    #[inline]
535    fn push_literal_byte(&mut self, out: &mut [u8], pos: &mut usize) {
536        let byte = self.input_buffer[self.get_input_offset() + self.match_scan_index - 1];
537        self.push_bits(8, byte, out, pos);
538    }
539
540    #[inline]
541    fn save_backlog(&mut self) {
542        self.input_buffer.copy_within(self.match_scan_index.., 0);
543        self.input_size -= self.match_scan_index;
544        self.match_scan_index = 0;
545    }
546}