use super::mrf::Mrf;
use crate::error::{SeqError, SeqResult};
#[derive(Debug, Clone, Copy)]
pub struct BpConfig {
pub max_iter: usize,
pub tol: f64,
pub damping: f64,
}
impl Default for BpConfig {
fn default() -> Self {
Self {
max_iter: 50,
tol: 1e-5,
damping: 0.5,
}
}
}
#[derive(Debug, Clone)]
pub struct BpResult {
pub marginals: Vec<f64>,
pub iterations: usize,
pub converged: bool,
}
pub fn loopy_bp_marginals(mrf: &Mrf, cfg: &BpConfig) -> SeqResult<BpResult> {
if cfg.max_iter == 0 {
return Err(SeqError::InvalidConfiguration(
"max_iter must be > 0".to_string(),
));
}
let nl = mrf.n_labels;
let l2 = nl * nl;
let n_messages = mrf.edges.len() * 2;
let mut log_msg = vec![0.0; n_messages * nl];
let mut new_log_msg = log_msg.clone();
let mut converged = false;
let mut iters = 0;
for it in 0..cfg.max_iter {
iters = it + 1;
for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
for &(src, _dst, msg_idx, _opp_idx) in &[
(u, v, e_idx * 2, e_idx * 2 + 1),
(v, u, e_idx * 2 + 1, e_idx * 2),
] {
let mut out = vec![f64::NEG_INFINITY; nl];
for l_dst in 0..nl {
let mut terms = vec![0.0; nl];
for l_src in 0..nl {
let mut acc = -mrf.unary[src * nl + l_src];
let psi = if src == u {
mrf.pairwise[e_idx * l2 + l_src * nl + l_dst]
} else {
mrf.pairwise[e_idx * l2 + l_dst * nl + l_src]
};
acc -= psi;
for (k_idx, &(uu, vv)) in mrf.edges.iter().enumerate() {
if k_idx == e_idx {
continue;
}
let in_msg = if uu == src {
&log_msg[(k_idx * 2 + 1) * nl..]
} else if vv == src {
&log_msg[(k_idx * 2) * nl..]
} else {
continue;
};
acc += in_msg[l_src];
}
terms[l_src] = acc;
}
out[l_dst] = logsumexp_in(&terms);
}
let m = out.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
for v in out.iter_mut() {
*v -= m;
}
for l in 0..nl {
new_log_msg[msg_idx * nl + l] =
(1.0 - cfg.damping) * log_msg[msg_idx * nl + l] + cfg.damping * out[l];
}
}
}
let mut max_diff = 0.0_f64;
for k in 0..log_msg.len() {
let d = (new_log_msg[k] - log_msg[k]).abs();
if d > max_diff {
max_diff = d;
}
}
log_msg.copy_from_slice(&new_log_msg);
if max_diff < cfg.tol {
converged = true;
break;
}
}
let mut marginals = vec![0.0; mrf.n_nodes * nl];
for i in 0..mrf.n_nodes {
let mut log_b = vec![0.0; nl];
for l in 0..nl {
log_b[l] = -mrf.unary[i * nl + l];
}
for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
if u == i {
for l in 0..nl {
log_b[l] += log_msg[(e_idx * 2 + 1) * nl + l];
}
}
if v == i {
for l in 0..nl {
log_b[l] += log_msg[(e_idx * 2) * nl + l];
}
}
}
let m = log_b.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let mut s = 0.0;
let mut exps = vec![0.0; nl];
for l in 0..nl {
exps[l] = (log_b[l] - m).exp();
s += exps[l];
}
for l in 0..nl {
marginals[i * nl + l] = if s > 0.0 {
exps[l] / s
} else {
1.0 / nl as f64
};
}
}
Ok(BpResult {
marginals,
iterations: iters,
converged,
})
}
pub fn loopy_bp_map(mrf: &Mrf, cfg: &BpConfig) -> SeqResult<Vec<usize>> {
if cfg.max_iter == 0 {
return Err(SeqError::InvalidConfiguration(
"max_iter must be > 0".to_string(),
));
}
let nl = mrf.n_labels;
let l2 = nl * nl;
let n_messages = mrf.edges.len() * 2;
let mut log_msg = vec![0.0; n_messages * nl];
let mut new_log_msg = log_msg.clone();
for _ in 0..cfg.max_iter {
for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
for &(src, dst, msg_idx) in &[(u, v, e_idx * 2), (v, u, e_idx * 2 + 1)] {
let _ = dst;
let mut out = vec![f64::NEG_INFINITY; nl];
for l_dst in 0..nl {
let mut best = f64::NEG_INFINITY;
for l_src in 0..nl {
let mut acc = -mrf.unary[src * nl + l_src];
let psi = if src == u {
mrf.pairwise[e_idx * l2 + l_src * nl + l_dst]
} else {
mrf.pairwise[e_idx * l2 + l_dst * nl + l_src]
};
acc -= psi;
for (k_idx, &(uu, vv)) in mrf.edges.iter().enumerate() {
if k_idx == e_idx {
continue;
}
let in_msg = if uu == src {
&log_msg[(k_idx * 2 + 1) * nl..]
} else if vv == src {
&log_msg[(k_idx * 2) * nl..]
} else {
continue;
};
acc += in_msg[l_src];
}
if acc > best {
best = acc;
}
}
out[l_dst] = best;
}
let m = out.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
for v in out.iter_mut() {
*v -= m;
}
for l in 0..nl {
new_log_msg[msg_idx * nl + l] =
(1.0 - cfg.damping) * log_msg[msg_idx * nl + l] + cfg.damping * out[l];
}
}
}
log_msg.copy_from_slice(&new_log_msg);
}
let mut labels = vec![0usize; mrf.n_nodes];
for i in 0..mrf.n_nodes {
let mut best_l = 0usize;
let mut best_v = f64::NEG_INFINITY;
for l in 0..nl {
let mut acc = -mrf.unary[i * nl + l];
for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
if u == i {
acc += log_msg[(e_idx * 2 + 1) * nl + l];
}
if v == i {
acc += log_msg[(e_idx * 2) * nl + l];
}
}
if acc > best_v {
best_v = acc;
best_l = l;
}
}
labels[i] = best_l;
}
Ok(labels)
}
fn logsumexp_in(xs: &[f64]) -> f64 {
let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if m == f64::NEG_INFINITY {
return f64::NEG_INFINITY;
}
let s: f64 = xs.iter().map(|x| (x - m).exp()).sum();
m + s.ln()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bp_marginals_normalise() {
let m = Mrf::new(
3,
2,
vec![(0, 1), (1, 2)],
vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
)
.expect("ok");
let res = loopy_bp_marginals(&m, &BpConfig::default()).expect("ok");
for i in 0..m.n_nodes {
let s: f64 = res.marginals[i * m.n_labels..(i + 1) * m.n_labels]
.iter()
.sum();
assert!((s - 1.0).abs() < 1e-6, "row sum {s}");
}
}
#[test]
fn bp_map_runs() {
let m = Mrf::new(
3,
2,
vec![(0, 1), (1, 2)],
vec![0.0, 5.0, 5.0, 0.0, 0.0, 5.0],
vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
)
.expect("ok");
let labels = loopy_bp_map(&m, &BpConfig::default()).expect("ok");
assert_eq!(labels.len(), 3);
}
}