mod action_generator;
mod agent;
mod export;
mod gamestate_iterator_gen;
mod historian;
mod node;
mod state;
mod state_store;
pub use action_generator::{ActionGenerator, BasicCFRActionGenerator};
pub use agent::CFRAgent;
pub use export::{ExportFormat, export_cfr_state, export_to_dot, export_to_png, export_to_svg};
pub use gamestate_iterator_gen::{
FixedGameStateIteratorGen, GameStateIteratorGen, PerRoundFixedGameStateIteratorGen,
};
pub use historian::CFRHistorian;
pub use node::{Node, NodeData, PlayerData, TerminalData};
pub use state::{CFRState, TraversalState};
pub use state_store::StateStore;
#[cfg(test)]
mod tests {
use std::vec;
use crate::arena::cfr::{BasicCFRActionGenerator, FixedGameStateIteratorGen, state_store};
use crate::arena::game_state::{Round, RoundData};
use crate::arena::{Agent, GameState, HoldemSimulation, HoldemSimulationBuilder, test_util};
use crate::core::{Hand, PlayerBitSet};
use super::CFRAgent;
#[test]
fn test_should_fold_all_in() {
let num_agents = 2;
let hand_zero = Hand::new_from_str("AsKsKcAcTh4d8d").unwrap();
let hand_one = Hand::new_from_str("JdTcKcAcTh4d8d").unwrap();
let board = (hand_zero & hand_one).iter().collect();
let stacks: Vec<f32> = vec![0.0, 900.0];
let player_bet = vec![1000.0, 100.0];
let player_bet_round = vec![900.0, 0.0];
let round_data =
RoundData::new_with_bets(100.0, PlayerBitSet::new(num_agents), 1, player_bet_round);
let game_state = GameState::new(
Round::River,
round_data,
board,
vec![hand_zero, hand_one],
stacks,
player_bet,
5.0,
0.0,
0.0,
0,
);
let sim = run(game_state, 10);
assert_eq!(sim.game_state.player_bet[1], 100.0);
assert_eq!(sim.game_state.stacks[0], 1100.0);
assert_eq!(sim.game_state.stacks[1], 900.0);
}
#[test]
fn test_should_go_all_in() {
let num_agents = 2;
let hand_zero = Hand::new_from_str("JdTcKcAcTh4d8d").unwrap();
let hand_one = Hand::new_from_str("KcKsKdAcTh4d8d").unwrap();
let board = (hand_zero & hand_one).iter().collect();
let stacks: Vec<f32> = vec![0.0, 900.0];
let player_bet = vec![1000.0, 100.0];
let player_bet_round = vec![900.0, 0.0];
let round_data =
RoundData::new_with_bets(100.0, PlayerBitSet::new(num_agents), 1, player_bet_round);
let game_state = GameState::new(
Round::River,
round_data,
board,
vec![hand_zero, hand_one],
stacks,
player_bet,
5.0,
0.0,
0.0,
0,
);
let sim = run(game_state, 10);
assert_eq!(sim.game_state.player_bet[1], 1000.0);
assert_eq!(sim.game_state.stacks[1], 2000.0);
}
#[test]
fn test_should_fold_with_one_round_to_go() {
let hand_zero = Hand::new_from_str("AdAcAs5h9hJcKd").unwrap();
let hand_one = Hand::new_from_str("Kc2cAs5h9hJcKd").unwrap();
let game_state = build_from_hands(hand_zero, hand_one, Round::Turn);
let result = run(game_state, 100);
assert_eq!(result.game_state.player_bet[1], 100.0);
}
#[test]
fn test_should_fold_with_two_rounds_to_go() {
let hand_zero = Hand::new_from_str("AsAhAdAcTh").unwrap();
let hand_one = Hand::new_from_str("JsTcAdAcTh").unwrap();
let game_state = build_from_hands(hand_zero, hand_one, Round::Flop);
let result = run(game_state, 100);
assert_eq!(result.game_state.player_bet[1], 100.0);
}
#[test]
fn test_should_fold_after_preflop() {
let hand_zero = Hand::new_from_str("AsAh").unwrap();
let hand_one = Hand::new_from_str("2s7h").unwrap();
let game_state = build_from_hands(hand_zero, hand_one, Round::Preflop);
let result = run(game_state, 100);
assert_eq!(result.game_state.player_bet[1], 100.0);
}
fn build_from_hands(hand_zero: Hand, hand_one: Hand, round: Round) -> GameState {
let board = (hand_zero & hand_one).iter().collect();
let num_agents = 2;
let stacks: Vec<f32> = vec![0.0, 900.0];
let player_bet = vec![1000.0, 100.0];
let player_bet_round = vec![900.0, 0.0];
let round_data =
RoundData::new_with_bets(100.0, PlayerBitSet::new(num_agents), 1, player_bet_round);
GameState::new(
round,
round_data,
board,
vec![hand_zero, hand_one],
stacks,
player_bet,
5.0,
0.0,
0.0,
0,
)
}
fn run(game_state: GameState, num_hands: usize) -> HoldemSimulation {
let mut state_store = state_store::StateStore::new();
let states: Vec<_> = (0..game_state.num_players)
.map(|i| state_store.new_state(game_state.clone(), i))
.collect();
let agents: Vec<_> = states
.iter()
.map(|(cfr_state, traversal_state)| {
Box::new(
CFRAgent::<BasicCFRActionGenerator, FixedGameStateIteratorGen>::new(
state_store.clone(),
cfr_state.clone(),
traversal_state.clone(),
FixedGameStateIteratorGen::new(num_hands),
),
)
})
.collect();
let dyn_agents = agents.into_iter().map(|a| a as Box<dyn Agent>).collect();
let mut rng = rand::rng();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(dyn_agents)
.build()
.unwrap();
sim.run(&mut rng);
assert_eq!(Round::Complete, sim.game_state.round);
test_util::assert_valid_game_state(&sim.game_state);
sim
}
}