mollendorff_forge/bayesian/
engine.rs1use super::config::BayesianConfig;
6use super::inference::BeliefPropagation;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct VariableResult {
13 pub name: String,
15 pub states: Vec<String>,
17 pub probabilities: Vec<f64>,
19 pub most_likely: String,
21 pub max_probability: f64,
23}
24
25impl VariableResult {
26 #[must_use]
28 pub fn get_probability(&self, state: &str) -> Option<f64> {
29 self.states
30 .iter()
31 .position(|s| s == state)
32 .map(|idx| self.probabilities[idx])
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct BayesianResult {
39 pub name: String,
41 pub queries: HashMap<String, VariableResult>,
43 pub evidence: HashMap<String, String>,
45}
46
47impl BayesianResult {
48 #[must_use]
50 pub fn to_yaml(&self) -> String {
51 serde_yaml_ng::to_string(self).unwrap_or_else(|_| "# Error serializing results".to_string())
52 }
53
54 pub fn to_json(&self) -> Result<String, serde_json::Error> {
60 serde_json::to_string_pretty(self)
61 }
62}
63
64pub struct BayesianEngine {
66 config: BayesianConfig,
67 bp: BeliefPropagation,
68}
69
70impl BayesianEngine {
71 pub fn new(config: BayesianConfig) -> Result<Self, String> {
77 let bp = BeliefPropagation::new(config.clone())?;
78 Ok(Self { config, bp })
79 }
80
81 pub fn query(&self, target: &str) -> Result<VariableResult, String> {
87 let probs = self.bp.query(target)?;
88
89 let node = self
90 .config
91 .nodes
92 .get(target)
93 .ok_or_else(|| format!("Variable '{target}' not found"))?;
94
95 let (max_idx, max_prob) = probs
96 .iter()
97 .enumerate()
98 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
99 .map_or((0, 0.0), |(i, p)| (i, *p));
100
101 Ok(VariableResult {
102 name: target.to_string(),
103 states: node.states.clone(),
104 probabilities: probs,
105 most_likely: node.states.get(max_idx).cloned().unwrap_or_default(),
106 max_probability: max_prob,
107 })
108 }
109
110 pub fn query_with_evidence(
116 &self,
117 target: &str,
118 evidence: &HashMap<String, &str>,
119 ) -> Result<VariableResult, String> {
120 let mut evidence_indices = HashMap::new();
122 for (var, state) in evidence {
123 let node = self
124 .config
125 .nodes
126 .get(var.as_str())
127 .ok_or_else(|| format!("Evidence variable '{var}' not found"))?;
128
129 let idx = node
130 .states
131 .iter()
132 .position(|s| s == state)
133 .ok_or_else(|| format!("State '{state}' not found for variable '{var}'"))?;
134
135 evidence_indices.insert(var.clone(), idx);
136 }
137
138 let probs = self.bp.query_with_evidence(target, &evidence_indices)?;
139
140 let node = self
141 .config
142 .nodes
143 .get(target)
144 .ok_or_else(|| format!("Variable '{target}' not found"))?;
145
146 let (max_idx, max_prob) = probs
147 .iter()
148 .enumerate()
149 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
150 .map_or((0, 0.0), |(i, p)| (i, *p));
151
152 Ok(VariableResult {
153 name: target.to_string(),
154 states: node.states.clone(),
155 probabilities: probs,
156 most_likely: node.states.get(max_idx).cloned().unwrap_or_default(),
157 max_probability: max_prob,
158 })
159 }
160
161 pub fn query_all(&self) -> Result<BayesianResult, String> {
167 let mut queries = HashMap::new();
168
169 for name in self.config.nodes.keys() {
170 let result = self.query(name)?;
171 queries.insert(name.clone(), result);
172 }
173
174 Ok(BayesianResult {
175 name: self.config.name.clone(),
176 queries,
177 evidence: HashMap::new(),
178 })
179 }
180
181 pub fn query_all_with_evidence(
187 &self,
188 evidence: &HashMap<String, &str>,
189 ) -> Result<BayesianResult, String> {
190 let mut queries = HashMap::new();
191
192 for name in self.config.nodes.keys() {
193 if evidence.contains_key(name) {
195 continue;
196 }
197
198 let result = self.query_with_evidence(name, evidence)?;
199 queries.insert(name.clone(), result);
200 }
201
202 let evidence_str: HashMap<String, String> = evidence
204 .iter()
205 .map(|(k, v)| (k.clone(), (*v).to_string()))
206 .collect();
207
208 Ok(BayesianResult {
209 name: self.config.name.clone(),
210 queries,
211 evidence: evidence_str,
212 })
213 }
214
215 pub fn most_likely_explanation(&self) -> Result<HashMap<String, String>, String> {
221 let mut explanation = HashMap::new();
222
223 for name in self.config.nodes.keys() {
224 let result = self.query(name)?;
225 explanation.insert(name.clone(), result.most_likely);
226 }
227
228 Ok(explanation)
229 }
230
231 #[must_use]
233 pub const fn config(&self) -> &BayesianConfig {
234 &self.config
235 }
236}
237
238#[cfg(test)]
239mod engine_tests {
240 use super::*;
241 use crate::bayesian::config::BayesianNode;
242
243 fn create_credit_risk_network() -> BayesianConfig {
244 BayesianConfig::new("Credit Risk")
245 .with_node(
246 "economic_conditions",
247 BayesianNode::discrete(vec!["good", "neutral", "bad"])
248 .with_prior(vec![0.3, 0.5, 0.2]),
249 )
250 .with_node(
251 "company_revenue",
252 BayesianNode::discrete(vec!["high", "medium", "low"])
253 .with_parents(vec!["economic_conditions"])
254 .with_cpt_entry("good", vec![0.6, 0.3, 0.1])
255 .with_cpt_entry("neutral", vec![0.3, 0.5, 0.2])
256 .with_cpt_entry("bad", vec![0.1, 0.3, 0.6]),
257 )
258 .with_node(
259 "default_probability",
260 BayesianNode::discrete(vec!["low", "medium", "high"])
261 .with_parents(vec!["company_revenue"])
262 .with_cpt_entry("high", vec![0.8, 0.15, 0.05])
263 .with_cpt_entry("medium", vec![0.4, 0.4, 0.2])
264 .with_cpt_entry("low", vec![0.1, 0.3, 0.6]),
265 )
266 }
267
268 #[test]
269 fn test_engine_creation() {
270 let config = create_credit_risk_network();
271 let engine = BayesianEngine::new(config);
272 assert!(engine.is_ok());
273 }
274
275 #[test]
276 fn test_marginal_query() {
277 let config = create_credit_risk_network();
278 let engine = BayesianEngine::new(config).unwrap();
279
280 let result = engine.query("economic_conditions").unwrap();
281
282 assert_eq!(result.states.len(), 3);
283 assert!((result.probabilities[0] - 0.3).abs() < 0.01); assert!((result.probabilities[1] - 0.5).abs() < 0.01); assert!((result.probabilities[2] - 0.2).abs() < 0.01); assert_eq!(result.most_likely, "neutral");
288 }
289
290 #[test]
291 fn test_evidence_query() {
292 let config = create_credit_risk_network();
293 let engine = BayesianEngine::new(config).unwrap();
294
295 let mut evidence = HashMap::new();
296 evidence.insert("economic_conditions".to_string(), "bad");
297
298 let result = engine
299 .query_with_evidence("company_revenue", &evidence)
300 .unwrap();
301
302 assert!((result.probabilities[0] - 0.1).abs() < 0.01); assert!((result.probabilities[1] - 0.3).abs() < 0.01); assert!((result.probabilities[2] - 0.6).abs() < 0.01); assert_eq!(result.most_likely, "low");
308 }
309
310 #[test]
311 fn test_query_all() {
312 let config = create_credit_risk_network();
313 let engine = BayesianEngine::new(config).unwrap();
314
315 let result = engine.query_all().unwrap();
316
317 assert_eq!(result.queries.len(), 3);
318 assert!(result.queries.contains_key("economic_conditions"));
319 assert!(result.queries.contains_key("company_revenue"));
320 assert!(result.queries.contains_key("default_probability"));
321 }
322
323 #[test]
324 fn test_most_likely_explanation() {
325 let config = create_credit_risk_network();
326 let engine = BayesianEngine::new(config).unwrap();
327
328 let mpe = engine.most_likely_explanation().unwrap();
329
330 assert!(mpe.contains_key("economic_conditions"));
331 assert!(mpe.contains_key("company_revenue"));
332 assert!(mpe.contains_key("default_probability"));
333
334 assert_eq!(mpe.get("economic_conditions"), Some(&"neutral".to_string()));
336 }
337
338 #[test]
339 fn test_yaml_export() {
340 let config = create_credit_risk_network();
341 let engine = BayesianEngine::new(config).unwrap();
342 let result = engine.query_all().unwrap();
343 let yaml = result.to_yaml();
344
345 assert!(yaml.contains("queries:"));
346 assert!(yaml.contains("economic_conditions"));
347 }
348
349 #[test]
350 fn test_json_export() {
351 let config = create_credit_risk_network();
352 let engine = BayesianEngine::new(config).unwrap();
353 let result = engine.query_all().unwrap();
354 let json = result.to_json().unwrap();
355
356 assert!(json.contains("\"queries\""));
357 assert!(json.contains("\"economic_conditions\""));
358 }
359}