use std::sync::Arc;
use thiserror::Error;
use crate::arena::{GameState, action::AgentAction, game_state::Round};
use super::super::{CFRState, TraversalState};
use super::{ActionGenerator, ActionVec};
#[derive(Debug, Error, PartialEq)]
pub enum ConfigurableActionConfigError {
#[error("raise_mult must be >= 1.0 (cannot raise less than min raise), got {0}")]
RaiseMultBelowOne(f32),
#[error("pot_mult must be non-negative, got {0}")]
PotMultNegative(f32),
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub struct RoundActionConfig {
pub call_enabled: bool,
pub raise_mult: Vec<f32>,
pub pot_mult: Vec<f32>,
pub setup_shove: bool,
pub all_in: bool,
}
impl Default for RoundActionConfig {
fn default() -> Self {
Self {
call_enabled: true,
raise_mult: vec![1.0, 2.0],
pot_mult: vec![0.5, 1.0],
setup_shove: false,
all_in: true,
}
}
}
impl RoundActionConfig {
pub fn validate(&self) -> Result<(), ConfigurableActionConfigError> {
for &mult in &self.raise_mult {
if mult < 1.0 {
return Err(ConfigurableActionConfigError::RaiseMultBelowOne(mult));
}
}
for &mult in &self.pot_mult {
if mult < 0.0 {
return Err(ConfigurableActionConfigError::PotMultNegative(mult));
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[derive(Default)]
pub struct ConfigurableActionConfig {
#[cfg_attr(feature = "serde", serde(default))]
pub default: RoundActionConfig,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub preflop: Option<RoundActionConfig>,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub flop: Option<RoundActionConfig>,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub turn: Option<RoundActionConfig>,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub river: Option<RoundActionConfig>,
}
impl ConfigurableActionConfig {
pub fn round_config(&self, round: Round) -> &RoundActionConfig {
match round {
Round::DealPreflop | Round::Preflop => self.preflop.as_ref().unwrap_or(&self.default),
Round::DealFlop | Round::Flop => self.flop.as_ref().unwrap_or(&self.default),
Round::DealTurn | Round::Turn => self.turn.as_ref().unwrap_or(&self.default),
Round::DealRiver | Round::River => self.river.as_ref().unwrap_or(&self.default),
Round::Starting | Round::Ante | Round::Showdown | Round::Complete => &self.default,
}
}
pub fn validate(&self) -> Result<(), ConfigurableActionConfigError> {
self.default.validate()?;
if let Some(ref cfg) = self.preflop {
cfg.validate()?;
}
if let Some(ref cfg) = self.flop {
cfg.validate()?;
}
if let Some(ref cfg) = self.turn {
cfg.validate()?;
}
if let Some(ref cfg) = self.river {
cfg.validate()?;
}
Ok(())
}
}
pub struct ConfigurableActionGenerator {
cfr_state: CFRState,
traversal_state: TraversalState,
config: Arc<ConfigurableActionConfig>,
}
impl ConfigurableActionGenerator {
pub fn new_with_config(
cfr_state: CFRState,
traversal_state: TraversalState,
config: ConfigurableActionConfig,
) -> Self {
ConfigurableActionGenerator {
cfr_state,
traversal_state,
config: Arc::new(config),
}
}
pub(crate) fn gen_actions_from_config(
config: &ConfigurableActionConfig,
game_state: &GameState,
) -> ActionVec {
let mut actions = ActionVec::new();
let mut used_amounts: Vec<f32> = Vec::new();
let epsilon = 0.01;
let current_bet = game_state.current_round_bet();
let player_bet = game_state.current_round_current_player_bet();
let stack = game_state.current_player_stack();
let pot = game_state.total_pot;
let min_raise = game_state.current_round_min_raise();
let to_call = current_bet - player_bet;
let all_in_amount = player_bet + stack;
let round_config = config.round_config(game_state.round);
let is_amount_used = |amount: f32, used: &[f32]| -> bool {
used.iter().any(|&a| (a - amount).abs() < epsilon)
};
if to_call > 0.0 {
actions.push(AgentAction::Fold);
}
if round_config.call_enabled {
actions.push(AgentAction::Bet(current_bet));
used_amounts.push(current_bet);
}
let min_valid_raise = current_bet + min_raise;
for &mult in &round_config.raise_mult {
let raise_amount = current_bet + min_raise * mult;
if raise_amount >= min_valid_raise
&& raise_amount < all_in_amount
&& !is_amount_used(raise_amount, &used_amounts)
{
actions.push(AgentAction::Bet(raise_amount));
used_amounts.push(raise_amount);
}
}
for &mult in &round_config.pot_mult {
let pot_amount = current_bet + pot * mult;
if pot_amount >= min_valid_raise
&& pot_amount < all_in_amount
&& !is_amount_used(pot_amount, &used_amounts)
{
actions.push(AgentAction::Bet(pot_amount));
used_amounts.push(pot_amount);
}
}
if round_config.setup_shove {
let setup_bet = (stack + player_bet + current_bet - pot) / 2.0;
if setup_bet >= min_valid_raise
&& setup_bet < all_in_amount
&& !is_amount_used(setup_bet, &used_amounts)
{
actions.push(AgentAction::Bet(setup_bet));
used_amounts.push(setup_bet);
}
}
if round_config.all_in && all_in_amount > current_bet {
actions.push(AgentAction::AllIn);
}
actions
}
}
impl ActionGenerator for ConfigurableActionGenerator {
type Config = ConfigurableActionConfig;
fn new(
cfr_state: CFRState,
traversal_state: TraversalState,
config: Arc<Self::Config>,
) -> Self {
ConfigurableActionGenerator {
cfr_state,
traversal_state,
config,
}
}
fn config(&self) -> &Self::Config {
&self.config
}
fn cfr_state(&self) -> &CFRState {
&self.cfr_state
}
fn traversal_state(&self) -> &TraversalState {
&self.traversal_state
}
fn gen_possible_actions(&self, game_state: &GameState) -> ActionVec {
Self::gen_actions_from_config(&self.config, game_state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::GameStateBuilder;
fn create_configurable_generator(
game_state: &GameState,
config: ConfigurableActionConfig,
) -> ConfigurableActionGenerator {
ConfigurableActionGenerator::new_with_config(
CFRState::new(game_state.clone()),
TraversalState::new_root(0),
config,
)
}
fn default_configurable_config() -> ConfigurableActionConfig {
ConfigurableActionConfig {
default: RoundActionConfig {
call_enabled: true,
raise_mult: vec![1.0, 2.0],
pot_mult: vec![0.5, 1.0],
setup_shove: false,
all_in: true,
},
preflop: None,
flop: None,
turn: None,
river: None,
}
}
#[test]
fn test_validate_raise_mult_below_one() {
let cfg = RoundActionConfig {
raise_mult: vec![0.5],
..RoundActionConfig::default()
};
assert_eq!(
cfg.validate().unwrap_err(),
ConfigurableActionConfigError::RaiseMultBelowOne(0.5)
);
}
#[test]
fn test_validate_pot_mult_negative() {
let cfg = RoundActionConfig {
pot_mult: vec![-0.1],
..RoundActionConfig::default()
};
assert_eq!(
cfg.validate().unwrap_err(),
ConfigurableActionConfigError::PotMultNegative(-0.1)
);
}
#[test]
fn test_validate_nested_round_override() {
let cfg = ConfigurableActionConfig {
preflop: Some(RoundActionConfig {
raise_mult: vec![0.9],
..RoundActionConfig::default()
}),
..ConfigurableActionConfig::default()
};
assert!(matches!(
cfg.validate().unwrap_err(),
ConfigurableActionConfigError::RaiseMultBelowOne(_)
));
}
#[test]
fn test_configurable_gen_actions_basic() {
let stacks = vec![500.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
game_state.do_bet(30.0, false).unwrap();
let action_gen = create_configurable_generator(&game_state, default_configurable_config());
let actions = action_gen.gen_possible_actions(&game_state);
assert!(actions.contains(&AgentAction::Fold));
assert!(actions.iter().any(|a| matches!(a, AgentAction::Bet(_))));
assert!(actions.contains(&AgentAction::AllIn));
}
#[test]
fn test_configurable_no_fold_when_checking() {
let stacks = vec![100.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
let action_gen = create_configurable_generator(&game_state, default_configurable_config());
let actions = action_gen.gen_possible_actions(&game_state);
assert!(!actions.contains(&AgentAction::Fold));
}
#[test]
fn test_configurable_with_setup_shove() {
let stacks = vec![500.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
let config = ConfigurableActionConfig {
default: RoundActionConfig {
call_enabled: true,
raise_mult: vec![1.0],
pot_mult: vec![0.5],
setup_shove: true,
all_in: true,
},
preflop: None,
flop: None,
turn: None,
river: None,
};
let action_gen = create_configurable_generator(&game_state, config);
let actions = action_gen.gen_possible_actions(&game_state);
assert!(!actions.is_empty());
}
#[test]
fn test_configurable_per_round_config() {
let stacks = vec![500.0; 2];
let game_state_preflop = GameStateBuilder::new()
.stacks(stacks.clone())
.blinds(10.0, 5.0)
.build()
.unwrap();
let config = ConfigurableActionConfig {
default: RoundActionConfig {
call_enabled: true,
raise_mult: vec![1.0],
pot_mult: vec![0.5, 1.0],
setup_shove: false,
all_in: true,
},
preflop: Some(RoundActionConfig {
call_enabled: true,
raise_mult: vec![2.0, 2.5, 3.0], pot_mult: vec![], setup_shove: false,
all_in: true,
}),
flop: None,
turn: None,
river: None,
};
let action_gen = create_configurable_generator(&game_state_preflop, config.clone());
let preflop_actions = action_gen.gen_possible_actions(&game_state_preflop);
assert!(!preflop_actions.is_empty());
}
fn verify_configurable_actions_valid(game_state: &GameState, config: ConfigurableActionConfig) {
let action_gen = create_configurable_generator(game_state, config);
let actions = action_gen.gen_possible_actions(game_state);
for action in &actions {
let mut gs_copy = game_state.clone();
let result = match action {
AgentAction::Fold => {
gs_copy.fold();
Ok(())
}
AgentAction::Bet(amount) => gs_copy.do_bet(*amount, false).map(|_| ()),
AgentAction::Call => {
let call_amount = gs_copy.current_round_bet();
gs_copy.do_bet(call_amount, false).map(|_| ())
}
AgentAction::AllIn => {
let all_in_amount =
gs_copy.current_round_current_player_bet() + gs_copy.current_player_stack();
gs_copy.do_bet(all_in_amount, false).map(|_| ())
}
};
assert!(
result.is_ok(),
"Action {:?} should be valid but got error: {:?}\n\
Game state: current_bet={}, min_raise={}, player_bet={}, stack={}, pot={}",
action,
result.err(),
game_state.current_round_bet(),
game_state.current_round_min_raise(),
game_state.current_round_current_player_bet(),
game_state.current_player_stack(),
game_state.total_pot
);
}
}
#[test]
fn test_configurable_all_actions_valid_preflop() {
let stacks = vec![100.0; 2];
let game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
verify_configurable_actions_valid(&game_state, default_configurable_config());
}
#[test]
fn test_configurable_all_actions_valid_flop() {
let stacks = vec![100.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
verify_configurable_actions_valid(&game_state, default_configurable_config());
}
#[test]
fn test_configurable_all_actions_valid_facing_bet() {
let stacks = vec![500.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
game_state.do_bet(30.0, false).unwrap();
verify_configurable_actions_valid(&game_state, default_configurable_config());
}
#[test]
fn test_configurable_call_disabled() {
let stacks = vec![500.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
game_state.do_bet(30.0, false).unwrap();
let config = ConfigurableActionConfig {
default: RoundActionConfig {
call_enabled: false,
raise_mult: vec![1.0],
pot_mult: vec![],
setup_shove: false,
all_in: true,
},
preflop: None,
flop: None,
turn: None,
river: None,
};
let action_gen = create_configurable_generator(&game_state, config);
let actions = action_gen.gen_possible_actions(&game_state);
assert!(!actions.is_empty());
assert!(
actions
.iter()
.any(|a| matches!(a, AgentAction::Bet(_) | AgentAction::AllIn))
);
}
}