use crate::frequency::FrequencyTable;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Node {
Leaf { symbol: u8, weight: usize },
Internal {
weight: usize,
left: Box<Node>,
right: Box<Node>,
},
}
impl Node {
pub fn weight(&self) -> usize {
match self {
Node::Leaf { weight, .. } => *weight,
Node::Internal { weight, .. } => *weight,
}
}
pub fn symbol(&self) -> Option<u8> {
match self {
Node::Leaf { symbol, .. } => Some(*symbol),
Node::Internal { .. } => None,
}
}
pub fn left(&self) -> Option<&Node> {
match self {
Node::Internal { left, .. } => Some(left),
_ => None,
}
}
pub fn right(&self) -> Option<&Node> {
match self {
Node::Internal { right, .. } => Some(right),
_ => None,
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
struct HeapNode {
node: Node,
}
impl Ord for HeapNode {
fn cmp(&self, other: &Self) -> Ordering {
other.node.weight().cmp(&self.node.weight())
}
}
impl PartialOrd for HeapNode {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
pub struct HuffmanTree {
root: Node,
}
impl HuffmanTree {
pub fn from_frequency_table(ft: &FrequencyTable) -> Option<Self> {
if ft.total() == 0 {
return None;
}
let mut heap: BinaryHeap<HeapNode> = BinaryHeap::new();
for (&symbol, &count) in ft.iter() {
heap.push(HeapNode {
node: Node::Leaf {
symbol,
weight: count,
},
});
}
if heap.len() == 1 {
let only = heap.pop().unwrap().node;
return Some(HuffmanTree {
root: Node::Internal {
weight: only.weight(),
left: Box::new(only),
right: Box::new(Node::Leaf {
symbol: 0,
weight: 0,
}),
},
});
}
while heap.len() > 1 {
let a = heap.pop().unwrap().node;
let b = heap.pop().unwrap().node;
let weight = a.weight() + b.weight();
heap.push(HeapNode {
node: Node::Internal {
weight,
left: Box::new(a),
right: Box::new(b),
},
});
}
let root = heap.pop().unwrap().node;
Some(HuffmanTree { root })
}
pub fn from_root(root: Node) -> Self {
HuffmanTree { root }
}
pub fn root(&self) -> &Node {
&self.root
}
pub fn code_lengths(&self) -> Vec<(u8, u8)> {
let mut lengths = Vec::new();
Self::walk_lengths(&self.root, 0, &mut lengths);
lengths
}
fn walk_lengths(node: &Node, depth: u8, lengths: &mut Vec<(u8, u8)>) {
match node {
Node::Leaf { .. } => {
lengths.push((node.symbol().unwrap(), depth.max(1)));
}
Node::Internal { left, right, .. } => {
Self::walk_lengths(left, depth + 1, lengths);
Self::walk_lengths(right, depth + 1, lengths);
}
}
}
}