Skip to main content

oxicuda_seq/grid_crf/
loopy_bp.rs

1//! Loopy belief propagation (sum-product) on a 4-connected 2-D grid CRF.
2//!
3//! Each grid node `i` carries a unary **log-potential** vector
4//! `unary[i*K + l]` and every edge shares a pairwise log-potential matrix
5//! `pairwise[a*K + b]`. The joint distribution is
6//!
7//! ```text
8//! P(x) ∝ exp( Σ_i unary[i, x_i] + Σ_{(i,j)∈E} pairwise[x_i, x_j] )
9//! ```
10//!
11//! so a *larger* unary entry favours that label and a diagonally dominant
12//! pairwise matrix (`pairwise[a,a] > pairwise[a,b]`) is *attractive* (encourages
13//! neighbouring nodes to share a label).
14//!
15//! Messages are passed in log-space (log-sum-exp) with damping until the
16//! largest message change drops below `tol` or `max_iter` sweeps elapse. On a
17//! loop-free grid (e.g. a `1×N` chain) the fixed point is **exact**; on a grid
18//! with cycles it is the standard loopy-BP approximation.
19
20use crate::error::{SeqError, SeqResult};
21use crate::hmm::forward_backward::logsumexp;
22
23/// Configuration for loopy belief propagation.
24#[derive(Debug, Clone, Copy)]
25pub struct LoopyBpConfig {
26    /// Maximum number of synchronous message-update sweeps.
27    pub max_iter: usize,
28    /// Convergence threshold on the largest absolute log-message change.
29    pub tol: f64,
30    /// Update step in `(0, 1]`: `new = (1−damp)·old + damp·update`. A value of
31    /// `1.0` is undamped (standard) BP; smaller values damp the updates.
32    pub damping: f64,
33}
34
35impl Default for LoopyBpConfig {
36    fn default() -> Self {
37        Self {
38            max_iter: 200,
39            tol: 1e-9,
40            damping: 0.5,
41        }
42    }
43}
44
45/// Inference output: per-node marginals plus convergence diagnostics.
46#[derive(Debug, Clone)]
47pub struct LoopyBpResult {
48    /// Per-node marginals `[H*W*n_states]`, normalised per node.
49    pub marginals: Vec<f64>,
50    /// Number of sweeps actually performed.
51    pub iterations: usize,
52    /// Whether the message change fell below `tol` before `max_iter`.
53    pub converged: bool,
54}
55
56/// Loopy belief-propagation engine for a fixed grid topology.
57///
58/// The grid geometry (edge list and per-node incident-message bookkeeping) is
59/// built once in [`LoopyBp::new`] and reused across [`LoopyBp::infer`] calls.
60#[derive(Debug, Clone)]
61pub struct LoopyBp {
62    height: usize,
63    width: usize,
64    n_states: usize,
65    config: LoopyBpConfig,
66    /// Undirected edges `(u, v)` with `u < v`.
67    edges: Vec<(usize, usize)>,
68    /// For each node, the `(incoming_slot, edge_index)` of every incident edge.
69    incident: Vec<Vec<(usize, usize)>>,
70}
71
72impl LoopyBp {
73    /// Build an inference engine for a `height × width` grid with `n_states`
74    /// labels per node.
75    ///
76    /// # Errors
77    /// * [`SeqError::InvalidConfiguration`] if any dimension is `0` or
78    ///   `config.max_iter == 0`.
79    /// * [`SeqError::InvalidParameter`] if `config.damping` is outside `(0, 1]`.
80    pub fn new(
81        height: usize,
82        width: usize,
83        n_states: usize,
84        config: LoopyBpConfig,
85    ) -> SeqResult<Self> {
86        if height == 0 || width == 0 || n_states == 0 {
87            return Err(SeqError::InvalidConfiguration(
88                "height, width and n_states must all be > 0".to_string(),
89            ));
90        }
91        if config.max_iter == 0 {
92            return Err(SeqError::InvalidConfiguration(
93                "max_iter must be > 0".to_string(),
94            ));
95        }
96        if config.damping <= 0.0 || config.damping > 1.0 {
97            return Err(SeqError::InvalidParameter {
98                name: "damping".to_string(),
99                value: config.damping,
100            });
101        }
102
103        // 4-connected edge list, each stored with the smaller node first.
104        let mut edges = Vec::new();
105        for r in 0..height {
106            for c in 0..width {
107                let node = r * width + c;
108                if c + 1 < width {
109                    edges.push((node, node + 1)); // horizontal
110                }
111                if r + 1 < height {
112                    edges.push((node, node + width)); // vertical
113                }
114            }
115        }
116
117        // Directed message slots: u→v at 2e, v→u at 2e+1. The message arriving
118        // *at* a node from edge e is the one pointing toward it.
119        let n_nodes = height * width;
120        let mut incident: Vec<Vec<(usize, usize)>> = vec![Vec::new(); n_nodes];
121        for (e, &(u, v)) in edges.iter().enumerate() {
122            incident[u].push((2 * e + 1, e)); // v→u arrives at u
123            incident[v].push((2 * e, e)); // u→v arrives at v
124        }
125
126        Ok(Self {
127            height,
128            width,
129            n_states,
130            config,
131            edges,
132            incident,
133        })
134    }
135
136    /// Grid height.
137    pub fn height(&self) -> usize {
138        self.height
139    }
140
141    /// Grid width.
142    pub fn width(&self) -> usize {
143        self.width
144    }
145
146    /// Number of labels per node.
147    pub fn n_states(&self) -> usize {
148        self.n_states
149    }
150
151    /// Run sum-product BP and return the per-node marginals `[H*W*n_states]`,
152    /// normalised so each node's distribution sums to 1.
153    ///
154    /// # Errors
155    /// [`SeqError::ShapeMismatch`] if `unary` or `pairwise` has the wrong length.
156    pub fn infer(&self, unary: &[f64], pairwise: &[f64]) -> SeqResult<Vec<f64>> {
157        Ok(self.infer_detailed(unary, pairwise)?.marginals)
158    }
159
160    /// Run sum-product BP, returning marginals together with the sweep count and
161    /// whether convergence was reached. See [`LoopyBp::infer`] for the error
162    /// conditions.
163    pub fn infer_detailed(&self, unary: &[f64], pairwise: &[f64]) -> SeqResult<LoopyBpResult> {
164        let k = self.n_states;
165        let n_nodes = self.height * self.width;
166        if unary.len() != n_nodes * k {
167            return Err(SeqError::ShapeMismatch {
168                expected: n_nodes * k,
169                got: unary.len(),
170            });
171        }
172        if pairwise.len() != k * k {
173            return Err(SeqError::ShapeMismatch {
174                expected: k * k,
175                got: pairwise.len(),
176            });
177        }
178
179        let damp = self.config.damping;
180        let n_slots = self.edges.len() * 2;
181        let mut log_msg = vec![0.0f64; n_slots * k];
182        let mut new_log_msg = log_msg.clone();
183        let mut terms = vec![0.0f64; k];
184        let mut out = vec![0.0f64; k];
185        let mut converged = false;
186        let mut iterations = 0usize;
187
188        for it in 0..self.config.max_iter {
189            iterations = it + 1;
190            for (e, &(u, v)) in self.edges.iter().enumerate() {
191                // Both directed messages along this edge.
192                for &(src, dst, out_slot) in &[(u, v, 2 * e), (v, u, 2 * e + 1)] {
193                    let _ = dst;
194                    for l_dst in 0..k {
195                        for l_src in 0..k {
196                            // Oriented pairwise: edge is stored as (u, v).
197                            let psi = if src == u {
198                                pairwise[l_src * k + l_dst]
199                            } else {
200                                pairwise[l_dst * k + l_src]
201                            };
202                            let mut acc = unary[src * k + l_src] + psi;
203                            // Product of incoming messages from every neighbour
204                            // of `src` except the one on this edge.
205                            for &(in_slot, in_edge) in &self.incident[src] {
206                                if in_edge == e {
207                                    continue;
208                                }
209                                acc += log_msg[in_slot * k + l_src];
210                            }
211                            terms[l_src] = acc;
212                        }
213                        out[l_dst] = logsumexp(&terms);
214                    }
215                    // Normalise in the log domain for numerical stability.
216                    let m = out.iter().copied().fold(f64::NEG_INFINITY, f64::max);
217                    if m > f64::NEG_INFINITY {
218                        for val in out.iter_mut() {
219                            *val -= m;
220                        }
221                    }
222                    for l in 0..k {
223                        let base = out_slot * k + l;
224                        new_log_msg[base] = (1.0 - damp) * log_msg[base] + damp * out[l];
225                    }
226                }
227            }
228
229            let mut max_diff = 0.0f64;
230            for idx in 0..log_msg.len() {
231                let d = (new_log_msg[idx] - log_msg[idx]).abs();
232                if d > max_diff {
233                    max_diff = d;
234                }
235            }
236            log_msg.copy_from_slice(&new_log_msg);
237            if max_diff < self.config.tol {
238                converged = true;
239                break;
240            }
241        }
242
243        let marginals = self.node_marginals(unary, &log_msg);
244        Ok(LoopyBpResult {
245            marginals,
246            iterations,
247            converged,
248        })
249    }
250
251    /// Combine the unary potential of each node with its incoming messages and
252    /// normalise to a proper distribution.
253    fn node_marginals(&self, unary: &[f64], log_msg: &[f64]) -> Vec<f64> {
254        let k = self.n_states;
255        let n_nodes = self.height * self.width;
256        let mut marginals = vec![0.0f64; n_nodes * k];
257        let mut log_b = vec![0.0f64; k];
258        for i in 0..n_nodes {
259            for l in 0..k {
260                log_b[l] = unary[i * k + l];
261            }
262            for &(in_slot, _e) in &self.incident[i] {
263                for l in 0..k {
264                    log_b[l] += log_msg[in_slot * k + l];
265                }
266            }
267            let m = log_b.iter().copied().fold(f64::NEG_INFINITY, f64::max);
268            let mut s = 0.0;
269            for l in 0..k {
270                let val = (log_b[l] - m).exp();
271                marginals[i * k + l] = val;
272                s += val;
273            }
274            for l in 0..k {
275                marginals[i * k + l] = if s > 0.0 {
276                    marginals[i * k + l] / s
277                } else {
278                    1.0 / k as f64
279                };
280            }
281        }
282        marginals
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    /// Exact chain marginals by brute-force enumeration of all `k^n` labellings.
291    fn brute_force_chain_marginals(
292        unary: &[f64],
293        pairwise: &[f64],
294        n: usize,
295        k: usize,
296    ) -> Vec<f64> {
297        let mut marg = vec![0.0f64; n * k];
298        let mut z = 0.0f64;
299        let total = k.pow(n as u32);
300        let mut labels = vec![0usize; n];
301        for code in 0..total {
302            let mut x = code;
303            for t in 0..n {
304                labels[t] = x % k;
305                x /= k;
306            }
307            let mut logp = 0.0f64;
308            for t in 0..n {
309                logp += unary[t * k + labels[t]];
310            }
311            for t in 0..n - 1 {
312                logp += pairwise[labels[t] * k + labels[t + 1]];
313            }
314            let p = logp.exp();
315            z += p;
316            for t in 0..n {
317                marg[t * k + labels[t]] += p;
318            }
319        }
320        for v in marg.iter_mut() {
321            *v /= z;
322        }
323        marg
324    }
325
326    #[test]
327    fn chain_matches_exact_marginals() {
328        let n = 4;
329        let k = 2;
330        let unary = vec![
331            0.3, -0.1, // node 0
332            -0.4, 0.2, // node 1
333            0.5, 0.0, // node 2
334            -0.2, 0.6, // node 3
335        ];
336        // Asymmetric on purpose to exercise oriented pairwise handling.
337        let pairwise = vec![0.7, -0.2, -0.3, 0.5];
338        let bp = LoopyBp::new(
339            1,
340            n,
341            k,
342            LoopyBpConfig {
343                max_iter: 500,
344                tol: 1e-12,
345                damping: 1.0,
346            },
347        )
348        .expect("new");
349        let got = bp.infer(&unary, &pairwise).expect("infer");
350        let exact = brute_force_chain_marginals(&unary, &pairwise, n, k);
351        for idx in 0..n * k {
352            assert!(
353                (got[idx] - exact[idx]).abs() < 1e-6,
354                "idx {idx}: bp {} vs exact {}",
355                got[idx],
356                exact[idx]
357            );
358        }
359    }
360
361    #[test]
362    fn uniform_potentials_give_uniform_marginals() {
363        let (h, w, k) = (2, 3, 3);
364        let bp = LoopyBp::new(h, w, k, LoopyBpConfig::default()).expect("new");
365        let unary = vec![0.0f64; h * w * k];
366        let pairwise = vec![0.0f64; k * k];
367        let marg = bp.infer(&unary, &pairwise).expect("infer");
368        for &m in &marg {
369            assert!((m - 1.0 / k as f64).abs() < 1e-9, "got {m}");
370        }
371    }
372
373    #[test]
374    fn strong_unary_propagates_through_attractive_pairwise() {
375        let (h, w, k) = (3, 3, 2);
376        let bp = LoopyBp::new(h, w, k, LoopyBpConfig::default()).expect("new");
377        let mut unary = vec![0.0f64; h * w * k];
378        let center = w + 1; // node (1,1)
379        unary[center * k] = 4.0; // strongly favours label 0
380        // Attractive (Potts) pairwise: same label rewarded.
381        let beta = 0.8;
382        let pairwise = vec![beta, 0.0, 0.0, beta];
383        let marg = bp.infer(&unary, &pairwise).expect("infer");
384        // Centre is pinned to label 0.
385        assert!(marg[center * k] > 0.9, "centre p0 = {}", marg[center * k]);
386        // A direct neighbour is pulled toward label 0 (above the 0.5 prior).
387        let nbr = w + 1; // (0,1), directly above the centre
388        assert!(marg[nbr * k] > 0.5, "neighbour p0 = {}", marg[nbr * k]);
389        // The neighbour is more affected than a far corner.
390        let corner = 2 * w; // (2,0)
391        assert!(
392            marg[nbr * k] >= marg[corner * k] - 1e-9,
393            "neighbour {} vs corner {}",
394            marg[nbr * k],
395            marg[corner * k]
396        );
397    }
398
399    #[test]
400    fn marginals_normalised_and_bounded() {
401        let (h, w, k) = (2, 2, 3);
402        let bp = LoopyBp::new(h, w, k, LoopyBpConfig::default()).expect("new");
403        let unary = vec![
404            0.2, -0.3, 0.1, //
405            -0.5, 0.4, 0.0, //
406            0.3, 0.3, -0.2, //
407            0.0, -0.1, 0.5, //
408        ];
409        let pairwise = vec![0.5, 0.1, 0.0, 0.1, 0.5, 0.1, 0.0, 0.1, 0.5];
410        let marg = bp.infer(&unary, &pairwise).expect("infer");
411        for i in 0..h * w {
412            let mut s = 0.0;
413            for l in 0..k {
414                let v = marg[i * k + l];
415                assert!((0.0..=1.0).contains(&v), "marginal out of range: {v}");
416                s += v;
417            }
418            assert!((s - 1.0).abs() < 1e-9, "node {i} sum {s}");
419        }
420    }
421
422    #[test]
423    fn converges_on_small_grid() {
424        let (h, w, k) = (3, 3, 2);
425        let bp = LoopyBp::new(h, w, k, LoopyBpConfig::default()).expect("new");
426        let mut unary = vec![0.0f64; h * w * k];
427        for i in 0..h * w {
428            unary[i * k] = 0.1 * (i as f64).cos();
429            unary[i * k + 1] = -0.1 * (i as f64).sin();
430        }
431        let pairwise = vec![0.3, 0.0, 0.0, 0.3]; // weak attractive
432        let res = bp.infer_detailed(&unary, &pairwise).expect("infer");
433        assert!(
434            res.converged,
435            "did not converge in {} sweeps",
436            res.iterations
437        );
438        for i in 0..h * w {
439            let s: f64 = res.marginals[i * k..(i + 1) * k].iter().sum();
440            assert!((s - 1.0).abs() < 1e-6, "node {i} sum {s}");
441        }
442    }
443
444    #[test]
445    fn invalid_dims_and_params_error() {
446        assert!(LoopyBp::new(0, 3, 2, LoopyBpConfig::default()).is_err());
447        assert!(LoopyBp::new(3, 3, 0, LoopyBpConfig::default()).is_err());
448        assert!(
449            LoopyBp::new(
450                2,
451                2,
452                2,
453                LoopyBpConfig {
454                    damping: 1.5,
455                    ..LoopyBpConfig::default()
456                }
457            )
458            .is_err()
459        );
460        let bp = LoopyBp::new(2, 2, 2, LoopyBpConfig::default()).expect("new");
461        // Wrong unary length (needs 2*2*2 = 8).
462        assert!(bp.infer(&[0.0; 3], &[0.0; 4]).is_err());
463        // Wrong pairwise length (needs 2*2 = 4).
464        assert!(bp.infer(&[0.0; 8], &[0.0; 3]).is_err());
465    }
466}