mod action_bit_set;
mod action_generator;
mod action_index_mapper;
mod action_picker;
mod action_validator;
mod agent;
mod budget;
mod budget_config;
mod exploration;
mod export;
mod historian;
mod node;
mod node_arena;
mod state;
mod traversal_state;
pub use action_generator::{
ActionGenerator, BasicCFRActionGenerator, ConfigurableActionConfig,
ConfigurableActionConfigError, ConfigurableActionGenerator, PositionCharts,
PreflopChartActionConfig, PreflopChartActionGenerator, PreflopChartConfig,
PreflopChartConfigError, RoundActionConfig, SimpleActionGenerator,
};
pub use action_index_mapper::{
ACTION_IDX_ALL_IN, ACTION_IDX_CALL, ACTION_IDX_FOLD, ACTION_IDX_RAISE_MAX,
ACTION_IDX_RAISE_MIN, ActionIndexMapper, ActionIndexMapperConfig, NUM_ACTION_INDICES,
};
pub use action_picker::{ActionPicker, get_regret_matcher_from_node};
pub use action_validator::validate_actions;
pub use agent::{CFRAgent, CFRAgentBuilder};
pub use budget::{
Budget, Deadline, ExplorationStats, IterationCount, MaxWidth, MostRestrictive, NextStep,
NodeCount, PerDepth, RegretBelow,
};
pub use budget_config::{BudgetConfig, BudgetItem};
pub use exploration::{InFlightLimiter, build_default_limiter, default_limiter_permits};
pub use export::{ExportFormat, export_cfr_state, export_to_dot, export_to_png, export_to_svg};
pub use historian::CFRHistorian;
pub use node::{Node, NodeData, PlayerData, TerminalData};
pub use state::CFRState;
pub use traversal_state::{TraversalSet, TraversalState};
#[cfg(test)]
mod tests {
use std::vec;
use std::sync::Arc;
use crate::arena::cfr::{
ACTION_IDX_FOLD, BasicCFRActionGenerator, Budget, IterationCount, MaxWidth,
MostRestrictive, NUM_ACTION_INDICES, PerDepth, TraversalSet,
};
use crate::arena::game_state::{Round, RoundData};
use crate::arena::{
Agent, GameState, GameStateBuilder, HoldemSimulation, HoldemSimulationBuilder, test_util,
};
use crate::core::{Hand, PlayerBitSet};
use super::CFRAgentBuilder;
#[tokio::test(flavor = "current_thread")]
async fn test_should_fold_all_in() {
let num_agents = 2;
let hand_zero = Hand::new_from_str("AsKsKcAcTh4d8d").unwrap();
let hand_one = Hand::new_from_str("JdTcKcAcTh4d8d").unwrap();
let board = (hand_zero & hand_one).iter().collect::<Vec<_>>();
let stacks: Vec<f32> = vec![0.0, 900.0];
let player_bet = vec![1000.0, 100.0];
let player_bet_round = vec![900.0, 0.0];
let round_data =
RoundData::new_with_bets(100.0, PlayerBitSet::new(num_agents), 1, player_bet_round);
let game_state = GameStateBuilder::new()
.round(Round::River)
.round_data(round_data)
.board(board)
.hands(vec![hand_zero, hand_one])
.stacks(stacks)
.player_bet(player_bet)
.big_blind(5.0)
.small_blind(0.0)
.build()
.unwrap();
let (sim, _cfr) = run(game_state, 5000).await;
assert_eq!(sim.game_state.player_bet[1], 100.0);
assert_eq!(sim.game_state.stacks[0], 1100.0);
assert_eq!(sim.game_state.stacks[1], 900.0);
}
#[tokio::test(flavor = "current_thread")]
async fn test_should_go_all_in() {
let num_agents = 2;
let hand_zero = Hand::new_from_str("JdTcKcAcTh4d8d").unwrap();
let hand_one = Hand::new_from_str("KcKsKdAcTh4d8d").unwrap();
let board = (hand_zero & hand_one).iter().collect::<Vec<_>>();
let stacks: Vec<f32> = vec![0.0, 900.0];
let player_bet = vec![1000.0, 100.0];
let player_bet_round = vec![900.0, 0.0];
let round_data =
RoundData::new_with_bets(100.0, PlayerBitSet::new(num_agents), 1, player_bet_round);
let game_state = GameStateBuilder::new()
.round(Round::River)
.round_data(round_data)
.board(board)
.hands(vec![hand_zero, hand_one])
.stacks(stacks)
.player_bet(player_bet)
.big_blind(5.0)
.small_blind(0.0)
.build()
.unwrap();
let (sim, _cfr) = run(game_state, 50000).await;
assert_eq!(sim.game_state.player_bet[1], 1000.0);
assert_eq!(sim.game_state.stacks[1], 2000.0);
}
#[tokio::test(flavor = "current_thread")]
async fn test_should_fold_with_one_round_to_go() {
let hand_zero = Hand::new_from_str("AdAcAs5h9hKd").unwrap();
let hand_one = Hand::new_from_str("Kc2cAs5h9hKd").unwrap();
let game_state = build_from_hands(hand_zero, hand_one, Round::Turn);
let (result, _cfr) = run(game_state, 200).await;
assert_eq!(result.game_state.player_bet[1], 100.0);
}
#[tokio::test(flavor = "current_thread")]
async fn test_should_fold_with_two_rounds_to_go() {
let hand_zero = Hand::new_from_str("AsAhAdAcTh").unwrap();
let hand_one = Hand::new_from_str("JsTcAdAcTh").unwrap();
let game_state = build_from_hands(hand_zero, hand_one, Round::Flop);
let (result, _cfr) = run(game_state, 200).await;
assert_eq!(result.game_state.player_bet[1], 100.0);
}
#[tokio::test(flavor = "current_thread")]
async fn test_should_fold_after_preflop() {
let hand_zero = Hand::new_from_str("AsAh").unwrap();
let hand_one = Hand::new_from_str("2s7h").unwrap();
let game_state = build_from_hands(hand_zero, hand_one, Round::Preflop);
let (_result, cfr_state) = run(game_state, 50000).await;
let fold_weight = decision_fold_weight(&cfr_state);
assert!(
fold_weight > 0.9,
"player 1 should converge to folding; fold weight = {fold_weight}"
);
}
fn build_from_hands(hand_zero: Hand, hand_one: Hand, round: Round) -> GameState {
let board = (hand_zero & hand_one).iter().collect::<Vec<_>>();
let num_agents = 2;
let stacks: Vec<f32> = vec![0.0, 900.0];
let player_bet = vec![1000.0, 100.0];
let player_bet_round = vec![900.0, 0.0];
let round_data =
RoundData::new_with_bets(100.0, PlayerBitSet::new(num_agents), 1, player_bet_round);
GameStateBuilder::new()
.round(round)
.round_data(round_data)
.board(board)
.hands(vec![hand_zero, hand_one])
.stacks(stacks)
.player_bet(player_bet)
.big_blind(5.0)
.small_blind(0.0)
.build()
.unwrap()
}
fn make_cfr_state(game_state: &GameState) -> super::CFRState {
super::CFRState::new(game_state.clone())
}
fn budget_for_schedule(iters_per_depth: &[usize]) -> Arc<dyn Budget> {
let by_depth: Vec<Arc<dyn Budget>> = iters_per_depth
.iter()
.map(|&h| Arc::new(IterationCount::new(h as u64)) as Arc<dyn Budget>)
.collect();
let iter_caps = Arc::new(PerDepth::new(by_depth, Arc::new(IterationCount::new(1))));
let widths = Arc::new(MaxWidth::new(vec![1; iters_per_depth.len()]));
Arc::new(MostRestrictive::new(vec![iter_caps, widths]))
}
async fn run(game_state: GameState, num_hands: usize) -> (HoldemSimulation, super::CFRState) {
use rand::{SeedableRng, rngs::StdRng};
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[num_hands, 1]);
let agents: Vec<_> = (0..game_state.num_players)
.map(|idx| {
Box::new(
CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name(format!("CFRAgent-run-{idx}"))
.player_idx(idx)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.budget(budget.clone())
.action_gen_config(())
.build(),
)
})
.collect();
let dyn_agents = agents.into_iter().map(|a| a as Box<dyn Agent>).collect();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(dyn_agents)
.cfr_context(cfr_state.clone(), traversal_set, true)
.build_with_rng(StdRng::seed_from_u64(42))
.unwrap();
sim.run().await;
assert_eq!(Round::Complete, sim.game_state.round);
test_util::assert_valid_game_state(&sim.game_state);
(sim, cfr_state)
}
fn decision_fold_weight(cfr_state: &super::CFRState) -> f32 {
for n in 0..cfr_state.node_count() {
if cfr_state.node_avg_regret(n).is_some() {
let mut strategy = [0.0f32; NUM_ACTION_INDICES];
if cfr_state.node_current_strategy_into(n, &mut strategy) {
return strategy[ACTION_IDX_FOLD];
}
}
}
panic!("no trained decision node found in CFR tree");
}
#[test]
fn test_debug_all_in_convergence() {
let num_agents = 2;
let hand_zero = Hand::new_from_str("JdTcKcAcTh4d8d").unwrap();
let hand_one = Hand::new_from_str("KcKsKdAcTh4d8d").unwrap();
let board = (hand_zero & hand_one).iter().collect::<Vec<_>>();
let stacks: Vec<f32> = vec![0.0, 900.0];
let player_bet = vec![1000.0, 100.0];
let player_bet_round = vec![900.0, 0.0];
let round_data =
RoundData::new_with_bets(100.0, PlayerBitSet::new(num_agents), 1, player_bet_round);
let game_state = GameStateBuilder::new()
.round(Round::River)
.round_data(round_data)
.board(board)
.hands(vec![hand_zero, hand_one])
.stacks(stacks)
.player_bet(player_bet)
.big_blind(5.0)
.small_blind(0.0)
.build()
.unwrap();
println!("current_round_bet = {}", game_state.current_round_bet());
println!(
"Player 1 current_round_player_bet = {}",
game_state.current_round_player_bet(1)
);
println!("Player 1 stack = {}", game_state.stacks[1]);
println!(
"Player 1 all-in amount = {}",
game_state.current_round_current_player_bet() + game_state.current_player_stack()
);
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let agent = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("CFRAgent-debug")
.player_idx(1)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set)
.action_gen_config(())
.build();
use crate::arena::cfr::action_generator::ActionGenerator;
use crate::arena::cfr::{ActionIndexMapper, ActionIndexMapperConfig, TraversalState};
let action_gen = BasicCFRActionGenerator::new(
agent.cfr_state().clone(),
TraversalState::new_root(1), );
let actions = action_gen.gen_possible_actions(&game_state);
println!("Valid actions: {:?}", actions);
let mapper = ActionIndexMapper::new(ActionIndexMapperConfig::from_game_state(&game_state));
for action in &actions {
let idx = mapper.action_to_idx(action, &game_state);
println!("Action {:?} maps to index {}", action, idx);
}
}
#[tokio::test(flavor = "current_thread")]
async fn test_cfr_agent_from_starting_round() {
use rand::{SeedableRng, rngs::StdRng};
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[2, 1]);
let agents: Vec<_> = (0..2)
.map(|idx| {
Box::new(
CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name(format!("CFRAgent-starting-{idx}"))
.player_idx(idx)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.budget(budget.clone())
.action_gen_config(())
.build(),
)
})
.collect();
let dyn_agents = agents.into_iter().map(|a| a as Box<dyn Agent>).collect();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(dyn_agents)
.cfr_context(cfr_state.clone(), traversal_set, true)
.build_with_rng(StdRng::seed_from_u64(42))
.unwrap();
sim.run().await;
assert_eq!(Round::Complete, sim.game_state.round);
test_util::assert_valid_game_state(&sim.game_state);
}
#[tokio::test(flavor = "current_thread")]
async fn test_cfr_vs_calling_from_starting_round() {
use crate::arena::agent::CallingAgent;
use rand::{SeedableRng, rngs::StdRng};
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[2, 1]);
let cfr_agent = Box::new(
CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("CFRAgent")
.player_idx(0)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.budget(budget)
.action_gen_config(())
.build(),
);
let calling_agent = Box::new(CallingAgent::new("CallingAgent"));
let agents: Vec<Box<dyn Agent>> = vec![cfr_agent, calling_agent];
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(agents)
.cfr_context(cfr_state.clone(), traversal_set, true)
.build_with_rng(StdRng::seed_from_u64(42))
.unwrap();
sim.run().await;
assert_eq!(Round::Complete, sim.game_state.round);
test_util::assert_valid_game_state(&sim.game_state);
}
#[tokio::test(flavor = "current_thread")]
async fn test_multiple_games_same_cfr_agent() {
use rand::{SeedableRng, rngs::StdRng};
let budget = budget_for_schedule(&[2, 1]);
for game_idx in 0..5 {
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let agents: Vec<Box<dyn Agent>> = (0..2)
.map(|idx| {
Box::new(
CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name(format!("CFRAgent-game{game_idx}-p{idx}"))
.player_idx(idx)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.budget(budget.clone())
.action_gen_config(())
.build(),
) as Box<dyn Agent>
})
.collect();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(agents)
.cfr_context(cfr_state.clone(), traversal_set, true)
.build_with_rng(StdRng::seed_from_u64(42 + game_idx as u64))
.unwrap();
sim.run().await;
assert_eq!(
Round::Complete,
sim.game_state.round,
"Game {game_idx} should complete"
);
test_util::assert_valid_game_state(&sim.game_state);
}
}
#[tokio::test(flavor = "current_thread")]
async fn test_cfr_with_configurable_action_generator() {
use crate::arena::cfr::{
ConfigurableActionConfig, ConfigurableActionGenerator, RoundActionConfig,
};
use rand::{SeedableRng, rngs::StdRng};
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 500.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let action_config = ConfigurableActionConfig {
default: RoundActionConfig {
call_enabled: true,
raise_mult: vec![4.0],
pot_mult: vec![0.5, 1.0],
setup_shove: false,
all_in: true,
},
preflop: None,
flop: None,
turn: None,
river: None,
};
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[2, 1]);
let agents: Vec<_> = (0..2)
.map(|idx| {
Box::new(
CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name(format!("CFRAgent-configurable-{idx}"))
.player_idx(idx)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.budget(budget.clone())
.action_gen_config(action_config.clone())
.allow_node_mutation(true)
.build(),
)
})
.collect();
let dyn_agents = agents.into_iter().map(|a| a as Box<dyn Agent>).collect();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(dyn_agents)
.cfr_context(cfr_state.clone(), traversal_set, true)
.build_with_rng(StdRng::seed_from_u64(42))
.unwrap();
sim.run().await;
assert_eq!(Round::Complete, sim.game_state.round);
test_util::assert_valid_game_state(&sim.game_state);
}
}