libtw2_huffman/
lib.rs

1use arrayvec::ArrayVec;
2use buffer::with_buffer;
3use buffer::Buffer;
4use buffer::BufferRef;
5use itertools::Itertools;
6use libtw2_common::num::Cast;
7use std::error;
8use std::fmt;
9use std::fmt::Write;
10use std::slice;
11
12/// Compresses some bytes using the Teeworlds-specific Huffman code into a
13/// given buffer.
14///
15/// # Errors
16///
17/// Returns an error if the given buffer does not have enough capacity for the
18/// compressed bytes.
19pub fn compress_into<'a, B: Buffer<'a>>(
20    input: &[u8],
21    buffer: B,
22) -> Result<&'a [u8], buffer::CapacityError> {
23    instances::TEEWORLDS.compress(input, buffer)
24}
25
26/// Compresses some bytes using the Teeworlds-specific Huffman code.
27pub fn compress(input: &[u8]) -> Vec<u8> {
28    instances::TEEWORLDS.compress_into_vec(input)
29}
30
31/// Decompresses some bytes using the Teeworlds-specific Huffman code into a
32/// given buffer.
33///
34/// # Errors
35///
36/// Returns an error if the given bytes aren't a valid compression, or if the
37/// given buffer does not have enough capacity for the uncompressed bytes.
38pub fn decompress_into<'a, B: Buffer<'a>>(
39    input: &[u8],
40    buffer: B,
41) -> Result<&'a [u8], DecompressionError> {
42    instances::TEEWORLDS.decompress(input, buffer)
43}
44
45/// Decompresses some bytes using the Teeworlds-specific Huffman code.
46///
47/// # Errors
48///
49/// Returns an error if the given bytes aren't a valid compression, or if the
50/// given buffer does not have enough capacity for the uncompressed bytes.
51pub fn decompress(input: &[u8]) -> Result<Vec<u8>, InvalidInput> {
52    instances::TEEWORLDS.decompress_into_vec(input)
53}
54
55#[doc(hidden)]
56pub mod instances;
57
58const EOF: u16 = 256;
59#[doc(hidden)]
60pub const NUM_SYMBOLS: u16 = EOF + 1;
61const NUM_NODES: usize = NUM_SYMBOLS as usize * 2 - 1;
62const ROOT_IDX: u16 = NUM_NODES as u16 - 1;
63#[doc(hidden)]
64pub const NUM_FREQUENCIES: usize = 256;
65
66#[doc(hidden)]
67pub struct Huffman {
68    nodes: [Node; NUM_NODES],
69}
70
71/// Error when decompressing into a given buffer.
72///
73/// Either `Capacity`, which means that the given buffer was too small, or
74/// `InvalidInput` which means that the given bytes weren't a valid
75/// compression.
76#[derive(Debug)]
77pub enum DecompressionError {
78    Capacity(buffer::CapacityError),
79    InvalidInput,
80}
81
82impl fmt::Display for DecompressionError {
83    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84        use self::DecompressionError::*;
85        match self {
86            Capacity(_) => "output buffer too small",
87            InvalidInput => "input is not a valid huffman compression",
88        }
89        .fmt(f)
90    }
91}
92
93impl error::Error for DecompressionError {}
94
95/// Error returned when the bytes given for decompression weren't a valid
96/// compression.
97#[derive(Debug)]
98pub struct InvalidInput;
99
100impl fmt::Display for InvalidInput {
101    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
102        "input is not a valid huffman compression".fmt(f)
103    }
104}
105
106impl error::Error for InvalidInput {}
107
108impl From<InvalidInput> for DecompressionError {
109    fn from(InvalidInput: InvalidInput) -> DecompressionError {
110        DecompressionError::InvalidInput
111    }
112}
113
114#[doc(hidden)]
115#[derive(Clone)]
116pub struct Repr<'a> {
117    repr: &'a [Node],
118}
119
120#[doc(hidden)]
121#[derive(Clone)]
122pub struct ReprIter<'a> {
123    iter: slice::Iter<'a, Node>,
124}
125
126impl<'a> IntoIterator for Repr<'a> {
127    type Item = SymbolRepr;
128    type IntoIter = ReprIter<'a>;
129    fn into_iter(self) -> ReprIter<'a> {
130        ReprIter {
131            iter: self.repr.iter(),
132        }
133    }
134}
135
136impl<'a> Iterator for ReprIter<'a> {
137    type Item = SymbolRepr;
138    fn next(&mut self) -> Option<SymbolRepr> {
139        self.iter.next().map(|n| n.to_symbol_repr())
140    }
141    fn size_hint(&self) -> (usize, Option<usize>) {
142        self.iter.size_hint()
143    }
144}
145
146impl<'a> ExactSizeIterator for ReprIter<'a> {
147    fn len(&self) -> usize {
148        self.iter.len()
149    }
150}
151
152impl<'a> DoubleEndedIterator for ReprIter<'a> {
153    fn next_back(&mut self) -> Option<SymbolRepr> {
154        self.iter.next_back().map(|n| n.to_symbol_repr())
155    }
156}
157
158struct Bits {
159    byte: u8,
160    remaining_bits: u8,
161}
162
163impl Bits {
164    fn new(byte: u8) -> Bits {
165        Bits {
166            byte: byte,
167            remaining_bits: 8,
168        }
169    }
170}
171
172impl Iterator for Bits {
173    type Item = bool;
174    fn next(&mut self) -> Option<bool> {
175        if self.remaining_bits == 0 {
176            return None;
177        }
178        self.remaining_bits -= 1;
179        let result = (self.byte & 1) != 0;
180        self.byte >>= 1;
181        Some(result)
182    }
183    fn size_hint(&self) -> (usize, Option<usize>) {
184        (self.len(), Some(self.len()))
185    }
186}
187
188impl ExactSizeIterator for Bits {
189    fn len(&self) -> usize {
190        self.remaining_bits.usize()
191    }
192}
193
194#[derive(Clone, Copy, Debug)]
195struct Frequency {
196    frequency: u32,
197    node_idx: u16,
198}
199
200impl Huffman {
201    pub fn from_frequencies(frequencies: &[u32]) -> Huffman {
202        assert!(frequencies.len() == 256);
203        let array = unsafe { &*(frequencies as *const _ as *const _) };
204        Huffman::from_frequencies_array(array)
205    }
206    pub fn from_frequencies_array(frequencies: &[u32; 256]) -> Huffman {
207        let mut frequencies: ArrayVec<[_; 512]> = frequencies
208            .iter()
209            .cloned()
210            .enumerate()
211            .map(|(i, f)| Frequency {
212                frequency: f,
213                node_idx: i.assert_u16(),
214            })
215            .collect();
216        frequencies.push(Frequency {
217            frequency: 1,
218            node_idx: EOF,
219        });
220
221        let mut nodes: ArrayVec<[_; 1024]> = (0..NUM_SYMBOLS).map(|_| NODE_SENTINEL).collect();
222
223        while frequencies.len() > 1 {
224            // Sort in reverse (upper to lower)!
225            frequencies.sort_by(|a, b| b.frequency.cmp(&a.frequency));
226
227            // `frequencies.len() > 1`, so these always succeed.
228            let freq1 = frequencies.pop().unwrap();
229            let freq2 = frequencies.pop().unwrap();
230
231            // Combine the nodes into one.
232            let node = Node {
233                children: [freq1.node_idx, freq2.node_idx],
234            };
235            let node_idx = nodes.len().assert_u16();
236            let node_freq = Frequency {
237                frequency: freq1.frequency.saturating_add(freq2.frequency),
238                node_idx: node_idx,
239            };
240
241            nodes.push(node);
242            frequencies.push(node_freq);
243        }
244
245        // We use a `top` variable as virtual extension of `stack` in order to
246        // have less `unwrap`s.
247        let mut stack: ArrayVec<[u16; 24]> = ArrayVec::new();
248        let mut top = ROOT_IDX;
249
250        let mut bits = 0;
251        let mut first = true;
252
253        // Use a depth-first traversal of the tree, exploring the left children
254        // of each node first, in order to set the bit patterns of the leaves.
255        loop {
256            // On first iteration, don't try to go up the tree.
257            if !first {
258                if let Some(t) = stack.pop() {
259                    top = t;
260                } else {
261                    break;
262                }
263                let b = 1 << stack.len().assert_u8();
264                if bits & b != 0 {
265                    bits &= !b;
266                    continue;
267                }
268                bits |= b;
269                stack.push(top);
270                top = nodes[top.usize()].children[1];
271            }
272            first = false;
273
274            while top >= NUM_SYMBOLS {
275                stack.push(top);
276                top = nodes[top.usize()].children[0];
277            }
278
279            nodes[top.usize()] = SymbolRepr {
280                bits: bits,
281                num_bits: stack.len().assert_u8(),
282            }
283            .to_node();
284        }
285
286        let mut result = Huffman {
287            nodes: [NODE_SENTINEL; NUM_NODES],
288        };
289        assert!(result.nodes.iter_mut().set_from(nodes.iter().cloned()) == NUM_NODES);
290        result
291    }
292    fn compressed_bit_len(&self, input: &[u8]) -> usize {
293        input
294            .iter()
295            .map(|&b| self.symbol_bit_length(b.u16()))
296            .fold(0, |s, a| s + a.usize())
297            + self.symbol_bit_length(EOF).usize()
298    }
299    pub fn compressed_len(&self, input: &[u8]) -> usize {
300        (self.compressed_bit_len(input) + 7) / 8
301    }
302    /// This function returns the number of bytes the reference implementation
303    /// uses to compress the input bytes.
304    ///
305    /// This might differ by 1 from `compressed_len` in case the compressed bit
306    /// stream would perfectly fit into bytes.
307    pub fn compressed_len_bug(&self, input: &[u8]) -> usize {
308        self.compressed_bit_len(input) / 8 + 1
309    }
310    pub fn compress<'a, B: Buffer<'a>>(
311        &self,
312        input: &[u8],
313        buffer: B,
314    ) -> Result<&'a [u8], buffer::CapacityError> {
315        with_buffer(buffer, |b| self.compress_impl(input, b, false))
316    }
317    pub fn compress_into_vec(&self, input: &[u8]) -> Vec<u8> {
318        // At most 3 bytes per symbol, i.e. input byte. Plus EOF symbol.
319        let mut result = Vec::with_capacity(input.len() * 3 + 3);
320        self.compress(input, &mut result).unwrap();
321        result.shrink_to_fit();
322        result
323    }
324    pub fn compress_bug<'a, B: Buffer<'a>>(
325        &self,
326        input: &[u8],
327        buffer: B,
328    ) -> Result<&'a [u8], buffer::CapacityError> {
329        with_buffer(buffer, |b| self.compress_impl(input, b, true))
330    }
331    fn compress_impl<'d, 's>(
332        &self,
333        input: &[u8],
334        mut buffer: BufferRef<'d, 's>,
335        bug: bool,
336    ) -> Result<&'d [u8], buffer::CapacityError> {
337        unsafe {
338            let len = self
339                .compress_impl_unsafe(input, buffer.uninitialized_mut(), bug)
340                .map_err(|()| buffer::CapacityError)?;
341            buffer.advance(len);
342            Ok(buffer.initialized())
343        }
344    }
345    fn compress_impl_unsafe(
346        &self,
347        input: &[u8],
348        buffer: &mut [u8],
349        bug: bool,
350    ) -> Result<usize, ()> {
351        let mut len = 0;
352        let mut output = buffer.into_iter();
353        let mut output_byte = 0;
354        let mut num_output_bits = 0;
355        for s in input.into_iter().map(|b| b.u16()).chain(Some(EOF)) {
356            let symbol = self.get_node(s).unwrap_err();
357            let mut bits_written = 0;
358            if symbol.num_bits >= 8 - num_output_bits {
359                output_byte |= (symbol.bits << num_output_bits) as u8;
360                *output.next().ok_or(())? = output_byte;
361                len += 1;
362                bits_written += 8 - num_output_bits;
363                while symbol.num_bits - bits_written >= 8 {
364                    output_byte = (symbol.bits >> bits_written) as u8;
365                    *output.next().ok_or(())? = output_byte;
366                    len += 1;
367                    bits_written += 8;
368                }
369                num_output_bits = 0;
370                output_byte = 0;
371            }
372            output_byte |= ((symbol.bits >> bits_written) << num_output_bits) as u8;
373            num_output_bits += symbol.num_bits - bits_written;
374        }
375        if num_output_bits > 0 || bug {
376            *output.next().ok_or(())? = output_byte;
377            len += 1;
378        }
379        Ok(len)
380    }
381
382    pub fn decompress<'a, B: Buffer<'a>>(
383        &self,
384        input: &[u8],
385        buffer: B,
386    ) -> Result<&'a [u8], DecompressionError> {
387        with_buffer(buffer, |b| self.decompress_impl(input, b))
388    }
389    pub fn decompress_into_vec(&self, input: &[u8]) -> Result<Vec<u8>, InvalidInput> {
390        // At most one output byte per input bit.
391        let mut result = Vec::with_capacity(input.len() * 8);
392        match self.decompress(input, &mut result) {
393            Ok(_) => {}
394            Err(DecompressionError::InvalidInput) => return Err(InvalidInput),
395            // If the buffer does not suffice, it means we have a runaway
396            // decompression.
397            Err(DecompressionError::Capacity(buffer::CapacityError)) => return Err(InvalidInput),
398        }
399        result.shrink_to_fit();
400        Ok(result)
401    }
402    fn decompress_impl<'d, 's>(
403        &self,
404        input: &[u8],
405        mut buffer: BufferRef<'d, 's>,
406    ) -> Result<&'d [u8], DecompressionError> {
407        unsafe {
408            let len = self
409                .decompress_unsafe(input, buffer.uninitialized_mut())
410                .map_err(|()| DecompressionError::Capacity(buffer::CapacityError))?;
411            buffer.advance(len);
412            Ok(buffer.initialized())
413        }
414    }
415    fn decompress_unsafe(&self, input: &[u8], buffer: &mut [u8]) -> Result<usize, ()> {
416        let mut len = 0;
417        {
418            let mut input = input.into_iter();
419            let mut output = buffer.into_iter();
420            let root = self.get_node(ROOT_IDX).unwrap();
421            let mut node = root;
422            'outer: loop {
423                let &byte = input.next().unwrap_or(&0);
424                for bit in Bits::new(byte) {
425                    let new_idx = node.children[bit as usize];
426                    if let Ok(n) = self.get_node(new_idx) {
427                        node = n;
428                    } else {
429                        if new_idx == EOF {
430                            break 'outer;
431                        }
432                        *output.next().ok_or(())? = new_idx.assert_u8();
433                        len += 1;
434                        node = root;
435                    }
436                }
437            }
438        }
439        Ok(len)
440    }
441    fn symbol_bit_length(&self, idx: u16) -> u32 {
442        self.get_node(idx).unwrap_err().num_bits()
443    }
444    fn get_node(&self, idx: u16) -> Result<Node, SymbolRepr> {
445        let n = self.nodes[idx.usize()];
446        if idx >= NUM_SYMBOLS {
447            Ok(n)
448        } else {
449            Err(n.to_symbol_repr())
450        }
451    }
452    pub fn repr(&self) -> Repr {
453        Repr {
454            repr: &self.nodes[..NUM_SYMBOLS.usize()],
455        }
456    }
457}
458
459#[derive(Clone, Copy, Debug, Eq, PartialEq)]
460struct Node {
461    children: [u16; 2],
462}
463
464#[doc(hidden)]
465#[derive(Clone, Copy, Eq, PartialEq)]
466pub struct SymbolRepr {
467    bits: u32, // u24
468    num_bits: u8,
469}
470
471const NODE_SENTINEL: Node = Node { children: [!0, !0] };
472
473impl Node {
474    fn to_symbol_repr(self) -> SymbolRepr {
475        SymbolRepr {
476            bits: ((self.children[0] & 0xff) as u32) << 16 | self.children[1] as u32,
477            num_bits: (self.children[0] >> 8) as u8,
478        }
479    }
480}
481
482impl SymbolRepr {
483    fn to_node(self) -> Node {
484        assert!(self.bits >> 24 == 0);
485        Node {
486            children: [
487                (self.num_bits as u16) << 8 | (self.bits >> 16) as u16,
488                self.bits as u16,
489            ],
490        }
491    }
492    pub fn num_bits(self) -> u32 {
493        self.num_bits.u32()
494    }
495    pub fn bit(self, idx: u32) -> bool {
496        assert!(idx < self.num_bits());
497        ((self.bits >> idx) & 1) != 0
498    }
499}
500
501impl fmt::Debug for SymbolRepr {
502    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
503        for i in 0..self.num_bits() {
504            f.write_char(if self.bit(i) { '1' } else { '0' })?;
505        }
506        Ok(())
507    }
508}
509
510impl fmt::Display for SymbolRepr {
511    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
512        fmt::Debug::fmt(self, f)
513    }
514}
515
516#[cfg(test)]
517mod test {
518    use super::Node;
519    use super::SymbolRepr;
520    use quickcheck::quickcheck;
521
522    quickcheck! {
523        fn roundtrip_node(v: (u16, u16)) -> bool {
524            let n = Node { children: [v.0, v.1] };
525            n.to_symbol_repr().to_node() == n
526        }
527
528        fn roundtrip_symbol(v: (u32, u8)) -> bool {
529            let v0 = v.0 ^ ((v.0 >> 24) << 24);
530            let s = SymbolRepr { bits: v0, num_bits: v.1 };
531            s.to_node().to_symbol_repr() == s
532        }
533    }
534}