mollendorff_forge/bayesian/
config.rs1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
10#[serde(rename_all = "lowercase")]
11pub enum NodeType {
12 #[default]
14 Discrete,
15 Continuous,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct BayesianNode {
22 #[serde(default, rename = "type")]
24 pub node_type: NodeType,
25 #[serde(default)]
27 pub states: Vec<String>,
28 #[serde(default)]
30 pub prior: Vec<f64>,
31 #[serde(default)]
33 pub parents: Vec<String>,
34 #[serde(default)]
37 pub cpt: HashMap<String, Vec<f64>>,
38 #[serde(default)]
40 pub mean: f64,
41 #[serde(default)]
43 pub std: f64,
44}
45
46impl BayesianNode {
47 pub fn discrete(states: Vec<&str>) -> Self {
49 Self {
50 node_type: NodeType::Discrete,
51 states: states
52 .into_iter()
53 .map(std::string::ToString::to_string)
54 .collect(),
55 prior: Vec::new(),
56 parents: Vec::new(),
57 cpt: HashMap::new(),
58 mean: 0.0,
59 std: 1.0,
60 }
61 }
62
63 #[must_use]
65 pub fn continuous(mean: f64, std: f64) -> Self {
66 Self {
67 node_type: NodeType::Continuous,
68 states: Vec::new(),
69 prior: Vec::new(),
70 parents: Vec::new(),
71 cpt: HashMap::new(),
72 mean,
73 std,
74 }
75 }
76
77 #[must_use]
79 pub fn with_prior(mut self, prior: Vec<f64>) -> Self {
80 self.prior = prior;
81 self
82 }
83
84 #[must_use]
86 pub fn with_parents(mut self, parents: Vec<&str>) -> Self {
87 self.parents = parents
88 .into_iter()
89 .map(std::string::ToString::to_string)
90 .collect();
91 self
92 }
93
94 #[must_use]
96 pub fn with_cpt_entry(mut self, parent_state: &str, probs: Vec<f64>) -> Self {
97 self.cpt.insert(parent_state.to_string(), probs);
98 self
99 }
100
101 pub fn validate(&self, name: &str) -> Result<(), String> {
108 match self.node_type {
109 NodeType::Discrete => self.validate_discrete(name),
110 NodeType::Continuous => self.validate_continuous(name),
111 }
112 }
113
114 fn validate_discrete(&self, name: &str) -> Result<(), String> {
115 if self.states.is_empty() {
116 return Err(format!("Node '{name}': discrete node must have states"));
117 }
118
119 if self.parents.is_empty() {
121 if self.prior.is_empty() {
122 return Err(format!(
123 "Node '{name}': root node must have prior probabilities"
124 ));
125 }
126 if self.prior.len() != self.states.len() {
127 return Err(format!(
128 "Node '{}': prior length ({}) must match states ({})",
129 name,
130 self.prior.len(),
131 self.states.len()
132 ));
133 }
134 let sum: f64 = self.prior.iter().sum();
135 if (sum - 1.0).abs() > 0.001 {
136 return Err(format!(
137 "Node '{name}': prior probabilities must sum to 1.0, got {sum}"
138 ));
139 }
140 } else {
141 if self.cpt.is_empty() {
143 return Err(format!("Node '{name}': child node must have CPT"));
144 }
145 for (key, probs) in &self.cpt {
146 if probs.len() != self.states.len() {
147 return Err(format!(
148 "Node '{}': CPT entry '{}' length ({}) must match states ({})",
149 name,
150 key,
151 probs.len(),
152 self.states.len()
153 ));
154 }
155 let sum: f64 = probs.iter().sum();
156 if (sum - 1.0).abs() > 0.001 {
157 return Err(format!(
158 "Node '{name}': CPT entry '{key}' must sum to 1.0, got {sum}"
159 ));
160 }
161 }
162 }
163
164 Ok(())
165 }
166
167 fn validate_continuous(&self, name: &str) -> Result<(), String> {
168 if self.std <= 0.0 {
169 return Err(format!(
170 "Node '{name}': standard deviation must be positive"
171 ));
172 }
173 Ok(())
174 }
175
176 #[must_use]
178 pub const fn is_root(&self) -> bool {
179 self.parents.is_empty()
180 }
181
182 #[must_use]
184 pub fn get_probability(&self, state_idx: usize, parent_state: Option<&str>) -> f64 {
185 if self.is_root() {
186 self.prior.get(state_idx).copied().unwrap_or(0.0)
187 } else if let Some(ps) = parent_state {
188 self.cpt
189 .get(ps)
190 .and_then(|probs| probs.get(state_idx))
191 .copied()
192 .unwrap_or(0.0)
193 } else {
194 0.0
195 }
196 }
197}
198
199#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201pub struct BayesianConfig {
202 #[serde(default)]
204 pub name: String,
205 #[serde(default)]
207 pub nodes: HashMap<String, BayesianNode>,
208}
209
210impl BayesianConfig {
211 #[must_use]
213 pub fn new(name: &str) -> Self {
214 Self {
215 name: name.to_string(),
216 nodes: HashMap::new(),
217 }
218 }
219
220 #[must_use]
222 pub fn with_node(mut self, name: &str, node: BayesianNode) -> Self {
223 self.nodes.insert(name.to_string(), node);
224 self
225 }
226
227 pub fn validate(&self) -> Result<(), String> {
234 if self.nodes.is_empty() {
235 return Err("Network must have at least one node".to_string());
236 }
237
238 for (name, node) in &self.nodes {
240 node.validate(name)?;
241
242 for parent in &node.parents {
244 if !self.nodes.contains_key(parent) {
245 return Err(format!(
246 "Node '{name}' references non-existent parent '{parent}'"
247 ));
248 }
249 }
250 }
251
252 self.check_cycles()?;
254
255 Ok(())
256 }
257
258 fn check_cycles(&self) -> Result<(), String> {
260 let mut visited = std::collections::HashSet::new();
261 let mut stack = std::collections::HashSet::new();
262
263 for name in self.nodes.keys() {
264 self.dfs_cycle_check(name, &mut visited, &mut stack)?;
265 }
266
267 Ok(())
268 }
269
270 fn dfs_cycle_check(
271 &self,
272 name: &str,
273 visited: &mut std::collections::HashSet<String>,
274 stack: &mut std::collections::HashSet<String>,
275 ) -> Result<(), String> {
276 if stack.contains(name) {
277 return Err(format!("Cycle detected involving node '{name}'"));
278 }
279 if visited.contains(name) {
280 return Ok(());
281 }
282
283 visited.insert(name.to_string());
284 stack.insert(name.to_string());
285
286 if let Some(node) = self.nodes.get(name) {
287 for parent in &node.parents {
288 self.dfs_cycle_check(parent, visited, stack)?;
289 }
290 }
291
292 stack.remove(name);
293 Ok(())
294 }
295
296 #[must_use]
298 pub fn topological_order(&self) -> Vec<String> {
299 fn visit(
300 name: &str,
301 config: &BayesianConfig,
302 visited: &mut std::collections::HashSet<String>,
303 order: &mut Vec<String>,
304 ) {
305 if visited.contains(name) {
306 return;
307 }
308 visited.insert(name.to_string());
309
310 if let Some(node) = config.nodes.get(name) {
311 for parent in &node.parents {
312 visit(parent, config, visited, order);
313 }
314 }
315
316 order.push(name.to_string());
317 }
318
319 let mut order = Vec::new();
320 let mut visited = std::collections::HashSet::new();
321
322 for name in self.nodes.keys() {
323 visit(name, self, &mut visited, &mut order);
324 }
325
326 order
327 }
328
329 #[must_use]
331 pub fn root_nodes(&self) -> Vec<&str> {
332 self.nodes
333 .iter()
334 .filter(|(_, node)| node.is_root())
335 .map(|(name, _)| name.as_str())
336 .collect()
337 }
338}
339
340#[cfg(test)]
341mod config_tests {
342 use super::*;
343
344 fn create_credit_risk_network() -> BayesianConfig {
345 BayesianConfig::new("Credit Risk")
346 .with_node(
347 "economic_conditions",
348 BayesianNode::discrete(vec!["good", "neutral", "bad"])
349 .with_prior(vec![0.3, 0.5, 0.2]),
350 )
351 .with_node(
352 "company_revenue",
353 BayesianNode::discrete(vec!["high", "medium", "low"])
354 .with_parents(vec!["economic_conditions"])
355 .with_cpt_entry("good", vec![0.6, 0.3, 0.1])
356 .with_cpt_entry("neutral", vec![0.3, 0.5, 0.2])
357 .with_cpt_entry("bad", vec![0.1, 0.3, 0.6]),
358 )
359 .with_node(
360 "default_probability",
361 BayesianNode::discrete(vec!["low", "medium", "high"])
362 .with_parents(vec!["company_revenue"])
363 .with_cpt_entry("high", vec![0.8, 0.15, 0.05])
364 .with_cpt_entry("medium", vec![0.4, 0.4, 0.2])
365 .with_cpt_entry("low", vec![0.1, 0.3, 0.6]),
366 )
367 }
368
369 #[test]
370 fn test_config_validation() {
371 let config = create_credit_risk_network();
372 assert!(config.validate().is_ok());
373 }
374
375 #[test]
376 fn test_empty_network_rejected() {
377 let config = BayesianConfig::new("Empty");
378 assert!(config.validate().is_err());
379 }
380
381 #[test]
382 fn test_missing_parent_rejected() {
383 let config = BayesianConfig::new("Bad Ref").with_node(
384 "child",
385 BayesianNode::discrete(vec!["a", "b"])
386 .with_parents(vec!["nonexistent"])
387 .with_cpt_entry("x", vec![0.5, 0.5]),
388 );
389
390 assert!(config.validate().is_err());
391 }
392
393 #[test]
394 fn test_invalid_prior_sum_rejected() {
395 let config = BayesianConfig::new("Bad Prior").with_node(
396 "node",
397 BayesianNode::discrete(vec!["a", "b"]).with_prior(vec![0.3, 0.3]),
398 );
399
400 assert!(config.validate().is_err());
401 }
402
403 #[test]
404 fn test_topological_order() {
405 let config = create_credit_risk_network();
406 let order = config.topological_order();
407
408 let ec_idx = order
410 .iter()
411 .position(|n| n == "economic_conditions")
412 .unwrap();
413 let cr_idx = order.iter().position(|n| n == "company_revenue").unwrap();
414 let dp_idx = order
415 .iter()
416 .position(|n| n == "default_probability")
417 .unwrap();
418
419 assert!(ec_idx < cr_idx);
420 assert!(cr_idx < dp_idx);
421 }
422
423 #[test]
424 fn test_root_nodes() {
425 let config = create_credit_risk_network();
426 let roots = config.root_nodes();
427
428 assert_eq!(roots.len(), 1);
429 assert!(roots.contains(&"economic_conditions"));
430 }
431}