1use crate::error::{SeqError, SeqResult};
19
20#[inline]
23fn logsumexp(xs: &[f64]) -> f64 {
24 let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
25 if m == f64::NEG_INFINITY {
26 return f64::NEG_INFINITY;
27 }
28 let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
29 m + s.ln()
30}
31
32#[derive(Debug, Clone)]
36pub struct GraphCrfConfig {
37 pub n_nodes: usize,
39 pub n_labels: usize,
41 pub max_iter: usize,
43 pub tol: f64,
45 pub damping: f64,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub struct Edge {
52 pub i: usize,
53 pub j: usize,
54}
55
56#[derive(Debug, Clone)]
60pub struct GeneralGraphCrf {
61 config: GraphCrfConfig,
62 pub node_potentials: Vec<f64>,
64 pub edge_potentials: Vec<f64>,
66 pub edges: Vec<Edge>,
68}
69
70impl GeneralGraphCrf {
71 pub fn new(config: GraphCrfConfig, edges: Vec<Edge>) -> SeqResult<Self> {
77 if config.n_nodes == 0 {
78 return Err(SeqError::InvalidConfiguration("n_nodes must be > 0".into()));
79 }
80 if config.n_labels == 0 {
81 return Err(SeqError::InvalidConfiguration(
82 "n_labels must be > 0".into(),
83 ));
84 }
85 for &Edge { i, j } in &edges {
86 if i >= config.n_nodes {
87 return Err(SeqError::IndexOutOfBounds {
88 index: i,
89 len: config.n_nodes,
90 });
91 }
92 if j >= config.n_nodes {
93 return Err(SeqError::IndexOutOfBounds {
94 index: j,
95 len: config.n_nodes,
96 });
97 }
98 }
99 let n_nodes = config.n_nodes;
100 let n_labels = config.n_labels;
101 let n_edges = edges.len();
102 Ok(Self {
103 node_potentials: vec![0.0f64; n_nodes * n_labels],
104 edge_potentials: vec![0.0f64; n_edges * n_labels * n_labels],
105 edges,
106 config,
107 })
108 }
109
110 pub fn set_node_potential(&mut self, node: usize, lbl: usize, val: f64) -> SeqResult<()> {
114 let n = self.config.n_labels;
115 if node >= self.config.n_nodes {
116 return Err(SeqError::IndexOutOfBounds {
117 index: node,
118 len: self.config.n_nodes,
119 });
120 }
121 if lbl >= n {
122 return Err(SeqError::IndexOutOfBounds { index: lbl, len: n });
123 }
124 self.node_potentials[node * n + lbl] = val;
125 Ok(())
126 }
127
128 pub fn set_edge_potential(
130 &mut self,
131 e_idx: usize,
132 li: usize,
133 lj: usize,
134 val: f64,
135 ) -> SeqResult<()> {
136 let n = self.config.n_labels;
137 if e_idx >= self.edges.len() {
138 return Err(SeqError::IndexOutOfBounds {
139 index: e_idx,
140 len: self.edges.len(),
141 });
142 }
143 if li >= n || lj >= n {
144 return Err(SeqError::IndexOutOfBounds {
145 index: li.max(lj),
146 len: n,
147 });
148 }
149 self.edge_potentials[e_idx * n * n + li * n + lj] = val;
150 Ok(())
151 }
152
153 pub fn sum_product_marginals(&self) -> SeqResult<Vec<f64>> {
160 let n = self.config.n_labels;
161 let n_nodes = self.config.n_nodes;
162 let n_edges = self.edges.len();
163
164 let mut msgs = vec![vec![0.0f64; n]; n_edges * 2];
167
168 let mut tmp = vec![0.0f64; n];
169
170 for _iter in 0..self.config.max_iter {
171 let mut max_delta = 0.0f64;
172
173 for e_idx in 0..n_edges {
174 let Edge { i, j } = self.edges[e_idx];
175 let ep_base = e_idx * n * n;
176
177 let new_i2j: Vec<f64> = (0..n)
181 .map(|yj| {
182 for yi in 0..n {
183 let mut incoming_i = self.node_potentials[i * n + yi];
185 for (e2, &Edge { i: ei2, j: ej2 }) in self.edges.iter().enumerate() {
186 if e2 == e_idx {
187 continue;
188 }
189 if ei2 == i {
190 incoming_i += msgs[e2 * 2 + 1][yi];
192 } else if ej2 == i {
193 incoming_i += msgs[e2 * 2][yi];
195 }
196 }
197 tmp[yi] = incoming_i + self.edge_potentials[ep_base + yi * n + yj];
198 }
199 logsumexp(&tmp)
200 })
201 .collect();
202
203 let lse = logsumexp(&new_i2j);
205 let new_i2j: Vec<f64> = new_i2j.iter().map(|&v| v - lse).collect();
206
207 let new_j2i: Vec<f64> = (0..n)
209 .map(|yi| {
210 for yj in 0..n {
211 let mut incoming_j = self.node_potentials[j * n + yj];
212 for (e2, &Edge { i: ei2, j: ej2 }) in self.edges.iter().enumerate() {
213 if e2 == e_idx {
214 continue;
215 }
216 if ei2 == j {
217 incoming_j += msgs[e2 * 2 + 1][yj];
218 } else if ej2 == j {
219 incoming_j += msgs[e2 * 2][yj];
220 }
221 }
222 tmp[yj] = incoming_j + self.edge_potentials[ep_base + yj * n + yi];
225 }
226 logsumexp(&tmp)
227 })
228 .collect();
229
230 let lse2 = logsumexp(&new_j2i);
231 let new_j2i: Vec<f64> = new_j2i.iter().map(|&v| v - lse2).collect();
232
233 let damp = self.config.damping;
235 for l in 0..n {
236 let old_i2j = msgs[e_idx * 2][l];
237 let old_j2i = msgs[e_idx * 2 + 1][l];
238 let updated_i2j = (1.0 - damp) * new_i2j[l] + damp * old_i2j;
239 let updated_j2i = (1.0 - damp) * new_j2i[l] + damp * old_j2i;
240 max_delta = max_delta
241 .max((updated_i2j - old_i2j).abs())
242 .max((updated_j2i - old_j2i).abs());
243 msgs[e_idx * 2][l] = updated_i2j;
244 msgs[e_idx * 2 + 1][l] = updated_j2i;
245 }
246 }
247
248 if max_delta < self.config.tol {
249 break;
250 }
251 }
252
253 let mut beliefs = vec![0.0f64; n_nodes * n];
255 for node in 0..n_nodes {
256 for l in 0..n {
257 let mut b = self.node_potentials[node * n + l];
258 for (e_idx, &Edge { i, j }) in self.edges.iter().enumerate() {
259 if i == node {
260 b += msgs[e_idx * 2 + 1][l];
262 } else if j == node {
263 b += msgs[e_idx * 2][l];
265 }
266 }
267 beliefs[node * n + l] = b;
268 }
269 let lse = logsumexp(&beliefs[node * n..(node + 1) * n]);
271 for l in 0..n {
272 beliefs[node * n + l] -= lse;
273 }
274 }
275
276 Ok(beliefs)
277 }
278
279 pub fn map_decode(&self) -> SeqResult<Vec<usize>> {
285 let n = self.config.n_labels;
286 let n_nodes = self.config.n_nodes;
287 let n_edges = self.edges.len();
288
289 let mut msgs = vec![vec![0.0f64; n]; n_edges * 2];
290 let mut tmp = vec![0.0f64; n];
291
292 for _iter in 0..self.config.max_iter {
293 let mut max_delta = 0.0f64;
294
295 for e_idx in 0..n_edges {
296 let Edge { i, j } = self.edges[e_idx];
297 let ep_base = e_idx * n * n;
298
299 let new_i2j: Vec<f64> = (0..n)
300 .map(|yj| {
301 for yi in 0..n {
302 let mut incoming_i = self.node_potentials[i * n + yi];
303 for (e2, &Edge { i: ei2, j: ej2 }) in self.edges.iter().enumerate() {
304 if e2 == e_idx {
305 continue;
306 }
307 if ei2 == i {
308 incoming_i += msgs[e2 * 2 + 1][yi];
309 } else if ej2 == i {
310 incoming_i += msgs[e2 * 2][yi];
311 }
312 }
313 tmp[yi] = incoming_i + self.edge_potentials[ep_base + yi * n + yj];
314 }
315 tmp.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
316 })
317 .collect();
318
319 let new_j2i: Vec<f64> = (0..n)
320 .map(|yi| {
321 for yj in 0..n {
322 let mut incoming_j = self.node_potentials[j * n + yj];
323 for (e2, &Edge { i: ei2, j: ej2 }) in self.edges.iter().enumerate() {
324 if e2 == e_idx {
325 continue;
326 }
327 if ei2 == j {
328 incoming_j += msgs[e2 * 2 + 1][yj];
329 } else if ej2 == j {
330 incoming_j += msgs[e2 * 2][yj];
331 }
332 }
333 tmp[yj] = incoming_j + self.edge_potentials[ep_base + yj * n + yi];
334 }
335 tmp.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
336 })
337 .collect();
338
339 let damp = self.config.damping;
340 for l in 0..n {
341 let old_i2j = msgs[e_idx * 2][l];
342 let old_j2i = msgs[e_idx * 2 + 1][l];
343 let updated_i2j = (1.0 - damp) * new_i2j[l] + damp * old_i2j;
344 let updated_j2i = (1.0 - damp) * new_j2i[l] + damp * old_j2i;
345 max_delta = max_delta
346 .max((updated_i2j - old_i2j).abs())
347 .max((updated_j2i - old_j2i).abs());
348 msgs[e_idx * 2][l] = updated_i2j;
349 msgs[e_idx * 2 + 1][l] = updated_j2i;
350 }
351 }
352
353 if max_delta < self.config.tol {
354 break;
355 }
356 }
357
358 let mut assignments = vec![0usize; n_nodes];
360 for node in 0..n_nodes {
361 let mut best_label = 0;
362 let mut best_b = f64::NEG_INFINITY;
363 let mut b_acc = self.node_potentials[node * n..node * n + n].to_vec();
364 for (e_idx, &Edge { i, j }) in self.edges.iter().enumerate() {
365 for l in 0..n {
366 if i == node {
367 b_acc[l] += msgs[e_idx * 2 + 1][l];
368 } else if j == node {
369 b_acc[l] += msgs[e_idx * 2][l];
370 }
371 }
372 }
373 for l in 0..n {
374 if b_acc[l] > best_b {
375 best_b = b_acc[l];
376 best_label = l;
377 }
378 }
379 assignments[node] = best_label;
380 }
381 Ok(assignments)
382 }
383
384 pub fn n_nodes(&self) -> usize {
386 self.config.n_nodes
387 }
388
389 pub fn n_edges(&self) -> usize {
391 self.edges.len()
392 }
393
394 pub fn n_labels(&self) -> usize {
396 self.config.n_labels
397 }
398}
399
400#[cfg(test)]
403mod tests {
404 use super::*;
405
406 fn default_config(n_nodes: usize, n_labels: usize) -> GraphCrfConfig {
407 GraphCrfConfig {
408 n_nodes,
409 n_labels,
410 max_iter: 50,
411 tol: 1e-8,
412 damping: 0.5,
413 }
414 }
415
416 fn chain_edges(n: usize) -> Vec<Edge> {
417 (0..n - 1).map(|i| Edge { i, j: i + 1 }).collect()
418 }
419
420 #[test]
421 fn construction_succeeds() {
422 let edges = chain_edges(4);
423 let crf = GeneralGraphCrf::new(default_config(4, 3), edges);
424 assert!(crf.is_ok());
425 }
426
427 #[test]
428 fn n_nodes_zero_error() {
429 let result = GeneralGraphCrf::new(default_config(0, 3), vec![]);
430 assert!(result.is_err(), "n_nodes=0 should return Err");
431 }
432
433 #[test]
434 fn n_labels_zero_error() {
435 let result = GeneralGraphCrf::new(default_config(3, 0), vec![]);
436 assert!(result.is_err(), "n_labels=0 should return Err");
437 }
438
439 #[test]
440 fn invalid_edge_node_index_error() {
441 let edges = vec![Edge { i: 0, j: 10 }]; let result = GeneralGraphCrf::new(default_config(3, 2), edges);
443 assert!(
444 result.is_err(),
445 "edge with out-of-range node should return Err"
446 );
447 }
448
449 #[test]
450 fn marginals_shape() {
451 let edges = chain_edges(4);
452 let crf = GeneralGraphCrf::new(default_config(4, 3), edges).expect("new");
453 let beliefs = crf.sum_product_marginals().expect("marginals");
454 assert_eq!(beliefs.len(), 4 * 3);
455 }
456
457 #[test]
458 fn marginals_normalised() {
459 let edges = chain_edges(3);
460 let crf = GeneralGraphCrf::new(default_config(3, 2), edges).expect("new");
461 let beliefs = crf.sum_product_marginals().expect("marginals");
462 for node in 0..3 {
463 let sum: f64 = beliefs[node * 2..(node + 1) * 2]
464 .iter()
465 .map(|&b| b.exp())
466 .sum();
467 assert!(
468 (sum - 1.0).abs() < 1e-9,
469 "node {node} marginals sum={sum} should be 1.0"
470 );
471 }
472 }
473
474 #[test]
475 fn map_decode_shape() {
476 let edges = chain_edges(5);
477 let crf = GeneralGraphCrf::new(default_config(5, 4), edges).expect("new");
478 let map = crf.map_decode().expect("map_decode");
479 assert_eq!(map.len(), 5);
480 }
481
482 #[test]
483 fn map_decode_valid_labels() {
484 let edges = chain_edges(4);
485 let crf = GeneralGraphCrf::new(default_config(4, 3), edges).expect("new");
486 let map = crf.map_decode().expect("map_decode");
487 for &l in &map {
488 assert!(l < 3, "map label {l} >= n_labels=3");
489 }
490 }
491
492 #[test]
493 fn strong_node_potential_drives_assignment() {
494 let mut crf =
496 GeneralGraphCrf::new(default_config(2, 2), vec![Edge { i: 0, j: 1 }]).expect("new");
497 crf.set_node_potential(0, 0, -10.0).expect("set");
498 crf.set_node_potential(0, 1, 10.0).expect("set");
499 let map = crf.map_decode().expect("map_decode");
500 assert_eq!(map[0], 1, "node 0 should be assigned label 1");
501 }
502
503 #[test]
504 fn set_potential_out_of_range_error() {
505 let mut crf = GeneralGraphCrf::new(default_config(3, 2), vec![]).expect("new");
506 let result = crf.set_node_potential(5, 0, 1.0); assert!(result.is_err());
508 }
509
510 #[test]
511 fn single_node_marginals() {
512 let mut crf = GeneralGraphCrf::new(default_config(1, 3), vec![]).expect("new");
514 crf.set_node_potential(0, 0, 0.0).expect("set");
515 crf.set_node_potential(0, 1, 1.0).expect("set");
516 crf.set_node_potential(0, 2, 2.0).expect("set");
517 let beliefs = crf.sum_product_marginals().expect("marginals");
518 assert!(
520 beliefs[2] > beliefs[1],
521 "label 2 should have highest marginal"
522 );
523 assert!(
524 beliefs[1] > beliefs[0],
525 "label 1 should have higher marginal than 0"
526 );
527 }
528
529 #[test]
530 fn cycle_graph_no_panic() {
531 let edges = vec![
533 Edge { i: 0, j: 1 },
534 Edge { i: 1, j: 2 },
535 Edge { i: 2, j: 0 },
536 ];
537 let crf = GeneralGraphCrf::new(default_config(3, 2), edges).expect("new");
538 let beliefs = crf.sum_product_marginals().expect("cycle marginals");
539 assert_eq!(beliefs.len(), 3 * 2);
540 for &b in &beliefs {
541 assert!(b.is_finite(), "cycle belief should be finite, got {b}");
542 }
543 }
544}