Skip to main content

oxicuda_seq/crf/
general_graph.rs

1//! General-graph CRF training via loopy belief propagation.
2//!
3//! Supports arbitrary pairwise factor graphs (not just linear chains).
4//! Inference uses sum-product message passing; MAP uses max-product.
5//!
6//! Reference: Koller & Friedman 2009, "Probabilistic Graphical Models",
7//! Chapter 11 (Belief Propagation).
8//!
9//! Graph structure:
10//!   - `n_nodes` variable nodes, each with `n_labels` possible states.
11//!   - `edges: Vec<(usize, usize)>` — undirected pairwise edges.
12//!   - Node potentials φ_i(y_i): `[n_nodes × n_labels]`.
13//!   - Edge potentials ψ_{ij}(y_i, y_j): `[n_edges × n_labels × n_labels]`.
14//!
15//! Loopy BP is approximate for graphs with cycles.
16//! Guaranteed exact for trees and linear chains.
17
18use crate::error::{SeqError, SeqResult};
19
20// ─── logsumexp helper ────────────────────────────────────────────────────────
21
22#[inline]
23fn logsumexp(xs: &[f64]) -> f64 {
24    let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
25    if m == f64::NEG_INFINITY {
26        return f64::NEG_INFINITY;
27    }
28    let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
29    m + s.ln()
30}
31
32// ─── Graph CRF configuration ─────────────────────────────────────────────────
33
34/// Configuration for the general-graph CRF.
35#[derive(Debug, Clone)]
36pub struct GraphCrfConfig {
37    /// Number of variable nodes in the graph.
38    pub n_nodes: usize,
39    /// Number of labels per node.
40    pub n_labels: usize,
41    /// Maximum BP iterations.
42    pub max_iter: usize,
43    /// Convergence tolerance (max absolute message change).
44    pub tol: f64,
45    /// Damping factor ∈ [0,1): new_msg = (1-damp)*new + damp*old.
46    pub damping: f64,
47}
48
49/// One undirected pairwise edge in the factor graph.
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub struct Edge {
52    pub i: usize,
53    pub j: usize,
54}
55
56// ─── GeneralGraphCrf ──────────────────────────────────────────────────────────
57
58/// General-graph pairwise CRF with loopy belief propagation inference.
59#[derive(Debug, Clone)]
60pub struct GeneralGraphCrf {
61    config: GraphCrfConfig,
62    /// Node unary log-potentials `[n_nodes × n_labels]`.
63    pub node_potentials: Vec<f64>,
64    /// Edge pairwise log-potentials `[n_edges × n_labels × n_labels]`.
65    pub edge_potentials: Vec<f64>,
66    /// Edge list.
67    pub edges: Vec<Edge>,
68}
69
70impl GeneralGraphCrf {
71    // ── Construction ─────────────────────────────────────────────────────────
72
73    /// Create a new CRF with zero potentials and the given edges.
74    ///
75    /// Validates that node indices in `edges` are within `[0, n_nodes)`.
76    pub fn new(config: GraphCrfConfig, edges: Vec<Edge>) -> SeqResult<Self> {
77        if config.n_nodes == 0 {
78            return Err(SeqError::InvalidConfiguration("n_nodes must be > 0".into()));
79        }
80        if config.n_labels == 0 {
81            return Err(SeqError::InvalidConfiguration(
82                "n_labels must be > 0".into(),
83            ));
84        }
85        for &Edge { i, j } in &edges {
86            if i >= config.n_nodes {
87                return Err(SeqError::IndexOutOfBounds {
88                    index: i,
89                    len: config.n_nodes,
90                });
91            }
92            if j >= config.n_nodes {
93                return Err(SeqError::IndexOutOfBounds {
94                    index: j,
95                    len: config.n_nodes,
96                });
97            }
98        }
99        let n_nodes = config.n_nodes;
100        let n_labels = config.n_labels;
101        let n_edges = edges.len();
102        Ok(Self {
103            node_potentials: vec![0.0f64; n_nodes * n_labels],
104            edge_potentials: vec![0.0f64; n_edges * n_labels * n_labels],
105            edges,
106            config,
107        })
108    }
109
110    // ── Potential accessors ───────────────────────────────────────────────────
111
112    /// Set the unary log-potential for node `node`, label `lbl`.
113    pub fn set_node_potential(&mut self, node: usize, lbl: usize, val: f64) -> SeqResult<()> {
114        let n = self.config.n_labels;
115        if node >= self.config.n_nodes {
116            return Err(SeqError::IndexOutOfBounds {
117                index: node,
118                len: self.config.n_nodes,
119            });
120        }
121        if lbl >= n {
122            return Err(SeqError::IndexOutOfBounds { index: lbl, len: n });
123        }
124        self.node_potentials[node * n + lbl] = val;
125        Ok(())
126    }
127
128    /// Set the pairwise log-potential for edge `e_idx`, labels `(li, lj)`.
129    pub fn set_edge_potential(
130        &mut self,
131        e_idx: usize,
132        li: usize,
133        lj: usize,
134        val: f64,
135    ) -> SeqResult<()> {
136        let n = self.config.n_labels;
137        if e_idx >= self.edges.len() {
138            return Err(SeqError::IndexOutOfBounds {
139                index: e_idx,
140                len: self.edges.len(),
141            });
142        }
143        if li >= n || lj >= n {
144            return Err(SeqError::IndexOutOfBounds {
145                index: li.max(lj),
146                len: n,
147            });
148        }
149        self.edge_potentials[e_idx * n * n + li * n + lj] = val;
150        Ok(())
151    }
152
153    // ── Sum-product BP ────────────────────────────────────────────────────────
154
155    /// Run sum-product loopy belief propagation.
156    ///
157    /// Returns marginal log-beliefs `[n_nodes × n_labels]` (normalised so that
158    /// `Σ_label exp(beliefs[node * n_labels + label]) = 1`).
159    pub fn sum_product_marginals(&self) -> SeqResult<Vec<f64>> {
160        let n = self.config.n_labels;
161        let n_nodes = self.config.n_nodes;
162        let n_edges = self.edges.len();
163
164        // Messages: for each directed edge (i→j) and (j→i), a log-message of
165        // length n_labels.  Index: msg[e * 2 + dir][label] where dir=0 → i→j, dir=1 → j→i.
166        let mut msgs = vec![vec![0.0f64; n]; n_edges * 2];
167
168        let mut tmp = vec![0.0f64; n];
169
170        for _iter in 0..self.config.max_iter {
171            let mut max_delta = 0.0f64;
172
173            for e_idx in 0..n_edges {
174                let Edge { i, j } = self.edges[e_idx];
175                let ep_base = e_idx * n * n;
176
177                // ── Update msg i→j ────────────────────────────────────────────
178                // m_{i→j}(y_j) = log Σ_{y_i} [φ_i(y_i) * ψ_{ij}(y_i, y_j) *
179                //                Π_{k≠j} m_{k→i}(y_i)]
180                let new_i2j: Vec<f64> = (0..n)
181                    .map(|yj| {
182                        for yi in 0..n {
183                            // Aggregate incoming messages to i (exclude j→i = msgs[e*2+1])
184                            let mut incoming_i = self.node_potentials[i * n + yi];
185                            for (e2, &Edge { i: ei2, j: ej2 }) in self.edges.iter().enumerate() {
186                                if e2 == e_idx {
187                                    continue;
188                                }
189                                if ei2 == i {
190                                    // msg j2→i is in e2*2+1
191                                    incoming_i += msgs[e2 * 2 + 1][yi];
192                                } else if ej2 == i {
193                                    // msg i2→j2 direction, msg from e2's i side to i
194                                    incoming_i += msgs[e2 * 2][yi];
195                                }
196                            }
197                            tmp[yi] = incoming_i + self.edge_potentials[ep_base + yi * n + yj];
198                        }
199                        logsumexp(&tmp)
200                    })
201                    .collect();
202
203                // Normalise
204                let lse = logsumexp(&new_i2j);
205                let new_i2j: Vec<f64> = new_i2j.iter().map(|&v| v - lse).collect();
206
207                // ── Update msg j→i ────────────────────────────────────────────
208                let new_j2i: Vec<f64> = (0..n)
209                    .map(|yi| {
210                        for yj in 0..n {
211                            let mut incoming_j = self.node_potentials[j * n + yj];
212                            for (e2, &Edge { i: ei2, j: ej2 }) in self.edges.iter().enumerate() {
213                                if e2 == e_idx {
214                                    continue;
215                                }
216                                if ei2 == j {
217                                    incoming_j += msgs[e2 * 2 + 1][yj];
218                                } else if ej2 == j {
219                                    incoming_j += msgs[e2 * 2][yj];
220                                }
221                            }
222                            // ψ_{ij}(yi, yj) is symmetric in undirected graph;
223                            // use ψ(yj, yi) for j→i direction
224                            tmp[yj] = incoming_j + self.edge_potentials[ep_base + yj * n + yi];
225                        }
226                        logsumexp(&tmp)
227                    })
228                    .collect();
229
230                let lse2 = logsumexp(&new_j2i);
231                let new_j2i: Vec<f64> = new_j2i.iter().map(|&v| v - lse2).collect();
232
233                // ── Damping + convergence check ────────────────────────────────
234                let damp = self.config.damping;
235                for l in 0..n {
236                    let old_i2j = msgs[e_idx * 2][l];
237                    let old_j2i = msgs[e_idx * 2 + 1][l];
238                    let updated_i2j = (1.0 - damp) * new_i2j[l] + damp * old_i2j;
239                    let updated_j2i = (1.0 - damp) * new_j2i[l] + damp * old_j2i;
240                    max_delta = max_delta
241                        .max((updated_i2j - old_i2j).abs())
242                        .max((updated_j2i - old_j2i).abs());
243                    msgs[e_idx * 2][l] = updated_i2j;
244                    msgs[e_idx * 2 + 1][l] = updated_j2i;
245                }
246            }
247
248            if max_delta < self.config.tol {
249                break;
250            }
251        }
252
253        // ── Compute beliefs ───────────────────────────────────────────────────
254        let mut beliefs = vec![0.0f64; n_nodes * n];
255        for node in 0..n_nodes {
256            for l in 0..n {
257                let mut b = self.node_potentials[node * n + l];
258                for (e_idx, &Edge { i, j }) in self.edges.iter().enumerate() {
259                    if i == node {
260                        // Incoming message from j (stored as j→i = e_idx*2+1)
261                        b += msgs[e_idx * 2 + 1][l];
262                    } else if j == node {
263                        // Incoming message from i (stored as i→j = e_idx*2)
264                        b += msgs[e_idx * 2][l];
265                    }
266                }
267                beliefs[node * n + l] = b;
268            }
269            // Normalise
270            let lse = logsumexp(&beliefs[node * n..(node + 1) * n]);
271            for l in 0..n {
272                beliefs[node * n + l] -= lse;
273            }
274        }
275
276        Ok(beliefs)
277    }
278
279    // ── Max-product MAP ───────────────────────────────────────────────────────
280
281    /// Run max-product loopy BP to compute approximate MAP assignments.
282    ///
283    /// Returns `[n_nodes]` with the argmax label for each node.
284    pub fn map_decode(&self) -> SeqResult<Vec<usize>> {
285        let n = self.config.n_labels;
286        let n_nodes = self.config.n_nodes;
287        let n_edges = self.edges.len();
288
289        let mut msgs = vec![vec![0.0f64; n]; n_edges * 2];
290        let mut tmp = vec![0.0f64; n];
291
292        for _iter in 0..self.config.max_iter {
293            let mut max_delta = 0.0f64;
294
295            for e_idx in 0..n_edges {
296                let Edge { i, j } = self.edges[e_idx];
297                let ep_base = e_idx * n * n;
298
299                let new_i2j: Vec<f64> = (0..n)
300                    .map(|yj| {
301                        for yi in 0..n {
302                            let mut incoming_i = self.node_potentials[i * n + yi];
303                            for (e2, &Edge { i: ei2, j: ej2 }) in self.edges.iter().enumerate() {
304                                if e2 == e_idx {
305                                    continue;
306                                }
307                                if ei2 == i {
308                                    incoming_i += msgs[e2 * 2 + 1][yi];
309                                } else if ej2 == i {
310                                    incoming_i += msgs[e2 * 2][yi];
311                                }
312                            }
313                            tmp[yi] = incoming_i + self.edge_potentials[ep_base + yi * n + yj];
314                        }
315                        tmp.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
316                    })
317                    .collect();
318
319                let new_j2i: Vec<f64> = (0..n)
320                    .map(|yi| {
321                        for yj in 0..n {
322                            let mut incoming_j = self.node_potentials[j * n + yj];
323                            for (e2, &Edge { i: ei2, j: ej2 }) in self.edges.iter().enumerate() {
324                                if e2 == e_idx {
325                                    continue;
326                                }
327                                if ei2 == j {
328                                    incoming_j += msgs[e2 * 2 + 1][yj];
329                                } else if ej2 == j {
330                                    incoming_j += msgs[e2 * 2][yj];
331                                }
332                            }
333                            tmp[yj] = incoming_j + self.edge_potentials[ep_base + yj * n + yi];
334                        }
335                        tmp.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
336                    })
337                    .collect();
338
339                let damp = self.config.damping;
340                for l in 0..n {
341                    let old_i2j = msgs[e_idx * 2][l];
342                    let old_j2i = msgs[e_idx * 2 + 1][l];
343                    let updated_i2j = (1.0 - damp) * new_i2j[l] + damp * old_i2j;
344                    let updated_j2i = (1.0 - damp) * new_j2i[l] + damp * old_j2i;
345                    max_delta = max_delta
346                        .max((updated_i2j - old_i2j).abs())
347                        .max((updated_j2i - old_j2i).abs());
348                    msgs[e_idx * 2][l] = updated_i2j;
349                    msgs[e_idx * 2 + 1][l] = updated_j2i;
350                }
351            }
352
353            if max_delta < self.config.tol {
354                break;
355            }
356        }
357
358        // MAP: argmax over beliefs
359        let mut assignments = vec![0usize; n_nodes];
360        for node in 0..n_nodes {
361            let mut best_label = 0;
362            let mut best_b = f64::NEG_INFINITY;
363            let mut b_acc = self.node_potentials[node * n..node * n + n].to_vec();
364            for (e_idx, &Edge { i, j }) in self.edges.iter().enumerate() {
365                for l in 0..n {
366                    if i == node {
367                        b_acc[l] += msgs[e_idx * 2 + 1][l];
368                    } else if j == node {
369                        b_acc[l] += msgs[e_idx * 2][l];
370                    }
371                }
372            }
373            for l in 0..n {
374                if b_acc[l] > best_b {
375                    best_b = b_acc[l];
376                    best_label = l;
377                }
378            }
379            assignments[node] = best_label;
380        }
381        Ok(assignments)
382    }
383
384    /// Number of nodes.
385    pub fn n_nodes(&self) -> usize {
386        self.config.n_nodes
387    }
388
389    /// Number of edges.
390    pub fn n_edges(&self) -> usize {
391        self.edges.len()
392    }
393
394    /// Number of labels.
395    pub fn n_labels(&self) -> usize {
396        self.config.n_labels
397    }
398}
399
400// ─── Tests ────────────────────────────────────────────────────────────────────
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    fn default_config(n_nodes: usize, n_labels: usize) -> GraphCrfConfig {
407        GraphCrfConfig {
408            n_nodes,
409            n_labels,
410            max_iter: 50,
411            tol: 1e-8,
412            damping: 0.5,
413        }
414    }
415
416    fn chain_edges(n: usize) -> Vec<Edge> {
417        (0..n - 1).map(|i| Edge { i, j: i + 1 }).collect()
418    }
419
420    #[test]
421    fn construction_succeeds() {
422        let edges = chain_edges(4);
423        let crf = GeneralGraphCrf::new(default_config(4, 3), edges);
424        assert!(crf.is_ok());
425    }
426
427    #[test]
428    fn n_nodes_zero_error() {
429        let result = GeneralGraphCrf::new(default_config(0, 3), vec![]);
430        assert!(result.is_err(), "n_nodes=0 should return Err");
431    }
432
433    #[test]
434    fn n_labels_zero_error() {
435        let result = GeneralGraphCrf::new(default_config(3, 0), vec![]);
436        assert!(result.is_err(), "n_labels=0 should return Err");
437    }
438
439    #[test]
440    fn invalid_edge_node_index_error() {
441        let edges = vec![Edge { i: 0, j: 10 }]; // node 10 out of range for 3-node graph
442        let result = GeneralGraphCrf::new(default_config(3, 2), edges);
443        assert!(
444            result.is_err(),
445            "edge with out-of-range node should return Err"
446        );
447    }
448
449    #[test]
450    fn marginals_shape() {
451        let edges = chain_edges(4);
452        let crf = GeneralGraphCrf::new(default_config(4, 3), edges).expect("new");
453        let beliefs = crf.sum_product_marginals().expect("marginals");
454        assert_eq!(beliefs.len(), 4 * 3);
455    }
456
457    #[test]
458    fn marginals_normalised() {
459        let edges = chain_edges(3);
460        let crf = GeneralGraphCrf::new(default_config(3, 2), edges).expect("new");
461        let beliefs = crf.sum_product_marginals().expect("marginals");
462        for node in 0..3 {
463            let sum: f64 = beliefs[node * 2..(node + 1) * 2]
464                .iter()
465                .map(|&b| b.exp())
466                .sum();
467            assert!(
468                (sum - 1.0).abs() < 1e-9,
469                "node {node} marginals sum={sum} should be 1.0"
470            );
471        }
472    }
473
474    #[test]
475    fn map_decode_shape() {
476        let edges = chain_edges(5);
477        let crf = GeneralGraphCrf::new(default_config(5, 4), edges).expect("new");
478        let map = crf.map_decode().expect("map_decode");
479        assert_eq!(map.len(), 5);
480    }
481
482    #[test]
483    fn map_decode_valid_labels() {
484        let edges = chain_edges(4);
485        let crf = GeneralGraphCrf::new(default_config(4, 3), edges).expect("new");
486        let map = crf.map_decode().expect("map_decode");
487        for &l in &map {
488            assert!(l < 3, "map label {l} >= n_labels=3");
489        }
490    }
491
492    #[test]
493    fn strong_node_potential_drives_assignment() {
494        // If node 0 has a very strong preference for label 1, MAP should pick 1.
495        let mut crf =
496            GeneralGraphCrf::new(default_config(2, 2), vec![Edge { i: 0, j: 1 }]).expect("new");
497        crf.set_node_potential(0, 0, -10.0).expect("set");
498        crf.set_node_potential(0, 1, 10.0).expect("set");
499        let map = crf.map_decode().expect("map_decode");
500        assert_eq!(map[0], 1, "node 0 should be assigned label 1");
501    }
502
503    #[test]
504    fn set_potential_out_of_range_error() {
505        let mut crf = GeneralGraphCrf::new(default_config(3, 2), vec![]).expect("new");
506        let result = crf.set_node_potential(5, 0, 1.0); // node 5 out of range
507        assert!(result.is_err());
508    }
509
510    #[test]
511    fn single_node_marginals() {
512        // No edges: marginals from node_potentials only.
513        let mut crf = GeneralGraphCrf::new(default_config(1, 3), vec![]).expect("new");
514        crf.set_node_potential(0, 0, 0.0).expect("set");
515        crf.set_node_potential(0, 1, 1.0).expect("set");
516        crf.set_node_potential(0, 2, 2.0).expect("set");
517        let beliefs = crf.sum_product_marginals().expect("marginals");
518        // Label 2 should have highest belief
519        assert!(
520            beliefs[2] > beliefs[1],
521            "label 2 should have highest marginal"
522        );
523        assert!(
524            beliefs[1] > beliefs[0],
525            "label 1 should have higher marginal than 0"
526        );
527    }
528
529    #[test]
530    fn cycle_graph_no_panic() {
531        // Triangle: 3 nodes, 3 edges (cycle) — loopy BP, just check it runs.
532        let edges = vec![
533            Edge { i: 0, j: 1 },
534            Edge { i: 1, j: 2 },
535            Edge { i: 2, j: 0 },
536        ];
537        let crf = GeneralGraphCrf::new(default_config(3, 2), edges).expect("new");
538        let beliefs = crf.sum_product_marginals().expect("cycle marginals");
539        assert_eq!(beliefs.len(), 3 * 2);
540        for &b in &beliefs {
541            assert!(b.is_finite(), "cycle belief should be finite, got {b}");
542        }
543    }
544}