embedded_huffman/
lib.rs

1#![no_std]
2extern crate alloc;
3#[cfg(any(feature = "std", test))]
4extern crate std;
5
6///!
7///! This is a library for streaming sensor data through
8///! an Encoder that writes its frequency table and huffman coding
9///! to NAND pages on a sensor. The Encoder includes some
10///! metadata to allow flushing bytes directly to NAND, including
11///! page headers every N bytes to indicate bits encoded in the page.
12///! The Decoder knows how to read these pages, inflate the huffman
13///! coding from the frequency table, and decode the sensor data.
14///!
15///! The Encoder and Decoder use the previous 1MiB of frequencies
16///! to build the frequency table used for the next 1MiB of data.
17///!
18///! This allows the sensor data to change over time, and the
19///! encoder and decoder to adapt to changes int the sensor data.
20///!
21use alloc::boxed::Box;
22use alloc::collections::VecDeque;
23use alloc::vec::Vec;
24use core::future::Future;
25use core::pin::Pin;
26use core::ptr::copy_nonoverlapping;
27
28mod tree;
29use tree::*;
30
31/// The Huffman code is built from bytes, so the symbol count is 2^8
32const WORD_SIZE: usize = 8;
33const SYMBOL_COUNT: usize = 1 << WORD_SIZE;
34
35// Predefined bit shifts to remove some of the bit shifting
36static PRESHIFTED7: [u8; 2] = [0b0000_0000, 0b1000_0000];
37static PRESHIFTED6: [u8; 2] = [0b0000_0000, 0b0100_0000];
38static PRESHIFTED5: [u8; 2] = [0b0000_0000, 0b0010_0000];
39static PRESHIFTED4: [u8; 2] = [0b0000_0000, 0b0001_0000];
40static PRESHIFTED3: [u8; 2] = [0b0000_0000, 0b0000_1000];
41static PRESHIFTED2: [u8; 2] = [0b0000_0000, 0b0000_0100];
42static PRESHIFTED1: [u8; 2] = [0b0000_0000, 0b0000_0010];
43static PRESHIFTED0: [u8; 2] = [0b0000_0000, 0b0000_0001];
44
45/// A huffman encoder that writes successive tables and data to pages
46pub struct Encoder {
47    page_size: usize,
48    page_threshold: usize,
49    page_threshold_limit: usize,
50    page_count: usize,
51    state: EncodeState,
52    word_batch: Vec<u8>,
53    weights: [u32; SYMBOL_COUNT],
54    code_table: Vec<CodeEntry>,
55    visit_deque: VecDeque<(Node, Vec<usize>)>,
56    done: bool,
57    #[cfg(feature = "ratio")]
58    bytes_in: usize,
59    #[cfg(feature = "ratio")]
60    bytes_out: usize,
61}
62
63/// The state while sinking bytes into the encoder
64#[derive(Debug, Copy, Clone)]
65enum EncodeState {
66    /// The encoder is gathering enough data for the initial table
67    Init,
68    /// The encoder is writing the frequency table to NAND.
69    /// The huffman tree is built from the frequencies.
70    Table,
71    /// The encoder is writing encoded data to NAND.
72    /// The data is encoded with a table derived from the previous huffman coding.
73    Data,
74    /// The encoder is in an error state. This is terminal.
75    Error,
76}
77
78impl Encoder {
79    /// Create an encoder that writes into pages of `page_size` bytes
80    /// and rebuilds the frequency table every `page_threshold` pages.
81    pub fn new(page_size: usize, page_threshold: usize) -> Encoder {
82        Encoder {
83            page_size,
84            page_threshold: 1, // exponential backoff until page_threshold_limit
85            page_threshold_limit: page_threshold,
86            page_count: 0,
87            state: EncodeState::Init,
88            word_batch: Vec::with_capacity(page_size),
89            weights: [1; SYMBOL_COUNT], // always assume each symbol is present at least once
90            code_table: Vec::with_capacity(SYMBOL_COUNT),
91            visit_deque: VecDeque::with_capacity(SYMBOL_COUNT * 2 - 1),
92            done: false,
93            #[cfg(feature = "ratio")]
94            bytes_in: 0,
95            #[cfg(feature = "ratio")]
96            bytes_out: 0,
97        }
98    }
99
100    /// Prepare the encoder for a new round of encoding
101    /// this keeps allocations, so its cheaper than a Encoder::new call
102    pub fn reset(&mut self) {
103        self.page_count = 0;
104        self.page_threshold = 1;
105        self.state = EncodeState::Init;
106        self.word_batch.clear();
107        self.weights.fill(1);
108        self.code_table.clear();
109        self.done = false;
110        #[cfg(feature = "ratio")]
111        {
112            self.bytes_in = 0;
113            self.bytes_out = 0;
114        }
115    }
116
117    /// Get the ratio of bytes in to bytes out
118    #[cfg(feature = "ratio")]
119    pub fn ratio(&self) -> f32 {
120        self.bytes_in as f32 / self.bytes_out as f32
121    }
122
123    #[inline(always)]
124    /// Check if a batch of bytes fits in the current batch
125    pub fn batch_fits(&self, bytes: usize) -> bool {
126        self.word_batch.len() + bytes < self.page_size
127    }
128
129    #[inline(always)]
130    /// Sink a batch of bytes without checking if the batch fits
131    pub unsafe fn batch_sink(&mut self, bytes: &[u8]) {
132        // Optimized for Cortex-M4 where there is no branch prediction
133        #[cfg(feature = "ratio")]
134        {
135            self.bytes_in += bytes.len();
136        }
137        // SAFETY:
138        // we know weights is 256 elements long and is indexed by a u8 [0..256]
139        // we know word_batch is < len because we check it below
140        // we can't use get_unchecked_mut because we have uninitialized memory
141        let weights_ptr = self.weights.as_mut_ptr();
142        let word_batch_ptr = self.word_batch.as_mut_ptr();
143        for byte in bytes {
144            *weights_ptr.add(*byte as usize) += 1;
145        }
146        let idx = self.word_batch.len();
147        let dst = word_batch_ptr.add(idx);
148        let src = bytes.as_ptr();
149        copy_nonoverlapping(src, dst, bytes.len());
150        self.word_batch.set_len(idx + bytes.len());
151    }
152
153    /// Put a byte into the encoder
154    #[inline(always)]
155    pub async fn sink<E>(&mut self, byte: u8, output: &mut impl PageWriter<E>) -> Result<bool, E> {
156        self.crank(output, Some(byte)).await
157    }
158
159    /// Finish encoding and flush any remaining bytes
160    pub async fn flush<E>(&mut self, output: &mut impl PageWriter<E>) -> Result<bool, E> {
161        self.crank(output, None).await
162    }
163
164    #[inline(always)]
165    async fn crank<E>(
166        &mut self,
167        output: &mut impl PageWriter<E>,
168        byte: Option<u8>,
169    ) -> Result<bool, E> {
170        let finish = if let Some(byte) = byte {
171            // Optimized for Cortex-M4 where there is no branch prediction
172            #[cfg(feature = "ratio")]
173            {
174                self.bytes_in += 1;
175            }
176            // SAFETY:
177            // we know weights is 256 elements long and is indexed by a u8 [0..256]
178            // we know word_batch is < len because we check it below
179            // we can't use get_unchecked_mut because we have uninitialized memory
180            unsafe {
181                let weights_ptr = self.weights.as_mut_ptr();
182                let word_batch_ptr = self.word_batch.as_mut_ptr();
183                *weights_ptr.add(byte as usize) += 1;
184                let idx = self.word_batch.len();
185                *word_batch_ptr.add(idx) = byte;
186                self.word_batch.set_len(idx + 1);
187            }
188
189            // Hot path for encoding data
190            if matches!(self.state, EncodeState::Data) & (self.word_batch.len() < self.page_size) {
191                return Ok(self.done);
192            }
193
194            false
195        } else {
196            true
197        };
198
199        loop {
200            match self.state {
201                EncodeState::Error => {
202                    unreachable!()
203                }
204                EncodeState::Init => {
205                    if self.word_batch.len() == self.page_size || finish {
206                        self.state = EncodeState::Table;
207                        output.reset();
208                    } else {
209                        return Ok(self.done);
210                    }
211                }
212                EncodeState::Table => {
213                    // The table is SYMBOL_COUNT * 4 bytes long, [0 -> count, 1 -> count, 2 -> count, ...]
214                    for &weight in &self.weights {
215                        output.write_u32le(weight as u32);
216                    }
217
218                    // Write out bytes
219                    #[cfg(feature = "ratio")]
220                    {
221                        self.bytes_out += self.page_size;
222                    }
223                    match output.flush().await {
224                        Ok(done) => {
225                            self.done |= done;
226                        }
227                        Err(e) => {
228                            self.state = EncodeState::Error;
229                            return Err(e);
230                        }
231                    }
232
233                    // Build the Huffman tree from the symbol frequencies
234                    let root = build_tree(&self.weights);
235
236                    // Create the bit representation of the tree
237                    self.code_table.clear();
238                    self.code_table.resize(SYMBOL_COUNT, Default::default());
239                    build_code_table(root, &mut self.code_table, &mut self.visit_deque);
240
241                    // Always assume each symbol is present at least once
242                    self.weights.fill(1);
243
244                    // Ready to start encoding data
245                    self.state = EncodeState::Data;
246                }
247                EncodeState::Data => {
248                    // batching is only done for performance reasons ... hypothetically
249                    if (self.word_batch.len() < self.page_size) & !finish {
250                        return Ok(self.done);
251                    }
252
253                    // write all of the bits in the current batch using the current code table
254                    let mut drain = None;
255                    for (idx, byte) in self.word_batch.iter().enumerate() {
256                        let code = unsafe { self.code_table.get_unchecked(*byte as usize) };
257
258                        // If the word does not fit in the current page, then we need to advance pages
259                        if !output.write_code(code) {
260                            drain = Some(idx);
261                            break;
262                        }
263                    }
264
265                    if let Some(drain) = drain {
266                        // Remove words that are emitted
267                        self.word_batch.drain(..drain);
268
269                        // Move to the next page, writing out this page
270                        #[cfg(feature = "ratio")]
271                        {
272                            self.bytes_out += self.page_size;
273                        }
274                        match output.flush().await {
275                            Ok(done) => {
276                                self.done |= done;
277                            }
278                            Err(e) => {
279                                self.state = EncodeState::Error;
280                                return Err(e);
281                            }
282                        }
283
284                        // Update the page count and check if we need to rebuild the table
285                        self.page_count += 1;
286                        if self.page_count > self.page_threshold {
287                            self.page_count = 0;
288                            self.page_threshold =
289                                self.page_threshold_limit.min(self.page_threshold * 2);
290                            self.state = EncodeState::Table;
291                        } else {
292                            self.state = EncodeState::Data;
293                        }
294                    } else {
295                        self.word_batch.clear();
296
297                        // We are done if all the words sunk, emit the final page
298                        if finish {
299                            #[cfg(feature = "ratio")]
300                            {
301                                self.bytes_out += self.page_size;
302                            }
303                            return output.flush().await.and_then(|done| {
304                                self.done |= done;
305                                Ok(self.done)
306                            });
307                        }
308                    }
309                }
310            }
311        }
312    }
313}
314
315#[allow(async_fn_in_trait)]
316pub trait PageWriter<E> {
317    async fn flush(&mut self) -> Result<bool, E>;
318
319    fn position(&self) -> usize;
320    fn reset(&mut self);
321    fn write_header(&mut self, header: u32);
322    fn write_u32le(&mut self, value: u32);
323    fn write_code(&mut self, code: &CodeEntry) -> bool;
324}
325
326pub struct BufferedPageWriter<E> {
327    /// Position in the current page in bits
328    bits_written: usize,
329    /// Pre-calculated page size in bits
330    page_size: usize,
331    /// Vec<bool> pre-cast to u8, representing bits that need to be chunked into bytes
332    bits: Vec<usize>,
333    /// A page of bytes that need to be flushed to NAND
334    bytes: Vec<u8>,
335    /// A function that takes a ref to the page and writes it to NAND
336    flush_page: WritePageFutureFn<E>,
337    /// Done
338    done: bool,
339}
340
341/// A function that takes a reference to the page and writes it to NAND
342pub type WritePageFutureFn<E> =
343    Box<dyn for<'a> Fn(&'a [u8]) -> Pin<Box<dyn Future<Output = Result<bool, E>> + 'a>>>;
344
345impl<E> BufferedPageWriter<E> {
346    pub fn new(page_size: usize, flush: WritePageFutureFn<E>) -> BufferedPageWriter<E> {
347        // Allocate a buffer for the page
348        let mut buf: Vec<u8> = Vec::with_capacity(page_size);
349        unsafe { buf.set_len(page_size) };
350        BufferedPageWriter {
351            page_size: 8 * page_size,
352            bits: Vec::with_capacity(8 * 2),
353            bytes: buf,
354            bits_written: 0,
355            flush_page: flush,
356            done: false,
357        }
358    }
359}
360
361impl<E> PageWriter<E> for BufferedPageWriter<E> {
362    /// current bit position in the page
363    #[inline(always)]
364    fn position(&self) -> usize {
365        self.bits_written
366    }
367
368    /// reset the page
369    fn reset(&mut self) {
370        self.bits_written = 32;
371    }
372
373    /// go back and fill in the number of bits in the page
374    #[inline(always)]
375    fn write_header(&mut self, header: u32) {
376        let bits_written_header = header.to_le_bytes();
377        // Safe version:
378        // self.bytes[..bits_written_header.len()].copy_from_slice(&bits_written_header);
379        unsafe {
380            copy_nonoverlapping(
381                bits_written_header.as_ptr(),
382                self.bytes.as_mut_ptr(),
383                bits_written_header.len(),
384            );
385        }
386    }
387
388    /// append a u32 value to the page
389    /// this will copy as a block of bytes ignoring pending bits
390    #[inline(always)]
391    fn write_u32le(&mut self, value: u32) {
392        debug_assert!(self.bits.len() == 0);
393        let bytes = value.to_le_bytes();
394        let offset = self.bits_written / 8;
395        // Safe version:
396        // self.bytes[offset..offset + bytes.len()].copy_from_slice(&bytes);
397        unsafe {
398            copy_nonoverlapping(
399                bytes.as_ptr(),
400                self.bytes.as_mut_ptr().add(offset),
401                bytes.len(),
402            );
403        }
404        self.bits_written += bytes.len() * 8;
405    }
406
407    /// write the symbols for an entry if there is enough room in the current page
408    #[inline(always)]
409    fn write_code(&mut self, code: &CodeEntry) -> bool {
410        // Do not overwrite the page
411        let position = self.position();
412        let pending = self.bits.len();
413        debug_assert!(pending < 8); // expecting less than a byte pending otherwise it should have been flushed
414        if position + pending + code.bits.len() > self.page_size {
415            // Flush any pending bits as this page is full
416            if pending > 0 {
417                // Pad the bits with zeros
418                let padding = 8 - pending;
419                self.bits.extend((0..padding).map(|_| 0));
420
421                // Convert the bits to bytes
422                for byte_bits in self.bits.chunks_exact(8) {
423                    let byte_bits: &[usize; 8] = unsafe { byte_bits.try_into().unwrap_unchecked() };
424                    let byte = unsafe {
425                        PRESHIFTED7.get_unchecked(byte_bits[0])
426                            | PRESHIFTED6.get_unchecked(byte_bits[1])
427                            | PRESHIFTED5.get_unchecked(byte_bits[2])
428                            | PRESHIFTED4.get_unchecked(byte_bits[3])
429                            | PRESHIFTED3.get_unchecked(byte_bits[4])
430                            | PRESHIFTED2.get_unchecked(byte_bits[5])
431                            | PRESHIFTED1.get_unchecked(byte_bits[6])
432                            | PRESHIFTED0.get_unchecked(byte_bits[7])
433                    };
434                    let offset = self.bits_written / 8;
435                    #[cfg(test)]
436                    {
437                        self.bytes[offset] = byte;
438                    }
439                    #[cfg(not(test))]
440                    {
441                        unsafe {
442                            *self.bytes.get_unchecked_mut(offset) = byte;
443                        }
444                    }
445                    self.bits_written += 8;
446                }
447
448                // Bookkeeping
449                self.bits_written -= padding;
450                self.bits.clear();
451            }
452
453            return false;
454        }
455
456        // Extend bits into the page
457        self.bits.extend(code.bits.iter());
458
459        // Convert the bits to bytes
460        let mut drained = 0;
461        for byte_bits in self.bits.chunks_exact(8) {
462            drained += 8;
463            let byte_bits: &[usize; 8] = unsafe { byte_bits.try_into().unwrap_unchecked() };
464            let byte = unsafe {
465                PRESHIFTED7.get_unchecked(byte_bits[0])
466                    | PRESHIFTED6.get_unchecked(byte_bits[1])
467                    | PRESHIFTED5.get_unchecked(byte_bits[2])
468                    | PRESHIFTED4.get_unchecked(byte_bits[3])
469                    | PRESHIFTED3.get_unchecked(byte_bits[4])
470                    | PRESHIFTED2.get_unchecked(byte_bits[5])
471                    | PRESHIFTED1.get_unchecked(byte_bits[6])
472                    | PRESHIFTED0.get_unchecked(byte_bits[7])
473            };
474            let offset = self.bits_written / 8;
475            #[cfg(test)]
476            {
477                self.bytes[offset] = byte;
478            }
479            #[cfg(not(test))]
480            {
481                unsafe {
482                    *self.bytes.get_unchecked_mut(offset) = byte;
483                }
484            }
485            self.bits_written += 8;
486        }
487
488        // Remove emitted bits, leaving bits buffered
489        self.bits.drain(..drained);
490
491        true
492    }
493
494    /// flush the current page to nand and reset the page
495    async fn flush(&mut self) -> Result<bool, E> {
496        // Flush the remaining bits
497        debug_assert!(self.bits.len() < 8); // expecting less than a byte pending otherwise it should have been flushed
498        if !self.bits.is_empty() {
499            // Pad the bits with zeros
500            let padding = 8 - self.bits.len();
501            self.bits.extend((0..padding).map(|_| 0));
502
503            // Convert the bits to bytes
504            for byte_bits in self.bits.chunks_exact(8) {
505                let byte_bits: &[usize; 8] = unsafe { byte_bits.try_into().unwrap_unchecked() };
506                let byte = unsafe {
507                    PRESHIFTED7.get_unchecked(byte_bits[0])
508                        | PRESHIFTED6.get_unchecked(byte_bits[1])
509                        | PRESHIFTED5.get_unchecked(byte_bits[2])
510                        | PRESHIFTED4.get_unchecked(byte_bits[3])
511                        | PRESHIFTED3.get_unchecked(byte_bits[4])
512                        | PRESHIFTED2.get_unchecked(byte_bits[5])
513                        | PRESHIFTED1.get_unchecked(byte_bits[6])
514                        | PRESHIFTED0.get_unchecked(byte_bits[7])
515                };
516                let offset = self.bits_written / 8;
517                #[cfg(test)]
518                {
519                    self.bytes[offset] = byte;
520                }
521                #[cfg(not(test))]
522                {
523                    unsafe {
524                        *self.bytes.get_unchecked_mut(offset) = byte;
525                    }
526                }
527                self.bits_written += 8;
528            }
529
530            // Bookkeeping
531            self.bits_written -= padding;
532            self.bits.clear();
533        }
534
535        // The reader is going to look at the header to know how many bits to decode
536        self.write_header(self.bits_written as u32);
537
538        // Flush the bytes to NAND
539        self.done |= !(*self.flush_page)(&self.bytes).await?;
540
541        // Start the next page on a clean slate
542        self.reset();
543
544        Ok(self.done)
545    }
546}
547
548/// A huffman decoder that reads successive tables and data from pages
549pub struct Decoder {
550    page_size: usize,
551    page_threshold: usize,
552    page_threshold_limit: usize,
553    page_count: usize,
554    state: DecodeState,
555    decoder_trie: Option<CodeLookupTrie>,
556    decoded_bytes: Vec<u8>,
557    emitted_idx: usize,
558}
559
560/// The state while draining bytes from the decoder
561#[derive(Debug, Copy, Clone)]
562enum DecodeState {
563    /// The decoder is reading the frequency table from NAND
564    Table,
565    /// The decoder is reading the encoded data from NAND
566    Data,
567    /// There are no more pages to read
568    Done,
569    /// There was an error in the encoder or malformed data
570    Error,
571}
572
573impl Decoder {
574    /// Create a decoder that reads pages of `page_size` bytes
575    /// Every `page_threshold` pages, the decoder will rebuild the huffman tree
576    pub fn new(page_size: usize, page_threshold: usize) -> Decoder {
577        Decoder {
578            page_size,
579            page_threshold: 1,
580            page_threshold_limit: page_threshold,
581            page_count: 0,
582            state: DecodeState::Table,
583            decoder_trie: None,
584            decoded_bytes: Vec::with_capacity(page_size),
585            emitted_idx: 0,
586        }
587    }
588
589    /// Prepare the decoder for a new round of decoding
590    /// this keeps allocations, so its cheaper than a Decoder::new call
591    pub fn reset(&mut self) {
592        self.page_count = 0;
593        self.page_threshold = 1;
594        self.state = DecodeState::Table;
595        self.decoder_trie = None;
596        self.decoded_bytes.clear();
597        self.emitted_idx = 0;
598    }
599
600    /// Drain a byte from the decoder
601    pub async fn drain<E>(
602        &mut self,
603        input: &mut impl PageReader<E>,
604    ) -> Result<Option<u8>, DecompressionError<E>> {
605        // If there are already decoded bytes in the buffer, return them
606        if self.emitted_idx < self.decoded_bytes.len() {
607            let byte = if cfg!(test) {
608                self.decoded_bytes[self.emitted_idx]
609            } else {
610                unsafe { *self.decoded_bytes.get_unchecked(self.emitted_idx) }
611            };
612            self.emitted_idx += 1;
613            return Ok(Some(byte));
614        }
615
616        loop {
617            match self.state {
618                DecodeState::Done => {
619                    return Ok(None);
620                }
621                DecodeState::Error => {
622                    return Err(DecompressionError::Bad);
623                }
624                DecodeState::Table => {
625                    // Read the page and check if this is the last page
626                    let page = input.read_page().await?;
627                    if page[..4] == [0xFF; 4] {
628                        self.state = DecodeState::Done;
629                        return Ok(None);
630                    }
631
632                    // Memcopy the weights from the page into the weights array
633                    let mut weights = [0u32; SYMBOL_COUNT];
634                    debug_assert!(page.len() >= SYMBOL_COUNT * 4 + 4); // +4 for the header
635                    unsafe {
636                        let weights_sz = SYMBOL_COUNT * 4; // u32s
637                        let page_weights_ptr = page.as_ptr().add(4); // skip the header
638                        let weights_ptr = weights.as_mut_ptr() as *mut u8;
639                        core::ptr::copy_nonoverlapping(page_weights_ptr, weights_ptr, weights_sz);
640                    }
641
642                    // Build the tree from the weights
643                    let root = build_tree(&weights);
644                    self.decoder_trie = Some(CodeLookupTrie::new(root));
645
646                    // Ready to start decoding data
647                    self.state = DecodeState::Data;
648                }
649                DecodeState::Data => {
650                    // Update bookkeeping
651                    self.emitted_idx = 0;
652                    self.decoded_bytes.clear();
653
654                    // Read the page and check if this is the last page
655                    let page = input.read_page().await?;
656                    if page[..4] == [0xFF; 4] {
657                        self.state = DecodeState::Done;
658                        return Ok(None);
659                    }
660
661                    // Push page bits through the trie to get symbols
662                    let symbol_lookup = self.decoder_trie.as_mut().unwrap();
663                    let bits_written = u32::from_le_bytes(page[..4].try_into().unwrap());
664
665                    // The number of bits written to a page must be valid
666                    if !(32..=self.page_size * 8).contains(&(bits_written as usize)) {
667                        self.state = DecodeState::Error;
668                        return Err(DecompressionError::Bad);
669                    }
670
671                    let bytes_written = ((bits_written + 7) / 8) as usize;
672                    let page_bytes = &page[4..bytes_written];
673
674                    if !page_bytes.is_empty() {
675                        let mut bits_read = 32;
676                        let full_bytes = page_bytes.len() - 1;
677                        for &byte in &page_bytes[..full_bytes] {
678                            for i in (0..8).rev() {
679                                let bit = (byte >> i) & 1;
680                                if let Some(symbol) = symbol_lookup.next(bit) {
681                                    self.decoded_bytes.push(symbol);
682                                }
683                            }
684                        }
685                        bits_read += full_bytes as u32 * 8;
686
687                        // Process final byte which may be partial
688                        if let Some(&last_byte) = page_bytes.last() {
689                            let remaining_bits = (bits_written - bits_read) as usize;
690                            for i in (0..8).rev().take(remaining_bits) {
691                                let bit = (last_byte >> i) & 1;
692                                if let Some(symbol) = symbol_lookup.next(bit) {
693                                    self.decoded_bytes.push(symbol);
694                                }
695                            }
696                        }
697                    }
698
699                    // If we read enough pages, the tree will be on the next page
700                    self.page_count += 1;
701                    if self.page_count > self.page_threshold {
702                        self.page_count = 0;
703                        self.page_threshold =
704                            self.page_threshold_limit.min(self.page_threshold * 2);
705                        self.state = DecodeState::Table;
706                    } else {
707                        self.state = DecodeState::Data;
708                    }
709
710                    // Emit the first byte from this page
711                    // If there are already decoded bytes in the buffer, return them
712                    if self.emitted_idx < self.decoded_bytes.len() {
713                        let byte = if cfg!(test) {
714                            self.decoded_bytes[self.emitted_idx]
715                        } else {
716                            unsafe { *self.decoded_bytes.get_unchecked(self.emitted_idx) }
717                        };
718                        self.emitted_idx += 1;
719                        return Ok(Some(byte));
720                    }
721                }
722            }
723        }
724    }
725}
726
727/// A reader that reads pages from NAND
728#[allow(async_fn_in_trait)]
729pub trait PageReader<E> {
730    async fn read_page(&mut self) -> Result<&[u8], E>;
731    fn reset(&mut self);
732}
733
734pub struct BufferedPageReader<E> {
735    /// A page of bytes that was loaded from NAND
736    bytes: Vec<u8>,
737    /// A function that fills a buffer with a page from NAND
738    read_page: ReadPageFutureFn<E>,
739    /// Done
740    done: bool,
741}
742
743/// A function that takes a mutable reference to the page and fills it with bytes from NAND
744/// The future returns true if there are more pages that could be read
745pub type ReadPageFutureFn<E> =
746    Box<dyn for<'a> Fn(&'a mut [u8]) -> Pin<Box<dyn Future<Output = Result<bool, E>> + 'a>>>;
747
748impl<E> BufferedPageReader<E> {
749    pub fn new(page_size: usize, read_page: ReadPageFutureFn<E>) -> BufferedPageReader<E> {
750        let mut bytes = Vec::with_capacity(page_size);
751        unsafe { bytes.set_len(page_size) };
752        BufferedPageReader {
753            bytes,
754            read_page,
755            done: false,
756        }
757    }
758}
759
760impl<E> PageReader<E> for BufferedPageReader<E> {
761    /// Fetch page bytes from NAND and provide a reference to them
762    async fn read_page(&mut self) -> Result<&[u8], E> {
763        if self.done {
764            self.bytes.fill(0xFF);
765            return Ok(&self.bytes);
766        }
767        self.done |= !(*self.read_page)(&mut self.bytes).await?;
768        Ok(&self.bytes)
769    }
770
771    /// Reset the reader to start a new round of reading pages
772    fn reset(&mut self) {
773        self.done = false;
774    }
775}
776
777/// Decompression errors can be malformed data or the error from the FutureFn
778#[derive(Debug, Clone, Copy)]
779pub enum DecompressionError<E> {
780    /// The data is malformed or you kept calling drain after Bad occurred
781    Bad,
782    /// The error from the FutureFn
783    Load(E),
784}
785impl<E> From<E> for DecompressionError<E> {
786    fn from(err: E) -> Self {
787        DecompressionError::Load(err)
788    }
789}
790
791#[cfg(test)]
792mod tests {
793    use super::*;
794    use core::cell::RefCell;
795    use std::prelude::v1::*;
796    use std::rc::Rc;
797    use std::vec;
798    use std::vec::Vec;
799
800    #[test]
801    fn test_std_vec() {
802        let mut vec = Vec::new();
803        vec.push(1);
804        vec.push(2);
805        vec.push(3);
806        assert_eq!(vec, vec![1, 2, 3]);
807    }
808
809    #[test]
810    fn test_flush_fn() {
811        let mut buf = [1, 3, 5, 7];
812        let flush: WritePageFutureFn<()> = Box::new(|page: &[u8]| {
813            Box::pin(async move {
814                std::dbg!("flush", page.len());
815                Ok(true)
816            })
817        });
818
819        smol::block_on(async {
820            (*flush)(&mut buf).await.unwrap();
821            assert_eq!(buf, [1, 3, 5, 7]);
822        });
823    }
824
825    #[test]
826    fn test_page_writer_advance() {
827        let flush: WritePageFutureFn<()> = Box::new(|page| {
828            Box::pin(async move {
829                std::dbg!("flush", page.len());
830                Ok(true)
831            })
832        });
833        let mut wtr = BufferedPageWriter::new(2048, flush);
834        smol::block_on(async {
835            wtr.flush().await.unwrap();
836        });
837    }
838
839    #[test]
840    fn test_compress_simple() {
841        let flush: WritePageFutureFn<()> = Box::new(|page| {
842            Box::pin(async move {
843                std::dbg!("flush", page.len());
844                Ok(true)
845            })
846        });
847        let mut wtr = BufferedPageWriter::new(2048, flush);
848        let mut encoder = Encoder::new(2048, 4);
849
850        smol::block_on(async {
851            for value in 0..2048 {
852                encoder.sink(value as u8, &mut wtr).await.unwrap();
853            }
854        });
855    }
856
857    #[test]
858    fn test_compress_multi_page() {
859        let flush: WritePageFutureFn<()> = Box::new(|_page| Box::pin(async move { Ok(true) }));
860        let mut wtr = BufferedPageWriter::new(2048, flush);
861        let mut encoder = Encoder::new(2048, 4);
862
863        smol::block_on(async {
864            for value in 0..2048 * 3 {
865                encoder.sink(value as u8, &mut wtr).await.unwrap();
866            }
867            encoder.flush(&mut wtr).await.unwrap();
868        });
869
870        #[cfg(feature = "ratio")]
871        {
872            std::dbg!(
873                encoder.bytes_in,
874                encoder.bytes_out,
875                encoder.bytes_in as f32 / encoder.bytes_out as f32
876            );
877        }
878    }
879
880    #[test]
881    fn test_roundtrip() {
882        let buf: Vec<u8> = Vec::new();
883        let buf = Rc::new(RefCell::new(buf));
884        let wtr_buf = buf.clone();
885        let rdr_buf = buf.clone();
886        let flush_page: WritePageFutureFn<()> = Box::new(move |page| {
887            let buf = wtr_buf.clone();
888            Box::pin(async move {
889                let mut buf = buf.borrow_mut();
890                buf.extend_from_slice(page);
891                Ok(true)
892            })
893        });
894        const PAGE_SIZE: usize = 2048;
895        const PAGE_THRESHOLD: usize = 4;
896        let mut wtr = BufferedPageWriter::new(PAGE_SIZE, flush_page);
897        let mut encoder = Encoder::new(PAGE_SIZE, PAGE_THRESHOLD);
898        let read_page: ReadPageFutureFn<()> = Box::new(move |page| {
899            let buf = rdr_buf.clone();
900            Box::pin(async move {
901                let mut buf = buf.borrow_mut();
902                assert!(buf.len() % PAGE_SIZE == 0);
903                if buf.is_empty() {
904                    page.fill(0xFF);
905                    Ok(false)
906                } else {
907                    let drained = buf.drain(..PAGE_SIZE);
908                    page[..drained.len()]
909                        .iter_mut()
910                        .zip(drained)
911                        .for_each(|(p, b)| *p = b);
912                    Ok(true)
913                }
914            })
915        });
916        let mut rdr = BufferedPageReader::new(PAGE_SIZE, read_page);
917        let mut decoder = Decoder::new(PAGE_SIZE, PAGE_THRESHOLD);
918
919        // We need to test
920        // * no data
921        // * less than a page
922        // * exactly a page
923        // * multiple pages
924        // * exactly the page threshold
925        // * multiple tables
926        // * highly compressible data
927
928        let bad_rand = (0..100)
929            .map(|i| vec![0; i])
930            .collect::<Vec<_>>()
931            .into_iter()
932            .map(|v| {
933                let ptr = v.as_ptr();
934                (ptr, v.len())
935            })
936            .fold(0, |acc, (ptr, len)| acc + ptr as usize * len * 31)
937            % 9999991
938            + 5123457;
939        std::dbg!(bad_rand);
940
941        let test_cases: Vec<Vec<u8>> = vec![
942            vec![],
943            (0..10).collect::<Vec<_>>(),
944            (0..2048).map(|i| i as u8).collect::<Vec<_>>(),
945            (0..2048 * 3).map(|i| i as u8).collect::<Vec<_>>(),
946            (0..2048 * 4).map(|i| i as u8).collect::<Vec<_>>(),
947            (0..1024 * 1024).map(|i| i as u8).collect::<Vec<_>>(),
948            (0..bad_rand).map(|i| (31 * i) as u8).collect::<Vec<_>>(),
949            (0..bad_rand)
950                .map(|i| ((31 * i) % 16) as u8)
951                .collect::<Vec<_>>(),
952        ];
953
954        #[cfg(feature = "ratio")]
955        let mut compression_ratios = Vec::new();
956        for (test_case, test_data) in test_cases.into_iter().enumerate() {
957            std::dbg!(test_case);
958
959            // Reset the encoder, writer, and decoder
960            buf.borrow_mut().clear();
961            encoder.reset();
962            wtr.reset();
963            decoder.reset();
964            rdr.reset();
965
966            // Write bytes to the encoder
967            smol::block_on(async {
968                for value in &test_data {
969                    encoder.sink(*value, &mut wtr).await.unwrap();
970                }
971                encoder.flush(&mut wtr).await.unwrap();
972            });
973
974            std::dbg!(buf.borrow().len());
975            let num_pages = (buf.borrow().len() + PAGE_SIZE - 1) / PAGE_SIZE;
976            for page in 0..num_pages {
977                let header_offset = page * PAGE_SIZE;
978                let header = u32::from_le_bytes(
979                    buf.borrow()[header_offset..header_offset + 4]
980                        .try_into()
981                        .unwrap(),
982                );
983                std::dbg!(header);
984            }
985
986            #[cfg(feature = "ratio")]
987            {
988                compression_ratios.push((
989                    encoder.bytes_in as f32 / encoder.bytes_out as f32,
990                    humanize_bytes::humanize_bytes_binary!(encoder.bytes_in),
991                    humanize_bytes::humanize_bytes_binary!(encoder.bytes_out),
992                ));
993            }
994
995            // Read bytes from the decoder
996            smol::block_on(async {
997                let mut idx = 0;
998                while let Some(byte) = decoder.drain(&mut rdr).await.unwrap() {
999                    assert_eq!(
1000                        byte, test_data[idx],
1001                        "test case {} byte {} mismatch",
1002                        test_case, idx
1003                    );
1004                    idx += 1;
1005                }
1006                assert_eq!(idx, test_data.len());
1007            });
1008        }
1009
1010        #[cfg(feature = "ratio")]
1011        {
1012            std::dbg!(compression_ratios);
1013        }
1014    }
1015}