1use cfr::{Game, GameNode, IntoGameNode, PlayerNum, SolveMethod};
3use clap::Parser;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6enum Impossible {}
7
8#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
9enum Action {
10 Fold,
11 Call,
12 Raise,
13}
14
15enum Kuhn {
16 Terminal(f64),
17 Deal(Vec<(f64, Kuhn)>),
18 Gambler(PlayerNum, (usize, bool), Vec<(Action, Kuhn)>),
19}
20
21impl IntoGameNode for Kuhn {
22 type PlayerInfo = (usize, bool);
23 type Action = Action;
24 type ChanceInfo = Impossible;
25 type Outcomes = Vec<(f64, Kuhn)>;
26 type Actions = Vec<(Action, Kuhn)>;
27
28 fn into_game_node(self) -> GameNode<Self> {
29 match self {
30 Kuhn::Terminal(payoff) => GameNode::Terminal(payoff),
31 Kuhn::Deal(outcomes) => GameNode::Chance(None, outcomes),
32 Kuhn::Gambler(num, info, actions) => GameNode::Player(num, info, actions),
33 }
34 }
35}
36
37fn create_kuhn_one(num_cards: usize) -> Kuhn {
38 assert!(num_cards > 1);
39 let frac = 1.0 / num_cards as f64;
40 Kuhn::Deal(
41 (0..num_cards)
42 .map(|card| {
43 let next: Vec<_> = [
44 (Action::Call, create_kuhn_two(num_cards, card, false)),
45 (Action::Raise, create_kuhn_two(num_cards, card, true)),
46 ]
47 .into();
48 (frac, Kuhn::Gambler(PlayerNum::One, (card, false), next))
49 })
50 .collect(),
51 )
52}
53
54fn create_kuhn_two(num_cards: usize, first: usize, one_raised: bool) -> Kuhn {
55 let frac = 1.0 / (num_cards - 1) as f64;
56 let lose_cards = 0..first;
57 let win_cards = first + 1..num_cards;
58 Kuhn::Deal(if one_raised {
59 let lose_acts = lose_cards.map(|card| {
60 let next = [
61 (Action::Call, Kuhn::Terminal(2.0)),
62 (Action::Fold, Kuhn::Terminal(1.0)),
63 ];
64 (
65 frac,
66 Kuhn::Gambler(PlayerNum::Two, (card, true), next.into()),
67 )
68 });
69 let win_acts = win_cards.map(|card| {
70 let next = [
71 (Action::Call, Kuhn::Terminal(-2.0)),
72 (Action::Fold, Kuhn::Terminal(1.0)),
73 ];
74 (
75 frac,
76 Kuhn::Gambler(PlayerNum::Two, (card, true), next.into()),
77 )
78 });
79 lose_acts.chain(win_acts).collect()
80 } else {
81 let lose_acts = lose_cards.map(|card| {
82 let raise = [
83 (Action::Call, Kuhn::Terminal(2.0)),
84 (Action::Fold, Kuhn::Terminal(-1.0)),
85 ];
86 let next = [
87 (Action::Call, Kuhn::Terminal(1.0)),
88 (
89 Action::Raise,
90 Kuhn::Gambler(PlayerNum::One, (first, true), raise.into()),
91 ),
92 ];
93 (
94 frac,
95 Kuhn::Gambler(PlayerNum::Two, (card, false), next.into()),
96 )
97 });
98 let win_acts = win_cards.map(|card| {
99 let raise = [
100 (Action::Call, Kuhn::Terminal(-2.0)),
101 (Action::Fold, Kuhn::Terminal(-1.0)),
102 ];
103 let next = [
104 (Action::Call, Kuhn::Terminal(-1.0)),
105 (
106 Action::Raise,
107 Kuhn::Gambler(PlayerNum::One, (first, true), raise.into()),
108 ),
109 ];
110 (
111 frac,
112 Kuhn::Gambler(PlayerNum::Two, (card, false), next.into()),
113 )
114 });
115 lose_acts.chain(win_acts).collect()
116 })
117}
118
119fn create_kuhn(num_cards: usize) -> Game<(usize, bool), Action> {
120 Game::from_root(create_kuhn_one(num_cards)).unwrap()
121}
122
123#[derive(Debug, Parser)]
125struct Args {
126 #[arg(default_value_t = 3)]
128 num_cards: usize,
129
130 #[arg(short, long, default_value_t = 10000)]
132 iterations: u64,
133
134 #[arg(short, long, default_value_t = 0)]
136 parallel: usize,
137}
138
139const CARDS: [&str; 13] = [
140 "2", "3", "4", "5", "6", "7", "8", "9", "10", "J", "Q", "K", "A",
141];
142
143fn print_card(card: usize, num_cards: usize) {
144 if num_cards > CARDS.len() {
145 print!("{card}");
146 } else {
147 print!("{}", CARDS[CARDS.len() - num_cards + card]);
148 }
149}
150
151fn print_card_strat<'a>(
152 mut card_strat: Vec<(usize, impl IntoIterator<Item = (&'a Action, f64)>)>,
153 num_cards: usize,
154) {
155 card_strat.sort_by_key(|&(card, _)| card);
156 for (card, actions) in card_strat {
157 print!("holding ");
158 print_card(card, num_cards);
159 print!(":");
160 for (action, prob) in actions {
161 if prob == 1.0 {
162 print!(" always {action:?}");
163 } else if prob > 0.0 {
164 print!(" {prob:.2} {action:?}");
165 }
166 }
167 println!();
168 }
169 println!();
170}
171
172fn main() {
173 let args = Args::parse();
174 let game = create_kuhn(args.num_cards);
175 let (mut strats, _) = game
176 .solve(
177 SolveMethod::External,
178 args.iterations,
179 0.0,
180 args.parallel,
181 None,
182 )
183 .unwrap();
184 strats.truncate(5e-3); let [player_one_strat, player_two_strat] = strats.as_named();
186 println!("Player One");
187 println!("==========");
188 let mut init = Vec::with_capacity(3);
189 let mut raised = Vec::with_capacity(3);
190 for (&(card, raise), actions) in player_one_strat {
191 if raise { &mut raised } else { &mut init }.push((card, actions));
192 }
193 println!("Initial Action");
194 println!("--------------");
195 print_card_strat(init, args.num_cards);
196 println!("If Player Two Raised");
197 println!("--------------------");
198 print_card_strat(raised, args.num_cards);
199 println!("Player Two");
200 println!("==========");
201 let mut called = Vec::with_capacity(3);
202 let mut raised = Vec::with_capacity(3);
203 for (&(card, raise), actions) in player_two_strat {
204 if raise { &mut raised } else { &mut called }.push((card, actions));
205 }
206 println!("If Player One Called");
207 println!("--------------------");
208 print_card_strat(called, args.num_cards);
209 println!("If Player One Raised");
210 println!("--------------------");
211 print_card_strat(raised, args.num_cards);
212}
213
214#[cfg(test)]
215mod tests {
216 use super::Action;
217 use cfr::{PlayerNum, RegretParams, SolveMethod, Strategies};
218 use rand::{thread_rng, Rng};
219 use rayon::iter::{IntoParallelIterator, ParallelIterator};
220
221 fn infer_alpha(strat: &Strategies<(usize, bool), Action>) -> f64 {
222 let mut alpha = 0.0;
223 let [one, _] = strat.as_named();
224 for (info, actions) in one {
225 match info {
226 (0, false) => {
227 let (_, prob) = actions
228 .into_iter()
229 .find(|(act, _)| act == &&Action::Raise)
230 .unwrap();
231 alpha += prob;
232 }
233 (1, true) => {
234 let (_, prob) = actions
235 .into_iter()
236 .find(|(act, _)| act == &&Action::Call)
237 .unwrap();
238 alpha += prob - 1.0 / 3.0;
239 }
240 (2, false) => {
241 let (_, prob) = actions
242 .into_iter()
243 .find(|(act, _)| act == &&Action::Raise)
244 .unwrap();
245 alpha += prob / 3.0;
246 }
247 _ => (),
248 }
249 }
250 alpha
251 }
252
253 fn create_equilibrium(alpha: f64) -> [Vec<((usize, bool), Vec<(Action, f64)>)>; 2] {
254 assert!(
255 (-1e-2..=1.01).contains(&alpha),
256 "alpha not in proper range: {}",
257 alpha
258 );
259 let alpha = f64::min(f64::max(0.0, alpha), 1.0);
260 let one = vec![
261 (
262 (0, false),
263 vec![(Action::Call, 3.0 - alpha), (Action::Raise, alpha)],
264 ),
265 ((1, false), vec![(Action::Call, 1.0)]),
266 (
267 (2, false),
268 vec![(Action::Call, 1.0 - alpha), (Action::Raise, alpha)],
269 ),
270 ((0, true), vec![(Action::Fold, 1.0)]),
271 (
272 (1, true),
273 vec![(Action::Fold, 2.0 - alpha), (Action::Call, 1.0 + alpha)],
274 ),
275 ((2, true), vec![(Action::Call, 1.0)]),
276 ];
277 let two = vec![
278 ((0, false), vec![(Action::Call, 2.0), (Action::Raise, 1.0)]),
279 ((1, false), vec![(Action::Call, 1.0)]),
280 ((2, false), vec![(Action::Raise, 1.0)]),
281 ((0, true), vec![(Action::Fold, 1.0)]),
282 ((1, true), vec![(Action::Call, 1.0), (Action::Fold, 2.0)]),
283 ((2, true), vec![(Action::Call, 1.0)]),
284 ];
285 [one, two]
286 }
287
288 #[test]
289 fn test_equilibrium() {
290 let game = super::create_kuhn(3);
291
292 let eqm = game.from_named(create_equilibrium(0.5)).unwrap();
293 let info = eqm.get_info();
294
295 let util = info.player_utility(PlayerNum::One);
296 assert!(
297 (util + 1.0 / 18.0).abs() < 1e-3,
298 "utility not close to -1/18: {}",
299 util
300 );
301
302 let eqm_reg = info.regret();
303 assert!(eqm_reg < 0.01, "equilibrium regret too large: {}", eqm_reg);
304
305 let eqm = game.from_named(create_equilibrium(1.0)).unwrap();
306 let info = eqm.get_info();
307
308 let util = info.player_utility(PlayerNum::One);
309 assert!(
310 (util + 1.0 / 18.0).abs() < 1e-3,
311 "utility not close to -1/18: {}",
312 util
313 );
314
315 let eqm_reg = info.regret();
316 assert!(eqm_reg < 0.01, "equilibrium regret too large: {}", eqm_reg);
317 }
318
319 #[test]
320 fn test_regret() {
321 let game = super::create_kuhn(3);
322
323 let [one, _] = create_equilibrium(0.5);
324 let bad = vec![
325 ((0, false), vec![(Action::Raise, 1.0)]),
326 ((1, false), vec![(Action::Raise, 1.0)]),
327 ((2, false), vec![(Action::Call, 1.0)]),
328 ((0, true), vec![(Action::Call, 1.0)]),
329 ((1, true), vec![(Action::Call, 1.0)]),
330 ((2, true), vec![(Action::Fold, 1.0)]),
331 ];
332 let eqm = game.from_named([one, bad]).unwrap();
333 let info = eqm.get_info();
334
335 let util = info.player_utility(PlayerNum::One);
336 assert!(util > 0.0, "utility not positive: {}", util);
337 }
338
339 #[test]
340 fn test_solve_three() {
341 let owned_game = super::create_kuhn(3);
342 let game = &owned_game;
343 [
345 SolveMethod::Full,
346 SolveMethod::Sampled,
347 SolveMethod::External,
348 ]
349 .into_par_iter()
350 .flat_map(|method| [1, 2].into_par_iter().map(move |threads| (method, threads)))
351 .for_each(|(method, threads)| {
352 thread_rng().fill(&mut [0; 8]);
353 let (mut strategies, bounds) = game
354 .solve(
355 method,
356 10_000_000,
357 0.005,
358 threads,
359 Some(RegretParams::vanilla()),
360 )
361 .unwrap();
362 strategies.truncate(1e-3);
363
364 let alpha = infer_alpha(&strategies);
365 let eqm = game.from_named(create_equilibrium(alpha)).unwrap();
366 let [dist_one, dist_two] = strategies.distance(&eqm, 1.0);
367 assert!(
368 dist_one < 0.05,
369 "first player strategy not close enough to alpha equilibrium: {} [{:?} {}]",
370 dist_one,
371 method,
372 threads
373 );
374 assert!(
375 dist_two < 0.05,
376 "second player strategy not close enough to alpha equilibrium: {} [{:?} {}]",
377 dist_two,
378 method,
379 threads
380 );
381
382 let info = strategies.get_info();
383 let util = info.player_utility(PlayerNum::One);
384 assert!(
385 (util + 1.0 / 18.0).abs() < 1e-3,
386 "utility not close to -1/18: {} [{:?} {}]",
387 util,
388 method,
389 threads
390 );
391
392 let bound = bounds.regret_bound();
393 assert!(
394 bound < 0.005,
395 "regret bound not small enough: {} [{:?} {}]",
396 bound,
397 method,
398 threads
399 );
400
401 let regret = info.regret();
402
403 let eff_bound = bound * 2.0;
405 assert!(
406 regret <= eff_bound,
407 "regret not less than effective bound: {} > {} [{:?} {}]",
408 regret,
409 eff_bound,
410 method,
411 threads,
412 );
413 });
414 }
415
416 #[test]
418 fn test_solve_three_discounts() {
419 let owned_game = super::create_kuhn(3);
420 let game = &owned_game;
421 [
422 ("vanilla", RegretParams::vanilla()),
423 ("lcfr", RegretParams::lcfr()),
424 ("cfr+", RegretParams::cfr_plus()),
425 ("dcfr", RegretParams::dcfr()),
426 ]
427 .into_par_iter()
428 .for_each(|(name, params)| {
429 thread_rng().fill(&mut [0; 8]);
430 let (mut strategies, _) = game
431 .solve(SolveMethod::Full, 10_000, 0.0, 1, Some(params))
432 .unwrap();
433 strategies.truncate(1e-3);
434 let info = strategies.get_info();
435 let util = info.player_utility(PlayerNum::One);
436 assert!(
437 (util + 1.0 / 18.0).abs() < 1e-3,
438 "utility not close to -1/18: {} [{}]",
439 util,
440 name
441 );
442 });
443 }
444}