mollendorff_forge/decision_trees/
engine.rs1use super::config::{Branch, DecisionTreeConfig, Node, NodeType};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct NodeResult {
12 pub name: String,
14 pub node_type: NodeType,
16 pub expected_value: f64,
18 pub optimal_choice: Option<String>,
20 pub branch_values: HashMap<String, f64>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TreeResult {
27 pub name: String,
29 pub root_expected_value: f64,
31 pub node_results: HashMap<String, NodeResult>,
33 pub optimal_path: Vec<String>,
35 pub decision_policy: HashMap<String, String>,
37 pub risk_profile: RiskProfile,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct RiskProfile {
44 pub best_case: f64,
46 pub worst_case: f64,
48 pub probability_positive: f64,
50}
51
52impl TreeResult {
53 #[must_use]
55 pub fn to_yaml(&self) -> String {
56 serde_yaml_ng::to_string(self).unwrap_or_else(|_| "# Error serializing results".to_string())
57 }
58
59 pub fn to_json(&self) -> Result<String, serde_json::Error> {
65 serde_json::to_string_pretty(self)
66 }
67}
68
69pub struct DecisionTreeEngine {
71 config: DecisionTreeConfig,
72}
73
74impl DecisionTreeEngine {
75 pub fn new(config: DecisionTreeConfig) -> Result<Self, String> {
81 config.validate()?;
82 Ok(Self { config })
83 }
84
85 pub fn analyze(&self) -> Result<TreeResult, String> {
91 let mut node_results = HashMap::new();
92 let mut all_terminal_values = Vec::new();
93
94 let root = self.config.root.as_ref().ok_or("No root node")?;
96 let root_result =
97 self.evaluate_node("root", root, &mut node_results, &mut all_terminal_values)?;
98
99 let optimal_path = self.build_optimal_path(&node_results);
101
102 let decision_policy = Self::build_decision_policy(&node_results);
104
105 let risk_profile = Self::calculate_risk_profile(&all_terminal_values);
107
108 Ok(TreeResult {
109 name: self.config.name.clone(),
110 root_expected_value: root_result.expected_value,
111 node_results,
112 optimal_path,
113 decision_policy,
114 risk_profile,
115 })
116 }
117
118 fn evaluate_node(
120 &self,
121 name: &str,
122 node: &Node,
123 results: &mut HashMap<String, NodeResult>,
124 all_terminal_values: &mut Vec<(f64, f64)>, ) -> Result<NodeResult, String> {
126 let mut branch_values = HashMap::new();
127
128 for (branch_name, branch) in &node.branches {
130 let branch_value =
131 self.evaluate_branch(branch, results, all_terminal_values, node.node_type)?;
132 branch_values.insert(branch_name.clone(), branch_value);
133 }
134
135 let (expected_value, optimal_choice) = match node.node_type {
137 NodeType::Decision => {
138 let (best_name, best_value) = branch_values
141 .iter()
142 .max_by(|(name_a, a), (name_b, b)| {
143 match a.partial_cmp(b).unwrap() {
144 std::cmp::Ordering::Equal => name_b.cmp(name_a),
146 other => other,
147 }
148 })
149 .map(|(n, v)| (n.clone(), *v))
150 .ok_or("No branches in decision node")?;
151 (best_value, Some(best_name))
152 },
153 NodeType::Chance => {
154 let ev: f64 = node
156 .branches
157 .iter()
158 .map(|(branch_name, branch)| {
159 branch.probability * branch_values.get(branch_name).unwrap_or(&0.0)
160 })
161 .sum();
162 (ev, None)
163 },
164 NodeType::Terminal => {
165 (0.0, None)
167 },
168 };
169
170 let result = NodeResult {
171 name: node.name.clone(),
172 node_type: node.node_type,
173 expected_value,
174 optimal_choice,
175 branch_values,
176 };
177
178 results.insert(name.to_string(), result.clone());
179 Ok(result)
180 }
181
182 fn evaluate_branch(
184 &self,
185 branch: &Branch,
186 results: &mut HashMap<String, NodeResult>,
187 all_terminal_values: &mut Vec<(f64, f64)>,
188 parent_type: NodeType,
189 ) -> Result<f64, String> {
190 let base_value = if let Some(value) = branch.value {
191 let prob = if parent_type == NodeType::Chance {
193 branch.probability
194 } else {
195 1.0
196 };
197 all_terminal_values.push((value - branch.cost, prob));
198 value
199 } else if let Some(ref next) = branch.next {
200 let next_node = self
202 .config
203 .get_node(next)
204 .ok_or_else(|| format!("Node '{next}' not found"))?;
205 let next_result = self.evaluate_node(next, next_node, results, all_terminal_values)?;
206 next_result.expected_value
207 } else {
208 return Err("Branch has neither value nor next node".to_string());
209 };
210
211 Ok(base_value - branch.cost)
213 }
214
215 fn build_optimal_path(&self, results: &HashMap<String, NodeResult>) -> Vec<String> {
217 let mut path = Vec::new();
218
219 if let Some(root_result) = results.get("root") {
220 self.trace_optimal_path("root", root_result, results, &mut path);
221 }
222
223 path
224 }
225
226 fn trace_optimal_path(
227 &self,
228 name: &str,
229 result: &NodeResult,
230 results: &HashMap<String, NodeResult>,
231 path: &mut Vec<String>,
232 ) {
233 match result.node_type {
234 NodeType::Decision => {
235 if let Some(ref choice) = result.optimal_choice {
236 path.push(format!("{} → {}", result.name, choice));
237
238 if let Some(root) = &self.config.root {
240 if name == "root" {
241 if let Some(branch) = root.branches.get(choice) {
242 if let Some(ref next) = branch.next {
243 if let Some(next_result) = results.get(next) {
244 self.trace_optimal_path(next, next_result, results, path);
245 }
246 }
247 }
248 }
249 }
250
251 if let Some(node) = self.config.get_node(name) {
252 if let Some(branch) = node.branches.get(choice) {
253 if let Some(ref next) = branch.next {
254 if let Some(next_result) = results.get(next) {
255 self.trace_optimal_path(next, next_result, results, path);
256 }
257 }
258 }
259 }
260 }
261 },
262 NodeType::Chance => {
263 path.push(format!("{} → (await outcome)", result.name));
264 if let Some(node) = self.config.get_node(name) {
266 for (branch_name, branch) in &node.branches {
267 if let Some(ref next) = branch.next {
268 if let Some(next_result) = results.get(next) {
269 path.push(format!(" if {branch_name} →"));
270 let mut sub_path = Vec::new();
271 self.trace_optimal_path(next, next_result, results, &mut sub_path);
272 for p in sub_path {
273 path.push(format!(" {p}"));
274 }
275 }
276 }
277 }
278 }
279 },
280 NodeType::Terminal => {
281 },
283 }
284 }
285
286 fn build_decision_policy(results: &HashMap<String, NodeResult>) -> HashMap<String, String> {
288 let mut policy = HashMap::new();
289
290 for (name, result) in results {
291 if result.node_type == NodeType::Decision {
292 if let Some(ref choice) = result.optimal_choice {
293 policy.insert(name.clone(), choice.clone());
294 }
295 }
296 }
297
298 policy
299 }
300
301 fn calculate_risk_profile(terminal_values: &[(f64, f64)]) -> RiskProfile {
303 if terminal_values.is_empty() {
304 return RiskProfile {
305 best_case: 0.0,
306 worst_case: 0.0,
307 probability_positive: 0.0,
308 };
309 }
310
311 let best_case = terminal_values
312 .iter()
313 .map(|(v, _)| *v)
314 .fold(f64::NEG_INFINITY, f64::max);
315
316 let worst_case = terminal_values
317 .iter()
318 .map(|(v, _)| *v)
319 .fold(f64::INFINITY, f64::min);
320
321 let probability_positive = terminal_values
323 .iter()
324 .filter(|(v, _)| *v > 0.0)
325 .map(|(_, p)| *p)
326 .sum::<f64>()
327 .min(1.0);
328
329 RiskProfile {
330 best_case,
331 worst_case,
332 probability_positive,
333 }
334 }
335
336 #[must_use]
338 pub const fn config(&self) -> &DecisionTreeConfig {
339 &self.config
340 }
341}
342
343#[cfg(test)]
344#[allow(clippy::float_cmp)]
346mod engine_tests {
347 use super::*;
348
349 fn create_rnd_tree() -> DecisionTreeConfig {
350 DecisionTreeConfig::new("R&D Investment")
351 .with_root(
352 Node::decision("Invest in R&D?")
353 .with_branch(
354 "invest",
355 Branch::continuation("tech_outcome").with_cost(2_000_000.0),
356 )
357 .with_branch("dont_invest", Branch::terminal(0.0)),
358 )
359 .with_node(
360 "tech_outcome",
361 Node::chance("Technology works?")
362 .with_branch(
363 "success",
364 Branch::continuation("commercialize").with_probability(0.60),
365 )
366 .with_branch("failure", Branch::terminal(0.0).with_probability(0.40)),
367 )
368 .with_node(
369 "commercialize",
370 Node::decision("How to commercialize?")
371 .with_branch("license", Branch::terminal(5_000_000.0))
372 .with_branch(
373 "manufacture",
374 Branch::terminal(8_000_000.0).with_cost(3_000_000.0),
375 ),
376 )
377 }
378
379 #[test]
380 fn test_backward_induction() {
381 let config = create_rnd_tree();
382 let engine = DecisionTreeEngine::new(config).unwrap();
383 let result = engine.analyze().unwrap();
384
385 let commercialize = result.node_results.get("commercialize").unwrap();
387 assert_eq!(commercialize.expected_value, 5_000_000.0);
388 assert_eq!(commercialize.optimal_choice, Some("license".to_string()));
389
390 let tech = result.node_results.get("tech_outcome").unwrap();
392 assert!((tech.expected_value - 3_000_000.0).abs() < 0.01);
393
394 assert!((result.root_expected_value - 1_000_000.0).abs() < 0.01);
396 }
397
398 #[test]
399 fn test_decision_policy() {
400 let config = create_rnd_tree();
401 let engine = DecisionTreeEngine::new(config).unwrap();
402 let result = engine.analyze().unwrap();
403
404 assert_eq!(
405 result.decision_policy.get("root"),
406 Some(&"invest".to_string())
407 );
408 assert_eq!(
409 result.decision_policy.get("commercialize"),
410 Some(&"license".to_string())
411 );
412 }
413
414 #[test]
416 fn test_scipy_numpy_equivalence() {
417 let config = create_rnd_tree();
433 let engine = DecisionTreeEngine::new(config).unwrap();
434 let result = engine.analyze().unwrap();
435
436 assert!(
438 (result.root_expected_value - 1_000_000.0).abs() < 0.01,
439 "Root EV should be $1M, got {}",
440 result.root_expected_value
441 );
442 }
443
444 #[test]
445 fn test_simple_coin_flip() {
446 let config = DecisionTreeConfig::new("Coin Flip").with_root(
447 Node::chance("Flip coin")
448 .with_branch("heads", Branch::terminal(100.0).with_probability(0.5))
449 .with_branch("tails", Branch::terminal(0.0).with_probability(0.5)),
450 );
451
452 let engine = DecisionTreeEngine::new(config).unwrap();
453 let result = engine.analyze().unwrap();
454
455 assert!(
457 (result.root_expected_value - 50.0).abs() < 0.01,
458 "Expected 50, got {}",
459 result.root_expected_value
460 );
461 }
462
463 #[test]
464 fn test_yaml_export() {
465 let config = create_rnd_tree();
466 let engine = DecisionTreeEngine::new(config).unwrap();
467 let result = engine.analyze().unwrap();
468 let yaml = result.to_yaml();
469
470 assert!(yaml.contains("root_expected_value"));
471 assert!(yaml.contains("decision_policy"));
472 }
473}