use super::pruning::PruningStrategy;
use super::tree::{PredictionSuffixTree, SymbolId};
#[derive(Debug)]
pub struct OnlinePSTLearner {
context_buffer: Vec<SymbolId>,
max_context: usize,
prune_interval: u64,
pruning_strategy: PruningStrategy,
updates: u64,
}
impl OnlinePSTLearner {
pub fn new(max_context: usize) -> Self {
Self {
context_buffer: Vec::with_capacity(max_context + 1),
max_context,
prune_interval: 10_000,
pruning_strategy: PruningStrategy::default(),
updates: 0,
}
}
pub const fn with_prune_interval(mut self, interval: u64) -> Self {
self.prune_interval = interval;
self
}
pub fn update(&mut self, pst: &mut PredictionSuffixTree, symbol: SymbolId) {
let max_depth = pst.max_depth();
for ctx_len in 0..=self.context_buffer.len().min(max_depth) {
let ctx_start = self.context_buffer.len().saturating_sub(ctx_len);
let context = &self.context_buffer[ctx_start..];
let mut current = 0; for &sym in context {
let next = if let Some(&child_idx) = pst.nodes[current].children.get(&sym) {
child_idx
} else {
let mut child_ctx: smallvec::SmallVec<[SymbolId; 4]> =
pst.nodes[current].context.clone();
child_ctx.push(sym);
let child_idx = pst.nodes.len();
let child = super::tree::PSTNode::new(child_ctx, Some(current));
pst.nodes.push(child);
pst.nodes[current].children.insert(sym, child_idx);
child_idx
};
current = next;
}
*pst.nodes[current].counts.entry(symbol).or_insert(0) += 1;
pst.nodes[current].total += 1;
}
self.context_buffer.push(symbol);
if self.context_buffer.len() > self.max_context {
self.context_buffer.remove(0);
}
self.updates += 1;
if self.prune_interval > 0 && self.updates.is_multiple_of(self.prune_interval) {
pst.prune(&self.pruning_strategy);
}
if pst.node_count() > pst.max_nodes() {
pst.prune(&self.pruning_strategy);
pst.compact();
}
}
pub fn current_context(&self) -> &[SymbolId] {
&self.context_buffer
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tree::PSTConfig;
#[test]
fn test_online_learning() {
let mut pst = PredictionSuffixTree::new(PSTConfig {
max_depth: 3,
smoothing: 0.01,
..Default::default()
});
let a = pst.register_symbol("A");
let b = pst.register_symbol("B");
let mut learner = OnlinePSTLearner::new(3).with_prune_interval(0);
for _ in 0..50 {
learner.update(&mut pst, a);
learner.update(&mut pst, b);
}
let p_b = pst.predict_symbol(&[a], b);
assert!(p_b > 0.6, "P(B|A) = {p_b} should be > 0.6");
}
#[test]
fn test_context_buffer_bounded() {
let mut pst = PredictionSuffixTree::new(PSTConfig {
max_depth: 3,
smoothing: 0.01,
..Default::default()
});
let a = pst.register_symbol("A");
let mut learner = OnlinePSTLearner::new(3).with_prune_interval(0);
for _ in 0..100 {
learner.update(&mut pst, a);
}
assert!(learner.current_context().len() <= 3);
}
}