1use super::{ChanceInfoset, Node, Player, PlayerInfoset, PlayerNum};
3use std::mem;
4
5pub(super) fn expected(
9 node: &Node,
10 chance_info: &[impl ChanceInfoset],
11 strat_info: [&[impl AsRef<[f64]>]; 2],
12) -> f64 {
13 let mut queue = vec![(node, 1.0)];
14 let mut expected = 0.0;
15 while let Some((node, reach)) = queue.pop() {
16 match node {
17 Node::Terminal(payoff) => {
18 expected += reach * payoff;
19 }
20 Node::Chance(chance) => {
21 let probs = chance_info[chance.infoset].probs();
22 for (prob, next) in probs.iter().zip(chance.outcomes.iter()) {
23 queue.push((next, prob * reach));
24 }
25 }
26 Node::Player(player) => {
27 let probs = player.num.ind(&strat_info)[player.infoset].as_ref();
28 for (prob, next) in probs.iter().zip(player.actions.iter()) {
29 if prob > &0.0 {
30 queue.push((next, prob * reach));
31 }
32 }
33 }
34 }
35 }
36 expected
37}
38
39#[derive(Default, Debug)]
40struct DeviationInfo<'a> {
41 future_nodes: usize,
42 prob_nodes: Vec<(&'a Player, f64)>,
43 max_utility: f64,
44}
45
46fn optimal_deviations<const PLAYER_ONE: bool>(
47 start: &Node,
48 chance_info: &[impl ChanceInfoset],
49 player_info: &[impl PlayerInfoset],
50 strat_info: &[impl AsRef<[f64]>],
51) -> f64 {
52 let mut infosets: Box<[_]> = player_info
53 .iter()
54 .map(|_| DeviationInfo::default())
55 .collect();
56 let mut search_queue = vec![(start, 1.0)];
57 while let Some((node, reach)) = search_queue.pop() {
58 match node {
59 Node::Terminal(_) => (),
60 Node::Chance(chance) => {
61 let probs = chance_info[chance.infoset].probs();
62 for (prob, next) in probs.iter().zip(chance.outcomes.iter()) {
63 search_queue.push((next, prob * reach));
64 }
65 }
66 Node::Player(player) => match (player.num, PLAYER_ONE) {
67 (PlayerNum::One, true) | (PlayerNum::Two, false) => {
68 infosets[player.infoset].prob_nodes.push((player, reach));
69 if let Some(prev) = player_info[player.infoset].prev_infoset() {
70 infosets[prev].future_nodes += 1;
71 }
72 for next in player.actions.iter() {
73 search_queue.push((next, reach));
74 }
75 }
76 (PlayerNum::One, false) | (PlayerNum::Two, true) => {
77 let probs = strat_info[player.infoset].as_ref();
78 for (prob, next) in probs.iter().zip(player.actions.iter()) {
79 if prob > &0.0 {
80 search_queue.push((next, prob * reach));
81 }
82 }
83 }
84 },
85 }
86 }
87
88 let mut info_queue: Vec<_> = infosets
89 .iter()
90 .enumerate()
91 .filter(|(_, dev)| dev.future_nodes == 0 && !dev.prob_nodes.is_empty())
92 .map(|(info, _)| info)
93 .collect();
94 while let Some(info) = info_queue.pop() {
95 let nodes = mem::take(&mut infosets[info].prob_nodes);
97 let total_reach: f64 = nodes.iter().map(|(_, p)| p).sum();
98
99 if let Some(prev) = player_info[info].prev_infoset() {
101 let futs = &mut infosets[prev].future_nodes;
102 *futs -= nodes.len();
103 if futs == &mut 0 {
105 info_queue.push(prev);
106 }
107 }
108
109 let mut payoffs = vec![0.0; player_info[info].num_actions()];
111 for (player, prob) in nodes {
112 for (next, res) in player.actions.iter().zip(payoffs.iter_mut()) {
113 *res += next_infoset_search::<PLAYER_ONE>(
114 next,
115 &mut search_queue,
116 &infosets,
117 chance_info,
118 strat_info,
119 ) * prob;
120 }
121 }
122
123 infosets[info].max_utility = payoffs.into_iter().reduce(f64::max).unwrap() / total_reach;
125 }
126 next_infoset_search::<PLAYER_ONE>(start, &mut search_queue, &infosets, chance_info, strat_info)
127}
128
129fn next_infoset_search<'a, const PLAYER_ONE: bool>(
130 start: &'a Node,
131 search_queue: &mut Vec<(&'a Node, f64)>,
132 infosets: &[DeviationInfo],
133 chance_info: &[impl ChanceInfoset],
134 strat_info: &[impl AsRef<[f64]>],
135) -> f64 {
136 let mut res = 0.0;
137 search_queue.push((start, 1.0));
138 while let Some((node, reach)) = search_queue.pop() {
139 match node {
140 Node::Terminal(payoff) => {
141 if PLAYER_ONE {
142 res += payoff * reach;
143 } else {
144 res -= payoff * reach;
145 }
146 }
147 Node::Chance(chance) => {
148 let probs = chance_info[chance.infoset].probs();
149 for (prob, next) in probs.iter().zip(chance.outcomes.iter()) {
150 search_queue.push((next, prob * reach));
151 }
152 }
153 Node::Player(player) => match (player.num, PLAYER_ONE) {
154 (PlayerNum::One, true) | (PlayerNum::Two, false) => {
155 let info = &infosets[player.infoset];
156 debug_assert_eq!(info.prob_nodes.len(), 0);
157 debug_assert_eq!(info.future_nodes, 0);
158 res += info.max_utility * reach;
159 }
160 (PlayerNum::One, false) | (PlayerNum::Two, true) => {
161 let probs = strat_info[player.infoset].as_ref();
162 for (prob, next) in probs.iter().zip(player.actions.iter()) {
163 if prob > &0.0 {
164 search_queue.push((next, prob * reach));
165 }
166 }
167 }
168 },
169 }
170 }
171 res
172}
173
174pub(super) fn regret(
175 start: &Node,
176 chance_info: &[impl ChanceInfoset],
177 player_info: [&[impl PlayerInfoset]; 2],
178 strat_info: [&[impl AsRef<[f64]>]; 2],
179) -> (f64, [f64; 2]) {
180 let expected = expected(start, chance_info, strat_info);
181 let one = optimal_deviations::<true>(start, chance_info, player_info[0], strat_info[1]);
182 let two = optimal_deviations::<false>(start, chance_info, player_info[1], strat_info[0]);
183 (
184 expected,
185 [f64::max(one - expected, 0.0), f64::max(two + expected, 0.0)],
186 )
187}