use rustc_hash::FxHashMap;
use smallvec::SmallVec;
pub type SymbolId = u16;
#[derive(Debug, Clone)]
pub struct PSTConfig {
pub max_depth: usize,
pub smoothing: f64,
pub max_nodes: usize,
}
impl Default for PSTConfig {
fn default() -> Self {
Self {
max_depth: 5,
smoothing: 0.01,
max_nodes: 100_000,
}
}
}
#[derive(Debug)]
pub struct PSTNode {
pub context: SmallVec<[SymbolId; 4]>,
pub(crate) counts: FxHashMap<SymbolId, u32>,
pub(crate) total: u32,
pub(crate) children: FxHashMap<SymbolId, usize>,
pub(crate) parent: Option<usize>,
}
impl PSTNode {
pub(crate) fn new(context: SmallVec<[SymbolId; 4]>, parent: Option<usize>) -> Self {
Self {
context,
counts: FxHashMap::default(),
total: 0,
children: FxHashMap::default(),
parent,
}
}
pub fn probability(&self, symbol: SymbolId, alphabet_size: usize, smoothing: f64) -> f64 {
let count = self.counts.get(&symbol).copied().unwrap_or(0) as f64;
let total = self.total as f64;
let alpha = smoothing;
let k = alphabet_size as f64;
(count + alpha) / alpha.mul_add(k, total)
}
pub fn distribution(&self, alphabet_size: usize, smoothing: f64) -> FxHashMap<SymbolId, f64> {
let mut dist = FxHashMap::default();
let alpha = smoothing;
let k = alphabet_size as f64;
let total = self.total as f64;
for (&sym, &count) in &self.counts {
dist.insert(sym, (count as f64 + alpha) / alpha.mul_add(k, total));
}
dist
}
}
#[derive(Debug)]
pub struct PredictionSuffixTree {
pub(crate) nodes: Vec<PSTNode>,
symbol_to_id: FxHashMap<String, SymbolId>,
id_to_symbol: Vec<String>,
config: PSTConfig,
}
impl PredictionSuffixTree {
pub fn new(config: PSTConfig) -> Self {
let root = PSTNode::new(SmallVec::new(), None);
Self {
nodes: vec![root],
symbol_to_id: FxHashMap::default(),
id_to_symbol: Vec::new(),
config,
}
}
pub fn register_symbol(&mut self, name: &str) -> SymbolId {
if let Some(&id) = self.symbol_to_id.get(name) {
return id;
}
let id = self.id_to_symbol.len() as SymbolId;
self.id_to_symbol.push(name.to_string());
self.symbol_to_id.insert(name.to_string(), id);
id
}
pub fn symbol_id(&self, name: &str) -> Option<SymbolId> {
self.symbol_to_id.get(name).copied()
}
pub const fn alphabet_size(&self) -> usize {
self.id_to_symbol.len()
}
pub fn train(&mut self, sequence: &[SymbolId]) {
let max_depth = self.config.max_depth;
for i in 0..sequence.len() {
let next_symbol = sequence[i];
let max_ctx_len = max_depth.min(i);
for ctx_len in 0..=max_ctx_len {
let ctx_start = i - ctx_len;
let context = &sequence[ctx_start..i];
self.update_node(context, next_symbol);
}
}
}
fn update_node(&mut self, context: &[SymbolId], next_symbol: SymbolId) {
let mut current = 0;
for &sym in context {
let next = if let Some(&child_idx) = self.nodes[current].children.get(&sym) {
child_idx
} else {
let mut child_ctx: SmallVec<[SymbolId; 4]> = self.nodes[current].context.clone();
child_ctx.push(sym);
let child_idx = self.nodes.len();
let child = PSTNode::new(child_ctx, Some(current));
self.nodes.push(child);
self.nodes[current].children.insert(sym, child_idx);
child_idx
};
current = next;
}
*self.nodes[current].counts.entry(next_symbol).or_insert(0) += 1;
self.nodes[current].total += 1;
}
pub fn predict(&self, context: &[SymbolId]) -> FxHashMap<SymbolId, f64> {
let alphabet_size = self.alphabet_size();
if alphabet_size == 0 {
return FxHashMap::default();
}
let node_idx = self.find_longest_context(context);
self.nodes[node_idx].distribution(alphabet_size, self.config.smoothing)
}
pub fn predict_symbol(&self, context: &[SymbolId], symbol: SymbolId) -> f64 {
let alphabet_size = self.alphabet_size();
if alphabet_size == 0 {
return 0.0;
}
let node_idx = self.find_longest_context(context);
self.nodes[node_idx].probability(symbol, alphabet_size, self.config.smoothing)
}
fn find_longest_context(&self, context: &[SymbolId]) -> usize {
let mut current = 0; for &sym in context {
if let Some(&child_idx) = self.nodes[current].children.get(&sym) {
current = child_idx;
} else {
break;
}
}
current
}
pub const fn node_count(&self) -> usize {
self.nodes.len()
}
pub const fn config(&self) -> &PSTConfig {
&self.config
}
pub const fn max_depth(&self) -> usize {
self.config.max_depth
}
pub const fn smoothing(&self) -> f64 {
self.config.smoothing
}
pub const fn max_nodes(&self) -> usize {
self.config.max_nodes
}
pub fn compact(&mut self) {
let mut reachable = vec![false; self.nodes.len()];
let mut queue = std::collections::VecDeque::new();
queue.push_back(0usize);
reachable[0] = true;
while let Some(idx) = queue.pop_front() {
for &child_idx in self.nodes[idx].children.values() {
if child_idx < self.nodes.len() && !reachable[child_idx] {
reachable[child_idx] = true;
queue.push_back(child_idx);
}
}
}
let mut old_to_new = vec![usize::MAX; self.nodes.len()];
let mut new_idx = 0usize;
for (old_idx, &is_reachable) in reachable.iter().enumerate() {
if is_reachable {
old_to_new[old_idx] = new_idx;
new_idx += 1;
}
}
if new_idx == self.nodes.len() {
return;
}
let old_nodes = std::mem::take(&mut self.nodes);
self.nodes = Vec::with_capacity(new_idx);
for (old_idx, node) in old_nodes.into_iter().enumerate() {
if !reachable[old_idx] {
continue;
}
let mut new_node = node;
new_node.parent = new_node.parent.and_then(|p| {
let mapped = old_to_new[p];
if mapped == usize::MAX {
None
} else {
Some(mapped)
}
});
let old_children = std::mem::take(&mut new_node.children);
for (sym, child_old) in old_children {
let mapped = old_to_new[child_old];
if mapped != usize::MAX {
new_node.children.insert(sym, mapped);
}
}
self.nodes.push(new_node);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_prediction_uniform() {
let mut pst = PredictionSuffixTree::new(PSTConfig::default());
let a = pst.register_symbol("A");
let b = pst.register_symbol("B");
let dist = pst.predict(&[]);
assert!(dist.is_empty() || dist.values().all(|&p| p > 0.0));
let pa = pst.predict_symbol(&[], a);
let pb = pst.predict_symbol(&[], b);
assert!((pa - pb).abs() < 1e-10);
}
#[test]
fn test_simple_sequence_learning() {
let mut pst = PredictionSuffixTree::new(PSTConfig {
max_depth: 3,
smoothing: 0.01,
..Default::default()
});
let a = pst.register_symbol("A");
let b = pst.register_symbol("B");
let sequence = vec![a, b, a, b, a, b, a, b, a, b];
pst.train(&sequence);
let p_b_given_a = pst.predict_symbol(&[a], b);
let p_a_given_a = pst.predict_symbol(&[a], a);
assert!(
p_b_given_a > p_a_given_a,
"P(B|A) = {p_b_given_a} should be > P(A|A) = {p_a_given_a}"
);
let p_a_given_b = pst.predict_symbol(&[b], a);
let p_b_given_b = pst.predict_symbol(&[b], b);
assert!(
p_a_given_b > p_b_given_b,
"P(A|B) = {p_a_given_b} should be > P(B|B) = {p_b_given_b}"
);
}
#[test]
fn test_variable_order_context() {
let mut pst = PredictionSuffixTree::new(PSTConfig {
max_depth: 3,
smoothing: 0.01,
..Default::default()
});
let a = pst.register_symbol("A");
let b = pst.register_symbol("B");
let c = pst.register_symbol("C");
let seq = vec![a, b, c, a, b, c, a, b, c];
pst.train(&seq);
let p_c = pst.predict_symbol(&[a, b], c);
assert!(p_c > 0.5, "P(C|A,B) = {p_c} should be > 0.5");
let p_c_long = pst.predict_symbol(&[c, a, b], c);
assert!(p_c_long > 0.5, "P(C|C,A,B) = {p_c_long} should be > 0.5");
}
#[test]
fn test_smoothing() {
let mut pst = PredictionSuffixTree::new(PSTConfig {
max_depth: 2,
smoothing: 0.1,
..Default::default()
});
let a = pst.register_symbol("A");
let b = pst.register_symbol("B");
let c = pst.register_symbol("C");
pst.train(&[a, a, a]);
let p_b = pst.predict_symbol(&[a], b);
let p_c = pst.predict_symbol(&[a], c);
assert!(p_b > 0.0, "Smoothed P(B|A) should be > 0");
assert!(p_c > 0.0, "Smoothed P(C|A) should be > 0");
}
#[test]
fn test_register_symbol_idempotent() {
let mut pst = PredictionSuffixTree::new(PSTConfig::default());
let id1 = pst.register_symbol("X");
let id2 = pst.register_symbol("X");
assert_eq!(id1, id2);
assert_eq!(pst.alphabet_size(), 1);
}
}