Skip to main content

oxicuda_seq/grid_crf/
mean_field.rs

1//! Mean-field variational inference for grid CRFs.
2
3use super::grid_crf::GridCrf;
4use crate::error::{SeqError, SeqResult};
5
6/// Mean-field inference configuration.
7#[derive(Debug, Clone, Copy)]
8pub struct MeanFieldConfig {
9    /// Number of mean-field sweeps.
10    pub max_iter: usize,
11    /// Convergence threshold on max change of q.
12    pub tol: f64,
13    /// Damping factor `q_new = (1−damp) q_old + damp q_update`.
14    pub damping: f64,
15}
16
17impl Default for MeanFieldConfig {
18    fn default() -> Self {
19        Self {
20            max_iter: 30,
21            tol: 1e-4,
22            damping: 0.5,
23        }
24    }
25}
26
27/// Run mean-field variational inference; returns the posterior `q` of shape
28/// `(n_rows × n_cols × n_labels)`.
29pub fn mean_field_inference(crf: &GridCrf, cfg: &MeanFieldConfig) -> SeqResult<Vec<f64>> {
30    if cfg.max_iter == 0 {
31        return Err(SeqError::InvalidConfiguration(
32            "max_iter must be > 0".to_string(),
33        ));
34    }
35    let r_max = crf.n_rows;
36    let c_max = crf.n_cols;
37    let l_max = crf.n_labels;
38    let n_sites = r_max * c_max;
39
40    // Initialise q with softmin of unary energies.
41    let mut q = vec![0.0f64; n_sites * l_max];
42    for r in 0..r_max {
43        for c in 0..c_max {
44            let base = (r * c_max + c) * l_max;
45            let mut neg = vec![0.0f64; l_max];
46            for l in 0..l_max {
47                neg[l] = -crf.unary[base + l];
48            }
49            let m = neg.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
50            let mut s = 0.0;
51            for l in 0..l_max {
52                neg[l] = (neg[l] - m).exp();
53                s += neg[l];
54            }
55            for l in 0..l_max {
56                q[base + l] = if s > 0.0 {
57                    neg[l] / s
58                } else {
59                    1.0 / l_max as f64
60                };
61            }
62        }
63    }
64
65    let mut q_new = q.clone();
66    for _it in 0..cfg.max_iter {
67        let mut max_change = 0.0_f64;
68        for r in 0..r_max {
69            for c in 0..c_max {
70                let base = (r * c_max + c) * l_max;
71                // For each label l at site (r,c), compute log_q_new = -φ - Σ_nbrs ψ-expectation
72                let mut log_q = vec![0.0f64; l_max];
73                for l in 0..l_max {
74                    let mut acc = -crf.unary[base + l];
75                    if r > 0 {
76                        let nbr = ((r - 1) * c_max + c) * l_max;
77                        for lp in 0..l_max {
78                            acc -= q[nbr + lp] * crf.pairwise[l * l_max + lp];
79                        }
80                    }
81                    if r + 1 < r_max {
82                        let nbr = ((r + 1) * c_max + c) * l_max;
83                        for lp in 0..l_max {
84                            acc -= q[nbr + lp] * crf.pairwise[l * l_max + lp];
85                        }
86                    }
87                    if c > 0 {
88                        let nbr = (r * c_max + (c - 1)) * l_max;
89                        for lp in 0..l_max {
90                            acc -= q[nbr + lp] * crf.pairwise[l * l_max + lp];
91                        }
92                    }
93                    if c + 1 < c_max {
94                        let nbr = (r * c_max + (c + 1)) * l_max;
95                        for lp in 0..l_max {
96                            acc -= q[nbr + lp] * crf.pairwise[l * l_max + lp];
97                        }
98                    }
99                    log_q[l] = acc;
100                }
101                let m = log_q.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
102                let mut sum = 0.0;
103                for l in 0..l_max {
104                    log_q[l] = (log_q[l] - m).exp();
105                    sum += log_q[l];
106                }
107                for l in 0..l_max {
108                    let val = if sum > 0.0 {
109                        log_q[l] / sum
110                    } else {
111                        1.0 / l_max as f64
112                    };
113                    let damped = (1.0 - cfg.damping) * q[base + l] + cfg.damping * val;
114                    let change = (damped - q[base + l]).abs();
115                    if change > max_change {
116                        max_change = change;
117                    }
118                    q_new[base + l] = damped;
119                }
120            }
121        }
122        q.copy_from_slice(&q_new);
123        if max_change < cfg.tol {
124            break;
125        }
126    }
127    Ok(q)
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn mean_field_runs_two_labels() {
136        let n_rows = 3;
137        let n_cols = 3;
138        let n_labels = 2;
139        // Unary prefers label 0 at corner (0,0) and label 1 at corner (2,2).
140        let mut unary = vec![0.5f64; n_rows * n_cols * n_labels];
141        unary[0] = 0.0; // label 0 at (0,0)
142        unary[1] = 1.0; // label 1 at (0,0)
143        let idx22 = (2 * n_cols + 2) * n_labels;
144        unary[idx22] = 1.0;
145        unary[idx22 + 1] = 0.0;
146        // Pairwise smoothing: same-label = 0, different-label = 0.5
147        let pairwise = vec![0.0, 0.5, 0.5, 0.0];
148        let g = GridCrf::new(n_rows, n_cols, n_labels, unary, pairwise).expect("ok");
149        let q = mean_field_inference(&g, &MeanFieldConfig::default()).expect("ok");
150        // q rows sum to 1
151        for r in 0..n_rows {
152            for c in 0..n_cols {
153                let s: f64 = q[(r * n_cols + c) * n_labels..(r * n_cols + c + 1) * n_labels]
154                    .iter()
155                    .sum();
156                assert!((s - 1.0).abs() < 1e-6, "row sum {s}");
157            }
158        }
159    }
160}