use std::sync::Arc;
use crate::arena::{GameState, action::AgentAction};
use super::super::{CFRState, TraversalState};
use super::{ActionGenerator, ActionVec};
pub struct SimpleActionGenerator {
cfr_state: CFRState,
traversal_state: TraversalState,
}
impl SimpleActionGenerator {
pub fn new(cfr_state: CFRState, traversal_state: TraversalState) -> Self {
SimpleActionGenerator {
cfr_state,
traversal_state,
}
}
}
impl ActionGenerator for SimpleActionGenerator {
type Config = ();
fn new(cfr_state: CFRState, traversal_state: TraversalState, _config: Arc<()>) -> Self {
SimpleActionGenerator::new(cfr_state, traversal_state)
}
fn config(&self) -> &Self::Config {
&()
}
fn cfr_state(&self) -> &CFRState {
&self.cfr_state
}
fn traversal_state(&self) -> &TraversalState {
&self.traversal_state
}
fn gen_possible_actions(&self, game_state: &GameState) -> ActionVec {
let mut actions = ActionVec::with_capacity(6);
let current_bet = game_state.current_round_bet();
let player_bet = game_state.current_round_current_player_bet();
let stack = game_state.current_player_stack();
let pot = game_state.total_pot;
let min_raise = game_state.current_round_min_raise();
let to_call = current_bet - player_bet;
let all_in_amount = player_bet + stack;
let min_raise_amount = current_bet + min_raise;
if to_call > 0.0 {
actions.push(AgentAction::Fold);
}
actions.push(AgentAction::Bet(current_bet));
if min_raise_amount > current_bet && min_raise_amount < all_in_amount {
actions.push(AgentAction::Bet(min_raise_amount));
}
let pot_33_amount = current_bet + pot * 0.33;
if pot_33_amount >= min_raise_amount
&& pot_33_amount > min_raise_amount
&& pot_33_amount < all_in_amount
{
actions.push(AgentAction::Bet(pot_33_amount));
}
let pot_66_amount = current_bet + pot * 0.66;
if pot_66_amount >= min_raise_amount
&& pot_66_amount > pot_33_amount
&& pot_66_amount < all_in_amount
{
actions.push(AgentAction::Bet(pot_66_amount));
}
if all_in_amount > current_bet {
actions.push(AgentAction::AllIn);
}
actions
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::GameStateBuilder;
fn create_simple_generator(game_state: &GameState) -> SimpleActionGenerator {
SimpleActionGenerator::new(
CFRState::new(game_state.clone()),
TraversalState::new_root(0),
)
}
#[test]
fn test_simple_gen_actions_with_bet_facing() {
let stacks = vec![500.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
game_state.do_bet(30.0, false).unwrap();
let action_gen = create_simple_generator(&game_state);
let actions = action_gen.gen_possible_actions(&game_state);
assert!(actions.contains(&AgentAction::Fold));
assert!(actions.iter().any(|a| matches!(a, AgentAction::Bet(_))));
assert!(actions.contains(&AgentAction::AllIn));
}
#[test]
fn test_simple_no_fold_when_checking() {
let stacks = vec![100.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
let action_gen = create_simple_generator(&game_state);
let actions = action_gen.gen_possible_actions(&game_state);
assert!(!actions.contains(&AgentAction::Fold));
}
fn verify_all_actions_valid(game_state: &GameState) {
let action_gen = create_simple_generator(game_state);
let actions = action_gen.gen_possible_actions(game_state);
for action in &actions {
let mut gs_copy = game_state.clone();
let result = match action {
AgentAction::Fold => {
gs_copy.fold();
Ok(())
}
AgentAction::Bet(amount) => gs_copy.do_bet(*amount, false).map(|_| ()),
AgentAction::Call => {
let call_amount = gs_copy.current_round_bet();
gs_copy.do_bet(call_amount, false).map(|_| ())
}
AgentAction::AllIn => {
let all_in_amount =
gs_copy.current_round_current_player_bet() + gs_copy.current_player_stack();
gs_copy.do_bet(all_in_amount, false).map(|_| ())
}
};
assert!(
result.is_ok(),
"Action {:?} should be valid but got error: {:?}\n\
Game state: current_bet={}, min_raise={}, player_bet={}, stack={}, pot={}",
action,
result.err(),
game_state.current_round_bet(),
game_state.current_round_min_raise(),
game_state.current_round_current_player_bet(),
game_state.current_player_stack(),
game_state.total_pot
);
}
}
#[test]
fn test_simple_all_actions_valid_preflop_sb() {
let stacks = vec![100.0; 2];
let game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
verify_all_actions_valid(&game_state);
}
#[test]
fn test_simple_all_actions_valid_preflop_bb() {
let stacks = vec![100.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.do_bet(10.0, false).unwrap(); verify_all_actions_valid(&game_state);
}
#[test]
fn test_simple_all_actions_valid_flop_first_to_act() {
let stacks = vec![100.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round(); verify_all_actions_valid(&game_state);
}
#[test]
fn test_simple_all_actions_valid_facing_bet() {
let stacks = vec![100.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
game_state.do_bet(20.0, false).unwrap(); verify_all_actions_valid(&game_state);
}
#[test]
fn test_simple_all_actions_valid_facing_raise() {
let stacks = vec![200.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
game_state.do_bet(20.0, false).unwrap(); game_state.do_bet(50.0, false).unwrap(); verify_all_actions_valid(&game_state);
}
#[test]
fn test_simple_all_actions_valid_small_stack() {
let stacks = vec![30.0, 100.0];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
game_state.do_bet(15.0, false).unwrap();
verify_all_actions_valid(&game_state);
}
#[test]
fn test_simple_all_actions_valid_large_pot() {
let stacks = vec![100.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.do_bet(10.0, false).unwrap(); game_state.do_bet(30.0, false).unwrap(); game_state.do_bet(30.0, false).unwrap(); game_state.advance_round(); verify_all_actions_valid(&game_state);
}
#[test]
fn test_simple_all_actions_valid_tiny_pot() {
let stacks = vec![1000.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(2.0, 1.0)
.build()
.unwrap();
game_state.advance_round();
verify_all_actions_valid(&game_state);
}
#[test]
fn test_simple_all_actions_valid_after_multiple_raises() {
let stacks = vec![500.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
game_state.do_bet(20.0, false).unwrap(); game_state.do_bet(50.0, false).unwrap(); game_state.do_bet(110.0, false).unwrap(); verify_all_actions_valid(&game_state);
}
#[test]
fn test_simple_all_actions_valid_three_players() {
let stacks = vec![100.0; 3];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
game_state.do_bet(15.0, false).unwrap();
game_state.do_bet(15.0, false).unwrap();
verify_all_actions_valid(&game_state);
}
#[test]
fn test_simple_all_actions_valid_river() {
let stacks = vec![100.0; 2];
let mut game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round(); game_state.advance_round(); game_state.advance_round(); verify_all_actions_valid(&game_state);
}
}