1use super::CodecError;
2use super::Finish;
3use super::Poll;
4use super::PollError;
5use super::SinkError;
6
7#[derive(Debug, Copy, Clone, PartialEq)]
8enum HSDstate {
9 TagBit,
10 YieldLiteral,
11 BackrefIndexMsb,
12 BackrefIndexLsb,
13 BackrefCountLsb,
14 YieldBackref,
15}
16
17#[derive(Debug)]
36pub struct HeatshrinkDecoder<const W: usize, const L: usize, const I: usize, const WIN: usize> {
37 input_size: usize,
38 input_index: usize,
39 output_index: usize,
40 head_index: usize,
41 output_count: u16,
42 current_byte: u8,
43 bit_index: u8,
44 state: HSDstate,
45 input_buffer: [u8; I],
46 output_buffer: [u8; WIN],
47}
48
49pub fn decode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a [u8], CodecError> {
51 let mut dec = super::DefaultDecoder::new();
52 run_decode(&mut dec, src, dst)
53}
54
55pub(crate) fn run_decode<'a, const W: usize, const L: usize, const I: usize, const WIN: usize>(
57 dec: &mut HeatshrinkDecoder<W, L, I, WIN>,
58 src: &[u8],
59 dst: &'a mut [u8],
60) -> Result<&'a [u8], CodecError> {
61 let mut total_input_size = 0;
62 let mut total_output_size = 0;
63
64 while total_input_size < src.len() {
65 match dec.sink(&src[total_input_size..]) {
66 Ok(n) => total_input_size += n,
67 Err(SinkError::Full) => {}
68 Err(SinkError::Misuse) => return Err(CodecError::Internal),
69 }
70
71 if total_output_size == dst.len() {
72 return Err(CodecError::OutputFull);
73 }
74
75 match dec.poll(&mut dst[total_output_size..]) {
76 Ok(Poll::More(_)) => return Err(CodecError::OutputFull),
77 Ok(Poll::Empty(n)) => total_output_size += n,
78 Err(_) => return Err(CodecError::Internal),
79 }
80
81 if total_input_size == src.len() {
82 match dec.finish() {
83 Finish::Done => {}
84 Finish::More => return Err(CodecError::OutputFull),
85 }
86 }
87 }
88
89 Ok(&dst[..total_output_size])
90}
91
92impl<const W: usize, const L: usize, const I: usize, const WIN: usize> Default
93 for HeatshrinkDecoder<W, L, I, WIN>
94{
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100impl<const W: usize, const L: usize, const I: usize, const WIN: usize>
101 HeatshrinkDecoder<W, L, I, WIN>
102{
103 pub fn new() -> Self {
110 assert!(W >= 4, "W must be >= 4");
111 assert!(L >= 3, "L must be >= 3");
112 assert!(L < W, "L must be < W");
113 assert!(W <= 15, "W must be <= 15 (search_index uses Option<u16>)");
114 assert!(I >= 1, "I must be >= 1");
115 assert!(WIN == 1 << W, "WIN must equal 1 << W");
116
117 HeatshrinkDecoder {
118 input_size: 0,
119 input_index: 0,
120 output_count: 0,
121 output_index: 0,
122 head_index: 0,
123 current_byte: 0,
124 bit_index: 0,
125 state: HSDstate::TagBit,
126 input_buffer: [0; I],
127 output_buffer: [0; WIN],
128 }
129 }
130
131 pub fn reset(&mut self) {
133 *self = Self::new();
134 }
135
136 pub fn sink(&mut self, input_buffer: &[u8]) -> Result<usize, SinkError> {
138 let unconsumed = self.input_size - self.input_index;
145 if self.input_index > 0 && unconsumed > 0 {
146 self.input_buffer
147 .copy_within(self.input_index..self.input_size, 0);
148 }
149 self.input_size = unconsumed;
150 self.input_index = 0;
151
152 let remaining_size = self.input_buffer.len() - self.input_size;
153 if remaining_size == 0 {
154 return Err(SinkError::Full);
155 }
156
157 let copy_size = remaining_size.min(input_buffer.len());
158 self.input_buffer[self.input_size..self.input_size + copy_size]
159 .copy_from_slice(&input_buffer[..copy_size]);
160 self.input_size += copy_size;
161
162 if self.bit_index == 0 {
163 self.current_byte = self.input_buffer[self.input_index];
164 self.input_index += 1;
165 self.bit_index = 8;
166 }
167
168 Ok(copy_size)
169 }
170
171 pub fn poll(&mut self, output_buffer: &mut [u8]) -> Result<Poll, PollError> {
173 if output_buffer.is_empty() {
174 return Err(PollError::Misuse);
175 }
176
177 let mut out_pos: usize = 0;
178
179 loop {
180 let previous_state = self.state;
181
182 match previous_state {
183 HSDstate::TagBit => {
184 self.state = self.st_tag_bit();
185 }
186 HSDstate::YieldLiteral => {
187 self.state = self.st_yield_literal(output_buffer, &mut out_pos);
188 }
189 HSDstate::BackrefIndexMsb => {
190 self.state = self.st_backref_index_msb();
191 }
192 HSDstate::BackrefIndexLsb => {
193 self.state = self.st_backref_index_lsb();
194 }
195 HSDstate::BackrefCountLsb => {
196 self.state = self.st_backref_count_lsb();
197 }
198 HSDstate::YieldBackref => {
199 self.state = self.st_yield_backref(output_buffer, &mut out_pos);
200 }
201 }
202
203 if self.state == previous_state {
204 return if out_pos < output_buffer.len() {
205 Ok(Poll::Empty(out_pos))
206 } else {
207 Ok(Poll::More(out_pos))
208 };
209 }
210 }
211 }
212
213 pub fn finish(&self) -> Finish {
215 if self.input_size == 0 {
216 Finish::Done
217 } else {
218 Finish::More
219 }
220 }
221
222 #[inline]
225 fn st_tag_bit(&mut self) -> HSDstate {
226 match self.get_bits(1) {
227 None => HSDstate::TagBit,
228 Some(0) => {
229 self.output_index = 0;
230 if W > 8 {
231 HSDstate::BackrefIndexMsb
232 } else {
233 HSDstate::BackrefIndexLsb
234 }
235 }
236 Some(_) => HSDstate::YieldLiteral,
237 }
238 }
239
240 #[inline]
241 fn st_yield_literal(&mut self, out: &mut [u8], pos: &mut usize) -> HSDstate {
242 if *pos < out.len() {
243 match self.get_bits(8) {
244 None => HSDstate::YieldLiteral,
245 Some(c) => {
246 let c = c as u8;
247 self.output_buffer[self.head_index % WIN] = c;
248 self.head_index += 1;
249 out[*pos] = c;
250 *pos += 1;
251 HSDstate::TagBit
252 }
253 }
254 } else {
255 HSDstate::YieldLiteral
256 }
257 }
258
259 #[inline]
261 fn st_backref_index_msb(&mut self) -> HSDstate {
262 match self.get_bits((W - 8) as u8) {
263 None => HSDstate::BackrefIndexMsb,
264 Some(x) => {
265 self.output_index = (x as usize) << 8;
266 HSDstate::BackrefIndexLsb
267 }
268 }
269 }
270
271 #[inline]
272 fn st_backref_index_lsb(&mut self) -> HSDstate {
273 let lsb_bits = W.min(8) as u8;
278 match self.get_bits(lsb_bits) {
279 None => HSDstate::BackrefIndexLsb,
280 Some(x) => {
281 self.output_index |= x as usize;
282 self.output_index += 1;
283 self.output_count = 0;
284 HSDstate::BackrefCountLsb
285 }
286 }
287 }
288
289 #[inline]
290 fn st_backref_count_lsb(&mut self) -> HSDstate {
291 match self.get_bits(L as u8) {
292 None => HSDstate::BackrefCountLsb,
293 Some(x) => {
294 self.output_count = x + 1;
295 HSDstate::YieldBackref
296 }
297 }
298 }
299
300 #[inline]
301 fn st_yield_backref(&mut self, out: &mut [u8], pos: &mut usize) -> HSDstate {
302 if *pos == out.len() {
303 return HSDstate::YieldBackref;
304 }
305
306 let output_index = self.output_index;
307 let count = (out.len() - *pos).min(self.output_count as usize);
308
309 if output_index > self.head_index {
313 let zero_count = count.min(output_index - self.head_index);
314 let limit = self.head_index + zero_count;
315 while self.head_index < limit {
316 out[*pos] = 0;
317 *pos += 1;
318 self.output_buffer[self.head_index & (WIN - 1)] = 0;
319 self.head_index += 1;
320 }
321 self.output_count -= zero_count as u16;
322 if self.output_count == 0 {
323 return HSDstate::TagBit;
324 }
325 if *pos == out.len() {
326 return HSDstate::YieldBackref;
327 }
328 }
329
330 let count = (out.len() - *pos).min(self.output_count as usize);
332
333 if output_index >= count {
334 let src_start = (self.head_index - output_index) & (WIN - 1);
341 let dst_start = self.head_index & (WIN - 1);
342
343 let src_end = src_start + count;
345 let dst_end = dst_start + count;
346
347 if src_end <= WIN && dst_end <= WIN {
348 self.output_buffer
350 .copy_within(src_start..src_start + count, dst_start);
351 } else {
352 let limit = self.head_index + count;
356 let mut h = self.head_index;
357 while h < limit {
358 let s = (h - output_index) & (WIN - 1);
359 let d = h & (WIN - 1);
360 self.output_buffer[d] = self.output_buffer[s];
361 h += 1;
362 }
363 }
364
365 if dst_end <= WIN {
369 out[*pos..*pos + count]
370 .copy_from_slice(&self.output_buffer[dst_start..dst_start + count]);
371 } else {
372 let first = WIN - dst_start;
373 let second = count - first;
374 out[*pos..*pos + first].copy_from_slice(&self.output_buffer[dst_start..WIN]);
375 out[*pos + first..*pos + count].copy_from_slice(&self.output_buffer[..second]);
376 }
377 *pos += count;
378 self.head_index += count;
379 } else {
380 let limit = self.head_index + count;
385 while self.head_index < limit {
386 let c = self.output_buffer[(self.head_index - output_index) & (WIN - 1)];
387 out[*pos] = c;
388 *pos += 1;
389 self.output_buffer[self.head_index & (WIN - 1)] = c;
390 self.head_index += 1;
391 }
392 }
393
394 self.output_count -= count as u16;
395 if self.output_count == 0 {
396 HSDstate::TagBit
397 } else {
398 HSDstate::YieldBackref
399 }
400 }
401
402 fn get_bits(&mut self, count: u8) -> Option<u16> {
408 debug_assert!(count > 0 && count <= 15);
409
410 let available = (self.input_size - self.input_index) * 8 + self.bit_index as usize;
411 if available < count as usize {
412 return None;
413 }
414
415 let mut acc = (self.current_byte as u32) & ((1 << self.bit_index) - 1);
418 let mut bits = self.bit_index;
419
420 while bits < count {
421 self.current_byte = self.input_buffer[self.input_index];
422 self.input_index += 1;
423 acc = (acc << 8) | self.current_byte as u32;
424 bits += 8;
425 }
426
427 let remaining = bits - count;
428 let result = (acc >> remaining) & ((1u32 << count) - 1);
429
430 if remaining == 0 {
435 if self.input_index < self.input_size {
436 self.current_byte = self.input_buffer[self.input_index];
437 self.input_index += 1;
438 self.bit_index = 8;
439 if self.input_index == self.input_size {
441 self.input_index = 0;
442 self.input_size = 0;
443 }
444 } else {
445 self.input_index = 0;
446 self.input_size = 0;
447 self.bit_index = 0;
448 self.current_byte = 0;
449 }
450 } else {
451 self.bit_index = remaining;
452 self.current_byte = (acc & ((1 << remaining) - 1)) as u8;
453 if self.input_index == self.input_size {
454 self.input_index = 0;
455 self.input_size = 0;
456 }
457 }
458
459 Some(result as u16)
460 }
461}