1use crate::error::{SeqError, SeqResult};
21use crate::hmm::forward_backward::logsumexp;
22
23#[derive(Debug, Clone, Copy)]
25pub struct LoopyBpConfig {
26 pub max_iter: usize,
28 pub tol: f64,
30 pub damping: f64,
33}
34
35impl Default for LoopyBpConfig {
36 fn default() -> Self {
37 Self {
38 max_iter: 200,
39 tol: 1e-9,
40 damping: 0.5,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct LoopyBpResult {
48 pub marginals: Vec<f64>,
50 pub iterations: usize,
52 pub converged: bool,
54}
55
56#[derive(Debug, Clone)]
61pub struct LoopyBp {
62 height: usize,
63 width: usize,
64 n_states: usize,
65 config: LoopyBpConfig,
66 edges: Vec<(usize, usize)>,
68 incident: Vec<Vec<(usize, usize)>>,
70}
71
72impl LoopyBp {
73 pub fn new(
81 height: usize,
82 width: usize,
83 n_states: usize,
84 config: LoopyBpConfig,
85 ) -> SeqResult<Self> {
86 if height == 0 || width == 0 || n_states == 0 {
87 return Err(SeqError::InvalidConfiguration(
88 "height, width and n_states must all be > 0".to_string(),
89 ));
90 }
91 if config.max_iter == 0 {
92 return Err(SeqError::InvalidConfiguration(
93 "max_iter must be > 0".to_string(),
94 ));
95 }
96 if config.damping <= 0.0 || config.damping > 1.0 {
97 return Err(SeqError::InvalidParameter {
98 name: "damping".to_string(),
99 value: config.damping,
100 });
101 }
102
103 let mut edges = Vec::new();
105 for r in 0..height {
106 for c in 0..width {
107 let node = r * width + c;
108 if c + 1 < width {
109 edges.push((node, node + 1)); }
111 if r + 1 < height {
112 edges.push((node, node + width)); }
114 }
115 }
116
117 let n_nodes = height * width;
120 let mut incident: Vec<Vec<(usize, usize)>> = vec![Vec::new(); n_nodes];
121 for (e, &(u, v)) in edges.iter().enumerate() {
122 incident[u].push((2 * e + 1, e)); incident[v].push((2 * e, e)); }
125
126 Ok(Self {
127 height,
128 width,
129 n_states,
130 config,
131 edges,
132 incident,
133 })
134 }
135
136 pub fn height(&self) -> usize {
138 self.height
139 }
140
141 pub fn width(&self) -> usize {
143 self.width
144 }
145
146 pub fn n_states(&self) -> usize {
148 self.n_states
149 }
150
151 pub fn infer(&self, unary: &[f64], pairwise: &[f64]) -> SeqResult<Vec<f64>> {
157 Ok(self.infer_detailed(unary, pairwise)?.marginals)
158 }
159
160 pub fn infer_detailed(&self, unary: &[f64], pairwise: &[f64]) -> SeqResult<LoopyBpResult> {
164 let k = self.n_states;
165 let n_nodes = self.height * self.width;
166 if unary.len() != n_nodes * k {
167 return Err(SeqError::ShapeMismatch {
168 expected: n_nodes * k,
169 got: unary.len(),
170 });
171 }
172 if pairwise.len() != k * k {
173 return Err(SeqError::ShapeMismatch {
174 expected: k * k,
175 got: pairwise.len(),
176 });
177 }
178
179 let damp = self.config.damping;
180 let n_slots = self.edges.len() * 2;
181 let mut log_msg = vec![0.0f64; n_slots * k];
182 let mut new_log_msg = log_msg.clone();
183 let mut terms = vec![0.0f64; k];
184 let mut out = vec![0.0f64; k];
185 let mut converged = false;
186 let mut iterations = 0usize;
187
188 for it in 0..self.config.max_iter {
189 iterations = it + 1;
190 for (e, &(u, v)) in self.edges.iter().enumerate() {
191 for &(src, dst, out_slot) in &[(u, v, 2 * e), (v, u, 2 * e + 1)] {
193 let _ = dst;
194 for l_dst in 0..k {
195 for l_src in 0..k {
196 let psi = if src == u {
198 pairwise[l_src * k + l_dst]
199 } else {
200 pairwise[l_dst * k + l_src]
201 };
202 let mut acc = unary[src * k + l_src] + psi;
203 for &(in_slot, in_edge) in &self.incident[src] {
206 if in_edge == e {
207 continue;
208 }
209 acc += log_msg[in_slot * k + l_src];
210 }
211 terms[l_src] = acc;
212 }
213 out[l_dst] = logsumexp(&terms);
214 }
215 let m = out.iter().copied().fold(f64::NEG_INFINITY, f64::max);
217 if m > f64::NEG_INFINITY {
218 for val in out.iter_mut() {
219 *val -= m;
220 }
221 }
222 for l in 0..k {
223 let base = out_slot * k + l;
224 new_log_msg[base] = (1.0 - damp) * log_msg[base] + damp * out[l];
225 }
226 }
227 }
228
229 let mut max_diff = 0.0f64;
230 for idx in 0..log_msg.len() {
231 let d = (new_log_msg[idx] - log_msg[idx]).abs();
232 if d > max_diff {
233 max_diff = d;
234 }
235 }
236 log_msg.copy_from_slice(&new_log_msg);
237 if max_diff < self.config.tol {
238 converged = true;
239 break;
240 }
241 }
242
243 let marginals = self.node_marginals(unary, &log_msg);
244 Ok(LoopyBpResult {
245 marginals,
246 iterations,
247 converged,
248 })
249 }
250
251 fn node_marginals(&self, unary: &[f64], log_msg: &[f64]) -> Vec<f64> {
254 let k = self.n_states;
255 let n_nodes = self.height * self.width;
256 let mut marginals = vec![0.0f64; n_nodes * k];
257 let mut log_b = vec![0.0f64; k];
258 for i in 0..n_nodes {
259 for l in 0..k {
260 log_b[l] = unary[i * k + l];
261 }
262 for &(in_slot, _e) in &self.incident[i] {
263 for l in 0..k {
264 log_b[l] += log_msg[in_slot * k + l];
265 }
266 }
267 let m = log_b.iter().copied().fold(f64::NEG_INFINITY, f64::max);
268 let mut s = 0.0;
269 for l in 0..k {
270 let val = (log_b[l] - m).exp();
271 marginals[i * k + l] = val;
272 s += val;
273 }
274 for l in 0..k {
275 marginals[i * k + l] = if s > 0.0 {
276 marginals[i * k + l] / s
277 } else {
278 1.0 / k as f64
279 };
280 }
281 }
282 marginals
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 fn brute_force_chain_marginals(
292 unary: &[f64],
293 pairwise: &[f64],
294 n: usize,
295 k: usize,
296 ) -> Vec<f64> {
297 let mut marg = vec![0.0f64; n * k];
298 let mut z = 0.0f64;
299 let total = k.pow(n as u32);
300 let mut labels = vec![0usize; n];
301 for code in 0..total {
302 let mut x = code;
303 for t in 0..n {
304 labels[t] = x % k;
305 x /= k;
306 }
307 let mut logp = 0.0f64;
308 for t in 0..n {
309 logp += unary[t * k + labels[t]];
310 }
311 for t in 0..n - 1 {
312 logp += pairwise[labels[t] * k + labels[t + 1]];
313 }
314 let p = logp.exp();
315 z += p;
316 for t in 0..n {
317 marg[t * k + labels[t]] += p;
318 }
319 }
320 for v in marg.iter_mut() {
321 *v /= z;
322 }
323 marg
324 }
325
326 #[test]
327 fn chain_matches_exact_marginals() {
328 let n = 4;
329 let k = 2;
330 let unary = vec![
331 0.3, -0.1, -0.4, 0.2, 0.5, 0.0, -0.2, 0.6, ];
336 let pairwise = vec![0.7, -0.2, -0.3, 0.5];
338 let bp = LoopyBp::new(
339 1,
340 n,
341 k,
342 LoopyBpConfig {
343 max_iter: 500,
344 tol: 1e-12,
345 damping: 1.0,
346 },
347 )
348 .expect("new");
349 let got = bp.infer(&unary, &pairwise).expect("infer");
350 let exact = brute_force_chain_marginals(&unary, &pairwise, n, k);
351 for idx in 0..n * k {
352 assert!(
353 (got[idx] - exact[idx]).abs() < 1e-6,
354 "idx {idx}: bp {} vs exact {}",
355 got[idx],
356 exact[idx]
357 );
358 }
359 }
360
361 #[test]
362 fn uniform_potentials_give_uniform_marginals() {
363 let (h, w, k) = (2, 3, 3);
364 let bp = LoopyBp::new(h, w, k, LoopyBpConfig::default()).expect("new");
365 let unary = vec![0.0f64; h * w * k];
366 let pairwise = vec![0.0f64; k * k];
367 let marg = bp.infer(&unary, &pairwise).expect("infer");
368 for &m in &marg {
369 assert!((m - 1.0 / k as f64).abs() < 1e-9, "got {m}");
370 }
371 }
372
373 #[test]
374 fn strong_unary_propagates_through_attractive_pairwise() {
375 let (h, w, k) = (3, 3, 2);
376 let bp = LoopyBp::new(h, w, k, LoopyBpConfig::default()).expect("new");
377 let mut unary = vec![0.0f64; h * w * k];
378 let center = w + 1; unary[center * k] = 4.0; let beta = 0.8;
382 let pairwise = vec![beta, 0.0, 0.0, beta];
383 let marg = bp.infer(&unary, &pairwise).expect("infer");
384 assert!(marg[center * k] > 0.9, "centre p0 = {}", marg[center * k]);
386 let nbr = w + 1; assert!(marg[nbr * k] > 0.5, "neighbour p0 = {}", marg[nbr * k]);
389 let corner = 2 * w; assert!(
392 marg[nbr * k] >= marg[corner * k] - 1e-9,
393 "neighbour {} vs corner {}",
394 marg[nbr * k],
395 marg[corner * k]
396 );
397 }
398
399 #[test]
400 fn marginals_normalised_and_bounded() {
401 let (h, w, k) = (2, 2, 3);
402 let bp = LoopyBp::new(h, w, k, LoopyBpConfig::default()).expect("new");
403 let unary = vec![
404 0.2, -0.3, 0.1, -0.5, 0.4, 0.0, 0.3, 0.3, -0.2, 0.0, -0.1, 0.5, ];
409 let pairwise = vec![0.5, 0.1, 0.0, 0.1, 0.5, 0.1, 0.0, 0.1, 0.5];
410 let marg = bp.infer(&unary, &pairwise).expect("infer");
411 for i in 0..h * w {
412 let mut s = 0.0;
413 for l in 0..k {
414 let v = marg[i * k + l];
415 assert!((0.0..=1.0).contains(&v), "marginal out of range: {v}");
416 s += v;
417 }
418 assert!((s - 1.0).abs() < 1e-9, "node {i} sum {s}");
419 }
420 }
421
422 #[test]
423 fn converges_on_small_grid() {
424 let (h, w, k) = (3, 3, 2);
425 let bp = LoopyBp::new(h, w, k, LoopyBpConfig::default()).expect("new");
426 let mut unary = vec![0.0f64; h * w * k];
427 for i in 0..h * w {
428 unary[i * k] = 0.1 * (i as f64).cos();
429 unary[i * k + 1] = -0.1 * (i as f64).sin();
430 }
431 let pairwise = vec![0.3, 0.0, 0.0, 0.3]; let res = bp.infer_detailed(&unary, &pairwise).expect("infer");
433 assert!(
434 res.converged,
435 "did not converge in {} sweeps",
436 res.iterations
437 );
438 for i in 0..h * w {
439 let s: f64 = res.marginals[i * k..(i + 1) * k].iter().sum();
440 assert!((s - 1.0).abs() < 1e-6, "node {i} sum {s}");
441 }
442 }
443
444 #[test]
445 fn invalid_dims_and_params_error() {
446 assert!(LoopyBp::new(0, 3, 2, LoopyBpConfig::default()).is_err());
447 assert!(LoopyBp::new(3, 3, 0, LoopyBpConfig::default()).is_err());
448 assert!(
449 LoopyBp::new(
450 2,
451 2,
452 2,
453 LoopyBpConfig {
454 damping: 1.5,
455 ..LoopyBpConfig::default()
456 }
457 )
458 .is_err()
459 );
460 let bp = LoopyBp::new(2, 2, 2, LoopyBpConfig::default()).expect("new");
461 assert!(bp.infer(&[0.0; 3], &[0.0; 4]).is_err());
463 assert!(bp.infer(&[0.0; 8], &[0.0; 3]).is_err());
465 }
466}