uefi_decompress/lib.rs
1#![no_std]
2use bitvec::{field::BitField, order::Msb0, slice::BitSlice, view::BitView};
3
4/// Decompress Error Definitions
5#[derive(Debug)]
6pub enum DecompressError {
7 InvalidSrcSize,
8 InvalidDstSize,
9 MalformedSrcData,
10}
11
12/// Supported Decompression Algorithms
13#[derive(Debug)]
14pub enum DecompressionAlgorithm {
15 UefiDecompress,
16 TianoDecompress,
17}
18
19/// Decompress the compressed data in `src` and store the output in `dst`, using the `algo` decompression algorithm.
20pub fn decompress_into_with_algo(
21 src: &[u8],
22 dst: &mut [u8],
23 algo: DecompressionAlgorithm,
24) -> Result<(), DecompressError> {
25 //sanity check the inputs
26 if src.len() < 8 {
27 Err(DecompressError::InvalidSrcSize)?;
28 }
29
30 let compressed_size = u32::from_le_bytes(src[0..4].try_into().unwrap()) as usize;
31 if compressed_size > src.len() {
32 Err(DecompressError::InvalidSrcSize)?;
33 }
34
35 let orig_size = u32::from_le_bytes(src[4..8].try_into().unwrap()) as usize;
36 if orig_size == 0 {
37 return Ok(());
38 }
39 if orig_size != dst.len() {
40 Err(DecompressError::InvalidDstSize)?;
41 }
42
43 //Create a code iterator that iterates through the `src` bitstream and returns `CodeSymbol` elements.
44 let mut dst_idx = 0;
45 for result in CodeIterator::new(&src[8..], algo) {
46 match result {
47 Ok(symbol) => match symbol {
48 CodeSymbol::OrigChar(char) => {
49 // symbol is an original character literal - copy it directly to the output buffer.
50 dst[dst_idx] = char;
51 dst_idx += 1;
52 }
53 CodeSymbol::StrPointer(offset, len) => {
54 // symbol is offset:len pair to be copied from a previously decompressed portion of the buffer.
55 let start = dst_idx
56 .checked_sub(offset)
57 .and_then(|x| x.checked_sub(1))
58 .ok_or(DecompressError::MalformedSrcData)?;
59
60 // note: this loop is used (instead of e.g. slice::copy_within or slice::copy_non_overlapping)
61 // because the offset:len window may overlap the current position. The "new" byte from the
62 // overlapping region needs to be copied instead of the original byte that existed at the start of
63 // the copy, which makes copy_within semantics inappropriate here.
64 for src in start..start + len {
65 dst[dst_idx] = dst[src];
66 dst_idx += 1;
67 if dst_idx == dst.len() {
68 break;
69 }
70 }
71 }
72 },
73 //CodeIterator encountered an error trying to produce the next symbol - return it to caller.
74 Err(err) => Err(err)?,
75 }
76
77 // Decompression is complete.
78 if dst_idx == dst.len() {
79 break;
80 }
81 }
82 Ok(())
83}
84
85enum CodeSymbol {
86 OrigChar(u8),
87 StrPointer(usize, usize),
88}
89
90//Nomenclature: Char&Len set = 'C', Position set = 'P', Extra set = 'T'
91
92//Size of Char&Len set
93const NC: usize = 510;
94const CBIT: usize = 9;
95const CTABLE_BITSIZE: usize = 12;
96
97//Size of Extra Set
98const NT: usize = 19;
99const TBIT: usize = 5;
100const PTABLE_BITSIZE: usize = 8;
101
102//Size of Position Set (actual size runtime defined based on selected algorithm)
103const MAXNP: usize = 31;
104
105const NPT: usize = [NT, MAXNP][(NT < MAXNP) as usize]; //Note: fancy const replacement for non-const usize::max(NT, MAXNP)
106
107struct CodeIterator<'a> {
108 src: &'a BitSlice<u8, Msb0>,
109 src_index: usize,
110 is_error: bool,
111 remaining_block_size: usize,
112 left: [u16; 2 * NC - 1],
113 right: [u16; 2 * NC - 1],
114 c_len: [u8; NC],
115 pt_len: [u8; NPT],
116 c_table: [u16; 1 << CTABLE_BITSIZE],
117 pt_table: [u16; 1 << PTABLE_BITSIZE],
118 p_bit: usize,
119}
120
121impl<'a> CodeIterator<'a> {
122 // initialize a new CodeIterator instance for the given source and algorithm
123 fn new(src: &'a [u8], algo: DecompressionAlgorithm) -> Self {
124 Self {
125 src: src.view_bits::<Msb0>(),
126 src_index: 0,
127 is_error: false,
128 remaining_block_size: 0,
129 left: [0u16; 2 * NC - 1],
130 right: [0u16; 2 * NC - 1],
131 c_len: [0u8; NC],
132 pt_len: [0u8; NPT],
133 c_table: [0u16; 4096],
134 pt_table: [0u16; 256],
135 p_bit: match algo {
136 DecompressionAlgorithm::UefiDecompress => 4,
137 DecompressionAlgorithm::TianoDecompress => 5,
138 },
139 }
140 }
141
142 // advances the source bitstream by `count` bits.
143 fn pop_bits(&mut self, count: usize) -> Result<&BitSlice<u8, Msb0>, DecompressError> {
144 if let Some(bitslice) = self.src.get(self.src_index..self.src_index + count) {
145 self.src_index += count;
146 Ok(bitslice)
147 } else {
148 Err(DecompressError::MalformedSrcData)
149 }
150 }
151
152 // returns the next `count` bits of the source bitstream without advancing it.
153 fn peek_bits(&self, count: usize) -> Result<&BitSlice<u8, Msb0>, DecompressError> {
154 if let Some(bitslice) = self.src.get(self.src_index..self.src_index + count) {
155 Ok(bitslice)
156 } else {
157 Err(DecompressError::MalformedSrcData)
158 }
159 }
160
161 // Reads the code lengths for the Extra Set or Position Set Huffman codes for the current block.
162 //
163 // The code lengths are preceded by a `num_bits`-sized field that gives the length of the array.
164 //
165 // This is then followed by an encoded set of lengths which use a variable number of bits:
166 // - If the code length is less than 7, it is encoded as a 3-bit binary number.
167 // - If the code length is 7 or greater, it is encoded as a series of '1b' followed by a terminating '0b'.
168 // The code length is therefore equal to "count of 1s" + 4.
169 // Example: "4" is coded as '100b', "7" is coded as '1110b', and "12" is coded as `111111110b`
170 //
171 // If the 'extra' flag is set, then after the third length element in the bitstream, there is a 2-bit field
172 // indicating the number of additional zero lengths that follow. For example, the following array of lengths
173 // [2,9,0,0,5,7] would be encoded with the following bit stream (num_bits size field not shown).
174 // 010 111110 10 101 1110
175 // ^
176 // this is the `extra` field added to generate the 2 "zero" lengths
177 // If the extra flag is not set, the same array of lengths would be encoded with the following bitstream
178 // 010 111110 000 000 101 1110
179 //
180 // The resulting code length array will be stored in self.pt_len.
181 //
182 // Once the code length array is generated, it is fed to the the Self::build_huffman_table() routine
183 // to generate the resulting Huffman code table, which will be stored in self.pt_table.
184 //
185 // Refer to UEFI Specification 2.10, section 19.2.3.1.
186 //
187 fn read_pt_len(&mut self, num_symbols: usize, num_bits: usize, extra: bool) -> Result<(), DecompressError> {
188 assert!(num_symbols <= NPT);
189
190 // Read Set Length Array size
191 let count = self.pop_bits(num_bits)?.load_be::<usize>();
192 if count == 0 {
193 // this represents the only Huffman code used.
194 let char_c = self.pop_bits(num_bits)?.load_be::<u16>();
195 self.pt_table.fill(char_c);
196 self.pt_len[..num_symbols].fill(0);
197 Ok(())
198 } else {
199 let mut idx = 0;
200 while idx < count && idx < NPT {
201 // if a code length is less than 7, it is encoded as 3-bit value. Otherwise it is encoded by a series of
202 // 1s followed by a terminating zero. The number of 1s = code length - 4.
203 let mut code_len = self.pop_bits(3)?.load_be::<u8>();
204 if code_len == 7 {
205 loop {
206 let bit = self.pop_bits(1)?[0];
207 if bit {
208 //current bit is one.
209 code_len += 1;
210 } else {
211 break;
212 }
213 }
214 }
215 self.pt_len[idx] = code_len;
216 idx += 1;
217
218 // if 'extra' is set, then after the third length of the code length concatenation, a 2-bit value is
219 // used to indicate the number of consecutive zero lengths immediately after the third length.
220 if extra && idx == 3 {
221 let zero_count = self.pop_bits(2)?.load_be::<usize>();
222 self.pt_len[idx..idx + zero_count].fill(0);
223 idx += zero_count;
224 }
225 }
226 if idx > num_symbols {
227 Err(DecompressError::MalformedSrcData)?;
228 }
229 // zero the rest of the table.
230 self.pt_len[idx..num_symbols].fill(0);
231
232 //convert the resulting code length array (self.pt_len) into a Huffman coding table (self.pt_table)
233 Self::build_huffman_table(
234 num_symbols,
235 &self.pt_len,
236 PTABLE_BITSIZE,
237 &mut self.pt_table,
238 &mut self.left,
239 &mut self.right,
240 )
241 }
242 }
243
244 // Read the code lengths for the Char&Length set Huffman code for the current block.
245 //
246 // The code lengths are preceded by a 9-bit field that gives the length of the array.
247 //
248 // This is then followed by an encoded set of lengths which use a variable number of bits. The set of lengths is
249 // double-encoded:
250 //
251 // 1: If a code length is not zero, then it is encoded as "code length + 2";
252 // If a code length is zero, then the number of consecutive zero lengths starting from this code length is
253 // counted:
254 // - if the count is equal to or less than 2, then the code "0" is used for each zero length;
255 // - if the count is greater than 2 and less than 19, then the code "1" followed by a 4-bit value of "count - 3"
256 // is used for these consecutive zero lengths;
257 // - if the count is equal to 19, then it is treated as "1 + 18," and a code "0" and a code "1" followed by a
258 // 4-bit value of "15" are used for these consecutive zero lengths;
259 // - if the count is greater than 19, then the code "2" followed by a 9-bit value of "count - 20" is used for
260 // these consecutive zero lengths.
261 // 2: The resulting bitstring symbols are the "extra set", and are encoded using Huffman coding. The tables derived
262 // from execution of the read_pt_len() function on the extra set can be used to decode these symbols.
263 //
264 // To decode the table, the above process is reversed. First, the Huffman coded "extra set" symbols are decoded,
265 // then the resulting symbols are converted into a code length by reversing the step 1 above.
266 //
267 // The resulting code length array will be stored in self.c_len.
268 //
269 // Once the code length array is generated, it is fed to the the Self::build_huffman_table() routine
270 // to generate the resulting Huffman code table, which will be stored in self.c_table.
271 //
272 // Refer to UEFI Specification 2.10, section 19.2.3.1.
273 //
274 // NOTE: this routine requires that the current contents of self.pt_len, self.pt_table, self.left, and self.right
275 // are initialized to match the "Extra Set" by executing read_pt_len() to decode the Extra Set Code Length Array.
276 //
277 fn read_c_len(&mut self) -> Result<(), DecompressError> {
278 // Read Set Length Array Size
279 let count = self.pop_bits(CBIT)?.load_be::<usize>();
280
281 if count == 0 {
282 // this represents the only Huffman code used
283 let symbol = self.pop_bits(CBIT)?.load_be::<u16>();
284 self.c_len.fill(0);
285 self.c_table.fill(symbol);
286 Ok(())
287 } else {
288 // iterate over all the symbols in the array.
289 let mut idx = 0;
290 while idx < count {
291 // read the next symbol. First, read the first PTABLE_BITSIZE bits of the symbol.
292 let mut symbol = self.pt_table[self.peek_bits(PTABLE_BITSIZE)?.load_be::<usize>()];
293 // if the symbol is less than NT, then it can be used as-is
294 if symbol as usize >= NT {
295 // symbol is larger than NT. Read bits from the stream and traverse the left/right tree until a leaf
296 // node (less than NT) is reached.
297 let mut mask_idx = PTABLE_BITSIZE;
298 loop {
299 let bit_buff = self.peek_bits(mask_idx + 1)?;
300 if bit_buff[mask_idx] {
301 symbol = self.right[symbol as usize];
302 } else {
303 symbol = self.left[symbol as usize];
304 }
305 mask_idx += 1;
306 if (symbol as usize) < NT {
307 break;
308 }
309 }
310 }
311
312 //now that we know the symbol, advance the bitstream by the symbol bitlength.
313 self.pop_bits(self.pt_len[symbol as usize] as usize)?;
314
315 if symbol <= 2 {
316 // if the symbol is 2 or less, it encodes 1 or more zero length symbols
317 if symbol == 0 {
318 // a single zero length
319 symbol = 1;
320 } else if symbol == 1 {
321 // '1' followed by a 4-bit value of count - 3 zero lengths follow.
322 symbol = self.pop_bits(4)?.load_be::<u16>() + 3;
323 } else if symbol == 2 {
324 // '2' followed by a 9-bit value of count - 20 zero lengths follow.
325 symbol = self.pop_bits(CBIT)?.load_be::<u16>() + 20;
326 }
327
328 //"symbol" now contains the consecutive number of zero-length symbols starting at the current idx.
329 //update the c_len table entries corresponding to these symbols and advance the index.
330 for _ in 0..symbol {
331 if idx >= self.c_len.len() {
332 Err(DecompressError::MalformedSrcData)?;
333 }
334 self.c_len[idx] = 0;
335 idx += 1;
336 }
337 } else {
338 // otherwise, the symbol encodes 'code length +2'. store it in c_len and advance the index.
339 if idx >= self.c_len.len() {
340 Err(DecompressError::MalformedSrcData)?;
341 }
342 self.c_len[idx] = (symbol - 2) as u8;
343 idx += 1;
344 }
345 }
346 // all valid symbols processed, zero the rest of c_len.
347 self.c_len[idx..NC].fill(0);
348
349 //convert the resulting code length array (self.c_len) into a Huffman coding table (self.c_table)
350 Self::build_huffman_table(
351 NC,
352 &self.c_len,
353 CTABLE_BITSIZE,
354 &mut self.c_table,
355 &mut self.left,
356 &mut self.right,
357 )
358 }
359 }
360
361 // Decodes a "position" value from the current bitstream according to the Position Set encoding.
362 //
363 // A String Position is a value that indicates the distance between the current position and the target string. The
364 // String Position value is defined as "Current Position - Starting Position of the target string - 1." The String
365 // Position value ranges from 0 to 8190 (so 8192 is the "sliding window" size, and this range should be ensured by
366 // the compressor). The lengths of the String Position values (in binary form) form a value set ranging from 0 to 13
367 // (it is assumed that value 0 has length of 0). This value set is the Position Set for Huffman Coding. The full
368 // representation of a String Position value is composed of two consecutive parts: one is the Huffman code for the
369 // value length; the other is the actual String Position value of "length - 1" bits (excluding the highest bit since
370 // the highest bit is always "1"). For example, String Position value 18 is represented as: Huffman code for "5"
371 // followed by "0010." If the value length is 0 or 1, then no value is appended to the Huffman code.
372 //
373 // NOTE: this routine requires that the current contents of self.pt_len, self.pt_table, self.left, and self.right
374 // are initialized to match the "Position Set" by executing read_pt_len() to decode the Position Set Code Length
375 // Array.
376 fn decode_position(&mut self) -> Result<usize, DecompressError> {
377 //First, read the first PTABLE_BITSIZE bits of the position symbol.
378 let bit_buffer = self.peek_bits(PTABLE_BITSIZE)?;
379 let mut val = self.pt_table[bit_buffer.load_be::<usize>()] as usize;
380
381 // if the symbol is less than NT, then it can be used as-is
382 if val >= MAXNP {
383 // symbol is larger than NT. Read bits from the stream and traverse the left/right tree until a leaf
384 // node (less than NT) is reached.
385 let mut mask_idx = PTABLE_BITSIZE;
386 loop {
387 let bit_buffer = self.peek_bits(mask_idx + 1)?;
388 if bit_buffer[mask_idx] {
389 val = self.right[val] as usize;
390 } else {
391 val = self.left[val] as usize;
392 }
393
394 mask_idx += 1;
395
396 if val < MAXNP {
397 break;
398 }
399 }
400 }
401 self.pop_bits(self.pt_len[val] as usize)?;
402
403 // if val is <= 1, then it directly encodes the position
404 if val > 1 {
405 // otherwise, (val - 1) encodes the bit length of an integer that encodes the position.
406 val = (1 << (val - 1)) + self.pop_bits(val - 1)?.load_be::<usize>();
407 }
408
409 Ok(val)
410 }
411
412 // Constructs a Huffman decode table + tree.
413 //
414 // input parameters:
415 // num_symbols: number of symbols in the Huffman symbol set
416 // bit_lengths: a table describing the code length for each symbol (indexed by the symbol)
417 // table_bits: the number of bits to be used for fixed symbol lookup. Symbols with an encoded bitlength longer than
418 // this parameter will require traversing the secondary tree to fully decode.
419 //
420 // modifies:
421 // table: the fixed decode table (see description below)
422 // left: the "left" nodes of the secondary decoder tree.
423 // right: the right" nodes of the secondary decoder tree.
424 //
425 // This routine takes as input the bit_lengths table representing the canonical Huffman encoding over the output
426 // symbols. It then generates 3 different table structures in the slices given as input:
427 // - table: this table consists of two sets of entries.
428 // - fixed lookup entries - this consists of fixed entries for all symbols where the length of the encoded
429 // bitstring is less than or equal to the table_bits. For a given symbol, all entries that have that symbol as
430 // a prefix are set to the decoded value of the symbol. For example, assume that the bitstring `100b` is the
431 // encoded representation of the value 0xB - in that case, all of the entries of the table that start with
432 // `100xxxxxxxxxb` (i.e. indexes 0x800 to 0x9FF) would be set to 0xB.
433 // - tree lookup root entry - if the length of the encoded symbol is longer than the table bits, then the unique
434 // prefix of that entry points to the index of the root of a secondary decode tree encoded in the left & right
435 // array structures. "Leaf" elements of the tree occupy the first `num_symbol` entries in the left and right
436 // arrays, and correspond to literal final symbols. "Node" elements of the tree occupy the entries higher than
437 // `num_symbol` in the left and and right arrays and point to other nodes or leaves.
438 //
439 // To decode the final symbol for an encoded bitstring that is longer than table_size bits, first locate the
440 // locate the entry within the table that corresponds to the root index in the left/right trees. Then, starting
441 // with the bit immediately following the first table_size bits of the encoded symbol, read bits from the
442 // encoded symbol. For each bit, if it is a 1, retrieve the next index from the `right` array, otherwise if it
443 // is a 0, retrieve the next index from the `left`. If the retrieved index is less than `num_symbol`, then it
444 // is the final decoded symbol. Otherwise, it is the index into the left or right tree for the next bit.
445 //
446 // Note: if all possible symbols can be encoded within the fixed table width, then the secondary lookup is not
447 // needed.
448 //
449 // - left & right - the secondary decode tree as described above.
450 //
451 // Note: This implementation shares the "left & right" tables between the Char&Len symbol Set decode and the
452 // Position Set decode; the portions of left & right used by each decode are disjoint. Care is taken to ensure that
453 // constructing a table only modifies left & right indices associated with that table.
454 fn build_huffman_table(
455 num_symbols: usize,
456 bit_lengths: &[u8],
457 table_bits: usize,
458 table: &mut [u16],
459 left: &mut [u16],
460 right: &mut [u16],
461 ) -> Result<(), DecompressError> {
462 assert!(table_bits <= 16);
463
464 // calculate the number of symbols for each bit length.
465 let mut count = [0u16; 17];
466 for idx in 0..num_symbols {
467 if bit_lengths[idx] > 16 {
468 Err(DecompressError::MalformedSrcData)?;
469 }
470 count[bit_lengths[idx] as usize] += 1;
471 }
472
473 // Determine the start index for each bit length. This determines the start index within the fixed size decode
474 // table for all symbols of a given bit length.
475 let mut start = [0u16; 18];
476 for idx in 1..=16 {
477 let word_of_start = start[idx];
478 let word_of_count = count[idx] << (16 - idx);
479 start[idx + 1] = word_of_start.wrapping_add(word_of_count);
480 }
481 if start[17] != 0 {
482 Err(DecompressError::MalformedSrcData)?;
483 }
484
485 // extended_bits is the number bits in the symbol exceeding the bit length for fixed entries in the table.
486 let extended_bits = 16 - table_bits;
487
488 // Determine weight of each length (the number of entries that a given symbol length will consume in the table).
489 let mut weight = [0; 17];
490 for idx in 1..=table_bits {
491 start[idx] >>= extended_bits;
492 weight[idx] = 1 << (table_bits - idx);
493 }
494
495 for (idx, w) in weight.iter_mut().enumerate().skip(table_bits + 1) {
496 *w = 1 << (16 - idx)
497 }
498
499 // zero unused table entries.
500 let idx = start[table_bits + 1] >> extended_bits;
501 if idx != 0 {
502 let idx_3 = 1 << table_bits;
503 if idx < idx_3 {
504 table[idx as usize..idx_3 as usize].fill(0);
505 }
506 }
507
508 // Private helper structure used in the implementation below to simplify construction of the secondary tree.
509 enum TablePointer {
510 Table(usize),
511 Left(usize),
512 Right(usize),
513 }
514 impl TablePointer {
515 fn set(&self, table: &mut [u16], left: &mut [u16], right: &mut [u16], val: u16) {
516 match self {
517 TablePointer::Table(idx) => table[*idx] = val,
518 TablePointer::Left(idx) => left[*idx] = val,
519 TablePointer::Right(idx) => right[*idx] = val,
520 }
521 }
522
523 fn get(&self, table: &mut [u16], left: &mut [u16], right: &mut [u16]) -> u16 {
524 match self {
525 TablePointer::Table(idx) => table[*idx],
526 TablePointer::Left(idx) => left[*idx],
527 TablePointer::Right(idx) => right[*idx],
528 }
529 }
530 }
531
532 // tracks the next available node
533 let mut next_avail_node = num_symbols;
534 // mask used to check the bit for left vs. right construction
535 let mask = 1 << (15 - table_bits);
536
537 // iterate over all symbols in the alphabet to generate the table.
538 for (char, sym_bit_len) in bit_lengths.iter().enumerate().take(num_symbols) {
539 let sym_bit_len = *sym_bit_len as usize;
540
541 // if the symbol length is zero, it is unused.
542 if sym_bit_len == 0 {
543 continue;
544 }
545
546 // max symbol length is fixed at 16 by spec, so encountering a larger symbol length is an error.
547 if sym_bit_len > 16 {
548 Err(DecompressError::MalformedSrcData)?;
549 }
550
551 // get the next code.
552 let next_code = start[sym_bit_len].wrapping_add(weight[sym_bit_len]);
553
554 if sym_bit_len <= table_bits {
555 // the symbol is short enough that tree construction is not needed.
556
557 // verify start and next sanity.
558 if start[sym_bit_len] >= next_code || next_code > 1 << table_bits {
559 Err(DecompressError::MalformedSrcData)?;
560 }
561
562 // fill in all the elements in the table for which this symbol is a prefix.
563 for idx in start[sym_bit_len]..next_code {
564 table[idx as usize] = char.try_into().expect("symbol count too large");
565 }
566 } else {
567 // the symbol is long enough that tree construction is required.
568 let mut symbol_bitstring = start[sym_bit_len];
569 let mut pointer = TablePointer::Table((symbol_bitstring >> extended_bits) as usize);
570 let mut idx = sym_bit_len - table_bits;
571
572 // traverse the tree using the extended bits in the symbol bitstring to select nodes
573 while idx != 0 {
574 if pointer.get(table, left, right) == 0 && next_avail_node < (2 * NC - 1) {
575 pointer.set(table, left, right, next_avail_node.try_into().expect("symbol count too large"));
576 right[next_avail_node] = 0;
577 left[next_avail_node] = 0;
578 next_avail_node += 1;
579 }
580
581 if pointer.get(table, left, right) < (2 * NC - 1) as u16 {
582 if symbol_bitstring & mask != 0 {
583 pointer = TablePointer::Right(pointer.get(table, left, right) as usize);
584 } else {
585 pointer = TablePointer::Left(pointer.get(table, left, right) as usize);
586 }
587 }
588
589 symbol_bitstring <<= 1;
590 idx -= 1;
591 }
592 // set the final node to the decoded symbol.
593 pointer.set(table, left, right, char.try_into().expect("symbol count too large"));
594 }
595
596 //update the start index for this bit length
597 start[sym_bit_len] = next_code;
598 }
599 Ok(())
600 }
601}
602
603impl Iterator for CodeIterator<'_> {
604 type Item = Result<CodeSymbol, DecompressError>;
605
606 // Returns the next CodeSymbol from the bitstream.
607 fn next(&mut self) -> Option<Self::Item> {
608 if self.is_error {
609 return None;
610 }
611 if self.remaining_block_size == 0 {
612 //Starting a new block - re-initialize block state.
613
614 //Read new block size.
615 self.remaining_block_size = match self.pop_bits(16) {
616 Ok(bits) => bits.load_be::<u16>() as usize,
617 Err(err) => {
618 self.is_error = true;
619 return Some(Err(err));
620 }
621 };
622
623 // Read in Extra Set Array and generate Huffman code mapping table for extra set used to decode Char&Len set.
624 if let Err(err) = self.read_pt_len(NT, TBIT, true) {
625 self.is_error = true;
626 return Some(Err(err));
627 }
628
629 // Read in Char&Len Set Array and generate Huffman code mapping table for Char&Len set.
630 if let Err(err) = self.read_c_len() {
631 self.is_error = true;
632 return Some(Err(err));
633 }
634
635 // Read in the Position Set Array and generate Huffman code mapping table for the Position set.
636 if let Err(err) = self.read_pt_len(MAXNP, self.p_bit, false) {
637 self.is_error = true;
638 return Some(Err(err));
639 }
640 }
641 self.remaining_block_size -= 1;
642
643 // Decode the next Char&Len symbol. First, find the index in the c_table by peeking the next 12 bits.
644 let bit_buff = match self.peek_bits(CTABLE_BITSIZE) {
645 Ok(buff) => buff,
646 Err(err) => {
647 self.is_error = true;
648 return Some(Err(err));
649 }
650 };
651 let mut decode_idx = self.c_table[bit_buff.load_be::<usize>()] as usize;
652
653 // If the index is larger than NC, then reconstruct the symbol by traversing the secondary decode tree.
654 // see read_c_len() for details of how this is done.
655 if decode_idx >= NC {
656 let mut mask_idx = CTABLE_BITSIZE;
657 loop {
658 let bit_buff = match self.peek_bits(mask_idx + 1) {
659 Ok(buff) => buff,
660 Err(err) => {
661 self.is_error = true;
662 return Some(Err(err));
663 }
664 };
665 if bit_buff[mask_idx] {
666 decode_idx = self.right[decode_idx] as usize;
667 } else {
668 decode_idx = self.left[decode_idx] as usize;
669 }
670 mask_idx += 1;
671 if decode_idx < NC {
672 break;
673 };
674 }
675 }
676 //decode_idx the current symbol. Advance the bitstream by the bitlength of the current symbol.
677 if let Err(err) = self.pop_bits(self.c_len[decode_idx] as usize) {
678 self.is_error = true;
679 return Some(Err(err));
680 }
681
682 //convert the symbol to the appropriate CodeSymbol
683 if decode_idx < 256 {
684 // symbols from 0-255 are byte literals.
685 Some(Ok(CodeSymbol::OrigChar(decode_idx as u8)))
686 } else {
687 // symbols greater than 255 are string lengths.
688 let len = decode_idx - (0x100 - 3);
689
690 // string lengths are followed by an encoded string position; invoke decode_position() to decode it.
691 let pos = match self.decode_position() {
692 Ok(pos) => pos,
693 Err(err) => {
694 self.is_error = true;
695 return Some(Err(err));
696 }
697 };
698
699 Some(Ok(CodeSymbol::StrPointer(pos, len)))
700 }
701 }
702}
703
704#[cfg(test)]
705mod test {
706 extern crate std;
707 use std::{fs::File, io::Read, iter::zip, println, time, vec, vec::Vec};
708
709 use crate::decompress_into_with_algo;
710
711 macro_rules! test_collateral {
712 ($fname:expr) => {
713 concat!(env!("CARGO_MANIFEST_DIR"), "/resources/test/", $fname)
714 };
715 }
716
717 #[test]
718 fn uefi_decompress_should_produce_expected_buffer() {
719 let mut compressed_file =
720 File::open(test_collateral!("uefi_compressed.bin")).expect("failed to open test file");
721 let mut compressed_buffer = Vec::new();
722
723 compressed_file.read_to_end(&mut compressed_buffer).expect("failed to read test file");
724
725 let mut uncompressed_file =
726 File::open(test_collateral!("uefi_uncompressed.bin")).expect("failed to open test file");
727 let mut uncompressed_buffer = Vec::new();
728 uncompressed_file.read_to_end(&mut uncompressed_buffer).expect("failed to read test file");
729
730 let mut test_buffer = vec![0u8; uncompressed_buffer.len()];
731
732 decompress_into_with_algo(&compressed_buffer, &mut test_buffer, crate::DecompressionAlgorithm::UefiDecompress)
733 .unwrap();
734 assert_eq!(test_buffer.len(), uncompressed_buffer.len());
735 for (idx, (test, reference)) in zip(test_buffer, uncompressed_buffer).enumerate() {
736 assert!(test == reference, "mismatch at idx: {:}, expected {:#x} != {:#x} actual", idx, reference, test);
737 }
738 }
739
740 #[test]
741 fn tiano_decompress_should_produce_expected_buffer() {
742 let mut compressed_file =
743 File::open(test_collateral!("tiano_compressed.bin")).expect("failed to open test file");
744 let mut compressed_buffer = Vec::new();
745
746 compressed_file.read_to_end(&mut compressed_buffer).expect("failed to read test file");
747
748 let mut uncompressed_file =
749 File::open(test_collateral!("tiano_uncompressed.bin")).expect("failed to open test file");
750 let mut uncompressed_buffer = Vec::new();
751 uncompressed_file.read_to_end(&mut uncompressed_buffer).expect("failed to read test file");
752
753 let mut test_buffer = vec![0u8; uncompressed_buffer.len()];
754
755 decompress_into_with_algo(&compressed_buffer, &mut test_buffer, crate::DecompressionAlgorithm::TianoDecompress)
756 .unwrap();
757 assert_eq!(test_buffer.len(), uncompressed_buffer.len());
758 for (idx, (test, reference)) in zip(test_buffer, uncompressed_buffer).enumerate() {
759 assert!(test == reference, "mismatch at idx: {:}, expected {:#x} != {:#x} actual", idx, reference, test);
760 }
761 }
762
763 #[test]
764 fn decompress_with_original_size_of_zero_should_return_zero_sized_buffer() {
765 // Setup a compressed buffer where the original size is zero but the compressed size is non-zero.
766 // This is represented by a 16-byte buffer where the first byte is 0x08 (indicating compressed size is 8).
767 let mut compressed_buffer = [0x0; 16];
768 compressed_buffer[0] = 0x08;
769
770 let mut uefi_uncompressed = Vec::new();
771 assert!(decompress_into_with_algo(&compressed_buffer, &mut uefi_uncompressed, crate::DecompressionAlgorithm::UefiDecompress).is_ok());
772 assert_eq!(uefi_uncompressed.len(), 0);
773
774 let mut tiano_uncompressed = Vec::new();
775 assert!(decompress_into_with_algo(&compressed_buffer, &mut tiano_uncompressed, crate::DecompressionAlgorithm::TianoDecompress).is_ok());
776 assert_eq!(tiano_uncompressed.len(), 0);
777 }
778
779 #[test]
780 fn fuzz_testing_should_fail_gracefully() {
781 const FUZZ_COUNT: usize = 100;
782 let mut compressed_file =
783 File::open(test_collateral!("uefi_compressed.bin")).expect("failed to open test file");
784 let mut compressed_buffer = Vec::new();
785
786 compressed_file.read_to_end(&mut compressed_buffer).expect("failed to read test file");
787
788 let mut uncompressed_file =
789 File::open(test_collateral!("uefi_uncompressed.bin")).expect("failed to open test file");
790 let mut uncompressed_buffer = Vec::new();
791 uncompressed_file.read_to_end(&mut uncompressed_buffer).expect("failed to read test file");
792
793 let uncompressed_len = uncompressed_buffer.len();
794
795 for _ in 0..FUZZ_COUNT {
796 let mut fuzz_buffer = compressed_buffer.clone();
797 let fuzz_time = time::SystemTime::now().duration_since(time::UNIX_EPOCH).unwrap().as_micros() as usize;
798 let fuzz_idx = fuzz_time % fuzz_buffer.len();
799 println!("fuzz_idx: {:} before: {:#x}", fuzz_idx, fuzz_buffer[fuzz_idx]);
800 fuzz_buffer[fuzz_idx] ^= 0xff;
801 println!("fuzz_idx: {:} after: {:#x}", fuzz_idx, fuzz_buffer[fuzz_idx]);
802
803 let mut test_buffer = vec![0u8; uncompressed_len];
804
805 //note: not all corruption can be successfully detected. most of the time (but not all) this will return an Err.
806 //the goal of the test is to ensure failure doesn't panic, not that bad data is always caught.
807 let _ = decompress_into_with_algo(
808 &fuzz_buffer,
809 &mut test_buffer,
810 crate::DecompressionAlgorithm::UefiDecompress,
811 );
812 }
813 }
814}