use crate::alloc::collections::BTreeMap;
use crate::alloc::vec::Vec;
#[derive(Clone, Debug)]
pub struct ContextNode {
pub count: [u32; 256],
pub total: u32,
}
impl Default for ContextNode {
fn default() -> Self {
Self {
count: [0u32; 256],
total: 0,
}
}
}
#[derive(Clone, Default, Debug)]
struct NgramState {
nodes: BTreeMap<Vec<u8>, ContextNode>,
}
pub struct ContextMixer {
orders: Vec<usize>,
models: Vec<NgramState>,
max_nodes: usize,
current_nodes: usize,
prune_divisor: usize,
}
impl ContextMixer {
pub fn new(orders: Vec<usize>) -> Self {
let mut models = Vec::with_capacity(orders.len());
for _ in 0..orders.len() {
models.push(NgramState::default());
}
Self {
orders,
models,
max_nodes: 8192, current_nodes: 0,
prune_divisor: 20, }
}
pub fn with_node_limit(mut self, limit: usize) -> Self {
self.max_nodes = limit;
self
}
pub fn with_prune_divisor(mut self, divisor: usize) -> Self {
self.prune_divisor = divisor;
self
}
pub fn update(&mut self, history: &[u8], next: u8) {
let h_len = history.len();
for (i, &order) in self.orders.iter().enumerate() {
if h_len >= order {
let context = &history[h_len - order..];
let node = self.models[i]
.nodes
.entry(context.to_vec())
.or_insert_with(|| {
self.current_nodes += 1;
ContextNode::default()
});
node.count[next as usize] = node.count[next as usize].saturating_add(1);
node.total = node.total.saturating_add(1);
}
}
if self.current_nodes > self.max_nodes {
self.prune();
}
}
fn prune(&mut self) {
let target_remove = self.max_nodes / self.prune_divisor;
let mut removed = 0;
for model in &mut self.models {
let mut keys_to_remove = Vec::new();
for (key, node) in &model.nodes {
if node.total < 2 {
keys_to_remove.push(key.clone());
if keys_to_remove.len() + removed >= target_remove {
break;
}
}
}
for key in keys_to_remove {
model.nodes.remove(&key);
removed += 1;
self.current_nodes -= 1;
}
if removed >= target_remove {
break;
}
}
if self.current_nodes > self.max_nodes * 11 / 10 {
for model in &mut self.models {
let diff = model.nodes.len();
model.nodes.clear();
self.current_nodes -= diff;
if self.current_nodes <= self.max_nodes {
break;
}
}
}
}
#[inline]
pub fn predict(&self, history: &[u8]) -> Option<[u64; 256]> {
let mut mixed_probs = [0u64; 256];
let mut total_weight = 0u64;
let h_len = history.len();
for (i, &order) in self.orders.iter().enumerate() {
if h_len >= order {
let context = &history[h_len - order..];
if let Some(node) = self.models[i].nodes.get(context).filter(|n| n.total > 0) {
let weight = (order + 1) as u64 * (order + 1) as u64;
let node_total = node.total as u64;
let multiplier = 65536 * weight;
for (p, count) in mixed_probs.iter_mut().zip(node.count.iter()) {
*p += (*count as u64 * multiplier) / node_total;
}
total_weight += weight;
}
}
}
if total_weight == 0 {
None
} else {
for prob in &mut mixed_probs {
*prob /= total_weight;
}
Some(mixed_probs)
}
}
}