1use super::factors::Factor;
28use super::types::{AntecedentValue, Assignment, JointVariable, LinkValue, VariableId};
29use crate::EntityType;
30use std::collections::HashMap;
31
32#[derive(Debug, Clone)]
38pub struct InferenceConfig {
39 pub max_iterations: usize,
41 pub convergence_threshold: f64,
43 pub damping: f64,
45 pub schedule: MessageSchedule,
47}
48
49impl Default for InferenceConfig {
50 fn default() -> Self {
51 Self {
52 max_iterations: 5,
53 convergence_threshold: 1e-4,
54 damping: 0.0,
55 schedule: MessageSchedule::Parallel,
56 }
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum MessageSchedule {
63 Parallel,
65 Sequential,
67}
68
69#[derive(Debug, Clone)]
78pub struct Message {
79 pub log_probs: Vec<f64>,
81}
82
83impl Message {
84 pub fn uniform(domain_size: usize) -> Self {
86 if domain_size == 0 {
87 return Self { log_probs: vec![] };
88 }
89 let log_prob = -(domain_size as f64).ln();
90 Self {
91 log_probs: vec![log_prob; domain_size],
92 }
93 }
94
95 pub fn from_log_probs(log_probs: Vec<f64>) -> Self {
97 Self { log_probs }
98 }
99
100 pub fn normalize(&mut self) {
102 if self.log_probs.is_empty() {
103 return;
104 }
105 let log_sum = log_sum_exp(&self.log_probs);
106 if log_sum.is_finite() {
107 for lp in &mut self.log_probs {
108 *lp -= log_sum;
109 }
110 }
111 }
112
113 pub fn max_change(&self, other: &Message) -> f64 {
115 self.log_probs
116 .iter()
117 .zip(other.log_probs.iter())
118 .map(|(a, b)| (a - b).abs())
119 .fold(0.0, f64::max)
120 }
121
122 pub fn damp(&mut self, previous: &Message, damping: f64) {
124 for (new, old) in self.log_probs.iter_mut().zip(previous.log_probs.iter()) {
125 *new = (1.0 - damping) * *new + damping * *old;
126 }
127 }
128
129 #[allow(dead_code)]
131 pub fn multiply(&self, other: &Message) -> Message {
132 let log_probs: Vec<f64> = self
133 .log_probs
134 .iter()
135 .zip(other.log_probs.iter())
136 .map(|(a, b)| a + b)
137 .collect();
138 Message { log_probs }
139 }
140}
141
142pub fn log_sum_exp(values: &[f64]) -> f64 {
144 if values.is_empty() {
145 return f64::NEG_INFINITY;
146 }
147
148 let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
149 if max_val.is_infinite() {
150 return max_val;
151 }
152
153 let sum: f64 = values.iter().map(|v| (v - max_val).exp()).sum();
154 max_val + sum.ln()
155}
156
157#[derive(Debug, Clone, PartialEq, Eq, Hash)]
163pub struct MessageKey {
164 pub from: String,
166 pub to: String,
168}
169
170impl MessageKey {
171 pub fn factor_to_var(factor_idx: usize, var_id: &VariableId) -> Self {
173 Self {
174 from: format!("f{}", factor_idx),
175 to: format!("v{}_{:?}", var_id.mention_idx, var_id.var_type),
176 }
177 }
178
179 pub fn var_to_factor(var_id: &VariableId, factor_idx: usize) -> Self {
181 Self {
182 from: format!("v{}_{:?}", var_id.mention_idx, var_id.var_type),
183 to: format!("f{}", factor_idx),
184 }
185 }
186}
187
188#[derive(Debug, Clone, Default)]
190pub struct MessageStore {
191 messages: HashMap<MessageKey, Message>,
192}
193
194impl MessageStore {
195 pub fn get(&self, key: &MessageKey, domain_size: usize) -> Message {
197 self.messages
198 .get(key)
199 .cloned()
200 .unwrap_or_else(|| Message::uniform(domain_size))
201 }
202
203 pub fn set(&mut self, key: MessageKey, message: Message) {
205 self.messages.insert(key, message);
206 }
207}
208
209#[derive(Debug, Clone, Default)]
215pub struct Marginals {
216 pub distributions: HashMap<VariableId, Vec<f64>>,
218}
219
220impl Marginals {
221 pub fn argmax(&self, var_id: &VariableId) -> Option<usize> {
223 self.distributions.get(var_id).and_then(|probs| {
224 probs
225 .iter()
226 .enumerate()
227 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
228 .map(|(i, _)| i)
229 })
230 }
231
232 pub fn prob(&self, var_id: &VariableId, value_idx: usize) -> Option<f64> {
234 self.distributions
235 .get(var_id)
236 .and_then(|probs| probs.get(value_idx))
237 .map(|log_p| log_p.exp())
238 }
239
240 pub fn max_prob(&self, var_id: &VariableId) -> Option<f64> {
242 self.distributions
243 .get(var_id)
244 .and_then(|probs| {
245 probs
246 .iter()
247 .cloned()
248 .fold(None, |max, p| Some(max.map_or(p, |m: f64| m.max(p))))
249 })
250 .map(|log_p| log_p.exp())
251 }
252}
253
254#[derive(Debug, Clone)]
260pub struct DomainValue {
261 pub index: usize,
263 pub value: DomainValueType,
265}
266
267#[derive(Debug, Clone)]
269pub enum DomainValueType {
270 Antecedent(AntecedentValue),
272 SemanticType(EntityType),
274 EntityLink(LinkValue),
276}
277
278pub fn get_domain_values(var: &JointVariable) -> Vec<DomainValue> {
280 match var {
281 JointVariable::Antecedent { candidates, .. } => {
282 let mut values: Vec<DomainValue> = candidates
283 .iter()
284 .enumerate()
285 .map(|(i, &m)| DomainValue {
286 index: i,
287 value: DomainValueType::Antecedent(AntecedentValue::Mention(m)),
288 })
289 .collect();
290 values.push(DomainValue {
291 index: candidates.len(),
292 value: DomainValueType::Antecedent(AntecedentValue::NewCluster),
293 });
294 values
295 }
296 JointVariable::SemanticType { types, .. } => types
297 .iter()
298 .enumerate()
299 .map(|(i, t)| DomainValue {
300 index: i,
301 value: DomainValueType::SemanticType(t.clone()),
302 })
303 .collect(),
304 JointVariable::EntityLink { candidates, .. } => {
305 let mut values: Vec<DomainValue> = candidates
306 .iter()
307 .enumerate()
308 .map(|(i, kb_id)| DomainValue {
309 index: i,
310 value: DomainValueType::EntityLink(LinkValue::KbId(kb_id.clone())),
311 })
312 .collect();
313 values.push(DomainValue {
314 index: candidates.len(),
315 value: DomainValueType::EntityLink(LinkValue::Nil),
316 });
317 values
318 }
319 }
320}
321
322pub fn apply_domain_value(assignment: &mut Assignment, var: &JointVariable, value: &DomainValue) {
324 let mention_idx = match var {
325 JointVariable::Antecedent { mention_idx, .. } => *mention_idx,
326 JointVariable::SemanticType { mention_idx, .. } => *mention_idx,
327 JointVariable::EntityLink { mention_idx, .. } => *mention_idx,
328 };
329
330 match &value.value {
331 DomainValueType::Antecedent(a) => assignment.set_antecedent(mention_idx, *a),
332 DomainValueType::SemanticType(t) => assignment.set_type(mention_idx, t.clone()),
333 DomainValueType::EntityLink(l) => assignment.set_link(mention_idx, l.clone()),
334 }
335}
336
337pub struct BeliefPropagation {
343 factors: Vec<Box<dyn Factor>>,
345 variables: Vec<JointVariable>,
347 messages: MessageStore,
349 config: InferenceConfig,
351 var_by_id: HashMap<VariableId, usize>,
353}
354
355impl BeliefPropagation {
356 pub fn new(
358 factors: Vec<Box<dyn Factor>>,
359 variables: Vec<JointVariable>,
360 config: InferenceConfig,
361 ) -> Self {
362 let var_by_id: HashMap<VariableId, usize> = variables
363 .iter()
364 .enumerate()
365 .map(|(i, v)| (v.id(), i))
366 .collect();
367
368 Self {
369 factors,
370 variables,
371 messages: MessageStore::default(),
372 config,
373 var_by_id,
374 }
375 }
376
377 pub fn run(&mut self) -> Marginals {
379 self.initialize_messages();
381
382 let mut converged = false;
383 for _iter in 0..self.config.max_iterations {
384 let max_change = self.iterate();
385
386 if max_change < self.config.convergence_threshold {
387 converged = true;
388 break;
389 }
390 }
391
392 let _ = converged;
395
396 self.compute_marginals()
397 }
398
399 fn initialize_messages(&mut self) {
401 self.messages = MessageStore::default();
402 }
403
404 fn iterate(&mut self) -> f64 {
406 match self.config.schedule {
407 MessageSchedule::Parallel => self.iterate_parallel(),
408 MessageSchedule::Sequential => self.iterate_sequential(),
409 }
410 }
411
412 fn iterate_parallel(&mut self) -> f64 {
414 let mut max_change: f64 = 0.0;
415 let mut new_messages = Vec::new();
416
417 for (factor_idx, factor) in self.factors.iter().enumerate() {
419 for var_id in factor.scope() {
420 if let Some(&var_idx) = self.var_by_id.get(var_id) {
421 let var = &self.variables[var_idx];
422 let msg = self.compute_factor_to_var_message(factor_idx, factor.as_ref(), var);
423 let key = MessageKey::factor_to_var(factor_idx, var_id);
424 new_messages.push((key, msg, var.domain_size()));
425 }
426 }
427 }
428
429 for (var_idx, var) in self.variables.iter().enumerate() {
431 let var_id = var.id();
432 for (factor_idx, factor) in self.factors.iter().enumerate() {
433 if factor.scope().contains(&var_id) {
434 let msg = self.compute_var_to_factor_message(var_idx, var, factor_idx);
435 let key = MessageKey::var_to_factor(&var_id, factor_idx);
436 new_messages.push((key, msg, var.domain_size()));
437 }
438 }
439 }
440
441 for (key, mut new_msg, domain_size) in new_messages {
443 let old_msg = self.messages.get(&key, domain_size);
444 let change = new_msg.max_change(&old_msg);
445 max_change = max_change.max(change);
446
447 if self.config.damping > 0.0 {
448 new_msg.damp(&old_msg, self.config.damping);
449 }
450
451 new_msg.normalize();
452 self.messages.set(key, new_msg);
453 }
454
455 max_change
456 }
457
458 fn iterate_sequential(&mut self) -> f64 {
460 let mut max_change: f64 = 0.0;
461
462 for (factor_idx, factor) in self.factors.iter().enumerate() {
464 for var_id in factor.scope() {
465 if let Some(&var_idx) = self.var_by_id.get(var_id) {
466 let var = &self.variables[var_idx];
467 let mut msg =
468 self.compute_factor_to_var_message(factor_idx, factor.as_ref(), var);
469 let key = MessageKey::factor_to_var(factor_idx, var_id);
470 let domain_size = var.domain_size();
471
472 let old_msg = self.messages.get(&key, domain_size);
473 let change = msg.max_change(&old_msg);
474 max_change = max_change.max(change);
475
476 if self.config.damping > 0.0 {
477 msg.damp(&old_msg, self.config.damping);
478 }
479
480 msg.normalize();
481 self.messages.set(key, msg);
482 }
483 }
484 }
485
486 for (var_idx, var) in self.variables.iter().enumerate() {
488 let var_id = var.id();
489 for (factor_idx, factor) in self.factors.iter().enumerate() {
490 if factor.scope().contains(&var_id) {
491 let mut msg = self.compute_var_to_factor_message(var_idx, var, factor_idx);
492 let key = MessageKey::var_to_factor(&var_id, factor_idx);
493 let domain_size = var.domain_size();
494
495 let old_msg = self.messages.get(&key, domain_size);
496 let change = msg.max_change(&old_msg);
497 max_change = max_change.max(change);
498
499 if self.config.damping > 0.0 {
500 msg.damp(&old_msg, self.config.damping);
501 }
502
503 msg.normalize();
504 self.messages.set(key, msg);
505 }
506 }
507 }
508
509 max_change
510 }
511
512 fn compute_factor_to_var_message(
516 &self,
517 factor_idx: usize,
518 factor: &dyn Factor,
519 target_var: &JointVariable,
520 ) -> Message {
521 let target_var_id = target_var.id();
522 let target_domain = get_domain_values(target_var);
523
524 let other_var_ids: Vec<&VariableId> = factor
526 .scope()
527 .iter()
528 .filter(|v| *v != &target_var_id)
529 .collect();
530
531 if other_var_ids.is_empty() {
532 let log_probs: Vec<f64> = target_domain
534 .iter()
535 .map(|dv| {
536 let mut assignment = Assignment::default();
537 apply_domain_value(&mut assignment, target_var, dv);
538 factor.log_potential(&assignment)
539 })
540 .collect();
541 return Message::from_log_probs(log_probs);
542 }
543
544 let other_vars: Vec<(&VariableId, &JointVariable)> = other_var_ids
546 .iter()
547 .filter_map(|vid| {
548 self.var_by_id
549 .get(*vid)
550 .map(|&idx| (*vid, &self.variables[idx]))
551 })
552 .collect();
553
554 let mut log_probs = Vec::with_capacity(target_domain.len());
556
557 for target_value in &target_domain {
558 let mut sum_terms = Vec::new();
560
561 let other_domains: Vec<Vec<DomainValue>> = other_vars
563 .iter()
564 .map(|(_, v)| get_domain_values(v))
565 .collect();
566
567 if other_domains.is_empty() {
568 let mut assignment = Assignment::default();
570 apply_domain_value(&mut assignment, target_var, target_value);
571 log_probs.push(factor.log_potential(&assignment));
572 continue;
573 }
574
575 let mut indices = vec![0usize; other_domains.len()];
577 loop {
578 let mut assignment = Assignment::default();
580 apply_domain_value(&mut assignment, target_var, target_value);
581
582 let mut incoming_msg_log_sum = 0.0;
583 for (i, (var_id, var)) in other_vars.iter().enumerate() {
584 let domain_value = &other_domains[i][indices[i]];
585 apply_domain_value(&mut assignment, var, domain_value);
586
587 let key = MessageKey::var_to_factor(var_id, factor_idx);
589 let msg = self.messages.get(&key, var.domain_size());
590 if domain_value.index < msg.log_probs.len() {
591 incoming_msg_log_sum += msg.log_probs[domain_value.index];
592 }
593 }
594
595 let term = factor.log_potential(&assignment) + incoming_msg_log_sum;
597 sum_terms.push(term);
598
599 let mut carry = true;
601 for i in (0..indices.len()).rev() {
602 if carry {
603 indices[i] += 1;
604 if indices[i] >= other_domains[i].len() {
605 indices[i] = 0;
606 } else {
607 carry = false;
608 }
609 }
610 }
611 if carry {
612 break;
613 }
614 }
615
616 log_probs.push(log_sum_exp(&sum_terms));
617 }
618
619 Message::from_log_probs(log_probs)
620 }
621
622 fn compute_var_to_factor_message(
626 &self,
627 _var_idx: usize,
628 var: &JointVariable,
629 exclude_factor_idx: usize,
630 ) -> Message {
631 let var_id = var.id();
632 let domain_size = var.domain_size();
633 let mut log_probs = vec![0.0; domain_size];
634
635 for (factor_idx, factor) in self.factors.iter().enumerate() {
637 if factor_idx == exclude_factor_idx {
638 continue;
639 }
640 if !factor.scope().contains(&var_id) {
641 continue;
642 }
643
644 let key = MessageKey::factor_to_var(factor_idx, &var_id);
645 let msg = self.messages.get(&key, domain_size);
646
647 for (i, lp) in log_probs.iter_mut().enumerate() {
648 if i < msg.log_probs.len() {
649 *lp += msg.log_probs[i];
650 }
651 }
652 }
653
654 Message::from_log_probs(log_probs)
655 }
656
657 fn compute_marginals(&self) -> Marginals {
659 let mut marginals = Marginals::default();
660
661 for var in &self.variables {
662 let var_id = var.id();
663 let domain_size = var.domain_size();
664
665 let mut log_probs = vec![0.0; domain_size];
667
668 for (factor_idx, factor) in self.factors.iter().enumerate() {
669 if factor.scope().contains(&var_id) {
670 let key = MessageKey::factor_to_var(factor_idx, &var_id);
671 let msg = self.messages.get(&key, domain_size);
672 for (i, lp) in log_probs.iter_mut().enumerate() {
673 if i < msg.log_probs.len() {
674 *lp += msg.log_probs[i];
675 }
676 }
677 }
678 }
679
680 let log_sum = log_sum_exp(&log_probs);
682 if log_sum.is_finite() {
683 for lp in &mut log_probs {
684 *lp -= log_sum;
685 }
686 }
687
688 marginals.distributions.insert(var_id, log_probs);
689 }
690
691 marginals
692 }
693}
694
695#[cfg(test)]
700mod tests {
701 use super::*;
702 use crate::joint::factors::{
703 CorefNerFactor, CorefNerWeights, UnaryCorefFactor, UnaryNerFactor,
704 };
705 use crate::joint::types::VariableType;
706
707 #[test]
708 fn test_log_sum_exp() {
709 let values = vec![1.0, 2.0, 3.0];
710 let result = log_sum_exp(&values);
711 assert!((result - 3.407).abs() < 0.01);
713 }
714
715 #[test]
716 fn test_log_sum_exp_empty() {
717 let values: Vec<f64> = vec![];
718 let result = log_sum_exp(&values);
719 assert!(result.is_infinite() && result < 0.0);
720 }
721
722 #[test]
723 fn test_log_sum_exp_single() {
724 let values = vec![5.0];
725 let result = log_sum_exp(&values);
726 assert!((result - 5.0).abs() < 1e-10);
727 }
728
729 #[test]
730 fn test_message_normalize() {
731 let mut msg = Message {
732 log_probs: vec![0.0, 0.0, 0.0],
733 };
734 msg.normalize();
735 let expected = -(3.0_f64).ln();
737 for lp in &msg.log_probs {
738 assert!((lp - expected).abs() < 1e-6);
739 }
740 }
741
742 #[test]
743 fn test_message_max_change() {
744 let msg1 = Message {
745 log_probs: vec![0.0, 0.0],
746 };
747 let msg2 = Message {
748 log_probs: vec![0.5, -0.5],
749 };
750 assert!((msg1.max_change(&msg2) - 0.5).abs() < 1e-6);
751 }
752
753 #[test]
754 fn test_message_multiply() {
755 let msg1 = Message {
756 log_probs: vec![1.0, 2.0],
757 };
758 let msg2 = Message {
759 log_probs: vec![0.5, 0.5],
760 };
761 let result = msg1.multiply(&msg2);
762 assert!((result.log_probs[0] - 1.5).abs() < 1e-10);
763 assert!((result.log_probs[1] - 2.5).abs() < 1e-10);
764 }
765
766 #[test]
767 fn test_marginals_argmax() {
768 let mut marginals = Marginals::default();
769 let var_id = VariableId {
770 mention_idx: 0,
771 var_type: VariableType::SemanticType,
772 };
773 marginals
774 .distributions
775 .insert(var_id, vec![-1.0, 0.0, -2.0]);
776
777 assert_eq!(marginals.argmax(&var_id), Some(1));
778 }
779
780 #[test]
781 fn test_marginals_prob() {
782 let mut marginals = Marginals::default();
783 let var_id = VariableId {
784 mention_idx: 0,
785 var_type: VariableType::SemanticType,
786 };
787 marginals.distributions.insert(var_id, vec![0.0, 0.0]);
788
789 let prob = marginals.prob(&var_id, 0);
791 assert!(prob.is_some());
792 assert!((prob.unwrap() - 1.0).abs() < 1e-10);
793 }
794
795 #[test]
796 fn test_belief_propagation_empty() {
797 let factors: Vec<Box<dyn Factor>> = vec![];
798 let variables: Vec<JointVariable> = vec![];
799 let config = InferenceConfig::default();
800
801 let mut bp = BeliefPropagation::new(factors, variables, config);
802 let marginals = bp.run();
803
804 assert!(marginals.distributions.is_empty());
805 }
806
807 #[test]
808 fn test_belief_propagation_unary_only() {
809 let variables = vec![JointVariable::SemanticType {
811 mention_idx: 0,
812 types: vec![EntityType::Person, EntityType::Organization],
813 }];
814
815 let factors: Vec<Box<dyn Factor>> = vec![Box::new(UnaryNerFactor::new(
816 0,
817 vec![(EntityType::Person, 1.0), (EntityType::Organization, 0.0)],
818 ))];
819
820 let config = InferenceConfig::default();
821 let mut bp = BeliefPropagation::new(factors, variables, config);
822 let marginals = bp.run();
823
824 let var_id = VariableId {
825 mention_idx: 0,
826 var_type: VariableType::SemanticType,
827 };
828
829 let argmax = marginals.argmax(&var_id);
831 assert_eq!(argmax, Some(0)); }
833
834 #[test]
835 fn test_belief_propagation_binary_factor() {
836 let variables = vec![
838 JointVariable::Antecedent {
839 mention_idx: 1,
840 candidates: vec![0],
841 },
842 JointVariable::SemanticType {
843 mention_idx: 0,
844 types: vec![EntityType::Person, EntityType::Organization],
845 },
846 JointVariable::SemanticType {
847 mention_idx: 1,
848 types: vec![EntityType::Person, EntityType::Organization],
849 },
850 ];
851
852 let factors: Vec<Box<dyn Factor>> = vec![
853 Box::new(UnaryNerFactor::new(
855 0,
856 vec![(EntityType::Person, 2.0), (EntityType::Organization, 0.0)],
857 )),
858 Box::new(UnaryNerFactor::new(
860 1,
861 vec![(EntityType::Person, 0.1), (EntityType::Organization, 0.1)],
862 )),
863 Box::new(UnaryCorefFactor::new(
865 1,
866 vec![
867 (AntecedentValue::Mention(0), 1.0),
868 (AntecedentValue::NewCluster, -1.0),
869 ],
870 )),
871 Box::new(CorefNerFactor::new(1, 0, CorefNerWeights::default())),
873 ];
874
875 let config = InferenceConfig {
876 max_iterations: 10,
877 ..Default::default()
878 };
879 let mut bp = BeliefPropagation::new(factors, variables, config);
880 let marginals = bp.run();
881
882 let var_id_0 = VariableId {
884 mention_idx: 0,
885 var_type: VariableType::SemanticType,
886 };
887 assert_eq!(marginals.argmax(&var_id_0), Some(0));
888
889 let var_id_1 = VariableId {
891 mention_idx: 1,
892 var_type: VariableType::SemanticType,
893 };
894 let probs = marginals.distributions.get(&var_id_1);
897 assert!(probs.is_some());
898 }
899
900 #[test]
901 fn test_message_key_serialization() {
902 let var_id = VariableId {
903 mention_idx: 5,
904 var_type: VariableType::Antecedent,
905 };
906 let key = MessageKey::factor_to_var(3, &var_id);
907 assert!(key.from.contains("f3"));
908 assert!(key.to.contains("5"));
909 }
910
911 #[test]
912 fn test_domain_values_antecedent() {
913 let var = JointVariable::Antecedent {
914 mention_idx: 2,
915 candidates: vec![0, 1],
916 };
917 let domain = get_domain_values(&var);
918 assert_eq!(domain.len(), 3); }
920
921 #[test]
922 fn test_domain_values_type() {
923 let var = JointVariable::SemanticType {
924 mention_idx: 0,
925 types: vec![
926 EntityType::Person,
927 EntityType::Organization,
928 EntityType::Location,
929 ],
930 };
931 let domain = get_domain_values(&var);
932 assert_eq!(domain.len(), 3);
933 }
934
935 #[test]
936 fn test_domain_values_link() {
937 let var = JointVariable::EntityLink {
938 mention_idx: 0,
939 candidates: vec!["Q42".to_string(), "Q937".to_string()],
940 };
941 let domain = get_domain_values(&var);
942 assert_eq!(domain.len(), 3); }
944
945 #[test]
946 fn test_sequential_schedule() {
947 let variables = vec![JointVariable::SemanticType {
948 mention_idx: 0,
949 types: vec![EntityType::Person],
950 }];
951
952 let factors: Vec<Box<dyn Factor>> = vec![Box::new(UnaryNerFactor::new(
953 0,
954 vec![(EntityType::Person, 1.0)],
955 ))];
956
957 let config = InferenceConfig {
958 schedule: MessageSchedule::Sequential,
959 ..Default::default()
960 };
961 let mut bp = BeliefPropagation::new(factors, variables, config);
962 let marginals = bp.run();
963
964 assert!(!marginals.distributions.is_empty());
965 }
966}