kuhn_poker/
kuhn_poker.rs

1//! An example implementation of IntoGameNode for Kuhn Poker
2use cfr::{Game, GameNode, IntoGameNode, PlayerNum, SolveMethod};
3use clap::Parser;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6enum Impossible {}
7
8#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
9enum Action {
10    Fold,
11    Call,
12    Raise,
13}
14
15enum Kuhn {
16    Terminal(f64),
17    Deal(Vec<(f64, Kuhn)>),
18    Gambler(PlayerNum, (usize, bool), Vec<(Action, Kuhn)>),
19}
20
21impl IntoGameNode for Kuhn {
22    type PlayerInfo = (usize, bool);
23    type Action = Action;
24    type ChanceInfo = Impossible;
25    type Outcomes = Vec<(f64, Kuhn)>;
26    type Actions = Vec<(Action, Kuhn)>;
27
28    fn into_game_node(self) -> GameNode<Self> {
29        match self {
30            Kuhn::Terminal(payoff) => GameNode::Terminal(payoff),
31            Kuhn::Deal(outcomes) => GameNode::Chance(None, outcomes),
32            Kuhn::Gambler(num, info, actions) => GameNode::Player(num, info, actions),
33        }
34    }
35}
36
37fn create_kuhn_one(num_cards: usize) -> Kuhn {
38    assert!(num_cards > 1);
39    let frac = 1.0 / num_cards as f64;
40    Kuhn::Deal(
41        (0..num_cards)
42            .map(|card| {
43                let next: Vec<_> = [
44                    (Action::Call, create_kuhn_two(num_cards, card, false)),
45                    (Action::Raise, create_kuhn_two(num_cards, card, true)),
46                ]
47                .into();
48                (frac, Kuhn::Gambler(PlayerNum::One, (card, false), next))
49            })
50            .collect(),
51    )
52}
53
54fn create_kuhn_two(num_cards: usize, first: usize, one_raised: bool) -> Kuhn {
55    let frac = 1.0 / (num_cards - 1) as f64;
56    let lose_cards = 0..first;
57    let win_cards = first + 1..num_cards;
58    Kuhn::Deal(if one_raised {
59        let lose_acts = lose_cards.map(|card| {
60            let next = [
61                (Action::Call, Kuhn::Terminal(2.0)),
62                (Action::Fold, Kuhn::Terminal(1.0)),
63            ];
64            (
65                frac,
66                Kuhn::Gambler(PlayerNum::Two, (card, true), next.into()),
67            )
68        });
69        let win_acts = win_cards.map(|card| {
70            let next = [
71                (Action::Call, Kuhn::Terminal(-2.0)),
72                (Action::Fold, Kuhn::Terminal(1.0)),
73            ];
74            (
75                frac,
76                Kuhn::Gambler(PlayerNum::Two, (card, true), next.into()),
77            )
78        });
79        lose_acts.chain(win_acts).collect()
80    } else {
81        let lose_acts = lose_cards.map(|card| {
82            let raise = [
83                (Action::Call, Kuhn::Terminal(2.0)),
84                (Action::Fold, Kuhn::Terminal(-1.0)),
85            ];
86            let next = [
87                (Action::Call, Kuhn::Terminal(1.0)),
88                (
89                    Action::Raise,
90                    Kuhn::Gambler(PlayerNum::One, (first, true), raise.into()),
91                ),
92            ];
93            (
94                frac,
95                Kuhn::Gambler(PlayerNum::Two, (card, false), next.into()),
96            )
97        });
98        let win_acts = win_cards.map(|card| {
99            let raise = [
100                (Action::Call, Kuhn::Terminal(-2.0)),
101                (Action::Fold, Kuhn::Terminal(-1.0)),
102            ];
103            let next = [
104                (Action::Call, Kuhn::Terminal(-1.0)),
105                (
106                    Action::Raise,
107                    Kuhn::Gambler(PlayerNum::One, (first, true), raise.into()),
108                ),
109            ];
110            (
111                frac,
112                Kuhn::Gambler(PlayerNum::Two, (card, false), next.into()),
113            )
114        });
115        lose_acts.chain(win_acts).collect()
116    })
117}
118
119fn create_kuhn(num_cards: usize) -> Game<(usize, bool), Action> {
120    Game::from_root(create_kuhn_one(num_cards)).unwrap()
121}
122
123/// Use cfr to find a kuhn poker strategy
124#[derive(Debug, Parser)]
125struct Args {
126    /// The number of cards used for Kuhn Poker
127    #[arg(default_value_t = 3)]
128    num_cards: usize,
129
130    /// The number of iterations to run
131    #[arg(short, long, default_value_t = 10000)]
132    iterations: u64,
133
134    /// The number of threads to use for solving
135    #[arg(short, long, default_value_t = 0)]
136    parallel: usize,
137}
138
139const CARDS: [&str; 13] = [
140    "2", "3", "4", "5", "6", "7", "8", "9", "10", "J", "Q", "K", "A",
141];
142
143fn print_card(card: usize, num_cards: usize) {
144    if num_cards > CARDS.len() {
145        print!("{card}");
146    } else {
147        print!("{}", CARDS[CARDS.len() - num_cards + card]);
148    }
149}
150
151fn print_card_strat<'a>(
152    mut card_strat: Vec<(usize, impl IntoIterator<Item = (&'a Action, f64)>)>,
153    num_cards: usize,
154) {
155    card_strat.sort_by_key(|&(card, _)| card);
156    for (card, actions) in card_strat {
157        print!("holding ");
158        print_card(card, num_cards);
159        print!(":");
160        for (action, prob) in actions {
161            if prob == 1.0 {
162                print!(" always {action:?}");
163            } else if prob > 0.0 {
164                print!(" {prob:.2} {action:?}");
165            }
166        }
167        println!();
168    }
169    println!();
170}
171
172fn main() {
173    let args = Args::parse();
174    let game = create_kuhn(args.num_cards);
175    let (mut strats, _) = game
176        .solve(
177            SolveMethod::External,
178            args.iterations,
179            0.0,
180            args.parallel,
181            None,
182        )
183        .unwrap();
184    strats.truncate(5e-3); // not visible
185    let [player_one_strat, player_two_strat] = strats.as_named();
186    println!("Player One");
187    println!("==========");
188    let mut init = Vec::with_capacity(3);
189    let mut raised = Vec::with_capacity(3);
190    for (&(card, raise), actions) in player_one_strat {
191        if raise { &mut raised } else { &mut init }.push((card, actions));
192    }
193    println!("Initial Action");
194    println!("--------------");
195    print_card_strat(init, args.num_cards);
196    println!("If Player Two Raised");
197    println!("--------------------");
198    print_card_strat(raised, args.num_cards);
199    println!("Player Two");
200    println!("==========");
201    let mut called = Vec::with_capacity(3);
202    let mut raised = Vec::with_capacity(3);
203    for (&(card, raise), actions) in player_two_strat {
204        if raise { &mut raised } else { &mut called }.push((card, actions));
205    }
206    println!("If Player One Called");
207    println!("--------------------");
208    print_card_strat(called, args.num_cards);
209    println!("If Player One Raised");
210    println!("--------------------");
211    print_card_strat(raised, args.num_cards);
212}
213
214#[cfg(test)]
215mod tests {
216    use super::Action;
217    use cfr::{PlayerNum, RegretParams, SolveMethod, Strategies};
218    use rand::{thread_rng, Rng};
219    use rayon::iter::{IntoParallelIterator, ParallelIterator};
220
221    fn infer_alpha(strat: &Strategies<(usize, bool), Action>) -> f64 {
222        let mut alpha = 0.0;
223        let [one, _] = strat.as_named();
224        for (info, actions) in one {
225            match info {
226                (0, false) => {
227                    let (_, prob) = actions
228                        .into_iter()
229                        .find(|(act, _)| act == &&Action::Raise)
230                        .unwrap();
231                    alpha += prob;
232                }
233                (1, true) => {
234                    let (_, prob) = actions
235                        .into_iter()
236                        .find(|(act, _)| act == &&Action::Call)
237                        .unwrap();
238                    alpha += prob - 1.0 / 3.0;
239                }
240                (2, false) => {
241                    let (_, prob) = actions
242                        .into_iter()
243                        .find(|(act, _)| act == &&Action::Raise)
244                        .unwrap();
245                    alpha += prob / 3.0;
246                }
247                _ => (),
248            }
249        }
250        alpha
251    }
252
253    fn create_equilibrium(alpha: f64) -> [Vec<((usize, bool), Vec<(Action, f64)>)>; 2] {
254        assert!(
255            (-1e-2..=1.01).contains(&alpha),
256            "alpha not in proper range: {}",
257            alpha
258        );
259        let alpha = f64::min(f64::max(0.0, alpha), 1.0);
260        let one = vec![
261            (
262                (0, false),
263                vec![(Action::Call, 3.0 - alpha), (Action::Raise, alpha)],
264            ),
265            ((1, false), vec![(Action::Call, 1.0)]),
266            (
267                (2, false),
268                vec![(Action::Call, 1.0 - alpha), (Action::Raise, alpha)],
269            ),
270            ((0, true), vec![(Action::Fold, 1.0)]),
271            (
272                (1, true),
273                vec![(Action::Fold, 2.0 - alpha), (Action::Call, 1.0 + alpha)],
274            ),
275            ((2, true), vec![(Action::Call, 1.0)]),
276        ];
277        let two = vec![
278            ((0, false), vec![(Action::Call, 2.0), (Action::Raise, 1.0)]),
279            ((1, false), vec![(Action::Call, 1.0)]),
280            ((2, false), vec![(Action::Raise, 1.0)]),
281            ((0, true), vec![(Action::Fold, 1.0)]),
282            ((1, true), vec![(Action::Call, 1.0), (Action::Fold, 2.0)]),
283            ((2, true), vec![(Action::Call, 1.0)]),
284        ];
285        [one, two]
286    }
287
288    #[test]
289    fn test_equilibrium() {
290        let game = super::create_kuhn(3);
291
292        let eqm = game.from_named(create_equilibrium(0.5)).unwrap();
293        let info = eqm.get_info();
294
295        let util = info.player_utility(PlayerNum::One);
296        assert!(
297            (util + 1.0 / 18.0).abs() < 1e-3,
298            "utility not close to -1/18: {}",
299            util
300        );
301
302        let eqm_reg = info.regret();
303        assert!(eqm_reg < 0.01, "equilibrium regret too large: {}", eqm_reg);
304
305        let eqm = game.from_named(create_equilibrium(1.0)).unwrap();
306        let info = eqm.get_info();
307
308        let util = info.player_utility(PlayerNum::One);
309        assert!(
310            (util + 1.0 / 18.0).abs() < 1e-3,
311            "utility not close to -1/18: {}",
312            util
313        );
314
315        let eqm_reg = info.regret();
316        assert!(eqm_reg < 0.01, "equilibrium regret too large: {}", eqm_reg);
317    }
318
319    #[test]
320    fn test_regret() {
321        let game = super::create_kuhn(3);
322
323        let [one, _] = create_equilibrium(0.5);
324        let bad = vec![
325            ((0, false), vec![(Action::Raise, 1.0)]),
326            ((1, false), vec![(Action::Raise, 1.0)]),
327            ((2, false), vec![(Action::Call, 1.0)]),
328            ((0, true), vec![(Action::Call, 1.0)]),
329            ((1, true), vec![(Action::Call, 1.0)]),
330            ((2, true), vec![(Action::Fold, 1.0)]),
331        ];
332        let eqm = game.from_named([one, bad]).unwrap();
333        let info = eqm.get_info();
334
335        let util = info.player_utility(PlayerNum::One);
336        assert!(util > 0.0, "utility not positive: {}", util);
337    }
338
339    #[test]
340    fn test_solve_three() {
341        let owned_game = super::create_kuhn(3);
342        let game = &owned_game;
343        // test all methods with multi threading
344        [
345            SolveMethod::Full,
346            SolveMethod::Sampled,
347            SolveMethod::External,
348        ]
349        .into_par_iter()
350        .flat_map(|method| [1, 2].into_par_iter().map(move |threads| (method, threads)))
351        .for_each(|(method, threads)| {
352            thread_rng().fill(&mut [0; 8]);
353            let (mut strategies, bounds) = game
354                .solve(
355                    method,
356                    10_000_000,
357                    0.005,
358                    threads,
359                    Some(RegretParams::vanilla()),
360                )
361                .unwrap();
362            strategies.truncate(1e-3);
363
364            let alpha = infer_alpha(&strategies);
365            let eqm = game.from_named(create_equilibrium(alpha)).unwrap();
366            let [dist_one, dist_two] = strategies.distance(&eqm, 1.0);
367            assert!(
368                dist_one < 0.05,
369                "first player strategy not close enough to alpha equilibrium: {} [{:?} {}]",
370                dist_one,
371                method,
372                threads
373            );
374            assert!(
375                dist_two < 0.05,
376                "second player strategy not close enough to alpha equilibrium: {} [{:?} {}]",
377                dist_two,
378                method,
379                threads
380            );
381
382            let info = strategies.get_info();
383            let util = info.player_utility(PlayerNum::One);
384            assert!(
385                (util + 1.0 / 18.0).abs() < 1e-3,
386                "utility not close to -1/18: {} [{:?} {}]",
387                util,
388                method,
389                threads
390            );
391
392            let bound = bounds.regret_bound();
393            assert!(
394                bound < 0.005,
395                "regret bound not small enough: {} [{:?} {}]",
396                bound,
397                method,
398                threads
399            );
400
401            let regret = info.regret();
402
403            // NOTE with the sampled versions, the bound can be a bit higher
404            let eff_bound = bound * 2.0;
405            assert!(
406                regret <= eff_bound,
407                "regret not less than effective bound: {} > {} [{:?} {}]",
408                regret,
409                eff_bound,
410                method,
411                threads,
412            );
413        });
414    }
415
416    // NOTE discounts don't have accurate regret thresholds
417    #[test]
418    fn test_solve_three_discounts() {
419        let owned_game = super::create_kuhn(3);
420        let game = &owned_game;
421        [
422            ("vanilla", RegretParams::vanilla()),
423            ("lcfr", RegretParams::lcfr()),
424            ("cfr+", RegretParams::cfr_plus()),
425            ("dcfr", RegretParams::dcfr()),
426        ]
427        .into_par_iter()
428        .for_each(|(name, params)| {
429            thread_rng().fill(&mut [0; 8]);
430            let (mut strategies, _) = game
431                .solve(SolveMethod::Full, 10_000, 0.0, 1, Some(params))
432                .unwrap();
433            strategies.truncate(1e-3);
434            let info = strategies.get_info();
435            let util = info.player_utility(PlayerNum::One);
436            assert!(
437                (util + 1.0 / 18.0).abs() < 1e-3,
438                "utility not close to -1/18: {} [{}]",
439                util,
440                name
441            );
442        });
443    }
444}