extensive_form/
backward.rs1use crate::node::{NodeId, NodeType};
4use crate::tree::GameTree;
5use std::collections::HashMap;
6
7pub type Strategy = HashMap<NodeId, usize>;
9
10#[derive(Clone, Debug)]
12pub struct BackwardResult {
13 pub strategy: Strategy,
15 pub values: HashMap<NodeId, Vec<f64>>,
17}
18
19pub fn backward_induction(tree: &GameTree) -> BackwardResult {
23 let mut values: HashMap<NodeId, Vec<f64>> = HashMap::new();
24 let mut strategy: Strategy = HashMap::new();
25
26 let root = tree.root.expect("Game tree must have a root");
28
29 solve_node(tree, root, &mut values, &mut strategy);
31
32 BackwardResult { strategy, values }
33}
34
35fn solve_node(
36 tree: &GameTree,
37 node_id: NodeId,
38 values: &mut HashMap<NodeId, Vec<f64>>,
39 strategy: &mut Strategy,
40) -> Vec<f64> {
41 if let Some(v) = values.get(&node_id) {
43 return v.clone();
44 }
45
46 let node = tree.get_node(node_id).expect("Node must exist").clone();
47
48 let result = match &node.node_type {
49 NodeType::Terminal { payoffs } => payoffs.clone(),
50 NodeType::Decision { player } => {
51 let p = *player;
52 let mut best_action = 0;
53 let mut best_value = Vec::new();
54 let mut best_player_val = f64::NEG_INFINITY;
55
56 for (i, &child_id) in node.children.iter().enumerate() {
57 let child_val = solve_node(tree, child_id, values, strategy);
58 if child_val[p] > best_player_val {
59 best_player_val = child_val[p];
60 best_value = child_val;
61 best_action = i;
62 }
63 }
64
65 strategy.insert(node_id, best_action);
66 best_value
67 }
68 NodeType::Chance { probabilities } => {
69 let n_players = tree.num_players;
70 let mut expected = vec![0.0; n_players];
71
72 for (i, &prob) in probabilities.iter().enumerate() {
73 if i < node.children.len() {
74 let child_val = solve_node(tree, node.children[i], values, strategy);
75 for (p, ev) in expected.iter_mut().enumerate() {
76 if p < child_val.len() {
77 *ev += prob * child_val[p];
78 }
79 }
80 }
81 }
82
83 expected
84 }
85 };
86
87 values.insert(node_id, result.clone());
88 result
89}
90
91pub fn equilibrium_path(tree: &GameTree, result: &BackwardResult) -> Vec<NodeId> {
93 let mut path = Vec::new();
94 let mut current = tree.root;
95
96 while let Some(node_id) = current {
97 path.push(node_id);
98 let node = tree.get_node(node_id).unwrap();
99 if node.is_terminal() {
100 break;
101 }
102 match result.strategy.get(&node_id) {
103 Some(&action_idx) => {
104 current = node.children.get(action_idx).copied();
105 }
106 None => {
107 current = node.children.first().copied();
109 }
110 }
111 }
112
113 path
114}
115
116pub fn equilibrium_actions(tree: &GameTree, result: &BackwardResult) -> Vec<String> {
118 let path = equilibrium_path(tree, result);
119 let mut actions = Vec::new();
120 for &node_id in &path[1..] {
121 if let Some(node) = tree.get_node(node_id) {
122 if let Some(ref action) = node.incoming_action {
123 actions.push(action.clone());
124 }
125 }
126 }
127 actions
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use crate::tree::GameTree;
134
135 #[test]
136 fn test_simple_backward() {
137 let mut tree = GameTree::new("Simple");
138 let root = tree.add_decision(0, "Choose");
139 let left = tree.add_terminal(vec![1.0, 3.0], "L");
140 let right = tree.add_terminal(vec![2.0, 1.0], "R");
141 tree.add_action(root, "L", left);
142 tree.add_action(root, "R", right);
143
144 let result = backward_induction(&tree);
145 assert_eq!(result.strategy[&root], 1);
147 let root_val = &result.values[&root];
148 assert!((root_val[0] - 2.0).abs() < 1e-10);
149 }
150
151 #[test]
152 fn test_ultimatum_backward() {
153 let tree = GameTree::ultimatum_game(10.0);
154 let result = backward_induction(&tree);
155 let root = tree.root.unwrap();
158 let root_val = &result.values[&root];
159 assert!(root_val[0] > 0.0);
161 }
162
163 #[test]
164 fn test_centipede_backward() {
165 let tree = GameTree::centipede_game(3, 1.0, 1.0);
166 let result = backward_induction(&tree);
167 let root = tree.root.unwrap();
168 assert_eq!(result.strategy[&root], 0); }
171
172 #[test]
173 fn test_equilibrium_path() {
174 let mut tree = GameTree::new("Path");
175 let root = tree.add_decision(0, "Root");
176 let left = tree.add_terminal(vec![5.0, 0.0], "L");
177 let right = tree.add_terminal(vec![0.0, 5.0], "R");
178 tree.add_action(root, "L", left);
179 tree.add_action(root, "R", right);
180
181 let result = backward_induction(&tree);
182 let path = equilibrium_path(&tree, &result);
183 assert_eq!(path.len(), 2);
184 assert_eq!(path[0], root);
185 assert_eq!(path[1], left); }
187
188 #[test]
189 fn test_multi_level_backward() {
190 let mut tree = GameTree::new("Two Level");
191 let root = tree.add_decision(0, "P0");
192 let left = tree.add_decision(1, "P1-Left");
193 let right = tree.add_decision(1, "P1-Right");
194 let ll = tree.add_terminal(vec![3.0, 1.0], "LL");
195 let lr = tree.add_terminal(vec![0.0, 2.0], "LR");
196 let rl = tree.add_terminal(vec![2.0, 2.0], "RL");
197 let rr = tree.add_terminal(vec![1.0, 3.0], "RR");
198
199 tree.add_action(root, "L", left);
200 tree.add_action(root, "R", right);
201 tree.add_action(left, "L", ll);
202 tree.add_action(left, "R", lr);
203 tree.add_action(right, "L", rl);
204 tree.add_action(right, "R", rr);
205
206 let result = backward_induction(&tree);
207 assert_eq!(result.strategy[&root], 1); }
212}