1#![cfg_attr(not(any(feature = "std", test)), no_std)]
11#[cfg(feature = "std")]
15extern crate std;
16
17#[cfg(not(feature = "std"))]
18extern crate alloc;
19
20#[cfg(feature = "std")]
21use std::io::{Read, Write};
22
23pub(crate) mod common;
24pub mod heatshrink_decoder;
25pub mod heatshrink_encoder;
26
27pub use heatshrink_decoder::*;
28pub use heatshrink_encoder::*;
29
30pub const HEATSHRINK_MIN_WINDOW_BITS: u8 = 4;
32pub const HEATSHRINK_MAX_WINDOW_BITS: u8 = 15; pub const HEATSHRINK_MIN_LOOKAHEAD_BITS: u8 = 3;
34
35#[cfg(feature = "std")]
37pub fn encode(window_sz2: u8, lookahead_sz2: u8, stdin: &mut impl Read, stdout: &mut impl Write) {
38 let mut encoder =
39 HeatshrinkEncoder::new(window_sz2, lookahead_sz2).expect("Failed to create encoder");
40
41 const WORK_SIZE_UNIT: usize = 1024;
42 let mut buf = [0; WORK_SIZE_UNIT];
43 let mut scratch = [0; WORK_SIZE_UNIT * 2];
44
45 let mut not_empty = false;
47 loop {
48 let read_len = read_in(stdin, &mut buf);
49 not_empty |= read_len > 0;
50 if read_len == 0 {
51 break;
52 }
53 let mut read_data = &buf[..read_len];
54 while !read_data.is_empty() {
55 let sink_res = encoder.sink(read_data);
56 match sink_res {
57 HSESinkRes::Ok(bytes_sunk) => {
58 read_data = &read_data[bytes_sunk..];
59 }
60 _ => unreachable!(),
61 }
62
63 loop {
64 match encoder.poll(&mut scratch) {
65 HSEPollRes::Empty(sz) => {
66 write_out(stdout, &scratch[..sz]);
67 break;
68 }
69 HSEPollRes::More(sz) => {
70 write_out(stdout, &scratch[..sz]);
71 }
72 HSEPollRes::ErrorMisuse | HSEPollRes::ErrorNull => unreachable!(),
73 }
74 }
75 }
76 }
77
78 if !not_empty {
79 return;
80 }
81
82 loop {
84 match encoder.finish() {
85 HSEFinishRes::Done => {
86 break;
87 }
88 HSEFinishRes::More => {}
89 HSEFinishRes::ErrorNull => unreachable!(),
90 }
91
92 loop {
93 match encoder.poll(&mut scratch) {
94 HSEPollRes::Empty(sz) => {
95 write_out(stdout, &scratch[..sz]);
96 break;
97 }
98 HSEPollRes::More(sz) => {
99 write_out(stdout, &scratch[..sz]);
100 }
101 HSEPollRes::ErrorMisuse | HSEPollRes::ErrorNull => unreachable!(),
102 }
103 }
104 }
105}
106
107#[cfg(feature = "std")]
109pub fn decode(window_sz2: u8, lookahead_sz2: u8, stdin: &mut impl Read, stdout: &mut impl Write) {
110 const WORK_SIZE_UNIT: usize = 1024;
111
112 let mut decoder = HeatshrinkDecoder::new(WORK_SIZE_UNIT as u16, window_sz2, lookahead_sz2)
113 .expect("Failed to create decoder");
114 let mut buf = [0; WORK_SIZE_UNIT];
115 let mut scratch = [0; WORK_SIZE_UNIT * 2];
116
117 let mut not_empty = false;
119 loop {
120 let read_len = read_in(stdin, &mut buf);
121 not_empty |= read_len > 0;
122 if read_len == 0 {
123 break;
124 }
125 let mut read_data = &buf[..read_len];
126 while !read_data.is_empty() {
127 let sink_res = decoder.sink(read_data);
128 match sink_res {
129 HSDSinkRes::Ok(bytes_sunk) => {
130 read_data = &read_data[bytes_sunk..];
131 }
132 _ => unreachable!(),
133 }
134
135 loop {
136 match decoder.poll(&mut scratch) {
137 HSDPollRes::Empty(sz) => {
138 write_out(stdout, &scratch[..sz]);
139 break;
140 }
141 HSDPollRes::More(sz) => {
142 write_out(stdout, &scratch[..sz]);
143 }
144 HSDPollRes::ErrorNull => unreachable!(),
145 HSDPollRes::ErrorUnknown => {
146 panic!("Error: Unknown");
147 }
148 }
149 }
150 }
151 }
152
153 if !not_empty {
154 return;
155 }
156
157 loop {
159 match decoder.finish() {
160 HSDFinishRes::Done => {
161 break;
162 }
163 HSDFinishRes::More => {}
164 HSDFinishRes::ErrorNull => unreachable!(),
165 }
166
167 loop {
168 match decoder.poll(&mut scratch) {
169 HSDPollRes::Empty(sz) => {
170 write_out(stdout, &scratch[..sz]);
171 break;
172 }
173 HSDPollRes::More(sz) => {
174 write_out(stdout, &scratch[..sz]);
175 }
176 HSDPollRes::ErrorNull => unreachable!(),
177 HSDPollRes::ErrorUnknown => {
178 panic!("Error: Unknown");
179 }
180 }
181 }
182 }
183}
184
185#[cfg(feature = "std")]
186#[inline]
187fn read_in(stdin: &mut impl Read, buf: &mut [u8]) -> usize {
188 stdin.read(buf).expect("Failed to read from stdin")
189}
190
191#[cfg(feature = "std")]
192#[inline]
193fn write_out(stdout: &mut impl Write, data: &[u8]) {
194 stdout.write_all(data).expect("Failed to write to stdout");
195}
196
197#[cfg(test)]
198mod tests {
199 use rayon::prelude::*;
200 use std::time::Instant;
201
202 use super::*;
203
204 fn encode_all(input: &[u8], window_sz2: u8, lookahead_sz2: u8, read_sz: usize) -> Vec<u8> {
205 assert!(read_sz > 0, "read_sz must be greater than 0");
206 let mut encoder =
207 HeatshrinkEncoder::new(window_sz2, lookahead_sz2).expect("Failed to create encoder");
208 let mut compressed = vec![];
209 let mut scratch: Vec<u8> = vec![0; read_sz * 2];
210 let mut read_offset = 0;
211
212 while read_offset < input.len() {
214 let read_len = if input.len() - read_offset > read_sz {
215 read_sz
216 } else {
217 input.len() - read_offset
218 };
219 let mut read_data = &input[read_offset..read_offset + read_len];
220 while !read_data.is_empty() {
221 let sink_res = encoder.sink(read_data);
222 match sink_res {
223 HSESinkRes::Ok(bytes_sunk) => {
224 read_data = &read_data[bytes_sunk..];
225 }
226 _ => unreachable!(),
227 }
228
229 loop {
230 match encoder.poll(&mut scratch) {
231 HSEPollRes::Empty(sz) => {
232 compressed.extend(&scratch[..sz]);
233 break;
234 }
235 HSEPollRes::More(sz) => {
236 compressed.extend(&scratch[..sz]);
237 }
238 HSEPollRes::ErrorMisuse | HSEPollRes::ErrorNull => unreachable!(),
239 }
240 }
241 }
242
243 read_offset += read_len;
244 }
245
246 loop {
248 match encoder.finish() {
249 HSEFinishRes::Done => {
250 break;
251 }
252 HSEFinishRes::More => {}
253 HSEFinishRes::ErrorNull => unreachable!(),
254 }
255
256 loop {
257 match encoder.poll(&mut scratch) {
258 HSEPollRes::Empty(sz) => {
259 compressed.extend(&scratch[..sz]);
260 break;
261 }
262 HSEPollRes::More(sz) => {
263 compressed.extend(&scratch[..sz]);
264 }
265 HSEPollRes::ErrorMisuse | HSEPollRes::ErrorNull => unreachable!(),
266 }
267 }
268 }
269
270 compressed
271 }
272
273 fn decode_all(
274 input: &[u8],
275 input_buffer_size: usize,
276 window_sz2: u8,
277 lookahead_sz2: u8,
278 read_sz: usize,
279 ) -> Vec<u8> {
280 assert!(read_sz > 0, "read_sz must be greater than 0");
281 let mut decoder =
282 HeatshrinkDecoder::new(input_buffer_size as u16, window_sz2, lookahead_sz2)
283 .expect("Failed to create decoder");
284 let mut decompressed = vec![];
285 let mut scratch: Vec<u8> = vec![0; read_sz * 2];
286 let mut read_offset = 0;
287
288 while read_offset < input.len() {
290 let read_len = if input.len() - read_offset > read_sz {
291 read_sz
292 } else {
293 input.len() - read_offset
294 };
295 let mut read_data = &input[read_offset..read_offset + read_len];
296 while !read_data.is_empty() {
297 let sink_res = decoder.sink(read_data);
298 match sink_res {
299 HSDSinkRes::Ok(bytes_sunk) => {
300 read_data = &read_data[bytes_sunk..];
301 }
302 _ => unreachable!(),
303 }
304
305 loop {
306 match decoder.poll(&mut scratch) {
307 HSDPollRes::Empty(sz) => {
308 decompressed.extend(&scratch[..sz]);
309 break;
310 }
311 HSDPollRes::More(sz) => {
312 decompressed.extend(&scratch[..sz]);
313 }
314 HSDPollRes::ErrorNull => unreachable!(),
315 e => panic!("Failed to poll data: {:?}", e),
316 }
317 }
318 }
319
320 read_offset += read_len;
321 }
322
323 loop {
325 match decoder.finish() {
326 HSDFinishRes::Done => {
327 break;
328 }
329 HSDFinishRes::More => {}
330 HSDFinishRes::ErrorNull => unreachable!(),
331 }
332
333 loop {
334 match decoder.poll(&mut scratch) {
335 HSDPollRes::Empty(sz) => {
336 decompressed.extend(&scratch[..sz]);
337 break;
338 }
339 HSDPollRes::More(sz) => {
340 decompressed.extend(&scratch[..sz]);
341 }
342 HSDPollRes::ErrorNull => unreachable!(),
343 e => panic!("Failed to poll data: {:?}", e),
344 }
345 }
346 }
347
348 decompressed
349 }
350
351 fn roundtrip(
352 input: &[u8],
353 window_sz2: u8,
354 lookahead_sz2: u8,
355 in_read_sz: usize,
356 out_read_sz: usize,
357 out_buffer_sz: usize,
358 ) -> (Vec<u8>, Vec<u8>) {
359 let compressed = encode_all(input, window_sz2, lookahead_sz2, in_read_sz);
360 let decompressed = decode_all(
361 &compressed,
362 out_buffer_sz,
363 window_sz2,
364 lookahead_sz2,
365 out_read_sz,
366 );
367 (compressed, decompressed)
368 }
369
370 #[test]
371 fn end2end_sanity_mock() {
372 let input_data: Vec<u8> = (0..100).flat_map(|x| vec![x; 10]).collect();
373 println!(
374 "Input {} bytes: {:02X?}",
375 input_data.len(),
376 input_data.as_slice()
377 );
378
379 let compressed = encode_all(&input_data, 8, 4, 16);
381
382 println!(
383 "Wrote {} bytes: {:02X?}",
384 compressed.len(),
385 compressed.as_slice()
386 );
387
388 let decompressed = decode_all(&compressed, 100, 8, 4, 16);
390
391 println!(
392 "Read {} bytes: {:02X?}",
393 decompressed.len(),
394 decompressed.as_slice()
395 );
396
397 for i in 0..input_data.len() {
399 if i >= decompressed.len() {
400 assert_eq!(input_data[i], 0, "{}: {} == {}", i, input_data[i], "EOF");
401 continue;
402 }
403 assert_eq!(
404 input_data[i], decompressed[i],
405 "{}: {} == {}",
406 i, input_data[i], decompressed[i]
407 );
408 }
409 }
410
411 #[derive(Debug, Clone, Copy)]
413 #[allow(dead_code)] struct RoundtripConfig {
415 window_sz2: u8,
416 lookahead_sz2: u8,
417 in_read_sz: usize,
418 out_read_sz: usize,
419 out_buffer_sz: usize,
420 file_name: &'static str,
421 compressed_size: usize,
422 compression_ratio: f32,
423 compression_time_us: usize,
424 }
425
426 #[test]
427 fn end2end_sanity_param_sweep() {
428 let text_data = include_bytes!("heatshrink_encoder.rs");
430 let real_medium_size_data = include_bytes!("../tsz-compressed-data.bin");
431 let data: Vec<(&'static str, &[u8])> = vec![
432 ("heatshrink_encoder.rs", text_data),
433 ("tsz-compressed-data.bin", real_medium_size_data),
434 ];
435
436 let window_lookahead_pairs = (HEATSHRINK_MIN_WINDOW_BITS..=HEATSHRINK_MAX_WINDOW_BITS)
438 .flat_map(|window_sz2| {
439 (HEATSHRINK_MIN_LOOKAHEAD_BITS..window_sz2)
440 .map(move |lookahead_sz2| (window_sz2, lookahead_sz2))
441 });
442
443 let read_buffer_sizes = [1, 2, 512];
445 let read_size_pairs = read_buffer_sizes
446 .iter()
447 .flat_map(|&read_sz| {
448 read_buffer_sizes
449 .iter()
450 .map(move |&buf_sz| (read_sz, buf_sz))
451 })
452 .collect::<Vec<_>>();
453
454 let input_buffer_sizes = [1, 64, 8192];
456
457 let mut configurations = vec![];
459 for (window_sz2, lookahead_sz2) in window_lookahead_pairs {
460 for (in_read_sz, out_read_sz) in read_size_pairs.iter() {
461 for out_buffer_sz in input_buffer_sizes.iter() {
462 for data in data.iter() {
463 configurations.push((
464 window_sz2,
465 lookahead_sz2,
466 *in_read_sz,
467 *out_read_sz,
468 *out_buffer_sz,
469 data,
470 ));
471 }
472 }
473 }
474 }
475
476 println!("Running {} configurations", configurations.len());
477 let t0 = Instant::now();
478
479 let results: Vec<RoundtripConfig> = configurations
480 .into_par_iter()
481 .map(
482 |(window_sz2, lookahead_sz2, in_read_sz, out_read_sz, out_buffer_sz, data)| {
483 let mut compression_ratio = 0.0;
485 let mut elapsed_us = 0;
486 let mut compressed_len = 0;
487 const ITERS: usize = 5;
488 for i in 0..ITERS {
489 let t0 = Instant::now();
490 let (compressed, decompressed) = roundtrip(
491 data.1,
492 window_sz2,
493 lookahead_sz2,
494 in_read_sz,
495 out_read_sz,
496 out_buffer_sz,
497 );
498 let t1 = Instant::now();
499 let elapsed = t1 - t0;
500 elapsed_us += elapsed.as_micros();
501 compression_ratio = data.1.len() as f32 / compressed.len() as f32;
502 if i == 0 {
503 compressed_len = compressed.len();
504 }
505 assert_eq!(compressed_len, compressed.len());
506 assert_eq!(data.1, decompressed.as_slice());
507 }
508 let config = RoundtripConfig {
509 window_sz2,
510 lookahead_sz2,
511 in_read_sz,
512 out_read_sz,
513 out_buffer_sz,
514 file_name: data.0,
515 compressed_size: compressed_len,
516 compression_ratio,
517 compression_time_us: elapsed_us as usize / ITERS,
518 };
519 println!("{:?}", config);
520 config
521 },
522 )
523 .collect();
524
525 let mut results = results
527 .into_iter()
528 .filter(|r| r.file_name == "tsz-compressed-data.bin")
529 .collect::<Vec<_>>();
530 results.sort_by(|a, b| {
532 a.compression_ratio
533 .partial_cmp(&b.compression_ratio)
534 .unwrap()
535 });
536 println!("Bottom compression ratios:");
537 for i in 0..50 {
538 println!("WORST RATIO: {:?}", results[i]);
539 }
540 println!("Top compression ratios:");
541 for i in (results.len() - 50)..results.len() {
542 println!("BEST RATIO: {:?}", results[i]);
543 }
544
545 results.sort_by(|a, b| {
547 a.compression_time_us
548 .partial_cmp(&b.compression_time_us)
549 .unwrap()
550 });
551 println!("Bottom 3 compression times:");
552 for i in (results.len() - 50)..results.len() {
553 println!("WORST TIME: {:?}", results[i]);
554 }
555 println!("Top compression times:");
556 for i in 0..50 {
557 println!("BEST TIME: {:?}", results[i]);
558 }
559
560 let t1 = Instant::now();
561 println!("Completed permutations in {:?}", t1 - t0);
562 }
563
564 #[test]
565 fn fuzz() {
566 let status = if cfg!(debug_assertions) {
569 std::process::Command::new("./fuzz.sh")
570 .arg("debug")
571 .status()
572 .expect("Fuzz failed")
573 } else {
574 std::process::Command::new("./fuzz.sh")
575 .status()
576 .expect("Fuzz failed")
577 };
578
579 assert!(status.success());
580 }
581}