use approx::abs_diff_eq;
use crate::arena::{GameState, action::AgentAction};
pub const NUM_ACTION_INDICES: usize = 16;
pub const ACTION_IDX_FOLD: usize = 0;
pub const ACTION_IDX_CALL: usize = 1;
pub const ACTION_IDX_RAISE_MIN: usize = 2;
pub const ACTION_IDX_RAISE_MAX: usize = 13;
pub const ACTION_IDX_ALL_IN: usize = 14;
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ActionIndexMapperConfig {
pub min_bet: f32,
pub max_bet: f32,
}
impl ActionIndexMapperConfig {
pub fn new(min_bet: f32, max_bet: f32) -> Self {
Self { min_bet, max_bet }
}
pub fn from_game_state(game_state: &GameState) -> Self {
let (min_bet, max_bet) = compute_effective_range(game_state);
Self::new(min_bet, max_bet)
}
}
pub fn compute_effective_range(game_state: &GameState) -> (f32, f32) {
let min_bet = game_state.big_blind;
let mut stacks: Vec<f32> = game_state.starting_stacks.to_vec();
stacks.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let max_bet = if stacks.len() >= 2 {
stacks[1]
} else if !stacks.is_empty() {
stacks[0]
} else {
min_bet * 100.0
};
let max_bet = max_bet.max(min_bet * 2.0);
(min_bet, max_bet)
}
#[derive(Debug, Clone)]
pub struct ActionIndexMapper {
config: ActionIndexMapperConfig,
}
impl ActionIndexMapper {
pub fn new(config: ActionIndexMapperConfig) -> Self {
Self { config }
}
pub fn from_game_state(game_state: &GameState) -> Self {
Self::new(ActionIndexMapperConfig::from_game_state(game_state))
}
pub fn config(&self) -> &ActionIndexMapperConfig {
&self.config
}
pub fn action_to_idx(&self, action: &AgentAction, game_state: &GameState) -> usize {
self.action_to_idx_raw(
action,
game_state.current_round_bet(),
game_state.current_round_current_player_bet(),
game_state.current_player_stack(),
)
}
pub fn action_to_idx_raw(
&self,
action: &AgentAction,
current_round_bet: f32,
current_player_bet: f32,
current_player_stack: f32,
) -> usize {
match action {
AgentAction::Fold => ACTION_IDX_FOLD,
AgentAction::Call => ACTION_IDX_CALL,
AgentAction::AllIn => ACTION_IDX_ALL_IN,
AgentAction::Bet(amount) => {
if abs_diff_eq!(*amount, current_round_bet) {
return ACTION_IDX_CALL;
}
let all_in_amount = current_player_bet + current_player_stack;
if abs_diff_eq!(*amount, all_in_amount) {
return ACTION_IDX_ALL_IN;
}
self.bet_to_index(*amount)
}
}
}
fn bet_to_index(&self, bet: f32) -> usize {
let min_bet = self.config.min_bet;
let max_bet = self.config.max_bet;
if bet <= min_bet {
return ACTION_IDX_RAISE_MIN;
}
if bet >= max_bet {
return ACTION_IDX_RAISE_MAX;
}
let log_min = min_bet.ln();
let log_max = max_bet.ln();
let log_bet = bet.ln();
let fraction = (log_bet - log_min) / (log_max - log_min);
let num_slots = (ACTION_IDX_RAISE_MAX - ACTION_IDX_RAISE_MIN) as f32;
let index = ACTION_IDX_RAISE_MIN + (fraction * num_slots).round() as usize;
index.clamp(ACTION_IDX_RAISE_MIN, ACTION_IDX_RAISE_MAX)
}
pub fn index_to_bet(&self, index: usize) -> Option<f32> {
match index {
ACTION_IDX_FOLD | ACTION_IDX_CALL | ACTION_IDX_ALL_IN => None,
idx if (ACTION_IDX_RAISE_MIN..=ACTION_IDX_RAISE_MAX).contains(&idx) => {
let min_bet = self.config.min_bet;
let max_bet = self.config.max_bet;
let log_min = min_bet.ln();
let log_max = max_bet.ln();
let num_slots = (ACTION_IDX_RAISE_MAX - ACTION_IDX_RAISE_MIN) as f32;
let fraction = (idx - ACTION_IDX_RAISE_MIN) as f32 / num_slots;
let log_bet = log_min + fraction * (log_max - log_min);
Some(log_bet.exp())
}
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::GameStateBuilder;
fn create_test_game_state() -> GameState {
GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap()
}
fn create_mapper() -> ActionIndexMapper {
ActionIndexMapper::new(ActionIndexMapperConfig::new(10.0, 100.0))
}
#[test]
fn test_fold_always_maps_to_zero() {
let mapper = create_mapper();
let game_state = create_test_game_state();
assert_eq!(
mapper.action_to_idx(&AgentAction::Fold, &game_state),
ACTION_IDX_FOLD
);
}
#[test]
fn test_call_always_maps_to_one() {
let mapper = create_mapper();
let game_state = create_test_game_state();
assert_eq!(
mapper.action_to_idx(&AgentAction::Call, &game_state),
ACTION_IDX_CALL
);
}
#[test]
fn test_all_in_always_maps_to_all_in_idx() {
let mapper = create_mapper();
let game_state = create_test_game_state();
assert_eq!(
mapper.action_to_idx(&AgentAction::AllIn, &game_state),
ACTION_IDX_ALL_IN
);
}
#[test]
fn test_bet_matching_current_bet_maps_to_call() {
let mapper = create_mapper();
let game_state = create_test_game_state();
let current_bet = game_state.current_round_bet();
assert_eq!(
mapper.action_to_idx(&AgentAction::Bet(current_bet), &game_state),
ACTION_IDX_CALL
);
}
#[test]
fn test_bet_matching_all_in_maps_to_all_in_idx() {
let mapper = create_mapper();
let game_state = create_test_game_state();
let all_in_amount =
game_state.current_round_current_player_bet() + game_state.current_player_stack();
assert_eq!(
mapper.action_to_idx(&AgentAction::Bet(all_in_amount), &game_state),
ACTION_IDX_ALL_IN
);
}
#[test]
fn test_raises_spread_across_indices() {
let mapper = ActionIndexMapper::new(ActionIndexMapperConfig::new(10.0, 1000.0));
let game_state = create_test_game_state();
let small_raise_idx = mapper.action_to_idx(&AgentAction::Bet(20.0), &game_state);
assert!((ACTION_IDX_RAISE_MIN..=ACTION_IDX_RAISE_MAX).contains(&small_raise_idx));
let large_raise_idx = mapper.action_to_idx(&AgentAction::Bet(500.0), &game_state);
assert!((ACTION_IDX_RAISE_MIN..=ACTION_IDX_RAISE_MAX).contains(&large_raise_idx));
assert!(large_raise_idx > small_raise_idx);
}
#[test]
fn test_min_bet_maps_to_raise_min() {
let mapper = ActionIndexMapper::new(ActionIndexMapperConfig::new(10.0, 100.0));
let game_state = create_test_game_state();
let idx = mapper.action_to_idx(&AgentAction::Bet(10.0), &game_state);
assert_eq!(idx, ACTION_IDX_RAISE_MIN);
}
#[test]
fn test_bet_below_min_maps_to_raise_min() {
let mapper = ActionIndexMapper::new(ActionIndexMapperConfig::new(10.0, 100.0));
let game_state = create_test_game_state();
let idx = mapper.action_to_idx(&AgentAction::Bet(5.0), &game_state);
assert_eq!(idx, ACTION_IDX_RAISE_MIN);
}
#[test]
fn test_bet_at_max_maps_to_raise_max() {
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 200.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let mapper = ActionIndexMapper::new(ActionIndexMapperConfig::new(10.0, 100.0));
let idx = mapper.action_to_idx(&AgentAction::Bet(100.0), &game_state);
assert_eq!(idx, ACTION_IDX_RAISE_MAX);
}
#[test]
fn test_bet_above_max_maps_to_raise_max() {
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 200.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let mapper = ActionIndexMapper::new(ActionIndexMapperConfig::new(10.0, 100.0));
let idx = mapper.action_to_idx(&AgentAction::Bet(150.0), &game_state);
assert_eq!(idx, ACTION_IDX_RAISE_MAX);
}
#[test]
fn test_log_distribution_midpoint() {
let mapper = ActionIndexMapper::new(ActionIndexMapperConfig::new(10.0, 1000.0));
let midpoint_idx = mapper.bet_to_index(100.0);
let mid_idx = (ACTION_IDX_RAISE_MIN + ACTION_IDX_RAISE_MAX) / 2;
assert!(
(midpoint_idx as i32 - mid_idx as i32).abs() <= 1,
"Geometric mean should map to middle index, got {} expected ~{}",
midpoint_idx,
mid_idx
);
}
#[test]
fn test_index_to_bet_roundtrip() {
let mapper = ActionIndexMapper::new(ActionIndexMapperConfig::new(10.0, 1000.0));
for idx in ACTION_IDX_RAISE_MIN..=ACTION_IDX_RAISE_MAX {
let bet = mapper.index_to_bet(idx).unwrap();
let recovered_idx = mapper.bet_to_index(bet);
assert_eq!(
idx, recovered_idx,
"Index {} -> bet {} -> index {}",
idx, bet, recovered_idx
);
}
}
#[test]
fn test_compute_effective_range() {
let game_state = GameStateBuilder::new()
.stacks(vec![100.0, 200.0, 150.0])
.blinds(10.0, 5.0)
.build()
.unwrap();
let (min_bet, max_bet) = compute_effective_range(&game_state);
assert_eq!(min_bet, 10.0);
assert_eq!(max_bet, 150.0);
}
#[test]
fn test_compute_effective_range_two_players() {
let game_state = GameStateBuilder::new()
.stacks(vec![100.0, 200.0])
.blinds(10.0, 5.0)
.build()
.unwrap();
let (min_bet, max_bet) = compute_effective_range(&game_state);
assert_eq!(min_bet, 10.0);
assert_eq!(max_bet, 100.0);
}
#[test]
fn test_config_from_game_state() {
let game_state = GameStateBuilder::new()
.stacks(vec![100.0, 200.0])
.blinds(10.0, 5.0)
.build()
.unwrap();
let config = ActionIndexMapperConfig::from_game_state(&game_state);
assert_eq!(config.min_bet, 10.0);
assert_eq!(config.max_bet, 100.0);
}
#[test]
fn test_mapper_from_game_state() {
let game_state = GameStateBuilder::new()
.stacks(vec![100.0, 200.0])
.blinds(10.0, 5.0)
.build()
.unwrap();
let mapper = ActionIndexMapper::from_game_state(&game_state);
assert_eq!(mapper.config().min_bet, 10.0);
assert_eq!(mapper.config().max_bet, 100.0);
}
#[test]
fn test_small_pot_edge_case() {
let game_state = GameStateBuilder::new()
.stacks(vec![10.0, 10.0])
.blinds(1.0, 0.5)
.build()
.unwrap();
let mapper = ActionIndexMapper::from_game_state(&game_state);
let idx = mapper.action_to_idx(&AgentAction::Bet(5.0), &game_state);
assert!((ACTION_IDX_RAISE_MIN..=ACTION_IDX_RAISE_MAX).contains(&idx));
}
#[test]
fn test_all_in_close_to_min_raise() {
let game_state = GameStateBuilder::new()
.stacks(vec![12.0, 100.0])
.blinds(10.0, 5.0)
.build()
.unwrap();
let mapper = ActionIndexMapper::from_game_state(&game_state);
assert_eq!(
mapper.action_to_idx(&AgentAction::AllIn, &game_state),
ACTION_IDX_ALL_IN
);
let all_in_amount =
game_state.current_round_current_player_bet() + game_state.current_player_stack();
assert_eq!(
mapper.action_to_idx(&AgentAction::Bet(all_in_amount), &game_state),
ACTION_IDX_ALL_IN
);
}
#[test]
fn test_index_to_bet_returns_none_for_special_indices() {
let mapper = create_mapper();
assert!(mapper.index_to_bet(ACTION_IDX_FOLD).is_none());
assert!(mapper.index_to_bet(ACTION_IDX_CALL).is_none());
assert!(mapper.index_to_bet(ACTION_IDX_ALL_IN).is_none());
}
#[test]
fn test_index_to_bet_returns_none_for_out_of_range() {
let mapper = create_mapper();
assert!(mapper.index_to_bet(16).is_none());
assert!(mapper.index_to_bet(100).is_none());
}
#[test]
fn test_num_action_indices_constant() {
assert_eq!(NUM_ACTION_INDICES, 16);
}
}