1use super::mrf::Mrf;
4use crate::error::{SeqError, SeqResult};
5
6#[derive(Debug, Clone, Copy)]
8pub struct BpConfig {
9 pub max_iter: usize,
10 pub tol: f64,
11 pub damping: f64,
12}
13
14impl Default for BpConfig {
15 fn default() -> Self {
16 Self {
17 max_iter: 50,
18 tol: 1e-5,
19 damping: 0.5,
20 }
21 }
22}
23
24#[derive(Debug, Clone)]
26pub struct BpResult {
27 pub marginals: Vec<f64>,
28 pub iterations: usize,
29 pub converged: bool,
30}
31
32pub fn loopy_bp_marginals(mrf: &Mrf, cfg: &BpConfig) -> SeqResult<BpResult> {
37 if cfg.max_iter == 0 {
38 return Err(SeqError::InvalidConfiguration(
39 "max_iter must be > 0".to_string(),
40 ));
41 }
42 let nl = mrf.n_labels;
43 let l2 = nl * nl;
44 let n_messages = mrf.edges.len() * 2;
46 let mut log_msg = vec![0.0; n_messages * nl];
47 let mut new_log_msg = log_msg.clone();
48 let mut converged = false;
49 let mut iters = 0;
50
51 for it in 0..cfg.max_iter {
52 iters = it + 1;
53 for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
54 for &(src, _dst, msg_idx, _opp_idx) in &[
55 (u, v, e_idx * 2, e_idx * 2 + 1),
56 (v, u, e_idx * 2 + 1, e_idx * 2),
57 ] {
58 let mut out = vec![f64::NEG_INFINITY; nl];
62 for l_dst in 0..nl {
63 let mut terms = vec![0.0; nl];
64 for l_src in 0..nl {
65 let mut acc = -mrf.unary[src * nl + l_src];
66 let psi = if src == u {
68 mrf.pairwise[e_idx * l2 + l_src * nl + l_dst]
69 } else {
70 mrf.pairwise[e_idx * l2 + l_dst * nl + l_src]
71 };
72 acc -= psi;
73 for (k_idx, &(uu, vv)) in mrf.edges.iter().enumerate() {
75 if k_idx == e_idx {
76 continue;
77 }
78 let in_msg = if uu == src {
79 &log_msg[(k_idx * 2 + 1) * nl..]
80 } else if vv == src {
81 &log_msg[(k_idx * 2) * nl..]
82 } else {
83 continue;
84 };
85 acc += in_msg[l_src];
86 }
87 terms[l_src] = acc;
88 }
89 out[l_dst] = logsumexp_in(&terms);
90 }
91 let m = out.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
93 for v in out.iter_mut() {
94 *v -= m;
95 }
96 for l in 0..nl {
98 new_log_msg[msg_idx * nl + l] =
99 (1.0 - cfg.damping) * log_msg[msg_idx * nl + l] + cfg.damping * out[l];
100 }
101 }
102 }
103 let mut max_diff = 0.0_f64;
104 for k in 0..log_msg.len() {
105 let d = (new_log_msg[k] - log_msg[k]).abs();
106 if d > max_diff {
107 max_diff = d;
108 }
109 }
110 log_msg.copy_from_slice(&new_log_msg);
111 if max_diff < cfg.tol {
112 converged = true;
113 break;
114 }
115 }
116
117 let mut marginals = vec![0.0; mrf.n_nodes * nl];
119 for i in 0..mrf.n_nodes {
120 let mut log_b = vec![0.0; nl];
121 for l in 0..nl {
122 log_b[l] = -mrf.unary[i * nl + l];
123 }
124 for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
125 if u == i {
126 for l in 0..nl {
127 log_b[l] += log_msg[(e_idx * 2 + 1) * nl + l];
128 }
129 }
130 if v == i {
131 for l in 0..nl {
132 log_b[l] += log_msg[(e_idx * 2) * nl + l];
133 }
134 }
135 }
136 let m = log_b.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
137 let mut s = 0.0;
138 let mut exps = vec![0.0; nl];
139 for l in 0..nl {
140 exps[l] = (log_b[l] - m).exp();
141 s += exps[l];
142 }
143 for l in 0..nl {
144 marginals[i * nl + l] = if s > 0.0 {
145 exps[l] / s
146 } else {
147 1.0 / nl as f64
148 };
149 }
150 }
151 Ok(BpResult {
152 marginals,
153 iterations: iters,
154 converged,
155 })
156}
157
158pub fn loopy_bp_map(mrf: &Mrf, cfg: &BpConfig) -> SeqResult<Vec<usize>> {
160 if cfg.max_iter == 0 {
161 return Err(SeqError::InvalidConfiguration(
162 "max_iter must be > 0".to_string(),
163 ));
164 }
165 let nl = mrf.n_labels;
166 let l2 = nl * nl;
167 let n_messages = mrf.edges.len() * 2;
168 let mut log_msg = vec![0.0; n_messages * nl];
169 let mut new_log_msg = log_msg.clone();
170
171 for _ in 0..cfg.max_iter {
172 for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
173 for &(src, dst, msg_idx) in &[(u, v, e_idx * 2), (v, u, e_idx * 2 + 1)] {
174 let _ = dst;
175 let mut out = vec![f64::NEG_INFINITY; nl];
176 for l_dst in 0..nl {
177 let mut best = f64::NEG_INFINITY;
178 for l_src in 0..nl {
179 let mut acc = -mrf.unary[src * nl + l_src];
180 let psi = if src == u {
181 mrf.pairwise[e_idx * l2 + l_src * nl + l_dst]
182 } else {
183 mrf.pairwise[e_idx * l2 + l_dst * nl + l_src]
184 };
185 acc -= psi;
186 for (k_idx, &(uu, vv)) in mrf.edges.iter().enumerate() {
187 if k_idx == e_idx {
188 continue;
189 }
190 let in_msg = if uu == src {
191 &log_msg[(k_idx * 2 + 1) * nl..]
192 } else if vv == src {
193 &log_msg[(k_idx * 2) * nl..]
194 } else {
195 continue;
196 };
197 acc += in_msg[l_src];
198 }
199 if acc > best {
200 best = acc;
201 }
202 }
203 out[l_dst] = best;
204 }
205 let m = out.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
206 for v in out.iter_mut() {
207 *v -= m;
208 }
209 for l in 0..nl {
210 new_log_msg[msg_idx * nl + l] =
211 (1.0 - cfg.damping) * log_msg[msg_idx * nl + l] + cfg.damping * out[l];
212 }
213 }
214 }
215 log_msg.copy_from_slice(&new_log_msg);
216 }
217
218 let mut labels = vec![0usize; mrf.n_nodes];
220 for i in 0..mrf.n_nodes {
221 let mut best_l = 0usize;
222 let mut best_v = f64::NEG_INFINITY;
223 for l in 0..nl {
224 let mut acc = -mrf.unary[i * nl + l];
225 for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
226 if u == i {
227 acc += log_msg[(e_idx * 2 + 1) * nl + l];
228 }
229 if v == i {
230 acc += log_msg[(e_idx * 2) * nl + l];
231 }
232 }
233 if acc > best_v {
234 best_v = acc;
235 best_l = l;
236 }
237 }
238 labels[i] = best_l;
239 }
240 Ok(labels)
241}
242
243fn logsumexp_in(xs: &[f64]) -> f64 {
244 let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
245 if m == f64::NEG_INFINITY {
246 return f64::NEG_INFINITY;
247 }
248 let s: f64 = xs.iter().map(|x| (x - m).exp()).sum();
249 m + s.ln()
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn bp_marginals_normalise() {
258 let m = Mrf::new(
259 3,
260 2,
261 vec![(0, 1), (1, 2)],
262 vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
263 vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
264 )
265 .expect("ok");
266 let res = loopy_bp_marginals(&m, &BpConfig::default()).expect("ok");
267 for i in 0..m.n_nodes {
268 let s: f64 = res.marginals[i * m.n_labels..(i + 1) * m.n_labels]
269 .iter()
270 .sum();
271 assert!((s - 1.0).abs() < 1e-6, "row sum {s}");
272 }
273 }
274
275 #[test]
276 fn bp_map_runs() {
277 let m = Mrf::new(
278 3,
279 2,
280 vec![(0, 1), (1, 2)],
281 vec![0.0, 5.0, 5.0, 0.0, 0.0, 5.0],
282 vec![0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
283 )
284 .expect("ok");
285 let labels = loopy_bp_map(&m, &BpConfig::default()).expect("ok");
286 assert_eq!(labels.len(), 3);
287 }
288}