cfr/
regret.rs

1//! private module for computing regret
2use super::{ChanceInfoset, Node, Player, PlayerInfoset, PlayerNum};
3use std::mem;
4
5// NOTE Some of these methods could be written to use thread pools, but it's not clear that this is
6// a large bottleneck so it's not worth the complexity
7
8pub(super) fn expected(
9    node: &Node,
10    chance_info: &[impl ChanceInfoset],
11    strat_info: [&[impl AsRef<[f64]>]; 2],
12) -> f64 {
13    let mut queue = vec![(node, 1.0)];
14    let mut expected = 0.0;
15    while let Some((node, reach)) = queue.pop() {
16        match node {
17            Node::Terminal(payoff) => {
18                expected += reach * payoff;
19            }
20            Node::Chance(chance) => {
21                let probs = chance_info[chance.infoset].probs();
22                for (prob, next) in probs.iter().zip(chance.outcomes.iter()) {
23                    queue.push((next, prob * reach));
24                }
25            }
26            Node::Player(player) => {
27                let probs = player.num.ind(&strat_info)[player.infoset].as_ref();
28                for (prob, next) in probs.iter().zip(player.actions.iter()) {
29                    if prob > &0.0 {
30                        queue.push((next, prob * reach));
31                    }
32                }
33            }
34        }
35    }
36    expected
37}
38
39#[derive(Default, Debug)]
40struct DeviationInfo<'a> {
41    future_nodes: usize,
42    prob_nodes: Vec<(&'a Player, f64)>,
43    max_utility: f64,
44}
45
46fn optimal_deviations<const PLAYER_ONE: bool>(
47    start: &Node,
48    chance_info: &[impl ChanceInfoset],
49    player_info: &[impl PlayerInfoset],
50    strat_info: &[impl AsRef<[f64]>],
51) -> f64 {
52    let mut infosets: Box<[_]> = player_info
53        .iter()
54        .map(|_| DeviationInfo::default())
55        .collect();
56    let mut search_queue = vec![(start, 1.0)];
57    while let Some((node, reach)) = search_queue.pop() {
58        match node {
59            Node::Terminal(_) => (),
60            Node::Chance(chance) => {
61                let probs = chance_info[chance.infoset].probs();
62                for (prob, next) in probs.iter().zip(chance.outcomes.iter()) {
63                    search_queue.push((next, prob * reach));
64                }
65            }
66            Node::Player(player) => match (player.num, PLAYER_ONE) {
67                (PlayerNum::One, true) | (PlayerNum::Two, false) => {
68                    infosets[player.infoset].prob_nodes.push((player, reach));
69                    if let Some(prev) = player_info[player.infoset].prev_infoset() {
70                        infosets[prev].future_nodes += 1;
71                    }
72                    for next in player.actions.iter() {
73                        search_queue.push((next, reach));
74                    }
75                }
76                (PlayerNum::One, false) | (PlayerNum::Two, true) => {
77                    let probs = strat_info[player.infoset].as_ref();
78                    for (prob, next) in probs.iter().zip(player.actions.iter()) {
79                        if prob > &0.0 {
80                            search_queue.push((next, prob * reach));
81                        }
82                    }
83                }
84            },
85        }
86    }
87
88    let mut info_queue: Vec<_> = infosets
89        .iter()
90        .enumerate()
91        .filter(|(_, dev)| dev.future_nodes == 0 && !dev.prob_nodes.is_empty())
92        .map(|(info, _)| info)
93        .collect();
94    while let Some(info) = info_queue.pop() {
95        // get iteration nodes and compute total probability of reach for normalization
96        let nodes = mem::take(&mut infosets[info].prob_nodes);
97        let total_reach: f64 = nodes.iter().map(|(_, p)| p).sum();
98
99        // check if finishing this infoset will allow us to evaluate a new infoset
100        if let Some(prev) = player_info[info].prev_infoset() {
101            let futs = &mut infosets[prev].future_nodes;
102            *futs -= nodes.len();
103            // if so, add it to the queue
104            if futs == &mut 0 {
105                info_queue.push(prev);
106            }
107        }
108
109        // get the expected payoff of each action
110        let mut payoffs = vec![0.0; player_info[info].num_actions()];
111        for (player, prob) in nodes {
112            for (next, res) in player.actions.iter().zip(payoffs.iter_mut()) {
113                *res += next_infoset_search::<PLAYER_ONE>(
114                    next,
115                    &mut search_queue,
116                    &infosets,
117                    chance_info,
118                    strat_info,
119                ) * prob;
120            }
121        }
122
123        // set the max utility of playing to reach an infoset
124        infosets[info].max_utility = payoffs.into_iter().reduce(f64::max).unwrap() / total_reach;
125    }
126    next_infoset_search::<PLAYER_ONE>(start, &mut search_queue, &infosets, chance_info, strat_info)
127}
128
129fn next_infoset_search<'a, const PLAYER_ONE: bool>(
130    start: &'a Node,
131    search_queue: &mut Vec<(&'a Node, f64)>,
132    infosets: &[DeviationInfo],
133    chance_info: &[impl ChanceInfoset],
134    strat_info: &[impl AsRef<[f64]>],
135) -> f64 {
136    let mut res = 0.0;
137    search_queue.push((start, 1.0));
138    while let Some((node, reach)) = search_queue.pop() {
139        match node {
140            Node::Terminal(payoff) => {
141                if PLAYER_ONE {
142                    res += payoff * reach;
143                } else {
144                    res -= payoff * reach;
145                }
146            }
147            Node::Chance(chance) => {
148                let probs = chance_info[chance.infoset].probs();
149                for (prob, next) in probs.iter().zip(chance.outcomes.iter()) {
150                    search_queue.push((next, prob * reach));
151                }
152            }
153            Node::Player(player) => match (player.num, PLAYER_ONE) {
154                (PlayerNum::One, true) | (PlayerNum::Two, false) => {
155                    let info = &infosets[player.infoset];
156                    debug_assert_eq!(info.prob_nodes.len(), 0);
157                    debug_assert_eq!(info.future_nodes, 0);
158                    res += info.max_utility * reach;
159                }
160                (PlayerNum::One, false) | (PlayerNum::Two, true) => {
161                    let probs = strat_info[player.infoset].as_ref();
162                    for (prob, next) in probs.iter().zip(player.actions.iter()) {
163                        if prob > &0.0 {
164                            search_queue.push((next, prob * reach));
165                        }
166                    }
167                }
168            },
169        }
170    }
171    res
172}
173
174pub(super) fn regret(
175    start: &Node,
176    chance_info: &[impl ChanceInfoset],
177    player_info: [&[impl PlayerInfoset]; 2],
178    strat_info: [&[impl AsRef<[f64]>]; 2],
179) -> (f64, [f64; 2]) {
180    let expected = expected(start, chance_info, strat_info);
181    let one = optimal_deviations::<true>(start, chance_info, player_info[0], strat_info[1]);
182    let two = optimal_deviations::<false>(start, chance_info, player_info[1], strat_info[0]);
183    (
184        expected,
185        [f64::max(one - expected, 0.0), f64::max(two + expected, 0.0)],
186    )
187}