Skip to main content

mollendorff_forge/bayesian/
config.rs

1//! Bayesian Network Configuration
2//!
3//! Handles parsing and validation of Bayesian network definitions.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Type of node in the Bayesian network
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
10#[serde(rename_all = "lowercase")]
11pub enum NodeType {
12    /// Discrete node with finite states
13    #[default]
14    Discrete,
15    /// Continuous node (Gaussian)
16    Continuous,
17}
18
19/// A node in the Bayesian network
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct BayesianNode {
22    /// Node type
23    #[serde(default, rename = "type")]
24    pub node_type: NodeType,
25    /// Possible states (for discrete nodes)
26    #[serde(default)]
27    pub states: Vec<String>,
28    /// Prior probabilities (for root nodes)
29    #[serde(default)]
30    pub prior: Vec<f64>,
31    /// Parent node names
32    #[serde(default)]
33    pub parents: Vec<String>,
34    /// Conditional probability table (CPT)
35    /// Keys are parent state combinations, values are probabilities for this node's states
36    #[serde(default)]
37    pub cpt: HashMap<String, Vec<f64>>,
38    /// For continuous nodes: mean
39    #[serde(default)]
40    pub mean: f64,
41    /// For continuous nodes: standard deviation
42    #[serde(default)]
43    pub std: f64,
44}
45
46impl BayesianNode {
47    /// Create a new discrete node with states
48    pub fn discrete(states: Vec<&str>) -> Self {
49        Self {
50            node_type: NodeType::Discrete,
51            states: states
52                .into_iter()
53                .map(std::string::ToString::to_string)
54                .collect(),
55            prior: Vec::new(),
56            parents: Vec::new(),
57            cpt: HashMap::new(),
58            mean: 0.0,
59            std: 1.0,
60        }
61    }
62
63    /// Create a new continuous node
64    #[must_use]
65    pub fn continuous(mean: f64, std: f64) -> Self {
66        Self {
67            node_type: NodeType::Continuous,
68            states: Vec::new(),
69            prior: Vec::new(),
70            parents: Vec::new(),
71            cpt: HashMap::new(),
72            mean,
73            std,
74        }
75    }
76
77    /// Set prior probabilities (for root nodes)
78    #[must_use]
79    pub fn with_prior(mut self, prior: Vec<f64>) -> Self {
80        self.prior = prior;
81        self
82    }
83
84    /// Set parent nodes
85    #[must_use]
86    pub fn with_parents(mut self, parents: Vec<&str>) -> Self {
87        self.parents = parents
88            .into_iter()
89            .map(std::string::ToString::to_string)
90            .collect();
91        self
92    }
93
94    /// Add CPT entry
95    #[must_use]
96    pub fn with_cpt_entry(mut self, parent_state: &str, probs: Vec<f64>) -> Self {
97        self.cpt.insert(parent_state.to_string(), probs);
98        self
99    }
100
101    /// Validate the node
102    ///
103    /// # Errors
104    ///
105    /// Returns an error if the node configuration is invalid (e.g., missing
106    /// states, incorrect prior/CPT dimensions, or probabilities not summing to 1).
107    pub fn validate(&self, name: &str) -> Result<(), String> {
108        match self.node_type {
109            NodeType::Discrete => self.validate_discrete(name),
110            NodeType::Continuous => self.validate_continuous(name),
111        }
112    }
113
114    fn validate_discrete(&self, name: &str) -> Result<(), String> {
115        if self.states.is_empty() {
116            return Err(format!("Node '{name}': discrete node must have states"));
117        }
118
119        // If root node, check prior
120        if self.parents.is_empty() {
121            if self.prior.is_empty() {
122                return Err(format!(
123                    "Node '{name}': root node must have prior probabilities"
124                ));
125            }
126            if self.prior.len() != self.states.len() {
127                return Err(format!(
128                    "Node '{}': prior length ({}) must match states ({})",
129                    name,
130                    self.prior.len(),
131                    self.states.len()
132                ));
133            }
134            let sum: f64 = self.prior.iter().sum();
135            if (sum - 1.0).abs() > 0.001 {
136                return Err(format!(
137                    "Node '{name}': prior probabilities must sum to 1.0, got {sum}"
138                ));
139            }
140        } else {
141            // Child node, check CPT
142            if self.cpt.is_empty() {
143                return Err(format!("Node '{name}': child node must have CPT"));
144            }
145            for (key, probs) in &self.cpt {
146                if probs.len() != self.states.len() {
147                    return Err(format!(
148                        "Node '{}': CPT entry '{}' length ({}) must match states ({})",
149                        name,
150                        key,
151                        probs.len(),
152                        self.states.len()
153                    ));
154                }
155                let sum: f64 = probs.iter().sum();
156                if (sum - 1.0).abs() > 0.001 {
157                    return Err(format!(
158                        "Node '{name}': CPT entry '{key}' must sum to 1.0, got {sum}"
159                    ));
160                }
161            }
162        }
163
164        Ok(())
165    }
166
167    fn validate_continuous(&self, name: &str) -> Result<(), String> {
168        if self.std <= 0.0 {
169            return Err(format!(
170                "Node '{name}': standard deviation must be positive"
171            ));
172        }
173        Ok(())
174    }
175
176    /// Check if this is a root node
177    #[must_use]
178    pub const fn is_root(&self) -> bool {
179        self.parents.is_empty()
180    }
181
182    /// Get probability for a state given parent state
183    #[must_use]
184    pub fn get_probability(&self, state_idx: usize, parent_state: Option<&str>) -> f64 {
185        if self.is_root() {
186            self.prior.get(state_idx).copied().unwrap_or(0.0)
187        } else if let Some(ps) = parent_state {
188            self.cpt
189                .get(ps)
190                .and_then(|probs| probs.get(state_idx))
191                .copied()
192                .unwrap_or(0.0)
193        } else {
194            0.0
195        }
196    }
197}
198
199/// Configuration for Bayesian network
200#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201pub struct BayesianConfig {
202    /// Network name
203    #[serde(default)]
204    pub name: String,
205    /// Nodes by name
206    #[serde(default)]
207    pub nodes: HashMap<String, BayesianNode>,
208}
209
210impl BayesianConfig {
211    /// Create a new configuration
212    #[must_use]
213    pub fn new(name: &str) -> Self {
214        Self {
215            name: name.to_string(),
216            nodes: HashMap::new(),
217        }
218    }
219
220    /// Add a node
221    #[must_use]
222    pub fn with_node(mut self, name: &str, node: BayesianNode) -> Self {
223        self.nodes.insert(name.to_string(), node);
224        self
225    }
226
227    /// Validate the configuration
228    ///
229    /// # Errors
230    ///
231    /// Returns an error if the network is empty, any node is invalid,
232    /// parent references are missing, or the graph contains cycles.
233    pub fn validate(&self) -> Result<(), String> {
234        if self.nodes.is_empty() {
235            return Err("Network must have at least one node".to_string());
236        }
237
238        // Validate each node
239        for (name, node) in &self.nodes {
240            node.validate(name)?;
241
242            // Check parent references
243            for parent in &node.parents {
244                if !self.nodes.contains_key(parent) {
245                    return Err(format!(
246                        "Node '{name}' references non-existent parent '{parent}'"
247                    ));
248                }
249            }
250        }
251
252        // Check for cycles
253        self.check_cycles()?;
254
255        Ok(())
256    }
257
258    /// Check for cycles (must be a DAG)
259    fn check_cycles(&self) -> Result<(), String> {
260        let mut visited = std::collections::HashSet::new();
261        let mut stack = std::collections::HashSet::new();
262
263        for name in self.nodes.keys() {
264            self.dfs_cycle_check(name, &mut visited, &mut stack)?;
265        }
266
267        Ok(())
268    }
269
270    fn dfs_cycle_check(
271        &self,
272        name: &str,
273        visited: &mut std::collections::HashSet<String>,
274        stack: &mut std::collections::HashSet<String>,
275    ) -> Result<(), String> {
276        if stack.contains(name) {
277            return Err(format!("Cycle detected involving node '{name}'"));
278        }
279        if visited.contains(name) {
280            return Ok(());
281        }
282
283        visited.insert(name.to_string());
284        stack.insert(name.to_string());
285
286        if let Some(node) = self.nodes.get(name) {
287            for parent in &node.parents {
288                self.dfs_cycle_check(parent, visited, stack)?;
289            }
290        }
291
292        stack.remove(name);
293        Ok(())
294    }
295
296    /// Get topological order of nodes
297    #[must_use]
298    pub fn topological_order(&self) -> Vec<String> {
299        fn visit(
300            name: &str,
301            config: &BayesianConfig,
302            visited: &mut std::collections::HashSet<String>,
303            order: &mut Vec<String>,
304        ) {
305            if visited.contains(name) {
306                return;
307            }
308            visited.insert(name.to_string());
309
310            if let Some(node) = config.nodes.get(name) {
311                for parent in &node.parents {
312                    visit(parent, config, visited, order);
313                }
314            }
315
316            order.push(name.to_string());
317        }
318
319        let mut order = Vec::new();
320        let mut visited = std::collections::HashSet::new();
321
322        for name in self.nodes.keys() {
323            visit(name, self, &mut visited, &mut order);
324        }
325
326        order
327    }
328
329    /// Get root nodes
330    #[must_use]
331    pub fn root_nodes(&self) -> Vec<&str> {
332        self.nodes
333            .iter()
334            .filter(|(_, node)| node.is_root())
335            .map(|(name, _)| name.as_str())
336            .collect()
337    }
338}
339
340#[cfg(test)]
341mod config_tests {
342    use super::*;
343
344    fn create_credit_risk_network() -> BayesianConfig {
345        BayesianConfig::new("Credit Risk")
346            .with_node(
347                "economic_conditions",
348                BayesianNode::discrete(vec!["good", "neutral", "bad"])
349                    .with_prior(vec![0.3, 0.5, 0.2]),
350            )
351            .with_node(
352                "company_revenue",
353                BayesianNode::discrete(vec!["high", "medium", "low"])
354                    .with_parents(vec!["economic_conditions"])
355                    .with_cpt_entry("good", vec![0.6, 0.3, 0.1])
356                    .with_cpt_entry("neutral", vec![0.3, 0.5, 0.2])
357                    .with_cpt_entry("bad", vec![0.1, 0.3, 0.6]),
358            )
359            .with_node(
360                "default_probability",
361                BayesianNode::discrete(vec!["low", "medium", "high"])
362                    .with_parents(vec!["company_revenue"])
363                    .with_cpt_entry("high", vec![0.8, 0.15, 0.05])
364                    .with_cpt_entry("medium", vec![0.4, 0.4, 0.2])
365                    .with_cpt_entry("low", vec![0.1, 0.3, 0.6]),
366            )
367    }
368
369    #[test]
370    fn test_config_validation() {
371        let config = create_credit_risk_network();
372        assert!(config.validate().is_ok());
373    }
374
375    #[test]
376    fn test_empty_network_rejected() {
377        let config = BayesianConfig::new("Empty");
378        assert!(config.validate().is_err());
379    }
380
381    #[test]
382    fn test_missing_parent_rejected() {
383        let config = BayesianConfig::new("Bad Ref").with_node(
384            "child",
385            BayesianNode::discrete(vec!["a", "b"])
386                .with_parents(vec!["nonexistent"])
387                .with_cpt_entry("x", vec![0.5, 0.5]),
388        );
389
390        assert!(config.validate().is_err());
391    }
392
393    #[test]
394    fn test_invalid_prior_sum_rejected() {
395        let config = BayesianConfig::new("Bad Prior").with_node(
396            "node",
397            BayesianNode::discrete(vec!["a", "b"]).with_prior(vec![0.3, 0.3]),
398        );
399
400        assert!(config.validate().is_err());
401    }
402
403    #[test]
404    fn test_topological_order() {
405        let config = create_credit_risk_network();
406        let order = config.topological_order();
407
408        // economic_conditions should come before company_revenue
409        let ec_idx = order
410            .iter()
411            .position(|n| n == "economic_conditions")
412            .unwrap();
413        let cr_idx = order.iter().position(|n| n == "company_revenue").unwrap();
414        let dp_idx = order
415            .iter()
416            .position(|n| n == "default_probability")
417            .unwrap();
418
419        assert!(ec_idx < cr_idx);
420        assert!(cr_idx < dp_idx);
421    }
422
423    #[test]
424    fn test_root_nodes() {
425        let config = create_credit_risk_network();
426        let roots = config.root_nodes();
427
428        assert_eq!(roots.len(), 1);
429        assert!(roots.contains(&"economic_conditions"));
430    }
431}