use super::tree::PredictionSuffixTree;
#[derive(Debug, Clone)]
pub enum PruningStrategy {
KLDivergence {
threshold: f64,
},
Entropy {
min_reduction: f64,
},
}
impl Default for PruningStrategy {
fn default() -> Self {
Self::KLDivergence { threshold: 0.05 }
}
}
impl PredictionSuffixTree {
pub fn prune(&mut self, strategy: &PruningStrategy) {
let alphabet_size = self.alphabet_size();
if alphabet_size == 0 {
return;
}
let smoothing = self.smoothing();
let mut to_remove: Vec<usize> = Vec::new();
for idx in (1..self.nodes.len()).rev() {
if to_remove.contains(&idx) {
continue;
}
if !self.nodes[idx].children.is_empty() {
continue;
}
let parent_idx = match self.nodes[idx].parent {
Some(p) => p,
None => continue,
};
let should_prune = match strategy {
PruningStrategy::KLDivergence { threshold } => {
let kl = kl_divergence(
&self.nodes[idx],
&self.nodes[parent_idx],
alphabet_size,
smoothing,
);
kl < *threshold
}
PruningStrategy::Entropy { min_reduction } => {
let child_entropy =
conditional_entropy(&self.nodes[idx], alphabet_size, smoothing);
let parent_entropy =
conditional_entropy(&self.nodes[parent_idx], alphabet_size, smoothing);
let reduction = parent_entropy - child_entropy;
reduction < *min_reduction
}
};
if should_prune {
to_remove.push(idx);
}
}
for &idx in &to_remove {
if let Some(parent_idx) = self.nodes[idx].parent {
let sym = self.nodes[idx].context.last().copied();
if let Some(sym) = sym {
self.nodes[parent_idx].children.remove(&sym);
}
}
}
}
}
fn kl_divergence(
child: &super::tree::PSTNode,
parent: &super::tree::PSTNode,
alphabet_size: usize,
smoothing: f64,
) -> f64 {
let mut kl = 0.0;
let alpha = smoothing;
let k = alphabet_size as f64;
let mut all_symbols: Vec<u16> = child.counts.keys().copied().collect();
for &sym in parent.counts.keys() {
if !all_symbols.contains(&sym) {
all_symbols.push(sym);
}
}
for sym in all_symbols {
let p = child.probability(sym, alphabet_size, smoothing);
let q = parent.probability(sym, alphabet_size, smoothing);
if p > 0.0 && q > 0.0 {
kl += p * (p / q).ln();
}
}
let unseen_count = alphabet_size.saturating_sub(child.counts.len().max(parent.counts.len()));
if unseen_count > 0 {
let p_unseen = alpha / alpha.mul_add(k, child.total as f64);
let q_unseen = alpha / alpha.mul_add(k, parent.total as f64);
if p_unseen > 0.0 && q_unseen > 0.0 {
kl += unseen_count as f64 * p_unseen * (p_unseen / q_unseen).ln();
}
}
kl
}
fn conditional_entropy(node: &super::tree::PSTNode, alphabet_size: usize, smoothing: f64) -> f64 {
let mut entropy = 0.0;
let alpha = smoothing;
let k = alphabet_size as f64;
for &sym_id in node.counts.keys() {
let p = node.probability(sym_id, alphabet_size, smoothing);
if p > 0.0 {
entropy -= p * p.ln();
}
}
let unseen = alphabet_size.saturating_sub(node.counts.len());
if unseen > 0 {
let p_unseen = alpha / alpha.mul_add(k, node.total as f64);
if p_unseen > 0.0 {
entropy -= unseen as f64 * p_unseen * p_unseen.ln();
}
}
entropy
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tree::PSTConfig;
#[test]
fn test_identical_distributions_prune() {
let mut pst = PredictionSuffixTree::new(PSTConfig {
max_depth: 2,
smoothing: 0.01,
..Default::default()
});
let a = pst.register_symbol("A");
let b = pst.register_symbol("B");
for _ in 0..50 {
pst.train(&[a, a]);
pst.train(&[a, b]);
pst.train(&[b, a]);
pst.train(&[b, b]);
}
let count_before = pst.node_count();
pst.prune(&PruningStrategy::KLDivergence { threshold: 0.1 });
let count_after = pst.node_count();
assert!(
count_after <= count_before,
"Pruning should not increase node count: {count_before} -> {count_after}"
);
}
#[test]
fn test_different_distributions_kept() {
let mut pst = PredictionSuffixTree::new(PSTConfig {
max_depth: 2,
smoothing: 0.001,
..Default::default()
});
let a = pst.register_symbol("A");
let b = pst.register_symbol("B");
for _ in 0..100 {
pst.train(&[a, b]);
pst.train(&[b, a]);
}
pst.prune(&PruningStrategy::KLDivergence { threshold: 0.001 });
let has_depth1_nodes = pst
.nodes
.iter()
.any(|n| n.context.len() == 1 && n.total > 0);
assert!(
has_depth1_nodes,
"Informative depth-1 nodes should survive pruning"
);
}
#[test]
fn test_entropy_pruning() {
let mut pst = PredictionSuffixTree::new(PSTConfig {
max_depth: 3,
smoothing: 0.01,
..Default::default()
});
let a = pst.register_symbol("A");
let _b = pst.register_symbol("B");
pst.train(&[a, a, a, a, a, a, a, a]);
let count_before = pst.node_count();
pst.prune(&PruningStrategy::Entropy { min_reduction: 0.1 });
let count_after = pst.node_count();
assert!(
count_after <= count_before,
"Entropy pruning should remove uninformative nodes"
);
}
}