use rand::{Rng, RngExt, SeedableRng, rngs::StdRng};
use crate::core::{CardBitSet, Deck};
use super::{
Agent, GameState, HoldemSimulation,
agent::FoldingAgent,
cfr::{CFRHistorian, CFRState, TraversalSet},
errors::HoldemSimulationError,
historian::Historian,
};
fn build_deck(game_state: &GameState) -> Deck {
let mut d = CardBitSet::default();
for hand in game_state.hands.iter() {
let bitset: CardBitSet = (*hand).into();
d &= !bitset; }
for card in game_state.board.iter() {
d.remove(*card); }
d.into() }
fn build_agents(num_agents: usize) -> Vec<Box<dyn Agent>> {
(0..num_agents)
.map(|_| -> Box<dyn Agent> { Box::<FoldingAgent>::default() })
.collect()
}
pub struct HoldemSimulationBuilder {
agents: Option<Vec<Box<dyn Agent>>>,
historians: Vec<Box<dyn Historian>>,
game_state: Option<GameState>,
deck: Option<Deck>,
panic_on_historian_error: bool,
cfr_state: Option<CFRState>,
cfr_traversal_set: Option<TraversalSet>,
cfr_allow_node_mutation: bool,
}
impl HoldemSimulationBuilder {
pub fn agents(mut self, agents: Vec<Box<dyn Agent>>) -> Self {
self.agents = Some(agents);
self
}
pub fn game_state(mut self, game_state: GameState) -> Self {
self.game_state = Some(game_state);
self
}
pub fn deck(mut self, deck: Deck) -> Self {
self.deck = Some(deck);
self
}
pub fn historians(mut self, historians: Vec<Box<dyn Historian>>) -> Self {
self.historians = historians;
self
}
pub fn panic_on_historian_error(mut self, panic_on_historian_error: bool) -> Self {
self.panic_on_historian_error = panic_on_historian_error;
self
}
pub fn cfr_context(
mut self,
cfr_state: CFRState,
traversal_set: TraversalSet,
allow_node_mutation: bool,
) -> Self {
self.cfr_state = Some(cfr_state);
self.cfr_traversal_set = Some(traversal_set);
self.cfr_allow_node_mutation = allow_node_mutation;
self
}
pub fn build(self) -> Result<HoldemSimulation, HoldemSimulationError> {
self.build_with_rng(StdRng::from_rng(&mut rand::rng()))
}
pub fn build_with_rng<R: Rng + Send + 'static>(
self,
mut rng: R,
) -> Result<HoldemSimulation, HoldemSimulationError> {
let game_state = self
.game_state
.ok_or(HoldemSimulationError::NeedGameState)?;
let agents = self
.agents
.unwrap_or_else(|| build_agents(game_state.hands.len()));
let agent_historians = agents.iter().filter_map(|a| a.historian());
let mut historians: Vec<_> = self
.historians
.into_iter()
.chain(agent_historians)
.collect();
if let (Some(cfr_state), Some(traversal_set)) = (self.cfr_state, self.cfr_traversal_set) {
let cfr_historian =
CFRHistorian::new(&cfr_state, traversal_set, self.cfr_allow_node_mutation);
historians.push(Box::new(cfr_historian));
}
let deck = self.deck.unwrap_or_else(|| build_deck(&game_state));
let id = rng.random::<u128>();
Ok(HoldemSimulation {
agents,
game_state,
deck,
id,
historians,
panic_on_historian_error: self.panic_on_historian_error,
rng: Box::new(rng),
})
}
}
impl Default for HoldemSimulationBuilder {
fn default() -> Self {
Self {
agents: None,
historians: vec![],
game_state: None,
deck: None,
panic_on_historian_error: true,
cfr_state: None,
cfr_traversal_set: None,
cfr_allow_node_mutation: true,
}
}
}
#[cfg(test)]
mod tests {
use rand::{SeedableRng, rngs::StdRng};
use crate::{arena::action::AgentAction, arena::game_state::Round, core::Card};
use super::*;
use crate::arena::GameStateBuilder;
fn test_game_state(
stacks: Vec<f32>,
big_blind: f32,
small_blind: f32,
ante: f32,
dealer_idx: usize,
) -> GameState {
GameStateBuilder::new()
.stacks(stacks)
.big_blind(big_blind)
.small_blind(small_blind)
.ante(ante)
.dealer_idx(dealer_idx)
.build()
.unwrap()
}
#[tokio::test(flavor = "current_thread")]
async fn test_single_step_agent() {
let stacks = vec![100.0; 9];
let game_state = test_game_state(stacks, 10.0, 5.0, 1.0, 0);
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.build_with_rng(StdRng::seed_from_u64(420))
.unwrap();
assert_eq!(100.0, sim.game_state.stacks[1]);
assert_eq!(100.0, sim.game_state.stacks[2]);
sim.run_round().await;
assert_eq!(100.0, sim.game_state.stacks[1]);
assert_eq!(100.0, sim.game_state.stacks[2]);
sim.run_round().await;
for i in 0..9 {
assert_eq!(99.0, sim.game_state.stacks[i]);
}
sim.run_round().await;
sim.run_round().await;
assert_eq!(6.0, sim.game_state.player_bet[1]);
assert_eq!(11.0, sim.game_state.player_bet[2]);
}
#[tokio::test]
async fn test_simulation_complex_showdown() {
let stacks = vec![102.0, 7.0, 12.0, 102.0, 202.0];
let mut game_state = test_game_state(stacks, 10.0, 5.0, 2.0, 0);
let mut deck = CardBitSet::default();
game_state.advance_round();
game_state.do_bet(2.0, true).unwrap(); game_state.do_bet(2.0, true).unwrap(); game_state.do_bet(2.0, true).unwrap(); game_state.do_bet(2.0, true).unwrap(); game_state.do_bet(2.0, true).unwrap(); game_state.advance_round();
deal_hand_card(0, "Ks", &mut deck, &mut game_state);
deal_hand_card(0, "Kh", &mut deck, &mut game_state);
deal_hand_card(1, "As", &mut deck, &mut game_state);
deal_hand_card(1, "Ac", &mut deck, &mut game_state);
deal_hand_card(2, "Ad", &mut deck, &mut game_state);
deal_hand_card(2, "Ah", &mut deck, &mut game_state);
deal_hand_card(3, "6d", &mut deck, &mut game_state);
deal_hand_card(3, "4d", &mut deck, &mut game_state);
deal_hand_card(4, "9d", &mut deck, &mut game_state);
deal_hand_card(4, "9s", &mut deck, &mut game_state);
game_state.advance_round();
game_state.do_bet(5.0, true).unwrap(); game_state.do_bet(10.0, true).unwrap(); game_state.fold(); game_state.do_bet(10.0, false).unwrap(); game_state.do_bet(10.0, false).unwrap(); game_state.advance_round();
deal_community_card("6c", &mut deck, &mut game_state);
deal_community_card("2d", &mut deck, &mut game_state);
deal_community_card("3d", &mut deck, &mut game_state);
game_state.advance_round();
assert_eq!(game_state.num_active_players(), 2);
game_state.do_bet(90.0, false).unwrap(); game_state.do_bet(90.0, false).unwrap(); game_state.advance_round();
assert_eq!(game_state.num_active_players(), 1);
deal_community_card("8h", &mut deck, &mut game_state);
game_state.advance_round();
game_state.do_bet(0.0, false).unwrap(); game_state.advance_round();
assert_eq!(game_state.num_active_players(), 1);
deal_community_card("8s", &mut deck, &mut game_state);
game_state.advance_round();
game_state.do_bet(100.0, false).unwrap(); game_state.advance_round();
assert_eq!(game_state.num_active_players(), 0);
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.build()
.unwrap();
sim.run().await;
assert_eq!(Round::Complete, sim.game_state.round);
assert_eq!(180.0, sim.game_state.player_winnings[0]);
assert_eq!(15.0, sim.game_state.player_winnings[1]);
assert_eq!(30.0, sim.game_state.player_winnings[2]);
assert_eq!(0.0, sim.game_state.player_winnings[3]);
assert_eq!(100.0, sim.game_state.player_winnings[4]);
assert_eq!(180.0, sim.game_state.stacks[0]);
assert_eq!(15.0, sim.game_state.stacks[1]);
assert_eq!(30.0, sim.game_state.stacks[2]);
assert_eq!(100.0, sim.game_state.stacks[3]);
assert_eq!(100.0, sim.game_state.stacks[4]);
}
fn deal_hand_card(
idx: usize,
card_str: &str,
deck: &mut CardBitSet,
game_state: &mut GameState,
) {
let card = Card::try_from(card_str).unwrap();
assert!(deck.contains(card));
deck.remove(card);
game_state.hands[idx].insert(card);
}
fn deal_community_card(card_str: &str, deck: &mut CardBitSet, game_state: &mut GameState) {
let card = Card::try_from(card_str).unwrap();
assert!(deck.contains(card));
deck.remove(card);
for h in &mut game_state.hands {
h.insert(card);
}
game_state.board.push(card);
}
#[derive(Clone)]
struct InvalidBetAgent {
name: String,
bet_amount: f32,
}
#[async_trait::async_trait]
impl crate::arena::Agent for InvalidBetAgent {
async fn act(&mut self, _id: u128, _game_state: &GameState) -> AgentAction {
AgentAction::Bet(self.bet_amount)
}
fn name(&self) -> &str {
&self.name
}
}
#[tokio::test(flavor = "current_thread")]
async fn test_invalid_bet_triggers_fold() {
let stacks = vec![100.0; 3];
let game_state = test_game_state(stacks, 10.0, 5.0, 0.0, 0);
let invalid_agent = InvalidBetAgent {
name: "InvalidBetAgent".to_string(),
bet_amount: 1.0, };
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.panic_on_historian_error(false)
.agents(vec![
Box::new(invalid_agent.clone()),
Box::new(invalid_agent.clone()),
Box::new(invalid_agent.clone()),
])
.build_with_rng(StdRng::seed_from_u64(42))
.unwrap();
sim.run().await;
assert_eq!(Round::Complete, sim.game_state.round);
}
#[test]
fn test_num_agents() {
let stacks = vec![100.0; 5];
let game_state = test_game_state(stacks, 10.0, 5.0, 0.0, 0);
let sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.build()
.unwrap();
assert_eq!(5, sim.num_agents());
}
#[test]
fn test_max_raises_default_is_three() {
let stacks = vec![100.0; 2];
let game_state = test_game_state(stacks, 10.0, 5.0, 0.0, 0);
let sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.build()
.unwrap();
assert_eq!(Some(3), sim.game_state.max_raises_per_round);
}
#[test]
fn test_max_raises_none_allows_unlimited() {
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.max_raises_per_round(None)
.build()
.unwrap();
let sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.build()
.unwrap();
assert_eq!(None, sim.game_state.max_raises_per_round);
}
#[derive(Clone)]
struct RaisingAgent {
name: String,
}
#[async_trait::async_trait]
impl crate::arena::Agent for RaisingAgent {
async fn act(&mut self, _id: u128, game_state: &GameState) -> AgentAction {
let current_bet = game_state.current_round_bet();
let min_raise = game_state.current_round_min_raise();
AgentAction::Bet(current_bet + min_raise)
}
fn name(&self) -> &str {
&self.name
}
}
#[tokio::test(flavor = "current_thread")]
async fn test_max_raises_converts_raise_to_call() {
use crate::arena::action::Action;
use crate::arena::historian::VecHistorian;
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 1000.0)
.blinds(10.0, 5.0)
.max_raises_per_round(Some(2)) .build()
.unwrap();
let hist = Box::new(VecHistorian::default());
let records = hist.get_storage();
let raiser = RaisingAgent {
name: "Raiser".to_string(),
};
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(vec![Box::new(raiser.clone()), Box::new(raiser.clone())])
.historians(vec![hist])
.build_with_rng(StdRng::seed_from_u64(42))
.unwrap();
sim.run_round().await; sim.run_round().await; sim.run_round().await; sim.run_round().await;
let failed_actions: Vec<_> = records
.lock()
.unwrap()
.iter()
.filter(|r| matches!(r.action, Action::FailedAction(_)))
.cloned()
.collect();
assert!(
!failed_actions.is_empty(),
"Expected some raises to be converted to calls"
);
}
#[tokio::test(flavor = "current_thread")]
async fn test_max_raises_all_in_always_allowed() {
use crate::arena::action::Action;
use crate::arena::historian::VecHistorian;
let mut game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.max_raises_per_round(Some(3)) .build()
.unwrap();
game_state.advance_round(); game_state.advance_round(); game_state.advance_round(); game_state.advance_round(); game_state.do_bet(5.0, true).unwrap(); game_state.do_bet(10.0, true).unwrap();
game_state.round_data.total_raise_count = 3;
let hist = Box::new(VecHistorian::default());
let records = hist.get_storage();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(vec![
Box::<crate::arena::agent::FoldingAgent>::default(),
Box::<crate::arena::agent::FoldingAgent>::default(),
])
.historians(vec![hist])
.build()
.unwrap();
let all_in_bet = sim.game_state.current_round_current_player_bet()
+ sim.game_state.current_player_stack();
assert!(
all_in_bet > sim.game_state.current_round_bet(),
"All-in should be a raise"
);
sim.run_agent_action(AgentAction::Bet(all_in_bet)).await;
let played: Vec<_> = records
.lock()
.unwrap()
.iter()
.filter(|r| matches!(r.action, Action::PlayedAction(_)))
.cloned()
.collect();
let failed: Vec<_> = records
.lock()
.unwrap()
.iter()
.filter(|r| matches!(r.action, Action::FailedAction(_)))
.cloned()
.collect();
assert_eq!(played.len(), 1, "All-in should be recorded as PlayedAction");
assert!(
failed.is_empty(),
"All-in should NOT be recorded as FailedAction"
);
}
#[tokio::test(flavor = "current_thread")]
async fn test_max_raises_resets_each_round() {
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 500.0)
.blinds(10.0, 5.0)
.max_raises_per_round(Some(2))
.build()
.unwrap();
let raiser = RaisingAgent {
name: "Raiser".to_string(),
};
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(vec![Box::new(raiser.clone()), Box::new(raiser.clone())])
.build_with_rng(StdRng::seed_from_u64(42))
.unwrap();
sim.run_round().await; sim.run_round().await; sim.run_round().await; sim.run_round().await;
sim.run_round().await;
assert_eq!(
sim.game_state.round_data.total_raise_count, 0,
"Raise count should reset at the start of each betting round"
);
}
#[tokio::test(flavor = "current_thread")]
async fn test_max_raises_records_failed_action() {
use crate::arena::action::Action;
use crate::arena::historian::VecHistorian;
let mut game_state = GameStateBuilder::new()
.num_players_with_stack(2, 1000.0)
.blinds(10.0, 5.0)
.max_raises_per_round(Some(2))
.build()
.unwrap();
game_state.advance_round(); game_state.advance_round(); game_state.advance_round(); game_state.advance_round(); game_state.do_bet(5.0, true).unwrap(); game_state.do_bet(10.0, true).unwrap();
game_state.round_data.total_raise_count = 2;
let hist = Box::new(VecHistorian::default());
let records = hist.get_storage();
let raiser = RaisingAgent {
name: "Raiser".to_string(),
};
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(vec![Box::new(raiser.clone()), Box::new(raiser.clone())])
.historians(vec![hist])
.build()
.unwrap();
sim.run_agent_action(AgentAction::Bet(20.0)).await;
let failed_actions: Vec<_> = records
.lock()
.unwrap()
.iter()
.filter_map(|r| {
if let Action::FailedAction(payload) = &r.action {
Some(payload.clone())
} else {
None
}
})
.collect();
assert_eq!(
failed_actions.len(),
1,
"Should have exactly one failed action"
);
assert!(
matches!(failed_actions[0].action, AgentAction::Bet(_)),
"Original action should be a Bet"
);
if let AgentAction::Bet(amount) = failed_actions[0].result.action {
assert_eq!(
amount, failed_actions[0].result.starting_bet,
"Result should be a call at the current bet level"
);
} else {
panic!("Result action should be a Bet");
}
}
}