1use super::mrf::Mrf;
28use crate::error::{SeqError, SeqResult};
29
30#[derive(Debug, Clone)]
32pub struct JunctionTreeConfig {
33 pub n_vars: usize,
35 pub cardinalities: Vec<usize>,
37}
38
39#[derive(Debug, Clone)]
45pub struct Clique {
46 pub vars: Vec<usize>,
48 pub potential: Vec<f64>,
50}
51
52#[derive(Debug, Clone)]
55struct Separator {
56 clique_a: usize,
58 clique_b: usize,
60 vars: Vec<usize>,
62 potential: Vec<f64>,
64}
65
66#[derive(Debug, Clone)]
68pub struct JunctionTree {
69 cfg: JunctionTreeConfig,
70 cliques: Vec<Clique>,
71 separators: Vec<Separator>,
72 adjacency: Vec<Vec<(usize, usize)>>,
74 bfs_order: Vec<usize>,
77 parent: Vec<usize>,
79 parent_sep: Vec<usize>,
81}
82
83fn config_count(vars: &[usize], cards: &[usize]) -> usize {
85 let mut n = 1usize;
86 for &v in vars {
87 n = n.saturating_mul(cards[v]);
88 }
89 n
90}
91
92fn decode_index(mut idx: usize, vars: &[usize], cards: &[usize], out: &mut [usize]) {
95 for k in (0..vars.len()).rev() {
96 let c = cards[vars[k]];
97 out[k] = idx % c;
98 idx /= c;
99 }
100}
101
102fn project_index(
106 super_vars: &[usize],
107 super_states: &[usize],
108 sub_vars: &[usize],
109 cards: &[usize],
110) -> usize {
111 let mut idx = 0usize;
112 let mut sp = 0usize;
113 for &sv in sub_vars {
114 while super_vars[sp] != sv {
116 sp += 1;
117 }
118 idx = idx * cards[sv] + super_states[sp];
119 }
120 idx
121}
122
123fn log_sum_exp(xs: &[f64]) -> f64 {
125 let mut m = f64::NEG_INFINITY;
126 for &x in xs {
127 if x > m {
128 m = x;
129 }
130 }
131 if m == f64::NEG_INFINITY {
132 return f64::NEG_INFINITY;
133 }
134 let mut s = 0.0;
135 for &x in xs {
136 s += (x - m).exp();
137 }
138 m + s.ln()
139}
140
141impl JunctionTree {
142 pub fn build(cfg: &JunctionTreeConfig, factors: &[(Vec<usize>, Vec<f64>)]) -> SeqResult<Self> {
149 if cfg.n_vars == 0 {
150 return Err(SeqError::InvalidConfiguration(
151 "n_vars must be >= 1".to_string(),
152 ));
153 }
154 if cfg.cardinalities.len() != cfg.n_vars {
155 return Err(SeqError::ShapeMismatch {
156 expected: cfg.n_vars,
157 got: cfg.cardinalities.len(),
158 });
159 }
160 for &c in &cfg.cardinalities {
161 if c == 0 {
162 return Err(SeqError::InvalidConfiguration(
163 "every cardinality must be >= 1".to_string(),
164 ));
165 }
166 }
167 for (vars, table) in factors {
168 for &v in vars {
169 if v >= cfg.n_vars {
170 return Err(SeqError::IndexOutOfBounds {
171 index: v,
172 len: cfg.n_vars,
173 });
174 }
175 }
176 let expected = config_count(vars, &cfg.cardinalities);
177 if table.len() != expected {
178 return Err(SeqError::ShapeMismatch {
179 expected,
180 got: table.len(),
181 });
182 }
183 }
184
185 let cards = &cfg.cardinalities;
186 let n = cfg.n_vars;
187
188 let mut adj = vec![vec![false; n]; n];
190 for (vars, _) in factors {
191 for a in 0..vars.len() {
192 for b in (a + 1)..vars.len() {
193 let (u, w) = (vars[a], vars[b]);
194 if u != w {
195 adj[u][w] = true;
196 adj[w][u] = true;
197 }
198 }
199 }
200 }
201
202 let candidate_cliques = Self::eliminate_for_cliques(&adj, cards);
204
205 let maximal = Self::keep_maximal(candidate_cliques);
207
208 let (adjacency, separators) = Self::build_clique_tree(&maximal, cards);
210
211 let mut cliques: Vec<Clique> = maximal
213 .into_iter()
214 .map(|vars| {
215 let len = config_count(&vars, cards);
216 Clique {
217 vars,
218 potential: vec![0.0; len],
219 }
220 })
221 .collect();
222
223 for (vars, table) in factors {
225 let mut sorted = vars.clone();
226 sorted.sort_unstable();
227 sorted.dedup();
228 let target = cliques
229 .iter()
230 .position(|c| sorted.iter().all(|v| c.vars.contains(v)));
231 let target = match target {
232 Some(t) => t,
233 None => {
234 return Err(SeqError::GraphInvariantViolated(format!(
235 "factor scope {sorted:?} not contained in any clique"
236 )));
237 }
238 };
239 Self::multiply_factor_into_clique(&mut cliques[target], vars, table, cards);
240 }
241
242 let (bfs_order, parent, parent_sep) = Self::root_tree(cliques.len(), &adjacency);
244
245 Ok(Self {
246 cfg: cfg.clone(),
247 cliques,
248 separators,
249 adjacency,
250 bfs_order,
251 parent,
252 parent_sep,
253 })
254 }
255
256 pub fn from_mrf(mrf: &Mrf) -> SeqResult<Self> {
261 let cfg = JunctionTreeConfig {
262 n_vars: mrf.n_nodes,
263 cardinalities: vec![mrf.n_labels; mrf.n_nodes],
264 };
265 let nl = mrf.n_labels;
266 let l2 = nl * nl;
267 let mut factors: Vec<(Vec<usize>, Vec<f64>)> = Vec::new();
268 for i in 0..mrf.n_nodes {
269 let mut table = vec![0.0; nl];
270 for l in 0..nl {
271 table[l] = (-mrf.unary[i * nl + l]).exp();
272 }
273 factors.push((vec![i], table));
274 }
275 for (e_idx, &(u, v)) in mrf.edges.iter().enumerate() {
276 let (lo, hi) = if u < v { (u, v) } else { (v, u) };
277 let mut table = vec![0.0; l2];
278 for a in 0..nl {
280 for b in 0..nl {
281 let (lu, lv) = if u == lo { (a, b) } else { (b, a) };
283 table[a * nl + b] = (-mrf.pairwise[e_idx * l2 + lu * nl + lv]).exp();
284 }
285 }
286 factors.push((vec![lo, hi], table));
287 }
288 Self::build(&cfg, &factors)
289 }
290
291 fn eliminate_for_cliques(adj: &[Vec<bool>], cards: &[usize]) -> Vec<Vec<usize>> {
294 let n = adj.len();
295 let mut work = adj.to_vec();
297 let mut alive = vec![true; n];
298 let mut cliques: Vec<Vec<usize>> = Vec::new();
299
300 for _ in 0..n {
301 let mut best_var = usize::MAX;
303 let mut best_fill = usize::MAX;
304 let mut best_deg = usize::MAX;
305 for v in 0..n {
306 if !alive[v] {
307 continue;
308 }
309 let neighbours: Vec<usize> = (0..n)
310 .filter(|&u| alive[u] && u != v && work[v][u])
311 .collect();
312 let deg = neighbours.len();
313 let mut fill = 0usize;
315 for a in 0..neighbours.len() {
316 for b in (a + 1)..neighbours.len() {
317 if !work[neighbours[a]][neighbours[b]] {
318 fill += 1;
319 }
320 }
321 }
322 if fill < best_fill || (fill == best_fill && deg < best_deg) {
323 best_fill = fill;
324 best_deg = deg;
325 best_var = v;
326 }
327 }
328 if best_var == usize::MAX {
329 break;
330 }
331
332 let neighbours: Vec<usize> = (0..n)
334 .filter(|&u| alive[u] && u != best_var && work[best_var][u])
335 .collect();
336 let mut clique = Vec::with_capacity(neighbours.len() + 1);
337 clique.push(best_var);
338 clique.extend_from_slice(&neighbours);
339 clique.sort_unstable();
340 cliques.push(clique);
341
342 for a in 0..neighbours.len() {
344 for b in (a + 1)..neighbours.len() {
345 work[neighbours[a]][neighbours[b]] = true;
346 work[neighbours[b]][neighbours[a]] = true;
347 }
348 }
349 alive[best_var] = false;
351 }
352
353 let _ = cards;
356 cliques
357 }
358
359 fn keep_maximal(mut cliques: Vec<Vec<usize>>) -> Vec<Vec<usize>> {
361 cliques.sort_by_key(|c| std::cmp::Reverse(c.len()));
363 let mut maximal: Vec<Vec<usize>> = Vec::new();
364 for c in cliques {
365 let is_subset = maximal.iter().any(|m| c.iter().all(|v| m.contains(v)));
366 if !is_subset {
367 maximal.push(c);
368 }
369 }
370 maximal
371 }
372
373 fn build_clique_tree(
378 cliques: &[Vec<usize>],
379 cards: &[usize],
380 ) -> (Vec<Vec<(usize, usize)>>, Vec<Separator>) {
381 let m = cliques.len();
382 let mut adjacency: Vec<Vec<(usize, usize)>> = vec![Vec::new(); m];
383 let mut separators: Vec<Separator> = Vec::new();
384 if m <= 1 {
385 return (adjacency, separators);
386 }
387
388 let mut edges: Vec<(usize, usize, usize)> = Vec::new();
390 for a in 0..m {
391 for b in (a + 1)..m {
392 let shared = shared_vars(&cliques[a], &cliques[b]);
393 edges.push((shared.len(), a, b));
394 }
395 }
396 edges.sort_by_key(|e| std::cmp::Reverse(e.0));
397
398 let mut parent: Vec<usize> = (0..m).collect();
400 fn find(parent: &mut [usize], x: usize) -> usize {
401 let mut r = x;
402 while parent[r] != r {
403 r = parent[r];
404 }
405 let mut c = x;
407 while parent[c] != r {
408 let next = parent[c];
409 parent[c] = r;
410 c = next;
411 }
412 r
413 }
414
415 for (_w, a, b) in edges {
416 let ra = find(&mut parent, a);
417 let rb = find(&mut parent, b);
418 if ra == rb {
419 continue;
420 }
421 parent[ra] = rb;
422 let shared = shared_vars(&cliques[a], &cliques[b]);
423 let len = config_count(&shared, cards);
424 let sep_idx = separators.len();
425 separators.push(Separator {
426 clique_a: a,
427 clique_b: b,
428 vars: shared,
429 potential: vec![0.0; len],
430 });
431 adjacency[a].push((b, sep_idx));
432 adjacency[b].push((a, sep_idx));
433 }
434
435 (adjacency, separators)
436 }
437
438 fn multiply_factor_into_clique(
442 clique: &mut Clique,
443 factor_vars: &[usize],
444 factor_table: &[f64],
445 cards: &[usize],
446 ) {
447 let len = clique.potential.len();
448 let mut states = vec![0usize; clique.vars.len()];
449 let positions: Vec<usize> = factor_vars
453 .iter()
454 .filter_map(|fv| clique.vars.binary_search(fv).ok())
455 .collect();
456 if positions.len() != factor_vars.len() {
457 return;
459 }
460 for idx in 0..len {
461 decode_index(idx, &clique.vars, cards, &mut states);
462 let mut fidx = 0usize;
464 for (k, &fv) in factor_vars.iter().enumerate() {
465 fidx = fidx * cards[fv] + states[positions[k]];
466 }
467 let val = factor_table[fidx];
468 clique.potential[idx] += if val > 0.0 {
469 val.ln()
470 } else {
471 f64::NEG_INFINITY
472 };
473 }
474 }
475
476 fn root_tree(
480 m: usize,
481 adjacency: &[Vec<(usize, usize)>],
482 ) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
483 let mut visited = vec![false; m];
484 let mut order = Vec::with_capacity(m);
485 let mut parent = vec![usize::MAX; m];
486 let mut parent_sep = vec![usize::MAX; m];
487 for start in 0..m {
488 if visited[start] {
489 continue;
490 }
491 visited[start] = true;
492 let mut queue = std::collections::VecDeque::new();
493 queue.push_back(start);
494 while let Some(c) = queue.pop_front() {
495 order.push(c);
496 for &(nbr, sep) in &adjacency[c] {
497 if !visited[nbr] {
498 visited[nbr] = true;
499 parent[nbr] = c;
500 parent_sep[nbr] = sep;
501 queue.push_back(nbr);
502 }
503 }
504 }
505 }
506 (order, parent, parent_sep)
507 }
508
509 fn marginalise_to_separator(&self, clique_idx: usize, sep_vars: &[usize]) -> Vec<f64> {
512 let clique = &self.cliques[clique_idx];
513 let cards = &self.cfg.cardinalities;
514 let sep_len = config_count(sep_vars, cards);
515 let mut buckets: Vec<Vec<f64>> = vec![Vec::new(); sep_len];
517 let mut states = vec![0usize; clique.vars.len()];
518 for idx in 0..clique.potential.len() {
519 decode_index(idx, &clique.vars, cards, &mut states);
520 let sidx = project_index(&clique.vars, &states, sep_vars, cards);
521 buckets[sidx].push(clique.potential[idx]);
522 }
523 let mut out = vec![f64::NEG_INFINITY; sep_len];
524 for (s, bucket) in buckets.iter().enumerate() {
525 out[s] = log_sum_exp(bucket);
526 }
527 out
528 }
529
530 fn absorb_message_into_clique(&mut self, clique_idx: usize, sep_idx: usize, delta: &[f64]) {
533 let sep_vars = self.separators[sep_idx].vars.clone();
534 let cards = self.cfg.cardinalities.clone();
535 let clique_vars = self.cliques[clique_idx].vars.clone();
536 let mut states = vec![0usize; clique_vars.len()];
537 let len = self.cliques[clique_idx].potential.len();
538 for idx in 0..len {
539 decode_index(idx, &clique_vars, &cards, &mut states);
540 let sidx = project_index(&clique_vars, &states, &sep_vars, &cards);
541 self.cliques[clique_idx].potential[idx] += delta[sidx];
542 }
543 }
544
545 pub fn calibrate(&mut self) -> SeqResult<()> {
549 if self.cliques.is_empty() {
550 return Ok(());
551 }
552
553 let order = self.bfs_order.clone();
556 for &c in order.iter().rev() {
557 let p = self.parent[c];
558 if p == usize::MAX {
559 continue; }
561 let sep_idx = self.parent_sep[c];
562 let sep_vars = self.separators[sep_idx].vars.clone();
563 let new_sep = self.marginalise_to_separator(c, &sep_vars);
565 let old_sep = self.separators[sep_idx].potential.clone();
567 let delta: Vec<f64> = new_sep
568 .iter()
569 .zip(old_sep.iter())
570 .map(|(&a, &b)| safe_log_sub(a, b))
571 .collect();
572 self.absorb_message_into_clique(p, sep_idx, &delta);
573 self.separators[sep_idx].potential = new_sep;
574 }
575
576 for &c in order.iter() {
579 let children: Vec<(usize, usize)> = self.adjacency[c]
581 .iter()
582 .filter(|&&(nbr, _)| self.parent[nbr] == c)
583 .copied()
584 .collect();
585 for (child, sep_idx) in children {
586 let sep_vars = self.separators[sep_idx].vars.clone();
587 let new_sep = self.marginalise_to_separator(c, &sep_vars);
588 let old_sep = self.separators[sep_idx].potential.clone();
589 let delta: Vec<f64> = new_sep
590 .iter()
591 .zip(old_sep.iter())
592 .map(|(&a, &b)| safe_log_sub(a, b))
593 .collect();
594 self.absorb_message_into_clique(child, sep_idx, &delta);
595 self.separators[sep_idx].potential = new_sep;
596 }
597 }
598
599 Ok(())
600 }
601
602 pub fn marginal(&self, var: usize) -> SeqResult<Vec<f64>> {
605 if var >= self.cfg.n_vars {
606 return Err(SeqError::IndexOutOfBounds {
607 index: var,
608 len: self.cfg.n_vars,
609 });
610 }
611 let card = self.cfg.cardinalities[var];
612 let clique_idx = self
614 .cliques
615 .iter()
616 .position(|c| c.vars.contains(&var))
617 .ok_or_else(|| {
618 SeqError::GraphInvariantViolated(format!(
619 "variable {var} not present in any clique"
620 ))
621 })?;
622 let log_marg = self.marginalise_to_separator(clique_idx, &[var]);
623 debug_assert_eq!(log_marg.len(), card);
624 let logz = log_sum_exp(&log_marg);
626 let mut out = vec![0.0; card];
627 if logz == f64::NEG_INFINITY {
628 let u = 1.0 / card as f64;
630 for v in out.iter_mut() {
631 *v = u;
632 }
633 return Ok(out);
634 }
635 for l in 0..card {
636 out[l] = (log_marg[l] - logz).exp();
637 }
638 Ok(out)
639 }
640
641 pub fn clique_marginal(&self, clique_idx: usize) -> SeqResult<Vec<f64>> {
644 if clique_idx >= self.cliques.len() {
645 return Err(SeqError::IndexOutOfBounds {
646 index: clique_idx,
647 len: self.cliques.len(),
648 });
649 }
650 let pot = &self.cliques[clique_idx].potential;
651 let logz = log_sum_exp(pot);
652 let mut out = vec![0.0; pot.len()];
653 if logz == f64::NEG_INFINITY {
654 let u = 1.0 / pot.len().max(1) as f64;
655 for v in out.iter_mut() {
656 *v = u;
657 }
658 return Ok(out);
659 }
660 for (o, &p) in out.iter_mut().zip(pot.iter()) {
661 *o = (p - logz).exp();
662 }
663 Ok(out)
664 }
665
666 pub fn log_partition(&self) -> SeqResult<f64> {
672 if self.cliques.is_empty() {
673 return Err(SeqError::GraphInvariantViolated(
674 "junction tree has no cliques".to_string(),
675 ));
676 }
677 Ok(log_sum_exp(&self.cliques[0].potential))
678 }
679
680 pub fn n_cliques(&self) -> usize {
682 self.cliques.len()
683 }
684
685 pub fn n_separators(&self) -> usize {
687 self.separators.len()
688 }
689
690 pub fn cliques(&self) -> &[Clique] {
692 &self.cliques
693 }
694
695 pub fn separator_vars(&self, sep_idx: usize) -> SeqResult<&[usize]> {
698 if sep_idx >= self.separators.len() {
699 return Err(SeqError::IndexOutOfBounds {
700 index: sep_idx,
701 len: self.separators.len(),
702 });
703 }
704 Ok(&self.separators[sep_idx].vars)
705 }
706
707 pub fn separator_cliques(&self, sep_idx: usize) -> SeqResult<(usize, usize)> {
709 if sep_idx >= self.separators.len() {
710 return Err(SeqError::IndexOutOfBounds {
711 index: sep_idx,
712 len: self.separators.len(),
713 });
714 }
715 Ok((
716 self.separators[sep_idx].clique_a,
717 self.separators[sep_idx].clique_b,
718 ))
719 }
720}
721
722fn shared_vars(a: &[usize], b: &[usize]) -> Vec<usize> {
724 let mut out = Vec::new();
725 let (mut i, mut j) = (0usize, 0usize);
726 while i < a.len() && j < b.len() {
727 match a[i].cmp(&b[j]) {
728 std::cmp::Ordering::Less => i += 1,
729 std::cmp::Ordering::Greater => j += 1,
730 std::cmp::Ordering::Equal => {
731 out.push(a[i]);
732 i += 1;
733 j += 1;
734 }
735 }
736 }
737 out
738}
739
740fn safe_log_sub(a: f64, b: f64) -> f64 {
742 if a == f64::NEG_INFINITY {
743 return f64::NEG_INFINITY;
744 }
745 if b == f64::NEG_INFINITY {
746 return a;
747 }
748 a - b
752}
753
754#[cfg(test)]
755mod tests {
756 use super::*;
757
758 fn cfg(cards: Vec<usize>) -> JunctionTreeConfig {
759 JunctionTreeConfig {
760 n_vars: cards.len(),
761 cardinalities: cards,
762 }
763 }
764
765 fn brute_force_marginal(
767 cards: &[usize],
768 factors: &[(Vec<usize>, Vec<f64>)],
769 var: usize,
770 ) -> Vec<f64> {
771 let n = cards.len();
772 let total: usize = cards.iter().product();
773 let mut marg = vec![0.0; cards[var]];
774 let mut states = vec![0usize; n];
775 for joint in 0..total {
776 let mut rem = joint;
777 for k in (0..n).rev() {
778 states[k] = rem % cards[k];
779 rem /= cards[k];
780 }
781 let mut p = 1.0;
782 for (vars, table) in factors {
783 let mut idx = 0usize;
784 for &v in vars {
785 idx = idx * cards[v] + states[v];
786 }
787 p *= table[idx];
788 }
789 marg[states[var]] += p;
790 }
791 let s: f64 = marg.iter().sum();
792 if s > 0.0 {
793 for m in marg.iter_mut() {
794 *m /= s;
795 }
796 }
797 marg
798 }
799
800 fn brute_force_log_z(cards: &[usize], factors: &[(Vec<usize>, Vec<f64>)]) -> f64 {
802 let n = cards.len();
803 let total: usize = cards.iter().product();
804 let mut z = 0.0;
805 let mut states = vec![0usize; n];
806 for joint in 0..total {
807 let mut rem = joint;
808 for k in (0..n).rev() {
809 states[k] = rem % cards[k];
810 rem /= cards[k];
811 }
812 let mut p = 1.0;
813 for (vars, table) in factors {
814 let mut idx = 0usize;
815 for &v in vars {
816 idx = idx * cards[v] + states[v];
817 }
818 p *= table[idx];
819 }
820 z += p;
821 }
822 z.ln()
823 }
824
825 #[test]
826 fn single_factor_one_clique() {
827 let c = cfg(vec![2, 2]);
828 let factors = vec![(vec![0, 1], vec![1.0, 2.0, 3.0, 4.0])];
829 let jt = JunctionTree::build(&c, &factors).expect("build");
830 assert_eq!(jt.n_cliques(), 1);
831 assert_eq!(jt.n_separators(), 0);
832 }
833
834 #[test]
835 fn single_var_factor_marginal_equals_normalised_potential() {
836 let c = cfg(vec![3]);
837 let factors = vec![(vec![0], vec![1.0, 2.0, 1.0])];
838 let mut jt = JunctionTree::build(&c, &factors).expect("build");
839 jt.calibrate().expect("cal");
840 let m = jt.marginal(0).expect("marg");
841 let expected = [0.25, 0.5, 0.25];
842 for (a, b) in m.iter().zip(expected.iter()) {
843 assert!((a - b).abs() < 1e-12, "{a} vs {b}");
844 }
845 }
846
847 #[test]
848 fn chain_marginals_match_brute_force() {
849 let c = cfg(vec![2, 2, 2]);
851 let f01 = (vec![0, 1], vec![1.0, 0.3, 0.4, 2.0]);
852 let f12 = (vec![1, 2], vec![1.5, 0.6, 0.2, 1.1]);
853 let f0 = (vec![0], vec![0.7, 1.3]);
854 let factors = vec![f0, f01, f12];
855 let mut jt = JunctionTree::build(&c, &factors).expect("build");
856 jt.calibrate().expect("cal");
857 for var in 0..3 {
858 let m = jt.marginal(var).expect("marg");
859 let bf = brute_force_marginal(&c.cardinalities, &factors, var);
860 for (a, b) in m.iter().zip(bf.iter()) {
861 assert!((a - b).abs() < 1e-6, "var {var}: {a} vs {b}");
862 }
863 }
864 }
865
866 #[test]
867 fn chain_marginal_sums_to_one() {
868 let c = cfg(vec![3, 2, 3]);
869 let f01 = (vec![0, 1], vec![1.0, 0.3, 0.4, 2.0, 0.5, 1.2]);
870 let f12 = (vec![1, 2], vec![1.5, 0.6, 0.2, 1.1, 0.9, 0.7]);
871 let factors = vec![f01, f12];
872 let mut jt = JunctionTree::build(&c, &factors).expect("build");
873 jt.calibrate().expect("cal");
874 for var in 0..3 {
875 let m = jt.marginal(var).expect("marg");
876 let s: f64 = m.iter().sum();
877 assert!((s - 1.0).abs() < 1e-9, "var {var} sum {s}");
878 }
879 }
880
881 #[test]
882 fn log_partition_matches_brute_force() {
883 let c = cfg(vec![2, 3, 2]);
884 let f01 = (vec![0, 1], vec![1.0, 0.3, 0.4, 2.0, 0.8, 0.5]);
885 let f12 = (vec![1, 2], vec![1.5, 0.6, 0.2, 1.1, 0.9, 0.7]);
886 let f2 = (vec![2], vec![1.2, 0.8]);
887 let factors = vec![f01, f12, f2];
888 let mut jt = JunctionTree::build(&c, &factors).expect("build");
889 jt.calibrate().expect("cal");
890 let lz = jt.log_partition().expect("logz");
891 let bf = brute_force_log_z(&c.cardinalities, &factors);
892 assert!((lz - bf).abs() < 1e-6, "logZ {lz} vs {bf}");
893 }
894
895 #[test]
896 fn independent_variables_product_marginals() {
897 let c = cfg(vec![2, 3]);
900 let f0 = (vec![0], vec![1.0, 3.0]);
901 let f1 = (vec![1], vec![2.0, 2.0, 4.0]);
902 let factors = vec![f0, f1];
903 let mut jt = JunctionTree::build(&c, &factors).expect("build");
904 jt.calibrate().expect("cal");
905 let m0 = jt.marginal(0).expect("m0");
906 let m1 = jt.marginal(1).expect("m1");
907 assert!((m0[0] - 0.25).abs() < 1e-12);
908 assert!((m0[1] - 0.75).abs() < 1e-12);
909 assert!((m1[0] - 0.25).abs() < 1e-12);
910 assert!((m1[1] - 0.25).abs() < 1e-12);
911 assert!((m1[2] - 0.5).abs() < 1e-12);
912 }
913
914 #[test]
915 fn disconnected_factors_handled() {
916 let c = cfg(vec![2, 2, 2, 2]);
918 let fa = (vec![0, 1], vec![1.0, 0.5, 0.5, 1.0]);
919 let fb = (vec![2, 3], vec![2.0, 0.1, 0.1, 2.0]);
920 let factors = vec![fa, fb];
921 let mut jt = JunctionTree::build(&c, &factors).expect("build");
922 jt.calibrate().expect("cal");
923 for var in 0..4 {
924 let m = jt.marginal(var).expect("marg");
925 let bf = brute_force_marginal(&c.cardinalities, &factors, var);
926 for (a, b) in m.iter().zip(bf.iter()) {
927 assert!((a - b).abs() < 1e-6, "var {var}: {a} vs {b}");
928 }
929 }
930 }
931
932 #[test]
933 fn calibrate_is_idempotent() {
934 let c = cfg(vec![2, 2, 2]);
935 let f01 = (vec![0, 1], vec![1.0, 0.3, 0.4, 2.0]);
936 let f12 = (vec![1, 2], vec![1.5, 0.6, 0.2, 1.1]);
937 let factors = vec![f01, f12];
938 let mut jt = JunctionTree::build(&c, &factors).expect("build");
939 jt.calibrate().expect("cal1");
940 let m_before: Vec<Vec<f64>> = (0..3).map(|v| jt.marginal(v).expect("m")).collect();
941 jt.calibrate().expect("cal2");
942 let m_after: Vec<Vec<f64>> = (0..3).map(|v| jt.marginal(v).expect("m")).collect();
943 for (a, b) in m_before.iter().zip(m_after.iter()) {
944 for (x, y) in a.iter().zip(b.iter()) {
945 assert!((x - y).abs() < 1e-9, "{x} vs {y}");
946 }
947 }
948 }
949
950 #[test]
951 fn running_intersection_on_chain() {
952 let c = cfg(vec![2, 2, 2, 2]);
955 let f01 = (vec![0, 1], vec![1.0, 0.5, 0.5, 1.0]);
956 let f12 = (vec![1, 2], vec![1.0, 0.5, 0.5, 1.0]);
957 let f23 = (vec![2, 3], vec![1.0, 0.5, 0.5, 1.0]);
958 let factors = vec![f01, f12, f23];
959 let jt = JunctionTree::build(&c, &factors).expect("build");
960 for s in 0..jt.n_separators() {
962 let (a, b) = jt.separator_cliques(s).expect("sep");
963 let inter = shared_vars(&jt.cliques()[a].vars, &jt.cliques()[b].vars);
964 assert_eq!(jt.separator_vars(s).expect("vars"), inter.as_slice());
965 assert!(
966 !inter.is_empty(),
967 "separator should be non-empty on a chain"
968 );
969 }
970 }
971
972 #[test]
973 fn n_cliques_sane_for_chain() {
974 let c = cfg(vec![2, 2, 2, 2]);
975 let f01 = (vec![0, 1], vec![1.0, 0.5, 0.5, 1.0]);
976 let f12 = (vec![1, 2], vec![1.0, 0.5, 0.5, 1.0]);
977 let f23 = (vec![2, 3], vec![1.0, 0.5, 0.5, 1.0]);
978 let jt = JunctionTree::build(&c, &[f01, f12, f23]).expect("build");
979 assert_eq!(jt.n_cliques(), 3);
981 for cl in jt.cliques() {
982 assert_eq!(cl.vars.len(), 2);
983 }
984 }
985
986 #[test]
987 fn ternary_cardinalities_match_brute_force() {
988 let c = cfg(vec![3, 3]);
989 let f = (
990 vec![0, 1],
991 vec![1.0, 0.2, 0.5, 0.3, 2.0, 0.4, 0.6, 0.1, 1.5],
992 );
993 let factors = vec![f];
994 let mut jt = JunctionTree::build(&c, &factors).expect("build");
995 jt.calibrate().expect("cal");
996 for var in 0..2 {
997 let m = jt.marginal(var).expect("marg");
998 let bf = brute_force_marginal(&c.cardinalities, &factors, var);
999 for (a, b) in m.iter().zip(bf.iter()) {
1000 assert!((a - b).abs() < 1e-9, "var {var}: {a} vs {b}");
1001 }
1002 }
1003 }
1004
1005 #[test]
1006 fn triangle_three_var_factor_match_brute_force() {
1007 let c = cfg(vec![2, 2, 2]);
1009 let f01 = (vec![0, 1], vec![1.0, 0.5, 0.5, 1.0]);
1010 let f12 = (vec![1, 2], vec![1.2, 0.3, 0.4, 0.9]);
1011 let f02 = (vec![0, 2], vec![0.7, 1.1, 1.3, 0.6]);
1012 let factors = vec![f01, f12, f02];
1013 let mut jt = JunctionTree::build(&c, &factors).expect("build");
1014 jt.calibrate().expect("cal");
1015 for var in 0..3 {
1016 let m = jt.marginal(var).expect("marg");
1017 let bf = brute_force_marginal(&c.cardinalities, &factors, var);
1018 for (a, b) in m.iter().zip(bf.iter()) {
1019 assert!((a - b).abs() < 1e-6, "var {var}: {a} vs {b}");
1020 }
1021 }
1022 assert_eq!(jt.n_cliques(), 1);
1024 assert_eq!(jt.cliques()[0].vars, vec![0, 1, 2]);
1025 }
1026
1027 #[test]
1028 fn from_mrf_matches_direct_factors() {
1029 let m = Mrf::new(
1032 3,
1033 2,
1034 vec![(0, 1), (1, 2)],
1035 vec![0.1, 0.5, 0.2, 0.3, 0.0, 0.4],
1036 vec![0.0, 0.7, 0.7, 0.0, 0.0, 0.5, 0.5, 0.0],
1037 )
1038 .expect("mrf");
1039 let mut jt = JunctionTree::from_mrf(&m).expect("jt");
1040 jt.calibrate().expect("cal");
1041 let nl = 2;
1043 let mut factors: Vec<(Vec<usize>, Vec<f64>)> = Vec::new();
1044 for i in 0..3 {
1045 let mut t = vec![0.0; nl];
1046 for l in 0..nl {
1047 t[l] = (-m.unary[i * nl + l]).exp();
1048 }
1049 factors.push((vec![i], t));
1050 }
1051 for (e, &(u, v)) in m.edges.iter().enumerate() {
1052 let mut t = vec![0.0; nl * nl];
1053 for a in 0..nl {
1054 for b in 0..nl {
1055 t[a * nl + b] = (-m.pairwise[e * nl * nl + a * nl + b]).exp();
1056 }
1057 }
1058 factors.push((vec![u, v], t));
1059 }
1060 for var in 0..3 {
1061 let mm = jt.marginal(var).expect("marg");
1062 let bf = brute_force_marginal(&[nl; 3], &factors, var);
1063 for (a, b) in mm.iter().zip(bf.iter()) {
1064 assert!((a - b).abs() < 1e-6, "var {var}: {a} vs {b}");
1065 }
1066 }
1067 }
1068
1069 #[test]
1070 fn deterministic_build_and_calibrate() {
1071 let c = cfg(vec![2, 2, 2]);
1072 let f01 = (vec![0, 1], vec![1.0, 0.3, 0.4, 2.0]);
1073 let f12 = (vec![1, 2], vec![1.5, 0.6, 0.2, 1.1]);
1074 let factors = vec![f01, f12];
1075 let mut a = JunctionTree::build(&c, &factors).expect("a");
1076 let mut b = JunctionTree::build(&c, &factors).expect("b");
1077 a.calibrate().expect("ca");
1078 b.calibrate().expect("cb");
1079 for var in 0..3 {
1080 let ma = a.marginal(var).expect("ma");
1081 let mb = b.marginal(var).expect("mb");
1082 assert_eq!(ma, mb);
1083 }
1084 }
1085
1086 #[test]
1087 fn err_cardinality_mismatch_with_factor_table() {
1088 let c = cfg(vec![2, 2]);
1089 let factors = vec![(vec![0, 1], vec![1.0, 2.0, 3.0])];
1091 let r = JunctionTree::build(&c, &factors);
1092 assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
1093 }
1094
1095 #[test]
1096 fn err_var_out_of_range_in_factor() {
1097 let c = cfg(vec![2, 2]);
1098 let factors = vec![(vec![0, 5], vec![1.0, 2.0, 3.0, 4.0])];
1099 let r = JunctionTree::build(&c, &factors);
1100 assert!(matches!(r, Err(SeqError::IndexOutOfBounds { .. })));
1101 }
1102
1103 #[test]
1104 fn err_empty_cardinalities_mismatch() {
1105 let c = JunctionTreeConfig {
1106 n_vars: 2,
1107 cardinalities: vec![2],
1108 };
1109 let r = JunctionTree::build(&c, &[]);
1110 assert!(matches!(r, Err(SeqError::ShapeMismatch { .. })));
1111 }
1112
1113 #[test]
1114 fn err_n_vars_zero() {
1115 let c = JunctionTreeConfig {
1116 n_vars: 0,
1117 cardinalities: vec![],
1118 };
1119 let r = JunctionTree::build(&c, &[]);
1120 assert!(matches!(r, Err(SeqError::InvalidConfiguration(_))));
1121 }
1122
1123 #[test]
1124 fn err_zero_cardinality() {
1125 let c = JunctionTreeConfig {
1126 n_vars: 2,
1127 cardinalities: vec![2, 0],
1128 };
1129 let r = JunctionTree::build(&c, &[]);
1130 assert!(matches!(r, Err(SeqError::InvalidConfiguration(_))));
1131 }
1132
1133 #[test]
1134 fn err_marginal_var_out_of_range() {
1135 let c = cfg(vec![2, 2]);
1136 let factors = vec![(vec![0, 1], vec![1.0, 1.0, 1.0, 1.0])];
1137 let jt = JunctionTree::build(&c, &factors).expect("build");
1138 let r = jt.marginal(5);
1139 assert!(matches!(r, Err(SeqError::IndexOutOfBounds { .. })));
1140 }
1141
1142 #[test]
1143 fn binary_vs_ternary_isolated_factors() {
1144 let c = cfg(vec![2, 3]);
1146 let f0 = (vec![0], vec![3.0, 1.0]);
1147 let f1 = (vec![1], vec![1.0, 1.0, 2.0]);
1148 let mut jt = JunctionTree::build(&c, &[f0, f1]).expect("build");
1149 jt.calibrate().expect("cal");
1150 let m0 = jt.marginal(0).expect("m0");
1151 let m1 = jt.marginal(1).expect("m1");
1152 assert_eq!(m0.len(), 2);
1153 assert_eq!(m1.len(), 3);
1154 assert!((m0[0] - 0.75).abs() < 1e-12);
1155 assert!((m1[2] - 0.5).abs() < 1e-12);
1156 }
1157
1158 #[test]
1159 fn clique_marginal_normalises() {
1160 let c = cfg(vec![2, 2]);
1161 let factors = vec![(vec![0, 1], vec![1.0, 0.3, 0.4, 2.0])];
1162 let mut jt = JunctionTree::build(&c, &factors).expect("build");
1163 jt.calibrate().expect("cal");
1164 let cm = jt.clique_marginal(0).expect("cm");
1165 let s: f64 = cm.iter().sum();
1166 assert!((s - 1.0).abs() < 1e-12, "sum {s}");
1167 }
1168
1169 #[test]
1170 fn no_factors_uniform_marginals() {
1171 let c = cfg(vec![2, 3]);
1173 let mut jt = JunctionTree::build(&c, &[]).expect("build");
1174 jt.calibrate().expect("cal");
1175 let m0 = jt.marginal(0).expect("m0");
1176 let m1 = jt.marginal(1).expect("m1");
1177 for v in &m0 {
1178 assert!((v - 0.5).abs() < 1e-12);
1179 }
1180 for v in &m1 {
1181 assert!((v - 1.0 / 3.0).abs() < 1e-12);
1182 }
1183 }
1184}