1use arrayvec::ArrayVec;
2use buffer::with_buffer;
3use buffer::Buffer;
4use buffer::BufferRef;
5use itertools::Itertools;
6use libtw2_common::num::Cast;
7use std::error;
8use std::fmt;
9use std::fmt::Write;
10use std::slice;
11
12pub fn compress_into<'a, B: Buffer<'a>>(
20 input: &[u8],
21 buffer: B,
22) -> Result<&'a [u8], buffer::CapacityError> {
23 instances::TEEWORLDS.compress(input, buffer)
24}
25
26pub fn compress(input: &[u8]) -> Vec<u8> {
28 instances::TEEWORLDS.compress_into_vec(input)
29}
30
31pub fn decompress_into<'a, B: Buffer<'a>>(
39 input: &[u8],
40 buffer: B,
41) -> Result<&'a [u8], DecompressionError> {
42 instances::TEEWORLDS.decompress(input, buffer)
43}
44
45pub fn decompress(input: &[u8]) -> Result<Vec<u8>, InvalidInput> {
52 instances::TEEWORLDS.decompress_into_vec(input)
53}
54
55#[doc(hidden)]
56pub mod instances;
57
58const EOF: u16 = 256;
59#[doc(hidden)]
60pub const NUM_SYMBOLS: u16 = EOF + 1;
61const NUM_NODES: usize = NUM_SYMBOLS as usize * 2 - 1;
62const ROOT_IDX: u16 = NUM_NODES as u16 - 1;
63#[doc(hidden)]
64pub const NUM_FREQUENCIES: usize = 256;
65
66#[doc(hidden)]
67pub struct Huffman {
68 nodes: [Node; NUM_NODES],
69}
70
71#[derive(Debug)]
77pub enum DecompressionError {
78 Capacity(buffer::CapacityError),
79 InvalidInput,
80}
81
82impl fmt::Display for DecompressionError {
83 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84 use self::DecompressionError::*;
85 match self {
86 Capacity(_) => "output buffer too small",
87 InvalidInput => "input is not a valid huffman compression",
88 }
89 .fmt(f)
90 }
91}
92
93impl error::Error for DecompressionError {}
94
95#[derive(Debug)]
98pub struct InvalidInput;
99
100impl fmt::Display for InvalidInput {
101 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
102 "input is not a valid huffman compression".fmt(f)
103 }
104}
105
106impl error::Error for InvalidInput {}
107
108impl From<InvalidInput> for DecompressionError {
109 fn from(InvalidInput: InvalidInput) -> DecompressionError {
110 DecompressionError::InvalidInput
111 }
112}
113
114#[doc(hidden)]
115#[derive(Clone)]
116pub struct Repr<'a> {
117 repr: &'a [Node],
118}
119
120#[doc(hidden)]
121#[derive(Clone)]
122pub struct ReprIter<'a> {
123 iter: slice::Iter<'a, Node>,
124}
125
126impl<'a> IntoIterator for Repr<'a> {
127 type Item = SymbolRepr;
128 type IntoIter = ReprIter<'a>;
129 fn into_iter(self) -> ReprIter<'a> {
130 ReprIter {
131 iter: self.repr.iter(),
132 }
133 }
134}
135
136impl<'a> Iterator for ReprIter<'a> {
137 type Item = SymbolRepr;
138 fn next(&mut self) -> Option<SymbolRepr> {
139 self.iter.next().map(|n| n.to_symbol_repr())
140 }
141 fn size_hint(&self) -> (usize, Option<usize>) {
142 self.iter.size_hint()
143 }
144}
145
146impl<'a> ExactSizeIterator for ReprIter<'a> {
147 fn len(&self) -> usize {
148 self.iter.len()
149 }
150}
151
152impl<'a> DoubleEndedIterator for ReprIter<'a> {
153 fn next_back(&mut self) -> Option<SymbolRepr> {
154 self.iter.next_back().map(|n| n.to_symbol_repr())
155 }
156}
157
158struct Bits {
159 byte: u8,
160 remaining_bits: u8,
161}
162
163impl Bits {
164 fn new(byte: u8) -> Bits {
165 Bits {
166 byte: byte,
167 remaining_bits: 8,
168 }
169 }
170}
171
172impl Iterator for Bits {
173 type Item = bool;
174 fn next(&mut self) -> Option<bool> {
175 if self.remaining_bits == 0 {
176 return None;
177 }
178 self.remaining_bits -= 1;
179 let result = (self.byte & 1) != 0;
180 self.byte >>= 1;
181 Some(result)
182 }
183 fn size_hint(&self) -> (usize, Option<usize>) {
184 (self.len(), Some(self.len()))
185 }
186}
187
188impl ExactSizeIterator for Bits {
189 fn len(&self) -> usize {
190 self.remaining_bits.usize()
191 }
192}
193
194#[derive(Clone, Copy, Debug)]
195struct Frequency {
196 frequency: u32,
197 node_idx: u16,
198}
199
200impl Huffman {
201 pub fn from_frequencies(frequencies: &[u32]) -> Huffman {
202 assert!(frequencies.len() == 256);
203 let array = unsafe { &*(frequencies as *const _ as *const _) };
204 Huffman::from_frequencies_array(array)
205 }
206 pub fn from_frequencies_array(frequencies: &[u32; 256]) -> Huffman {
207 let mut frequencies: ArrayVec<[_; 512]> = frequencies
208 .iter()
209 .cloned()
210 .enumerate()
211 .map(|(i, f)| Frequency {
212 frequency: f,
213 node_idx: i.assert_u16(),
214 })
215 .collect();
216 frequencies.push(Frequency {
217 frequency: 1,
218 node_idx: EOF,
219 });
220
221 let mut nodes: ArrayVec<[_; 1024]> = (0..NUM_SYMBOLS).map(|_| NODE_SENTINEL).collect();
222
223 while frequencies.len() > 1 {
224 frequencies.sort_by(|a, b| b.frequency.cmp(&a.frequency));
226
227 let freq1 = frequencies.pop().unwrap();
229 let freq2 = frequencies.pop().unwrap();
230
231 let node = Node {
233 children: [freq1.node_idx, freq2.node_idx],
234 };
235 let node_idx = nodes.len().assert_u16();
236 let node_freq = Frequency {
237 frequency: freq1.frequency.saturating_add(freq2.frequency),
238 node_idx: node_idx,
239 };
240
241 nodes.push(node);
242 frequencies.push(node_freq);
243 }
244
245 let mut stack: ArrayVec<[u16; 24]> = ArrayVec::new();
248 let mut top = ROOT_IDX;
249
250 let mut bits = 0;
251 let mut first = true;
252
253 loop {
256 if !first {
258 if let Some(t) = stack.pop() {
259 top = t;
260 } else {
261 break;
262 }
263 let b = 1 << stack.len().assert_u8();
264 if bits & b != 0 {
265 bits &= !b;
266 continue;
267 }
268 bits |= b;
269 stack.push(top);
270 top = nodes[top.usize()].children[1];
271 }
272 first = false;
273
274 while top >= NUM_SYMBOLS {
275 stack.push(top);
276 top = nodes[top.usize()].children[0];
277 }
278
279 nodes[top.usize()] = SymbolRepr {
280 bits: bits,
281 num_bits: stack.len().assert_u8(),
282 }
283 .to_node();
284 }
285
286 let mut result = Huffman {
287 nodes: [NODE_SENTINEL; NUM_NODES],
288 };
289 assert!(result.nodes.iter_mut().set_from(nodes.iter().cloned()) == NUM_NODES);
290 result
291 }
292 fn compressed_bit_len(&self, input: &[u8]) -> usize {
293 input
294 .iter()
295 .map(|&b| self.symbol_bit_length(b.u16()))
296 .fold(0, |s, a| s + a.usize())
297 + self.symbol_bit_length(EOF).usize()
298 }
299 pub fn compressed_len(&self, input: &[u8]) -> usize {
300 (self.compressed_bit_len(input) + 7) / 8
301 }
302 pub fn compressed_len_bug(&self, input: &[u8]) -> usize {
308 self.compressed_bit_len(input) / 8 + 1
309 }
310 pub fn compress<'a, B: Buffer<'a>>(
311 &self,
312 input: &[u8],
313 buffer: B,
314 ) -> Result<&'a [u8], buffer::CapacityError> {
315 with_buffer(buffer, |b| self.compress_impl(input, b, false))
316 }
317 pub fn compress_into_vec(&self, input: &[u8]) -> Vec<u8> {
318 let mut result = Vec::with_capacity(input.len() * 3 + 3);
320 self.compress(input, &mut result).unwrap();
321 result.shrink_to_fit();
322 result
323 }
324 pub fn compress_bug<'a, B: Buffer<'a>>(
325 &self,
326 input: &[u8],
327 buffer: B,
328 ) -> Result<&'a [u8], buffer::CapacityError> {
329 with_buffer(buffer, |b| self.compress_impl(input, b, true))
330 }
331 fn compress_impl<'d, 's>(
332 &self,
333 input: &[u8],
334 mut buffer: BufferRef<'d, 's>,
335 bug: bool,
336 ) -> Result<&'d [u8], buffer::CapacityError> {
337 unsafe {
338 let len = self
339 .compress_impl_unsafe(input, buffer.uninitialized_mut(), bug)
340 .map_err(|()| buffer::CapacityError)?;
341 buffer.advance(len);
342 Ok(buffer.initialized())
343 }
344 }
345 fn compress_impl_unsafe(
346 &self,
347 input: &[u8],
348 buffer: &mut [u8],
349 bug: bool,
350 ) -> Result<usize, ()> {
351 let mut len = 0;
352 let mut output = buffer.into_iter();
353 let mut output_byte = 0;
354 let mut num_output_bits = 0;
355 for s in input.into_iter().map(|b| b.u16()).chain(Some(EOF)) {
356 let symbol = self.get_node(s).unwrap_err();
357 let mut bits_written = 0;
358 if symbol.num_bits >= 8 - num_output_bits {
359 output_byte |= (symbol.bits << num_output_bits) as u8;
360 *output.next().ok_or(())? = output_byte;
361 len += 1;
362 bits_written += 8 - num_output_bits;
363 while symbol.num_bits - bits_written >= 8 {
364 output_byte = (symbol.bits >> bits_written) as u8;
365 *output.next().ok_or(())? = output_byte;
366 len += 1;
367 bits_written += 8;
368 }
369 num_output_bits = 0;
370 output_byte = 0;
371 }
372 output_byte |= ((symbol.bits >> bits_written) << num_output_bits) as u8;
373 num_output_bits += symbol.num_bits - bits_written;
374 }
375 if num_output_bits > 0 || bug {
376 *output.next().ok_or(())? = output_byte;
377 len += 1;
378 }
379 Ok(len)
380 }
381
382 pub fn decompress<'a, B: Buffer<'a>>(
383 &self,
384 input: &[u8],
385 buffer: B,
386 ) -> Result<&'a [u8], DecompressionError> {
387 with_buffer(buffer, |b| self.decompress_impl(input, b))
388 }
389 pub fn decompress_into_vec(&self, input: &[u8]) -> Result<Vec<u8>, InvalidInput> {
390 let mut result = Vec::with_capacity(input.len() * 8);
392 match self.decompress(input, &mut result) {
393 Ok(_) => {}
394 Err(DecompressionError::InvalidInput) => return Err(InvalidInput),
395 Err(DecompressionError::Capacity(buffer::CapacityError)) => return Err(InvalidInput),
398 }
399 result.shrink_to_fit();
400 Ok(result)
401 }
402 fn decompress_impl<'d, 's>(
403 &self,
404 input: &[u8],
405 mut buffer: BufferRef<'d, 's>,
406 ) -> Result<&'d [u8], DecompressionError> {
407 unsafe {
408 let len = self
409 .decompress_unsafe(input, buffer.uninitialized_mut())
410 .map_err(|()| DecompressionError::Capacity(buffer::CapacityError))?;
411 buffer.advance(len);
412 Ok(buffer.initialized())
413 }
414 }
415 fn decompress_unsafe(&self, input: &[u8], buffer: &mut [u8]) -> Result<usize, ()> {
416 let mut len = 0;
417 {
418 let mut input = input.into_iter();
419 let mut output = buffer.into_iter();
420 let root = self.get_node(ROOT_IDX).unwrap();
421 let mut node = root;
422 'outer: loop {
423 let &byte = input.next().unwrap_or(&0);
424 for bit in Bits::new(byte) {
425 let new_idx = node.children[bit as usize];
426 if let Ok(n) = self.get_node(new_idx) {
427 node = n;
428 } else {
429 if new_idx == EOF {
430 break 'outer;
431 }
432 *output.next().ok_or(())? = new_idx.assert_u8();
433 len += 1;
434 node = root;
435 }
436 }
437 }
438 }
439 Ok(len)
440 }
441 fn symbol_bit_length(&self, idx: u16) -> u32 {
442 self.get_node(idx).unwrap_err().num_bits()
443 }
444 fn get_node(&self, idx: u16) -> Result<Node, SymbolRepr> {
445 let n = self.nodes[idx.usize()];
446 if idx >= NUM_SYMBOLS {
447 Ok(n)
448 } else {
449 Err(n.to_symbol_repr())
450 }
451 }
452 pub fn repr(&self) -> Repr {
453 Repr {
454 repr: &self.nodes[..NUM_SYMBOLS.usize()],
455 }
456 }
457}
458
459#[derive(Clone, Copy, Debug, Eq, PartialEq)]
460struct Node {
461 children: [u16; 2],
462}
463
464#[doc(hidden)]
465#[derive(Clone, Copy, Eq, PartialEq)]
466pub struct SymbolRepr {
467 bits: u32, num_bits: u8,
469}
470
471const NODE_SENTINEL: Node = Node { children: [!0, !0] };
472
473impl Node {
474 fn to_symbol_repr(self) -> SymbolRepr {
475 SymbolRepr {
476 bits: ((self.children[0] & 0xff) as u32) << 16 | self.children[1] as u32,
477 num_bits: (self.children[0] >> 8) as u8,
478 }
479 }
480}
481
482impl SymbolRepr {
483 fn to_node(self) -> Node {
484 assert!(self.bits >> 24 == 0);
485 Node {
486 children: [
487 (self.num_bits as u16) << 8 | (self.bits >> 16) as u16,
488 self.bits as u16,
489 ],
490 }
491 }
492 pub fn num_bits(self) -> u32 {
493 self.num_bits.u32()
494 }
495 pub fn bit(self, idx: u32) -> bool {
496 assert!(idx < self.num_bits());
497 ((self.bits >> idx) & 1) != 0
498 }
499}
500
501impl fmt::Debug for SymbolRepr {
502 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
503 for i in 0..self.num_bits() {
504 f.write_char(if self.bit(i) { '1' } else { '0' })?;
505 }
506 Ok(())
507 }
508}
509
510impl fmt::Display for SymbolRepr {
511 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
512 fmt::Debug::fmt(self, f)
513 }
514}
515
516#[cfg(test)]
517mod test {
518 use super::Node;
519 use super::SymbolRepr;
520 use quickcheck::quickcheck;
521
522 quickcheck! {
523 fn roundtrip_node(v: (u16, u16)) -> bool {
524 let n = Node { children: [v.0, v.1] };
525 n.to_symbol_repr().to_node() == n
526 }
527
528 fn roundtrip_symbol(v: (u32, u8)) -> bool {
529 let v0 = v.0 ^ ((v.0 >> 24) << 24);
530 let s = SymbolRepr { bits: v0, num_bits: v.1 };
531 s.to_node().to_symbol_repr() == s
532 }
533 }
534}