use super::config::{BayesianConfig, BayesianNode, NodeType};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct Factor {
pub variables: Vec<String>,
pub cardinalities: Vec<usize>,
pub values: Vec<f64>,
}
impl Factor {
#[must_use]
pub fn from_node(name: &str, node: &BayesianNode, config: &BayesianConfig) -> Self {
let mut variables = vec![name.to_string()];
let mut cardinalities = vec![node.states.len()];
for parent in &node.parents {
if let Some(parent_node) = config.nodes.get(parent) {
variables.push(parent.clone());
cardinalities.push(parent_node.states.len());
}
}
let total_size: usize = cardinalities.iter().product();
let mut values = vec![0.0; total_size];
if node.is_root() {
values.clone_from(&node.prior);
} else {
if let Some(parent_node) = config.nodes.get(&node.parents[0]) {
let parent_card = parent_node.states.len();
for (i, val) in values.iter_mut().enumerate().take(total_size) {
let parent_idx = i % parent_card;
let state_idx = i / parent_card;
if parent_idx < parent_node.states.len() {
let parent_state = &parent_node.states[parent_idx];
if let Some(probs) = node.cpt.get(parent_state) {
if state_idx < probs.len() {
*val = probs[state_idx];
}
}
}
}
}
}
Self {
variables,
cardinalities,
values,
}
}
#[must_use]
pub fn multiply(&self, other: &Self) -> Self {
let mut new_variables = self.variables.clone();
let mut new_cardinalities = self.cardinalities.clone();
let mut other_indices: Vec<Option<usize>> = vec![None; other.variables.len()];
for (i, var) in other.variables.iter().enumerate() {
if let Some(pos) = self.variables.iter().position(|v| v == var) {
other_indices[i] = Some(pos);
} else {
new_variables.push(var.clone());
new_cardinalities.push(other.cardinalities[i]);
other_indices[i] = Some(new_variables.len() - 1);
}
}
let total_size: usize = new_cardinalities.iter().product();
let mut new_values = vec![0.0; total_size];
for (i, val) in new_values.iter_mut().enumerate() {
let indices = Self::decode_index(i, &new_cardinalities);
let self_idx =
Self::encode_index(&indices[..self.variables.len()], &self.cardinalities);
let other_idx_vec: Vec<usize> = other_indices
.iter()
.filter_map(|&idx| idx.map(|j| indices[j]))
.collect();
let other_idx = Self::encode_index(&other_idx_vec, &other.cardinalities);
let self_val = self.values.get(self_idx).copied().unwrap_or(0.0);
let other_val = other.values.get(other_idx).copied().unwrap_or(0.0);
*val = self_val * other_val;
}
Self {
variables: new_variables,
cardinalities: new_cardinalities,
values: new_values,
}
}
#[must_use]
pub fn marginalize(&self, var: &str) -> Self {
let Some(var_idx) = self.variables.iter().position(|v| v == var) else {
return self.clone();
};
let new_variables: Vec<String> = self
.variables
.iter()
.enumerate()
.filter(|(i, _)| *i != var_idx)
.map(|(_, v)| v.clone())
.collect();
let new_cardinalities: Vec<usize> = self
.cardinalities
.iter()
.enumerate()
.filter(|(i, _)| *i != var_idx)
.map(|(_, c)| *c)
.collect();
if new_variables.is_empty() {
return Self {
variables: vec![],
cardinalities: vec![],
values: vec![self.values.iter().sum()],
};
}
let total_size: usize = new_cardinalities.iter().product();
let mut new_values = vec![0.0; total_size];
for i in 0..self.values.len() {
let indices = Self::decode_index(i, &self.cardinalities);
let new_idx_vec: Vec<usize> = indices
.iter()
.enumerate()
.filter(|(j, _)| *j != var_idx)
.map(|(_, idx)| *idx)
.collect();
let new_idx = if new_idx_vec.is_empty() {
0
} else {
Self::encode_index(&new_idx_vec, &new_cardinalities)
};
new_values[new_idx] += self.values[i];
}
Self {
variables: new_variables,
cardinalities: new_cardinalities,
values: new_values,
}
}
pub fn normalize(&mut self) {
let sum: f64 = self.values.iter().sum();
if sum > 0.0 {
for v in &mut self.values {
*v /= sum;
}
}
}
fn decode_index(mut idx: usize, cardinalities: &[usize]) -> Vec<usize> {
let mut indices = vec![0; cardinalities.len()];
for i in (0..cardinalities.len()).rev() {
indices[i] = idx % cardinalities[i];
idx /= cardinalities[i];
}
indices
}
fn encode_index(indices: &[usize], cardinalities: &[usize]) -> usize {
let mut idx = 0;
let mut multiplier = 1;
for i in (0..indices.len()).rev() {
idx += indices[i] * multiplier;
multiplier *= cardinalities.get(i).copied().unwrap_or(1);
}
idx
}
#[must_use]
pub fn get_probability(&self, assignment: &HashMap<String, usize>) -> f64 {
let indices: Vec<usize> = self
.variables
.iter()
.map(|v| assignment.get(v).copied().unwrap_or(0))
.collect();
let idx = Self::encode_index(&indices, &self.cardinalities);
self.values.get(idx).copied().unwrap_or(0.0)
}
}
pub struct BeliefPropagation {
config: BayesianConfig,
factors: Vec<Factor>,
}
impl BeliefPropagation {
pub fn new(config: BayesianConfig) -> Result<Self, String> {
config.validate()?;
let mut factors = Vec::new();
for (name, node) in &config.nodes {
if node.node_type == NodeType::Discrete {
factors.push(Factor::from_node(name, node, &config));
}
}
Ok(Self { config, factors })
}
pub fn query(&self, target: &str) -> Result<Vec<f64>, String> {
if !self.config.nodes.contains_key(target) {
return Err(format!("Variable '{target}' not found in network"));
}
let order = self.get_elimination_order(target);
let mut factors = self.factors.clone();
for var in order {
if var == target {
continue;
}
let (containing, remaining): (Vec<_>, Vec<_>) = factors
.into_iter()
.partition(|f| f.variables.contains(&var));
if containing.is_empty() {
factors = remaining;
continue;
}
let mut product = containing[0].clone();
for f in containing.iter().skip(1) {
product = product.multiply(f);
}
let marginal = product.marginalize(&var);
factors = remaining;
factors.push(marginal);
}
if factors.is_empty() {
return Err("No factors remaining".to_string());
}
let mut result = factors[0].clone();
for f in factors.iter().skip(1) {
result = result.multiply(f);
}
result.normalize();
if result.variables.len() == 1 && result.variables[0] == target {
let sum: f64 = result.values.iter().sum();
if sum > 0.0 {
Ok(result.values.iter().map(|v| v / sum).collect())
} else {
Ok(result.values.clone())
}
} else {
let mut final_result = result.clone();
for var in &result.variables {
if var != target {
final_result = final_result.marginalize(var);
}
}
let sum: f64 = final_result.values.iter().sum();
if sum > 0.0 {
Ok(final_result.values.iter().map(|v| v / sum).collect())
} else {
Ok(final_result.values)
}
}
}
pub fn query_with_evidence(
&self,
target: &str,
evidence: &HashMap<String, usize>,
) -> Result<Vec<f64>, String> {
if !self.config.nodes.contains_key(target) {
return Err(format!("Variable '{target}' not found in network"));
}
let mut factors: Vec<Factor> = self
.factors
.iter()
.map(|f| Self::apply_evidence(f, evidence))
.collect();
let order = self.get_elimination_order(target);
for var in order {
if var == target || evidence.contains_key(&var) {
continue;
}
let (containing, remaining): (Vec<_>, Vec<_>) = factors
.into_iter()
.partition(|f| f.variables.contains(&var));
if containing.is_empty() {
factors = remaining;
continue;
}
let mut product = containing[0].clone();
for f in containing.iter().skip(1) {
product = product.multiply(f);
}
let marginal = product.marginalize(&var);
factors = remaining;
factors.push(marginal);
}
if factors.is_empty() {
return Err("No factors remaining".to_string());
}
let mut result = factors[0].clone();
for f in factors.iter().skip(1) {
result = result.multiply(f);
}
result.normalize();
if result.variables.len() == 1 && result.variables[0] == target {
let sum: f64 = result.values.iter().sum();
if sum > 0.0 {
Ok(result.values.iter().map(|v| v / sum).collect())
} else {
Ok(result.values.clone())
}
} else {
let mut final_result = result.clone();
for var in &result.variables {
if var != target {
final_result = final_result.marginalize(var);
}
}
let sum: f64 = final_result.values.iter().sum();
if sum > 0.0 {
Ok(final_result.values.iter().map(|v| v / sum).collect())
} else {
Ok(final_result.values)
}
}
}
fn apply_evidence(factor: &Factor, evidence: &HashMap<String, usize>) -> Factor {
let mut new_values = factor.values.clone();
for (i, val) in new_values.iter_mut().enumerate() {
let indices = Factor::decode_index(i, &factor.cardinalities);
for (var_idx, var) in factor.variables.iter().enumerate() {
if let Some(&ev_val) = evidence.get(var) {
if indices[var_idx] != ev_val {
*val = 0.0;
break;
}
}
}
}
Factor {
variables: factor.variables.clone(),
cardinalities: factor.cardinalities.clone(),
values: new_values,
}
}
fn get_elimination_order(&self, exclude: &str) -> Vec<String> {
let mut order = self.config.topological_order();
order.reverse();
order.retain(|v| v != exclude);
order
}
#[must_use]
pub const fn config(&self) -> &BayesianConfig {
&self.config
}
}
#[cfg(test)]
mod inference_tests {
use super::*;
fn create_simple_network() -> BayesianConfig {
BayesianConfig::new("Sprinkler")
.with_node(
"rain",
BayesianNode::discrete(vec!["no", "yes"]).with_prior(vec![0.8, 0.2]),
)
.with_node(
"sprinkler",
BayesianNode::discrete(vec!["off", "on"])
.with_parents(vec!["rain"])
.with_cpt_entry("no", vec![0.6, 0.4])
.with_cpt_entry("yes", vec![0.99, 0.01]),
)
}
#[test]
fn test_prior_query() {
let config = create_simple_network();
let bp = BeliefPropagation::new(config).unwrap();
let rain_probs = bp.query("rain").unwrap();
assert!(
(rain_probs[0] - 0.8).abs() < 0.01,
"P(rain=no) should be 0.8"
);
assert!(
(rain_probs[1] - 0.2).abs() < 0.01,
"P(rain=yes) should be 0.2"
);
}
#[test]
fn test_marginal_query() {
let config = create_simple_network();
let bp = BeliefPropagation::new(config).unwrap();
let sprinkler_probs = bp.query("sprinkler").unwrap();
let expected_on = 0.4f64.mul_add(0.8, 0.01 * 0.2);
assert!(
(sprinkler_probs[1] - expected_on).abs() < 0.01,
"P(sprinkler=on) should be {}, got {}",
expected_on,
sprinkler_probs[1]
);
}
#[test]
fn test_evidence_query() {
let config = create_simple_network();
let bp = BeliefPropagation::new(config).unwrap();
let mut evidence = HashMap::new();
evidence.insert("rain".to_string(), 1);
let probs = bp.query_with_evidence("sprinkler", &evidence).unwrap();
assert!(
(probs[1] - 0.01).abs() < 0.01,
"P(sprinkler=on | rain=yes) should be 0.01, got {}",
probs[1]
);
}
}