use std::borrow::Cow;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use little_sorry::{PcfrPlusRegretMatcher, RegretMinimizer};
use rand::SeedableRng;
use rand::rngs::StdRng;
use smallvec::SmallVec;
use tracing::event;
use crate::arena::hand_estimator::sample_world;
use crate::arena::{
Agent, GameState, HoldemSimulationBuilder, action::AgentAction, game_state::Round,
};
use super::super::{
ActionIndexMapper, Budget, CFRState, ExplorationStats, InFlightLimiter, NUM_ACTION_INDICES,
NextStep, NodeData, PlayerData, TraversalSet, TraversalState, action_bit_set::ActionBitSet,
action_generator::ActionGenerator, action_validator::validate_actions,
};
#[derive(Copy, Clone, Debug)]
enum StopCause {
Deadline,
BudgetStop,
BudgetStartTimer,
FastForward,
SingleAction,
StableStrategy,
}
impl std::fmt::Display for StopCause {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
StopCause::Deadline => "deadline",
StopCause::BudgetStop => "budget_stop",
StopCause::BudgetStartTimer => "budget_start_timer",
StopCause::FastForward => "fast_forward",
StopCause::SingleAction => "single_action",
StopCause::StableStrategy => "stable_strategy",
};
f.write_str(s)
}
}
const EARLY_EXIT_MIN_ITERS: usize = 4;
const EARLY_EXIT_STABLE_ITERS: u32 = 3;
const EARLY_EXIT_EPSILON: f32 = 0.001;
use super::builder::CFRAgentBuilder;
use super::fast_forward::{
fast_forward_advance_betting, fast_forward_apply_action, fast_forward_distribute_pot,
fast_forward_enumerate_showdowns, fast_forward_run_to_showdown,
fast_forward_sample_flop_enumerate_runout,
};
use super::hand_log::{HandLog, HandLogHistorian};
use super::reward_context::ComputeRewardContext;
pub(super) fn wave_mean_into(out: &mut [f32], sums: &[f32], counts: &[u32], penalty: f32) {
for (o, (&s, &c)) in out.iter_mut().zip(sums.iter().zip(counts)) {
*o = if c > 0 { s / c as f32 } else { penalty };
}
}
pub struct CFRAgent<T>
where
T: ActionGenerator,
{
pub(super) name: Cow<'static, str>,
pub(super) traversal_set: TraversalSet,
pub(super) traversal_state: TraversalState,
pub(super) cfr_state: CFRState,
pub(super) action_generator: T,
pub(super) action_gen_config: Arc<T::Config>,
pub(super) action_index_mapper: ActionIndexMapper,
pub(super) forced_action: Option<AgentAction>,
pub(super) depth: usize,
pub(super) allow_node_mutation: bool,
pub(super) limiter: InFlightLimiter,
pub(super) budget: Arc<dyn Budget>,
pub(super) stop: Arc<AtomicBool>,
pub(super) estimator: std::sync::Arc<dyn crate::arena::HandDistributionEstimator>,
pub(super) hand_log: Option<HandLog>,
}
pub(super) fn spawn_stop_timer(
duration: std::time::Duration,
stop: Arc<AtomicBool>,
) -> super::AbortOnDrop {
super::AbortOnDrop(tokio::spawn(async move {
tokio::time::sleep(duration).await;
stop.store(true, Ordering::Relaxed);
}))
}
impl<T> CFRAgent<T>
where
T: ActionGenerator + Send + 'static,
T::Config: Send + Sync,
{
pub fn cfr_state(&self) -> &CFRState {
&self.cfr_state
}
pub fn traversal_set(&self) -> &TraversalSet {
&self.traversal_set
}
pub fn allow_node_mutation(&self) -> bool {
self.allow_node_mutation
}
async fn compute_reward(
game_state: &GameState,
action: &AgentAction,
ctx: &ComputeRewardContext<T>,
) -> f32 {
if ctx.fast_forward {
let player_idx = ctx.traversal_state.player_idx() as usize;
Self::compute_reward_fast_forward(game_state, action, player_idx)
} else {
Self::compute_reward_recursive(game_state, action, ctx).await
}
}
async fn compute_reward_recursive(
game_state: &GameState,
action: &AgentAction,
ctx: &ComputeRewardContext<T>,
) -> f32 {
let num_agents = game_state.num_players;
let (_before_node_idx, _before_child_idx, player_idx) = ctx.traversal_state.get_all();
event!(
tracing::Level::TRACE,
num_agents,
?action,
player_idx = player_idx,
"Computing reward via sub-simulation"
);
let forked_traversal_set = ctx.traversal_set.fork();
let sub_depth = ctx.depth + 1;
let action_config = ctx.action_gen_config.clone();
let cached_mapper_config = *ctx.action_index_mapper.config();
let shared_cfr_state = ctx.cfr_state.clone();
let child_log: Option<HandLog> = ctx.hand_log.as_ref().map(|l| l.spawn_child());
let mut agents: Vec<Box<dyn Agent>> = Vec::with_capacity(num_agents);
for i in 0..num_agents {
let mut builder = CFRAgentBuilder::<T>::new()
.name("CFRAgent-sub")
.player_idx(i)
.cfr_state(shared_cfr_state.clone())
.mapper_config(cached_mapper_config)
.action_gen_config_arc(action_config.clone())
.traversal_set(forked_traversal_set.clone())
.depth(sub_depth)
.limiter(ctx.limiter.clone())
.budget(ctx.budget.clone())
.stop_flag(ctx.stop.clone())
.estimator(ctx.estimator.clone());
if let Some(ref cl) = child_log {
builder = builder.hand_log(cl.clone());
}
if i == player_idx as usize {
builder = builder.forced_action((*action).clone());
}
agents.push(Box::new(builder.build()) as Box<dyn Agent>);
}
let sub_sim_rng = StdRng::from_rng(&mut rand::rng());
let mut sim_builder = HoldemSimulationBuilder::default()
.game_state(game_state.clone())
.agents(agents)
.cfr_context(
shared_cfr_state,
forked_traversal_set,
ctx.allow_node_mutation,
);
if let Some(cl) = child_log {
sim_builder = sim_builder.historians(vec![
Box::new(HandLogHistorian::new(cl)) as Box<dyn crate::arena::Historian>
]);
}
let mut sim = sim_builder.build_with_rng(sub_sim_rng).unwrap();
sim.run().await;
#[cfg(debug_assertions)]
{
let (after_node_idx, after_child_idx) = ctx.traversal_state.get_position();
assert_eq!(
_before_node_idx, after_node_idx,
"Node index should be the same after exploration"
);
assert_eq!(
_before_child_idx, after_child_idx,
"Child index should be the same after exploration"
);
}
sim.game_state.player_reward(player_idx as usize)
}
fn compute_reward_fast_forward(
game_state: &GameState,
action: &AgentAction,
player_idx: usize,
) -> f32 {
let mut rng = rand::rng();
let mut gs = game_state.clone();
fast_forward_apply_action(&mut gs, action);
let contenders = gs.player_active.count() + gs.player_all_in.count();
if contenders <= 1 {
fast_forward_run_to_showdown(&mut gs, &mut rng);
fast_forward_distribute_pot(&mut gs);
return gs.player_reward(player_idx);
}
fast_forward_advance_betting(&mut gs);
let cards_needed = match gs.round {
Round::Showdown | Round::Complete => 0,
Round::DealFlop => 3,
Round::DealTurn => 2, Round::DealRiver => 1,
_ => {
fast_forward_run_to_showdown(&mut gs, &mut rng);
fast_forward_distribute_pot(&mut gs);
return gs.player_reward(player_idx);
}
};
if cards_needed <= 2 {
fast_forward_enumerate_showdowns(&gs, player_idx, cards_needed)
} else {
fast_forward_sample_flop_enumerate_runout(&gs, player_idx, &mut rng)
}
}
pub(super) fn target_node_idx(&self) -> Option<usize> {
let (from_node_idx, from_child_idx) = self.traversal_state.get_position();
self.cfr_state.get_child(from_node_idx, from_child_idx)
}
pub(super) fn ensure_target_node(&self) -> usize {
let (node_idx, chosen_child_idx, player_idx) = self.traversal_state.get_all();
let expected_data = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx,
});
self.cfr_state.ensure_child(
node_idx,
chosen_child_idx,
expected_data,
self.allow_node_mutation,
)
}
pub(super) fn ensure_regret_matcher(&mut self) {
let target_node_idx = self.ensure_target_node();
self.cfr_state
.update_node(target_node_idx, |data| {
if let NodeData::Player(player_data) = data
&& player_data.regret_matcher.is_none()
{
let regret_matcher = Box::new(PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES));
player_data.regret_matcher = Some(regret_matcher);
}
})
.unwrap();
}
fn update_regret_at_node(&self, target_node_idx: usize, rewards: &[f32]) {
self.cfr_state
.update_node(target_node_idx, |data| {
if let NodeData::Player(player_data) = data {
if let Some(regret_matcher) = player_data.regret_matcher.as_mut() {
regret_matcher.update_regret(rewards);
}
} else {
event!(
tracing::Level::DEBUG,
target_node_idx,
found_type = %data,
"Concurrent node type change detected — restoring Player"
);
let mut regret_matcher =
Box::new(PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES));
regret_matcher.update_regret(rewards);
*data = NodeData::Player(PlayerData {
regret_matcher: Some(regret_matcher),
player_idx: self.traversal_state.player_idx(),
});
}
})
.unwrap();
}
pub async fn explore_all_actions(&mut self, game_state: &GameState) {
let raw_actions = self.action_generator.gen_possible_actions(game_state);
let validated_actions = validate_actions(raw_actions, game_state);
let mut seen_indices = ActionBitSet::new();
let indexed_actions: SmallVec<[(AgentAction, usize); 8]> = validated_actions
.into_iter()
.filter_map(|a| {
let idx = self.action_index_mapper.action_to_idx(&a, game_state);
if seen_indices.insert(idx) {
Some((a, idx))
} else {
None
}
})
.collect();
if indexed_actions.is_empty() {
return;
}
if indexed_actions.len() == 1 {
if tracing::event_enabled!(target: "cfr_diag", tracing::Level::TRACE) {
let nodes = self.cfr_state.node_count() as u64;
let empty: &[f32] = &[];
tracing::event!(
target: "cfr_diag",
tracing::Level::TRACE,
depth = self.depth as u64,
stop_cause = %StopCause::SingleAction,
final_iterations = 0u64,
final_elapsed_us = 0u64,
nodes_touched_start = nodes,
nodes_touched_end = nodes,
timer_armed = false,
actions_considered = 1u64,
regret_series = ?empty,
);
}
return;
}
let invalid_action_penalty =
-(game_state.starting_stacks[self.traversal_state.player_idx() as usize]);
let target_node_idx = self.target_node_idx().unwrap();
let estimator = self.estimator.clone();
let exploration_log: Option<HandLog> = if estimator.needs_history() {
let log = self
.hand_log
.as_ref()
.expect("needs_history agent must have a hand_log");
Some(if self.depth == 0 {
log.freeze()
} else {
log.clone()
})
} else {
None
};
let history_actions: Option<Vec<crate::arena::action::Action>> =
exploration_log.as_ref().map(|l| l.to_actions());
let game_log = history_actions
.as_ref()
.map(|a| crate::arena::GameLog { actions: a });
let ranges = estimator.estimate(game_state, game_log.as_ref()).await;
let mut sums = [0.0f32; NUM_ACTION_INDICES];
let mut counts = [0u32; NUM_ACTION_INDICES];
let mut rewards = [0.0f32; NUM_ACTION_INDICES];
const PRUNE_WARMUP: usize = 3;
const REPROBE_INTERVAL: usize = 4;
const DYNAMIC_THRESHOLD_C: f32 = 0.01;
const SPAWN_FRONTIER_DEPTH: usize = 3;
let (initial_active, initial_updates) = self.cfr_state.get_pruning_info(target_node_idx);
let can_prune = indexed_actions.len() > 2 && initial_updates >= PRUNE_WARMUP;
let mut active_actions = initial_active;
let mut updates_since_warmup = initial_updates;
let started = std::time::Instant::now();
let mut latest_avg_regret: Option<f32> = None;
let mut iter_idx: u64 = 0;
let mut early_exit_prev_strategy = [0.0f32; NUM_ACTION_INDICES];
let mut early_exit_curr_strategy = [0.0f32; NUM_ACTION_INDICES];
let mut early_exit_stable_count: u32 = 0;
let mut early_exit_has_prev = false;
let mut timer_armed = self.depth > 0;
let mut _timer_guard: Option<super::AbortOnDrop> = None;
let diag_on = tracing::event_enabled!(target: "cfr_diag", tracing::Level::TRACE);
let diag_nodes_touched_start: u64 = if diag_on {
self.cfr_state.node_count() as u64
} else {
0
};
let mut diag_regret_series: Vec<f32> = if diag_on {
Vec::with_capacity(32)
} else {
Vec::new()
};
let mut diag_stop_cause: StopCause = StopCause::BudgetStop;
loop {
let stats = ExplorationStats {
elapsed: started.elapsed(),
iterations: iter_idx,
nodes_touched: self.cfr_state.node_count() as u64,
depth: self.depth,
avg_regret: latest_avg_regret,
timer_armed,
};
if self.stop.load(Ordering::Relaxed) {
if diag_on {
diag_stop_cause = StopCause::Deadline;
}
break;
}
let (wave_width, fast_forward) = match self.budget.next_step(&stats) {
NextStep::Stop | NextStep::Pass => {
if diag_on {
diag_stop_cause = StopCause::BudgetStop;
}
break;
}
NextStep::StartTimer { duration } if !timer_armed => {
debug_assert!(
self.depth == 0,
"NextStep::StartTimer should only arrive at the root (depth 0). \
Sub-agents inherit `timer_armed = true`, so a StartTimer here means \
a budget returned StartTimer at depth > 0, which is unsupported by \
the current engine. If you want per-depth timers, the engine needs \
to grow per-depth timer slots."
);
_timer_guard = Some(spawn_stop_timer(duration, self.stop.clone()));
timer_armed = true;
continue;
}
NextStep::StartTimer { .. } => {
if diag_on {
diag_stop_cause = StopCause::BudgetStartTimer;
}
break;
}
NextStep::Wave { width } => (width, false),
NextStep::FastForward => (1, true),
};
let ctx = ComputeRewardContext::<T> {
traversal_set: self.traversal_set.clone(),
traversal_state: self.traversal_state.clone(),
cfr_state: self.cfr_state.clone(),
action_gen_config: self.action_gen_config.clone(),
action_index_mapper: self.action_index_mapper.clone(),
limiter: self.limiter.clone(),
budget: self.budget.clone(),
stop: self.stop.clone(),
depth: self.depth,
fast_forward,
allow_node_mutation: self.allow_node_mutation,
estimator: self.estimator.clone(),
hand_log: exploration_log.clone(),
};
let wave_state: Option<GameState> = if fast_forward {
None
} else {
let mut rng = rand::rng();
Some(sample_world(&ranges, game_state, &mut rng))
};
let sampled_arc: Option<std::sync::Arc<GameState>> =
wave_state.map(std::sync::Arc::new);
let effective_gs: &GameState = sampled_arc.as_deref().unwrap_or(game_state);
let prune_this_iter =
indexed_actions.len() > 2 && (can_prune || updates_since_warmup >= PRUNE_WARMUP);
let is_reprobe = iter_idx.is_multiple_of(REPROBE_INTERVAL as u64);
let skip_pruned = prune_this_iter && !is_reprobe;
let dyn_thresh_set: Option<ActionBitSet> = if skip_pruned && DYNAMIC_THRESHOLD_C > 0.0 {
let mut dyn_strategy = [0.0f32; NUM_ACTION_INDICES];
if self
.cfr_state
.node_current_strategy_into(target_node_idx, &mut dyn_strategy)
{
let denom = (iter_idx as f32).max(PRUNE_WARMUP as f32).sqrt();
let threshold = DYNAMIC_THRESHOLD_C / denom;
let mut set = ActionBitSet::new();
for (i, &p) in dyn_strategy.iter().enumerate() {
if p >= threshold {
set.insert(i);
}
}
Some(set)
} else {
None
}
} else {
None
};
sums.fill(0.0);
counts.fill(0);
let mut set: tokio::task::JoinSet<(usize, f32)> = tokio::task::JoinSet::new();
let mut inline: Vec<(usize, f32)> = Vec::new();
let spawn_here = self.depth < SPAWN_FRONTIER_DEPTH;
let gs_arc = spawn_here.then(|| match &sampled_arc {
Some(arc) => arc.clone(),
None => std::sync::Arc::new(game_state.clone()),
});
for _sample in 0..wave_width {
for (action, reward_idx) in &indexed_actions {
let reward_idx = *reward_idx;
if skip_pruned && !active_actions.contains(reward_idx) {
event!(
tracing::Level::TRACE,
action_idx = reward_idx,
wave = iter_idx,
"RBP: skipping pruned action"
);
continue;
}
if let Some(dyn_set) = &dyn_thresh_set
&& !dyn_set.contains(reward_idx)
{
event!(
tracing::Level::TRACE,
action_idx = reward_idx,
wave = iter_idx,
"DynThresh: skipping low-probability action"
);
continue;
}
debug_assert!(
reward_idx < sums.len(),
"Action index {} should be less than number of potential actions {}",
reward_idx,
sums.len()
);
let action = action.clone();
if let Some(gs_arc) = &gs_arc
&& let Ok(permit) = ctx.limiter.clone().try_acquire_owned()
{
let ctx = ctx.clone();
let gs = gs_arc.clone();
set.spawn(async move {
let _permit = permit;
let r = CFRAgent::<T>::compute_reward(&gs, &action, &ctx).await;
(reward_idx, r)
});
continue;
}
let r = Self::compute_reward(effective_gs, &action, &ctx).await;
inline.push((reward_idx, r));
}
}
for (idx, r) in inline.drain(..) {
sums[idx] += r;
counts[idx] += 1;
}
while let Some(joined) = set.join_next().await {
match joined {
Ok((idx, r)) => {
sums[idx] += r;
counts[idx] += 1;
}
Err(join_err) => {
if join_err.is_panic() {
std::panic::resume_unwind(join_err.into_panic());
} else {
panic!("CFR exploration task failed to join: {join_err}");
}
}
}
}
if self.stop.load(Ordering::Relaxed) {
if diag_on {
diag_stop_cause = StopCause::Deadline;
}
break;
}
wave_mean_into(&mut rewards, &sums, &counts, invalid_action_penalty);
self.update_regret_at_node(target_node_idx, &rewards);
updates_since_warmup += 1;
iter_idx += 1;
latest_avg_regret = self.cfr_state.node_avg_regret(target_node_idx);
if diag_on && let Some(r) = latest_avg_regret {
diag_regret_series.push(r);
}
if self
.cfr_state
.node_current_strategy_into(target_node_idx, &mut early_exit_curr_strategy)
{
if early_exit_has_prev && (iter_idx as usize) >= EARLY_EXIT_MIN_ITERS {
let mut l1 = 0.0f32;
for (a, b) in early_exit_curr_strategy
.iter()
.zip(early_exit_prev_strategy.iter())
{
l1 += (a - b).abs();
}
if l1 < EARLY_EXIT_EPSILON {
early_exit_stable_count += 1;
if early_exit_stable_count >= EARLY_EXIT_STABLE_ITERS {
if diag_on {
diag_stop_cause = StopCause::StableStrategy;
}
break;
}
} else {
early_exit_stable_count = 0;
}
}
early_exit_prev_strategy.copy_from_slice(&early_exit_curr_strategy);
early_exit_has_prev = true;
}
if is_reprobe
&& indexed_actions.len() > 2
&& (can_prune || updates_since_warmup >= PRUNE_WARMUP)
{
let (new_active, _) = self.cfr_state.get_pruning_info(target_node_idx);
active_actions = new_active;
}
if fast_forward {
if diag_on {
diag_stop_cause = StopCause::FastForward;
}
break;
}
}
if diag_on {
let elapsed_us = started.elapsed().as_micros() as u64;
let nodes_touched_end = self.cfr_state.node_count() as u64;
tracing::event!(
target: "cfr_diag",
tracing::Level::TRACE,
depth = self.depth as u64,
stop_cause = %diag_stop_cause,
final_iterations = iter_idx,
final_elapsed_us = elapsed_us,
nodes_touched_start = diag_nodes_touched_start,
nodes_touched_end = nodes_touched_end,
timer_armed = timer_armed,
actions_considered = indexed_actions.len() as u64,
regret_series = ?diag_regret_series.as_slice(),
);
}
}
}
#[cfg(test)]
mod wave_tests {
use super::wave_mean_into;
#[test]
fn wave_mean_averages_only_sampled_slots() {
let penalty = -100.0_f32;
let sums = [3.0, 0.0, 8.0];
let counts = [2u32, 0, 2];
let mut mean = [0.0f32; 3];
wave_mean_into(&mut mean, &sums, &counts, penalty);
assert_eq!(mean, [1.5, -100.0, 4.0]);
}
#[test]
fn wave_mean_single_sample_equals_sample() {
let penalty = -7.0_f32;
let sums = [5.0, -2.0, 0.0];
let counts = [1u32, 1, 0];
let mut mean = [0.0f32; 3];
wave_mean_into(&mut mean, &sums, &counts, penalty);
assert_eq!(mean, [5.0, -2.0, -7.0]);
}
}