embedded_heatshrink/
lib.rs

1//!
2//! This is a fairly faithful port of the heatshrink library to Rust.
3//!
4//! The streaming C-style API for encoding and decoding is implemented
5//! with the `heatshrink_encoder` and `heatshrink_decoder` modules. The
6//! heatshrink index assuming dynamic allocation is implemented is present
7//! under the assumption that this library may be used in a no_std context
8//! with alloc support.
9//!
10#![cfg_attr(not(any(feature = "std", test)), no_std)]
11// #![cfg(not(test))]
12// extern crate alloc;
13
14#[cfg(feature = "std")]
15extern crate std;
16
17#[cfg(not(feature = "std"))]
18extern crate alloc;
19
20#[cfg(feature = "std")]
21use std::io::{Read, Write};
22
23pub(crate) mod common;
24pub mod heatshrink_decoder;
25pub mod heatshrink_encoder;
26
27pub use heatshrink_decoder::*;
28pub use heatshrink_encoder::*;
29
30/// Heatshrink constant limits
31pub const HEATSHRINK_MIN_WINDOW_BITS: u8 = 4;
32pub const HEATSHRINK_MAX_WINDOW_BITS: u8 = 15; // there may be some strangeness with 15 but it passes tests
33pub const HEATSHRINK_MIN_LOOKAHEAD_BITS: u8 = 3;
34
35/// Create an encoder, Read from stdin, Sink and Poll through the encoder, and Write polled bytes to stdout.
36#[cfg(feature = "std")]
37pub fn encode(window_sz2: u8, lookahead_sz2: u8, stdin: &mut impl Read, stdout: &mut impl Write) {
38    let mut encoder =
39        HeatshrinkEncoder::new(window_sz2, lookahead_sz2).expect("Failed to create encoder");
40
41    const WORK_SIZE_UNIT: usize = 1024;
42    let mut buf = [0; WORK_SIZE_UNIT];
43    let mut scratch = [0; WORK_SIZE_UNIT * 2];
44
45    // Sink all bytes from the input buffer
46    let mut not_empty = false;
47    loop {
48        let read_len = read_in(stdin, &mut buf);
49        not_empty |= read_len > 0;
50        if read_len == 0 {
51            break;
52        }
53        let mut read_data = &buf[..read_len];
54        while !read_data.is_empty() {
55            let sink_res = encoder.sink(read_data);
56            match sink_res {
57                HSESinkRes::Ok(bytes_sunk) => {
58                    read_data = &read_data[bytes_sunk..];
59                }
60                _ => unreachable!(),
61            }
62
63            loop {
64                match encoder.poll(&mut scratch) {
65                    HSEPollRes::Empty(sz) => {
66                        write_out(stdout, &scratch[..sz]);
67                        break;
68                    }
69                    HSEPollRes::More(sz) => {
70                        write_out(stdout, &scratch[..sz]);
71                    }
72                    HSEPollRes::ErrorMisuse | HSEPollRes::ErrorNull => unreachable!(),
73                }
74            }
75        }
76    }
77
78    if !not_empty {
79        return;
80    }
81
82    // Poll out the remaining bytes
83    loop {
84        match encoder.finish() {
85            HSEFinishRes::Done => {
86                break;
87            }
88            HSEFinishRes::More => {}
89            HSEFinishRes::ErrorNull => unreachable!(),
90        }
91
92        loop {
93            match encoder.poll(&mut scratch) {
94                HSEPollRes::Empty(sz) => {
95                    write_out(stdout, &scratch[..sz]);
96                    break;
97                }
98                HSEPollRes::More(sz) => {
99                    write_out(stdout, &scratch[..sz]);
100                }
101                HSEPollRes::ErrorMisuse | HSEPollRes::ErrorNull => unreachable!(),
102            }
103        }
104    }
105}
106
107/// Create a decoder, Read from stdin, Sink and Poll through the decoder, and Write polled bytes to stdout.
108#[cfg(feature = "std")]
109pub fn decode(window_sz2: u8, lookahead_sz2: u8, stdin: &mut impl Read, stdout: &mut impl Write) {
110    const WORK_SIZE_UNIT: usize = 1024;
111
112    let mut decoder = HeatshrinkDecoder::new(WORK_SIZE_UNIT as u16, window_sz2, lookahead_sz2)
113        .expect("Failed to create decoder");
114    let mut buf = [0; WORK_SIZE_UNIT];
115    let mut scratch = [0; WORK_SIZE_UNIT * 2];
116
117    // Sink all bytes from the input buffer
118    let mut not_empty = false;
119    loop {
120        let read_len = read_in(stdin, &mut buf);
121        not_empty |= read_len > 0;
122        if read_len == 0 {
123            break;
124        }
125        let mut read_data = &buf[..read_len];
126        while !read_data.is_empty() {
127            let sink_res = decoder.sink(read_data);
128            match sink_res {
129                HSDSinkRes::Ok(bytes_sunk) => {
130                    read_data = &read_data[bytes_sunk..];
131                }
132                _ => unreachable!(),
133            }
134
135            loop {
136                match decoder.poll(&mut scratch) {
137                    HSDPollRes::Empty(sz) => {
138                        write_out(stdout, &scratch[..sz]);
139                        break;
140                    }
141                    HSDPollRes::More(sz) => {
142                        write_out(stdout, &scratch[..sz]);
143                    }
144                    HSDPollRes::ErrorNull => unreachable!(),
145                    HSDPollRes::ErrorUnknown => {
146                        panic!("Error: Unknown");
147                    }
148                }
149            }
150        }
151    }
152
153    if !not_empty {
154        return;
155    }
156
157    // Poll out the remaining bytes
158    loop {
159        match decoder.finish() {
160            HSDFinishRes::Done => {
161                break;
162            }
163            HSDFinishRes::More => {}
164            HSDFinishRes::ErrorNull => unreachable!(),
165        }
166
167        loop {
168            match decoder.poll(&mut scratch) {
169                HSDPollRes::Empty(sz) => {
170                    write_out(stdout, &scratch[..sz]);
171                    break;
172                }
173                HSDPollRes::More(sz) => {
174                    write_out(stdout, &scratch[..sz]);
175                }
176                HSDPollRes::ErrorNull => unreachable!(),
177                HSDPollRes::ErrorUnknown => {
178                    panic!("Error: Unknown");
179                }
180            }
181        }
182    }
183}
184
185#[cfg(feature = "std")]
186#[inline]
187fn read_in(stdin: &mut impl Read, buf: &mut [u8]) -> usize {
188    stdin.read(buf).expect("Failed to read from stdin")
189}
190
191#[cfg(feature = "std")]
192#[inline]
193fn write_out(stdout: &mut impl Write, data: &[u8]) {
194    stdout.write_all(data).expect("Failed to write to stdout");
195}
196
197#[cfg(test)]
198mod tests {
199    use rayon::prelude::*;
200    use std::time::Instant;
201
202    use super::*;
203
204    fn encode_all(input: &[u8], window_sz2: u8, lookahead_sz2: u8, read_sz: usize) -> Vec<u8> {
205        assert!(read_sz > 0, "read_sz must be greater than 0");
206        let mut encoder =
207            HeatshrinkEncoder::new(window_sz2, lookahead_sz2).expect("Failed to create encoder");
208        let mut compressed = vec![];
209        let mut scratch: Vec<u8> = vec![0; read_sz * 2];
210        let mut read_offset = 0;
211
212        // Sink all bytes from the input buffer
213        while read_offset < input.len() {
214            let read_len = if input.len() - read_offset > read_sz {
215                read_sz
216            } else {
217                input.len() - read_offset
218            };
219            let mut read_data = &input[read_offset..read_offset + read_len];
220            while !read_data.is_empty() {
221                let sink_res = encoder.sink(read_data);
222                match sink_res {
223                    HSESinkRes::Ok(bytes_sunk) => {
224                        read_data = &read_data[bytes_sunk..];
225                    }
226                    _ => unreachable!(),
227                }
228
229                loop {
230                    match encoder.poll(&mut scratch) {
231                        HSEPollRes::Empty(sz) => {
232                            compressed.extend(&scratch[..sz]);
233                            break;
234                        }
235                        HSEPollRes::More(sz) => {
236                            compressed.extend(&scratch[..sz]);
237                        }
238                        HSEPollRes::ErrorMisuse | HSEPollRes::ErrorNull => unreachable!(),
239                    }
240                }
241            }
242
243            read_offset += read_len;
244        }
245
246        // Poll out the remaining bytes
247        loop {
248            match encoder.finish() {
249                HSEFinishRes::Done => {
250                    break;
251                }
252                HSEFinishRes::More => {}
253                HSEFinishRes::ErrorNull => unreachable!(),
254            }
255
256            loop {
257                match encoder.poll(&mut scratch) {
258                    HSEPollRes::Empty(sz) => {
259                        compressed.extend(&scratch[..sz]);
260                        break;
261                    }
262                    HSEPollRes::More(sz) => {
263                        compressed.extend(&scratch[..sz]);
264                    }
265                    HSEPollRes::ErrorMisuse | HSEPollRes::ErrorNull => unreachable!(),
266                }
267            }
268        }
269
270        compressed
271    }
272
273    fn decode_all(
274        input: &[u8],
275        input_buffer_size: usize,
276        window_sz2: u8,
277        lookahead_sz2: u8,
278        read_sz: usize,
279    ) -> Vec<u8> {
280        assert!(read_sz > 0, "read_sz must be greater than 0");
281        let mut decoder =
282            HeatshrinkDecoder::new(input_buffer_size as u16, window_sz2, lookahead_sz2)
283                .expect("Failed to create decoder");
284        let mut decompressed = vec![];
285        let mut scratch: Vec<u8> = vec![0; read_sz * 2];
286        let mut read_offset = 0;
287
288        // Sink all bytes from the input buffer
289        while read_offset < input.len() {
290            let read_len = if input.len() - read_offset > read_sz {
291                read_sz
292            } else {
293                input.len() - read_offset
294            };
295            let mut read_data = &input[read_offset..read_offset + read_len];
296            while !read_data.is_empty() {
297                let sink_res = decoder.sink(read_data);
298                match sink_res {
299                    HSDSinkRes::Ok(bytes_sunk) => {
300                        read_data = &read_data[bytes_sunk..];
301                    }
302                    _ => unreachable!(),
303                }
304
305                loop {
306                    match decoder.poll(&mut scratch) {
307                        HSDPollRes::Empty(sz) => {
308                            decompressed.extend(&scratch[..sz]);
309                            break;
310                        }
311                        HSDPollRes::More(sz) => {
312                            decompressed.extend(&scratch[..sz]);
313                        }
314                        HSDPollRes::ErrorNull => unreachable!(),
315                        e => panic!("Failed to poll data: {:?}", e),
316                    }
317                }
318            }
319
320            read_offset += read_len;
321        }
322
323        // Poll out the remaining bytes
324        loop {
325            match decoder.finish() {
326                HSDFinishRes::Done => {
327                    break;
328                }
329                HSDFinishRes::More => {}
330                HSDFinishRes::ErrorNull => unreachable!(),
331            }
332
333            loop {
334                match decoder.poll(&mut scratch) {
335                    HSDPollRes::Empty(sz) => {
336                        decompressed.extend(&scratch[..sz]);
337                        break;
338                    }
339                    HSDPollRes::More(sz) => {
340                        decompressed.extend(&scratch[..sz]);
341                    }
342                    HSDPollRes::ErrorNull => unreachable!(),
343                    e => panic!("Failed to poll data: {:?}", e),
344                }
345            }
346        }
347
348        decompressed
349    }
350
351    fn roundtrip(
352        input: &[u8],
353        window_sz2: u8,
354        lookahead_sz2: u8,
355        in_read_sz: usize,
356        out_read_sz: usize,
357        out_buffer_sz: usize,
358    ) -> (Vec<u8>, Vec<u8>) {
359        let compressed = encode_all(input, window_sz2, lookahead_sz2, in_read_sz);
360        let decompressed = decode_all(
361            &compressed,
362            out_buffer_sz,
363            window_sz2,
364            lookahead_sz2,
365            out_read_sz,
366        );
367        (compressed, decompressed)
368    }
369
370    #[test]
371    fn end2end_sanity_mock() {
372        let input_data: Vec<u8> = (0..100).flat_map(|x| vec![x; 10]).collect();
373        println!(
374            "Input {} bytes: {:02X?}",
375            input_data.len(),
376            input_data.as_slice()
377        );
378
379        // Encode
380        let compressed = encode_all(&input_data, 8, 4, 16);
381
382        println!(
383            "Wrote {} bytes: {:02X?}",
384            compressed.len(),
385            compressed.as_slice()
386        );
387
388        // Decode
389        let decompressed = decode_all(&compressed, 100, 8, 4, 16);
390
391        println!(
392            "Read {} bytes: {:02X?}",
393            decompressed.len(),
394            decompressed.as_slice()
395        );
396
397        // Check
398        for i in 0..input_data.len() {
399            if i >= decompressed.len() {
400                assert_eq!(input_data[i], 0, "{}: {} == {}", i, input_data[i], "EOF");
401                continue;
402            }
403            assert_eq!(
404                input_data[i], decompressed[i],
405                "{}: {} == {}",
406                i, input_data[i], decompressed[i]
407            );
408        }
409    }
410
411    /// Configuration used to track the compression configurations
412    #[derive(Debug, Clone, Copy)]
413    #[allow(dead_code)] // used by Debug
414    struct RoundtripConfig {
415        window_sz2: u8,
416        lookahead_sz2: u8,
417        in_read_sz: usize,
418        out_read_sz: usize,
419        out_buffer_sz: usize,
420        file_name: &'static str,
421        compressed_size: usize,
422        compression_ratio: f32,
423        compression_time_us: usize,
424    }
425
426    #[test]
427    fn end2end_sanity_param_sweep() {
428        // Compress several different types of files from B to KB to MB
429        let text_data = include_bytes!("heatshrink_encoder.rs");
430        let real_medium_size_data = include_bytes!("../tsz-compressed-data.bin");
431        let data: Vec<(&'static str, &[u8])> = vec![
432            ("heatshrink_encoder.rs", text_data),
433            ("tsz-compressed-data.bin", real_medium_size_data),
434        ];
435
436        // Use all possible window and lookahead sizes
437        let window_lookahead_pairs = (HEATSHRINK_MIN_WINDOW_BITS..=HEATSHRINK_MAX_WINDOW_BITS)
438            .flat_map(|window_sz2| {
439                (HEATSHRINK_MIN_LOOKAHEAD_BITS..window_sz2)
440                    .map(move |lookahead_sz2| (window_sz2, lookahead_sz2))
441            });
442
443        // Use several different read and buffer sizes
444        let read_buffer_sizes = [1, 2, 512];
445        let read_size_pairs = read_buffer_sizes
446            .iter()
447            .flat_map(|&read_sz| {
448                read_buffer_sizes
449                    .iter()
450                    .map(move |&buf_sz| (read_sz, buf_sz))
451            })
452            .collect::<Vec<_>>();
453
454        // Use several different input buffer sizes to stress different code paths
455        let input_buffer_sizes = [1, 64, 8192];
456
457        // Use rayon to run all the permutations in parallel
458        let mut configurations = vec![];
459        for (window_sz2, lookahead_sz2) in window_lookahead_pairs {
460            for (in_read_sz, out_read_sz) in read_size_pairs.iter() {
461                for out_buffer_sz in input_buffer_sizes.iter() {
462                    for data in data.iter() {
463                        configurations.push((
464                            window_sz2,
465                            lookahead_sz2,
466                            *in_read_sz,
467                            *out_read_sz,
468                            *out_buffer_sz,
469                            data,
470                        ));
471                    }
472                }
473            }
474        }
475
476        println!("Running {} configurations", configurations.len());
477        let t0 = Instant::now();
478
479        let results: Vec<RoundtripConfig> = configurations
480            .into_par_iter()
481            .map(
482                |(window_sz2, lookahead_sz2, in_read_sz, out_read_sz, out_buffer_sz, data)| {
483                    // Run the roundtrip configuration several times to get an average
484                    let mut compression_ratio = 0.0;
485                    let mut elapsed_us = 0;
486                    let mut compressed_len = 0;
487                    const ITERS: usize = 5;
488                    for i in 0..ITERS {
489                        let t0 = Instant::now();
490                        let (compressed, decompressed) = roundtrip(
491                            data.1,
492                            window_sz2,
493                            lookahead_sz2,
494                            in_read_sz,
495                            out_read_sz,
496                            out_buffer_sz,
497                        );
498                        let t1 = Instant::now();
499                        let elapsed = t1 - t0;
500                        elapsed_us += elapsed.as_micros();
501                        compression_ratio = data.1.len() as f32 / compressed.len() as f32;
502                        if i == 0 {
503                            compressed_len = compressed.len();
504                        }
505                        assert_eq!(compressed_len, compressed.len());
506                        assert_eq!(data.1, decompressed.as_slice());
507                    }
508                    let config = RoundtripConfig {
509                        window_sz2,
510                        lookahead_sz2,
511                        in_read_sz,
512                        out_read_sz,
513                        out_buffer_sz,
514                        file_name: data.0,
515                        compressed_size: compressed_len,
516                        compression_ratio,
517                        compression_time_us: elapsed_us as usize / ITERS,
518                    };
519                    println!("{:?}", config);
520                    config
521                },
522            )
523            .collect();
524
525        // Only print out results for real data
526        let mut results = results
527            .into_iter()
528            .filter(|r| r.file_name == "tsz-compressed-data.bin")
529            .collect::<Vec<_>>();
530        // Print top 3 and bottom 3 compression ratios
531        results.sort_by(|a, b| {
532            a.compression_ratio
533                .partial_cmp(&b.compression_ratio)
534                .unwrap()
535        });
536        println!("Bottom compression ratios:");
537        for i in 0..50 {
538            println!("WORST RATIO: {:?}", results[i]);
539        }
540        println!("Top compression ratios:");
541        for i in (results.len() - 50)..results.len() {
542            println!("BEST RATIO: {:?}", results[i]);
543        }
544
545        // Print top 3 and bottom 3 compression times
546        results.sort_by(|a, b| {
547            a.compression_time_us
548                .partial_cmp(&b.compression_time_us)
549                .unwrap()
550        });
551        println!("Bottom 3 compression times:");
552        for i in (results.len() - 50)..results.len() {
553            println!("WORST TIME: {:?}", results[i]);
554        }
555        println!("Top compression times:");
556        for i in 0..50 {
557            println!("BEST TIME: {:?}", results[i]);
558        }
559
560        let t1 = Instant::now();
561        println!("Completed permutations in {:?}", t1 - t0);
562    }
563
564    #[test]
565    fn fuzz() {
566        // Fuzzing is implemented by ./fuzz.sh, call with ./fuzz.sh debug if in debug mode
567        // Run the command and expect 0 exit code
568        let status = if cfg!(debug_assertions) {
569            std::process::Command::new("./fuzz.sh")
570                .arg("debug")
571                .status()
572                .expect("Fuzz failed")
573        } else {
574            std::process::Command::new("./fuzz.sh")
575                .status()
576                .expect("Fuzz failed")
577        };
578
579        assert!(status.success());
580    }
581}