mod builder;
mod engine;
mod fast_forward;
mod hand_log;
mod reward_context;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use tracing::event;
use crate::arena::{Agent, GameState, action::AgentAction};
use super::{
ActionPicker, action_generator::ActionGenerator, action_validator::validate_actions,
get_regret_matcher_from_node,
};
pub use builder::CFRAgentBuilder;
pub use engine::CFRAgent;
pub(super) struct AbortOnDrop(pub(super) tokio::task::JoinHandle<()>);
impl Drop for AbortOnDrop {
fn drop(&mut self) {
self.0.abort();
}
}
#[async_trait::async_trait]
impl<T> Agent for CFRAgent<T>
where
T: ActionGenerator + Send + 'static,
T::Config: Send + Sync,
{
async fn act(&mut self, id: u128, game_state: &GameState) -> AgentAction {
event!(tracing::Level::TRACE, ?id, "Agent acting");
assert!(
game_state.round_data.to_act_idx == self.traversal_state.player_idx() as usize,
"Agent should only be called when it's the player's turn"
);
let player_idx = self.traversal_state.player_idx() as usize;
assert!(
game_state.hands[player_idx].count() == 2 || game_state.hands[player_idx].count() >= 5,
"Agent should only be called when it has at least 2 cards"
);
self.ensure_target_node();
if let Some(force_action) = self.forced_action.take() {
event!(
tracing::Level::DEBUG,
?force_action,
"Playing forced action"
);
let valid_actions = self.action_generator.gen_possible_actions(game_state);
match &force_action {
AgentAction::Fold => {
if valid_actions.contains(&AgentAction::Fold) {
force_action
} else {
event!(
tracing::Level::WARN,
"Forced Fold action invalid, using first valid action"
);
valid_actions.first().cloned().unwrap_or(AgentAction::Fold)
}
}
AgentAction::AllIn => {
force_action
}
AgentAction::Call => {
force_action
}
AgentAction::Bet(amount) => {
let forced_idx = self
.action_index_mapper
.action_to_idx(&force_action, game_state);
if let Some(valid_action) = valid_actions.iter().find(|a| {
self.action_index_mapper.action_to_idx(a, game_state) == forced_idx
}) {
valid_action.clone()
} else {
event!(
tracing::Level::WARN,
?force_action,
forced_idx = forced_idx,
amount = amount,
current_bet = game_state.current_round_bet(),
min_raise = game_state.current_round_min_raise(),
"Forced Bet action index not valid, using first valid action"
);
valid_actions.first().cloned().unwrap_or(AgentAction::Fold)
}
}
}
} else {
self.ensure_regret_matcher();
if self.depth == 0 {
self.stop = Arc::new(AtomicBool::new(false));
}
self.explore_all_actions(game_state).await;
let raw_actions = self.action_generator.gen_possible_actions(game_state);
let possible_actions = validate_actions(raw_actions, game_state);
let target_node_idx = self.target_node_idx().unwrap();
self.cfr_state.with_node_data(target_node_idx, |node_data| {
let regret_matcher = get_regret_matcher_from_node(node_data);
let picker = ActionPicker::new(
&self.action_index_mapper,
&possible_actions,
regret_matcher,
game_state,
);
picker.pick_action(&mut rand::rng())
})
}
}
fn name(&self) -> &str {
self.name.as_ref()
}
fn historian(&self) -> Option<Box<dyn crate::arena::Historian>> {
if self.depth == 0 && self.estimator.needs_history() {
self.hand_log.as_ref().map(|log| {
Box::new(hand_log::HandLogHistorian::new(log.clone()))
as Box<dyn crate::arena::Historian>
})
} else {
None
}
}
}
#[cfg(test)]
#[derive(Default, Clone)]
pub(crate) struct HistoryNeedingStub {
pub observed_counts: std::sync::Arc<std::sync::Mutex<Vec<usize>>>,
}
#[cfg(test)]
#[async_trait::async_trait]
impl crate::arena::HandDistributionEstimator for HistoryNeedingStub {
async fn estimate(
&self,
game_state: &GameState,
history: Option<&crate::arena::GameLog<'_>>,
) -> crate::arena::OpponentRanges {
let n = history.map(|h| h.actions.len()).unwrap_or(0);
self.observed_counts
.lock()
.expect("observed_counts poisoned")
.push(n);
crate::arena::hand_estimator::KnownHandsEstimator
.estimate(game_state, None)
.await
}
fn needs_history(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use little_sorry::{PcfrPlusRegretMatcher, RegretMinimizer};
use rand::{SeedableRng, rngs::StdRng};
use crate::arena::GameStateBuilder;
use crate::arena::agent::CallingAgent;
use crate::arena::cfr::{
BasicCFRActionGenerator, ConfigurableActionConfig, ConfigurableActionGenerator,
IterationCount, MaxWidth, MostRestrictive, PerDepth, TraversalSet,
};
use crate::arena::{Agent, HoldemSimulationBuilder};
use super::super::{Budget, CFRState, NUM_ACTION_INDICES};
use super::fast_forward::*;
use super::*;
fn make_cfr_state(game_state: &GameState) -> CFRState {
CFRState::new(game_state.clone())
}
fn budget_for_schedule(iters_per_depth: &[usize]) -> Arc<dyn Budget> {
let by_depth: Vec<Arc<dyn Budget>> = iters_per_depth
.iter()
.map(|&h| Arc::new(IterationCount::new(h as u64)) as Arc<dyn Budget>)
.collect();
let iter_caps = Arc::new(PerDepth::new(by_depth, Arc::new(IterationCount::new(1))));
let widths = Arc::new(MaxWidth::new(vec![1; iters_per_depth.len()]));
Arc::new(MostRestrictive::new(vec![iter_caps, widths]))
}
#[tokio::test(flavor = "current_thread")]
async fn test_cfr_vs_non_cfr_agent() {
let stacks: Vec<f32> = vec![50.0, 50.0];
let game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(5.0, 2.5)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[1]);
let cfr_agent = Box::new(
CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("CFRAgent-player1")
.player_idx(1)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.budget(budget)
.action_gen_config(())
.build(),
);
let calling_agent = Box::new(CallingAgent::new("CallingAgent-player0"));
let agents: Vec<Box<dyn Agent>> = vec![calling_agent, cfr_agent];
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(agents)
.cfr_context(cfr_state.clone(), traversal_set, true)
.build()
.unwrap();
sim.run().await;
}
#[tokio::test(flavor = "current_thread")]
async fn test_act_returns_only_validated_actions_when_cap_reached() {
use crate::arena::cfr::action_validator::validate_actions;
use crate::core::{Card, Hand, Suit, Value};
let mut game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.max_raises_per_round(Some(2))
.build()
.unwrap();
game_state.advance_round(); game_state.advance_round(); game_state.advance_round(); game_state.do_bet(5.0, true).unwrap();
game_state.do_bet(10.0, true).unwrap();
let mut hand0 = Hand::default();
hand0.insert(Card::new(Value::Ace, Suit::Spade));
hand0.insert(Card::new(Value::King, Suit::Spade));
let mut hand1 = Hand::default();
hand1.insert(Card::new(Value::Queen, Suit::Heart));
hand1.insert(Card::new(Value::Jack, Suit::Heart));
game_state.hands[0] = hand0;
game_state.hands[1] = hand1;
game_state.round_data.total_raise_count = 2;
assert!(game_state.is_raise_capped());
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[1]);
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFRAgent-cap-test")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set)
.budget(budget)
.action_gen_config(ConfigurableActionConfig::default())
.build();
let raw_actions = agent.action_generator.gen_possible_actions(&game_state);
let validated_actions = validate_actions(raw_actions.clone(), &game_state);
assert!(
validated_actions.len() < raw_actions.len(),
"validate_actions should filter raises once the cap is reached \
(raw={raw_actions:?}, validated={validated_actions:?})"
);
for i in 0..32u128 {
let action = agent.act(i, &game_state).await;
assert!(
validated_actions.contains(&action),
"CFRAgent returned {action:?}, not in validated set \
{validated_actions:?} (raw set was {raw_actions:?})"
);
}
}
#[test]
fn test_create_agent() {
let game_state = GameStateBuilder::new()
.num_players_with_stack(3, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let _ = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("CFRAgent-test")
.player_idx(0)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set)
.action_gen_config(())
.build();
}
#[tokio::test(flavor = "current_thread")]
async fn test_run_heads_up() {
let num_agents = 2;
let stacks: Vec<f32> = vec![50.0, 50.0];
let game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(5.0, 2.5)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[2, 1]);
let agents: Vec<Box<dyn Agent>> = (0..num_agents)
.map(|i| {
Box::new(
CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name(format!("CFRAgent-test-{i}"))
.player_idx(i)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.budget(budget.clone())
.action_gen_config(())
.build(),
) as Box<dyn Agent>
})
.collect();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(agents)
.cfr_context(cfr_state.clone(), traversal_set, true)
.build()
.unwrap();
sim.run().await;
}
#[test]
fn test_shared_cfr_states_between_agents() {
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let agent0 = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("Agent0")
.player_idx(0)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.action_gen_config(())
.build();
let agent1 = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("Agent1")
.player_idx(1)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set)
.action_gen_config(())
.build();
let state0 = agent0.cfr_state();
let state1 = agent1.cfr_state();
assert!(state0.get_node_data(0).is_some());
assert!(state1.get_node_data(0).is_some());
assert!(state0.get_child(0, 0).is_none()); assert!(state1.get_child(0, 0).is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_run_heads_up_parallel() {
let num_agents = 2;
let stacks: Vec<f32> = vec![50.0, 50.0];
let game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(5.0, 2.5)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[2, 1]);
let agents: Vec<Box<dyn Agent>> = (0..num_agents)
.map(|i| {
Box::new(
CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name(format!("CFRAgent-par-{i}"))
.player_idx(i)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.budget(budget.clone())
.action_gen_config(())
.build(),
) as Box<dyn Agent>
})
.collect();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(agents)
.cfr_context(cfr_state.clone(), traversal_set, true)
.build()
.unwrap();
sim.run().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_parallel_builds_cfr_tree() {
let stacks: Vec<f32> = vec![50.0, 50.0];
let game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(5.0, 2.5)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[2, 1]);
let agents: Vec<Box<dyn Agent>> = (0..2)
.map(|i| {
Box::new(
CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name(format!("CFRAgent-par-{i}"))
.player_idx(i)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.budget(budget.clone())
.action_gen_config(())
.build(),
) as Box<dyn Agent>
})
.collect();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(agents)
.cfr_context(cfr_state.clone(), traversal_set, true)
.build()
.unwrap();
sim.run().await;
assert!(
cfr_state.node_count() > 1,
"CFR tree should have grown during simulation"
);
}
#[test]
fn test_sub_simulation_traversal_isolation() {
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let traversal_set = TraversalSet::new(game_state.num_players);
let initial = traversal_set.get(0);
let initial_node = initial.node_idx();
let initial_child = initial.chosen_child_idx();
let forked = traversal_set.fork();
let sub_traversal = forked.get(0);
assert_eq!(sub_traversal.node_idx(), initial_node);
assert_eq!(sub_traversal.chosen_child_idx(), initial_child);
sub_traversal.move_to(5, 3);
assert_eq!(forked.get(0).node_idx(), 5);
assert_eq!(forked.get(0).chosen_child_idx(), 3);
assert_eq!(traversal_set.get(0).node_idx(), initial_node);
assert_eq!(traversal_set.get(0).chosen_child_idx(), initial_child);
}
#[tokio::test(flavor = "current_thread")]
async fn test_river_should_fold_king_high_vs_pair() {
use crate::arena::cfr::action_generator::{
ConfigurableActionConfig, ConfigurableActionGenerator,
};
use crate::arena::game_state::Round;
use crate::core::{Card, Hand, PlayerBitSet, Suit, Value};
let board_cards = vec![
Card::new(Value::Four, Suit::Spade),
Card::new(Value::Four, Suit::Diamond),
Card::new(Value::Nine, Suit::Diamond),
Card::new(Value::Ace, Suit::Heart),
Card::new(Value::Jack, Suit::Diamond),
];
let p0_hole = vec![
Card::new(Value::Nine, Suit::Heart),
Card::new(Value::Queen, Suit::Spade),
];
let p1_hole = vec![
Card::new(Value::Seven, Suit::Spade),
Card::new(Value::King, Suit::Spade),
];
let mut p0_hand = Hand::new_with_cards(p0_hole.clone());
p0_hand.extend(board_cards.iter().copied());
let mut p1_hand = Hand::new_with_cards(p1_hole.clone());
p1_hand.extend(board_cards.iter().copied());
let stacks = vec![0.0, 300.0]; let starting_stacks = vec![700.0, 700.0]; let player_bet = vec![700.0, 400.0];
let round_player_bet = vec![500.0, 0.0];
let mut active = PlayerBitSet::new(2);
active.disable(0);
let round_data = crate::arena::game_state::RoundData::new_with_bets(
10.0, active,
1, round_player_bet,
);
let mut game_state = GameStateBuilder::new()
.round(Round::River)
.round_data(round_data)
.stacks(stacks)
.player_bet(player_bet)
.big_blind(10.0)
.small_blind(5.0)
.hands(vec![p0_hand, p1_hand])
.board(board_cards)
.build()
.unwrap();
game_state.starting_stacks = starting_stacks.into();
assert_eq!(game_state.round, Round::River);
assert_eq!(game_state.to_act_idx(), 1); assert_eq!(game_state.current_round_bet(), 500.0); assert_eq!(game_state.current_player_stack(), 300.0);
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let action_config = ConfigurableActionConfig::default();
let budget = budget_for_schedule(&[24, 3, 1]);
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("TestCFRAgent")
.player_idx(1)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.budget(budget)
.action_gen_config(action_config)
.build();
let possible_actions = agent.action_generator.gen_possible_actions(&game_state);
println!("Possible actions: {:?}", possible_actions);
for action in &possible_actions {
let idx = agent.action_index_mapper.action_to_idx(action, &game_state);
println!(" {:?} -> index {}", action, idx);
}
let chosen_action = agent.act(0, &game_state).await;
println!("Chosen action: {:?}", chosen_action);
assert!(
matches!(chosen_action, AgentAction::Fold),
"Agent should fold K-high facing all-in, but chose {:?}. \
With a fresh regret matcher, 24 iterations of exploration \
should overwhelmingly prefer fold over call.",
chosen_action
);
let target_node_idx = agent.target_node_idx().unwrap();
agent
.cfr_state
.with_node_data(target_node_idx, |node_data| {
let matcher = get_regret_matcher_from_node(node_data).unwrap();
let weights = matcher.best_weight();
let fold_weight = weights[0]; let call_weight = weights[1];
println!(
"Weights after exploration: fold={:.6}, call={:.6}",
fold_weight, call_weight
);
assert!(
fold_weight > 0.99,
"Fold weight should be >0.99, got {:.6}. Call weight: {:.6}",
fold_weight,
call_weight
);
});
}
#[tokio::test(flavor = "current_thread")]
async fn test_river_fold_via_fast_forward_at_depth_zero() {
use crate::arena::cfr::action_generator::{
ConfigurableActionConfig, ConfigurableActionGenerator,
};
use crate::arena::game_state::Round;
use crate::core::{Card, Hand, PlayerBitSet, Suit, Value};
let board_cards = vec![
Card::new(Value::Four, Suit::Spade),
Card::new(Value::Four, Suit::Diamond),
Card::new(Value::Nine, Suit::Diamond),
Card::new(Value::Ace, Suit::Heart),
Card::new(Value::Jack, Suit::Diamond),
];
let p0_hole = vec![
Card::new(Value::Nine, Suit::Heart),
Card::new(Value::Queen, Suit::Spade),
];
let p1_hole = vec![
Card::new(Value::Seven, Suit::Spade),
Card::new(Value::King, Suit::Spade),
];
let mut p0_hand = Hand::new_with_cards(p0_hole);
p0_hand.extend(board_cards.iter().copied());
let mut p1_hand = Hand::new_with_cards(p1_hole);
p1_hand.extend(board_cards.iter().copied());
let stacks = vec![0.0, 300.0];
let starting_stacks = vec![700.0, 700.0];
let player_bet = vec![700.0, 400.0];
let round_player_bet = vec![500.0, 0.0];
let mut active = PlayerBitSet::new(2);
active.disable(0);
let round_data =
crate::arena::game_state::RoundData::new_with_bets(10.0, active, 1, round_player_bet);
let mut game_state = GameStateBuilder::new()
.round(Round::River)
.round_data(round_data)
.stacks(stacks)
.player_bet(player_bet)
.big_blind(10.0)
.small_blind(5.0)
.hands(vec![p0_hand, p1_hand])
.board(board_cards)
.build()
.unwrap();
game_state.starting_stacks = starting_stacks.into();
game_state.total_pot = 1100.0;
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[]);
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFR-ff")
.player_idx(1)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set)
.budget(budget)
.action_gen_config(ConfigurableActionConfig::default())
.build();
let chosen = agent.act(0, &game_state).await;
assert!(
matches!(chosen, AgentAction::Fold),
"Agent should fold via fast-forward path, got {:?}",
chosen
);
}
#[test]
fn test_pcfr_fold_vs_call_with_penalty_equal_to_call() {
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let fold_reward = -604.0_f32;
let call_reward = -1463.0_f32;
let invalid_penalty = -1463.0_f32;
for _ in 0..24 {
let mut rewards = vec![invalid_penalty; NUM_ACTION_INDICES];
rewards[0] = fold_reward; rewards[1] = call_reward; matcher.update_regret(&rewards);
}
let weights = matcher.best_weight();
let fold_weight = weights[0];
let call_weight = weights[1];
println!(
"PCFR+ after 24 iterations: fold={:.6}, call={:.6}",
fold_weight, call_weight
);
assert!(
fold_weight > 0.99,
"Fold should have >99% weight, got {:.4}%",
fold_weight * 100.0
);
assert!(
call_weight < 0.01,
"Call should have <1% weight, got {:.4}%",
call_weight * 100.0
);
}
#[test]
fn fast_forward_river_call_loses_with_worse_hand() {
use crate::arena::game_state::{Round, RoundData};
use crate::core::{Card, Hand, PlayerBitSet, Suit, Value};
let board_cards = vec![
Card::new(Value::Four, Suit::Spade),
Card::new(Value::Four, Suit::Diamond),
Card::new(Value::Nine, Suit::Diamond),
Card::new(Value::Ace, Suit::Heart),
Card::new(Value::Jack, Suit::Diamond),
];
let p0_hole = vec![
Card::new(Value::Nine, Suit::Heart),
Card::new(Value::Queen, Suit::Spade),
];
let p1_hole = vec![
Card::new(Value::Seven, Suit::Spade),
Card::new(Value::King, Suit::Spade),
];
let mut p0_hand = Hand::new_with_cards(p0_hole);
p0_hand.extend(board_cards.iter().copied());
let mut p1_hand = Hand::new_with_cards(p1_hole);
p1_hand.extend(board_cards.iter().copied());
let stacks = vec![0.0, 300.0];
let starting_stacks = vec![700.0, 700.0];
let player_bet = vec![700.0, 400.0];
let round_player_bet = vec![500.0, 0.0];
let mut active = PlayerBitSet::new(2);
active.disable(0);
let round_data = RoundData::new_with_bets(10.0, active, 1, round_player_bet);
let mut game_state = GameStateBuilder::new()
.round(Round::River)
.round_data(round_data)
.stacks(stacks)
.player_bet(player_bet)
.big_blind(10.0)
.small_blind(5.0)
.hands(vec![p0_hand, p1_hand])
.board(board_cards)
.build()
.unwrap();
game_state.starting_stacks = starting_stacks.into();
game_state.total_pot = 1100.0;
game_state.player_all_in = {
let mut pbs = PlayerBitSet::new(2);
pbs.disable(1);
pbs
};
let mut rng = StdRng::seed_from_u64(7);
let mut call_state = game_state.clone();
fast_forward_apply_action(&mut call_state, &AgentAction::Call);
fast_forward_run_to_showdown(&mut call_state, &mut rng);
fast_forward_distribute_pot(&mut call_state);
let call_reward = call_state.player_reward(1);
assert!(
call_reward < -699.0,
"Calling with the losing hand should cost ~700 stack, got {}",
call_reward
);
let mut fold_state = game_state.clone();
fast_forward_apply_action(&mut fold_state, &AgentAction::Fold);
fast_forward_run_to_showdown(&mut fold_state, &mut rng);
fast_forward_distribute_pot(&mut fold_state);
let fold_reward = fold_state.player_reward(1);
assert!(
(fold_reward - (-400.0)).abs() < 0.01,
"Folding should cost exactly the 400 already committed, got {}",
fold_reward
);
assert!(fold_reward > call_reward);
}
#[test]
fn fast_forward_split_pot_on_tie() {
use crate::arena::game_state::{Round, RoundData};
use crate::core::{Card, Hand, PlayerBitSet, Suit, Value};
let board_cards = vec![
Card::new(Value::Ace, Suit::Spade),
Card::new(Value::Ace, Suit::Diamond),
Card::new(Value::King, Suit::Heart),
Card::new(Value::Queen, Suit::Club),
Card::new(Value::Jack, Suit::Diamond),
];
let p0_hole = vec![
Card::new(Value::Two, Suit::Club),
Card::new(Value::Three, Suit::Club),
];
let p1_hole = vec![
Card::new(Value::Two, Suit::Heart),
Card::new(Value::Three, Suit::Spade),
];
let mut p0 = Hand::new_with_cards(p0_hole);
p0.extend(board_cards.iter().copied());
let mut p1 = Hand::new_with_cards(p1_hole);
p1.extend(board_cards.iter().copied());
let round_data = RoundData::new_with_bets(10.0, PlayerBitSet::new(2), 0, vec![0.0, 0.0]);
let mut gs = GameStateBuilder::new()
.round(Round::River)
.round_data(round_data)
.stacks(vec![500.0, 500.0])
.player_bet(vec![500.0, 500.0])
.big_blind(10.0)
.small_blind(5.0)
.hands(vec![p0, p1])
.board(board_cards)
.build()
.unwrap();
gs.starting_stacks = vec![1000.0, 1000.0].into();
gs.total_pot = 1000.0;
let mut rng = StdRng::seed_from_u64(11);
fast_forward_apply_action(&mut gs, &AgentAction::Call);
fast_forward_run_to_showdown(&mut gs, &mut rng);
fast_forward_distribute_pot(&mut gs);
let reward = gs.player_reward(0);
assert!(
reward.abs() < 0.01,
"split should yield ~0 reward, got {}",
reward
);
}
#[test]
fn fast_forward_from_flop_deals_turn_and_river() {
use crate::arena::game_state::{Round, RoundData};
use crate::core::{Card, Hand, PlayerBitSet, Suit, Value};
let board_cards = vec![
Card::new(Value::Two, Suit::Spade),
Card::new(Value::Seven, Suit::Diamond),
Card::new(Value::Jack, Suit::Heart),
];
let p0_hole = vec![
Card::new(Value::Ace, Suit::Spade),
Card::new(Value::Ace, Suit::Diamond),
];
let p1_hole = vec![
Card::new(Value::King, Suit::Spade),
Card::new(Value::Queen, Suit::Diamond),
];
let mut p0 = Hand::new_with_cards(p0_hole);
p0.extend(board_cards.iter().copied());
let mut p1 = Hand::new_with_cards(p1_hole);
p1.extend(board_cards.iter().copied());
let round_data = RoundData::new_with_bets(10.0, PlayerBitSet::new(2), 0, vec![0.0, 0.0]);
let mut gs = GameStateBuilder::new()
.round(Round::Flop)
.round_data(round_data)
.stacks(vec![100.0, 100.0])
.player_bet(vec![10.0, 10.0])
.big_blind(10.0)
.small_blind(5.0)
.hands(vec![p0, p1])
.board(board_cards)
.build()
.unwrap();
gs.starting_stacks = vec![110.0, 110.0].into();
gs.total_pot = 20.0;
let mut rng = StdRng::seed_from_u64(3);
fast_forward_apply_action(&mut gs, &AgentAction::Call);
fast_forward_run_to_showdown(&mut gs, &mut rng);
fast_forward_distribute_pot(&mut gs);
assert_eq!(gs.board.len(), 5);
}
#[tokio::test(flavor = "current_thread")]
async fn test_rbp_preserves_fold_decision() {
use crate::arena::cfr::action_generator::{
ConfigurableActionConfig, ConfigurableActionGenerator,
};
use crate::arena::game_state::Round;
use crate::core::{Card, Hand, PlayerBitSet, Suit, Value};
let board_cards = vec![
Card::new(Value::Four, Suit::Spade),
Card::new(Value::Four, Suit::Diamond),
Card::new(Value::Nine, Suit::Diamond),
Card::new(Value::Ace, Suit::Heart),
Card::new(Value::Jack, Suit::Diamond),
];
let p0_hole = vec![
Card::new(Value::Nine, Suit::Heart),
Card::new(Value::Queen, Suit::Spade),
];
let p1_hole = vec![
Card::new(Value::Seven, Suit::Spade),
Card::new(Value::King, Suit::Spade),
];
let mut p0_hand = Hand::new_with_cards(p0_hole);
p0_hand.extend(board_cards.iter().copied());
let mut p1_hand = Hand::new_with_cards(p1_hole);
p1_hand.extend(board_cards.iter().copied());
let stacks = vec![0.0, 300.0];
let starting_stacks = vec![700.0, 700.0];
let player_bet = vec![700.0, 400.0];
let round_player_bet = vec![500.0, 0.0];
let mut active = PlayerBitSet::new(2);
active.disable(0);
let round_data =
crate::arena::game_state::RoundData::new_with_bets(10.0, active, 1, round_player_bet);
let mut game_state = GameStateBuilder::new()
.round(Round::River)
.round_data(round_data)
.stacks(stacks)
.player_bet(player_bet)
.big_blind(10.0)
.small_blind(5.0)
.hands(vec![p0_hand, p1_hand])
.board(board_cards)
.build()
.unwrap();
game_state.starting_stacks = starting_stacks.into();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[24, 3, 1]);
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFR-RBP-test")
.player_idx(1)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set)
.budget(budget)
.action_gen_config(ConfigurableActionConfig::default())
.build();
let chosen = agent.act(0, &game_state).await;
assert!(
matches!(chosen, AgentAction::Fold),
"RBP should preserve the fold decision for K-high vs pair, got {:?}",
chosen
);
}
#[tokio::test(flavor = "current_thread")]
async fn test_rbp_reduces_active_actions() {
use crate::arena::cfr::action_generator::{
ConfigurableActionConfig, ConfigurableActionGenerator,
};
use crate::arena::game_state::Round;
use crate::core::{Card, Hand, PlayerBitSet, Suit, Value};
let board_cards = vec![
Card::new(Value::Four, Suit::Spade),
Card::new(Value::Four, Suit::Diamond),
Card::new(Value::Nine, Suit::Diamond),
Card::new(Value::Ace, Suit::Heart),
Card::new(Value::Jack, Suit::Diamond),
];
let p0_hole = vec![
Card::new(Value::Nine, Suit::Heart),
Card::new(Value::Queen, Suit::Spade),
];
let p1_hole = vec![
Card::new(Value::Seven, Suit::Spade),
Card::new(Value::King, Suit::Spade),
];
let mut p0_hand = Hand::new_with_cards(p0_hole);
p0_hand.extend(board_cards.iter().copied());
let mut p1_hand = Hand::new_with_cards(p1_hole);
p1_hand.extend(board_cards.iter().copied());
let stacks = vec![0.0, 300.0];
let starting_stacks = vec![700.0, 700.0];
let player_bet = vec![700.0, 400.0];
let round_player_bet = vec![500.0, 0.0];
let mut active = PlayerBitSet::new(2);
active.disable(0);
let round_data =
crate::arena::game_state::RoundData::new_with_bets(10.0, active, 1, round_player_bet);
let mut game_state = GameStateBuilder::new()
.round(Round::River)
.round_data(round_data)
.stacks(stacks)
.player_bet(player_bet)
.big_blind(10.0)
.small_blind(5.0)
.hands(vec![p0_hand, p1_hand])
.board(board_cards)
.build()
.unwrap();
game_state.starting_stacks = starting_stacks.into();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[24, 3, 1]);
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFR-RBP-sparse")
.player_idx(1)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set)
.budget(budget)
.action_gen_config(ConfigurableActionConfig::default())
.build();
let _ = agent.act(0, &game_state).await;
let target_node_idx = agent.target_node_idx().unwrap();
let (active_set, num_updates) = agent.cfr_state.get_pruning_info(target_node_idx);
println!(
"After exploration: {} active actions, {} updates",
active_set.count(),
num_updates
);
assert!(
num_updates >= 6,
"Expected >= 6 updates (warmup + a few prunes), got {}",
num_updates
);
assert!(
active_set.count() <= 3,
"Expected <= 3 active actions after pruning, got {}. \
In this lopsided fold-vs-call scenario, most actions should be pruned.",
active_set.count()
);
}
#[test]
fn test_flop_sample_variance_vs_single_sample() {
use crate::arena::game_state::Round;
use crate::core::{Card, Hand, Suit, Value};
use rand::SeedableRng;
let p0_hole = vec![
Card::new(Value::Ace, Suit::Spade),
Card::new(Value::King, Suit::Spade),
];
let p1_hole = vec![
Card::new(Value::Queen, Suit::Heart),
Card::new(Value::Jack, Suit::Heart),
];
let p0_hand = Hand::new_with_cards(p0_hole);
let p1_hand = Hand::new_with_cards(p1_hole);
let mut game_state = GameStateBuilder::new()
.round(Round::DealFlop)
.stacks(vec![0.0, 0.0])
.player_bet(vec![500.0, 500.0])
.big_blind(10.0)
.small_blind(5.0)
.hands(vec![p0_hand, p1_hand])
.build()
.unwrap();
game_state.starting_stacks = vec![500.0, 500.0].into();
let num_trials = 50;
let multi_results: Vec<f32> = (0..num_trials)
.map(|seed| {
let mut rng = StdRng::seed_from_u64(seed);
fast_forward_sample_flop_enumerate_runout(&game_state, 0, &mut rng)
})
.collect();
let single_results: Vec<f32> = (0..num_trials)
.map(|seed| {
let mut gs = game_state.clone();
let mut rng = StdRng::seed_from_u64(seed);
fast_forward_run_to_showdown(&mut gs, &mut rng);
fast_forward_distribute_pot(&mut gs);
gs.player_reward(0)
})
.collect();
fn std_dev(vals: &[f32]) -> f64 {
let n = vals.len() as f64;
let mean = vals.iter().map(|&v| v as f64).sum::<f64>() / n;
let var = vals.iter().map(|&v| (v as f64 - mean).powi(2)).sum::<f64>() / n;
var.sqrt()
}
let multi_std = std_dev(&multi_results);
let single_std = std_dev(&single_results);
let multi_mean: f64 =
multi_results.iter().map(|&v| v as f64).sum::<f64>() / num_trials as f64;
let single_mean: f64 =
single_results.iter().map(|&v| v as f64).sum::<f64>() / num_trials as f64;
println!(
"Multi-sample flop (k={}): mean={:.2}, std={:.2}",
FLOP_SAMPLES, multi_mean, multi_std
);
println!(
"Single-sample runout: mean={:.2}, std={:.2}",
single_mean, single_std
);
println!(
"Variance reduction: {:.1}x",
if multi_std > 0.0 {
single_std / multi_std
} else {
f64::INFINITY
}
);
assert!(
multi_std < single_std,
"Multi-sample flop should have lower variance than single-sample. \
multi_std={:.2}, single_std={:.2}",
multi_std,
single_std
);
assert!(
(multi_mean - single_mean).abs() < 200.0,
"Means should be broadly similar: multi={:.2}, single={:.2}",
multi_mean,
single_mean
);
}
#[test]
fn test_flop_sample_dominated_hand() {
use crate::arena::game_state::Round;
use crate::core::{Card, Hand, Suit, Value};
use rand::SeedableRng;
let p0_hole = vec![
Card::new(Value::Ace, Suit::Spade),
Card::new(Value::King, Suit::Spade),
];
let p1_hole = vec![
Card::new(Value::Seven, Suit::Diamond),
Card::new(Value::Two, Suit::Club),
];
let p0_hand = Hand::new_with_cards(p0_hole);
let p1_hand = Hand::new_with_cards(p1_hole);
let mut game_state = GameStateBuilder::new()
.round(Round::DealFlop)
.stacks(vec![0.0, 0.0])
.player_bet(vec![500.0, 500.0])
.big_blind(10.0)
.small_blind(5.0)
.hands(vec![p0_hand, p1_hand])
.build()
.unwrap();
game_state.starting_stacks = vec![500.0, 500.0].into();
let num_trials = 20;
let total_reward: f64 = (0..num_trials)
.map(|seed| {
let mut rng = StdRng::seed_from_u64(seed + 100);
fast_forward_sample_flop_enumerate_runout(&game_state, 0, &mut rng) as f64
})
.sum();
let avg_reward = total_reward / num_trials as f64;
println!("AKs vs 72o avg reward for AKs: {:.2}", avg_reward);
assert!(
avg_reward > 50.0,
"AKs should have positive EV vs 72o, got {:.2}",
avg_reward
);
}
#[test]
fn test_flop_sample_count_comparison() {
use crate::arena::game_state::Round;
use crate::core::{Card, Hand, Suit, Value};
use rand::SeedableRng;
use std::time::Instant;
let p0_hole = vec![
Card::new(Value::Ace, Suit::Spade),
Card::new(Value::King, Suit::Spade),
];
let p1_hole = vec![
Card::new(Value::Queen, Suit::Heart),
Card::new(Value::Jack, Suit::Heart),
];
let p0_hand = Hand::new_with_cards(p0_hole);
let p1_hand = Hand::new_with_cards(p1_hole);
let mut game_state = GameStateBuilder::new()
.round(Round::DealFlop)
.stacks(vec![0.0, 0.0])
.player_bet(vec![500.0, 500.0])
.big_blind(10.0)
.small_blind(5.0)
.hands(vec![p0_hand, p1_hand])
.build()
.unwrap();
game_state.starting_stacks = vec![500.0, 500.0].into();
let num_trials = 100;
fn std_dev(vals: &[f32]) -> f64 {
let n = vals.len() as f64;
let mean = vals.iter().map(|&v| v as f64).sum::<f64>() / n;
let var = vals.iter().map(|&v| (v as f64 - mean).powi(2)).sum::<f64>() / n;
var.sqrt()
}
println!(
"\n{:<8} {:>10} {:>10} {:>12} {:>10}",
"k", "mean", "std", "var_red", "time_us"
);
println!("{}", "-".repeat(56));
let start = Instant::now();
let single_results: Vec<f32> = (0..num_trials)
.map(|seed| {
let mut gs = game_state.clone();
let mut rng = StdRng::seed_from_u64(seed);
fast_forward_run_to_showdown(&mut gs, &mut rng);
fast_forward_distribute_pot(&mut gs);
gs.player_reward(0)
})
.collect();
let single_time = start.elapsed();
let single_std = std_dev(&single_results);
let single_mean: f64 =
single_results.iter().map(|&v| v as f64).sum::<f64>() / num_trials as f64;
println!(
"{:<8} {:>10.2} {:>10.2} {:>12} {:>10}",
"1-samp",
single_mean,
single_std,
"baseline",
single_time.as_micros()
);
for k in [1, 2, 3, 5, 8, 13] {
let start = Instant::now();
let results: Vec<f32> = (0..num_trials)
.map(|seed| {
let mut rng = StdRng::seed_from_u64(seed);
fast_forward_sample_flop_enumerate_runout_n(&game_state, 0, &mut rng, k)
})
.collect();
let elapsed = start.elapsed();
let multi_std = std_dev(&results);
let multi_mean: f64 =
results.iter().map(|&v| v as f64).sum::<f64>() / num_trials as f64;
let var_reduction = if multi_std > 0.0 {
single_std / multi_std
} else {
f64::INFINITY
};
println!(
"{:<8} {:>10.2} {:>10.2} {:>11.1}x {:>10}",
format!("k={k}"),
multi_mean,
multi_std,
var_reduction,
elapsed.as_micros()
);
}
let k3_results: Vec<f32> = (0..num_trials)
.map(|seed| {
let mut rng = StdRng::seed_from_u64(seed);
fast_forward_sample_flop_enumerate_runout_n(&game_state, 0, &mut rng, 3)
})
.collect();
assert!(
std_dev(&k3_results) < single_std,
"k=3 should reduce variance vs single sample"
);
}
#[derive(Clone, Default)]
struct CapturedEvent {
depth: u64,
stop_cause: String,
final_iterations: u64,
final_elapsed_us: u64,
timer_armed: bool,
actions_considered: u64,
regret_series: String,
}
#[derive(Default)]
struct CapturingDiagLayer {
events: std::sync::Arc<std::sync::Mutex<Vec<CapturedEvent>>>,
}
impl CapturingDiagLayer {
fn new() -> Self {
Self::default()
}
fn events(&self) -> std::sync::Arc<std::sync::Mutex<Vec<CapturedEvent>>> {
self.events.clone()
}
}
impl<S: tracing::Subscriber> tracing_subscriber::Layer<S> for CapturingDiagLayer {
fn on_event(
&self,
event: &tracing::Event<'_>,
_ctx: tracing_subscriber::layer::Context<'_, S>,
) {
if event.metadata().target() != "cfr_diag" {
return;
}
let mut captured = CapturedEvent::default();
struct V<'a>(&'a mut CapturedEvent);
impl tracing::field::Visit for V<'_> {
fn record_u64(&mut self, f: &tracing::field::Field, value: u64) {
match f.name() {
"depth" => self.0.depth = value,
"final_iterations" => self.0.final_iterations = value,
"final_elapsed_us" => self.0.final_elapsed_us = value,
"actions_considered" => self.0.actions_considered = value,
_ => {}
}
}
fn record_bool(&mut self, f: &tracing::field::Field, value: bool) {
if f.name() == "timer_armed" {
self.0.timer_armed = value;
}
}
fn record_debug(&mut self, f: &tracing::field::Field, value: &dyn std::fmt::Debug) {
match f.name() {
"stop_cause" => self.0.stop_cause = format!("{value:?}"),
"regret_series" => self.0.regret_series = format!("{value:?}"),
_ => {}
}
}
}
event.record(&mut V(&mut captured));
self.events.lock().unwrap().push(captured);
}
}
#[tokio::test(flavor = "current_thread")]
async fn diag_event_records_iteration_bound_stop() {
use tracing_subscriber::layer::SubscriberExt;
let layer = CapturingDiagLayer::new();
let events = layer.events();
let subscriber = tracing_subscriber::registry()
.with(
tracing_subscriber::filter::Targets::new()
.with_target("cfr_diag", tracing::Level::TRACE),
)
.with(layer);
let _guard = tracing::subscriber::set_default(subscriber);
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let budget = budget_for_schedule(&[5]);
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFRAgent-iter-bound")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.action_gen_config(ConfigurableActionConfig::default())
.budget(budget)
.build();
let _ = agent.act(0, &game_state).await;
let events = events.lock().unwrap();
let root = events
.iter()
.find(|e| e.depth == 0)
.expect("expected at least one depth=0 event");
assert_eq!(root.stop_cause, "budget_stop");
assert_eq!(root.final_iterations, 5);
let series_len = if root.regret_series == "[]" {
0
} else {
root.regret_series.matches(',').count() + 1
};
assert_eq!(
series_len, 5,
"expected 5 entries in regret_series, got '{}'",
root.regret_series
);
}
#[tokio::test(flavor = "current_thread")]
async fn diag_event_records_deadline_stop() {
use crate::arena::cfr::{Deadline, MostRestrictive};
use tracing_subscriber::layer::SubscriberExt;
let layer = CapturingDiagLayer::new();
let events = layer.events();
let subscriber = tracing_subscriber::registry()
.with(
tracing_subscriber::filter::Targets::new()
.with_target("cfr_diag", tracing::Level::TRACE),
)
.with(layer);
let _guard = tracing::subscriber::set_default(subscriber);
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let budget: Arc<dyn Budget> = Arc::new(MostRestrictive::new(vec![
Arc::new(Deadline::new(std::time::Duration::from_millis(1))),
Arc::new(MaxWidth::new(vec![1, 1, 1])),
]));
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFRAgent-deadline")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.action_gen_config(ConfigurableActionConfig::default())
.budget(budget)
.build();
let _ = agent.act(0, &game_state).await;
let events = events.lock().unwrap();
let root = events
.iter()
.find(|e| e.depth == 0)
.expect("expected at least one depth=0 event");
assert_eq!(root.stop_cause, "deadline");
assert!(
root.final_iterations < 1_000_000,
"deadline should have stopped well before any theoretical cap, got {}",
root.final_iterations
);
}
#[tokio::test(flavor = "current_thread")]
async fn diag_event_emitted_at_every_depth() {
use tracing_subscriber::layer::SubscriberExt;
let layer = CapturingDiagLayer::new();
let events = layer.events();
let subscriber = tracing_subscriber::registry()
.with(
tracing_subscriber::filter::Targets::new()
.with_target("cfr_diag", tracing::Level::TRACE),
)
.with(layer);
let _guard = tracing::subscriber::set_default(subscriber);
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let budget = budget_for_schedule(&[2, 1]);
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFRAgent-perdepth")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.action_gen_config(ConfigurableActionConfig::default())
.budget(budget)
.build();
let _ = agent.act(0, &game_state).await;
let events = events.lock().unwrap();
let depths_seen: std::collections::BTreeSet<u64> = events.iter().map(|e| e.depth).collect();
assert!(depths_seen.contains(&0), "expected a depth=0 event");
assert!(
depths_seen.contains(&1),
"expected at least one depth=1 event from recursive sub-agents; saw depths {:?}",
depths_seen
);
}
#[tokio::test(flavor = "current_thread")]
async fn diag_event_records_stable_strategy_stop() {
use tracing_subscriber::layer::SubscriberExt;
let layer = CapturingDiagLayer::new();
let events = layer.events();
let subscriber = tracing_subscriber::registry()
.with(
tracing_subscriber::filter::Targets::new()
.with_target("cfr_diag", tracing::Level::TRACE),
)
.with(layer);
let _guard = tracing::subscriber::set_default(subscriber);
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let budget = budget_for_schedule(&[4096, 1]);
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFRAgent-stable")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.action_gen_config(ConfigurableActionConfig::default())
.budget(budget)
.build();
let _ = agent.act(0, &game_state).await;
let events = events.lock().unwrap();
let any_stable = events.iter().any(|e| e.stop_cause == "stable_strategy");
assert!(
any_stable,
"expected at least one stable_strategy event across all depths; saw causes: {:?}",
events.iter().map(|e| &e.stop_cause).collect::<Vec<_>>()
);
}
fn setup_tiny_heads_up() -> (GameState, CFRState, TraversalSet) {
use crate::core::{Card, Hand, Suit, Value};
let mut game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round(); game_state.advance_round(); game_state.advance_round(); game_state.do_bet(5.0, true).unwrap();
game_state.do_bet(10.0, true).unwrap();
let mut hand0 = Hand::default();
hand0.insert(Card::new(Value::Ace, Suit::Spade));
hand0.insert(Card::new(Value::King, Suit::Spade));
let mut hand1 = Hand::default();
hand1.insert(Card::new(Value::Queen, Suit::Heart));
hand1.insert(Card::new(Value::Jack, Suit::Heart));
game_state.hands[0] = hand0;
game_state.hands[1] = hand1;
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
(game_state, cfr_state, traversal_set)
}
#[tokio::test(flavor = "current_thread")]
async fn act_respects_iteration_budget() {
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let budget: Arc<dyn Budget> = Arc::new(IterationCount::new(1));
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFRAgent-budget")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.action_gen_config(ConfigurableActionConfig::default())
.budget(budget)
.build();
let action = agent.act(0, &game_state).await;
match action {
AgentAction::Fold | AgentAction::Call | AgentAction::Bet(_) | AgentAction::AllIn => {}
}
}
#[tokio::test(flavor = "current_thread")]
async fn engine_populates_avg_regret_after_updates() {
use crate::arena::cfr::{ExplorationStats, NextStep};
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct RecordingBudget {
seen: Arc<Mutex<Vec<Option<f32>>>>,
}
impl Budget for RecordingBudget {
fn next_step(&self, stats: &ExplorationStats) -> NextStep {
self.seen.lock().unwrap().push(stats.avg_regret);
let cap = if stats.depth == 0 { 8 } else { 1 };
if stats.iterations < cap {
NextStep::Wave { width: 1 }
} else {
NextStep::Stop
}
}
}
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let seen = Arc::new(Mutex::new(Vec::new()));
let budget: Arc<dyn Budget> = Arc::new(RecordingBudget { seen: seen.clone() });
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFRAgent-record")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.action_gen_config(ConfigurableActionConfig::default())
.budget(budget)
.build();
let _ = agent.act(0, &game_state).await;
let seen = seen.lock().unwrap();
assert!(
!seen.is_empty(),
"the budget must be consulted at least once"
);
assert_eq!(
seen[0], None,
"the first budget check happens before any completed update"
);
assert!(
seen.iter().any(Option::is_some),
"avg_regret must be populated once the root node has been updated"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn multi_thread_exploration_is_sound() {
use crate::core::{Card, Hand, Suit, Value};
for run in 0u64..10 {
let mut game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
game_state.advance_round();
game_state.advance_round();
game_state.do_bet(5.0, true).unwrap();
game_state.do_bet(10.0, true).unwrap();
let mut hand0 = Hand::default();
hand0.insert(Card::new(Value::Ace, Suit::Spade));
hand0.insert(Card::new(Value::King, Suit::Spade));
let mut hand1 = Hand::default();
hand1.insert(Card::new(Value::Queen, Suit::Heart));
hand1.insert(Card::new(Value::Jack, Suit::Heart));
game_state.hands[0] = hand0;
game_state.hands[1] = hand1;
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget = budget_for_schedule(&[4, 1]);
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name(format!("CFRAgent-mt-{run}"))
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.budget(budget)
.action_gen_config(ConfigurableActionConfig::default())
.build();
let action = agent.act(0, &game_state).await;
match action {
AgentAction::Fold
| AgentAction::Call
| AgentAction::Bet(_)
| AgentAction::AllIn => {}
}
}
}
#[tokio::test(flavor = "current_thread")]
async fn wave_loop_runs_exactly_budget_waves_at_m_one() {
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let budget: Arc<dyn Budget> = Arc::new(MostRestrictive::new(vec![
Arc::new(IterationCount::new(5)),
Arc::new(MaxWidth::new(vec![1])),
]));
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFRAgent-m1")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set)
.budget(budget)
.action_gen_config(ConfigurableActionConfig::default())
.build();
agent.ensure_target_node();
agent.ensure_regret_matcher();
let target_node_idx = agent.target_node_idx().unwrap();
agent.explore_all_actions(&game_state).await;
let (_active, num_updates) = cfr_state.get_pruning_info(target_node_idx);
assert_eq!(
num_updates, 5,
"five waves at M=1 must produce exactly five matcher updates"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn wave_loop_completes_with_size_one_limiter() {
use tokio::sync::Semaphore;
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let budget: Arc<dyn Budget> = Arc::new(MostRestrictive::new(vec![
Arc::new(IterationCount::new(3)),
Arc::new(MaxWidth::new(vec![1])),
]));
let limiter: Arc<Semaphore> = Arc::new(Semaphore::new(1));
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFRAgent-inline")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set)
.budget(budget)
.limiter(limiter)
.action_gen_config(ConfigurableActionConfig::default())
.build();
agent.ensure_target_node();
agent.ensure_regret_matcher();
let target_node_idx = agent.target_node_idx().unwrap();
agent.explore_all_actions(&game_state).await;
let (_active, num_updates) = cfr_state.get_pruning_info(target_node_idx);
assert_eq!(
num_updates, 3,
"the wave loop must complete every budgeted wave even on the inline path"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn act_returns_within_act_deadline() {
use crate::arena::cfr::Deadline;
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let budget: Arc<dyn Budget> = Arc::new(MostRestrictive::new(vec![
Arc::new(Deadline::new(std::time::Duration::from_millis(150))),
Arc::new(MaxWidth::new(vec![1, 1, 1])),
]));
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFRAgent-deadline")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.budget(budget)
.action_gen_config(ConfigurableActionConfig::default())
.build();
let start = std::time::Instant::now();
let action = agent.act(0, &game_state).await;
let elapsed = start.elapsed();
assert!(
elapsed < std::time::Duration::from_secs(2),
"act must stop ~at the 150ms deadline (unbounded budget would loop forever), \
but took {elapsed:?}"
);
match action {
AgentAction::Fold | AgentAction::Call | AgentAction::Bet(_) | AgentAction::AllIn => {}
}
}
async fn run_turn_enumeration_once() -> Vec<u32> {
use crate::arena::game_state::{Round, RoundData};
use crate::core::{Card, Hand, PlayerBitSet, Suit, Value};
let board_cards = vec![
Card::new(Value::Ace, Suit::Heart),
Card::new(Value::King, Suit::Diamond),
Card::new(Value::Queen, Suit::Club),
Card::new(Value::Jack, Suit::Spade),
];
let p0_hole = vec![
Card::new(Value::Ace, Suit::Spade),
Card::new(Value::Ace, Suit::Diamond),
];
let p1_hole = vec![
Card::new(Value::Nine, Suit::Heart),
Card::new(Value::Two, Suit::Club),
];
let mut p0 = Hand::new_with_cards(p0_hole);
p0.extend(board_cards.iter().copied());
let mut p1 = Hand::new_with_cards(p1_hole);
p1.extend(board_cards.iter().copied());
let round_data = RoundData::new_with_bets(10.0, PlayerBitSet::new(2), 1, vec![0.0, 0.0]);
let mut game_state = GameStateBuilder::new()
.round(Round::Turn)
.round_data(round_data)
.stacks(vec![100.0, 100.0])
.player_bet(vec![50.0, 50.0])
.big_blind(10.0)
.small_blind(5.0)
.hands(vec![p0, p1])
.board(board_cards)
.build()
.unwrap();
game_state.starting_stacks = vec![150.0, 150.0].into();
game_state.total_pot = 100.0;
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let budget: Arc<dyn Budget> = Arc::new(MostRestrictive::new(vec![
Arc::new(IterationCount::new(1)),
Arc::new(MaxWidth::new(vec![])),
]));
let mut agent = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("CFR-enum-stability")
.player_idx(1) .cfr_state(cfr_state.clone())
.traversal_set(traversal_set)
.budget(budget)
.action_gen_config(())
.build();
agent.ensure_target_node();
agent.ensure_regret_matcher();
let target_node_idx = agent.target_node_idx().unwrap();
agent.explore_all_actions(&game_state).await;
cfr_state.with_node_data(target_node_idx, |node_data| {
let matcher = get_regret_matcher_from_node(node_data).unwrap();
matcher.best_weight().iter().map(|w| w.to_bits()).collect()
})
}
#[tokio::test(flavor = "current_thread")]
async fn enumeration_path_is_value_stable_without_seed() {
let a = run_turn_enumeration_once().await;
let b = run_turn_enumeration_once().await;
assert_eq!(
a, b,
"the ≤2-card enumeration path is RNG-free, so best_weight vectors \
must be bit-identical across two independent runs with no seed"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn no_act_deadline_completes_budgeted_waves() {
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let budget: Arc<dyn Budget> = Arc::new(MostRestrictive::new(vec![
Arc::new(IterationCount::new(5)),
Arc::new(MaxWidth::new(vec![1])),
]));
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("CFRAgent-no-deadline")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set)
.budget(budget)
.action_gen_config(ConfigurableActionConfig::default())
.build();
agent.ensure_target_node();
agent.ensure_regret_matcher();
let target_node_idx = agent.target_node_idx().unwrap();
agent.explore_all_actions(&game_state).await;
let (_active, num_updates) = cfr_state.get_pruning_info(target_node_idx);
assert_eq!(
num_updates, 5,
"without a deadline the wave loop runs to its iteration cap: \
all five budgeted waves complete and apply their updates"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn explore_runs_with_uniform_estimator() {
use crate::arena::hand_estimator::UniformRandomEstimator;
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let budget = budget_for_schedule(&[4, 2, 1]);
let mut agent = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("uniform-test")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.budget(budget)
.action_gen_config(())
.estimator(Arc::new(UniformRandomEstimator))
.build();
agent.ensure_target_node();
agent.ensure_regret_matcher();
agent.explore_all_actions(&game_state).await;
}
#[tokio::test]
async fn historian_populates_log_in_real_cfr_sim() {
use std::sync::Arc;
let game_state = crate::arena::GameStateBuilder::default()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let stub_agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("hist")
.player_idx(0)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.action_gen_config(ConfigurableActionConfig::default())
.budget(budget_for_schedule(&[1]))
.estimator(Arc::new(HistoryNeedingStub::default()))
.build();
let log = stub_agent
.hand_log
.clone()
.expect("needs_history agent has a hand_log");
let other = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("p1")
.player_idx(1)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.action_gen_config(ConfigurableActionConfig::default())
.budget(budget_for_schedule(&[1]))
.build();
let agents: Vec<Box<dyn Agent>> = vec![Box::new(stub_agent), Box::new(other)];
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(agents)
.cfr_context(cfr_state, traversal_set, true)
.build()
.unwrap();
sim.run().await;
assert!(
!log.to_actions().is_empty(),
"agent historian must be collected and populate the log even when cfr_context is set"
);
}
#[tokio::test]
async fn historian_present_only_when_estimator_needs_history() {
use crate::arena::Agent;
use std::sync::Arc;
let game_state = crate::arena::GameStateBuilder::default()
.num_players_with_stack(2, 100.0)
.big_blind(2.0)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let agent = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("no-hist")
.player_idx(0)
.cfr_state(cfr_state.clone())
.traversal_set(traversal_set.clone())
.action_gen_config(())
.build();
assert!(agent.historian().is_none());
let agent2 = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("hist")
.player_idx(0)
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.action_gen_config(())
.estimator(Arc::new(HistoryNeedingStub::default()))
.build();
assert!(agent2.historian().is_some());
}
#[tokio::test]
async fn estimate_receives_current_hand_log() {
use crate::arena::action::{Action, GameStartPayload};
use crate::arena::game_state::Round;
use std::sync::Arc;
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let stub = HistoryNeedingStub::default();
let observed = stub.observed_counts.clone();
let mut agent = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("hist-feed")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.budget(budget_for_schedule(&[2, 1]))
.action_gen_config(())
.estimator(Arc::new(stub))
.build();
{
let log = agent
.hand_log
.as_ref()
.expect("needs_history agent has a hand_log");
log.record(Action::GameStart(GameStartPayload {
ante: 0.0,
small_blind: 5.0,
big_blind: 10.0,
}));
log.record(Action::RoundAdvance(Round::Preflop));
log.record(Action::DealCommunity(crate::core::Card::from(10u8)));
}
agent.ensure_target_node();
agent.ensure_regret_matcher();
agent.explore_all_actions(&game_state).await;
let counts = observed.lock().unwrap();
assert!(
counts.contains(&3),
"expected the root estimate to receive a 3-action GameLog; saw {counts:?}"
);
assert!(
counts.iter().all(|&c| c >= 3),
"every estimate must see at least the real-hand prefix (full path); saw {counts:?}"
);
}
#[tokio::test(flavor = "current_thread")]
async fn sub_sim_continues_prefix_without_relogging_setup() {
use crate::arena::action::Action;
use std::sync::Arc;
let (game_state, cfr_state, traversal_set) = setup_tiny_heads_up();
let stub = HistoryNeedingStub::default();
let observed = stub.observed_counts.clone();
let mut agent = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name("cont")
.player_idx(game_state.to_act_idx())
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.budget(budget_for_schedule(&[2, 1]))
.action_gen_config(ConfigurableActionConfig::default())
.estimator(Arc::new(stub))
.build();
{
let log = agent.hand_log.as_ref().unwrap();
log.record(Action::RoundAdvance(crate::arena::game_state::Round::River));
log.record(Action::DealCommunity(crate::core::Card::from(5u8)));
}
agent.ensure_target_node();
agent.ensure_regret_matcher();
agent.explore_all_actions(&game_state).await;
let counts = observed.lock().unwrap();
assert!(
counts.contains(&2),
"root estimate should see the 2-action prefix; saw {counts:?}"
);
assert!(
counts.iter().any(|&c| c > 2),
"a sub-sim estimate should see the prefix plus its own line; saw {counts:?}"
);
assert!(
counts.iter().all(|&c| c >= 2),
"every estimate must include the real-hand prefix; saw {counts:?}"
);
}
#[test]
fn concurrent_children_have_independent_tails() {
use crate::arena::action::Action;
use crate::arena::game_state::Round;
let root = super::hand_log::HandLog::new();
root.record(Action::RoundAdvance(Round::Flop));
let frozen = root.freeze();
let a = frozen.spawn_child();
let b = frozen.spawn_child();
a.record(Action::DealCommunity(crate::core::Card::from(1u8)));
b.record(Action::DealCommunity(crate::core::Card::from(2u8)));
assert_eq!(
a.to_actions(),
vec![
Action::RoundAdvance(Round::Flop),
Action::DealCommunity(crate::core::Card::from(1u8))
]
);
assert_eq!(
b.to_actions(),
vec![
Action::RoundAdvance(Round::Flop),
Action::DealCommunity(crate::core::Card::from(2u8))
]
);
}
#[test]
fn default_estimator_has_no_hand_log_and_no_historian() {
use crate::arena::Agent;
let game_state = crate::arena::GameStateBuilder::default()
.num_players_with_stack(2, 100.0)
.big_blind(2.0)
.build()
.unwrap();
let cfr_state = make_cfr_state(&game_state);
let traversal_set = TraversalSet::new(game_state.num_players);
let agent = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name("default")
.player_idx(0)
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.action_gen_config(())
.build();
assert!(
agent.hand_log.is_none(),
"default estimator must carry no log"
);
assert!(
agent.historian().is_none(),
"no historian on the default fast path"
);
}
}