mollendorff_forge/decision_trees/
config.rs1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase")]
11pub enum NodeType {
12 Decision,
14 Chance,
16 Terminal,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Branch {
23 #[serde(default)]
25 pub cost: f64,
26 #[serde(default)]
28 pub probability: f64,
29 pub value: Option<f64>,
31 pub next: Option<String>,
33}
34
35impl Branch {
36 #[must_use]
38 pub const fn terminal(value: f64) -> Self {
39 Self {
40 cost: 0.0,
41 probability: 0.0,
42 value: Some(value),
43 next: None,
44 }
45 }
46
47 #[must_use]
49 pub fn continuation(next: &str) -> Self {
50 Self {
51 cost: 0.0,
52 probability: 0.0,
53 value: None,
54 next: Some(next.to_string()),
55 }
56 }
57
58 #[must_use]
60 pub const fn with_cost(mut self, cost: f64) -> Self {
61 self.cost = cost;
62 self
63 }
64
65 #[must_use]
67 pub const fn with_probability(mut self, probability: f64) -> Self {
68 self.probability = probability;
69 self
70 }
71
72 #[must_use]
74 pub const fn is_terminal(&self) -> bool {
75 self.value.is_some() && self.next.is_none()
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct Node {
82 #[serde(rename = "type")]
84 pub node_type: NodeType,
85 #[serde(default)]
87 pub name: String,
88 pub branches: HashMap<String, Branch>,
90}
91
92impl Node {
93 #[must_use]
95 pub fn decision(name: &str) -> Self {
96 Self {
97 node_type: NodeType::Decision,
98 name: name.to_string(),
99 branches: HashMap::new(),
100 }
101 }
102
103 #[must_use]
105 pub fn chance(name: &str) -> Self {
106 Self {
107 node_type: NodeType::Chance,
108 name: name.to_string(),
109 branches: HashMap::new(),
110 }
111 }
112
113 #[must_use]
115 pub fn with_branch(mut self, name: &str, branch: Branch) -> Self {
116 self.branches.insert(name.to_string(), branch);
117 self
118 }
119
120 pub fn validate(&self) -> Result<(), String> {
127 const TOLERANCE: f64 = 0.001;
128
129 if self.branches.is_empty() {
130 return Err(format!("Node '{}' has no branches", self.name));
131 }
132
133 if self.node_type == NodeType::Chance {
134 let total_prob: f64 = self.branches.values().map(|b| b.probability).sum();
135 if (total_prob - 1.0).abs() > TOLERANCE {
136 return Err(format!(
137 "Chance node '{}' probabilities must sum to 1.0, got {:.4}",
138 self.name, total_prob
139 ));
140 }
141 }
142
143 Ok(())
144 }
145}
146
147#[derive(Debug, Clone, Default, Serialize, Deserialize)]
149pub struct DecisionTreeConfig {
150 #[serde(default)]
152 pub name: String,
153 pub root: Option<Node>,
155 #[serde(default)]
157 pub nodes: HashMap<String, Node>,
158}
159
160impl DecisionTreeConfig {
161 #[must_use]
163 pub fn new(name: &str) -> Self {
164 Self {
165 name: name.to_string(),
166 root: None,
167 nodes: HashMap::new(),
168 }
169 }
170
171 #[must_use]
173 pub fn with_root(mut self, root: Node) -> Self {
174 self.root = Some(root);
175 self
176 }
177
178 #[must_use]
180 pub fn with_node(mut self, name: &str, node: Node) -> Self {
181 self.nodes.insert(name.to_string(), node);
182 self
183 }
184
185 pub fn validate(&self) -> Result<(), String> {
192 let root = self.root.as_ref().ok_or("No root node defined")?;
193 root.validate()?;
194
195 self.validate_references(root)?;
197
198 for (name, node) in &self.nodes {
200 node.validate().map_err(|e| format!("Node '{name}': {e}"))?;
201 self.validate_references(node)?;
202 }
203
204 self.check_cycles()?;
206
207 Ok(())
208 }
209
210 fn validate_references(&self, node: &Node) -> Result<(), String> {
212 for (branch_name, branch) in &node.branches {
213 if let Some(ref next) = branch.next {
214 if !self.nodes.contains_key(next) {
215 return Err(format!(
216 "Branch '{branch_name}' references non-existent node '{next}'"
217 ));
218 }
219 }
220 }
221 Ok(())
222 }
223
224 fn check_cycles(&self) -> Result<(), String> {
226 let mut visited = std::collections::HashSet::new();
227 let mut stack = std::collections::HashSet::new();
228
229 if let Some(ref root) = self.root {
230 self.dfs_cycle_check("root", root, &mut visited, &mut stack)?;
231 }
232
233 Ok(())
234 }
235
236 fn dfs_cycle_check(
237 &self,
238 name: &str,
239 node: &Node,
240 visited: &mut std::collections::HashSet<String>,
241 stack: &mut std::collections::HashSet<String>,
242 ) -> Result<(), String> {
243 if stack.contains(name) {
244 return Err(format!("Cycle detected involving node '{name}'"));
245 }
246 if visited.contains(name) {
247 return Ok(());
248 }
249
250 visited.insert(name.to_string());
251 stack.insert(name.to_string());
252
253 for branch in node.branches.values() {
254 if let Some(ref next) = branch.next {
255 if let Some(next_node) = self.nodes.get(next) {
256 self.dfs_cycle_check(next, next_node, visited, stack)?;
257 }
258 }
259 }
260
261 stack.remove(name);
262 Ok(())
263 }
264
265 #[must_use]
267 pub fn get_node(&self, name: &str) -> Option<&Node> {
268 self.nodes.get(name)
269 }
270}
271
272#[cfg(test)]
273mod config_tests {
274 use super::*;
275
276 fn create_rnd_tree() -> DecisionTreeConfig {
277 DecisionTreeConfig::new("R&D Investment")
278 .with_root(
279 Node::decision("Invest in R&D?")
280 .with_branch(
281 "invest",
282 Branch::continuation("tech_outcome").with_cost(2_000_000.0),
283 )
284 .with_branch("dont_invest", Branch::terminal(0.0)),
285 )
286 .with_node(
287 "tech_outcome",
288 Node::chance("Technology works?")
289 .with_branch(
290 "success",
291 Branch::continuation("commercialize").with_probability(0.60),
292 )
293 .with_branch(
294 "failure",
295 Branch::terminal(-2_000_000.0).with_probability(0.40),
296 ),
297 )
298 .with_node(
299 "commercialize",
300 Node::decision("How to commercialize?")
301 .with_branch("license", Branch::terminal(5_000_000.0))
302 .with_branch(
303 "manufacture",
304 Branch::terminal(8_000_000.0).with_cost(3_000_000.0),
305 ),
306 )
307 }
308
309 #[test]
310 fn test_tree_config_validation() {
311 let tree = create_rnd_tree();
312 assert!(tree.validate().is_ok());
313 }
314
315 #[test]
316 fn test_missing_root_rejected() {
317 let tree = DecisionTreeConfig::new("Empty");
318 let result = tree.validate();
319 assert!(result.is_err());
320 assert!(result.unwrap_err().contains("No root node"));
321 }
322
323 #[test]
324 fn test_invalid_reference_rejected() {
325 let tree = DecisionTreeConfig::new("Bad Ref").with_root(
326 Node::decision("Start").with_branch("go", Branch::continuation("nonexistent")),
327 );
328
329 let result = tree.validate();
330 assert!(result.is_err());
331 assert!(result.unwrap_err().contains("non-existent node"));
332 }
333
334 #[test]
335 fn test_chance_probabilities_must_sum_to_one() {
336 let tree = DecisionTreeConfig::new("Bad Probs").with_root(
337 Node::chance("Coin flip")
338 .with_branch("heads", Branch::terminal(100.0).with_probability(0.5))
339 .with_branch("tails", Branch::terminal(0.0).with_probability(0.3)),
340 );
341
342 let result = tree.validate();
343 assert!(result.is_err());
344 assert!(result.unwrap_err().contains("sum to 1.0"));
345 }
346
347 #[test]
348 fn test_cycle_detection() {
349 let tree = DecisionTreeConfig::new("Cycle")
350 .with_root(Node::decision("A").with_branch("go", Branch::continuation("b")))
351 .with_node(
352 "b",
353 Node::decision("B").with_branch("back", Branch::continuation("b")),
354 );
355
356 let result = tree.validate();
357 assert!(result.is_err());
358 assert!(result.unwrap_err().contains("Cycle"));
359 }
360}