Skip to main content

mollendorff_forge/bayesian/
inference.rs

1//! Belief Propagation Inference
2//!
3//! Implements variable elimination for exact inference in Bayesian networks.
4//! Validated against pgmpy.
5
6use super::config::{BayesianConfig, BayesianNode, NodeType};
7use std::collections::HashMap;
8
9/// Factor (potential function) for inference
10#[derive(Debug, Clone)]
11pub struct Factor {
12    /// Variables in this factor
13    pub variables: Vec<String>,
14    /// Cardinalities of each variable
15    pub cardinalities: Vec<usize>,
16    /// Probability values (flattened)
17    pub values: Vec<f64>,
18}
19
20impl Factor {
21    /// Create a factor from a node's CPT
22    #[must_use]
23    pub fn from_node(name: &str, node: &BayesianNode, config: &BayesianConfig) -> Self {
24        let mut variables = vec![name.to_string()];
25        let mut cardinalities = vec![node.states.len()];
26
27        // Add parent variables
28        for parent in &node.parents {
29            if let Some(parent_node) = config.nodes.get(parent) {
30                variables.push(parent.clone());
31                cardinalities.push(parent_node.states.len());
32            }
33        }
34
35        // Build values array
36        let total_size: usize = cardinalities.iter().product();
37        let mut values = vec![0.0; total_size];
38
39        if node.is_root() {
40            // Root node: just prior
41            values.clone_from(&node.prior);
42        } else {
43            // Child node: build from CPT
44            // Get parent cardinality (assuming single parent for now)
45            if let Some(parent_node) = config.nodes.get(&node.parents[0]) {
46                let parent_card = parent_node.states.len();
47
48                for (i, val) in values.iter_mut().enumerate().take(total_size) {
49                    // Decode indices - last variable (parent) changes fastest
50                    let parent_idx = i % parent_card;
51                    let state_idx = i / parent_card;
52
53                    if parent_idx < parent_node.states.len() {
54                        let parent_state = &parent_node.states[parent_idx];
55                        if let Some(probs) = node.cpt.get(parent_state) {
56                            if state_idx < probs.len() {
57                                *val = probs[state_idx];
58                            }
59                        }
60                    }
61                }
62            }
63        }
64
65        Self {
66            variables,
67            cardinalities,
68            values,
69        }
70    }
71
72    /// Multiply two factors
73    #[must_use]
74    pub fn multiply(&self, other: &Self) -> Self {
75        // Find common and unique variables
76        let mut new_variables = self.variables.clone();
77        let mut new_cardinalities = self.cardinalities.clone();
78
79        let mut other_indices: Vec<Option<usize>> = vec![None; other.variables.len()];
80
81        for (i, var) in other.variables.iter().enumerate() {
82            if let Some(pos) = self.variables.iter().position(|v| v == var) {
83                other_indices[i] = Some(pos);
84            } else {
85                new_variables.push(var.clone());
86                new_cardinalities.push(other.cardinalities[i]);
87                other_indices[i] = Some(new_variables.len() - 1);
88            }
89        }
90
91        let total_size: usize = new_cardinalities.iter().product();
92        let mut new_values = vec![0.0; total_size];
93
94        // Compute product
95        for (i, val) in new_values.iter_mut().enumerate() {
96            let indices = Self::decode_index(i, &new_cardinalities);
97
98            // Get index into self
99            let self_idx =
100                Self::encode_index(&indices[..self.variables.len()], &self.cardinalities);
101
102            // Get index into other
103            let other_idx_vec: Vec<usize> = other_indices
104                .iter()
105                .filter_map(|&idx| idx.map(|j| indices[j]))
106                .collect();
107            let other_idx = Self::encode_index(&other_idx_vec, &other.cardinalities);
108
109            let self_val = self.values.get(self_idx).copied().unwrap_or(0.0);
110            let other_val = other.values.get(other_idx).copied().unwrap_or(0.0);
111
112            *val = self_val * other_val;
113        }
114
115        Self {
116            variables: new_variables,
117            cardinalities: new_cardinalities,
118            values: new_values,
119        }
120    }
121
122    /// Marginalize (sum out) a variable
123    #[must_use]
124    pub fn marginalize(&self, var: &str) -> Self {
125        let Some(var_idx) = self.variables.iter().position(|v| v == var) else {
126            return self.clone();
127        };
128
129        let new_variables: Vec<String> = self
130            .variables
131            .iter()
132            .enumerate()
133            .filter(|(i, _)| *i != var_idx)
134            .map(|(_, v)| v.clone())
135            .collect();
136
137        let new_cardinalities: Vec<usize> = self
138            .cardinalities
139            .iter()
140            .enumerate()
141            .filter(|(i, _)| *i != var_idx)
142            .map(|(_, c)| *c)
143            .collect();
144
145        if new_variables.is_empty() {
146            // Marginalizing the last variable
147            return Self {
148                variables: vec![],
149                cardinalities: vec![],
150                values: vec![self.values.iter().sum()],
151            };
152        }
153
154        let total_size: usize = new_cardinalities.iter().product();
155        let mut new_values = vec![0.0; total_size];
156
157        for i in 0..self.values.len() {
158            let indices = Self::decode_index(i, &self.cardinalities);
159
160            // Get new index (without marginalized variable)
161            let new_idx_vec: Vec<usize> = indices
162                .iter()
163                .enumerate()
164                .filter(|(j, _)| *j != var_idx)
165                .map(|(_, idx)| *idx)
166                .collect();
167
168            let new_idx = if new_idx_vec.is_empty() {
169                0
170            } else {
171                Self::encode_index(&new_idx_vec, &new_cardinalities)
172            };
173
174            new_values[new_idx] += self.values[i];
175        }
176
177        Self {
178            variables: new_variables,
179            cardinalities: new_cardinalities,
180            values: new_values,
181        }
182    }
183
184    /// Normalize the factor
185    pub fn normalize(&mut self) {
186        let sum: f64 = self.values.iter().sum();
187        if sum > 0.0 {
188            for v in &mut self.values {
189                *v /= sum;
190            }
191        }
192    }
193
194    /// Decode a flat index to multi-dimensional indices
195    fn decode_index(mut idx: usize, cardinalities: &[usize]) -> Vec<usize> {
196        let mut indices = vec![0; cardinalities.len()];
197        for i in (0..cardinalities.len()).rev() {
198            indices[i] = idx % cardinalities[i];
199            idx /= cardinalities[i];
200        }
201        indices
202    }
203
204    /// Encode multi-dimensional indices to a flat index
205    fn encode_index(indices: &[usize], cardinalities: &[usize]) -> usize {
206        let mut idx = 0;
207        let mut multiplier = 1;
208        for i in (0..indices.len()).rev() {
209            idx += indices[i] * multiplier;
210            multiplier *= cardinalities.get(i).copied().unwrap_or(1);
211        }
212        idx
213    }
214
215    /// Get probability for a specific assignment
216    #[must_use]
217    pub fn get_probability(&self, assignment: &HashMap<String, usize>) -> f64 {
218        let indices: Vec<usize> = self
219            .variables
220            .iter()
221            .map(|v| assignment.get(v).copied().unwrap_or(0))
222            .collect();
223        let idx = Self::encode_index(&indices, &self.cardinalities);
224        self.values.get(idx).copied().unwrap_or(0.0)
225    }
226}
227
228/// Belief Propagation (Variable Elimination) for exact inference
229pub struct BeliefPropagation {
230    config: BayesianConfig,
231    factors: Vec<Factor>,
232}
233
234impl BeliefPropagation {
235    /// Create a new belief propagation engine
236    ///
237    /// # Errors
238    ///
239    /// Returns an error if the configuration is invalid.
240    pub fn new(config: BayesianConfig) -> Result<Self, String> {
241        config.validate()?;
242
243        // Build initial factors from nodes
244        let mut factors = Vec::new();
245        for (name, node) in &config.nodes {
246            if node.node_type == NodeType::Discrete {
247                factors.push(Factor::from_node(name, node, &config));
248            }
249        }
250
251        Ok(Self { config, factors })
252    }
253
254    /// Query the marginal probability of a variable
255    ///
256    /// # Errors
257    ///
258    /// Returns an error if the target variable is not found or no factors remain.
259    pub fn query(&self, target: &str) -> Result<Vec<f64>, String> {
260        if !self.config.nodes.contains_key(target) {
261            return Err(format!("Variable '{target}' not found in network"));
262        }
263
264        // Variable elimination
265        let order = self.get_elimination_order(target);
266
267        let mut factors = self.factors.clone();
268
269        for var in order {
270            if var == target {
271                continue;
272            }
273
274            // Find factors containing this variable
275            let (containing, remaining): (Vec<_>, Vec<_>) = factors
276                .into_iter()
277                .partition(|f| f.variables.contains(&var));
278
279            if containing.is_empty() {
280                factors = remaining;
281                continue;
282            }
283
284            // Multiply containing factors
285            let mut product = containing[0].clone();
286            for f in containing.iter().skip(1) {
287                product = product.multiply(f);
288            }
289
290            // Marginalize
291            let marginal = product.marginalize(&var);
292
293            factors = remaining;
294            factors.push(marginal);
295        }
296
297        // Multiply remaining factors
298        if factors.is_empty() {
299            return Err("No factors remaining".to_string());
300        }
301
302        let mut result = factors[0].clone();
303        for f in factors.iter().skip(1) {
304            result = result.multiply(f);
305        }
306
307        // Normalize
308        result.normalize();
309
310        // Extract probabilities for target variable
311        // After variable elimination, result should only contain the target variable
312        if result.variables.len() == 1 && result.variables[0] == target {
313            // Simple case: result is already just the target variable
314            let sum: f64 = result.values.iter().sum();
315            if sum > 0.0 {
316                Ok(result.values.iter().map(|v| v / sum).collect())
317            } else {
318                Ok(result.values.clone())
319            }
320        } else {
321            // Complex case: marginalize out any remaining variables except target
322            let mut final_result = result.clone();
323            for var in &result.variables {
324                if var != target {
325                    final_result = final_result.marginalize(var);
326                }
327            }
328
329            // Extract probabilities
330            let sum: f64 = final_result.values.iter().sum();
331            if sum > 0.0 {
332                Ok(final_result.values.iter().map(|v| v / sum).collect())
333            } else {
334                Ok(final_result.values)
335            }
336        }
337    }
338
339    /// Query with evidence (observed values)
340    ///
341    /// # Errors
342    ///
343    /// Returns an error if the target variable is not found or no factors remain.
344    pub fn query_with_evidence(
345        &self,
346        target: &str,
347        evidence: &HashMap<String, usize>,
348    ) -> Result<Vec<f64>, String> {
349        if !self.config.nodes.contains_key(target) {
350            return Err(format!("Variable '{target}' not found in network"));
351        }
352
353        // Apply evidence to factors
354        let mut factors: Vec<Factor> = self
355            .factors
356            .iter()
357            .map(|f| Self::apply_evidence(f, evidence))
358            .collect();
359
360        // Variable elimination (excluding evidence variables)
361        let order = self.get_elimination_order(target);
362
363        for var in order {
364            if var == target || evidence.contains_key(&var) {
365                continue;
366            }
367
368            // Find factors containing this variable
369            let (containing, remaining): (Vec<_>, Vec<_>) = factors
370                .into_iter()
371                .partition(|f| f.variables.contains(&var));
372
373            if containing.is_empty() {
374                factors = remaining;
375                continue;
376            }
377
378            // Multiply containing factors
379            let mut product = containing[0].clone();
380            for f in containing.iter().skip(1) {
381                product = product.multiply(f);
382            }
383
384            // Marginalize
385            let marginal = product.marginalize(&var);
386
387            factors = remaining;
388            factors.push(marginal);
389        }
390
391        // Multiply remaining factors
392        if factors.is_empty() {
393            return Err("No factors remaining".to_string());
394        }
395
396        let mut result = factors[0].clone();
397        for f in factors.iter().skip(1) {
398            result = result.multiply(f);
399        }
400
401        // Normalize
402        result.normalize();
403
404        // Extract probabilities for target variable
405        // After variable elimination, result should only contain the target variable
406        if result.variables.len() == 1 && result.variables[0] == target {
407            // Simple case: result is already just the target variable
408            let sum: f64 = result.values.iter().sum();
409            if sum > 0.0 {
410                Ok(result.values.iter().map(|v| v / sum).collect())
411            } else {
412                Ok(result.values.clone())
413            }
414        } else {
415            // Complex case: marginalize out any remaining variables except target
416            let mut final_result = result.clone();
417            for var in &result.variables {
418                if var != target {
419                    final_result = final_result.marginalize(var);
420                }
421            }
422
423            // Extract probabilities
424            let sum: f64 = final_result.values.iter().sum();
425            if sum > 0.0 {
426                Ok(final_result.values.iter().map(|v| v / sum).collect())
427            } else {
428                Ok(final_result.values)
429            }
430        }
431    }
432
433    /// Apply evidence to a factor
434    fn apply_evidence(factor: &Factor, evidence: &HashMap<String, usize>) -> Factor {
435        let mut new_values = factor.values.clone();
436
437        for (i, val) in new_values.iter_mut().enumerate() {
438            let indices = Factor::decode_index(i, &factor.cardinalities);
439
440            for (var_idx, var) in factor.variables.iter().enumerate() {
441                if let Some(&ev_val) = evidence.get(var) {
442                    if indices[var_idx] != ev_val {
443                        *val = 0.0;
444                        break;
445                    }
446                }
447            }
448        }
449
450        Factor {
451            variables: factor.variables.clone(),
452            cardinalities: factor.cardinalities.clone(),
453            values: new_values,
454        }
455    }
456
457    /// Get elimination order (simple reverse topological)
458    fn get_elimination_order(&self, exclude: &str) -> Vec<String> {
459        let mut order = self.config.topological_order();
460        order.reverse();
461        order.retain(|v| v != exclude);
462        order
463    }
464
465    /// Get the configuration
466    #[must_use]
467    pub const fn config(&self) -> &BayesianConfig {
468        &self.config
469    }
470}
471
472#[cfg(test)]
473mod inference_tests {
474    use super::*;
475
476    fn create_simple_network() -> BayesianConfig {
477        // Rain -> Sprinkler
478        //    \-> Wet Grass <- Sprinkler
479        BayesianConfig::new("Sprinkler")
480            .with_node(
481                "rain",
482                BayesianNode::discrete(vec!["no", "yes"]).with_prior(vec![0.8, 0.2]),
483            )
484            .with_node(
485                "sprinkler",
486                BayesianNode::discrete(vec!["off", "on"])
487                    .with_parents(vec!["rain"])
488                    .with_cpt_entry("no", vec![0.6, 0.4])
489                    .with_cpt_entry("yes", vec![0.99, 0.01]),
490            )
491    }
492
493    #[test]
494    fn test_prior_query() {
495        let config = create_simple_network();
496        let bp = BeliefPropagation::new(config).unwrap();
497
498        let rain_probs = bp.query("rain").unwrap();
499        assert!(
500            (rain_probs[0] - 0.8).abs() < 0.01,
501            "P(rain=no) should be 0.8"
502        );
503        assert!(
504            (rain_probs[1] - 0.2).abs() < 0.01,
505            "P(rain=yes) should be 0.2"
506        );
507    }
508
509    #[test]
510    fn test_marginal_query() {
511        let config = create_simple_network();
512        let bp = BeliefPropagation::new(config).unwrap();
513
514        let sprinkler_probs = bp.query("sprinkler").unwrap();
515
516        // P(sprinkler=on) = P(sprinkler=on|rain=no)*P(rain=no) + P(sprinkler=on|rain=yes)*P(rain=yes)
517        //                 = 0.4 * 0.8 + 0.01 * 0.2 = 0.32 + 0.002 = 0.322
518        let expected_on = 0.4f64.mul_add(0.8, 0.01 * 0.2);
519        assert!(
520            (sprinkler_probs[1] - expected_on).abs() < 0.01,
521            "P(sprinkler=on) should be {}, got {}",
522            expected_on,
523            sprinkler_probs[1]
524        );
525    }
526
527    #[test]
528    fn test_evidence_query() {
529        let config = create_simple_network();
530        let bp = BeliefPropagation::new(config).unwrap();
531
532        // Query P(sprinkler | rain=yes)
533        let mut evidence = HashMap::new();
534        evidence.insert("rain".to_string(), 1); // yes
535
536        let probs = bp.query_with_evidence("sprinkler", &evidence).unwrap();
537
538        // P(sprinkler=on | rain=yes) = 0.01
539        assert!(
540            (probs[1] - 0.01).abs() < 0.01,
541            "P(sprinkler=on | rain=yes) should be 0.01, got {}",
542            probs[1]
543        );
544    }
545}