Skip to main content

mollendorff_forge/decision_trees/
config.rs

1//! Decision Tree Configuration
2//!
3//! Handles parsing and validation of decision tree structures from YAML.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Type of node in the decision tree
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum NodeType {
12    /// Choice point (we control)
13    Decision,
14    /// Uncertainty (we don't control)
15    Chance,
16    /// End state with value
17    Terminal,
18}
19
20/// A branch from a node
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Branch {
23    /// Cost incurred when taking this branch (for decision nodes)
24    #[serde(default)]
25    pub cost: f64,
26    /// Probability of this outcome (for chance nodes, must sum to 1.0)
27    #[serde(default)]
28    pub probability: f64,
29    /// Terminal value if this is an endpoint
30    pub value: Option<f64>,
31    /// Next node reference if not terminal
32    pub next: Option<String>,
33}
34
35impl Branch {
36    /// Create a terminal branch with a value
37    #[must_use]
38    pub const fn terminal(value: f64) -> Self {
39        Self {
40            cost: 0.0,
41            probability: 0.0,
42            value: Some(value),
43            next: None,
44        }
45    }
46
47    /// Create a continuation branch
48    #[must_use]
49    pub fn continuation(next: &str) -> Self {
50        Self {
51            cost: 0.0,
52            probability: 0.0,
53            value: None,
54            next: Some(next.to_string()),
55        }
56    }
57
58    /// Add a cost to this branch
59    #[must_use]
60    pub const fn with_cost(mut self, cost: f64) -> Self {
61        self.cost = cost;
62        self
63    }
64
65    /// Add a probability to this branch
66    #[must_use]
67    pub const fn with_probability(mut self, probability: f64) -> Self {
68        self.probability = probability;
69        self
70    }
71
72    /// Check if this branch is terminal
73    #[must_use]
74    pub const fn is_terminal(&self) -> bool {
75        self.value.is_some() && self.next.is_none()
76    }
77}
78
79/// A node in the decision tree
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct Node {
82    /// Type of node
83    #[serde(rename = "type")]
84    pub node_type: NodeType,
85    /// Human-readable name
86    #[serde(default)]
87    pub name: String,
88    /// Branches from this node
89    pub branches: HashMap<String, Branch>,
90}
91
92impl Node {
93    /// Create a new decision node
94    #[must_use]
95    pub fn decision(name: &str) -> Self {
96        Self {
97            node_type: NodeType::Decision,
98            name: name.to_string(),
99            branches: HashMap::new(),
100        }
101    }
102
103    /// Create a new chance node
104    #[must_use]
105    pub fn chance(name: &str) -> Self {
106        Self {
107            node_type: NodeType::Chance,
108            name: name.to_string(),
109            branches: HashMap::new(),
110        }
111    }
112
113    /// Add a branch to this node
114    #[must_use]
115    pub fn with_branch(mut self, name: &str, branch: Branch) -> Self {
116        self.branches.insert(name.to_string(), branch);
117        self
118    }
119
120    /// Validate node structure
121    ///
122    /// # Errors
123    ///
124    /// Returns an error if the node has no branches or chance node probabilities
125    /// do not sum to 1.0.
126    pub fn validate(&self) -> Result<(), String> {
127        const TOLERANCE: f64 = 0.001;
128
129        if self.branches.is_empty() {
130            return Err(format!("Node '{}' has no branches", self.name));
131        }
132
133        if self.node_type == NodeType::Chance {
134            let total_prob: f64 = self.branches.values().map(|b| b.probability).sum();
135            if (total_prob - 1.0).abs() > TOLERANCE {
136                return Err(format!(
137                    "Chance node '{}' probabilities must sum to 1.0, got {:.4}",
138                    self.name, total_prob
139                ));
140            }
141        }
142
143        Ok(())
144    }
145}
146
147/// Configuration for a decision tree
148#[derive(Debug, Clone, Default, Serialize, Deserialize)]
149pub struct DecisionTreeConfig {
150    /// Name of the decision tree
151    #[serde(default)]
152    pub name: String,
153    /// Root node definition
154    pub root: Option<Node>,
155    /// Additional nodes by name
156    #[serde(default)]
157    pub nodes: HashMap<String, Node>,
158}
159
160impl DecisionTreeConfig {
161    /// Create a new empty configuration
162    #[must_use]
163    pub fn new(name: &str) -> Self {
164        Self {
165            name: name.to_string(),
166            root: None,
167            nodes: HashMap::new(),
168        }
169    }
170
171    /// Set the root node
172    #[must_use]
173    pub fn with_root(mut self, root: Node) -> Self {
174        self.root = Some(root);
175        self
176    }
177
178    /// Add a named node
179    #[must_use]
180    pub fn with_node(mut self, name: &str, node: Node) -> Self {
181        self.nodes.insert(name.to_string(), node);
182        self
183    }
184
185    /// Validate the configuration
186    ///
187    /// # Errors
188    ///
189    /// Returns an error if the tree has no root, any node is invalid,
190    /// references are broken, or the graph contains cycles.
191    pub fn validate(&self) -> Result<(), String> {
192        let root = self.root.as_ref().ok_or("No root node defined")?;
193        root.validate()?;
194
195        // Validate all referenced nodes exist
196        self.validate_references(root)?;
197
198        // Validate all nodes
199        for (name, node) in &self.nodes {
200            node.validate().map_err(|e| format!("Node '{name}': {e}"))?;
201            self.validate_references(node)?;
202        }
203
204        // Check for cycles
205        self.check_cycles()?;
206
207        Ok(())
208    }
209
210    /// Validate that all referenced nodes exist
211    fn validate_references(&self, node: &Node) -> Result<(), String> {
212        for (branch_name, branch) in &node.branches {
213            if let Some(ref next) = branch.next {
214                if !self.nodes.contains_key(next) {
215                    return Err(format!(
216                        "Branch '{branch_name}' references non-existent node '{next}'"
217                    ));
218                }
219            }
220        }
221        Ok(())
222    }
223
224    /// Check for cycles in the tree (must be a DAG)
225    fn check_cycles(&self) -> Result<(), String> {
226        let mut visited = std::collections::HashSet::new();
227        let mut stack = std::collections::HashSet::new();
228
229        if let Some(ref root) = self.root {
230            self.dfs_cycle_check("root", root, &mut visited, &mut stack)?;
231        }
232
233        Ok(())
234    }
235
236    fn dfs_cycle_check(
237        &self,
238        name: &str,
239        node: &Node,
240        visited: &mut std::collections::HashSet<String>,
241        stack: &mut std::collections::HashSet<String>,
242    ) -> Result<(), String> {
243        if stack.contains(name) {
244            return Err(format!("Cycle detected involving node '{name}'"));
245        }
246        if visited.contains(name) {
247            return Ok(());
248        }
249
250        visited.insert(name.to_string());
251        stack.insert(name.to_string());
252
253        for branch in node.branches.values() {
254            if let Some(ref next) = branch.next {
255                if let Some(next_node) = self.nodes.get(next) {
256                    self.dfs_cycle_check(next, next_node, visited, stack)?;
257                }
258            }
259        }
260
261        stack.remove(name);
262        Ok(())
263    }
264
265    /// Get a node by name
266    #[must_use]
267    pub fn get_node(&self, name: &str) -> Option<&Node> {
268        self.nodes.get(name)
269    }
270}
271
272#[cfg(test)]
273mod config_tests {
274    use super::*;
275
276    fn create_rnd_tree() -> DecisionTreeConfig {
277        DecisionTreeConfig::new("R&D Investment")
278            .with_root(
279                Node::decision("Invest in R&D?")
280                    .with_branch(
281                        "invest",
282                        Branch::continuation("tech_outcome").with_cost(2_000_000.0),
283                    )
284                    .with_branch("dont_invest", Branch::terminal(0.0)),
285            )
286            .with_node(
287                "tech_outcome",
288                Node::chance("Technology works?")
289                    .with_branch(
290                        "success",
291                        Branch::continuation("commercialize").with_probability(0.60),
292                    )
293                    .with_branch(
294                        "failure",
295                        Branch::terminal(-2_000_000.0).with_probability(0.40),
296                    ),
297            )
298            .with_node(
299                "commercialize",
300                Node::decision("How to commercialize?")
301                    .with_branch("license", Branch::terminal(5_000_000.0))
302                    .with_branch(
303                        "manufacture",
304                        Branch::terminal(8_000_000.0).with_cost(3_000_000.0),
305                    ),
306            )
307    }
308
309    #[test]
310    fn test_tree_config_validation() {
311        let tree = create_rnd_tree();
312        assert!(tree.validate().is_ok());
313    }
314
315    #[test]
316    fn test_missing_root_rejected() {
317        let tree = DecisionTreeConfig::new("Empty");
318        let result = tree.validate();
319        assert!(result.is_err());
320        assert!(result.unwrap_err().contains("No root node"));
321    }
322
323    #[test]
324    fn test_invalid_reference_rejected() {
325        let tree = DecisionTreeConfig::new("Bad Ref").with_root(
326            Node::decision("Start").with_branch("go", Branch::continuation("nonexistent")),
327        );
328
329        let result = tree.validate();
330        assert!(result.is_err());
331        assert!(result.unwrap_err().contains("non-existent node"));
332    }
333
334    #[test]
335    fn test_chance_probabilities_must_sum_to_one() {
336        let tree = DecisionTreeConfig::new("Bad Probs").with_root(
337            Node::chance("Coin flip")
338                .with_branch("heads", Branch::terminal(100.0).with_probability(0.5))
339                .with_branch("tails", Branch::terminal(0.0).with_probability(0.3)),
340        );
341
342        let result = tree.validate();
343        assert!(result.is_err());
344        assert!(result.unwrap_err().contains("sum to 1.0"));
345    }
346
347    #[test]
348    fn test_cycle_detection() {
349        let tree = DecisionTreeConfig::new("Cycle")
350            .with_root(Node::decision("A").with_branch("go", Branch::continuation("b")))
351            .with_node(
352                "b",
353                Node::decision("B").with_branch("back", Branch::continuation("b")),
354            );
355
356        let result = tree.validate();
357        assert!(result.is_err());
358        assert!(result.unwrap_err().contains("Cycle"));
359    }
360}