use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum NodeType {
Decision,
Chance,
Terminal,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Branch {
#[serde(default)]
pub cost: f64,
#[serde(default)]
pub probability: f64,
pub value: Option<f64>,
pub next: Option<String>,
}
impl Branch {
#[must_use]
pub const fn terminal(value: f64) -> Self {
Self {
cost: 0.0,
probability: 0.0,
value: Some(value),
next: None,
}
}
#[must_use]
pub fn continuation(next: &str) -> Self {
Self {
cost: 0.0,
probability: 0.0,
value: None,
next: Some(next.to_string()),
}
}
#[must_use]
pub const fn with_cost(mut self, cost: f64) -> Self {
self.cost = cost;
self
}
#[must_use]
pub const fn with_probability(mut self, probability: f64) -> Self {
self.probability = probability;
self
}
#[must_use]
pub const fn is_terminal(&self) -> bool {
self.value.is_some() && self.next.is_none()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Node {
#[serde(rename = "type")]
pub node_type: NodeType,
#[serde(default)]
pub name: String,
pub branches: HashMap<String, Branch>,
}
impl Node {
#[must_use]
pub fn decision(name: &str) -> Self {
Self {
node_type: NodeType::Decision,
name: name.to_string(),
branches: HashMap::new(),
}
}
#[must_use]
pub fn chance(name: &str) -> Self {
Self {
node_type: NodeType::Chance,
name: name.to_string(),
branches: HashMap::new(),
}
}
#[must_use]
pub fn with_branch(mut self, name: &str, branch: Branch) -> Self {
self.branches.insert(name.to_string(), branch);
self
}
pub fn validate(&self) -> Result<(), String> {
const TOLERANCE: f64 = 0.001;
if self.branches.is_empty() {
return Err(format!("Node '{}' has no branches", self.name));
}
if self.node_type == NodeType::Chance {
let total_prob: f64 = self.branches.values().map(|b| b.probability).sum();
if (total_prob - 1.0).abs() > TOLERANCE {
return Err(format!(
"Chance node '{}' probabilities must sum to 1.0, got {:.4}",
self.name, total_prob
));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DecisionTreeConfig {
#[serde(default)]
pub name: String,
pub root: Option<Node>,
#[serde(default)]
pub nodes: HashMap<String, Node>,
}
impl DecisionTreeConfig {
#[must_use]
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
root: None,
nodes: HashMap::new(),
}
}
#[must_use]
pub fn with_root(mut self, root: Node) -> Self {
self.root = Some(root);
self
}
#[must_use]
pub fn with_node(mut self, name: &str, node: Node) -> Self {
self.nodes.insert(name.to_string(), node);
self
}
pub fn validate(&self) -> Result<(), String> {
let root = self.root.as_ref().ok_or("No root node defined")?;
root.validate()?;
self.validate_references(root)?;
for (name, node) in &self.nodes {
node.validate().map_err(|e| format!("Node '{name}': {e}"))?;
self.validate_references(node)?;
}
self.check_cycles()?;
Ok(())
}
fn validate_references(&self, node: &Node) -> Result<(), String> {
for (branch_name, branch) in &node.branches {
if let Some(ref next) = branch.next {
if !self.nodes.contains_key(next) {
return Err(format!(
"Branch '{branch_name}' references non-existent node '{next}'"
));
}
}
}
Ok(())
}
fn check_cycles(&self) -> Result<(), String> {
let mut visited = std::collections::HashSet::new();
let mut stack = std::collections::HashSet::new();
if let Some(ref root) = self.root {
self.dfs_cycle_check("root", root, &mut visited, &mut stack)?;
}
Ok(())
}
fn dfs_cycle_check(
&self,
name: &str,
node: &Node,
visited: &mut std::collections::HashSet<String>,
stack: &mut std::collections::HashSet<String>,
) -> Result<(), String> {
if stack.contains(name) {
return Err(format!("Cycle detected involving node '{name}'"));
}
if visited.contains(name) {
return Ok(());
}
visited.insert(name.to_string());
stack.insert(name.to_string());
for branch in node.branches.values() {
if let Some(ref next) = branch.next {
if let Some(next_node) = self.nodes.get(next) {
self.dfs_cycle_check(next, next_node, visited, stack)?;
}
}
}
stack.remove(name);
Ok(())
}
#[must_use]
pub fn get_node(&self, name: &str) -> Option<&Node> {
self.nodes.get(name)
}
}
#[cfg(test)]
mod config_tests {
use super::*;
fn create_rnd_tree() -> DecisionTreeConfig {
DecisionTreeConfig::new("R&D Investment")
.with_root(
Node::decision("Invest in R&D?")
.with_branch(
"invest",
Branch::continuation("tech_outcome").with_cost(2_000_000.0),
)
.with_branch("dont_invest", Branch::terminal(0.0)),
)
.with_node(
"tech_outcome",
Node::chance("Technology works?")
.with_branch(
"success",
Branch::continuation("commercialize").with_probability(0.60),
)
.with_branch(
"failure",
Branch::terminal(-2_000_000.0).with_probability(0.40),
),
)
.with_node(
"commercialize",
Node::decision("How to commercialize?")
.with_branch("license", Branch::terminal(5_000_000.0))
.with_branch(
"manufacture",
Branch::terminal(8_000_000.0).with_cost(3_000_000.0),
),
)
}
#[test]
fn test_tree_config_validation() {
let tree = create_rnd_tree();
assert!(tree.validate().is_ok());
}
#[test]
fn test_missing_root_rejected() {
let tree = DecisionTreeConfig::new("Empty");
let result = tree.validate();
assert!(result.is_err());
assert!(result.unwrap_err().contains("No root node"));
}
#[test]
fn test_invalid_reference_rejected() {
let tree = DecisionTreeConfig::new("Bad Ref").with_root(
Node::decision("Start").with_branch("go", Branch::continuation("nonexistent")),
);
let result = tree.validate();
assert!(result.is_err());
assert!(result.unwrap_err().contains("non-existent node"));
}
#[test]
fn test_chance_probabilities_must_sum_to_one() {
let tree = DecisionTreeConfig::new("Bad Probs").with_root(
Node::chance("Coin flip")
.with_branch("heads", Branch::terminal(100.0).with_probability(0.5))
.with_branch("tails", Branch::terminal(0.0).with_probability(0.3)),
);
let result = tree.validate();
assert!(result.is_err());
assert!(result.unwrap_err().contains("sum to 1.0"));
}
#[test]
fn test_cycle_detection() {
let tree = DecisionTreeConfig::new("Cycle")
.with_root(Node::decision("A").with_branch("go", Branch::continuation("b")))
.with_node(
"b",
Node::decision("B").with_branch("back", Branch::continuation("b")),
);
let result = tree.validate();
assert!(result.is_err());
assert!(result.unwrap_err().contains("Cycle"));
}
}