use scirs2_core::ndarray::{Array1, ArrayD};
use std::collections::HashMap;
use crate::error::{PgmError, Result};
use crate::graph::FactorGraph;
use crate::message_passing::MessagePassingAlgorithm;
use super::config::{LoopyBpConfig, LoopyBpResult};
use super::cycle::CycleDetector;
use super::energy::bethe_free_energy;
use super::types::{LbpConvergenceMonitor, LbpIterStats, LogMessage, UpdateSchedule};
#[derive(Clone, Debug, Default)]
pub(super) struct LogMessageStore {
pub(super) var_to_factor: HashMap<(String, String), LogMessage>,
pub(super) factor_to_var: HashMap<(String, String), LogMessage>,
}
impl LogMessageStore {
pub(super) fn new() -> Self {
Self::default()
}
pub(super) fn get_vtf(&self, var: &str, fac: &str) -> Option<&LogMessage> {
self.var_to_factor.get(&(var.to_string(), fac.to_string()))
}
pub(super) fn set_vtf(&mut self, var: String, fac: String, msg: LogMessage) {
self.var_to_factor.insert((var, fac), msg);
}
pub(super) fn get_ftv(&self, fac: &str, var: &str) -> Option<&LogMessage> {
self.factor_to_var.get(&(fac.to_string(), var.to_string()))
}
pub(super) fn set_ftv(&mut self, fac: String, var: String, msg: LogMessage) {
self.factor_to_var.insert((fac, var), msg);
}
}
pub struct LoopyBeliefPropagation {
pub config: LoopyBpConfig,
}
impl LoopyBeliefPropagation {
pub fn new(config: LoopyBpConfig) -> Self {
Self { config }
}
pub fn run_full(&self, graph: &FactorGraph) -> Result<LoopyBpResult> {
let cycle_analysis = CycleDetector::new(graph).analyse();
let mut messages = self.initialise_messages(graph);
let mut monitor = LbpConvergenceMonitor::new();
match self.config.schedule {
UpdateSchedule::Synchronous => {
self.run_synchronous(graph, &mut messages, &mut monitor)?;
}
UpdateSchedule::Sequential => {
self.run_sequential(graph, &mut messages, &mut monitor)?;
}
UpdateSchedule::Residual => {
self.run_residual(graph, &mut messages, &mut monitor)?;
}
}
let beliefs = self.compute_variable_beliefs(graph, &messages)?;
let factor_beliefs = self.compute_factor_beliefs(graph, &messages)?;
let bethe = if self.config.compute_bethe {
Some(bethe_free_energy(graph, &beliefs, &factor_beliefs))
} else {
None
};
Ok(LoopyBpResult {
beliefs,
factor_beliefs,
convergence: monitor,
bethe,
cycle_analysis,
})
}
pub(super) fn initialise_messages(&self, graph: &FactorGraph) -> LogMessageStore {
let mut store = LogMessageStore::new();
for var_name in graph.variable_names() {
let card = graph
.get_variable(var_name)
.map(|v| v.cardinality)
.unwrap_or(2);
if let Some(fac_ids) = graph.get_adjacent_factors(var_name) {
for fac_id in fac_ids {
store.set_vtf(
var_name.clone(),
fac_id.clone(),
LogMessage::uniform(var_name, card),
);
let ftv_msg = if let Some(factor) = graph.get_factor(fac_id) {
let marginal = self.marginalise_factor_to_var(factor, var_name);
marginal.unwrap_or_else(|_| LogMessage::uniform(var_name, card))
} else {
LogMessage::uniform(var_name, card)
};
store.set_ftv(fac_id.clone(), var_name.clone(), ftv_msg);
}
}
}
store
}
pub(super) fn compute_vtf_message(
&self,
graph: &FactorGraph,
messages: &LogMessageStore,
var: &str,
target_fac: &str,
) -> Result<LogMessage> {
let card = graph
.get_variable(var)
.ok_or_else(|| PgmError::VariableNotFound(var.to_string()))?
.cardinality;
let mut log_msg = Array1::<f64>::zeros(card);
if let Some(fac_ids) = graph.get_adjacent_factors(var) {
for fac_id in fac_ids {
if fac_id == target_fac {
continue;
}
if let Some(ftv) = messages.get_ftv(fac_id, var) {
log_msg += &ftv.log_values;
}
}
}
let mut msg = LogMessage {
variable: var.to_string(),
log_values: log_msg,
};
msg.log_normalise();
Ok(msg)
}
pub(super) fn compute_ftv_message(
&self,
graph: &FactorGraph,
messages: &LogMessageStore,
fac_id: &str,
target_var: &str,
) -> Result<LogMessage> {
let factor = graph
.get_factor(fac_id)
.ok_or_else(|| PgmError::FactorNotFound(fac_id.to_string()))?;
let target_idx = factor
.variables
.iter()
.position(|v| v == target_var)
.ok_or_else(|| {
PgmError::VariableNotFound(format!(
"Variable '{}' not in factor '{}'",
target_var, fac_id
))
})?;
let target_card = factor.values.shape()[target_idx];
let total_size: usize = factor.values.shape().iter().product();
let mut log_joint = Vec::with_capacity(total_size);
for lin_idx in 0..total_size {
let assignment = linear_to_assignment(lin_idx, factor.values.shape());
let mut log_val = {
let phi = factor.values[assignment.as_slice()];
if phi > 1e-300 {
phi.ln()
} else {
-700.0
}
};
for (dim, var_name) in factor.variables.iter().enumerate() {
if var_name == target_var {
continue;
}
if let Some(vtf) = messages.get_vtf(var_name, fac_id) {
let val_idx = assignment[dim];
let lv = vtf.log_values.get(val_idx).copied().unwrap_or(-700.0);
log_val += lv;
}
}
log_joint.push((assignment[target_idx], log_val));
}
let mut result = vec![f64::NEG_INFINITY; target_card];
for (t_val, lv) in log_joint {
let cur = result[t_val];
if cur == f64::NEG_INFINITY {
result[t_val] = lv;
} else {
let m = cur.max(lv);
result[t_val] = m + ((cur - m).exp() + (lv - m).exp()).ln();
}
}
let mut msg = LogMessage {
variable: target_var.to_string(),
log_values: Array1::from(result),
};
msg.log_normalise();
Ok(msg)
}
pub(super) fn apply_updates_and_track(
&self,
messages: &mut LogMessageStore,
new_vtf: HashMap<(String, String), LogMessage>,
new_ftv: HashMap<(String, String), LogMessage>,
iteration: usize,
) -> LbpIterStats {
let mut max_residual = 0.0_f64;
let mut sum_residual = 0.0_f64;
let mut count = 0usize;
let mut active = 0usize;
for ((var, fac), new_msg) in new_vtf {
let old = messages.get_vtf(&var, &fac).cloned();
let residual = old
.as_ref()
.map(|o| new_msg.residual_linf(o))
.unwrap_or(1.0);
let lambda = self.config.damping.effective_lambda(residual);
let final_msg = if let Some(o) = &old {
new_msg.damp(o, lambda)
} else {
new_msg
};
max_residual = max_residual.max(residual);
sum_residual += residual;
count += 1;
if residual >= self.config.tolerance {
active += 1;
}
messages.set_vtf(var, fac, final_msg);
}
for ((fac, var), new_msg) in new_ftv {
let old = messages.get_ftv(&fac, &var).cloned();
let residual = old
.as_ref()
.map(|o| new_msg.residual_linf(o))
.unwrap_or(1.0);
let lambda = self.config.damping.effective_lambda(residual);
let final_msg = if let Some(o) = &old {
new_msg.damp(o, lambda)
} else {
new_msg
};
max_residual = max_residual.max(residual);
sum_residual += residual;
count += 1;
if residual >= self.config.tolerance {
active += 1;
}
messages.set_ftv(fac, var, final_msg);
}
let mean_residual = if count > 0 {
sum_residual / count as f64
} else {
0.0
};
LbpIterStats {
iteration,
max_residual,
mean_residual,
active_messages: active,
}
}
}
impl MessagePassingAlgorithm for LoopyBeliefPropagation {
fn run(
&self,
graph: &FactorGraph,
) -> std::result::Result<HashMap<String, ArrayD<f64>>, crate::error::PgmError> {
let result = self.run_full(graph)?;
let beliefs_dyn: HashMap<String, ArrayD<f64>> = result
.beliefs
.into_iter()
.map(|(k, v)| (k, v.into_dyn()))
.collect();
Ok(beliefs_dyn)
}
fn name(&self) -> &str {
"LoopyBeliefPropagation"
}
}
pub(super) fn linear_to_assignment(mut lin: usize, shape: &[usize]) -> Vec<usize> {
let mut assignment = vec![0usize; shape.len()];
for (i, &dim) in shape.iter().enumerate().rev() {
assignment[i] = lin % dim;
lin /= dim;
}
assignment
}