Skip to main content

oxicuda_seq/mrf/
belief_prop.rs

1//! Loopy belief propagation on a pairwise MRF (sum-product + max-product).
2
3use super::mrf::Mrf;
4use crate::error::{SeqError, SeqResult};
5
6/// BP configuration.
7#[derive(Debug, Clone, Copy)]
8pub struct BpConfig {
9    pub max_iter: usize,
10    pub tol: f64,
11    pub damping: f64,
12}
13
14impl Default for BpConfig {
15    fn default() -> Self {
16        Self {
17            max_iter: 50,
18            tol: 1e-5,
19            damping: 0.5,
20        }
21    }
22}
23
24/// Result of BP marginal inference.
25#[derive(Debug, Clone)]
26pub struct BpResult {
27    pub marginals: Vec<f64>,
28    pub iterations: usize,
29    pub converged: bool,
30}
31
32/// Loopy sum-product BP on a pairwise MRF.  Computes approximate node marginals.
33///
34/// Internally uses *log-space* messages so values do not under/overflow on
35/// high-energy graphs.
36pub fn loopy_bp_marginals(mrf: &Mrf, cfg: &BpConfig) -> SeqResult<BpResult> {
37    if cfg.max_iter == 0 {
38        return Err(SeqError::InvalidConfiguration(
39            "max_iter must be > 0".to_string(),
40        ));
41    }
42    let nl = mrf.n_labels;
43    let l2 = nl * nl;
44    // Directed messages per (edge_idx, direction): u→v at idx*2, v→u at idx*2+1.
45    let n_messages = mrf.edges.len() * 2;
46    let mut log_msg = vec![0.0; n_messages * nl];
47    let mut new_log_msg = log_msg.clone();
48    let mut converged = false;
49    let mut iters = 0;
50
51    for it in 0..cfg.max_iter {
52        iters = it + 1;
53        for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
54            for &(src, _dst, msg_idx, _opp_idx) in &[
55                (u, v, e_idx * 2, e_idx * 2 + 1),
56                (v, u, e_idx * 2 + 1, e_idx * 2),
57            ] {
58                // Build message[l_dst] = logsumexp_{l_src}(unary[src][l_src] +
59                //                       pairwise[edge][l_u, l_v] (oriented) +
60                //                       Σ_k≠e log_msg[k→src][l_src]).
61                let mut out = vec![f64::NEG_INFINITY; nl];
62                for l_dst in 0..nl {
63                    let mut terms = vec![0.0; nl];
64                    for l_src in 0..nl {
65                        let mut acc = -mrf.unary[src * nl + l_src];
66                        // pairwise oriented: edge is stored as (u, v); apply correct order.
67                        let psi = if src == u {
68                            mrf.pairwise[e_idx * l2 + l_src * nl + l_dst]
69                        } else {
70                            mrf.pairwise[e_idx * l2 + l_dst * nl + l_src]
71                        };
72                        acc -= psi;
73                        // Incoming messages from all neighbours of `src` except this edge.
74                        for (k_idx, &(uu, vv)) in mrf.edges.iter().enumerate() {
75                            if k_idx == e_idx {
76                                continue;
77                            }
78                            let in_msg = if uu == src {
79                                &log_msg[(k_idx * 2 + 1) * nl..]
80                            } else if vv == src {
81                                &log_msg[(k_idx * 2) * nl..]
82                            } else {
83                                continue;
84                            };
85                            acc += in_msg[l_src];
86                        }
87                        terms[l_src] = acc;
88                    }
89                    out[l_dst] = logsumexp_in(&terms);
90                }
91                // Normalise (log-domain)
92                let m = out.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
93                for v in out.iter_mut() {
94                    *v -= m;
95                }
96                // Damping + write
97                for l in 0..nl {
98                    new_log_msg[msg_idx * nl + l] =
99                        (1.0 - cfg.damping) * log_msg[msg_idx * nl + l] + cfg.damping * out[l];
100                }
101            }
102        }
103        let mut max_diff = 0.0_f64;
104        for k in 0..log_msg.len() {
105            let d = (new_log_msg[k] - log_msg[k]).abs();
106            if d > max_diff {
107                max_diff = d;
108            }
109        }
110        log_msg.copy_from_slice(&new_log_msg);
111        if max_diff < cfg.tol {
112            converged = true;
113            break;
114        }
115    }
116
117    // Compute marginals
118    let mut marginals = vec![0.0; mrf.n_nodes * nl];
119    for i in 0..mrf.n_nodes {
120        let mut log_b = vec![0.0; nl];
121        for l in 0..nl {
122            log_b[l] = -mrf.unary[i * nl + l];
123        }
124        for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
125            if u == i {
126                for l in 0..nl {
127                    log_b[l] += log_msg[(e_idx * 2 + 1) * nl + l];
128                }
129            }
130            if v == i {
131                for l in 0..nl {
132                    log_b[l] += log_msg[(e_idx * 2) * nl + l];
133                }
134            }
135        }
136        let m = log_b.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
137        let mut s = 0.0;
138        let mut exps = vec![0.0; nl];
139        for l in 0..nl {
140            exps[l] = (log_b[l] - m).exp();
141            s += exps[l];
142        }
143        for l in 0..nl {
144            marginals[i * nl + l] = if s > 0.0 {
145                exps[l] / s
146            } else {
147                1.0 / nl as f64
148            };
149        }
150    }
151    Ok(BpResult {
152        marginals,
153        iterations: iters,
154        converged,
155    })
156}
157
158/// Loopy max-product BP for MAP inference.  Returns the MAP labelling.
159pub fn loopy_bp_map(mrf: &Mrf, cfg: &BpConfig) -> SeqResult<Vec<usize>> {
160    if cfg.max_iter == 0 {
161        return Err(SeqError::InvalidConfiguration(
162            "max_iter must be > 0".to_string(),
163        ));
164    }
165    let nl = mrf.n_labels;
166    let l2 = nl * nl;
167    let n_messages = mrf.edges.len() * 2;
168    let mut log_msg = vec![0.0; n_messages * nl];
169    let mut new_log_msg = log_msg.clone();
170
171    for _ in 0..cfg.max_iter {
172        for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
173            for &(src, dst, msg_idx) in &[(u, v, e_idx * 2), (v, u, e_idx * 2 + 1)] {
174                let _ = dst;
175                let mut out = vec![f64::NEG_INFINITY; nl];
176                for l_dst in 0..nl {
177                    let mut best = f64::NEG_INFINITY;
178                    for l_src in 0..nl {
179                        let mut acc = -mrf.unary[src * nl + l_src];
180                        let psi = if src == u {
181                            mrf.pairwise[e_idx * l2 + l_src * nl + l_dst]
182                        } else {
183                            mrf.pairwise[e_idx * l2 + l_dst * nl + l_src]
184                        };
185                        acc -= psi;
186                        for (k_idx, &(uu, vv)) in mrf.edges.iter().enumerate() {
187                            if k_idx == e_idx {
188                                continue;
189                            }
190                            let in_msg = if uu == src {
191                                &log_msg[(k_idx * 2 + 1) * nl..]
192                            } else if vv == src {
193                                &log_msg[(k_idx * 2) * nl..]
194                            } else {
195                                continue;
196                            };
197                            acc += in_msg[l_src];
198                        }
199                        if acc > best {
200                            best = acc;
201                        }
202                    }
203                    out[l_dst] = best;
204                }
205                let m = out.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
206                for v in out.iter_mut() {
207                    *v -= m;
208                }
209                for l in 0..nl {
210                    new_log_msg[msg_idx * nl + l] =
211                        (1.0 - cfg.damping) * log_msg[msg_idx * nl + l] + cfg.damping * out[l];
212                }
213            }
214        }
215        log_msg.copy_from_slice(&new_log_msg);
216    }
217
218    // Decode
219    let mut labels = vec![0usize; mrf.n_nodes];
220    for i in 0..mrf.n_nodes {
221        let mut best_l = 0usize;
222        let mut best_v = f64::NEG_INFINITY;
223        for l in 0..nl {
224            let mut acc = -mrf.unary[i * nl + l];
225            for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
226                if u == i {
227                    acc += log_msg[(e_idx * 2 + 1) * nl + l];
228                }
229                if v == i {
230                    acc += log_msg[(e_idx * 2) * nl + l];
231                }
232            }
233            if acc > best_v {
234                best_v = acc;
235                best_l = l;
236            }
237        }
238        labels[i] = best_l;
239    }
240    Ok(labels)
241}
242
243fn logsumexp_in(xs: &[f64]) -> f64 {
244    let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
245    if m == f64::NEG_INFINITY {
246        return f64::NEG_INFINITY;
247    }
248    let s: f64 = xs.iter().map(|x| (x - m).exp()).sum();
249    m + s.ln()
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn bp_marginals_normalise() {
258        let m = Mrf::new(
259            3,
260            2,
261            vec![(0, 1), (1, 2)],
262            vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
263            vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
264        )
265        .expect("ok");
266        let res = loopy_bp_marginals(&m, &BpConfig::default()).expect("ok");
267        for i in 0..m.n_nodes {
268            let s: f64 = res.marginals[i * m.n_labels..(i + 1) * m.n_labels]
269                .iter()
270                .sum();
271            assert!((s - 1.0).abs() < 1e-6, "row sum {s}");
272        }
273    }
274
275    #[test]
276    fn bp_map_runs() {
277        let m = Mrf::new(
278            3,
279            2,
280            vec![(0, 1), (1, 2)],
281            vec![0.0, 5.0, 5.0, 0.0, 0.0, 5.0],
282            vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
283        )
284        .expect("ok");
285        let labels = loopy_bp_map(&m, &BpConfig::default()).expect("ok");
286        assert_eq!(labels.len(), 3);
287    }
288}