use std::cmp::{Ordering, Reverse};
use std::collections::{BinaryHeap, HashMap};
use buffertk::{Packable, Unpackable, stack_pack};
use crate::Error;
pub trait Encoder {
fn construct(text: &[u32]) -> Self;
fn encode(&self, t: u32) -> Option<(u32, u8)>;
fn decode(&self, e: u32, s: u8) -> Option<u32>;
fn symbols(&self) -> usize;
}
#[derive(Clone, Debug, Default, prototk_derive::Message)]
pub struct FixedWidthEncoder {
#[prototk(1, uint32)]
chars: Vec<u32>,
}
impl Encoder for FixedWidthEncoder {
fn construct(text: &[u32]) -> Self {
let mut chars = text.to_vec();
chars.sort();
chars.dedup();
chars.shrink_to_fit();
assert!(chars.len() <= u32::MAX as usize);
Self { chars }
}
fn encode(&self, t: u32) -> Option<(u32, u8)> {
let position: u32 = self.chars.binary_search(&t).ok()?.try_into().ok()?;
let bits = std::cmp::max(self.chars.len(), 2)
.next_power_of_two()
.ilog2()
.try_into()
.ok()?;
Some((position, bits))
}
fn decode(&self, v: u32, _: u8) -> Option<u32> {
let v: usize = v.try_into().ok()?;
self.chars.get(v).copied()
}
fn symbols(&self) -> usize {
self.chars.len()
}
}
struct Node {
prob: f64,
sym: Option<u32>,
lhs: Option<Box<Node>>,
rhs: Option<Box<Node>>,
}
impl Node {
fn append_symbols(&self, depth: u8, symbols: &mut Vec<(u8, u32)>) -> bool {
if depth == u8::MAX {
return false;
}
if let Some(sym) = self.sym.as_ref() {
symbols.push((depth, *sym));
}
if let Some(lhs) = self.lhs.as_ref()
&& !lhs.append_symbols(depth + 1, symbols)
{
return false;
}
if let Some(rhs) = self.rhs.as_ref()
&& !rhs.append_symbols(depth + 1, symbols)
{
return false;
}
true
}
}
impl Eq for Node {}
impl PartialEq for Node {
fn eq(&self, other: &Self) -> bool {
self.prob.total_cmp(&other.prob).is_eq()
}
}
impl Ord for Node {
fn cmp(&self, other: &Self) -> Ordering {
self.prob.total_cmp(&other.prob)
}
}
impl PartialOrd for Node {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Clone, Debug, Default, prototk_derive::Message)]
pub struct CodeBookEntry {
#[prototk(1, uint32)]
symbol: u32,
#[prototk(2, uint32)]
code: u32,
#[prototk(3, uint32)]
len: u32,
}
#[derive(Clone, Debug, Default, prototk_derive::Message)]
pub struct CodeBook {
#[prototk(1, message)]
code_book: Vec<CodeBookEntry>,
}
#[derive(Clone, Debug, Default)]
pub struct HuffmanEncoder {
encode: HashMap<u32, (u32, u8)>,
decode: HashMap<u32, u32>,
}
impl HuffmanEncoder {
fn code_book(&self) -> CodeBook {
let mut code_book = CodeBook::default();
for (symbol, (code, len)) in self.encode.iter() {
code_book.code_book.push(CodeBookEntry {
symbol: *symbol,
code: *code,
len: *len as u32,
});
}
code_book
}
}
impl Encoder for HuffmanEncoder {
fn construct(text: &[u32]) -> Self {
if text.is_empty() {
return Self {
encode: HashMap::new(),
decode: HashMap::new(),
};
}
let mut probabilities = HashMap::new();
for t in text.iter() {
*probabilities.entry(*t).or_insert(0) += 1;
}
let mut heap = BinaryHeap::new();
for (sym, prob) in probabilities.into_iter() {
heap.push(Reverse(Node {
prob: prob as f64,
sym: Some(sym),
lhs: None,
rhs: None,
}));
}
while heap.len() >= 2 {
let lhs = heap.pop().unwrap().0;
let rhs = heap.pop().unwrap().0;
heap.push(Reverse(Node {
prob: lhs.prob + rhs.prob,
sym: None,
lhs: Some(Box::new(lhs)),
rhs: Some(Box::new(rhs)),
}));
}
assert_eq!(1, heap.len());
let tree = heap.pop().unwrap().0;
let mut symbols = vec![];
if let Some(sym) = tree.sym.as_ref() {
symbols.push((1u8, *sym));
} else {
tree.append_symbols(0, &mut symbols);
}
symbols.sort();
let mut encode = HashMap::new();
let mut decode = HashMap::new();
let mut code = 0u32;
let mut prev_len = 1u8;
for (len, sym) in symbols.into_iter() {
code <<= len - prev_len;
let flipped = code.reverse_bits() >> (32 - len);
encode.insert(sym, (flipped, len));
decode.insert(flipped, sym);
code += 1;
prev_len = len;
}
Self { encode, decode }
}
fn encode(&self, t: u32) -> Option<(u32, u8)> {
self.encode.get(&t).copied()
}
fn decode(&self, v: u32, _: u8) -> Option<u32> {
self.decode.get(&v).copied()
}
fn symbols(&self) -> usize {
self.encode.len()
}
}
impl Packable for HuffmanEncoder {
fn pack_sz(&self) -> usize {
let code_book = self.code_book();
stack_pack(code_book).pack_sz()
}
fn pack(&self, buf: &mut [u8]) {
let code_book = self.code_book();
stack_pack(code_book).into_slice(buf);
}
}
impl<'a> Unpackable<'a> for HuffmanEncoder {
type Error = Error;
fn unpack<'b: 'a>(buf: &'b [u8]) -> Result<(Self, &'b [u8]), Self::Error> {
let (code_book, buf) = CodeBook::unpack(buf).map_err(|_| Error::InvalidEncoder)?;
let mut encode = HashMap::new();
let mut decode = HashMap::new();
for cbe in code_book.code_book.into_iter() {
if encode.contains_key(&cbe.symbol) {
return Err(Error::InvalidEncoder);
}
if cbe.len > u8::MAX as u32 {
return Err(Error::InvalidEncoder);
}
let len = cbe.len as u8;
if decode.contains_key(&cbe.code) {
return Err(Error::InvalidEncoder);
}
encode.insert(cbe.symbol, (cbe.code, len));
decode.insert(cbe.code, cbe.symbol);
}
let this = HuffmanEncoder { encode, decode };
Ok((this, buf))
}
}
#[cfg(test)]
mod tests {
use super::{Encoder, FixedWidthEncoder, HuffmanEncoder};
#[test]
fn fixed_width_empty() {
let chars = vec![];
let encoder = FixedWidthEncoder::construct(&chars);
assert_eq!(None, encoder.encode(0u32));
}
#[test]
fn fixed_width_0() {
let chars = vec![0u32];
let encoder = FixedWidthEncoder::construct(&chars);
assert_eq!((0, 1), encoder.encode(0u32).unwrap());
assert_eq!(None, encoder.encode(1u32));
}
#[test]
fn fixed_width_0_1() {
let chars = vec![0u32, 1u32];
let encoder = FixedWidthEncoder::construct(&chars);
assert_eq!((0, 1), encoder.encode(0u32).unwrap());
assert_eq!((1, 1), encoder.encode(1u32).unwrap());
}
#[test]
fn fixed_width_chars() {
let chars: Vec<u32> = "AaBbCcDdEeFfNnBananaMississippi"
.chars()
.map(|c| c as u32)
.collect();
let encoder = FixedWidthEncoder::construct(&chars);
assert_eq!((0, 5), encoder.encode('A' as u32).unwrap());
assert_eq!((1, 5), encoder.encode('B' as u32).unwrap());
assert_eq!((2, 5), encoder.encode('C' as u32).unwrap());
assert_eq!((8, 5), encoder.encode('a' as u32).unwrap());
assert_eq!((9, 5), encoder.encode('b' as u32).unwrap());
assert_eq!((10, 5), encoder.encode('c' as u32).unwrap());
for c in chars.iter().copied() {
let (v, s) = encoder.encode(c).unwrap();
assert_eq!(Some(c), encoder.decode(v, s));
}
assert_eq!(None, encoder.encode('q' as u32));
assert_eq!(None, encoder.encode('z' as u32));
}
#[test]
fn huffman_chars() {
let chars: Vec<u32> = "BananaMississippi".chars().map(|c| c as u32).collect();
let encoder = HuffmanEncoder::construct(&chars);
assert_eq!((0, 2), encoder.encode('i' as u32).unwrap());
assert_eq!((2, 2), encoder.encode('s' as u32).unwrap());
assert_eq!((1, 3), encoder.encode('a' as u32).unwrap());
assert_eq!((5, 3), encoder.encode('n' as u32).unwrap());
assert_eq!((3, 3), encoder.encode('p' as u32).unwrap());
assert_eq!((7, 4), encoder.encode('B' as u32).unwrap());
assert_eq!((15, 4), encoder.encode('M' as u32).unwrap());
for c in chars.iter().copied() {
let (v, s) = encoder.encode(c).unwrap();
assert_eq!(Some(c), encoder.decode(v, s));
}
assert_eq!(None, encoder.encode('q' as u32));
assert_eq!(None, encoder.encode('z' as u32));
}
}