Skip to main content

mollendorff_forge/decision_trees/
engine.rs

1//! Decision Tree Engine
2//!
3//! Executes backward induction to find optimal decisions and expected values.
4
5use super::config::{Branch, DecisionTreeConfig, Node, NodeType};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Result for a single node
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct NodeResult {
12    /// Node name
13    pub name: String,
14    /// Node type
15    pub node_type: NodeType,
16    /// Expected value at this node
17    pub expected_value: f64,
18    /// Optimal choice (for decision nodes)
19    pub optimal_choice: Option<String>,
20    /// Branch values
21    pub branch_values: HashMap<String, f64>,
22}
23
24/// Complete tree analysis result
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TreeResult {
27    /// Tree name
28    pub name: String,
29    /// Expected value at root
30    pub root_expected_value: f64,
31    /// Node results
32    pub node_results: HashMap<String, NodeResult>,
33    /// Optimal decision path
34    pub optimal_path: Vec<String>,
35    /// Decision policy (what to choose at each decision node)
36    pub decision_policy: HashMap<String, String>,
37    /// Risk profile
38    pub risk_profile: RiskProfile,
39}
40
41/// Risk profile showing outcome distribution
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct RiskProfile {
44    /// Best possible outcome
45    pub best_case: f64,
46    /// Worst possible outcome
47    pub worst_case: f64,
48    /// Probability of positive outcome
49    pub probability_positive: f64,
50}
51
52impl TreeResult {
53    /// Export results to YAML format
54    #[must_use]
55    pub fn to_yaml(&self) -> String {
56        serde_yaml_ng::to_string(self).unwrap_or_else(|_| "# Error serializing results".to_string())
57    }
58
59    /// Export results to JSON format
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if JSON serialization fails.
64    pub fn to_json(&self) -> Result<String, serde_json::Error> {
65        serde_json::to_string_pretty(self)
66    }
67}
68
69/// Decision Tree Engine
70pub struct DecisionTreeEngine {
71    config: DecisionTreeConfig,
72}
73
74impl DecisionTreeEngine {
75    /// Create a new decision tree engine
76    ///
77    /// # Errors
78    ///
79    /// Returns an error if the configuration is invalid.
80    pub fn new(config: DecisionTreeConfig) -> Result<Self, String> {
81        config.validate()?;
82        Ok(Self { config })
83    }
84
85    /// Analyze the decision tree using backward induction
86    ///
87    /// # Errors
88    ///
89    /// Returns an error if the tree has no root or a branch references a missing node.
90    pub fn analyze(&self) -> Result<TreeResult, String> {
91        let mut node_results = HashMap::new();
92        let mut all_terminal_values = Vec::new();
93
94        // Start backward induction from root
95        let root = self.config.root.as_ref().ok_or("No root node")?;
96        let root_result =
97            self.evaluate_node("root", root, &mut node_results, &mut all_terminal_values)?;
98
99        // Build optimal path
100        let optimal_path = self.build_optimal_path(&node_results);
101
102        // Build decision policy
103        let decision_policy = Self::build_decision_policy(&node_results);
104
105        // Calculate risk profile
106        let risk_profile = Self::calculate_risk_profile(&all_terminal_values);
107
108        Ok(TreeResult {
109            name: self.config.name.clone(),
110            root_expected_value: root_result.expected_value,
111            node_results,
112            optimal_path,
113            decision_policy,
114            risk_profile,
115        })
116    }
117
118    /// Evaluate a node recursively using backward induction
119    fn evaluate_node(
120        &self,
121        name: &str,
122        node: &Node,
123        results: &mut HashMap<String, NodeResult>,
124        all_terminal_values: &mut Vec<(f64, f64)>, // (value, probability)
125    ) -> Result<NodeResult, String> {
126        let mut branch_values = HashMap::new();
127
128        // Evaluate each branch
129        for (branch_name, branch) in &node.branches {
130            let branch_value =
131                self.evaluate_branch(branch, results, all_terminal_values, node.node_type)?;
132            branch_values.insert(branch_name.clone(), branch_value);
133        }
134
135        // Calculate expected value based on node type
136        let (expected_value, optimal_choice) = match node.node_type {
137            NodeType::Decision => {
138                // Decision node: choose maximum value branch
139                // Use alphabetical ordering as tie-breaker for deterministic results
140                let (best_name, best_value) = branch_values
141                    .iter()
142                    .max_by(|(name_a, a), (name_b, b)| {
143                        match a.partial_cmp(b).unwrap() {
144                            // When values are equal, prefer earlier alphabetically
145                            std::cmp::Ordering::Equal => name_b.cmp(name_a),
146                            other => other,
147                        }
148                    })
149                    .map(|(n, v)| (n.clone(), *v))
150                    .ok_or("No branches in decision node")?;
151                (best_value, Some(best_name))
152            },
153            NodeType::Chance => {
154                // Chance node: probability-weighted expected value
155                let ev: f64 = node
156                    .branches
157                    .iter()
158                    .map(|(branch_name, branch)| {
159                        branch.probability * branch_values.get(branch_name).unwrap_or(&0.0)
160                    })
161                    .sum();
162                (ev, None)
163            },
164            NodeType::Terminal => {
165                // Terminal nodes shouldn't have branches in typical usage
166                (0.0, None)
167            },
168        };
169
170        let result = NodeResult {
171            name: node.name.clone(),
172            node_type: node.node_type,
173            expected_value,
174            optimal_choice,
175            branch_values,
176        };
177
178        results.insert(name.to_string(), result.clone());
179        Ok(result)
180    }
181
182    /// Evaluate a branch value
183    fn evaluate_branch(
184        &self,
185        branch: &Branch,
186        results: &mut HashMap<String, NodeResult>,
187        all_terminal_values: &mut Vec<(f64, f64)>,
188        parent_type: NodeType,
189    ) -> Result<f64, String> {
190        let base_value = if let Some(value) = branch.value {
191            // Terminal branch - track for risk profile
192            let prob = if parent_type == NodeType::Chance {
193                branch.probability
194            } else {
195                1.0
196            };
197            all_terminal_values.push((value - branch.cost, prob));
198            value
199        } else if let Some(ref next) = branch.next {
200            // Continuation branch - recurse
201            let next_node = self
202                .config
203                .get_node(next)
204                .ok_or_else(|| format!("Node '{next}' not found"))?;
205            let next_result = self.evaluate_node(next, next_node, results, all_terminal_values)?;
206            next_result.expected_value
207        } else {
208            return Err("Branch has neither value nor next node".to_string());
209        };
210
211        // Subtract cost (for decision branches)
212        Ok(base_value - branch.cost)
213    }
214
215    /// Build the optimal decision path
216    fn build_optimal_path(&self, results: &HashMap<String, NodeResult>) -> Vec<String> {
217        let mut path = Vec::new();
218
219        if let Some(root_result) = results.get("root") {
220            self.trace_optimal_path("root", root_result, results, &mut path);
221        }
222
223        path
224    }
225
226    fn trace_optimal_path(
227        &self,
228        name: &str,
229        result: &NodeResult,
230        results: &HashMap<String, NodeResult>,
231        path: &mut Vec<String>,
232    ) {
233        match result.node_type {
234            NodeType::Decision => {
235                if let Some(ref choice) = result.optimal_choice {
236                    path.push(format!("{} → {}", result.name, choice));
237
238                    // Follow the chosen branch
239                    if let Some(root) = &self.config.root {
240                        if name == "root" {
241                            if let Some(branch) = root.branches.get(choice) {
242                                if let Some(ref next) = branch.next {
243                                    if let Some(next_result) = results.get(next) {
244                                        self.trace_optimal_path(next, next_result, results, path);
245                                    }
246                                }
247                            }
248                        }
249                    }
250
251                    if let Some(node) = self.config.get_node(name) {
252                        if let Some(branch) = node.branches.get(choice) {
253                            if let Some(ref next) = branch.next {
254                                if let Some(next_result) = results.get(next) {
255                                    self.trace_optimal_path(next, next_result, results, path);
256                                }
257                            }
258                        }
259                    }
260                }
261            },
262            NodeType::Chance => {
263                path.push(format!("{} → (await outcome)", result.name));
264                // For chance nodes, show all branches lead to
265                if let Some(node) = self.config.get_node(name) {
266                    for (branch_name, branch) in &node.branches {
267                        if let Some(ref next) = branch.next {
268                            if let Some(next_result) = results.get(next) {
269                                path.push(format!("  if {branch_name} →"));
270                                let mut sub_path = Vec::new();
271                                self.trace_optimal_path(next, next_result, results, &mut sub_path);
272                                for p in sub_path {
273                                    path.push(format!("    {p}"));
274                                }
275                            }
276                        }
277                    }
278                }
279            },
280            NodeType::Terminal => {
281                // End of path
282            },
283        }
284    }
285
286    /// Build decision policy
287    fn build_decision_policy(results: &HashMap<String, NodeResult>) -> HashMap<String, String> {
288        let mut policy = HashMap::new();
289
290        for (name, result) in results {
291            if result.node_type == NodeType::Decision {
292                if let Some(ref choice) = result.optimal_choice {
293                    policy.insert(name.clone(), choice.clone());
294                }
295            }
296        }
297
298        policy
299    }
300
301    /// Calculate risk profile from terminal values
302    fn calculate_risk_profile(terminal_values: &[(f64, f64)]) -> RiskProfile {
303        if terminal_values.is_empty() {
304            return RiskProfile {
305                best_case: 0.0,
306                worst_case: 0.0,
307                probability_positive: 0.0,
308            };
309        }
310
311        let best_case = terminal_values
312            .iter()
313            .map(|(v, _)| *v)
314            .fold(f64::NEG_INFINITY, f64::max);
315
316        let worst_case = terminal_values
317            .iter()
318            .map(|(v, _)| *v)
319            .fold(f64::INFINITY, f64::min);
320
321        // This is simplified - actual calculation would need path probabilities
322        let probability_positive = terminal_values
323            .iter()
324            .filter(|(v, _)| *v > 0.0)
325            .map(|(_, p)| *p)
326            .sum::<f64>()
327            .min(1.0);
328
329        RiskProfile {
330            best_case,
331            worst_case,
332            probability_positive,
333        }
334    }
335
336    /// Get the configuration
337    #[must_use]
338    pub const fn config(&self) -> &DecisionTreeConfig {
339        &self.config
340    }
341}
342
343#[cfg(test)]
344// Financial math: exact float comparison validated against Excel/Gnumeric/R
345#[allow(clippy::float_cmp)]
346mod engine_tests {
347    use super::*;
348
349    fn create_rnd_tree() -> DecisionTreeConfig {
350        DecisionTreeConfig::new("R&D Investment")
351            .with_root(
352                Node::decision("Invest in R&D?")
353                    .with_branch(
354                        "invest",
355                        Branch::continuation("tech_outcome").with_cost(2_000_000.0),
356                    )
357                    .with_branch("dont_invest", Branch::terminal(0.0)),
358            )
359            .with_node(
360                "tech_outcome",
361                Node::chance("Technology works?")
362                    .with_branch(
363                        "success",
364                        Branch::continuation("commercialize").with_probability(0.60),
365                    )
366                    .with_branch("failure", Branch::terminal(0.0).with_probability(0.40)),
367            )
368            .with_node(
369                "commercialize",
370                Node::decision("How to commercialize?")
371                    .with_branch("license", Branch::terminal(5_000_000.0))
372                    .with_branch(
373                        "manufacture",
374                        Branch::terminal(8_000_000.0).with_cost(3_000_000.0),
375                    ),
376            )
377    }
378
379    #[test]
380    fn test_backward_induction() {
381        let config = create_rnd_tree();
382        let engine = DecisionTreeEngine::new(config).unwrap();
383        let result = engine.analyze().unwrap();
384
385        // commercialize_decision: max($5M, $8M-$3M) = $5M (license)
386        let commercialize = result.node_results.get("commercialize").unwrap();
387        assert_eq!(commercialize.expected_value, 5_000_000.0);
388        assert_eq!(commercialize.optimal_choice, Some("license".to_string()));
389
390        // tech_outcome: 0.6 × $5M + 0.4 × $0 = $3M
391        let tech = result.node_results.get("tech_outcome").unwrap();
392        assert!((tech.expected_value - 3_000_000.0).abs() < 0.01);
393
394        // root: max($3M - $2M, $0) = $1M (invest)
395        assert!((result.root_expected_value - 1_000_000.0).abs() < 0.01);
396    }
397
398    #[test]
399    fn test_decision_policy() {
400        let config = create_rnd_tree();
401        let engine = DecisionTreeEngine::new(config).unwrap();
402        let result = engine.analyze().unwrap();
403
404        assert_eq!(
405            result.decision_policy.get("root"),
406            Some(&"invest".to_string())
407        );
408        assert_eq!(
409            result.decision_policy.get("commercialize"),
410            Some(&"license".to_string())
411        );
412    }
413
414    /// Roundtrip validation - matches SciPy/NumPy backward induction
415    #[test]
416    fn test_scipy_numpy_equivalence() {
417        // This test validates against Python's SciPy/NumPy
418        // Python code:
419        //   license_value = 5_000_000
420        //   manufacture_value = 8_000_000 - 3_000_000  # net of cost
421        //   commercialize_ev = max(license_value, manufacture_value)  # $5,000,000
422        //
423        //   p_success, p_failure = 0.60, 0.40
424        //   failure_value = 0
425        //   tech_ev = p_success * commercialize_ev + p_failure * failure_value  # $3,000,000
426        //
427        //   invest_cost = 2_000_000
428        //   invest_ev = tech_ev - invest_cost  # $1,000,000
429        //   no_invest_ev = 0
430        //   root_ev = max(invest_ev, no_invest_ev)  # $1,000,000
431
432        let config = create_rnd_tree();
433        let engine = DecisionTreeEngine::new(config).unwrap();
434        let result = engine.analyze().unwrap();
435
436        // Validate against Python calculation
437        assert!(
438            (result.root_expected_value - 1_000_000.0).abs() < 0.01,
439            "Root EV should be $1M, got {}",
440            result.root_expected_value
441        );
442    }
443
444    #[test]
445    fn test_simple_coin_flip() {
446        let config = DecisionTreeConfig::new("Coin Flip").with_root(
447            Node::chance("Flip coin")
448                .with_branch("heads", Branch::terminal(100.0).with_probability(0.5))
449                .with_branch("tails", Branch::terminal(0.0).with_probability(0.5)),
450        );
451
452        let engine = DecisionTreeEngine::new(config).unwrap();
453        let result = engine.analyze().unwrap();
454
455        // EV = 0.5 * 100 + 0.5 * 0 = 50
456        assert!(
457            (result.root_expected_value - 50.0).abs() < 0.01,
458            "Expected 50, got {}",
459            result.root_expected_value
460        );
461    }
462
463    #[test]
464    fn test_yaml_export() {
465        let config = create_rnd_tree();
466        let engine = DecisionTreeEngine::new(config).unwrap();
467        let result = engine.analyze().unwrap();
468        let yaml = result.to_yaml();
469
470        assert!(yaml.contains("root_expected_value"));
471        assert!(yaml.contains("decision_policy"));
472    }
473}