Skip to main content

palimpsest_dataflow/trace/implementations/
huffman_container.rs

1//! A slice container that Huffman encodes its contents.
2
3use std::collections::BTreeMap;
4use timely::container::PushInto;
5
6use crate::trace::implementations::{BatchContainer, OffsetList};
7
8use self::encoded::Encoded;
9use self::huffman::Huffman;
10use self::wrapper::Wrapped;
11
12/// A container that contains slices `[B]` as items.
13pub struct HuffmanContainer<B: Ord + Clone> {
14    /// Either encoded data or raw data.
15    inner: Result<(Huffman<B>, Vec<u8>), Vec<B>>,
16    /// Offsets that bound each contained slice.
17    ///
18    /// The length will be one greater than the number of contained items.
19    offsets: OffsetList,
20    /// Counts of the number of each pattern we've seen.
21    stats: BTreeMap<B, i64>,
22}
23
24impl<B: Ord + Clone> HuffmanContainer<B> {
25    /// Prints statistics about encoded containers.
26    pub fn print(&self) {
27        if let Ok((_huff, bytes)) = &self.inner {
28            println!(
29                "Bytes: {:?}, Symbols: {:?}",
30                bytes.len(),
31                self.stats.values().sum::<i64>()
32            );
33        }
34    }
35}
36
37impl<'a, B: Ord + Clone + 'static> PushInto<&'a Vec<B>> for HuffmanContainer<B> {
38    fn push_into(&mut self, item: &'a Vec<B>) {
39        for x in item.iter() {
40            *self.stats.entry(x.clone()).or_insert(0) += 1;
41        }
42        match &mut self.inner {
43            Ok((huffman, bytes)) => {
44                bytes.extend(huffman.encode(item.iter()));
45                self.offsets.push(bytes.len());
46            }
47            Err(raw) => {
48                raw.extend(item.iter().cloned());
49                self.offsets.push(raw.len());
50            }
51        }
52    }
53}
54
55impl<'a, B: Ord + Clone + 'static> PushInto<Wrapped<'a, B>> for HuffmanContainer<B> {
56    fn push_into(&mut self, item: Wrapped<'a, B>) {
57        match item.decode() {
58            Ok(decoded) => {
59                for x in decoded {
60                    *self.stats.entry(x.clone()).or_insert(0) += 1;
61                }
62            }
63            Err(symbols) => {
64                for x in symbols.iter() {
65                    *self.stats.entry(x.clone()).or_insert(0) += 1;
66                }
67            }
68        }
69        match (item.decode(), &mut self.inner) {
70            (Ok(decoded), Ok((huffman, bytes))) => {
71                bytes.extend(huffman.encode(decoded));
72                self.offsets.push(bytes.len());
73            }
74            (Ok(decoded), Err(raw)) => {
75                raw.extend(decoded.cloned());
76                self.offsets.push(raw.len());
77            }
78            (Err(symbols), Ok((huffman, bytes))) => {
79                bytes.extend(huffman.encode(symbols.iter()));
80                self.offsets.push(bytes.len());
81            }
82            (Err(symbols), Err(raw)) => {
83                raw.extend(symbols.iter().cloned());
84                self.offsets.push(raw.len());
85            }
86        }
87    }
88}
89
90impl<B: Ord + Clone + 'static> BatchContainer for HuffmanContainer<B> {
91    type Owned = Vec<B>;
92    type ReadItem<'a> = Wrapped<'a, B>;
93
94    fn into_owned<'a>(item: Self::ReadItem<'a>) -> Self::Owned {
95        match item.decode() {
96            Ok(decode) => decode.cloned().collect(),
97            Err(bytes) => bytes.to_vec(),
98        }
99    }
100    fn clone_onto<'a>(item: Self::ReadItem<'a>, other: &mut Self::Owned) {
101        other.clear();
102        match item.decode() {
103            Ok(decode) => other.extend(decode.cloned()),
104            Err(bytes) => other.extend_from_slice(bytes),
105        }
106    }
107    fn reborrow<'b, 'a: 'b>(item: Self::ReadItem<'a>) -> Self::ReadItem<'b> {
108        item
109    }
110
111    fn push_ref(&mut self, item: Self::ReadItem<'_>) {
112        self.push_into(item)
113    }
114    fn push_own(&mut self, item: &Self::Owned) {
115        self.push_into(item)
116    }
117
118    fn clear(&mut self) {
119        *self = Self::default();
120    }
121
122    fn with_capacity(size: usize) -> Self {
123        let mut offsets = OffsetList::with_capacity(size + 1);
124        offsets.push(0);
125        Self {
126            inner: Err(Vec::with_capacity(size)),
127            offsets,
128            stats: Default::default(),
129        }
130    }
131    fn merge_capacity(cont1: &Self, cont2: &Self) -> Self {
132        if cont1.len() > 0 {
133            cont1.print();
134        }
135        if cont2.len() > 0 {
136            cont2.print();
137        }
138
139        let mut counts = BTreeMap::default();
140        for (symbol, count) in cont1.stats.iter() {
141            *counts.entry(symbol.clone()).or_insert(0) += count;
142        }
143        for (symbol, count) in cont2.stats.iter() {
144            *counts.entry(symbol.clone()).or_insert(0) += count;
145        }
146
147        let bytes = Vec::with_capacity(counts.values().cloned().sum::<i64>() as usize);
148        let huffman = Huffman::create_from(counts);
149        let inner = Ok((huffman, bytes));
150        // : Err(Vec::with_capacity(length))
151
152        let length = cont1.offsets.len() + cont2.offsets.len() - 2;
153        let mut offsets = OffsetList::with_capacity(length + 1);
154        offsets.push(0);
155        Self {
156            inner,
157            offsets,
158            stats: Default::default(),
159        }
160    }
161    fn index(&self, index: usize) -> Self::ReadItem<'_> {
162        let lower = self.offsets.index(index);
163        let upper = self.offsets.index(index + 1);
164        match &self.inner {
165            Ok((huffman, bytes)) => Wrapped::encoded(Encoded::new(huffman, &bytes[lower..upper])),
166            Err(raw) => Wrapped::decoded(&raw[lower..upper]),
167        }
168    }
169    fn len(&self) -> usize {
170        self.offsets.len() - 1
171    }
172}
173/// Default implementation introduces a first offset.
174impl<B: Ord + Clone> Default for HuffmanContainer<B> {
175    fn default() -> Self {
176        let mut offsets = OffsetList::with_capacity(1);
177        offsets.push(0);
178        Self {
179            inner: Err(Vec::new()),
180            offsets,
181            stats: Default::default(),
182        }
183    }
184}
185
186mod wrapper {
187
188    use super::Encoded;
189
190    pub struct Wrapped<'a, B: Ord> {
191        pub(crate) inner: Result<Encoded<'a, B>, &'a [B]>,
192    }
193
194    impl<'a, B: Ord> Wrapped<'a, B> {
195        /// Returns either a decoding iterator, or just the bytes themselves.
196        pub fn decode(&'a self) -> Result<impl Iterator<Item = &'a B> + 'a, &'a [B]> {
197            match &self.inner {
198                Ok(encoded) => Ok(encoded.decode()),
199                Err(symbols) => Err(symbols),
200            }
201        }
202        /// A wrapper around an encoded sequence.
203        pub fn encoded(e: Encoded<'a, B>) -> Self {
204            Self { inner: Ok(e) }
205        }
206        /// A wrapper around a decoded sequence.
207        pub fn decoded(d: &'a [B]) -> Self {
208            Self { inner: Err(d) }
209        }
210    }
211
212    impl<'a, B: Ord> Copy for Wrapped<'a, B> {}
213    impl<'a, B: Ord> Clone for Wrapped<'a, B> {
214        fn clone(&self) -> Self {
215            *self
216        }
217    }
218
219    use std::cmp::Ordering;
220    impl<'a, 'b, B: Ord> PartialEq<Wrapped<'a, B>> for Wrapped<'b, B> {
221        fn eq(&self, other: &Wrapped<'a, B>) -> bool {
222            match (self.decode(), other.decode()) {
223                (Ok(decode1), Ok(decode2)) => decode1.eq(decode2),
224                (Ok(decode1), Err(bytes2)) => decode1.eq(bytes2.iter()),
225                (Err(bytes1), Ok(decode2)) => bytes1.iter().eq(decode2),
226                (Err(bytes1), Err(bytes2)) => bytes1.eq(bytes2),
227            }
228        }
229    }
230    impl<'a, B: Ord> Eq for Wrapped<'a, B> {}
231    impl<'a, 'b, B: Ord> PartialOrd<Wrapped<'a, B>> for Wrapped<'b, B> {
232        fn partial_cmp(&self, other: &Wrapped<'a, B>) -> Option<Ordering> {
233            match (self.decode(), other.decode()) {
234                (Ok(decode1), Ok(decode2)) => decode1.partial_cmp(decode2),
235                (Ok(decode1), Err(bytes2)) => decode1.partial_cmp(bytes2.iter()),
236                (Err(bytes1), Ok(decode2)) => bytes1.iter().partial_cmp(decode2),
237                (Err(bytes1), Err(bytes2)) => bytes1.partial_cmp(bytes2),
238            }
239        }
240    }
241    impl<'a, B: Ord> Ord for Wrapped<'a, B> {
242        fn cmp(&self, other: &Self) -> Ordering {
243            self.partial_cmp(other).unwrap()
244        }
245    }
246}
247
248/// Wrapper around a Huffman decoder and byte slices, decodeable to a byte sequence.
249mod encoded {
250
251    use super::Huffman;
252
253    /// Welcome to GATs!
254    pub struct Encoded<'a, B: Ord> {
255        /// Text that decorates the data.
256        huffman: &'a Huffman<B>,
257        /// The data itself.
258        bytes: &'a [u8],
259    }
260
261    impl<'a, B: Ord> Encoded<'a, B> {
262        /// Returns either a decoding iterator, or just the bytes themselves.
263        pub fn decode(&'a self) -> impl Iterator<Item = &'a B> + 'a {
264            self.huffman.decode(self.bytes.iter().cloned())
265        }
266        pub fn new(huffman: &'a Huffman<B>, bytes: &'a [u8]) -> Self {
267            Self { huffman, bytes }
268        }
269    }
270
271    impl<'a, B: Ord> Copy for Encoded<'a, B> {}
272    impl<'a, B: Ord> Clone for Encoded<'a, B> {
273        fn clone(&self) -> Self {
274            *self
275        }
276    }
277}
278
279mod huffman {
280
281    use std::collections::BTreeMap;
282    use std::convert::TryInto;
283
284    use self::decoder::Decoder;
285    use self::encoder::Encoder;
286
287    /// Encoding and decoding state for Huffman codes.
288    pub struct Huffman<T: Ord> {
289        /// byte indexed description of what to blat down for encoding.
290        /// An entry `(bits, code)` indicates that the low `bits` of `code` should be blatted down.
291        /// Probably every `code` fits in a `u64`, unless there are crazy frequencies?
292        encode: BTreeMap<T, (usize, u64)>,
293        /// Byte-by-byte decoder.
294        decode: [Decode<T>; 256],
295    }
296    impl<T: Ord> Huffman<T> {
297        /// Encodes the provided symbols as a sequence of bytes.
298        ///
299        /// The last byte may only contain partial information, but it should be recorded as presented,
300        /// as we haven't a way to distinguish (e.g. a `Result` return type).
301        pub fn encode<'a, I>(&'a self, symbols: I) -> Encoder<'a, T, I::IntoIter>
302        where
303            I: IntoIterator<Item = &'a T>,
304        {
305            Encoder::new(&self.encode, symbols.into_iter())
306        }
307
308        /// Decodes the provided bytes as a sequence of symbols.
309        pub fn decode<I>(&self, bytes: I) -> Decoder<'_, T, I::IntoIter>
310        where
311            I: IntoIterator<Item = u8>,
312        {
313            Decoder::new(&self.decode, bytes.into_iter())
314        }
315
316        pub fn create_from(counts: BTreeMap<T, i64>) -> Self
317        where
318            T: Clone,
319        {
320            if counts.is_empty() {
321                return Self {
322                    encode: Default::default(),
323                    decode: Decode::map(),
324                };
325            }
326
327            let mut heap = std::collections::BinaryHeap::new();
328            for (item, count) in counts {
329                heap.push((-count, Node::Leaf(item)));
330            }
331            let mut tree = Vec::with_capacity(2 * heap.len() - 1);
332            while heap.len() > 1 {
333                let (count1, least1) = heap.pop().unwrap();
334                let (count2, least2) = heap.pop().unwrap();
335                let fork = Node::Fork(tree.len(), tree.len() + 1);
336                tree.push(least1);
337                tree.push(least2);
338                heap.push((count1 + count2, fork));
339            }
340            tree.push(heap.pop().unwrap().1);
341
342            let mut levels = Vec::with_capacity(1 + tree.len() / 2);
343            let mut todo = vec![(tree.last().unwrap(), 0)];
344            while let Some((node, level)) = todo.pop() {
345                match node {
346                    Node::Leaf(sym) => {
347                        levels.push((level, sym));
348                    }
349                    Node::Fork(l, r) => {
350                        todo.push((&tree[*l], level + 1));
351                        todo.push((&tree[*r], level + 1));
352                    }
353                }
354            }
355            levels.sort_by(|x, y| x.0.cmp(&y.0));
356            let mut code: u64 = 0;
357            let mut prev_level = 0;
358            let mut encode = BTreeMap::new();
359            let mut decode = Decode::map();
360            for (level, sym) in levels {
361                if prev_level != level {
362                    code <<= level - prev_level;
363                    prev_level = level;
364                }
365                encode.insert(sym.clone(), (level, code));
366                Self::insert_decode(&mut decode, sym, level, code << (64 - level));
367
368                code += 1;
369            }
370
371            for (index, entry) in decode.iter().enumerate() {
372                if entry.any_void() {
373                    panic!("VOID FOUND: {:?}", index);
374                }
375            }
376
377            Huffman { encode, decode }
378        }
379
380        /// Inserts a symbol, and
381        fn insert_decode(map: &mut [Decode<T>; 256], symbol: &T, bits: usize, code: u64)
382        where
383            T: Clone,
384        {
385            let byte: u8 = (code >> 56).try_into().unwrap();
386            if bits <= 8 {
387                for off in 0..(1 << (8 - bits)) {
388                    map[(byte as usize) + off] = Decode::Symbol(symbol.clone(), bits);
389                }
390            } else {
391                if let Decode::Void = &map[byte as usize] {
392                    map[byte as usize] = Decode::Further(Box::new(Decode::map()));
393                }
394                if let Decode::Further(next_map) = &mut map[byte as usize] {
395                    Self::insert_decode(next_map, symbol, bits - 8, code << 8);
396                }
397            }
398        }
399    }
400    /// Tree structure for Huffman bit length determination.
401    #[derive(Eq, PartialEq, Ord, PartialOrd, Debug)]
402    enum Node<T> {
403        Leaf(T),
404        Fork(usize, usize),
405    }
406
407    /// Decoder
408    #[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Default)]
409    pub enum Decode<T> {
410        /// An as-yet unfilled slot.
411        #[default]
412        Void,
413        /// The symbol, and the number of bits consumed.
414        Symbol(T, usize),
415        /// An additional map to push subsequent bytes at.
416        Further(Box<[Decode<T>; 256]>),
417    }
418
419    impl<T> Decode<T> {
420        /// Tests to see if the map contains any invalid values.
421        ///
422        /// A correctly initialized map will have no invalid values.
423        /// A map with invalid values will be unable to decode some
424        /// input byte sequences.
425        fn any_void(&self) -> bool {
426            match self {
427                Decode::Void => true,
428                Decode::Symbol(_, _) => false,
429                Decode::Further(map) => map.iter().any(|m| m.any_void()),
430            }
431        }
432        /// Creates a new map containing invalid values.
433        fn map() -> [Decode<T>; 256] {
434            let mut vec = Vec::with_capacity(256);
435            for _ in 0..256 {
436                vec.push(Decode::Void);
437            }
438            vec.try_into().ok().unwrap()
439        }
440    }
441
442    /// A tabled Huffman decoder, written as an iterator.
443    mod decoder {
444
445        use super::Decode;
446
447        #[derive(Copy, Clone)]
448        pub struct Decoder<'a, T, I> {
449            decode: &'a [Decode<T>; 256],
450            bytes: I,
451            pending_byte: u16,
452            pending_bits: usize,
453        }
454
455        impl<'a, T, I> Decoder<'a, T, I> {
456            pub fn new(decode: &'a [Decode<T>; 256], bytes: I) -> Self {
457                Self {
458                    decode,
459                    bytes,
460                    pending_byte: 0,
461                    pending_bits: 0,
462                }
463            }
464        }
465
466        impl<'a, T, I: Iterator<Item = u8>> Iterator for Decoder<'a, T, I> {
467            type Item = &'a T;
468            fn next(&mut self) -> Option<&'a T> {
469                // We must navigate `self.decode`, restocking bits whenever possible.
470                // We stop if ever there are not enough bits remaining.
471                let mut map = self.decode;
472                loop {
473                    if self.pending_bits < 8 {
474                        if let Some(next_byte) = self.bytes.next() {
475                            self.pending_byte = (self.pending_byte << 8) + next_byte as u16;
476                            self.pending_bits += 8;
477                        } else {
478                            return None;
479                        }
480                    }
481                    let byte = (self.pending_byte >> (self.pending_bits - 8)) as usize;
482                    match &map[byte] {
483                        Decode::Void => {
484                            panic!("invalid decoding map");
485                        }
486                        Decode::Symbol(s, bits) => {
487                            self.pending_bits -= bits;
488                            self.pending_byte &= (1 << self.pending_bits) - 1;
489                            return Some(s);
490                        }
491                        Decode::Further(next_map) => {
492                            self.pending_bits -= 8;
493                            self.pending_byte &= (1 << self.pending_bits) - 1;
494                            map = next_map;
495                        }
496                    }
497                }
498            }
499        }
500    }
501
502    /// A tabled Huffman encoder, written as an iterator.
503    mod encoder {
504
505        use std::collections::BTreeMap;
506
507        #[derive(Copy, Clone)]
508        pub struct Encoder<'a, T, I> {
509            encode: &'a BTreeMap<T, (usize, u64)>,
510            symbols: I,
511            pending_byte: u64,
512            pending_bits: usize,
513        }
514
515        impl<'a, T, I> Encoder<'a, T, I> {
516            pub fn new(encode: &'a BTreeMap<T, (usize, u64)>, symbols: I) -> Self {
517                Self {
518                    encode,
519                    symbols,
520                    pending_byte: 0,
521                    pending_bits: 0,
522                }
523            }
524        }
525
526        impl<'a, T: Ord, I> Iterator for Encoder<'a, T, I>
527        where
528            I: Iterator<Item = &'a T>,
529        {
530            type Item = u8;
531            fn next(&mut self) -> Option<u8> {
532                // We repeatedly ship bytes out of `self.pending_byte`, restocking from `self.symbols`.
533                while self.pending_bits < 8 {
534                    if let Some(symbol) = self.symbols.next() {
535                        let (bits, code) = self.encode.get(symbol).unwrap();
536                        self.pending_byte <<= bits;
537                        self.pending_byte += code;
538                        self.pending_bits += bits;
539                    } else {
540                        // We have run out of symbols. Perhaps there is a final fractional byte to ship?
541                        if self.pending_bits > 0 {
542                            let byte = self.pending_byte << (8 - self.pending_bits);
543                            self.pending_bits = 0;
544                            self.pending_byte = 0;
545                            return Some(byte as u8);
546                        } else {
547                            return None;
548                        }
549                    }
550                }
551
552                let byte = self.pending_byte >> (self.pending_bits - 8);
553                self.pending_bits -= 8;
554                self.pending_byte &= (1 << self.pending_bits) - 1;
555                Some(byte as u8)
556            }
557        }
558    }
559}