use rand::RngExt;
use tracing::event;
use crate::arena::{GameState, action::AgentAction};
use super::{ActionIndexMapper, NodeData, action_bit_set::ActionBitSet};
use little_sorry::{PcfrPlusRegretMatcher, RegretMinimizer};
const MAX_DEDUPED_ACTIONS: usize = 16;
static DEDUP_FALLBACK_ACTION: AgentAction = AgentAction::Fold;
struct DedupedActions<'a> {
entries: [(u8, &'a AgentAction); MAX_DEDUPED_ACTIONS],
len: usize,
}
impl<'a> DedupedActions<'a> {
fn from_slice(
actions: &'a [AgentAction],
mapper: &ActionIndexMapper,
game_state: &GameState,
) -> Self {
let mut seen = ActionBitSet::new();
let mut entries: [(u8, &'a AgentAction); MAX_DEDUPED_ACTIONS] =
[(0, &DEDUP_FALLBACK_ACTION); MAX_DEDUPED_ACTIONS];
let mut len = 0usize;
for action in actions {
let idx = mapper.action_to_idx(action, game_state);
if seen.insert(idx) {
debug_assert!(
len < MAX_DEDUPED_ACTIONS,
"action generator produced more than {MAX_DEDUPED_ACTIONS} distinct actions"
);
entries[len] = (idx as u8, action);
len += 1;
}
}
Self { entries, len }
}
#[inline]
fn len(&self) -> usize {
self.len
}
#[inline]
fn get(&self, i: usize) -> (usize, &'a AgentAction) {
debug_assert!(i < self.len);
let (idx, action) = self.entries[i];
(idx as usize, action)
}
}
pub struct ActionPicker<'a> {
mapper: &'a ActionIndexMapper,
possible_actions: &'a [AgentAction],
regret_matcher: Option<&'a PcfrPlusRegretMatcher>,
game_state: &'a GameState,
}
impl<'a> ActionPicker<'a> {
pub fn new(
mapper: &'a ActionIndexMapper,
possible_actions: &'a [AgentAction],
regret_matcher: Option<&'a PcfrPlusRegretMatcher>,
game_state: &'a GameState,
) -> Self {
debug_assert!(
!possible_actions.is_empty(),
"possible_actions should always contain at least one action"
);
Self {
mapper,
possible_actions,
regret_matcher,
game_state,
}
}
pub fn pick_action<R: rand::Rng>(&self, rng: &mut R) -> AgentAction {
let Some(matcher) = self.regret_matcher else {
return self.pick_uniform_reservoir(rng);
};
let valid = DedupedActions::from_slice(self.possible_actions, self.mapper, self.game_state);
let len = valid.len();
debug_assert!(len > 0, "ActionPicker must have at least one valid action");
let weights = matcher.best_weight();
let mut total_weight: f32 = 0.0;
for i in 0..len {
let (idx, _) = valid.get(i);
total_weight += weights.get(idx).copied().unwrap_or(0.0).max(0.0);
}
if total_weight < 1e-10 {
let chosen_idx = rng.random_range(0..len);
event!(
tracing::Level::DEBUG,
chosen_idx = chosen_idx,
"All weights zero, using uniform random"
);
return valid.get(chosen_idx).1.clone();
}
let random_value: f32 = rng.random::<f32>() * total_weight;
let mut cumulative = 0.0f32;
for i in 0..len {
let (action_idx, action) = valid.get(i);
cumulative += weights.get(action_idx).copied().unwrap_or(0.0).max(0.0);
if random_value <= cumulative {
event!(
tracing::Level::DEBUG,
action_idx = action_idx,
total_weight = total_weight,
"Selected action from regret matcher"
);
return action.clone();
}
}
valid.get(len - 1).1.clone()
}
fn pick_uniform_reservoir<R: rand::Rng>(&self, rng: &mut R) -> AgentAction {
let mut seen = ActionBitSet::new();
let mut count: u32 = 0;
let mut chosen: Option<&AgentAction> = None;
for action in self.possible_actions {
let idx = self.mapper.action_to_idx(action, self.game_state);
if seen.insert(idx) {
count += 1;
if rng.random_range(0..count) == 0 {
chosen = Some(action);
}
}
}
event!(
tracing::Level::DEBUG,
count = count,
"No regret matcher, using uniform random"
);
chosen
.expect("possible_actions must contain at least one action")
.clone()
}
pub fn pick_best_action(&self) -> AgentAction {
let mut seen = ActionBitSet::new();
let Some(matcher) = self.regret_matcher else {
for action in self.possible_actions {
let idx = self.mapper.action_to_idx(action, self.game_state);
if seen.insert(idx) {
return action.clone();
}
}
unreachable!("possible_actions must contain at least one action");
};
let weights = matcher.best_weight();
let mut best_weight = f32::NEG_INFINITY;
let mut best: Option<&AgentAction> = None;
for action in self.possible_actions {
let idx = self.mapper.action_to_idx(action, self.game_state);
if seen.insert(idx) {
let w = weights.get(idx).copied().unwrap_or(0.0);
if w > best_weight {
best_weight = w;
best = Some(action);
}
}
}
best.expect("possible_actions must contain at least one action")
.clone()
}
}
pub fn get_regret_matcher_from_node(node_data: &NodeData) -> Option<&PcfrPlusRegretMatcher> {
if let NodeData::Player(pd) = node_data {
pd.regret_matcher.as_ref().map(|rm| rm.as_ref())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::GameStateBuilder;
use crate::arena::cfr::{
ACTION_IDX_ALL_IN, ACTION_IDX_RAISE_MIN, ActionIndexMapperConfig, NUM_ACTION_INDICES,
};
fn create_test_game_state() -> GameState {
GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap()
}
fn create_mapper() -> ActionIndexMapper {
ActionIndexMapper::new(ActionIndexMapperConfig::new(10.0, 100.0))
}
fn create_seeded_rng() -> rand::rngs::StdRng {
use rand::SeedableRng;
rand::rngs::StdRng::seed_from_u64(42)
}
#[test]
fn test_pick_action_uniform_no_regret_matcher() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let actions = vec![
AgentAction::Fold,
AgentAction::Bet(10.0),
AgentAction::AllIn,
];
let mut rng = create_seeded_rng();
let picker = ActionPicker::new(&mapper, &actions, None, &game_state);
let picked = picker.pick_action(&mut rng);
assert!(
actions.contains(&picked),
"Picked action {:?} should be in valid actions",
picked
);
}
#[test]
fn test_pick_action_with_regret_matcher() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let actions = vec![
AgentAction::Fold,
AgentAction::Bet(10.0),
AgentAction::AllIn,
];
let mut rng = create_seeded_rng();
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let mut rewards = vec![0.0; NUM_ACTION_INDICES];
rewards[0] = 100.0; rewards[1] = 0.0; rewards[ACTION_IDX_ALL_IN] = 0.0; matcher.update_regret(&rewards);
let picker = ActionPicker::new(&mapper, &actions, Some(&matcher), &game_state);
let mut fold_count = 0;
for _ in 0..100 {
let picked = picker.pick_action(&mut rng);
if picked == AgentAction::Fold {
fold_count += 1;
}
}
assert!(
fold_count > 50,
"Fold should be chosen more often when it has high weight, got {}%",
fold_count
);
}
#[test]
fn test_pick_best_action_no_regret_matcher() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let actions = vec![
AgentAction::Fold,
AgentAction::Bet(10.0),
AgentAction::AllIn,
];
let picker = ActionPicker::new(&mapper, &actions, None, &game_state);
let picked = picker.pick_best_action();
assert_eq!(picked, AgentAction::Fold);
}
#[test]
fn test_pick_best_action_with_regret_matcher() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let actions = vec![
AgentAction::Fold,
AgentAction::Bet(10.0),
AgentAction::AllIn,
];
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let mut rewards = vec![0.0; NUM_ACTION_INDICES];
rewards[0] = 10.0; rewards[1] = 20.0; rewards[ACTION_IDX_ALL_IN] = 100.0; matcher.update_regret(&rewards);
let picker = ActionPicker::new(&mapper, &actions, Some(&matcher), &game_state);
let picked = picker.pick_best_action();
assert_eq!(picked, AgentAction::AllIn);
}
#[test]
fn test_filters_to_valid_actions_only() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let actions = vec![AgentAction::Bet(10.0)];
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let mut rewards = vec![0.0; NUM_ACTION_INDICES];
rewards[0] = 1000.0; rewards[1] = 1.0; rewards[ACTION_IDX_ALL_IN] = 1000.0; matcher.update_regret(&rewards);
let picker = ActionPicker::new(&mapper, &actions, Some(&matcher), &game_state);
let mut rng = create_seeded_rng();
let picked = picker.pick_action(&mut rng);
assert_eq!(picked, AgentAction::Bet(10.0));
}
#[test]
fn test_handles_zero_weights() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let actions = vec![
AgentAction::Fold,
AgentAction::Bet(10.0),
AgentAction::AllIn,
];
let mut rng = create_seeded_rng();
let matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let picker = ActionPicker::new(&mapper, &actions, Some(&matcher), &game_state);
let picked = picker.pick_action(&mut rng);
assert!(
actions.contains(&picked),
"Picked action {:?} should be in valid actions",
picked
);
}
#[test]
fn test_pick_best_action_deterministic() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let actions = vec![
AgentAction::Fold,
AgentAction::Bet(50.0),
AgentAction::AllIn,
];
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let mut rewards = vec![0.0; NUM_ACTION_INDICES];
rewards[0] = 1000.0; matcher.update_regret(&rewards);
let picker = ActionPicker::new(&mapper, &actions, Some(&matcher), &game_state);
for _ in 0..10 {
let picked = picker.pick_best_action();
assert_eq!(
picked,
AgentAction::Fold,
"Best action should always be fold"
);
}
}
#[test]
fn test_different_bet_amounts_map_correctly() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let small_bet = AgentAction::Bet(15.0);
let medium_bet = AgentAction::Bet(50.0);
let large_bet = AgentAction::Bet(90.0);
let small_idx = mapper.action_to_idx(&small_bet, &game_state);
let medium_idx = mapper.action_to_idx(&medium_bet, &game_state);
let large_idx = mapper.action_to_idx(&large_bet, &game_state);
assert!(
(2..=50).contains(&small_idx),
"Small bet index {} should be in range 2-50",
small_idx
);
assert!(
(2..=50).contains(&medium_idx),
"Medium bet index {} should be in range 2-50",
medium_idx
);
assert!(
(2..=50).contains(&large_idx),
"Large bet index {} should be in range 2-50",
large_idx
);
assert!(
small_idx < medium_idx,
"Small bet index {} should be less than medium {}",
small_idx,
medium_idx
);
assert!(
medium_idx < large_idx,
"Medium bet index {} should be less than large {}",
medium_idx,
large_idx
);
}
#[test]
fn test_fold_and_allin_always_same_index() {
use crate::arena::cfr::{ACTION_IDX_ALL_IN, ACTION_IDX_FOLD};
let game_state = create_test_game_state();
let mapper = create_mapper();
let fold_idx = mapper.action_to_idx(&AgentAction::Fold, &game_state);
assert_eq!(
fold_idx, ACTION_IDX_FOLD,
"Fold should always map to index 0"
);
let allin_idx = mapper.action_to_idx(&AgentAction::AllIn, &game_state);
assert_eq!(
allin_idx, ACTION_IDX_ALL_IN,
"AllIn should always map to ACTION_IDX_ALL_IN"
);
}
#[test]
fn test_pick_action_with_only_two_valid_actions() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let actions = vec![AgentAction::Fold, AgentAction::Bet(10.0)];
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let mut rewards = vec![0.0; NUM_ACTION_INDICES];
rewards[0] = -50.0; rewards[1] = 50.0;
let bet_idx = mapper.action_to_idx(&AgentAction::Bet(10.0), &game_state);
rewards[bet_idx] = 50.0;
matcher.update_regret(&rewards);
let picker = ActionPicker::new(&mapper, &actions, Some(&matcher), &game_state);
let mut rng = create_seeded_rng();
let mut bet_count = 0;
for _ in 0..100 {
let picked = picker.pick_action(&mut rng);
if matches!(picked, AgentAction::Bet(_)) {
bet_count += 1;
}
}
assert!(
bet_count > 60,
"Bet should be chosen more often when it has higher weight, got {}%",
bet_count
);
}
#[test]
fn test_pick_best_handles_ties() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let actions = vec![
AgentAction::Fold,
AgentAction::Bet(50.0),
AgentAction::AllIn,
];
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let mut rewards = vec![0.0; NUM_ACTION_INDICES];
rewards[0] = 10.0; let bet_idx = mapper.action_to_idx(&AgentAction::Bet(50.0), &game_state);
rewards[bet_idx] = 10.0; rewards[ACTION_IDX_ALL_IN] = 10.0; matcher.update_regret(&rewards);
let picker = ActionPicker::new(&mapper, &actions, Some(&matcher), &game_state);
let picked = picker.pick_best_action();
assert_eq!(
picked,
AgentAction::Fold,
"On ties, should return first action with highest weight"
);
}
#[test]
fn test_pick_action_dedupes_index_collisions() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let bet_a = AgentAction::Bet(60.0);
let bet_b = AgentAction::Bet(63.0);
let bet_a_idx = mapper.action_to_idx(&bet_a, &game_state);
let bet_b_idx = mapper.action_to_idx(&bet_b, &game_state);
assert_eq!(
bet_a_idx, bet_b_idx,
"test setup: pick two bet sizes that collide on a raise slot"
);
let call_idx = mapper.action_to_idx(&AgentAction::Bet(10.0), &game_state);
let actions = vec![AgentAction::Bet(10.0), bet_a.clone(), bet_b.clone()];
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let mut rewards = vec![0.0; NUM_ACTION_INDICES];
rewards[call_idx] = 50.0;
rewards[bet_a_idx] = 50.0;
matcher.update_regret(&rewards);
let picker = ActionPicker::new(&mapper, &actions, Some(&matcher), &game_state);
let mut rng = create_seeded_rng();
let mut raise_count = 0;
let iterations = 2000;
for _ in 0..iterations {
if matches!(picker.pick_action(&mut rng), AgentAction::Bet(x) if x > 50.0) {
raise_count += 1;
}
}
let raise_pct = (raise_count * 100) / iterations;
assert!(
(40..=60).contains(&raise_pct),
"raise picked {raise_pct}% of the time; expected ~50% after dedupe (2000 samples)"
);
}
#[test]
fn test_single_action_always_picked() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let mut rng = create_seeded_rng();
let actions = vec![AgentAction::AllIn];
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let mut rewards = vec![0.0; NUM_ACTION_INDICES];
rewards[0] = 1000.0; rewards[ACTION_IDX_ALL_IN] = -1000.0; matcher.update_regret(&rewards);
let picker = ActionPicker::new(&mapper, &actions, Some(&matcher), &game_state);
for _ in 0..10 {
let picked = picker.pick_action(&mut rng);
assert_eq!(
picked,
AgentAction::AllIn,
"Must pick the only available action"
);
}
}
#[test]
fn test_convergence_with_clear_winner() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let actions = vec![AgentAction::Fold, AgentAction::Bet(10.0)];
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let call_reward = 900.0;
let fold_reward = 0.0;
let invalid_penalty = -100.0;
let bet_idx = mapper.action_to_idx(&AgentAction::Bet(10.0), &game_state);
for _ in 0..100 {
let mut rewards = vec![invalid_penalty; NUM_ACTION_INDICES];
rewards[0] = fold_reward;
rewards[bet_idx] = call_reward;
matcher.update_regret(&rewards);
}
let picker = ActionPicker::new(&mapper, &actions, Some(&matcher), &game_state);
let mut rng = create_seeded_rng();
let mut call_count = 0;
for _ in 0..1000 {
let picked = picker.pick_action(&mut rng);
if matches!(picked, AgentAction::Bet(_)) {
call_count += 1;
}
}
assert!(
call_count > 900,
"Call should be chosen >90% of the time with clear reward advantage, got {}%",
call_count / 10
);
}
#[test]
fn test_convergence_equal_rewards() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let actions = vec![AgentAction::Fold, AgentAction::Bet(10.0)];
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let bet_idx = mapper.action_to_idx(&AgentAction::Bet(10.0), &game_state);
for _ in 0..100 {
let mut rewards = vec![-100.0; NUM_ACTION_INDICES];
rewards[0] = 50.0;
rewards[bet_idx] = 50.0;
matcher.update_regret(&rewards);
}
let picker = ActionPicker::new(&mapper, &actions, Some(&matcher), &game_state);
let mut rng = create_seeded_rng();
let mut call_count = 0;
for _ in 0..1000 {
let picked = picker.pick_action(&mut rng);
if matches!(picked, AgentAction::Bet(_)) {
call_count += 1;
}
}
assert!(
(300..=700).contains(&call_count),
"With equal rewards, should be roughly 50/50, got {}%",
call_count / 10
);
}
#[test]
fn test_debug_weights_after_iterations() {
let game_state = create_test_game_state();
let mapper = create_mapper();
let mut matcher = PcfrPlusRegretMatcher::new(NUM_ACTION_INDICES);
let bet_idx = mapper.action_to_idx(&AgentAction::Bet(10.0), &game_state);
println!("bet_idx for Bet(10.0) = {}", bet_idx);
for i in 0..10 {
let mut rewards = vec![-100.0; NUM_ACTION_INDICES];
rewards[0] = 0.0; rewards[1] = 900.0;
matcher.update_regret(&rewards);
let weights = matcher.best_weight();
let fold_weight = weights[0];
let call_weight = weights[1];
let invalid_weight: f32 = weights[ACTION_IDX_RAISE_MIN..ACTION_IDX_ALL_IN]
.iter()
.sum();
let allin_weight = weights[ACTION_IDX_ALL_IN];
println!(
"Iteration {}: fold={:.4}, call={:.4}, invalid_sum={:.4}, allin={:.4}",
i + 1,
fold_weight,
call_weight,
invalid_weight,
allin_weight
);
}
let final_weights = matcher.best_weight();
let fold_weight = final_weights[0];
let call_weight = final_weights[1];
assert!(
call_weight > fold_weight * 2.0,
"Call weight ({}) should be much higher than fold weight ({})",
call_weight,
fold_weight
);
let invalid_weight: f32 = final_weights[ACTION_IDX_RAISE_MIN..ACTION_IDX_ALL_IN]
.iter()
.sum();
assert!(
invalid_weight < 0.01,
"Invalid actions should have near-zero total weight, got {}",
invalid_weight
);
}
#[test]
fn test_all_in_scenario_action_mapping() {
use crate::arena::GameStateBuilder;
use crate::arena::cfr::{ACTION_IDX_ALL_IN, ACTION_IDX_CALL, ACTION_IDX_FOLD};
use crate::arena::game_state::{Round, RoundData};
use crate::core::PlayerBitSet;
let stacks: Vec<f32> = vec![0.0, 900.0];
let player_bet = vec![1000.0, 100.0];
let player_bet_round = vec![900.0, 0.0];
let round_data = RoundData::new_with_bets(100.0, PlayerBitSet::new(2), 1, player_bet_round);
let game_state = GameStateBuilder::new()
.round(Round::River)
.round_data(round_data)
.stacks(stacks)
.player_bet(player_bet)
.big_blind(5.0)
.small_blind(0.0)
.build()
.unwrap();
let mapper = ActionIndexMapper::from_game_state(&game_state);
let current_bet = game_state.current_round_bet();
println!("current_round_bet = {}", current_bet);
let fold_idx = mapper.action_to_idx(&AgentAction::Fold, &game_state);
assert_eq!(fold_idx, ACTION_IDX_FOLD, "Fold should map to index 0");
let call_idx = mapper.action_to_idx(&AgentAction::Bet(900.0), &game_state);
println!("Bet(900.0) maps to index {}", call_idx);
assert_eq!(
call_idx, ACTION_IDX_CALL,
"Bet(900.0) matching current_round_bet should map to CALL (index 1)"
);
let allin_idx = mapper.action_to_idx(&AgentAction::AllIn, &game_state);
assert_eq!(
allin_idx, ACTION_IDX_ALL_IN,
"AllIn should map to ACTION_IDX_ALL_IN"
);
let player_allin =
game_state.current_round_current_player_bet() + game_state.current_player_stack();
println!("Player 1 all-in amount = {}", player_allin);
let bet_1000_idx = mapper.action_to_idx(&AgentAction::Bet(1000.0), &game_state);
println!("Bet(1000.0) maps to index {}", bet_1000_idx);
}
}