use serde::{Deserialize, Serialize};
use crate::VariableId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Cpt {
pub variable_id: VariableId,
pub parent_ids: Vec<VariableId>,
pub n_states: usize,
pub parent_sizes: Vec<usize>,
pub table: Vec<f64>,
counts: Vec<f64>,
pub smoothing: f64,
}
impl Cpt {
pub fn new(
variable_id: VariableId,
n_states: usize,
parent_ids: Vec<VariableId>,
parent_sizes: Vec<usize>,
) -> Self {
let n_configs: usize = parent_sizes.iter().product::<usize>().max(1);
let total_cells = n_states * n_configs;
let uniform = 1.0 / n_states as f64;
Self {
variable_id,
parent_ids,
n_states,
parent_sizes,
table: vec![uniform; total_cells],
counts: vec![0.0; total_cells],
smoothing: 1.0, }
}
pub fn prior(variable_id: VariableId, n_states: usize) -> Self {
Self::new(variable_id, n_states, vec![], vec![])
}
pub fn from_probs(
variable_id: VariableId,
n_states: usize,
parent_ids: Vec<VariableId>,
parent_sizes: Vec<usize>,
probs: Vec<f64>,
) -> Self {
let n_configs: usize = parent_sizes.iter().product::<usize>().max(1);
assert_eq!(
probs.len(),
n_states * n_configs,
"probs length mismatch: expected {}, got {}",
n_states * n_configs,
probs.len()
);
Self {
variable_id,
parent_ids: parent_ids.clone(),
n_states,
parent_sizes: parent_sizes.clone(),
table: probs,
counts: vec![0.0; n_states * n_configs],
smoothing: 1.0,
}
}
pub fn probability(&self, state: usize, parent_config: usize) -> f64 {
let idx = parent_config * self.n_states + state;
self.table.get(idx).copied().unwrap_or(0.0)
}
pub fn distribution(&self, parent_config: usize) -> &[f64] {
let start = parent_config * self.n_states;
let end = start + self.n_states;
&self.table[start..end.min(self.table.len())]
}
pub fn parent_config_index(&self, parent_states: &[usize]) -> usize {
if parent_states.is_empty() {
return 0;
}
let mut idx = 0;
let mut stride = 1;
for i in (0..parent_states.len()).rev() {
idx += parent_states[i] * stride;
stride *= self.parent_sizes[i];
}
idx
}
pub fn observe(&mut self, state: usize, parent_states: &[usize]) {
let config = self.parent_config_index(parent_states);
let idx = config * self.n_states + state;
if idx < self.counts.len() {
self.counts[idx] += 1.0;
}
}
pub fn update_from_counts(&mut self) {
let n_configs: usize = self.parent_sizes.iter().product::<usize>().max(1);
for config in 0..n_configs {
let start = config * self.n_states;
let total: f64 = (0..self.n_states)
.map(|s| self.counts[start + s])
.sum::<f64>()
+ self.smoothing * self.n_states as f64;
for s in 0..self.n_states {
let idx = start + s;
self.table[idx] = (self.counts[idx] + self.smoothing) / total;
}
}
}
pub fn n_configs(&self) -> usize {
self.parent_sizes.iter().product::<usize>().max(1)
}
pub fn n_cells(&self) -> usize {
self.table.len()
}
pub fn reset_counts(&mut self) {
self.counts.fill(0.0);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prior_uniform() {
let cpt = Cpt::prior(0, 3);
let dist = cpt.distribution(0);
assert_eq!(dist.len(), 3);
for &p in dist {
assert!((p - 1.0 / 3.0).abs() < 1e-6);
}
}
#[test]
fn conditional_with_one_parent() {
let cpt = Cpt::from_probs(
0,
2, vec![1], vec![3], vec![
0.8, 0.2, 0.5, 0.5, 0.3, 0.7, ],
);
assert_eq!(cpt.n_configs(), 3);
assert!((cpt.probability(0, 0) - 0.8).abs() < 1e-6); assert!((cpt.probability(0, 2) - 0.3).abs() < 1e-6); }
#[test]
fn parent_config_index() {
let cpt = Cpt::new(0, 2, vec![1, 2], vec![3, 4]);
assert_eq!(cpt.n_configs(), 12);
assert_eq!(cpt.parent_config_index(&[0, 0]), 0);
assert_eq!(cpt.parent_config_index(&[0, 1]), 1);
assert_eq!(cpt.parent_config_index(&[1, 0]), 4);
assert_eq!(cpt.parent_config_index(&[2, 3]), 11);
}
#[test]
fn learn_from_observations() {
let mut cpt = Cpt::prior(0, 2);
for _ in 0..8 {
cpt.observe(0, &[]); }
for _ in 0..2 {
cpt.observe(1, &[]); }
cpt.update_from_counts();
let p = cpt.probability(0, 0);
assert!((p - 0.75).abs() < 0.01, "P(success) = {p}");
}
#[test]
fn learn_conditional() {
let mut cpt = Cpt::new(0, 2, vec![1], vec![2]);
for _ in 0..9 {
cpt.observe(0, &[0]);
}
cpt.observe(1, &[0]);
for _ in 0..3 {
cpt.observe(0, &[1]);
}
for _ in 0..7 {
cpt.observe(1, &[1]);
}
cpt.update_from_counts();
let p_nav = cpt.probability(0, 0);
assert!((p_nav - 0.833).abs() < 0.01, "P(s|nav) = {p_nav}");
let p_take = cpt.probability(0, 1);
assert!((p_take - 0.333).abs() < 0.01, "P(s|take) = {p_take}");
}
#[test]
fn serialization() {
let mut cpt = Cpt::prior(0, 3);
cpt.observe(0, &[]);
cpt.observe(0, &[]);
cpt.update_from_counts();
let bytes = postcard::to_allocvec(&cpt).unwrap();
let restored: Cpt = postcard::from_bytes(&bytes).unwrap();
assert_eq!(restored.n_states, 3);
assert!((restored.probability(0, 0) - cpt.probability(0, 0)).abs() < 1e-6);
}
}