Skip to main content

uefi_decompress/
lib.rs

1#![no_std]
2use bitvec::{field::BitField, order::Msb0, slice::BitSlice, view::BitView};
3
4/// Decompress Error Definitions
5#[derive(Debug)]
6pub enum DecompressError {
7    InvalidSrcSize,
8    InvalidDstSize,
9    MalformedSrcData,
10}
11
12/// Supported Decompression Algorithms
13#[derive(Debug)]
14pub enum DecompressionAlgorithm {
15    UefiDecompress,
16    TianoDecompress,
17}
18
19/// Decompress the compressed data in `src` and store the output in `dst`, using the `algo` decompression algorithm.
20pub fn decompress_into_with_algo(
21    src: &[u8],
22    dst: &mut [u8],
23    algo: DecompressionAlgorithm,
24) -> Result<(), DecompressError> {
25    //sanity check the inputs
26    if src.len() < 8 {
27        Err(DecompressError::InvalidSrcSize)?;
28    }
29
30    let compressed_size = u32::from_le_bytes(src[0..4].try_into().unwrap()) as usize;
31    if compressed_size > src.len() {
32        Err(DecompressError::InvalidSrcSize)?;
33    }
34
35    let orig_size = u32::from_le_bytes(src[4..8].try_into().unwrap()) as usize;
36    if orig_size == 0 {
37        return Ok(());
38    }
39    if orig_size != dst.len() {
40        Err(DecompressError::InvalidDstSize)?;
41    }
42
43    //Create a code iterator that iterates through the `src` bitstream and returns `CodeSymbol` elements.
44    let mut dst_idx = 0;
45    for result in CodeIterator::new(&src[8..], algo) {
46        match result {
47            Ok(symbol) => match symbol {
48                CodeSymbol::OrigChar(char) => {
49                    // symbol is an original character literal - copy it directly to the output buffer.
50                    dst[dst_idx] = char;
51                    dst_idx += 1;
52                }
53                CodeSymbol::StrPointer(offset, len) => {
54                    // symbol is offset:len pair to be copied from a previously decompressed portion of the buffer.
55                    let start = dst_idx
56                        .checked_sub(offset)
57                        .and_then(|x| x.checked_sub(1))
58                        .ok_or(DecompressError::MalformedSrcData)?;
59
60                    // note: this loop is used (instead of e.g. slice::copy_within or slice::copy_non_overlapping)
61                    // because the offset:len window may overlap the current position. The "new" byte from the
62                    // overlapping region needs to be copied instead of the original byte that existed at the start of
63                    // the copy, which makes copy_within semantics inappropriate here.
64                    for src in start..start + len {
65                        dst[dst_idx] = dst[src];
66                        dst_idx += 1;
67                        if dst_idx == dst.len() {
68                            break;
69                        }
70                    }
71                }
72            },
73            //CodeIterator encountered an error trying to produce the next symbol - return it to caller.
74            Err(err) => Err(err)?,
75        }
76
77        // Decompression is complete.
78        if dst_idx == dst.len() {
79            break;
80        }
81    }
82    Ok(())
83}
84
85enum CodeSymbol {
86    OrigChar(u8),
87    StrPointer(usize, usize),
88}
89
90//Nomenclature: Char&Len set = 'C', Position set = 'P', Extra set = 'T'
91
92//Size of Char&Len set
93const NC: usize = 510;
94const CBIT: usize = 9;
95const CTABLE_BITSIZE: usize = 12;
96
97//Size of Extra Set
98const NT: usize = 19;
99const TBIT: usize = 5;
100const PTABLE_BITSIZE: usize = 8;
101
102//Size of Position Set (actual size runtime defined based on selected algorithm)
103const MAXNP: usize = 31;
104
105const NPT: usize = [NT, MAXNP][(NT < MAXNP) as usize]; //Note: fancy const replacement for non-const usize::max(NT, MAXNP)
106
107struct CodeIterator<'a> {
108    src: &'a BitSlice<u8, Msb0>,
109    src_index: usize,
110    is_error: bool,
111    remaining_block_size: usize,
112    left: [u16; 2 * NC - 1],
113    right: [u16; 2 * NC - 1],
114    c_len: [u8; NC],
115    pt_len: [u8; NPT],
116    c_table: [u16; 1 << CTABLE_BITSIZE],
117    pt_table: [u16; 1 << PTABLE_BITSIZE],
118    p_bit: usize,
119}
120
121impl<'a> CodeIterator<'a> {
122    // initialize a new CodeIterator instance for the given source and algorithm
123    fn new(src: &'a [u8], algo: DecompressionAlgorithm) -> Self {
124        Self {
125            src: src.view_bits::<Msb0>(),
126            src_index: 0,
127            is_error: false,
128            remaining_block_size: 0,
129            left: [0u16; 2 * NC - 1],
130            right: [0u16; 2 * NC - 1],
131            c_len: [0u8; NC],
132            pt_len: [0u8; NPT],
133            c_table: [0u16; 4096],
134            pt_table: [0u16; 256],
135            p_bit: match algo {
136                DecompressionAlgorithm::UefiDecompress => 4,
137                DecompressionAlgorithm::TianoDecompress => 5,
138            },
139        }
140    }
141
142    // advances the source bitstream by `count` bits.
143    fn pop_bits(&mut self, count: usize) -> Result<&BitSlice<u8, Msb0>, DecompressError> {
144        if let Some(bitslice) = self.src.get(self.src_index..self.src_index + count) {
145            self.src_index += count;
146            Ok(bitslice)
147        } else {
148            Err(DecompressError::MalformedSrcData)
149        }
150    }
151
152    // returns the next `count` bits of the source bitstream without advancing it.
153    fn peek_bits(&self, count: usize) -> Result<&BitSlice<u8, Msb0>, DecompressError> {
154        if let Some(bitslice) = self.src.get(self.src_index..self.src_index + count) {
155            Ok(bitslice)
156        } else {
157            Err(DecompressError::MalformedSrcData)
158        }
159    }
160
161    // Reads the code lengths for the Extra Set or Position Set Huffman codes for the current block.
162    //
163    // The code lengths are preceded by a `num_bits`-sized field that gives the length of the array.
164    //
165    // This is then followed by an encoded set of lengths which use a variable number of bits:
166    // - If the code length is less than 7, it is encoded as a 3-bit binary number.
167    // - If the code length is 7 or greater, it is encoded as a series of '1b' followed by a terminating '0b'.
168    //   The code length is therefore equal to "count of 1s" + 4.
169    //   Example: "4" is coded as '100b', "7" is coded as '1110b', and "12" is coded as `111111110b`
170    //
171    // If the 'extra' flag is set, then after the third length element in the bitstream, there is a 2-bit field
172    // indicating the number of additional zero lengths that follow. For example, the following array of lengths
173    // [2,9,0,0,5,7] would be encoded with the following bit stream (num_bits size field not shown).
174    // 010 111110 10 101 1110
175    //            ^
176    //            this is the `extra` field added to generate the 2 "zero" lengths
177    // If the extra flag is not set, the same array of lengths would be encoded with the following bitstream
178    // 010 111110 000 000 101 1110
179    //
180    // The resulting code length array will be stored in self.pt_len.
181    //
182    // Once the code length array is generated, it is fed to the the Self::build_huffman_table() routine
183    // to generate the resulting Huffman code table, which will be stored in self.pt_table.
184    //
185    // Refer to UEFI Specification 2.10, section 19.2.3.1.
186    //
187    fn read_pt_len(&mut self, num_symbols: usize, num_bits: usize, extra: bool) -> Result<(), DecompressError> {
188        assert!(num_symbols <= NPT);
189
190        // Read Set Length Array size
191        let count = self.pop_bits(num_bits)?.load_be::<usize>();
192        if count == 0 {
193            // this represents the only Huffman code used.
194            let char_c = self.pop_bits(num_bits)?.load_be::<u16>();
195            self.pt_table.fill(char_c);
196            self.pt_len[..num_symbols].fill(0);
197            Ok(())
198        } else {
199            let mut idx = 0;
200            while idx < count && idx < NPT {
201                // if a code length is less than 7, it is encoded as 3-bit value. Otherwise it is encoded by a series of
202                // 1s followed by a terminating zero. The number of 1s = code length - 4.
203                let mut code_len = self.pop_bits(3)?.load_be::<u8>();
204                if code_len == 7 {
205                    loop {
206                        let bit = self.pop_bits(1)?[0];
207                        if bit {
208                            //current bit is one.
209                            code_len += 1;
210                        } else {
211                            break;
212                        }
213                    }
214                }
215                self.pt_len[idx] = code_len;
216                idx += 1;
217
218                // if 'extra' is set, then after the third length of the code length concatenation, a 2-bit value is
219                // used to indicate the number of consecutive zero lengths immediately after the third length.
220                if extra && idx == 3 {
221                    let zero_count = self.pop_bits(2)?.load_be::<usize>();
222                    self.pt_len[idx..idx + zero_count].fill(0);
223                    idx += zero_count;
224                }
225            }
226            if idx > num_symbols {
227                Err(DecompressError::MalformedSrcData)?;
228            }
229            // zero the rest of the table.
230            self.pt_len[idx..num_symbols].fill(0);
231
232            //convert the resulting code length array (self.pt_len) into a Huffman coding table (self.pt_table)
233            Self::build_huffman_table(
234                num_symbols,
235                &self.pt_len,
236                PTABLE_BITSIZE,
237                &mut self.pt_table,
238                &mut self.left,
239                &mut self.right,
240            )
241        }
242    }
243
244    // Read the code lengths for the Char&Length set Huffman code for the current block.
245    //
246    // The code lengths are preceded by a 9-bit field that gives the length of the array.
247    //
248    // This is then followed by an encoded set of lengths which use a variable number of bits. The set of lengths is
249    // double-encoded:
250    //
251    //  1: If a code length is not zero, then it is encoded as "code length + 2";
252    //     If a code length is zero, then the number of consecutive zero lengths starting from this code length is
253    //     counted:
254    //    - if the count is equal to or less than 2, then the code "0" is used for each zero length;
255    //    - if the count is greater than 2 and less than 19, then the code "1" followed by a 4-bit value of "count - 3"
256    //      is used for these consecutive zero lengths;
257    //    - if the count is equal to 19, then it is treated as "1 + 18," and a code "0" and a code "1" followed by a
258    //      4-bit value of "15" are used for these consecutive zero lengths;
259    //    - if the count is greater than 19, then the code "2" followed by a 9-bit value of "count - 20" is used for
260    //      these consecutive zero lengths.
261    //  2: The resulting bitstring symbols are the "extra set", and are encoded using Huffman coding. The tables derived
262    //     from execution of the read_pt_len() function on the extra set can be used to decode these symbols.
263    //
264    // To decode the table, the above process is reversed. First, the Huffman coded "extra set" symbols are decoded,
265    // then the resulting symbols are converted into a code length by reversing the step 1 above.
266    //
267    // The resulting code length array will be stored in self.c_len.
268    //
269    // Once the code length array is generated, it is fed to the the Self::build_huffman_table() routine
270    // to generate the resulting Huffman code table, which will be stored in self.c_table.
271    //
272    // Refer to UEFI Specification 2.10, section 19.2.3.1.
273    //
274    // NOTE: this routine requires that the current contents of self.pt_len, self.pt_table, self.left, and self.right
275    // are initialized to match the "Extra Set" by executing read_pt_len() to decode the Extra Set Code Length Array.
276    //
277    fn read_c_len(&mut self) -> Result<(), DecompressError> {
278        // Read Set Length Array Size
279        let count = self.pop_bits(CBIT)?.load_be::<usize>();
280
281        if count == 0 {
282            // this represents the only Huffman code used
283            let symbol = self.pop_bits(CBIT)?.load_be::<u16>();
284            self.c_len.fill(0);
285            self.c_table.fill(symbol);
286            Ok(())
287        } else {
288            // iterate over all the symbols in the array.
289            let mut idx = 0;
290            while idx < count {
291                // read the next symbol. First, read the first PTABLE_BITSIZE bits of the symbol.
292                let mut symbol = self.pt_table[self.peek_bits(PTABLE_BITSIZE)?.load_be::<usize>()];
293                // if the symbol is less than NT, then it can be used as-is
294                if symbol as usize >= NT {
295                    // symbol is larger than NT. Read bits from the stream and traverse the left/right tree until a leaf
296                    // node (less than NT) is reached.
297                    let mut mask_idx = PTABLE_BITSIZE;
298                    loop {
299                        let bit_buff = self.peek_bits(mask_idx + 1)?;
300                        if bit_buff[mask_idx] {
301                            symbol = self.right[symbol as usize];
302                        } else {
303                            symbol = self.left[symbol as usize];
304                        }
305                        mask_idx += 1;
306                        if (symbol as usize) < NT {
307                            break;
308                        }
309                    }
310                }
311
312                //now that we know the symbol, advance the bitstream by the symbol bitlength.
313                self.pop_bits(self.pt_len[symbol as usize] as usize)?;
314
315                if symbol <= 2 {
316                    // if the symbol is 2 or less, it encodes 1 or more zero length symbols
317                    if symbol == 0 {
318                        // a single zero length
319                        symbol = 1;
320                    } else if symbol == 1 {
321                        // '1' followed by a 4-bit value of count - 3 zero lengths follow.
322                        symbol = self.pop_bits(4)?.load_be::<u16>() + 3;
323                    } else if symbol == 2 {
324                        // '2' followed by a 9-bit value of count - 20 zero lengths follow.
325                        symbol = self.pop_bits(CBIT)?.load_be::<u16>() + 20;
326                    }
327
328                    //"symbol" now contains the consecutive number of zero-length symbols starting at the current idx.
329                    //update the c_len table entries corresponding to these symbols and advance the index.
330                    for _ in 0..symbol {
331                        if idx >= self.c_len.len() {
332                            Err(DecompressError::MalformedSrcData)?;
333                        }
334                        self.c_len[idx] = 0;
335                        idx += 1;
336                    }
337                } else {
338                    // otherwise, the symbol encodes 'code length +2'. store it in c_len and advance the index.
339                    if idx >= self.c_len.len() {
340                        Err(DecompressError::MalformedSrcData)?;
341                    }
342                    self.c_len[idx] = (symbol - 2) as u8;
343                    idx += 1;
344                }
345            }
346            // all valid symbols processed, zero the rest of c_len.
347            self.c_len[idx..NC].fill(0);
348
349            //convert the resulting code length array (self.c_len) into a Huffman coding table (self.c_table)
350            Self::build_huffman_table(
351                NC,
352                &self.c_len,
353                CTABLE_BITSIZE,
354                &mut self.c_table,
355                &mut self.left,
356                &mut self.right,
357            )
358        }
359    }
360
361    // Decodes a "position" value from the current bitstream according to the Position Set encoding.
362    //
363    // A String Position is a value that indicates the distance between the current position and the target string. The
364    // String Position value is defined as "Current Position - Starting Position of the target string - 1." The String
365    // Position value ranges from 0 to 8190 (so 8192 is the "sliding window" size, and this range should be ensured by
366    // the compressor). The lengths of the String Position values (in binary form) form a value set ranging from 0 to 13
367    // (it is assumed that value 0 has length of 0). This value set is the Position Set for Huffman Coding. The full
368    // representation of a String Position value is composed of two consecutive parts: one is the Huffman code for the
369    // value length; the other is the actual String Position value of "length - 1" bits (excluding the highest bit since
370    // the highest bit is always "1"). For example, String Position value 18 is represented as: Huffman code for "5"
371    // followed by "0010." If the value length is 0 or 1, then no value is appended to the Huffman code.
372    //
373    // NOTE: this routine requires that the current contents of self.pt_len, self.pt_table, self.left, and self.right
374    // are initialized to match the "Position Set" by executing read_pt_len() to decode the Position Set Code Length
375    // Array.
376    fn decode_position(&mut self) -> Result<usize, DecompressError> {
377        //First, read the first PTABLE_BITSIZE bits of the position symbol.
378        let bit_buffer = self.peek_bits(PTABLE_BITSIZE)?;
379        let mut val = self.pt_table[bit_buffer.load_be::<usize>()] as usize;
380
381        // if the symbol is less than NT, then it can be used as-is
382        if val >= MAXNP {
383            // symbol is larger than NT. Read bits from the stream and traverse the left/right tree until a leaf
384            // node (less than NT) is reached.
385            let mut mask_idx = PTABLE_BITSIZE;
386            loop {
387                let bit_buffer = self.peek_bits(mask_idx + 1)?;
388                if bit_buffer[mask_idx] {
389                    val = self.right[val] as usize;
390                } else {
391                    val = self.left[val] as usize;
392                }
393
394                mask_idx += 1;
395
396                if val < MAXNP {
397                    break;
398                }
399            }
400        }
401        self.pop_bits(self.pt_len[val] as usize)?;
402
403        // if val is <= 1, then it directly encodes the position
404        if val > 1 {
405            // otherwise, (val - 1) encodes the bit length of an integer that encodes the position.
406            val = (1 << (val - 1)) + self.pop_bits(val - 1)?.load_be::<usize>();
407        }
408
409        Ok(val)
410    }
411
412    // Constructs a Huffman decode table + tree.
413    //
414    // input parameters:
415    // num_symbols: number of symbols in the Huffman symbol set
416    // bit_lengths: a table describing the code length for each symbol (indexed by the symbol)
417    // table_bits: the number of bits to be used for fixed symbol lookup. Symbols with an encoded bitlength longer than
418    //             this parameter will require traversing the secondary tree to fully decode.
419    //
420    //  modifies:
421    //  table: the fixed decode table (see description below)
422    //  left: the "left" nodes of the secondary decoder tree.
423    //  right: the right" nodes of the secondary decoder tree.
424    //
425    // This routine takes as input the bit_lengths table representing the canonical Huffman encoding over the output
426    // symbols. It then generates 3 different table structures in the slices given as input:
427    // - table: this table consists of two sets of entries.
428    //    - fixed lookup entries - this consists of fixed entries for all symbols where the length of the encoded
429    //      bitstring is less than or equal to the table_bits. For a given symbol, all entries that have that symbol as
430    //      a prefix are set to the decoded value of the symbol. For example, assume that the bitstring `100b` is the
431    //      encoded representation of the value 0xB - in that case, all of the entries of the table that start with
432    //      `100xxxxxxxxxb` (i.e. indexes 0x800 to 0x9FF) would be set to 0xB.
433    //    - tree lookup root entry - if the length of the encoded symbol is longer than the table bits, then the unique
434    //      prefix of that entry points to the index of the root of a secondary decode tree encoded in the left & right
435    //      array structures. "Leaf" elements of the tree occupy the first `num_symbol` entries in the left and right
436    //      arrays, and correspond to literal final symbols. "Node" elements of the tree occupy the entries higher than
437    //      `num_symbol` in the left and and right arrays and point to other nodes or leaves.
438    //
439    //      To decode the final symbol for an encoded bitstring that is longer than table_size bits, first locate the
440    //      locate the entry within the table that corresponds to the root index in the left/right trees. Then, starting
441    //      with the bit immediately following the first table_size bits of the encoded symbol, read bits from the
442    //      encoded symbol. For each bit, if it is a 1, retrieve the next index from the `right` array, otherwise if it
443    //      is a 0, retrieve the next index from the `left`. If the retrieved index is less than `num_symbol`, then it
444    //      is the final decoded symbol. Otherwise, it is the index into the left or right tree for the next bit.
445    //
446    //      Note: if all possible symbols can be encoded within the fixed table width, then the secondary lookup is not
447    //      needed.
448    //
449    // - left & right - the secondary decode tree as described above.
450    //
451    // Note: This implementation shares the "left & right" tables between the Char&Len symbol Set decode and the
452    // Position Set decode; the portions of left & right used by each decode are disjoint. Care is taken to ensure that
453    // constructing a table only modifies left & right indices associated with that table.
454    fn build_huffman_table(
455        num_symbols: usize,
456        bit_lengths: &[u8],
457        table_bits: usize,
458        table: &mut [u16],
459        left: &mut [u16],
460        right: &mut [u16],
461    ) -> Result<(), DecompressError> {
462        assert!(table_bits <= 16);
463
464        // calculate the number of symbols for each bit length.
465        let mut count = [0u16; 17];
466        for idx in 0..num_symbols {
467            if bit_lengths[idx] > 16 {
468                Err(DecompressError::MalformedSrcData)?;
469            }
470            count[bit_lengths[idx] as usize] += 1;
471        }
472
473        // Determine the start index for each bit length. This determines the start index within the fixed size decode
474        // table for all symbols of a given bit length.
475        let mut start = [0u16; 18];
476        for idx in 1..=16 {
477            let word_of_start = start[idx];
478            let word_of_count = count[idx] << (16 - idx);
479            start[idx + 1] = word_of_start.wrapping_add(word_of_count);
480        }
481        if start[17] != 0 {
482            Err(DecompressError::MalformedSrcData)?;
483        }
484
485        // extended_bits is the number bits in the symbol exceeding the bit length for fixed entries in the table.
486        let extended_bits = 16 - table_bits;
487
488        // Determine weight of each length (the number of entries that a given symbol length will consume in the table).
489        let mut weight = [0; 17];
490        for idx in 1..=table_bits {
491            start[idx] >>= extended_bits;
492            weight[idx] = 1 << (table_bits - idx);
493        }
494
495        for (idx, w) in weight.iter_mut().enumerate().skip(table_bits + 1) {
496            *w = 1 << (16 - idx)
497        }
498
499        // zero unused table entries.
500        let idx = start[table_bits + 1] >> extended_bits;
501        if idx != 0 {
502            let idx_3 = 1 << table_bits;
503            if idx < idx_3 {
504                table[idx as usize..idx_3 as usize].fill(0);
505            }
506        }
507
508        // Private helper structure used in the implementation below to simplify construction of the secondary tree.
509        enum TablePointer {
510            Table(usize),
511            Left(usize),
512            Right(usize),
513        }
514        impl TablePointer {
515            fn set(&self, table: &mut [u16], left: &mut [u16], right: &mut [u16], val: u16) {
516                match self {
517                    TablePointer::Table(idx) => table[*idx] = val,
518                    TablePointer::Left(idx) => left[*idx] = val,
519                    TablePointer::Right(idx) => right[*idx] = val,
520                }
521            }
522
523            fn get(&self, table: &mut [u16], left: &mut [u16], right: &mut [u16]) -> u16 {
524                match self {
525                    TablePointer::Table(idx) => table[*idx],
526                    TablePointer::Left(idx) => left[*idx],
527                    TablePointer::Right(idx) => right[*idx],
528                }
529            }
530        }
531
532        // tracks the next available node
533        let mut next_avail_node = num_symbols;
534        // mask used to check the bit for left vs. right construction
535        let mask = 1 << (15 - table_bits);
536
537        // iterate over all symbols in the alphabet to generate the table.
538        for (char, sym_bit_len) in bit_lengths.iter().enumerate().take(num_symbols) {
539            let sym_bit_len = *sym_bit_len as usize;
540
541            // if the symbol length is zero, it is unused.
542            if sym_bit_len == 0 {
543                continue;
544            }
545
546            // max symbol length is fixed at 16 by spec, so encountering a larger symbol length is an error.
547            if sym_bit_len > 16 {
548                Err(DecompressError::MalformedSrcData)?;
549            }
550
551            // get the next code.
552            let next_code = start[sym_bit_len].wrapping_add(weight[sym_bit_len]);
553
554            if sym_bit_len <= table_bits {
555                // the symbol is short enough that tree construction is not needed.
556
557                // verify start and next sanity.
558                if start[sym_bit_len] >= next_code || next_code > 1 << table_bits {
559                    Err(DecompressError::MalformedSrcData)?;
560                }
561
562                // fill in all the elements in the table for which this symbol is a prefix.
563                for idx in start[sym_bit_len]..next_code {
564                    table[idx as usize] = char.try_into().expect("symbol count too large");
565                }
566            } else {
567                // the symbol is long enough that tree construction is required.
568                let mut symbol_bitstring = start[sym_bit_len];
569                let mut pointer = TablePointer::Table((symbol_bitstring >> extended_bits) as usize);
570                let mut idx = sym_bit_len - table_bits;
571
572                // traverse the tree using the extended bits in the symbol bitstring to select nodes
573                while idx != 0 {
574                    if pointer.get(table, left, right) == 0 && next_avail_node < (2 * NC - 1) {
575                        pointer.set(table, left, right, next_avail_node.try_into().expect("symbol count too large"));
576                        right[next_avail_node] = 0;
577                        left[next_avail_node] = 0;
578                        next_avail_node += 1;
579                    }
580
581                    if pointer.get(table, left, right) < (2 * NC - 1) as u16 {
582                        if symbol_bitstring & mask != 0 {
583                            pointer = TablePointer::Right(pointer.get(table, left, right) as usize);
584                        } else {
585                            pointer = TablePointer::Left(pointer.get(table, left, right) as usize);
586                        }
587                    }
588
589                    symbol_bitstring <<= 1;
590                    idx -= 1;
591                }
592                // set the final node to the decoded symbol.
593                pointer.set(table, left, right, char.try_into().expect("symbol count too large"));
594            }
595
596            //update the start index for this bit length
597            start[sym_bit_len] = next_code;
598        }
599        Ok(())
600    }
601}
602
603impl Iterator for CodeIterator<'_> {
604    type Item = Result<CodeSymbol, DecompressError>;
605
606    // Returns the next CodeSymbol from the bitstream.
607    fn next(&mut self) -> Option<Self::Item> {
608        if self.is_error {
609            return None;
610        }
611        if self.remaining_block_size == 0 {
612            //Starting a new block - re-initialize block state.
613
614            //Read new block size.
615            self.remaining_block_size = match self.pop_bits(16) {
616                Ok(bits) => bits.load_be::<u16>() as usize,
617                Err(err) => {
618                    self.is_error = true;
619                    return Some(Err(err));
620                }
621            };
622
623            // Read in Extra Set Array and generate Huffman code mapping table for extra set used to decode Char&Len set.
624            if let Err(err) = self.read_pt_len(NT, TBIT, true) {
625                self.is_error = true;
626                return Some(Err(err));
627            }
628
629            // Read in Char&Len Set Array and generate Huffman code mapping table for Char&Len set.
630            if let Err(err) = self.read_c_len() {
631                self.is_error = true;
632                return Some(Err(err));
633            }
634
635            // Read in the Position Set Array and generate Huffman code mapping table for the Position set.
636            if let Err(err) = self.read_pt_len(MAXNP, self.p_bit, false) {
637                self.is_error = true;
638                return Some(Err(err));
639            }
640        }
641        self.remaining_block_size -= 1;
642
643        // Decode the next Char&Len symbol. First, find the index in the c_table by peeking the next 12 bits.
644        let bit_buff = match self.peek_bits(CTABLE_BITSIZE) {
645            Ok(buff) => buff,
646            Err(err) => {
647                self.is_error = true;
648                return Some(Err(err));
649            }
650        };
651        let mut decode_idx = self.c_table[bit_buff.load_be::<usize>()] as usize;
652
653        // If the index is larger than NC, then reconstruct the symbol by traversing the secondary decode tree.
654        // see read_c_len() for details of how this is done.
655        if decode_idx >= NC {
656            let mut mask_idx = CTABLE_BITSIZE;
657            loop {
658                let bit_buff = match self.peek_bits(mask_idx + 1) {
659                    Ok(buff) => buff,
660                    Err(err) => {
661                        self.is_error = true;
662                        return Some(Err(err));
663                    }
664                };
665                if bit_buff[mask_idx] {
666                    decode_idx = self.right[decode_idx] as usize;
667                } else {
668                    decode_idx = self.left[decode_idx] as usize;
669                }
670                mask_idx += 1;
671                if decode_idx < NC {
672                    break;
673                };
674            }
675        }
676        //decode_idx the current symbol. Advance the bitstream by the bitlength of the current symbol.
677        if let Err(err) = self.pop_bits(self.c_len[decode_idx] as usize) {
678            self.is_error = true;
679            return Some(Err(err));
680        }
681
682        //convert the symbol to the appropriate CodeSymbol
683        if decode_idx < 256 {
684            // symbols from 0-255 are byte literals.
685            Some(Ok(CodeSymbol::OrigChar(decode_idx as u8)))
686        } else {
687            // symbols greater than 255 are string lengths.
688            let len = decode_idx - (0x100 - 3);
689
690            // string lengths are followed by an encoded string position; invoke decode_position() to decode it.
691            let pos = match self.decode_position() {
692                Ok(pos) => pos,
693                Err(err) => {
694                    self.is_error = true;
695                    return Some(Err(err));
696                }
697            };
698
699            Some(Ok(CodeSymbol::StrPointer(pos, len)))
700        }
701    }
702}
703
704#[cfg(test)]
705mod test {
706    extern crate std;
707    use std::{fs::File, io::Read, iter::zip, println, time, vec, vec::Vec};
708
709    use crate::decompress_into_with_algo;
710
711    macro_rules! test_collateral {
712        ($fname:expr) => {
713            concat!(env!("CARGO_MANIFEST_DIR"), "/resources/test/", $fname)
714        };
715    }
716
717    #[test]
718    fn uefi_decompress_should_produce_expected_buffer() {
719        let mut compressed_file =
720            File::open(test_collateral!("uefi_compressed.bin")).expect("failed to open test file");
721        let mut compressed_buffer = Vec::new();
722
723        compressed_file.read_to_end(&mut compressed_buffer).expect("failed to read test file");
724
725        let mut uncompressed_file =
726            File::open(test_collateral!("uefi_uncompressed.bin")).expect("failed to open test file");
727        let mut uncompressed_buffer = Vec::new();
728        uncompressed_file.read_to_end(&mut uncompressed_buffer).expect("failed to read test file");
729
730        let mut test_buffer = vec![0u8; uncompressed_buffer.len()];
731
732        decompress_into_with_algo(&compressed_buffer, &mut test_buffer, crate::DecompressionAlgorithm::UefiDecompress)
733            .unwrap();
734        assert_eq!(test_buffer.len(), uncompressed_buffer.len());
735        for (idx, (test, reference)) in zip(test_buffer, uncompressed_buffer).enumerate() {
736            assert!(test == reference, "mismatch at idx: {:}, expected {:#x} != {:#x} actual", idx, reference, test);
737        }
738    }
739
740    #[test]
741    fn tiano_decompress_should_produce_expected_buffer() {
742        let mut compressed_file =
743            File::open(test_collateral!("tiano_compressed.bin")).expect("failed to open test file");
744        let mut compressed_buffer = Vec::new();
745
746        compressed_file.read_to_end(&mut compressed_buffer).expect("failed to read test file");
747
748        let mut uncompressed_file =
749            File::open(test_collateral!("tiano_uncompressed.bin")).expect("failed to open test file");
750        let mut uncompressed_buffer = Vec::new();
751        uncompressed_file.read_to_end(&mut uncompressed_buffer).expect("failed to read test file");
752
753        let mut test_buffer = vec![0u8; uncompressed_buffer.len()];
754
755        decompress_into_with_algo(&compressed_buffer, &mut test_buffer, crate::DecompressionAlgorithm::TianoDecompress)
756            .unwrap();
757        assert_eq!(test_buffer.len(), uncompressed_buffer.len());
758        for (idx, (test, reference)) in zip(test_buffer, uncompressed_buffer).enumerate() {
759            assert!(test == reference, "mismatch at idx: {:}, expected {:#x} != {:#x} actual", idx, reference, test);
760        }
761    }
762
763    #[test]
764    fn decompress_with_original_size_of_zero_should_return_zero_sized_buffer() {
765        // Setup a compressed buffer where the original size is zero but the compressed size is non-zero.
766        // This is represented by a 16-byte buffer where the first byte is 0x08 (indicating compressed size is 8).
767        let mut compressed_buffer = [0x0; 16];
768        compressed_buffer[0] = 0x08;
769
770        let mut uefi_uncompressed = Vec::new();
771        assert!(decompress_into_with_algo(&compressed_buffer, &mut uefi_uncompressed, crate::DecompressionAlgorithm::UefiDecompress).is_ok());
772        assert_eq!(uefi_uncompressed.len(), 0);
773
774        let mut tiano_uncompressed = Vec::new();
775        assert!(decompress_into_with_algo(&compressed_buffer, &mut tiano_uncompressed, crate::DecompressionAlgorithm::TianoDecompress).is_ok());
776        assert_eq!(tiano_uncompressed.len(), 0);
777    }
778
779    #[test]
780    fn fuzz_testing_should_fail_gracefully() {
781        const FUZZ_COUNT: usize = 100;
782        let mut compressed_file =
783            File::open(test_collateral!("uefi_compressed.bin")).expect("failed to open test file");
784        let mut compressed_buffer = Vec::new();
785
786        compressed_file.read_to_end(&mut compressed_buffer).expect("failed to read test file");
787
788        let mut uncompressed_file =
789            File::open(test_collateral!("uefi_uncompressed.bin")).expect("failed to open test file");
790        let mut uncompressed_buffer = Vec::new();
791        uncompressed_file.read_to_end(&mut uncompressed_buffer).expect("failed to read test file");
792
793        let uncompressed_len = uncompressed_buffer.len();
794
795        for _ in 0..FUZZ_COUNT {
796            let mut fuzz_buffer = compressed_buffer.clone();
797            let fuzz_time = time::SystemTime::now().duration_since(time::UNIX_EPOCH).unwrap().as_micros() as usize;
798            let fuzz_idx = fuzz_time % fuzz_buffer.len();
799            println!("fuzz_idx: {:} before: {:#x}", fuzz_idx, fuzz_buffer[fuzz_idx]);
800            fuzz_buffer[fuzz_idx] ^= 0xff;
801            println!("fuzz_idx: {:} after: {:#x}", fuzz_idx, fuzz_buffer[fuzz_idx]);
802
803            let mut test_buffer = vec![0u8; uncompressed_len];
804
805            //note: not all corruption can be successfully detected. most of the time (but not all) this will return an Err.
806            //the goal of the test is to ensure failure doesn't panic, not that bad data is always caught.
807            let _ = decompress_into_with_algo(
808                &fuzz_buffer,
809                &mut test_buffer,
810                crate::DecompressionAlgorithm::UefiDecompress,
811            );
812        }
813    }
814}