use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::cpt::Cpt;
use crate::variable::{Variable, VariableId};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BayesianNetwork {
variables: HashMap<VariableId, Variable>,
children: HashMap<VariableId, Vec<VariableId>>,
parents: HashMap<VariableId, Vec<VariableId>>,
cpts: HashMap<VariableId, Cpt>,
order: Vec<VariableId>,
}
impl BayesianNetwork {
pub fn new() -> Self {
Self {
variables: HashMap::new(),
children: HashMap::new(),
parents: HashMap::new(),
cpts: HashMap::new(),
order: Vec::new(),
}
}
pub fn add_variable(&mut self, var: Variable) -> VariableId {
let id = var.id;
self.order.push(id);
self.children.entry(id).or_default();
self.parents.entry(id).or_default();
self.variables.insert(id, var);
id
}
pub fn add_edge(&mut self, parent: VariableId, child: VariableId) {
self.children.entry(parent).or_default().push(child);
self.parents.entry(child).or_default().push(parent);
}
pub fn initialize_cpts(&mut self) {
for &var_id in &self.order {
let var = &self.variables[&var_id];
let parent_ids = self.parents.get(&var_id).cloned().unwrap_or_default();
let parent_sizes: Vec<usize> = parent_ids
.iter()
.map(|pid| self.variables[pid].n_states())
.collect();
let cpt = Cpt::new(var_id, var.n_states(), parent_ids, parent_sizes);
self.cpts.insert(var_id, cpt);
}
}
pub fn observe(&mut self, observations: &[(VariableId, usize)]) {
let obs_map: HashMap<VariableId, usize> = observations.iter().copied().collect();
for &var_id in &self.order {
if let Some(&state) = obs_map.get(&var_id) {
if let Some(cpt) = self.cpts.get_mut(&var_id) {
let parent_states: Vec<usize> = cpt
.parent_ids
.iter()
.map(|pid| obs_map.get(pid).copied().unwrap_or(0))
.collect();
cpt.observe(state, &parent_states);
}
}
}
}
pub fn update_cpts(&mut self) {
for cpt in self.cpts.values_mut() {
cpt.update_from_counts();
}
}
pub fn query(
&self,
query_var: VariableId,
query_state: usize,
evidence: &[(VariableId, usize)],
) -> f64 {
let posterior = self.posterior(query_var, evidence);
posterior.get(query_state).copied().unwrap_or(0.0)
}
pub fn posterior(&self, query_var: VariableId, evidence: &[(VariableId, usize)]) -> Vec<f64> {
let evidence_map: HashMap<VariableId, usize> = evidence.iter().copied().collect();
let query_n_states = self.variables[&query_var].n_states();
let mut joint = vec![0.0f64; query_n_states];
for (q_state, joint_prob) in joint.iter_mut().enumerate() {
let mut prob = 1.0;
for &var_id in &self.order {
let cpt = &self.cpts[&var_id];
let state = if var_id == query_var {
q_state
} else if let Some(&s) = evidence_map.get(&var_id) {
s
} else {
0
};
let parent_states: Vec<usize> = cpt
.parent_ids
.iter()
.map(|pid| {
if *pid == query_var {
q_state
} else {
evidence_map.get(pid).copied().unwrap_or(0)
}
})
.collect();
let config = cpt.parent_config_index(&parent_states);
prob *= cpt.probability(state, config);
}
*joint_prob = prob;
}
let total: f64 = joint.iter().sum();
if total > 1e-12 {
for p in &mut joint {
*p /= total;
}
} else {
let uniform = 1.0 / query_n_states as f64;
joint.fill(uniform);
}
joint
}
pub fn map_estimate(
&self,
query_var: VariableId,
evidence: &[(VariableId, usize)],
) -> (usize, f64) {
let posterior = self.posterior(query_var, evidence);
posterior
.iter()
.enumerate()
.max_by(|a, b| a.1.total_cmp(b.1))
.map(|(i, &p)| (i, p))
.unwrap_or((0, 0.0))
}
pub fn variable(&self, id: VariableId) -> Option<&Variable> {
self.variables.get(&id)
}
pub fn cpt(&self, var_id: VariableId) -> Option<&Cpt> {
self.cpts.get(&var_id)
}
pub fn n_variables(&self) -> usize {
self.variables.len()
}
pub fn n_edges(&self) -> usize {
self.children.values().map(|c| c.len()).sum()
}
pub fn reset_counts(&mut self) {
for cpt in self.cpts.values_mut() {
cpt.reset_counts();
}
}
pub fn stats(&self) -> String {
let total_cells: usize = self.cpts.values().map(|c| c.n_cells()).sum();
format!(
"{} variables, {} edges, {} CPT cells",
self.n_variables(),
self.n_edges(),
total_cells
)
}
}
impl Default for BayesianNetwork {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Variable;
fn build_simple_network() -> BayesianNetwork {
let mut bn = BayesianNetwork::new();
let task = bn.add_variable(Variable::new(0, "task", vec!["easy".into(), "hard".into()]));
let success = bn.add_variable(Variable::binary(1, "success"));
bn.add_edge(task, success);
bn.initialize_cpts();
for _ in 0..9 {
bn.observe(&[(task, 0), (success, 0)]); }
bn.observe(&[(task, 0), (success, 1)]);
for _ in 0..3 {
bn.observe(&[(task, 1), (success, 0)]); }
for _ in 0..7 {
bn.observe(&[(task, 1), (success, 1)]); }
bn.update_cpts();
bn
}
#[test]
fn network_structure() {
let bn = build_simple_network();
assert_eq!(bn.n_variables(), 2);
assert_eq!(bn.n_edges(), 1);
}
#[test]
fn learned_cpt() {
let bn = build_simple_network();
let cpt = bn.cpt(1).unwrap();
let p_easy = cpt.probability(0, 0);
assert!((p_easy - 0.833).abs() < 0.01, "P(success|easy) = {p_easy}");
let p_hard = cpt.probability(0, 1);
assert!((p_hard - 0.333).abs() < 0.01, "P(success|hard) = {p_hard}");
}
#[test]
fn query_conditional() {
let bn = build_simple_network();
let p_easy = bn.query(1, 0, &[(0, 0)]);
assert!(p_easy > 0.7, "P(success|easy) = {p_easy}");
let p_hard = bn.query(1, 0, &[(0, 1)]);
assert!(p_hard < 0.5, "P(success|hard) = {p_hard}");
}
#[test]
fn map_estimate() {
let bn = build_simple_network();
let (state, prob) = bn.map_estimate(1, &[(0, 0)]);
assert_eq!(state, 0, "MAP for easy should be success");
assert!(prob > 0.7);
let (state, _) = bn.map_estimate(1, &[(0, 1)]);
assert_eq!(state, 1, "MAP for hard should be failure");
}
#[test]
fn three_variable_network() {
let mut bn = BayesianNetwork::new();
let task = bn.add_variable(Variable::new(
0,
"task",
vec!["pick".into(), "clean".into()],
));
let region = bn.add_variable(Variable::new(
1,
"region",
vec!["kitchen".into(), "bathroom".into()],
));
let success = bn.add_variable(Variable::binary(2, "success"));
bn.add_edge(task, success);
bn.add_edge(region, success);
bn.initialize_cpts();
for _ in 0..8 {
bn.observe(&[(task, 0), (region, 0), (success, 0)]);
}
for _ in 0..8 {
bn.observe(&[(task, 1), (region, 1), (success, 0)]);
}
for _ in 0..2 {
bn.observe(&[(task, 0), (region, 1), (success, 0)]);
}
for _ in 0..6 {
bn.observe(&[(task, 0), (region, 1), (success, 1)]);
}
bn.update_cpts();
let p = bn.query(2, 0, &[(0, 0), (1, 0)]);
assert!(p > 0.7, "P(success | pick, kitchen) = {p}");
let p = bn.query(2, 0, &[(0, 0), (1, 1)]);
assert!(p < 0.5, "P(success | pick, bathroom) = {p}");
}
#[test]
fn posterior_sums_to_one() {
let bn = build_simple_network();
let posterior = bn.posterior(1, &[(0, 0)]);
let sum: f64 = posterior.iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "posterior sum = {sum}");
}
#[test]
fn serialization_roundtrip() {
let bn = build_simple_network();
let bytes = postcard::to_allocvec(&bn).unwrap();
let restored: BayesianNetwork = postcard::from_bytes(&bytes).unwrap();
assert_eq!(restored.n_variables(), 2);
let p_orig = bn.query(1, 0, &[(0, 0)]);
let p_restored = restored.query(1, 0, &[(0, 0)]);
assert!((p_orig - p_restored).abs() < 1e-6);
}
#[test]
fn reset_and_relearn() {
let mut bn = build_simple_network();
let p_before = bn.query(1, 0, &[(0, 0)]);
bn.reset_counts();
for _ in 0..9 {
bn.observe(&[(0, 0), (1, 1)]); }
bn.observe(&[(0, 0), (1, 0)]); bn.update_cpts();
let p_after = bn.query(1, 0, &[(0, 0)]);
assert!(
p_after < p_before,
"after re-learning, P(success|easy) should be lower"
);
}
}