Skip to main content

anno/joint/
inference.rs

1//! Belief propagation inference for joint entity analysis.
2//!
3//! Implements loopy belief propagation (sum-product) for approximate
4//! inference in the factor graph.
5//!
6//! # Algorithm
7//!
8//! ```text
9//! repeat until convergence:
10//!   for each factor f:
11//!     for each variable v in scope(f):
12//!       m_{f→v}(v) ∝ ∑_{scope(f)\v} ψ_f(scope(f)) ∏_{v'∈scope(f)\v} m_{v'→f}(v')
13//!
14//!   for each variable v:
15//!     for each factor f containing v:
16//!       m_{v→f}(v) ∝ ∏_{f'≠f} m_{f'→v}(v)
17//!
18//! marginal: p(v) ∝ ∏_f m_{f→v}(v)
19//! ```
20//!
21//! # References
22//!
23//! - Kschischang et al. (2001): "Factor Graphs and the Sum-Product Algorithm"
24//! - Murphy et al. (1999): "Loopy Belief Propagation for Approximate Inference"
25//! - Durrett & Klein (2014): "A Joint Model for Entity Analysis"
26
27use super::factors::Factor;
28use super::types::{AntecedentValue, Assignment, JointVariable, LinkValue, VariableId};
29use crate::EntityType;
30use std::collections::HashMap;
31
32// =============================================================================
33// Configuration
34// =============================================================================
35
36/// Configuration for belief propagation inference.
37#[derive(Debug, Clone)]
38pub struct InferenceConfig {
39    /// Maximum iterations
40    pub max_iterations: usize,
41    /// Convergence threshold (max message change)
42    pub convergence_threshold: f64,
43    /// Damping factor (0 = no damping, 1 = no update)
44    pub damping: f64,
45    /// Message schedule
46    pub schedule: MessageSchedule,
47}
48
49impl Default for InferenceConfig {
50    fn default() -> Self {
51        Self {
52            max_iterations: 5,
53            convergence_threshold: 1e-4,
54            damping: 0.0,
55            schedule: MessageSchedule::Parallel,
56        }
57    }
58}
59
60/// Order in which to update messages.
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum MessageSchedule {
63    /// Update all messages in parallel
64    Parallel,
65    /// Update messages sequentially (may converge faster)
66    Sequential,
67}
68
69// =============================================================================
70// Messages
71// =============================================================================
72
73/// A message in belief propagation.
74///
75/// Messages are probability distributions over a variable's domain,
76/// stored in log space for numerical stability.
77#[derive(Debug, Clone)]
78pub struct Message {
79    /// Log probabilities for each domain value
80    pub log_probs: Vec<f64>,
81}
82
83impl Message {
84    /// Create a uniform message.
85    pub fn uniform(domain_size: usize) -> Self {
86        if domain_size == 0 {
87            return Self { log_probs: vec![] };
88        }
89        let log_prob = -(domain_size as f64).ln();
90        Self {
91            log_probs: vec![log_prob; domain_size],
92        }
93    }
94
95    /// Create from raw log probabilities.
96    pub fn from_log_probs(log_probs: Vec<f64>) -> Self {
97        Self { log_probs }
98    }
99
100    /// Normalize to sum to 1 (in probability space).
101    pub fn normalize(&mut self) {
102        if self.log_probs.is_empty() {
103            return;
104        }
105        let log_sum = log_sum_exp(&self.log_probs);
106        if log_sum.is_finite() {
107            for lp in &mut self.log_probs {
108                *lp -= log_sum;
109            }
110        }
111    }
112
113    /// Max change from another message.
114    pub fn max_change(&self, other: &Message) -> f64 {
115        self.log_probs
116            .iter()
117            .zip(other.log_probs.iter())
118            .map(|(a, b)| (a - b).abs())
119            .fold(0.0, f64::max)
120    }
121
122    /// Apply damping with previous message.
123    pub fn damp(&mut self, previous: &Message, damping: f64) {
124        for (new, old) in self.log_probs.iter_mut().zip(previous.log_probs.iter()) {
125            *new = (1.0 - damping) * *new + damping * *old;
126        }
127    }
128
129    /// Pointwise multiply two messages (sum in log space).
130    #[allow(dead_code)]
131    pub fn multiply(&self, other: &Message) -> Message {
132        let log_probs: Vec<f64> = self
133            .log_probs
134            .iter()
135            .zip(other.log_probs.iter())
136            .map(|(a, b)| a + b)
137            .collect();
138        Message { log_probs }
139    }
140}
141
142/// Log-sum-exp trick for numerical stability.
143pub fn log_sum_exp(values: &[f64]) -> f64 {
144    if values.is_empty() {
145        return f64::NEG_INFINITY;
146    }
147
148    let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
149    if max_val.is_infinite() {
150        return max_val;
151    }
152
153    let sum: f64 = values.iter().map(|v| (v - max_val).exp()).sum();
154    max_val + sum.ln()
155}
156
157// =============================================================================
158// Message Store
159// =============================================================================
160
161/// Key for a message (from source to target).
162#[derive(Debug, Clone, PartialEq, Eq, Hash)]
163pub struct MessageKey {
164    /// Source (factor index or variable ID serialized)
165    pub from: String,
166    /// Target (variable ID or factor index serialized)
167    pub to: String,
168}
169
170impl MessageKey {
171    /// Create factor-to-variable key.
172    pub fn factor_to_var(factor_idx: usize, var_id: &VariableId) -> Self {
173        Self {
174            from: format!("f{}", factor_idx),
175            to: format!("v{}_{:?}", var_id.mention_idx, var_id.var_type),
176        }
177    }
178
179    /// Create variable-to-factor key.
180    pub fn var_to_factor(var_id: &VariableId, factor_idx: usize) -> Self {
181        Self {
182            from: format!("v{}_{:?}", var_id.mention_idx, var_id.var_type),
183            to: format!("f{}", factor_idx),
184        }
185    }
186}
187
188/// Storage for all messages in belief propagation.
189#[derive(Debug, Clone, Default)]
190pub struct MessageStore {
191    messages: HashMap<MessageKey, Message>,
192}
193
194impl MessageStore {
195    /// Get message, or return uniform if not set.
196    pub fn get(&self, key: &MessageKey, domain_size: usize) -> Message {
197        self.messages
198            .get(key)
199            .cloned()
200            .unwrap_or_else(|| Message::uniform(domain_size))
201    }
202
203    /// Set message.
204    pub fn set(&mut self, key: MessageKey, message: Message) {
205        self.messages.insert(key, message);
206    }
207}
208
209// =============================================================================
210// Marginals
211// =============================================================================
212
213/// Computed marginal distributions.
214#[derive(Debug, Clone, Default)]
215pub struct Marginals {
216    /// Marginals per variable: var_id → log_probs
217    pub distributions: HashMap<VariableId, Vec<f64>>,
218}
219
220impl Marginals {
221    /// Get most likely value index for a variable.
222    pub fn argmax(&self, var_id: &VariableId) -> Option<usize> {
223        self.distributions.get(var_id).and_then(|probs| {
224            probs
225                .iter()
226                .enumerate()
227                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
228                .map(|(i, _)| i)
229        })
230    }
231
232    /// Get probability of a specific value.
233    pub fn prob(&self, var_id: &VariableId, value_idx: usize) -> Option<f64> {
234        self.distributions
235            .get(var_id)
236            .and_then(|probs| probs.get(value_idx))
237            .map(|log_p| log_p.exp())
238    }
239
240    /// Get max probability for a variable.
241    pub fn max_prob(&self, var_id: &VariableId) -> Option<f64> {
242        self.distributions
243            .get(var_id)
244            .and_then(|probs| {
245                probs
246                    .iter()
247                    .cloned()
248                    .fold(None, |max, p| Some(max.map_or(p, |m: f64| m.max(p))))
249            })
250            .map(|log_p| log_p.exp())
251    }
252}
253
254// =============================================================================
255// Domain Enumeration
256// =============================================================================
257
258/// Helper for enumerating variable domains.
259#[derive(Debug, Clone)]
260pub struct DomainValue {
261    /// Index in the domain
262    pub index: usize,
263    /// The actual value
264    pub value: DomainValueType,
265}
266
267/// Concrete domain value types.
268#[derive(Debug, Clone)]
269pub enum DomainValueType {
270    /// Antecedent value
271    Antecedent(AntecedentValue),
272    /// Entity type
273    SemanticType(EntityType),
274    /// Link value
275    EntityLink(LinkValue),
276}
277
278/// Get domain values for a variable.
279pub fn get_domain_values(var: &JointVariable) -> Vec<DomainValue> {
280    match var {
281        JointVariable::Antecedent { candidates, .. } => {
282            let mut values: Vec<DomainValue> = candidates
283                .iter()
284                .enumerate()
285                .map(|(i, &m)| DomainValue {
286                    index: i,
287                    value: DomainValueType::Antecedent(AntecedentValue::Mention(m)),
288                })
289                .collect();
290            values.push(DomainValue {
291                index: candidates.len(),
292                value: DomainValueType::Antecedent(AntecedentValue::NewCluster),
293            });
294            values
295        }
296        JointVariable::SemanticType { types, .. } => types
297            .iter()
298            .enumerate()
299            .map(|(i, t)| DomainValue {
300                index: i,
301                value: DomainValueType::SemanticType(t.clone()),
302            })
303            .collect(),
304        JointVariable::EntityLink { candidates, .. } => {
305            let mut values: Vec<DomainValue> = candidates
306                .iter()
307                .enumerate()
308                .map(|(i, kb_id)| DomainValue {
309                    index: i,
310                    value: DomainValueType::EntityLink(LinkValue::KbId(kb_id.clone())),
311                })
312                .collect();
313            values.push(DomainValue {
314                index: candidates.len(),
315                value: DomainValueType::EntityLink(LinkValue::Nil),
316            });
317            values
318        }
319    }
320}
321
322/// Apply a domain value to an assignment.
323pub fn apply_domain_value(assignment: &mut Assignment, var: &JointVariable, value: &DomainValue) {
324    let mention_idx = match var {
325        JointVariable::Antecedent { mention_idx, .. } => *mention_idx,
326        JointVariable::SemanticType { mention_idx, .. } => *mention_idx,
327        JointVariable::EntityLink { mention_idx, .. } => *mention_idx,
328    };
329
330    match &value.value {
331        DomainValueType::Antecedent(a) => assignment.set_antecedent(mention_idx, *a),
332        DomainValueType::SemanticType(t) => assignment.set_type(mention_idx, t.clone()),
333        DomainValueType::EntityLink(l) => assignment.set_link(mention_idx, l.clone()),
334    }
335}
336
337// =============================================================================
338// Belief Propagation
339// =============================================================================
340
341/// Belief propagation inference engine.
342pub struct BeliefPropagation {
343    /// Factor graph factors
344    factors: Vec<Box<dyn Factor>>,
345    /// Variables
346    variables: Vec<JointVariable>,
347    /// Message store
348    messages: MessageStore,
349    /// Configuration
350    config: InferenceConfig,
351    /// Variable lookup by ID
352    var_by_id: HashMap<VariableId, usize>,
353}
354
355impl BeliefPropagation {
356    /// Create a new belief propagation engine.
357    pub fn new(
358        factors: Vec<Box<dyn Factor>>,
359        variables: Vec<JointVariable>,
360        config: InferenceConfig,
361    ) -> Self {
362        let var_by_id: HashMap<VariableId, usize> = variables
363            .iter()
364            .enumerate()
365            .map(|(i, v)| (v.id(), i))
366            .collect();
367
368        Self {
369            factors,
370            variables,
371            messages: MessageStore::default(),
372            config,
373            var_by_id,
374        }
375    }
376
377    /// Run belief propagation to compute marginals.
378    pub fn run(&mut self) -> Marginals {
379        // Initialize messages to uniform
380        self.initialize_messages();
381
382        let mut converged = false;
383        for _iter in 0..self.config.max_iterations {
384            let max_change = self.iterate();
385
386            if max_change < self.config.convergence_threshold {
387                converged = true;
388                break;
389            }
390        }
391
392        // Convergence check - algorithm continues regardless since BP typically
393        // provides reasonable approximations after a few iterations
394        let _ = converged;
395
396        self.compute_marginals()
397    }
398
399    /// Initialize messages to uniform distributions.
400    fn initialize_messages(&mut self) {
401        self.messages = MessageStore::default();
402    }
403
404    /// Run one iteration of message passing.
405    fn iterate(&mut self) -> f64 {
406        match self.config.schedule {
407            MessageSchedule::Parallel => self.iterate_parallel(),
408            MessageSchedule::Sequential => self.iterate_sequential(),
409        }
410    }
411
412    /// Parallel message update.
413    fn iterate_parallel(&mut self) -> f64 {
414        let mut max_change: f64 = 0.0;
415        let mut new_messages = Vec::new();
416
417        // Factor-to-variable messages
418        for (factor_idx, factor) in self.factors.iter().enumerate() {
419            for var_id in factor.scope() {
420                if let Some(&var_idx) = self.var_by_id.get(var_id) {
421                    let var = &self.variables[var_idx];
422                    let msg = self.compute_factor_to_var_message(factor_idx, factor.as_ref(), var);
423                    let key = MessageKey::factor_to_var(factor_idx, var_id);
424                    new_messages.push((key, msg, var.domain_size()));
425                }
426            }
427        }
428
429        // Variable-to-factor messages
430        for (var_idx, var) in self.variables.iter().enumerate() {
431            let var_id = var.id();
432            for (factor_idx, factor) in self.factors.iter().enumerate() {
433                if factor.scope().contains(&var_id) {
434                    let msg = self.compute_var_to_factor_message(var_idx, var, factor_idx);
435                    let key = MessageKey::var_to_factor(&var_id, factor_idx);
436                    new_messages.push((key, msg, var.domain_size()));
437                }
438            }
439        }
440
441        // Apply new messages
442        for (key, mut new_msg, domain_size) in new_messages {
443            let old_msg = self.messages.get(&key, domain_size);
444            let change = new_msg.max_change(&old_msg);
445            max_change = max_change.max(change);
446
447            if self.config.damping > 0.0 {
448                new_msg.damp(&old_msg, self.config.damping);
449            }
450
451            new_msg.normalize();
452            self.messages.set(key, new_msg);
453        }
454
455        max_change
456    }
457
458    /// Sequential message update.
459    fn iterate_sequential(&mut self) -> f64 {
460        let mut max_change: f64 = 0.0;
461
462        // Factor-to-variable messages
463        for (factor_idx, factor) in self.factors.iter().enumerate() {
464            for var_id in factor.scope() {
465                if let Some(&var_idx) = self.var_by_id.get(var_id) {
466                    let var = &self.variables[var_idx];
467                    let mut msg =
468                        self.compute_factor_to_var_message(factor_idx, factor.as_ref(), var);
469                    let key = MessageKey::factor_to_var(factor_idx, var_id);
470                    let domain_size = var.domain_size();
471
472                    let old_msg = self.messages.get(&key, domain_size);
473                    let change = msg.max_change(&old_msg);
474                    max_change = max_change.max(change);
475
476                    if self.config.damping > 0.0 {
477                        msg.damp(&old_msg, self.config.damping);
478                    }
479
480                    msg.normalize();
481                    self.messages.set(key, msg);
482                }
483            }
484        }
485
486        // Variable-to-factor messages
487        for (var_idx, var) in self.variables.iter().enumerate() {
488            let var_id = var.id();
489            for (factor_idx, factor) in self.factors.iter().enumerate() {
490                if factor.scope().contains(&var_id) {
491                    let mut msg = self.compute_var_to_factor_message(var_idx, var, factor_idx);
492                    let key = MessageKey::var_to_factor(&var_id, factor_idx);
493                    let domain_size = var.domain_size();
494
495                    let old_msg = self.messages.get(&key, domain_size);
496                    let change = msg.max_change(&old_msg);
497                    max_change = max_change.max(change);
498
499                    if self.config.damping > 0.0 {
500                        msg.damp(&old_msg, self.config.damping);
501                    }
502
503                    msg.normalize();
504                    self.messages.set(key, msg);
505                }
506            }
507        }
508
509        max_change
510    }
511
512    /// Compute factor-to-variable message.
513    ///
514    /// m_{f→v}(v) ∝ ∑_{scope(f)\v} ψ_f(scope(f)) ∏_{v'∈scope(f)\v} m_{v'→f}(v')
515    fn compute_factor_to_var_message(
516        &self,
517        factor_idx: usize,
518        factor: &dyn Factor,
519        target_var: &JointVariable,
520    ) -> Message {
521        let target_var_id = target_var.id();
522        let target_domain = get_domain_values(target_var);
523
524        // Get other variables in factor scope
525        let other_var_ids: Vec<&VariableId> = factor
526            .scope()
527            .iter()
528            .filter(|v| *v != &target_var_id)
529            .collect();
530
531        if other_var_ids.is_empty() {
532            // Unary factor: just evaluate factor potential
533            let log_probs: Vec<f64> = target_domain
534                .iter()
535                .map(|dv| {
536                    let mut assignment = Assignment::default();
537                    apply_domain_value(&mut assignment, target_var, dv);
538                    factor.log_potential(&assignment)
539                })
540                .collect();
541            return Message::from_log_probs(log_probs);
542        }
543
544        // Get other variables
545        let other_vars: Vec<(&VariableId, &JointVariable)> = other_var_ids
546            .iter()
547            .filter_map(|vid| {
548                self.var_by_id
549                    .get(*vid)
550                    .map(|&idx| (*vid, &self.variables[idx]))
551            })
552            .collect();
553
554        // Compute message by marginalizing over other variables
555        let mut log_probs = Vec::with_capacity(target_domain.len());
556
557        for target_value in &target_domain {
558            // Sum over all assignments to other variables
559            let mut sum_terms = Vec::new();
560
561            // Enumerate all combinations of other variable assignments
562            let other_domains: Vec<Vec<DomainValue>> = other_vars
563                .iter()
564                .map(|(_, v)| get_domain_values(v))
565                .collect();
566
567            if other_domains.is_empty() {
568                // No other variables (shouldn't happen after the check above)
569                let mut assignment = Assignment::default();
570                apply_domain_value(&mut assignment, target_var, target_value);
571                log_probs.push(factor.log_potential(&assignment));
572                continue;
573            }
574
575            // Iterate over Cartesian product of other domains
576            let mut indices = vec![0usize; other_domains.len()];
577            loop {
578                // Build assignment
579                let mut assignment = Assignment::default();
580                apply_domain_value(&mut assignment, target_var, target_value);
581
582                let mut incoming_msg_log_sum = 0.0;
583                for (i, (var_id, var)) in other_vars.iter().enumerate() {
584                    let domain_value = &other_domains[i][indices[i]];
585                    apply_domain_value(&mut assignment, var, domain_value);
586
587                    // Get incoming message from this variable
588                    let key = MessageKey::var_to_factor(var_id, factor_idx);
589                    let msg = self.messages.get(&key, var.domain_size());
590                    if domain_value.index < msg.log_probs.len() {
591                        incoming_msg_log_sum += msg.log_probs[domain_value.index];
592                    }
593                }
594
595                // Factor potential + incoming messages
596                let term = factor.log_potential(&assignment) + incoming_msg_log_sum;
597                sum_terms.push(term);
598
599                // Advance to next combination
600                let mut carry = true;
601                for i in (0..indices.len()).rev() {
602                    if carry {
603                        indices[i] += 1;
604                        if indices[i] >= other_domains[i].len() {
605                            indices[i] = 0;
606                        } else {
607                            carry = false;
608                        }
609                    }
610                }
611                if carry {
612                    break;
613                }
614            }
615
616            log_probs.push(log_sum_exp(&sum_terms));
617        }
618
619        Message::from_log_probs(log_probs)
620    }
621
622    /// Compute variable-to-factor message.
623    ///
624    /// m_{v→f}(v) ∝ ∏_{f'≠f} m_{f'→v}(v)
625    fn compute_var_to_factor_message(
626        &self,
627        _var_idx: usize,
628        var: &JointVariable,
629        exclude_factor_idx: usize,
630    ) -> Message {
631        let var_id = var.id();
632        let domain_size = var.domain_size();
633        let mut log_probs = vec![0.0; domain_size];
634
635        // Product of all incoming factor messages except the excluded one
636        for (factor_idx, factor) in self.factors.iter().enumerate() {
637            if factor_idx == exclude_factor_idx {
638                continue;
639            }
640            if !factor.scope().contains(&var_id) {
641                continue;
642            }
643
644            let key = MessageKey::factor_to_var(factor_idx, &var_id);
645            let msg = self.messages.get(&key, domain_size);
646
647            for (i, lp) in log_probs.iter_mut().enumerate() {
648                if i < msg.log_probs.len() {
649                    *lp += msg.log_probs[i];
650                }
651            }
652        }
653
654        Message::from_log_probs(log_probs)
655    }
656
657    /// Compute final marginals from converged messages.
658    fn compute_marginals(&self) -> Marginals {
659        let mut marginals = Marginals::default();
660
661        for var in &self.variables {
662            let var_id = var.id();
663            let domain_size = var.domain_size();
664
665            // Marginal = product of all incoming factor messages
666            let mut log_probs = vec![0.0; domain_size];
667
668            for (factor_idx, factor) in self.factors.iter().enumerate() {
669                if factor.scope().contains(&var_id) {
670                    let key = MessageKey::factor_to_var(factor_idx, &var_id);
671                    let msg = self.messages.get(&key, domain_size);
672                    for (i, lp) in log_probs.iter_mut().enumerate() {
673                        if i < msg.log_probs.len() {
674                            *lp += msg.log_probs[i];
675                        }
676                    }
677                }
678            }
679
680            // Normalize
681            let log_sum = log_sum_exp(&log_probs);
682            if log_sum.is_finite() {
683                for lp in &mut log_probs {
684                    *lp -= log_sum;
685                }
686            }
687
688            marginals.distributions.insert(var_id, log_probs);
689        }
690
691        marginals
692    }
693}
694
695// =============================================================================
696// Tests
697// =============================================================================
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702    use crate::joint::factors::{
703        CorefNerFactor, CorefNerWeights, UnaryCorefFactor, UnaryNerFactor,
704    };
705    use crate::joint::types::VariableType;
706
707    #[test]
708    fn test_log_sum_exp() {
709        let values = vec![1.0, 2.0, 3.0];
710        let result = log_sum_exp(&values);
711        // log(e^1 + e^2 + e^3) ≈ 3.407
712        assert!((result - 3.407).abs() < 0.01);
713    }
714
715    #[test]
716    fn test_log_sum_exp_empty() {
717        let values: Vec<f64> = vec![];
718        let result = log_sum_exp(&values);
719        assert!(result.is_infinite() && result < 0.0);
720    }
721
722    #[test]
723    fn test_log_sum_exp_single() {
724        let values = vec![5.0];
725        let result = log_sum_exp(&values);
726        assert!((result - 5.0).abs() < 1e-10);
727    }
728
729    #[test]
730    fn test_message_normalize() {
731        let mut msg = Message {
732            log_probs: vec![0.0, 0.0, 0.0],
733        };
734        msg.normalize();
735        // Each should be log(1/3)
736        let expected = -(3.0_f64).ln();
737        for lp in &msg.log_probs {
738            assert!((lp - expected).abs() < 1e-6);
739        }
740    }
741
742    #[test]
743    fn test_message_max_change() {
744        let msg1 = Message {
745            log_probs: vec![0.0, 0.0],
746        };
747        let msg2 = Message {
748            log_probs: vec![0.5, -0.5],
749        };
750        assert!((msg1.max_change(&msg2) - 0.5).abs() < 1e-6);
751    }
752
753    #[test]
754    fn test_message_multiply() {
755        let msg1 = Message {
756            log_probs: vec![1.0, 2.0],
757        };
758        let msg2 = Message {
759            log_probs: vec![0.5, 0.5],
760        };
761        let result = msg1.multiply(&msg2);
762        assert!((result.log_probs[0] - 1.5).abs() < 1e-10);
763        assert!((result.log_probs[1] - 2.5).abs() < 1e-10);
764    }
765
766    #[test]
767    fn test_marginals_argmax() {
768        let mut marginals = Marginals::default();
769        let var_id = VariableId {
770            mention_idx: 0,
771            var_type: VariableType::SemanticType,
772        };
773        marginals
774            .distributions
775            .insert(var_id, vec![-1.0, 0.0, -2.0]);
776
777        assert_eq!(marginals.argmax(&var_id), Some(1));
778    }
779
780    #[test]
781    fn test_marginals_prob() {
782        let mut marginals = Marginals::default();
783        let var_id = VariableId {
784            mention_idx: 0,
785            var_type: VariableType::SemanticType,
786        };
787        marginals.distributions.insert(var_id, vec![0.0, 0.0]);
788
789        // exp(0) = 1
790        let prob = marginals.prob(&var_id, 0);
791        assert!(prob.is_some());
792        assert!((prob.unwrap() - 1.0).abs() < 1e-10);
793    }
794
795    #[test]
796    fn test_belief_propagation_empty() {
797        let factors: Vec<Box<dyn Factor>> = vec![];
798        let variables: Vec<JointVariable> = vec![];
799        let config = InferenceConfig::default();
800
801        let mut bp = BeliefPropagation::new(factors, variables, config);
802        let marginals = bp.run();
803
804        assert!(marginals.distributions.is_empty());
805    }
806
807    #[test]
808    fn test_belief_propagation_unary_only() {
809        // Single variable with unary factor
810        let variables = vec![JointVariable::SemanticType {
811            mention_idx: 0,
812            types: vec![EntityType::Person, EntityType::Organization],
813        }];
814
815        let factors: Vec<Box<dyn Factor>> = vec![Box::new(UnaryNerFactor::new(
816            0,
817            vec![(EntityType::Person, 1.0), (EntityType::Organization, 0.0)],
818        ))];
819
820        let config = InferenceConfig::default();
821        let mut bp = BeliefPropagation::new(factors, variables, config);
822        let marginals = bp.run();
823
824        let var_id = VariableId {
825            mention_idx: 0,
826            var_type: VariableType::SemanticType,
827        };
828
829        // Person should have higher probability
830        let argmax = marginals.argmax(&var_id);
831        assert_eq!(argmax, Some(0)); // Person is index 0
832    }
833
834    #[test]
835    fn test_belief_propagation_binary_factor() {
836        // Two mentions with coref+NER factor encouraging consistent types
837        let variables = vec![
838            JointVariable::Antecedent {
839                mention_idx: 1,
840                candidates: vec![0],
841            },
842            JointVariable::SemanticType {
843                mention_idx: 0,
844                types: vec![EntityType::Person, EntityType::Organization],
845            },
846            JointVariable::SemanticType {
847                mention_idx: 1,
848                types: vec![EntityType::Person, EntityType::Organization],
849            },
850        ];
851
852        let factors: Vec<Box<dyn Factor>> = vec![
853            // Unary: mention 0 is likely Person
854            Box::new(UnaryNerFactor::new(
855                0,
856                vec![(EntityType::Person, 2.0), (EntityType::Organization, 0.0)],
857            )),
858            // Unary: mention 1 type prior (weak)
859            Box::new(UnaryNerFactor::new(
860                1,
861                vec![(EntityType::Person, 0.1), (EntityType::Organization, 0.1)],
862            )),
863            // Unary: mention 1 antecedent
864            Box::new(UnaryCorefFactor::new(
865                1,
866                vec![
867                    (AntecedentValue::Mention(0), 1.0),
868                    (AntecedentValue::NewCluster, -1.0),
869                ],
870            )),
871            // Binary: coref+NER consistency
872            Box::new(CorefNerFactor::new(1, 0, CorefNerWeights::default())),
873        ];
874
875        let config = InferenceConfig {
876            max_iterations: 10,
877            ..Default::default()
878        };
879        let mut bp = BeliefPropagation::new(factors, variables, config);
880        let marginals = bp.run();
881
882        // Mention 0 should be Person
883        let var_id_0 = VariableId {
884            mention_idx: 0,
885            var_type: VariableType::SemanticType,
886        };
887        assert_eq!(marginals.argmax(&var_id_0), Some(0));
888
889        // Mention 1 should also be Person (propagated via coref factor)
890        let var_id_1 = VariableId {
891            mention_idx: 1,
892            var_type: VariableType::SemanticType,
893        };
894        // The coref factor encourages type consistency
895        // With the antecedent set to mention 0, type should propagate
896        let probs = marginals.distributions.get(&var_id_1);
897        assert!(probs.is_some());
898    }
899
900    #[test]
901    fn test_message_key_serialization() {
902        let var_id = VariableId {
903            mention_idx: 5,
904            var_type: VariableType::Antecedent,
905        };
906        let key = MessageKey::factor_to_var(3, &var_id);
907        assert!(key.from.contains("f3"));
908        assert!(key.to.contains("5"));
909    }
910
911    #[test]
912    fn test_domain_values_antecedent() {
913        let var = JointVariable::Antecedent {
914            mention_idx: 2,
915            candidates: vec![0, 1],
916        };
917        let domain = get_domain_values(&var);
918        assert_eq!(domain.len(), 3); // 2 candidates + NewCluster
919    }
920
921    #[test]
922    fn test_domain_values_type() {
923        let var = JointVariable::SemanticType {
924            mention_idx: 0,
925            types: vec![
926                EntityType::Person,
927                EntityType::Organization,
928                EntityType::Location,
929            ],
930        };
931        let domain = get_domain_values(&var);
932        assert_eq!(domain.len(), 3);
933    }
934
935    #[test]
936    fn test_domain_values_link() {
937        let var = JointVariable::EntityLink {
938            mention_idx: 0,
939            candidates: vec!["Q42".to_string(), "Q937".to_string()],
940        };
941        let domain = get_domain_values(&var);
942        assert_eq!(domain.len(), 3); // 2 candidates + NIL
943    }
944
945    #[test]
946    fn test_sequential_schedule() {
947        let variables = vec![JointVariable::SemanticType {
948            mention_idx: 0,
949            types: vec![EntityType::Person],
950        }];
951
952        let factors: Vec<Box<dyn Factor>> = vec![Box::new(UnaryNerFactor::new(
953            0,
954            vec![(EntityType::Person, 1.0)],
955        ))];
956
957        let config = InferenceConfig {
958            schedule: MessageSchedule::Sequential,
959            ..Default::default()
960        };
961        let mut bp = BeliefPropagation::new(factors, variables, config);
962        let marginals = bp.run();
963
964        assert!(!marginals.distributions.is_empty());
965    }
966}