oxicuda_seq/grid_crf/
mean_field.rs1use super::grid_crf::GridCrf;
4use crate::error::{SeqError, SeqResult};
5
6#[derive(Debug, Clone, Copy)]
8pub struct MeanFieldConfig {
9 pub max_iter: usize,
11 pub tol: f64,
13 pub damping: f64,
15}
16
17impl Default for MeanFieldConfig {
18 fn default() -> Self {
19 Self {
20 max_iter: 30,
21 tol: 1e-4,
22 damping: 0.5,
23 }
24 }
25}
26
27pub fn mean_field_inference(crf: &GridCrf, cfg: &MeanFieldConfig) -> SeqResult<Vec<f64>> {
30 if cfg.max_iter == 0 {
31 return Err(SeqError::InvalidConfiguration(
32 "max_iter must be > 0".to_string(),
33 ));
34 }
35 let r_max = crf.n_rows;
36 let c_max = crf.n_cols;
37 let l_max = crf.n_labels;
38 let n_sites = r_max * c_max;
39
40 let mut q = vec![0.0f64; n_sites * l_max];
42 for r in 0..r_max {
43 for c in 0..c_max {
44 let base = (r * c_max + c) * l_max;
45 let mut neg = vec![0.0f64; l_max];
46 for l in 0..l_max {
47 neg[l] = -crf.unary[base + l];
48 }
49 let m = neg.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
50 let mut s = 0.0;
51 for l in 0..l_max {
52 neg[l] = (neg[l] - m).exp();
53 s += neg[l];
54 }
55 for l in 0..l_max {
56 q[base + l] = if s > 0.0 {
57 neg[l] / s
58 } else {
59 1.0 / l_max as f64
60 };
61 }
62 }
63 }
64
65 let mut q_new = q.clone();
66 for _it in 0..cfg.max_iter {
67 let mut max_change = 0.0_f64;
68 for r in 0..r_max {
69 for c in 0..c_max {
70 let base = (r * c_max + c) * l_max;
71 let mut log_q = vec![0.0f64; l_max];
73 for l in 0..l_max {
74 let mut acc = -crf.unary[base + l];
75 if r > 0 {
76 let nbr = ((r - 1) * c_max + c) * l_max;
77 for lp in 0..l_max {
78 acc -= q[nbr + lp] * crf.pairwise[l * l_max + lp];
79 }
80 }
81 if r + 1 < r_max {
82 let nbr = ((r + 1) * c_max + c) * l_max;
83 for lp in 0..l_max {
84 acc -= q[nbr + lp] * crf.pairwise[l * l_max + lp];
85 }
86 }
87 if c > 0 {
88 let nbr = (r * c_max + (c - 1)) * l_max;
89 for lp in 0..l_max {
90 acc -= q[nbr + lp] * crf.pairwise[l * l_max + lp];
91 }
92 }
93 if c + 1 < c_max {
94 let nbr = (r * c_max + (c + 1)) * l_max;
95 for lp in 0..l_max {
96 acc -= q[nbr + lp] * crf.pairwise[l * l_max + lp];
97 }
98 }
99 log_q[l] = acc;
100 }
101 let m = log_q.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
102 let mut sum = 0.0;
103 for l in 0..l_max {
104 log_q[l] = (log_q[l] - m).exp();
105 sum += log_q[l];
106 }
107 for l in 0..l_max {
108 let val = if sum > 0.0 {
109 log_q[l] / sum
110 } else {
111 1.0 / l_max as f64
112 };
113 let damped = (1.0 - cfg.damping) * q[base + l] + cfg.damping * val;
114 let change = (damped - q[base + l]).abs();
115 if change > max_change {
116 max_change = change;
117 }
118 q_new[base + l] = damped;
119 }
120 }
121 }
122 q.copy_from_slice(&q_new);
123 if max_change < cfg.tol {
124 break;
125 }
126 }
127 Ok(q)
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn mean_field_runs_two_labels() {
136 let n_rows = 3;
137 let n_cols = 3;
138 let n_labels = 2;
139 let mut unary = vec![0.5f64; n_rows * n_cols * n_labels];
141 unary[0] = 0.0; unary[1] = 1.0; let idx22 = (2 * n_cols + 2) * n_labels;
144 unary[idx22] = 1.0;
145 unary[idx22 + 1] = 0.0;
146 let pairwise = vec![0.0, 0.5, 0.5, 0.0];
148 let g = GridCrf::new(n_rows, n_cols, n_labels, unary, pairwise).expect("ok");
149 let q = mean_field_inference(&g, &MeanFieldConfig::default()).expect("ok");
150 for r in 0..n_rows {
152 for c in 0..n_cols {
153 let s: f64 = q[(r * n_cols + c) * n_labels..(r * n_cols + c + 1) * n_labels]
154 .iter()
155 .sum();
156 assert!((s - 1.0).abs() < 1e-6, "row sum {s}");
157 }
158 }
159 }
160}