1use super::CodecError;
2use super::Finish;
3use super::Poll;
4use super::PollError;
5use super::SinkError;
6
7#[derive(Debug, Copy, Clone, PartialEq)]
8enum HSEstate {
9 NotFull,
10 Filled,
11 Search,
12 YieldTagBit,
13 YieldLiteral,
14 YieldBrIndex,
15 YieldBrLength,
16 SaveBacklog,
17 FlushBits,
18 Done,
19}
20
21#[derive(Debug)]
40pub struct HeatshrinkEncoder<const W: usize, const L: usize, const BUF: usize> {
41 input_size: usize,
42 match_scan_index: usize,
43 match_length: usize,
44 match_position: usize,
45 outgoing_bits: u16,
46 outgoing_bits_count: u8,
47 is_finishing: bool,
48 current_byte: u8,
49 bit_index: u8,
50 state: HSEstate,
51 #[cfg(feature = "heatshrink-use-index")]
52 search_index: [u16; BUF],
53 input_buffer: [u8; BUF],
54}
55
56pub fn encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a [u8], CodecError> {
58 let mut enc = super::DefaultEncoder::new();
59 run_encode(&mut enc, src, dst)
60}
61
62pub(crate) fn run_encode<'a, const W: usize, const L: usize, const BUF: usize>(
64 enc: &mut HeatshrinkEncoder<W, L, BUF>,
65 src: &[u8],
66 dst: &'a mut [u8],
67) -> Result<&'a [u8], CodecError> {
68 let mut total_input_size = 0;
69 let mut total_output_size = 0;
70
71 loop {
72 if total_input_size < src.len() {
73 match enc.sink(&src[total_input_size..]) {
74 Ok(n) => total_input_size += n,
75 Err(SinkError::Full) => {}
76 Err(SinkError::Misuse) => return Err(CodecError::Internal),
77 }
78 }
79
80 if total_input_size == src.len() {
81 enc.finish();
82 }
83
84 if total_output_size == dst.len() {
85 return Err(CodecError::OutputFull);
86 }
87
88 match enc.poll(&mut dst[total_output_size..]) {
89 Ok(Poll::More(n)) => {
90 total_output_size += n;
91 if total_output_size == dst.len() {
92 return Err(CodecError::OutputFull);
93 }
94 }
95 Ok(Poll::Empty(n)) => {
96 total_output_size += n;
97 if total_input_size == src.len() {
98 break;
99 }
100 }
101 Err(_) => return Err(CodecError::Internal),
102 }
103 }
104
105 Ok(&dst[..total_output_size])
106}
107
108impl<const W: usize, const L: usize, const BUF: usize> Default for HeatshrinkEncoder<W, L, BUF> {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl<const W: usize, const L: usize, const BUF: usize> HeatshrinkEncoder<W, L, BUF> {
115 pub fn new() -> Self {
121 assert!(W >= 4, "W must be >= 4");
122 assert!(L >= 3, "L must be >= 3");
123 assert!(L < W, "L must be < W");
124 assert!(
125 W <= 15,
126 "W must be <= 15 (BUF = 2<<W, max index u16::MAX-1 = 65534 >= 2<<15)"
127 );
128 assert!(BUF == 2 << W, "BUF must equal 2 << W");
129
130 HeatshrinkEncoder {
131 input_size: 0,
132 match_scan_index: 0,
133 match_length: 0,
134 match_position: 0,
135 outgoing_bits: 0,
136 outgoing_bits_count: 0,
137 is_finishing: false,
138 current_byte: 0,
139 bit_index: 8,
140 state: HSEstate::NotFull,
141 #[cfg(feature = "heatshrink-use-index")]
142 search_index: [u16::MAX; BUF],
143 input_buffer: [0; BUF],
144 }
145 }
146
147 pub fn reset(&mut self) {
149 *self = Self::new();
150 }
151
152 pub fn sink(&mut self, input_buffer: &[u8]) -> Result<usize, SinkError> {
154 if self.is_finishing {
155 return Err(SinkError::Misuse);
156 }
157 if self.state != HSEstate::NotFull {
158 return Err(SinkError::Full);
159 }
160
161 let remaining_size = self.get_input_buffer_size() - self.input_size;
162 if remaining_size == 0 {
163 return Err(SinkError::Full);
164 }
165
166 let copy_size = remaining_size.min(input_buffer.len());
167 let write_offset = self.get_input_offset() + self.input_size;
168
169 self.input_buffer[write_offset..write_offset + copy_size]
170 .copy_from_slice(&input_buffer[..copy_size]);
171 self.input_size += copy_size;
172
173 if self.input_size == self.get_input_buffer_size() {
174 self.state = HSEstate::Filled;
175 }
176
177 Ok(copy_size)
178 }
179
180 pub fn poll(&mut self, output_buffer: &mut [u8]) -> Result<Poll, PollError> {
182 if output_buffer.is_empty() {
183 return Err(PollError::Misuse);
184 }
185
186 let mut out_pos: usize = 0;
187
188 loop {
189 let previous_state = self.state;
190
191 match previous_state {
192 HSEstate::NotFull => return Ok(Poll::Empty(out_pos)),
193 HSEstate::Filled => {
194 self.do_indexing();
195 self.state = HSEstate::Search;
196 }
197 HSEstate::Search => {
198 self.state = self.st_step_search();
199 }
200 HSEstate::YieldTagBit => {
201 self.state = self.st_yield_tag_bit(output_buffer, &mut out_pos);
202 }
203 HSEstate::YieldLiteral => {
204 self.state = self.st_yield_literal(output_buffer, &mut out_pos);
205 }
206 HSEstate::YieldBrIndex => {
207 self.state = self.st_yield_br_index(output_buffer, &mut out_pos);
208 }
209 HSEstate::YieldBrLength => {
210 self.state = self.st_yield_br_length(output_buffer, &mut out_pos);
211 }
212 HSEstate::SaveBacklog => {
213 self.state = self.st_save_backlog();
214 }
215 HSEstate::FlushBits => {
216 self.state = self.st_flush_bit_buffer(output_buffer, &mut out_pos);
217 return Ok(Poll::Empty(out_pos));
218 }
219 HSEstate::Done => return Ok(Poll::Empty(out_pos)),
220 }
221
222 if self.state == previous_state && out_pos == output_buffer.len() {
223 return Ok(Poll::More(out_pos));
224 }
225 }
226 }
227
228 pub fn finish(&mut self) -> Finish {
230 self.is_finishing = true;
231 if self.state == HSEstate::NotFull {
232 self.state = HSEstate::Filled;
233 }
234 if self.state == HSEstate::Done {
235 Finish::Done
236 } else {
237 Finish::More
238 }
239 }
240
241 #[inline]
244 fn st_step_search(&mut self) -> HSEstate {
245 let lookahead = if self.is_finishing {
246 1
247 } else {
248 self.get_lookahead_size()
249 };
250 if self.match_scan_index + lookahead > self.input_size {
251 return if self.is_finishing {
252 HSEstate::FlushBits
253 } else {
254 HSEstate::SaveBacklog
255 };
256 }
257
258 let end = self.get_input_offset() + self.match_scan_index;
259 let start = end - self.get_input_buffer_size();
260 let max_possible = if self.input_size < self.get_lookahead_size() + self.match_scan_index {
261 self.input_size - self.match_scan_index
262 } else {
263 self.get_lookahead_size()
264 };
265
266 match self.find_longest_match(start, end, max_possible) {
267 None => {
268 self.match_scan_index += 1;
269 self.match_length = 0;
270 }
271 Some((position, length)) => {
272 self.match_position = position;
273 self.match_length = length;
274 assert!(self.match_position <= 1 << W);
275 }
276 }
277 HSEstate::YieldTagBit
278 }
279
280 #[inline]
281 fn st_yield_tag_bit(&mut self, out: &mut [u8], pos: &mut usize) -> HSEstate {
282 if *pos < out.len() {
283 if self.match_length == 0 {
284 self.add_tag_bit(out, pos, 0x1);
285 HSEstate::YieldLiteral
286 } else {
287 self.add_tag_bit(out, pos, 0);
288 self.outgoing_bits = self.match_position as u16 - 1;
289 self.outgoing_bits_count = W as u8;
290 HSEstate::YieldBrIndex
291 }
292 } else {
293 HSEstate::YieldTagBit
294 }
295 }
296
297 #[inline]
298 fn st_yield_literal(&mut self, out: &mut [u8], pos: &mut usize) -> HSEstate {
299 if *pos < out.len() {
300 self.push_literal_byte(out, pos);
301 HSEstate::Search
302 } else {
303 HSEstate::YieldLiteral
304 }
305 }
306
307 #[inline]
308 fn st_yield_br_index(&mut self, out: &mut [u8], pos: &mut usize) -> HSEstate {
309 if *pos < out.len() {
310 if self.push_outgoing_bits(out, pos) > 0 {
311 HSEstate::YieldBrIndex
312 } else {
313 self.outgoing_bits = self.match_length as u16 - 1;
314 self.outgoing_bits_count = L as u8;
315 HSEstate::YieldBrLength
316 }
317 } else {
318 HSEstate::YieldBrIndex
319 }
320 }
321
322 #[inline]
323 fn st_yield_br_length(&mut self, out: &mut [u8], pos: &mut usize) -> HSEstate {
324 if *pos < out.len() {
325 if self.push_outgoing_bits(out, pos) > 0 {
326 HSEstate::YieldBrLength
327 } else {
328 self.match_scan_index += self.match_length;
329 self.match_length = 0;
330 HSEstate::Search
331 }
332 } else {
333 HSEstate::YieldBrLength
334 }
335 }
336
337 #[inline]
338 fn st_save_backlog(&mut self) -> HSEstate {
339 self.save_backlog();
340 HSEstate::NotFull
341 }
342
343 #[inline]
344 fn st_flush_bit_buffer(&self, out: &mut [u8], pos: &mut usize) -> HSEstate {
345 if self.bit_index == 8 {
346 HSEstate::Done
347 } else if *pos < out.len() {
348 out[*pos] = self.current_byte;
349 *pos += 1;
350 HSEstate::Done
351 } else {
352 HSEstate::FlushBits
353 }
354 }
355
356 #[inline]
357 fn add_tag_bit(&mut self, out: &mut [u8], pos: &mut usize, tag: u8) {
358 self.push_bits(1, tag, out, pos)
359 }
360
361 #[inline]
362 fn get_input_offset(&self) -> usize {
363 self.get_input_buffer_size()
364 }
365
366 #[inline]
367 fn get_input_buffer_size(&self) -> usize {
368 BUF / 2
369 }
370
371 #[inline]
372 fn get_lookahead_size(&self) -> usize {
373 1 << L
374 }
375
376 #[inline]
377 fn do_indexing(&mut self) {
378 #[cfg(feature = "heatshrink-use-index")]
379 {
380 let mut last: [u16; 256] = [u16::MAX; 256];
383 let end = self.get_input_offset() + self.input_size - 1;
384 self.input_buffer[..end]
385 .iter()
386 .zip(self.search_index[..end].iter_mut())
387 .enumerate()
388 .for_each(|(i, (&v, slot))| {
389 let v = v as usize;
390 *slot = last[v];
391 last[v] = i as u16;
392 });
393 }
394 }
395
396 #[inline]
397 fn find_longest_match(
398 &self,
399 start: usize,
400 end: usize,
401 maxlen: usize,
402 ) -> Option<(usize, usize)> {
403 let mut match_maxlen: usize = 0;
404 let mut match_index: usize = 0;
405
406 let window = &self.input_buffer[start..end + maxlen];
417 let needle_off = end - start; let needle = &window[needle_off..needle_off + maxlen];
420
421 #[cfg(not(feature = "heatshrink-use-index"))]
422 {
423 let mut position = end - 1;
424 loop {
425 let cand_off = position - start;
426 let candidate = &window[cand_off..cand_off + maxlen];
427 if candidate[0] == needle[0] && candidate[match_maxlen] == needle[match_maxlen] {
428 let mut len = 1;
429 while len < maxlen {
430 if candidate[len] != needle[len] {
431 break;
432 }
433 len += 1;
434 }
435 if len > match_maxlen {
436 match_maxlen = len;
437 match_index = position;
438 if len == maxlen {
439 break;
440 }
441 }
442 }
443 if position == start {
444 break;
445 }
446 position -= 1;
447 }
448 }
449
450 #[cfg(feature = "heatshrink-use-index")]
451 {
452 let mut position = self.search_index[end];
453 while position != u16::MAX {
454 let pos = position as usize;
455 if pos < start {
456 break;
457 }
458 let cand_off = pos - start;
459 let candidate = &window[cand_off..cand_off + maxlen];
460 if candidate[match_maxlen] != needle[match_maxlen] {
462 position = self.search_index[pos];
463 continue;
464 }
465 let mut len = 1;
466 while len < maxlen {
467 if candidate[len] != needle[len] {
468 break;
469 }
470 len += 1;
471 }
472 if len > match_maxlen {
473 match_maxlen = len;
474 match_index = pos;
475 if len == maxlen {
476 break;
477 }
478 }
479 position = self.search_index[pos];
480 }
481 }
482
483 let break_even_point: usize = (1 + W + L) / 8;
484 if match_maxlen > break_even_point {
485 Some((end - match_index, match_maxlen))
486 } else {
487 None
488 }
489 }
490
491 #[inline]
492 fn push_outgoing_bits(&mut self, out: &mut [u8], pos: &mut usize) -> u8 {
493 let (count, bits) = if self.outgoing_bits_count > 8 {
494 (
495 8u8,
496 (self.outgoing_bits >> (self.outgoing_bits_count - 8)) as u8,
497 )
498 } else {
499 (self.outgoing_bits_count, self.outgoing_bits as u8)
500 };
501 if count > 0 {
502 self.push_bits(count, bits, out, pos);
503 self.outgoing_bits_count -= count;
504 }
505 count
506 }
507
508 #[inline]
509 fn push_bits(&mut self, count: u8, bits: u8, out: &mut [u8], pos: &mut usize) {
510 debug_assert!(count > 0 && count <= 8);
511 if count == 8 && self.bit_index == 8 {
513 out[*pos] = bits;
514 *pos += 1;
515 return;
516 }
517 if count >= self.bit_index {
518 let shift = count - self.bit_index;
519 let tmp_byte = self.current_byte | (bits >> shift);
520 out[*pos] = tmp_byte;
521 *pos += 1;
522 self.bit_index = 8 - shift;
523 self.current_byte = if shift == 0 {
524 0
525 } else {
526 bits << self.bit_index
527 };
528 } else {
529 self.bit_index -= count;
530 self.current_byte |= bits << self.bit_index;
531 }
532 }
533
534 #[inline]
535 fn push_literal_byte(&mut self, out: &mut [u8], pos: &mut usize) {
536 let byte = self.input_buffer[self.get_input_offset() + self.match_scan_index - 1];
537 self.push_bits(8, byte, out, pos);
538 }
539
540 #[inline]
541 fn save_backlog(&mut self) {
542 self.input_buffer.copy_within(self.match_scan_index.., 0);
543 self.input_size -= self.match_scan_index;
544 self.match_scan_index = 0;
545 }
546}