Skip to main content

mollendorff_forge/bayesian/
engine.rs

1//! Bayesian Network Engine
2//!
3//! High-level interface for Bayesian network inference.
4
5use super::config::BayesianConfig;
6use super::inference::BeliefPropagation;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Query result for a variable
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct VariableResult {
13    /// Variable name
14    pub name: String,
15    /// State names
16    pub states: Vec<String>,
17    /// Probability for each state
18    pub probabilities: Vec<f64>,
19    /// Most likely state
20    pub most_likely: String,
21    /// Probability of most likely state
22    pub max_probability: f64,
23}
24
25impl VariableResult {
26    /// Get probability for a specific state
27    #[must_use]
28    pub fn get_probability(&self, state: &str) -> Option<f64> {
29        self.states
30            .iter()
31            .position(|s| s == state)
32            .map(|idx| self.probabilities[idx])
33    }
34}
35
36/// Complete Bayesian analysis result
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct BayesianResult {
39    /// Network name
40    pub name: String,
41    /// Query results
42    pub queries: HashMap<String, VariableResult>,
43    /// Evidence used
44    pub evidence: HashMap<String, String>,
45}
46
47impl BayesianResult {
48    /// Export results to YAML format
49    #[must_use]
50    pub fn to_yaml(&self) -> String {
51        serde_yaml_ng::to_string(self).unwrap_or_else(|_| "# Error serializing results".to_string())
52    }
53
54    /// Export results to JSON format
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if JSON serialization fails.
59    pub fn to_json(&self) -> Result<String, serde_json::Error> {
60        serde_json::to_string_pretty(self)
61    }
62}
63
64/// Bayesian Network Engine
65pub struct BayesianEngine {
66    config: BayesianConfig,
67    bp: BeliefPropagation,
68}
69
70impl BayesianEngine {
71    /// Create a new Bayesian engine
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the configuration is invalid.
76    pub fn new(config: BayesianConfig) -> Result<Self, String> {
77        let bp = BeliefPropagation::new(config.clone())?;
78        Ok(Self { config, bp })
79    }
80
81    /// Query the marginal probability of a variable
82    ///
83    /// # Errors
84    ///
85    /// Returns an error if the target variable is not found in the network.
86    pub fn query(&self, target: &str) -> Result<VariableResult, String> {
87        let probs = self.bp.query(target)?;
88
89        let node = self
90            .config
91            .nodes
92            .get(target)
93            .ok_or_else(|| format!("Variable '{target}' not found"))?;
94
95        let (max_idx, max_prob) = probs
96            .iter()
97            .enumerate()
98            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
99            .map_or((0, 0.0), |(i, p)| (i, *p));
100
101        Ok(VariableResult {
102            name: target.to_string(),
103            states: node.states.clone(),
104            probabilities: probs,
105            most_likely: node.states.get(max_idx).cloned().unwrap_or_default(),
106            max_probability: max_prob,
107        })
108    }
109
110    /// Query with evidence
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the target or evidence variables are not found.
115    pub fn query_with_evidence(
116        &self,
117        target: &str,
118        evidence: &HashMap<String, &str>,
119    ) -> Result<VariableResult, String> {
120        // Convert evidence from state names to indices
121        let mut evidence_indices = HashMap::new();
122        for (var, state) in evidence {
123            let node = self
124                .config
125                .nodes
126                .get(var.as_str())
127                .ok_or_else(|| format!("Evidence variable '{var}' not found"))?;
128
129            let idx = node
130                .states
131                .iter()
132                .position(|s| s == state)
133                .ok_or_else(|| format!("State '{state}' not found for variable '{var}'"))?;
134
135            evidence_indices.insert(var.clone(), idx);
136        }
137
138        let probs = self.bp.query_with_evidence(target, &evidence_indices)?;
139
140        let node = self
141            .config
142            .nodes
143            .get(target)
144            .ok_or_else(|| format!("Variable '{target}' not found"))?;
145
146        let (max_idx, max_prob) = probs
147            .iter()
148            .enumerate()
149            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
150            .map_or((0, 0.0), |(i, p)| (i, *p));
151
152        Ok(VariableResult {
153            name: target.to_string(),
154            states: node.states.clone(),
155            probabilities: probs,
156            most_likely: node.states.get(max_idx).cloned().unwrap_or_default(),
157            max_probability: max_prob,
158        })
159    }
160
161    /// Query all variables
162    ///
163    /// # Errors
164    ///
165    /// Returns an error if querying any variable fails.
166    pub fn query_all(&self) -> Result<BayesianResult, String> {
167        let mut queries = HashMap::new();
168
169        for name in self.config.nodes.keys() {
170            let result = self.query(name)?;
171            queries.insert(name.clone(), result);
172        }
173
174        Ok(BayesianResult {
175            name: self.config.name.clone(),
176            queries,
177            evidence: HashMap::new(),
178        })
179    }
180
181    /// Query all variables with evidence
182    ///
183    /// # Errors
184    ///
185    /// Returns an error if querying any variable fails or evidence is invalid.
186    pub fn query_all_with_evidence(
187        &self,
188        evidence: &HashMap<String, &str>,
189    ) -> Result<BayesianResult, String> {
190        let mut queries = HashMap::new();
191
192        for name in self.config.nodes.keys() {
193            // Skip evidence variables (their probability is deterministic)
194            if evidence.contains_key(name) {
195                continue;
196            }
197
198            let result = self.query_with_evidence(name, evidence)?;
199            queries.insert(name.clone(), result);
200        }
201
202        // Convert evidence to string map
203        let evidence_str: HashMap<String, String> = evidence
204            .iter()
205            .map(|(k, v)| (k.clone(), (*v).to_string()))
206            .collect();
207
208        Ok(BayesianResult {
209            name: self.config.name.clone(),
210            queries,
211            evidence: evidence_str,
212        })
213    }
214
215    /// Get the most likely explanation (MPE) for all variables
216    ///
217    /// # Errors
218    ///
219    /// Returns an error if querying any variable fails.
220    pub fn most_likely_explanation(&self) -> Result<HashMap<String, String>, String> {
221        let mut explanation = HashMap::new();
222
223        for name in self.config.nodes.keys() {
224            let result = self.query(name)?;
225            explanation.insert(name.clone(), result.most_likely);
226        }
227
228        Ok(explanation)
229    }
230
231    /// Get the configuration
232    #[must_use]
233    pub const fn config(&self) -> &BayesianConfig {
234        &self.config
235    }
236}
237
238#[cfg(test)]
239mod engine_tests {
240    use super::*;
241    use crate::bayesian::config::BayesianNode;
242
243    fn create_credit_risk_network() -> BayesianConfig {
244        BayesianConfig::new("Credit Risk")
245            .with_node(
246                "economic_conditions",
247                BayesianNode::discrete(vec!["good", "neutral", "bad"])
248                    .with_prior(vec![0.3, 0.5, 0.2]),
249            )
250            .with_node(
251                "company_revenue",
252                BayesianNode::discrete(vec!["high", "medium", "low"])
253                    .with_parents(vec!["economic_conditions"])
254                    .with_cpt_entry("good", vec![0.6, 0.3, 0.1])
255                    .with_cpt_entry("neutral", vec![0.3, 0.5, 0.2])
256                    .with_cpt_entry("bad", vec![0.1, 0.3, 0.6]),
257            )
258            .with_node(
259                "default_probability",
260                BayesianNode::discrete(vec!["low", "medium", "high"])
261                    .with_parents(vec!["company_revenue"])
262                    .with_cpt_entry("high", vec![0.8, 0.15, 0.05])
263                    .with_cpt_entry("medium", vec![0.4, 0.4, 0.2])
264                    .with_cpt_entry("low", vec![0.1, 0.3, 0.6]),
265            )
266    }
267
268    #[test]
269    fn test_engine_creation() {
270        let config = create_credit_risk_network();
271        let engine = BayesianEngine::new(config);
272        assert!(engine.is_ok());
273    }
274
275    #[test]
276    fn test_marginal_query() {
277        let config = create_credit_risk_network();
278        let engine = BayesianEngine::new(config).unwrap();
279
280        let result = engine.query("economic_conditions").unwrap();
281
282        assert_eq!(result.states.len(), 3);
283        assert!((result.probabilities[0] - 0.3).abs() < 0.01); // good
284        assert!((result.probabilities[1] - 0.5).abs() < 0.01); // neutral
285        assert!((result.probabilities[2] - 0.2).abs() < 0.01); // bad
286
287        assert_eq!(result.most_likely, "neutral");
288    }
289
290    #[test]
291    fn test_evidence_query() {
292        let config = create_credit_risk_network();
293        let engine = BayesianEngine::new(config).unwrap();
294
295        let mut evidence = HashMap::new();
296        evidence.insert("economic_conditions".to_string(), "bad");
297
298        let result = engine
299            .query_with_evidence("company_revenue", &evidence)
300            .unwrap();
301
302        // P(revenue | economy=bad) = [0.1, 0.3, 0.6]
303        assert!((result.probabilities[0] - 0.1).abs() < 0.01); // high
304        assert!((result.probabilities[1] - 0.3).abs() < 0.01); // medium
305        assert!((result.probabilities[2] - 0.6).abs() < 0.01); // low
306
307        assert_eq!(result.most_likely, "low");
308    }
309
310    #[test]
311    fn test_query_all() {
312        let config = create_credit_risk_network();
313        let engine = BayesianEngine::new(config).unwrap();
314
315        let result = engine.query_all().unwrap();
316
317        assert_eq!(result.queries.len(), 3);
318        assert!(result.queries.contains_key("economic_conditions"));
319        assert!(result.queries.contains_key("company_revenue"));
320        assert!(result.queries.contains_key("default_probability"));
321    }
322
323    #[test]
324    fn test_most_likely_explanation() {
325        let config = create_credit_risk_network();
326        let engine = BayesianEngine::new(config).unwrap();
327
328        let mpe = engine.most_likely_explanation().unwrap();
329
330        assert!(mpe.contains_key("economic_conditions"));
331        assert!(mpe.contains_key("company_revenue"));
332        assert!(mpe.contains_key("default_probability"));
333
334        // Most likely economy is neutral (0.5)
335        assert_eq!(mpe.get("economic_conditions"), Some(&"neutral".to_string()));
336    }
337
338    #[test]
339    fn test_yaml_export() {
340        let config = create_credit_risk_network();
341        let engine = BayesianEngine::new(config).unwrap();
342        let result = engine.query_all().unwrap();
343        let yaml = result.to_yaml();
344
345        assert!(yaml.contains("queries:"));
346        assert!(yaml.contains("economic_conditions"));
347    }
348
349    #[test]
350    fn test_json_export() {
351        let config = create_credit_risk_network();
352        let engine = BayesianEngine::new(config).unwrap();
353        let result = engine.query_all().unwrap();
354        let json = result.to_json().unwrap();
355
356        assert!(json.contains("\"queries\""));
357        assert!(json.contains("\"economic_conditions\""));
358    }
359}