Skip to main content

oxicuda_seq/mrf/
junction_tree.rs

1//! Junction-tree (clique-tree) exact inference for discrete pairwise/higher-order
2//! Markov Random Fields.
3//!
4//! Implements the Lauritzen-Spiegelhalter / Hugin algorithm:
5//!
6//! 1. **Moralise** the factor graph — connect every pair of variables that co-occur
7//!    in the scope of some factor (this also marries co-parents).
8//! 2. **Triangulate** the moral graph by eliminating variables in a heuristic order
9//!    (min-fill, breaking ties by min-degree).  The set of a variable together with
10//!    its still-living neighbours at elimination time is a candidate clique.
11//! 3. Collect the **maximal cliques** (drop any candidate that is a subset of another).
12//! 4. Build a **clique tree** as a maximum-weight spanning tree over the cliques,
13//!    where the weight of an edge between two cliques is the size of their separator
14//!    (their shared variables).  Maximising separator sizes guarantees the
15//!    running-intersection property.
16//! 5. **Assign** each input factor to a clique that contains its scope and multiply
17//!    it into that clique's potential.
18//! 6. **Calibrate** with a Hugin-style two-pass schedule (collect to a root, then
19//!    distribute from the root) so that every clique holds the joint marginal over
20//!    its variables (up to a global scale equal to the partition function `Z`).
21//!
22//! All potentials are stored in the **log domain** (`ln`-potentials, row-major over
23//! the clique's joint configuration in increasing variable order) so that the
24//! message passing never under/overflows.  A configuration value of `-inf`
25//! represents a zero-probability entry.
26
27use super::mrf::Mrf;
28use crate::error::{SeqError, SeqResult};
29
30/// Configuration for a junction tree: number of variables and their cardinalities.
31#[derive(Debug, Clone)]
32pub struct JunctionTreeConfig {
33    /// Number of discrete random variables.
34    pub n_vars: usize,
35    /// Cardinality (number of states) of each variable; length must equal `n_vars`.
36    pub cardinalities: Vec<usize>,
37}
38
39/// A clique of the junction tree.
40///
41/// * `vars` — the variables in the clique, stored in **strictly increasing** order.
42/// * `potential` — the log-potential table over the clique's joint configuration,
43///   row-major in `vars` order (the last variable varies fastest).
44#[derive(Debug, Clone)]
45pub struct Clique {
46    /// Variables in the clique (sorted ascending).
47    pub vars: Vec<usize>,
48    /// Log-potential table over the clique's joint configuration (row-major).
49    pub potential: Vec<f64>,
50}
51
52/// A separator between two adjacent cliques: the shared variables and a cached
53/// log-potential table over them (used by Hugin message passing).
54#[derive(Debug, Clone)]
55struct Separator {
56    /// Index of the first incident clique.
57    clique_a: usize,
58    /// Index of the second incident clique.
59    clique_b: usize,
60    /// Shared variables (sorted ascending).
61    vars: Vec<usize>,
62    /// Current log-potential table over the separator (row-major).
63    potential: Vec<f64>,
64}
65
66/// A calibrated (or calibratable) junction tree for exact inference.
67#[derive(Debug, Clone)]
68pub struct JunctionTree {
69    cfg: JunctionTreeConfig,
70    cliques: Vec<Clique>,
71    separators: Vec<Separator>,
72    /// Adjacency list over cliques; each entry is `(neighbour_clique, separator_idx)`.
73    adjacency: Vec<Vec<(usize, usize)>>,
74    /// A root clique and an ordering of cliques such that parents precede children
75    /// (breadth-first from the root over the clique tree).  Used for the two passes.
76    bfs_order: Vec<usize>,
77    /// Parent clique of each clique in the rooted tree (`usize::MAX` for the root).
78    parent: Vec<usize>,
79    /// Separator index connecting each clique to its parent (`usize::MAX` for root).
80    parent_sep: Vec<usize>,
81}
82
83/// Compute the number of joint configurations of a set of variables.
84fn config_count(vars: &[usize], cards: &[usize]) -> usize {
85    let mut n = 1usize;
86    for &v in vars {
87        n = n.saturating_mul(cards[v]);
88    }
89    n
90}
91
92/// Decode a linear index into the per-variable state assignment (row-major, last
93/// variable fastest) for `vars`.  `out` must have length `vars.len()`.
94fn decode_index(mut idx: usize, vars: &[usize], cards: &[usize], out: &mut [usize]) {
95    for k in (0..vars.len()).rev() {
96        let c = cards[vars[k]];
97        out[k] = idx % c;
98        idx /= c;
99    }
100}
101
102/// Given a full assignment over `super_vars` (as decoded states aligned to
103/// `super_vars`), compute the linear index into a table over the `sub_vars`
104/// (a subset of `super_vars`, both sorted ascending).
105fn project_index(
106    super_vars: &[usize],
107    super_states: &[usize],
108    sub_vars: &[usize],
109    cards: &[usize],
110) -> usize {
111    let mut idx = 0usize;
112    let mut sp = 0usize;
113    for &sv in sub_vars {
114        // Advance the super pointer to the matching variable (both sorted).
115        while super_vars[sp] != sv {
116            sp += 1;
117        }
118        idx = idx * cards[sv] + super_states[sp];
119    }
120    idx
121}
122
123/// Log-sum-exp of a slice, returning `-inf` for an all-`-inf` slice.
124fn log_sum_exp(xs: &[f64]) -> f64 {
125    let mut m = f64::NEG_INFINITY;
126    for &x in xs {
127        if x > m {
128            m = x;
129        }
130    }
131    if m == f64::NEG_INFINITY {
132        return f64::NEG_INFINITY;
133    }
134    let mut s = 0.0;
135    for &x in xs {
136        s += (x - m).exp();
137    }
138    m + s.ln()
139}
140
141impl JunctionTree {
142    /// Build a junction tree from a list of factors.
143    ///
144    /// Each factor is `(vars, table)` where `vars` is the (unsorted is allowed)
145    /// scope and `table` is a **linear-domain** non-negative potential table,
146    /// row-major over the factor's joint configuration in the *given* `vars`
147    /// order.  Tables are converted to the log domain internally.
148    pub fn build(cfg: &JunctionTreeConfig, factors: &[(Vec<usize>, Vec<f64>)]) -> SeqResult<Self> {
149        if cfg.n_vars == 0 {
150            return Err(SeqError::InvalidConfiguration(
151                "n_vars must be >= 1".to_string(),
152            ));
153        }
154        if cfg.cardinalities.len() != cfg.n_vars {
155            return Err(SeqError::ShapeMismatch {
156                expected: cfg.n_vars,
157                got: cfg.cardinalities.len(),
158            });
159        }
160        for &c in &cfg.cardinalities {
161            if c == 0 {
162                return Err(SeqError::InvalidConfiguration(
163                    "every cardinality must be >= 1".to_string(),
164                ));
165            }
166        }
167        for (vars, table) in factors {
168            for &v in vars {
169                if v >= cfg.n_vars {
170                    return Err(SeqError::IndexOutOfBounds {
171                        index: v,
172                        len: cfg.n_vars,
173                    });
174                }
175            }
176            let expected = config_count(vars, &cfg.cardinalities);
177            if table.len() != expected {
178                return Err(SeqError::ShapeMismatch {
179                    expected,
180                    got: table.len(),
181                });
182            }
183        }
184
185        let cards = &cfg.cardinalities;
186        let n = cfg.n_vars;
187
188        // --- Step 1: moral graph adjacency (symmetric boolean matrix). ---
189        let mut adj = vec![vec![false; n]; n];
190        for (vars, _) in factors {
191            for a in 0..vars.len() {
192                for b in (a + 1)..vars.len() {
193                    let (u, w) = (vars[a], vars[b]);
194                    if u != w {
195                        adj[u][w] = true;
196                        adj[w][u] = true;
197                    }
198                }
199            }
200        }
201
202        // --- Step 2+3: triangulation via heuristic elimination -> candidate cliques. ---
203        let candidate_cliques = Self::eliminate_for_cliques(&adj, cards);
204
205        // Keep only maximal cliques (drop subsets of another candidate).
206        let maximal = Self::keep_maximal(candidate_cliques);
207
208        // --- Step 4: build clique tree (maximum-weight spanning tree on |separator|). ---
209        let (adjacency, separators) = Self::build_clique_tree(&maximal, cards);
210
211        // Allocate clique potentials in the log domain, initialised to log(1)=0.
212        let mut cliques: Vec<Clique> = maximal
213            .into_iter()
214            .map(|vars| {
215                let len = config_count(&vars, cards);
216                Clique {
217                    vars,
218                    potential: vec![0.0; len],
219                }
220            })
221            .collect();
222
223        // --- Step 5: assign each factor to a containing clique and multiply in. ---
224        for (vars, table) in factors {
225            let mut sorted = vars.clone();
226            sorted.sort_unstable();
227            sorted.dedup();
228            let target = cliques
229                .iter()
230                .position(|c| sorted.iter().all(|v| c.vars.contains(v)));
231            let target = match target {
232                Some(t) => t,
233                None => {
234                    return Err(SeqError::GraphInvariantViolated(format!(
235                        "factor scope {sorted:?} not contained in any clique"
236                    )));
237                }
238            };
239            Self::multiply_factor_into_clique(&mut cliques[target], vars, table, cards);
240        }
241
242        // Root the tree at clique 0 and compute a BFS order with parent pointers.
243        let (bfs_order, parent, parent_sep) = Self::root_tree(cliques.len(), &adjacency);
244
245        Ok(Self {
246            cfg: cfg.clone(),
247            cliques,
248            separators,
249            adjacency,
250            bfs_order,
251            parent,
252            parent_sep,
253        })
254    }
255
256    /// Build a junction tree directly from a pairwise [`Mrf`].  Unary terms become
257    /// single-variable factors and each pairwise term a two-variable factor.  The
258    /// [`Mrf`] stores **energies** (probability ∝ `exp(-energy)`), so tables are
259    /// `exp(-unary)` and `exp(-pairwise)` respectively.
260    pub fn from_mrf(mrf: &Mrf) -> SeqResult<Self> {
261        let cfg = JunctionTreeConfig {
262            n_vars: mrf.n_nodes,
263            cardinalities: vec![mrf.n_labels; mrf.n_nodes],
264        };
265        let nl = mrf.n_labels;
266        let l2 = nl * nl;
267        let mut factors: Vec<(Vec<usize>, Vec<f64>)> = Vec::new();
268        for i in 0..mrf.n_nodes {
269            let mut table = vec![0.0; nl];
270            for l in 0..nl {
271                table[l] = (-mrf.unary[i * nl + l]).exp();
272            }
273            factors.push((vec![i], table));
274        }
275        for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
276            let (lo, hi) = if u < v { (u, v) } else { (v, u) };
277            let mut table = vec![0.0; l2];
278            // Table is row-major over (lo, hi); the Mrf stores pairwise as (u, v).
279            for a in 0..nl {
280                for b in 0..nl {
281                    // a indexes `lo`, b indexes `hi`.
282                    let (lu, lv) = if u == lo { (a, b) } else { (b, a) };
283                    table[a * nl + b] = (-mrf.pairwise[e_idx * l2 + lu * nl + lv]).exp();
284                }
285            }
286            factors.push((vec![lo, hi], table));
287        }
288        Self::build(&cfg, &factors)
289    }
290
291    /// Heuristic elimination (min-fill, ties broken by min-degree) returning the
292    /// candidate cliques formed at each elimination step.
293    fn eliminate_for_cliques(adj: &[Vec<bool>], cards: &[usize]) -> Vec<Vec<usize>> {
294        let n = adj.len();
295        // Working copy of adjacency we will fill in / remove from.
296        let mut work = adj.to_vec();
297        let mut alive = vec![true; n];
298        let mut cliques: Vec<Vec<usize>> = Vec::new();
299
300        for _ in 0..n {
301            // Choose the next variable to eliminate.
302            let mut best_var = usize::MAX;
303            let mut best_fill = usize::MAX;
304            let mut best_deg = usize::MAX;
305            for v in 0..n {
306                if !alive[v] {
307                    continue;
308                }
309                let neighbours: Vec<usize> = (0..n)
310                    .filter(|&u| alive[u] && u != v && work[v][u])
311                    .collect();
312                let deg = neighbours.len();
313                // Count missing edges among neighbours (fill-in count).
314                let mut fill = 0usize;
315                for a in 0..neighbours.len() {
316                    for b in (a + 1)..neighbours.len() {
317                        if !work[neighbours[a]][neighbours[b]] {
318                            fill += 1;
319                        }
320                    }
321                }
322                if fill < best_fill || (fill == best_fill && deg < best_deg) {
323                    best_fill = fill;
324                    best_deg = deg;
325                    best_var = v;
326                }
327            }
328            if best_var == usize::MAX {
329                break;
330            }
331
332            // Form the candidate clique: best_var + its living neighbours.
333            let neighbours: Vec<usize> = (0..n)
334                .filter(|&u| alive[u] && u != best_var && work[best_var][u])
335                .collect();
336            let mut clique = Vec::with_capacity(neighbours.len() + 1);
337            clique.push(best_var);
338            clique.extend_from_slice(&neighbours);
339            clique.sort_unstable();
340            cliques.push(clique);
341
342            // Add fill edges to make the neighbourhood a clique.
343            for a in 0..neighbours.len() {
344                for b in (a + 1)..neighbours.len() {
345                    work[neighbours[a]][neighbours[b]] = true;
346                    work[neighbours[b]][neighbours[a]] = true;
347                }
348            }
349            // Eliminate the variable.
350            alive[best_var] = false;
351        }
352
353        // Cardinalities are only needed implicitly; isolated variables (no factor,
354        // no neighbour) still produce a singleton clique above, which is correct.
355        let _ = cards;
356        cliques
357    }
358
359    /// Drop any clique that is a subset of another, returning the maximal cliques.
360    fn keep_maximal(mut cliques: Vec<Vec<usize>>) -> Vec<Vec<usize>> {
361        // Sort by descending size so that supersets come first.
362        cliques.sort_by_key(|c| std::cmp::Reverse(c.len()));
363        let mut maximal: Vec<Vec<usize>> = Vec::new();
364        for c in cliques {
365            let is_subset = maximal.iter().any(|m| c.iter().all(|v| m.contains(v)));
366            if !is_subset {
367                maximal.push(c);
368            }
369        }
370        maximal
371    }
372
373    /// Build a maximum-weight spanning forest over the cliques (weight = separator
374    /// size) using Prim/Kruskal-style greedy selection; returns the adjacency list
375    /// and the separators.  Disconnected groups of cliques form a forest, which the
376    /// two-pass schedule handles by visiting each tree independently.
377    fn build_clique_tree(
378        cliques: &[Vec<usize>],
379        cards: &[usize],
380    ) -> (Vec<Vec<(usize, usize)>>, Vec<Separator>) {
381        let m = cliques.len();
382        let mut adjacency: Vec<Vec<(usize, usize)>> = vec![Vec::new(); m];
383        let mut separators: Vec<Separator> = Vec::new();
384        if m <= 1 {
385            return (adjacency, separators);
386        }
387
388        // Candidate edges with their separator size; Kruskal on descending weight.
389        let mut edges: Vec<(usize, usize, usize)> = Vec::new();
390        for a in 0..m {
391            for b in (a + 1)..m {
392                let shared = shared_vars(&cliques[a], &cliques[b]);
393                edges.push((shared.len(), a, b));
394            }
395        }
396        edges.sort_by_key(|e| std::cmp::Reverse(e.0));
397
398        // Union-find for cycle detection.
399        let mut parent: Vec<usize> = (0..m).collect();
400        fn find(parent: &mut [usize], x: usize) -> usize {
401            let mut r = x;
402            while parent[r] != r {
403                r = parent[r];
404            }
405            // Path compression.
406            let mut c = x;
407            while parent[c] != r {
408                let next = parent[c];
409                parent[c] = r;
410                c = next;
411            }
412            r
413        }
414
415        for (_w, a, b) in edges {
416            let ra = find(&mut parent, a);
417            let rb = find(&mut parent, b);
418            if ra == rb {
419                continue;
420            }
421            parent[ra] = rb;
422            let shared = shared_vars(&cliques[a], &cliques[b]);
423            let len = config_count(&shared, cards);
424            let sep_idx = separators.len();
425            separators.push(Separator {
426                clique_a: a,
427                clique_b: b,
428                vars: shared,
429                potential: vec![0.0; len],
430            });
431            adjacency[a].push((b, sep_idx));
432            adjacency[b].push((a, sep_idx));
433        }
434
435        (adjacency, separators)
436    }
437
438    /// Multiply (add in the log domain) a linear-domain factor table into a clique
439    /// potential.  `factor_vars` may be unsorted (matching the factor's table
440    /// layout); the clique potential is row-major over the sorted `clique.vars`.
441    fn multiply_factor_into_clique(
442        clique: &mut Clique,
443        factor_vars: &[usize],
444        factor_table: &[f64],
445        cards: &[usize],
446    ) {
447        let len = clique.potential.len();
448        let mut states = vec![0usize; clique.vars.len()];
449        // Pre-resolve each factor variable's position within the (sorted) clique
450        // variables; the factor's scope is always a subset of the clique by
451        // construction, so every lookup succeeds.
452        let positions: Vec<usize> = factor_vars
453            .iter()
454            .filter_map(|fv| clique.vars.binary_search(fv).ok())
455            .collect();
456        if positions.len() != factor_vars.len() {
457            // Defensive: scope not fully contained — leave the potential unchanged.
458            return;
459        }
460        for idx in 0..len {
461            decode_index(idx, &clique.vars, cards, &mut states);
462            // Compute the factor's linear index for this clique configuration.
463            let mut fidx = 0usize;
464            for (k, &fv) in factor_vars.iter().enumerate() {
465                fidx = fidx * cards[fv] + states[positions[k]];
466            }
467            let val = factor_table[fidx];
468            clique.potential[idx] += if val > 0.0 {
469                val.ln()
470            } else {
471                f64::NEG_INFINITY
472            };
473        }
474    }
475
476    /// Root the clique forest, returning a BFS visiting order (parents before
477    /// children), the parent of each clique, and the separator connecting it to its
478    /// parent.  Disconnected components are each rooted at their lowest-index clique.
479    fn root_tree(
480        m: usize,
481        adjacency: &[Vec<(usize, usize)>],
482    ) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
483        let mut visited = vec![false; m];
484        let mut order = Vec::with_capacity(m);
485        let mut parent = vec![usize::MAX; m];
486        let mut parent_sep = vec![usize::MAX; m];
487        for start in 0..m {
488            if visited[start] {
489                continue;
490            }
491            visited[start] = true;
492            let mut queue = std::collections::VecDeque::new();
493            queue.push_back(start);
494            while let Some(c) = queue.pop_front() {
495                order.push(c);
496                for &(nbr, sep) in &adjacency[c] {
497                    if !visited[nbr] {
498                        visited[nbr] = true;
499                        parent[nbr] = c;
500                        parent_sep[nbr] = sep;
501                        queue.push_back(nbr);
502                    }
503                }
504            }
505        }
506        (order, parent, parent_sep)
507    }
508
509    /// Marginalise a clique's log-potential onto a separator's variables, returning
510    /// a fresh log-potential table over the separator (row-major in `sep_vars`).
511    fn marginalise_to_separator(&self, clique_idx: usize, sep_vars: &[usize]) -> Vec<f64> {
512        let clique = &self.cliques[clique_idx];
513        let cards = &self.cfg.cardinalities;
514        let sep_len = config_count(sep_vars, cards);
515        // Accumulate exp-domain mass per separator config via log-sum-exp.
516        let mut buckets: Vec<Vec<f64>> = vec![Vec::new(); sep_len];
517        let mut states = vec![0usize; clique.vars.len()];
518        for idx in 0..clique.potential.len() {
519            decode_index(idx, &clique.vars, cards, &mut states);
520            let sidx = project_index(&clique.vars, &states, sep_vars, cards);
521            buckets[sidx].push(clique.potential[idx]);
522        }
523        let mut out = vec![f64::NEG_INFINITY; sep_len];
524        for (s, bucket) in buckets.iter().enumerate() {
525            out[s] = log_sum_exp(bucket);
526        }
527        out
528    }
529
530    /// Multiply a separator's log-message (`new - old`) into a clique potential,
531    /// broadcasting over the clique configurations that share each separator config.
532    fn absorb_message_into_clique(&mut self, clique_idx: usize, sep_idx: usize, delta: &[f64]) {
533        let sep_vars = self.separators[sep_idx].vars.clone();
534        let cards = self.cfg.cardinalities.clone();
535        let clique_vars = self.cliques[clique_idx].vars.clone();
536        let mut states = vec![0usize; clique_vars.len()];
537        let len = self.cliques[clique_idx].potential.len();
538        for idx in 0..len {
539            decode_index(idx, &clique_vars, &cards, &mut states);
540            let sidx = project_index(&clique_vars, &states, &sep_vars, &cards);
541            self.cliques[clique_idx].potential[idx] += delta[sidx];
542        }
543    }
544
545    /// Run Hugin two-pass message passing (collect to root, distribute from root)
546    /// so that every clique stores its joint marginal (scaled by `Z`) and every
547    /// separator is consistent with both incident cliques.
548    pub fn calibrate(&mut self) -> SeqResult<()> {
549        if self.cliques.is_empty() {
550            return Ok(());
551        }
552
553        // Pass 1 — Collect: process cliques in reverse BFS order so that children
554        // send messages to their parent before the parent is processed.
555        let order = self.bfs_order.clone();
556        for &c in order.iter().rev() {
557            let p = self.parent[c];
558            if p == usize::MAX {
559                continue; // root has no parent to send to
560            }
561            let sep_idx = self.parent_sep[c];
562            let sep_vars = self.separators[sep_idx].vars.clone();
563            // New separator potential from the child clique.
564            let new_sep = self.marginalise_to_separator(c, &sep_vars);
565            // delta = new_sep - old_sep (log domain); update parent and separator.
566            let old_sep = self.separators[sep_idx].potential.clone();
567            let delta: Vec<f64> = new_sep
568                .iter()
569                .zip(old_sep.iter())
570                .map(|(&a, &b)| safe_log_sub(a, b))
571                .collect();
572            self.absorb_message_into_clique(p, sep_idx, &delta);
573            self.separators[sep_idx].potential = new_sep;
574        }
575
576        // Pass 2 — Distribute: process cliques in BFS order so that each parent
577        // sends to its children after it has been fully updated.
578        for &c in order.iter() {
579            // For each child of c, send a message c -> child.
580            let children: Vec<(usize, usize)> = self.adjacency[c]
581                .iter()
582                .filter(|&&(nbr, _)| self.parent[nbr] == c)
583                .copied()
584                .collect();
585            for (child, sep_idx) in children {
586                let sep_vars = self.separators[sep_idx].vars.clone();
587                let new_sep = self.marginalise_to_separator(c, &sep_vars);
588                let old_sep = self.separators[sep_idx].potential.clone();
589                let delta: Vec<f64> = new_sep
590                    .iter()
591                    .zip(old_sep.iter())
592                    .map(|(&a, &b)| safe_log_sub(a, b))
593                    .collect();
594                self.absorb_message_into_clique(child, sep_idx, &delta);
595                self.separators[sep_idx].potential = new_sep;
596            }
597        }
598
599        Ok(())
600    }
601
602    /// Marginal distribution over a single variable after calibration, normalised
603    /// to sum to 1.
604    pub fn marginal(&self, var: usize) -> SeqResult<Vec<f64>> {
605        if var >= self.cfg.n_vars {
606            return Err(SeqError::IndexOutOfBounds {
607                index: var,
608                len: self.cfg.n_vars,
609            });
610        }
611        let card = self.cfg.cardinalities[var];
612        // Find any clique containing the variable.
613        let clique_idx = self
614            .cliques
615            .iter()
616            .position(|c| c.vars.contains(&var))
617            .ok_or_else(|| {
618                SeqError::GraphInvariantViolated(format!(
619                    "variable {var} not present in any clique"
620                ))
621            })?;
622        let log_marg = self.marginalise_to_separator(clique_idx, &[var]);
623        debug_assert_eq!(log_marg.len(), card);
624        // Normalise in the log domain, then exponentiate.
625        let logz = log_sum_exp(&log_marg);
626        let mut out = vec![0.0; card];
627        if logz == f64::NEG_INFINITY {
628            // Degenerate (all-zero) potential — fall back to uniform.
629            let u = 1.0 / card as f64;
630            for v in out.iter_mut() {
631                *v = u;
632            }
633            return Ok(out);
634        }
635        for l in 0..card {
636            out[l] = (log_marg[l] - logz).exp();
637        }
638        Ok(out)
639    }
640
641    /// Joint marginal over the variables of clique `clique_idx`, normalised to sum
642    /// to 1.  The output is row-major over the clique's (sorted) variables.
643    pub fn clique_marginal(&self, clique_idx: usize) -> SeqResult<Vec<f64>> {
644        if clique_idx >= self.cliques.len() {
645            return Err(SeqError::IndexOutOfBounds {
646                index: clique_idx,
647                len: self.cliques.len(),
648            });
649        }
650        let pot = &self.cliques[clique_idx].potential;
651        let logz = log_sum_exp(pot);
652        let mut out = vec![0.0; pot.len()];
653        if logz == f64::NEG_INFINITY {
654            let u = 1.0 / pot.len().max(1) as f64;
655            for v in out.iter_mut() {
656                *v = u;
657            }
658            return Ok(out);
659        }
660        for (o, &p) in out.iter_mut().zip(pot.iter()) {
661            *o = (p - logz).exp();
662        }
663        Ok(out)
664    }
665
666    /// Log partition function `log Z` (log of the normalisation constant).
667    ///
668    /// After calibration every clique sums (in the linear domain) to `Z`, so we
669    /// take the log-sum-exp of any clique's log-potential.  Before calibration this
670    /// is generally **not** `Z`; callers should calibrate first.
671    pub fn log_partition(&self) -> SeqResult<f64> {
672        if self.cliques.is_empty() {
673            return Err(SeqError::GraphInvariantViolated(
674                "junction tree has no cliques".to_string(),
675            ));
676        }
677        Ok(log_sum_exp(&self.cliques[0].potential))
678    }
679
680    /// Number of cliques in the tree.
681    pub fn n_cliques(&self) -> usize {
682        self.cliques.len()
683    }
684
685    /// Number of separators (tree edges) in the tree.
686    pub fn n_separators(&self) -> usize {
687        self.separators.len()
688    }
689
690    /// Read-only view of the cliques.
691    pub fn cliques(&self) -> &[Clique] {
692        &self.cliques
693    }
694
695    /// Variables of separator `sep_idx` (sorted); useful for tests checking the
696    /// running-intersection property.
697    pub fn separator_vars(&self, sep_idx: usize) -> SeqResult<&[usize]> {
698        if sep_idx >= self.separators.len() {
699            return Err(SeqError::IndexOutOfBounds {
700                index: sep_idx,
701                len: self.separators.len(),
702            });
703        }
704        Ok(&self.separators[sep_idx].vars)
705    }
706
707    /// The `(clique_a, clique_b)` incident to separator `sep_idx`.
708    pub fn separator_cliques(&self, sep_idx: usize) -> SeqResult<(usize, usize)> {
709        if sep_idx >= self.separators.len() {
710            return Err(SeqError::IndexOutOfBounds {
711                index: sep_idx,
712                len: self.separators.len(),
713            });
714        }
715        Ok((
716            self.separators[sep_idx].clique_a,
717            self.separators[sep_idx].clique_b,
718        ))
719    }
720}
721
722/// Compute the sorted shared variables between two sorted variable lists.
723fn shared_vars(a: &[usize], b: &[usize]) -> Vec<usize> {
724    let mut out = Vec::new();
725    let (mut i, mut j) = (0usize, 0usize);
726    while i < a.len() && j < b.len() {
727        match a[i].cmp(&b[j]) {
728            std::cmp::Ordering::Less => i += 1,
729            std::cmp::Ordering::Greater => j += 1,
730            std::cmp::Ordering::Equal => {
731                out.push(a[i]);
732                i += 1;
733                j += 1;
734            }
735        }
736    }
737    out
738}
739
740/// Numerically-safe `log(exp(a) - exp(b))` for `a >= b`, with `-inf` handling.
741fn safe_log_sub(a: f64, b: f64) -> f64 {
742    if a == f64::NEG_INFINITY {
743        return f64::NEG_INFINITY;
744    }
745    if b == f64::NEG_INFINITY {
746        return a;
747    }
748    // a - b is the log message; in the Hugin update we add (new_sep - old_sep) in
749    // the log domain, which is exactly `a - b` (the *ratio* of separator
750    // potentials), not `log(exp a - exp b)`.
751    a - b
752}
753
754#[cfg(test)]
755mod tests {
756    use super::*;
757
758    fn cfg(cards: Vec<usize>) -> JunctionTreeConfig {
759        JunctionTreeConfig {
760            n_vars: cards.len(),
761            cardinalities: cards,
762        }
763    }
764
765    /// Brute-force marginal of `var` by enumerating the full joint of the factors.
766    fn brute_force_marginal(
767        cards: &[usize],
768        factors: &[(Vec<usize>, Vec<f64>)],
769        var: usize,
770    ) -> Vec<f64> {
771        let n = cards.len();
772        let total: usize = cards.iter().product();
773        let mut marg = vec![0.0; cards[var]];
774        let mut states = vec![0usize; n];
775        for joint in 0..total {
776            let mut rem = joint;
777            for k in (0..n).rev() {
778                states[k] = rem % cards[k];
779                rem /= cards[k];
780            }
781            let mut p = 1.0;
782            for (vars, table) in factors {
783                let mut idx = 0usize;
784                for &v in vars {
785                    idx = idx * cards[v] + states[v];
786                }
787                p *= table[idx];
788            }
789            marg[states[var]] += p;
790        }
791        let s: f64 = marg.iter().sum();
792        if s > 0.0 {
793            for m in marg.iter_mut() {
794                *m /= s;
795            }
796        }
797        marg
798    }
799
800    /// Brute-force log partition function over the full joint.
801    fn brute_force_log_z(cards: &[usize], factors: &[(Vec<usize>, Vec<f64>)]) -> f64 {
802        let n = cards.len();
803        let total: usize = cards.iter().product();
804        let mut z = 0.0;
805        let mut states = vec![0usize; n];
806        for joint in 0..total {
807            let mut rem = joint;
808            for k in (0..n).rev() {
809                states[k] = rem % cards[k];
810                rem /= cards[k];
811            }
812            let mut p = 1.0;
813            for (vars, table) in factors {
814                let mut idx = 0usize;
815                for &v in vars {
816                    idx = idx * cards[v] + states[v];
817                }
818                p *= table[idx];
819            }
820            z += p;
821        }
822        z.ln()
823    }
824
825    #[test]
826    fn single_factor_one_clique() {
827        let c = cfg(vec![2, 2]);
828        let factors = vec![(vec![0, 1], vec![1.0, 2.0, 3.0, 4.0])];
829        let jt = JunctionTree::build(&c, &factors).expect("build");
830        assert_eq!(jt.n_cliques(), 1);
831        assert_eq!(jt.n_separators(), 0);
832    }
833
834    #[test]
835    fn single_var_factor_marginal_equals_normalised_potential() {
836        let c = cfg(vec![3]);
837        let factors = vec![(vec![0], vec![1.0, 2.0, 1.0])];
838        let mut jt = JunctionTree::build(&c, &factors).expect("build");
839        jt.calibrate().expect("cal");
840        let m = jt.marginal(0).expect("marg");
841        let expected = [0.25, 0.5, 0.25];
842        for (a, b) in m.iter().zip(expected.iter()) {
843            assert!((a - b).abs() < 1e-12, "{a} vs {b}");
844        }
845    }
846
847    #[test]
848    fn chain_marginals_match_brute_force() {
849        // Chain X0 - X1 - X2 with pairwise factors.
850        let c = cfg(vec![2, 2, 2]);
851        let f01 = (vec![0, 1], vec![1.0, 0.3, 0.4, 2.0]);
852        let f12 = (vec![1, 2], vec![1.5, 0.6, 0.2, 1.1]);
853        let f0 = (vec![0], vec![0.7, 1.3]);
854        let factors = vec![f0, f01, f12];
855        let mut jt = JunctionTree::build(&c, &factors).expect("build");
856        jt.calibrate().expect("cal");
857        for var in 0..3 {
858            let m = jt.marginal(var).expect("marg");
859            let bf = brute_force_marginal(&c.cardinalities, &factors, var);
860            for (a, b) in m.iter().zip(bf.iter()) {
861                assert!((a - b).abs() < 1e-6, "var {var}: {a} vs {b}");
862            }
863        }
864    }
865
866    #[test]
867    fn chain_marginal_sums_to_one() {
868        let c = cfg(vec![3, 2, 3]);
869        let f01 = (vec![0, 1], vec![1.0, 0.3, 0.4, 2.0, 0.5, 1.2]);
870        let f12 = (vec![1, 2], vec![1.5, 0.6, 0.2, 1.1, 0.9, 0.7]);
871        let factors = vec![f01, f12];
872        let mut jt = JunctionTree::build(&c, &factors).expect("build");
873        jt.calibrate().expect("cal");
874        for var in 0..3 {
875            let m = jt.marginal(var).expect("marg");
876            let s: f64 = m.iter().sum();
877            assert!((s - 1.0).abs() < 1e-9, "var {var} sum {s}");
878        }
879    }
880
881    #[test]
882    fn log_partition_matches_brute_force() {
883        let c = cfg(vec![2, 3, 2]);
884        let f01 = (vec![0, 1], vec![1.0, 0.3, 0.4, 2.0, 0.8, 0.5]);
885        let f12 = (vec![1, 2], vec![1.5, 0.6, 0.2, 1.1, 0.9, 0.7]);
886        let f2 = (vec![2], vec![1.2, 0.8]);
887        let factors = vec![f01, f12, f2];
888        let mut jt = JunctionTree::build(&c, &factors).expect("build");
889        jt.calibrate().expect("cal");
890        let lz = jt.log_partition().expect("logz");
891        let bf = brute_force_log_z(&c.cardinalities, &factors);
892        assert!((lz - bf).abs() < 1e-6, "logZ {lz} vs {bf}");
893    }
894
895    #[test]
896    fn independent_variables_product_marginals() {
897        // Two independent single-variable factors -> marginals are the normalised
898        // single potentials, independent of each other.
899        let c = cfg(vec![2, 3]);
900        let f0 = (vec![0], vec![1.0, 3.0]);
901        let f1 = (vec![1], vec![2.0, 2.0, 4.0]);
902        let factors = vec![f0, f1];
903        let mut jt = JunctionTree::build(&c, &factors).expect("build");
904        jt.calibrate().expect("cal");
905        let m0 = jt.marginal(0).expect("m0");
906        let m1 = jt.marginal(1).expect("m1");
907        assert!((m0[0] - 0.25).abs() < 1e-12);
908        assert!((m0[1] - 0.75).abs() < 1e-12);
909        assert!((m1[0] - 0.25).abs() < 1e-12);
910        assert!((m1[1] - 0.25).abs() < 1e-12);
911        assert!((m1[2] - 0.5).abs() < 1e-12);
912    }
913
914    #[test]
915    fn disconnected_factors_handled() {
916        // Factor over {0,1} and a separate factor over {2,3}: forest of two trees.
917        let c = cfg(vec![2, 2, 2, 2]);
918        let fa = (vec![0, 1], vec![1.0, 0.5, 0.5, 1.0]);
919        let fb = (vec![2, 3], vec![2.0, 0.1, 0.1, 2.0]);
920        let factors = vec![fa, fb];
921        let mut jt = JunctionTree::build(&c, &factors).expect("build");
922        jt.calibrate().expect("cal");
923        for var in 0..4 {
924            let m = jt.marginal(var).expect("marg");
925            let bf = brute_force_marginal(&c.cardinalities, &factors, var);
926            for (a, b) in m.iter().zip(bf.iter()) {
927                assert!((a - b).abs() < 1e-6, "var {var}: {a} vs {b}");
928            }
929        }
930    }
931
932    #[test]
933    fn calibrate_is_idempotent() {
934        let c = cfg(vec![2, 2, 2]);
935        let f01 = (vec![0, 1], vec![1.0, 0.3, 0.4, 2.0]);
936        let f12 = (vec![1, 2], vec![1.5, 0.6, 0.2, 1.1]);
937        let factors = vec![f01, f12];
938        let mut jt = JunctionTree::build(&c, &factors).expect("build");
939        jt.calibrate().expect("cal1");
940        let m_before: Vec<Vec<f64>> = (0..3).map(|v| jt.marginal(v).expect("m")).collect();
941        jt.calibrate().expect("cal2");
942        let m_after: Vec<Vec<f64>> = (0..3).map(|v| jt.marginal(v).expect("m")).collect();
943        for (a, b) in m_before.iter().zip(m_after.iter()) {
944            for (x, y) in a.iter().zip(b.iter()) {
945                assert!((x - y).abs() < 1e-9, "{x} vs {y}");
946            }
947        }
948    }
949
950    #[test]
951    fn running_intersection_on_chain() {
952        // Build a chain that induces cliques {0,1},{1,2},{2,3}; the variable 2
953        // shared by cliques {1,2} and {2,3} must appear in their separator.
954        let c = cfg(vec![2, 2, 2, 2]);
955        let f01 = (vec![0, 1], vec![1.0, 0.5, 0.5, 1.0]);
956        let f12 = (vec![1, 2], vec![1.0, 0.5, 0.5, 1.0]);
957        let f23 = (vec![2, 3], vec![1.0, 0.5, 0.5, 1.0]);
958        let factors = vec![f01, f12, f23];
959        let jt = JunctionTree::build(&c, &factors).expect("build");
960        // Every separator must be the intersection of its two incident cliques.
961        for s in 0..jt.n_separators() {
962            let (a, b) = jt.separator_cliques(s).expect("sep");
963            let inter = shared_vars(&jt.cliques()[a].vars, &jt.cliques()[b].vars);
964            assert_eq!(jt.separator_vars(s).expect("vars"), inter.as_slice());
965            assert!(
966                !inter.is_empty(),
967                "separator should be non-empty on a chain"
968            );
969        }
970    }
971
972    #[test]
973    fn n_cliques_sane_for_chain() {
974        let c = cfg(vec![2, 2, 2, 2]);
975        let f01 = (vec![0, 1], vec![1.0, 0.5, 0.5, 1.0]);
976        let f12 = (vec![1, 2], vec![1.0, 0.5, 0.5, 1.0]);
977        let f23 = (vec![2, 3], vec![1.0, 0.5, 0.5, 1.0]);
978        let jt = JunctionTree::build(&c, &[f01, f12, f23]).expect("build");
979        // A 4-node chain has exactly 3 maximal cliques of size 2.
980        assert_eq!(jt.n_cliques(), 3);
981        for cl in jt.cliques() {
982            assert_eq!(cl.vars.len(), 2);
983        }
984    }
985
986    #[test]
987    fn ternary_cardinalities_match_brute_force() {
988        let c = cfg(vec![3, 3]);
989        let f = (
990            vec![0, 1],
991            vec![1.0, 0.2, 0.5, 0.3, 2.0, 0.4, 0.6, 0.1, 1.5],
992        );
993        let factors = vec![f];
994        let mut jt = JunctionTree::build(&c, &factors).expect("build");
995        jt.calibrate().expect("cal");
996        for var in 0..2 {
997            let m = jt.marginal(var).expect("marg");
998            let bf = brute_force_marginal(&c.cardinalities, &factors, var);
999            for (a, b) in m.iter().zip(bf.iter()) {
1000                assert!((a - b).abs() < 1e-9, "var {var}: {a} vs {b}");
1001            }
1002        }
1003    }
1004
1005    #[test]
1006    fn triangle_three_var_factor_match_brute_force() {
1007        // A loop 0-1-2-0 induces one triangle clique {0,1,2}.
1008        let c = cfg(vec![2, 2, 2]);
1009        let f01 = (vec![0, 1], vec![1.0, 0.5, 0.5, 1.0]);
1010        let f12 = (vec![1, 2], vec![1.2, 0.3, 0.4, 0.9]);
1011        let f02 = (vec![0, 2], vec![0.7, 1.1, 1.3, 0.6]);
1012        let factors = vec![f01, f12, f02];
1013        let mut jt = JunctionTree::build(&c, &factors).expect("build");
1014        jt.calibrate().expect("cal");
1015        for var in 0..3 {
1016            let m = jt.marginal(var).expect("marg");
1017            let bf = brute_force_marginal(&c.cardinalities, &factors, var);
1018            for (a, b) in m.iter().zip(bf.iter()) {
1019                assert!((a - b).abs() < 1e-6, "var {var}: {a} vs {b}");
1020            }
1021        }
1022        // The triangle collapses into a single 3-variable clique.
1023        assert_eq!(jt.n_cliques(), 1);
1024        assert_eq!(jt.cliques()[0].vars, vec![0, 1, 2]);
1025    }
1026
1027    #[test]
1028    fn from_mrf_matches_direct_factors() {
1029        // Build an Mrf and check its junction-tree marginals against brute force on
1030        // the equivalent exp(-energy) factors.
1031        let m = Mrf::new(
1032            3,
1033            2,
1034            vec![(0, 1), (1, 2)],
1035            vec![0.1, 0.5, 0.2, 0.3, 0.0, 0.4],
1036            vec![0.0, 0.7, 0.7, 0.0, 0.0, 0.5, 0.5, 0.0],
1037        )
1038        .expect("mrf");
1039        let mut jt = JunctionTree::from_mrf(&m).expect("jt");
1040        jt.calibrate().expect("cal");
1041        // Equivalent factors.
1042        let nl = 2;
1043        let mut factors: Vec<(Vec<usize>, Vec<f64>)> = Vec::new();
1044        for i in 0..3 {
1045            let mut t = vec![0.0; nl];
1046            for l in 0..nl {
1047                t[l] = (-m.unary[i * nl + l]).exp();
1048            }
1049            factors.push((vec![i], t));
1050        }
1051        for (e, &(u, v)) in m.edges.iter().enumerate() {
1052            let mut t = vec![0.0; nl * nl];
1053            for a in 0..nl {
1054                for b in 0..nl {
1055                    t[a * nl + b] = (-m.pairwise[e * nl * nl + a * nl + b]).exp();
1056                }
1057            }
1058            factors.push((vec![u, v], t));
1059        }
1060        for var in 0..3 {
1061            let mm = jt.marginal(var).expect("marg");
1062            let bf = brute_force_marginal(&[nl; 3], &factors, var);
1063            for (a, b) in mm.iter().zip(bf.iter()) {
1064                assert!((a - b).abs() < 1e-6, "var {var}: {a} vs {b}");
1065            }
1066        }
1067    }
1068
1069    #[test]
1070    fn deterministic_build_and_calibrate() {
1071        let c = cfg(vec![2, 2, 2]);
1072        let f01 = (vec![0, 1], vec![1.0, 0.3, 0.4, 2.0]);
1073        let f12 = (vec![1, 2], vec![1.5, 0.6, 0.2, 1.1]);
1074        let factors = vec![f01, f12];
1075        let mut a = JunctionTree::build(&c, &factors).expect("a");
1076        let mut b = JunctionTree::build(&c, &factors).expect("b");
1077        a.calibrate().expect("ca");
1078        b.calibrate().expect("cb");
1079        for var in 0..3 {
1080            let ma = a.marginal(var).expect("ma");
1081            let mb = b.marginal(var).expect("mb");
1082            assert_eq!(ma, mb);
1083        }
1084    }
1085
1086    #[test]
1087    fn err_cardinality_mismatch_with_factor_table() {
1088        let c = cfg(vec![2, 2]);
1089        // Factor over {0,1} should have a 4-entry table, give 3.
1090        let factors = vec![(vec![0, 1], vec![1.0, 2.0, 3.0])];
1091        let r = JunctionTree::build(&c, &factors);
1092        assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
1093    }
1094
1095    #[test]
1096    fn err_var_out_of_range_in_factor() {
1097        let c = cfg(vec![2, 2]);
1098        let factors = vec![(vec![0, 5], vec![1.0, 2.0, 3.0, 4.0])];
1099        let r = JunctionTree::build(&c, &factors);
1100        assert!(matches!(r, Err(SeqError::IndexOutOfBounds { .. })));
1101    }
1102
1103    #[test]
1104    fn err_empty_cardinalities_mismatch() {
1105        let c = JunctionTreeConfig {
1106            n_vars: 2,
1107            cardinalities: vec![2],
1108        };
1109        let r = JunctionTree::build(&c, &[]);
1110        assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
1111    }
1112
1113    #[test]
1114    fn err_n_vars_zero() {
1115        let c = JunctionTreeConfig {
1116            n_vars: 0,
1117            cardinalities: vec![],
1118        };
1119        let r = JunctionTree::build(&c, &[]);
1120        assert!(matches!(r, Err(SeqError::InvalidConfiguration(_))));
1121    }
1122
1123    #[test]
1124    fn err_zero_cardinality() {
1125        let c = JunctionTreeConfig {
1126            n_vars: 2,
1127            cardinalities: vec![2, 0],
1128        };
1129        let r = JunctionTree::build(&c, &[]);
1130        assert!(matches!(r, Err(SeqError::InvalidConfiguration(_))));
1131    }
1132
1133    #[test]
1134    fn err_marginal_var_out_of_range() {
1135        let c = cfg(vec![2, 2]);
1136        let factors = vec![(vec![0, 1], vec![1.0, 1.0, 1.0, 1.0])];
1137        let jt = JunctionTree::build(&c, &factors).expect("build");
1138        let r = jt.marginal(5);
1139        assert!(matches!(r, Err(SeqError::IndexOutOfBounds { .. })));
1140    }
1141
1142    #[test]
1143    fn binary_vs_ternary_isolated_factors() {
1144        // One binary variable, one ternary variable, single factor each.
1145        let c = cfg(vec![2, 3]);
1146        let f0 = (vec![0], vec![3.0, 1.0]);
1147        let f1 = (vec![1], vec![1.0, 1.0, 2.0]);
1148        let mut jt = JunctionTree::build(&c, &[f0, f1]).expect("build");
1149        jt.calibrate().expect("cal");
1150        let m0 = jt.marginal(0).expect("m0");
1151        let m1 = jt.marginal(1).expect("m1");
1152        assert_eq!(m0.len(), 2);
1153        assert_eq!(m1.len(), 3);
1154        assert!((m0[0] - 0.75).abs() < 1e-12);
1155        assert!((m1[2] - 0.5).abs() < 1e-12);
1156    }
1157
1158    #[test]
1159    fn clique_marginal_normalises() {
1160        let c = cfg(vec![2, 2]);
1161        let factors = vec![(vec![0, 1], vec![1.0, 0.3, 0.4, 2.0])];
1162        let mut jt = JunctionTree::build(&c, &factors).expect("build");
1163        jt.calibrate().expect("cal");
1164        let cm = jt.clique_marginal(0).expect("cm");
1165        let s: f64 = cm.iter().sum();
1166        assert!((s - 1.0).abs() < 1e-12, "sum {s}");
1167    }
1168
1169    #[test]
1170    fn no_factors_uniform_marginals() {
1171        // No factors at all: every clique is a singleton with uniform potential.
1172        let c = cfg(vec![2, 3]);
1173        let mut jt = JunctionTree::build(&c, &[]).expect("build");
1174        jt.calibrate().expect("cal");
1175        let m0 = jt.marginal(0).expect("m0");
1176        let m1 = jt.marginal(1).expect("m1");
1177        for v in &m0 {
1178            assert!((v - 0.5).abs() < 1e-12);
1179        }
1180        for v in &m1 {
1181            assert!((v - 1.0 / 3.0).abs() < 1e-12);
1182        }
1183    }
1184}