use super::action_generator::ActionVec;
use crate::arena::{GameState, action::AgentAction};
const EPSILON: f32 = 0.01;
pub fn validate_actions(actions: impl Into<ActionVec>, game_state: &GameState) -> ActionVec {
let mut actions = actions.into();
actions = filter_invalid_fold(actions, game_state);
actions = filter_call_equivalents(actions, game_state);
actions = filter_all_in_equivalents(actions, game_state);
actions = filter_duplicate_bets(actions);
actions = filter_raises_when_capped(actions, game_state);
actions
}
pub fn filter_invalid_fold(actions: impl Into<ActionVec>, game_state: &GameState) -> ActionVec {
let actions = actions.into();
let to_call = game_state.current_round_bet() - game_state.current_round_current_player_bet();
if to_call <= 0.0 {
actions
.into_iter()
.filter(|a| !matches!(a, AgentAction::Fold))
.collect()
} else {
actions
}
}
pub fn filter_call_equivalents(actions: impl Into<ActionVec>, game_state: &GameState) -> ActionVec {
let actions = actions.into();
let current_bet = game_state.current_round_bet();
let has_call = actions.iter().any(|a| matches!(a, AgentAction::Call));
if has_call {
actions
.into_iter()
.filter(|a| {
if let AgentAction::Bet(amount) = a {
(amount - current_bet).abs() >= EPSILON
} else {
true
}
})
.collect()
} else {
actions
}
}
pub fn filter_all_in_equivalents(
actions: impl Into<ActionVec>,
game_state: &GameState,
) -> ActionVec {
let actions = actions.into();
let all_in_amount =
game_state.current_round_current_player_bet() + game_state.current_player_stack();
let has_all_in = actions.iter().any(|a| matches!(a, AgentAction::AllIn));
if has_all_in {
actions
.into_iter()
.filter(|a| {
if let AgentAction::Bet(amount) = a {
(amount - all_in_amount).abs() >= EPSILON
} else {
true
}
})
.collect()
} else {
actions
}
}
pub fn filter_duplicate_bets(actions: impl Into<ActionVec>) -> ActionVec {
let actions = actions.into();
let mut seen_amounts: Vec<f32> = Vec::new();
let mut result = ActionVec::with_capacity(actions.len());
for action in actions {
match &action {
AgentAction::Bet(amount) => {
let is_duplicate = seen_amounts
.iter()
.any(|seen| (seen - amount).abs() < EPSILON);
if !is_duplicate {
seen_amounts.push(*amount);
result.push(action);
}
}
_ => {
result.push(action);
}
}
}
result
}
pub fn filter_raises_when_capped(
actions: impl Into<ActionVec>,
game_state: &GameState,
) -> ActionVec {
let actions = actions.into();
if !game_state.is_raise_capped() {
return actions;
}
let current_bet = game_state.current_round_bet();
actions
.into_iter()
.filter(|a| {
match a {
AgentAction::Bet(amount) => {
(amount - current_bet).abs() < EPSILON
}
AgentAction::AllIn => true, _ => true, }
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::GameStateBuilder;
fn create_test_game_state() -> GameState {
GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap()
}
fn create_flop_game_state_with_bet(bet_amount: f32) -> GameState {
let mut game_state = GameStateBuilder::new()
.num_players_with_stack(2, 500.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round(); game_state.do_bet(bet_amount, false).unwrap();
game_state
}
#[test]
fn test_fold_removed_when_nothing_to_call() {
let mut game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
let actions = vec![AgentAction::Fold, AgentAction::Call, AgentAction::AllIn];
let filtered = filter_invalid_fold(actions, &game_state);
assert!(!filtered.contains(&AgentAction::Fold));
assert!(filtered.contains(&AgentAction::Call));
assert!(filtered.contains(&AgentAction::AllIn));
}
#[test]
fn test_fold_kept_when_facing_bet() {
let game_state = create_flop_game_state_with_bet(30.0);
let actions = vec![AgentAction::Fold, AgentAction::Call, AgentAction::AllIn];
let filtered = filter_invalid_fold(actions, &game_state);
assert!(filtered.contains(&AgentAction::Fold));
}
#[test]
fn test_bet_equal_to_call_removed_when_call_exists() {
let game_state = create_flop_game_state_with_bet(30.0);
let call_amount = game_state.current_round_bet();
let actions = vec![
AgentAction::Call,
AgentAction::Bet(call_amount),
AgentAction::Bet(50.0),
];
let filtered = filter_call_equivalents(actions, &game_state);
assert!(filtered.contains(&AgentAction::Call));
assert!(filtered.iter().any(|a| matches!(a, AgentAction::Bet(50.0))));
assert_eq!(filtered.len(), 2);
}
#[test]
fn test_bet_equal_to_call_kept_when_no_call() {
let game_state = create_flop_game_state_with_bet(30.0);
let call_amount = game_state.current_round_bet();
let actions = vec![AgentAction::Fold, AgentAction::Bet(call_amount)];
let filtered = filter_call_equivalents(actions, &game_state);
assert_eq!(filtered.len(), 2);
}
#[test]
fn test_bet_equal_to_all_in_removed_when_all_in_exists() {
let mut game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
let all_in =
game_state.current_round_current_player_bet() + game_state.current_player_stack();
let actions = vec![
AgentAction::AllIn,
AgentAction::Bet(all_in),
AgentAction::Bet(50.0),
];
let filtered = filter_all_in_equivalents(actions, &game_state);
assert!(filtered.contains(&AgentAction::AllIn));
assert!(filtered.iter().any(|a| matches!(a, AgentAction::Bet(50.0))));
assert_eq!(filtered.len(), 2);
}
#[test]
fn test_bet_less_than_all_in_kept() {
let mut game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
let actions = vec![AgentAction::AllIn, AgentAction::Bet(50.0)];
let filtered = filter_all_in_equivalents(actions, &game_state);
assert!(filtered.contains(&AgentAction::AllIn));
assert!(filtered.iter().any(|a| matches!(a, AgentAction::Bet(50.0))));
assert_eq!(filtered.len(), 2);
}
#[test]
fn test_duplicate_bet_amounts_removed() {
let actions = vec![
AgentAction::Fold,
AgentAction::Bet(30.0),
AgentAction::Bet(30.0),
AgentAction::Bet(50.0),
];
let filtered = filter_duplicate_bets(actions);
assert!(filtered.contains(&AgentAction::Fold));
let bet_30_count = filtered
.iter()
.filter(|a| matches!(a, AgentAction::Bet(x) if (*x - 30.0).abs() < EPSILON))
.count();
assert_eq!(bet_30_count, 1);
assert_eq!(filtered.len(), 3); }
#[test]
fn test_different_bet_amounts_kept() {
let actions = vec![
AgentAction::Bet(20.0),
AgentAction::Bet(30.0),
AgentAction::Bet(50.0),
];
let filtered = filter_duplicate_bets(actions);
assert_eq!(filtered.len(), 3);
}
#[test]
fn test_raises_removed_when_capped() {
let mut game_state = create_flop_game_state_with_bet(30.0);
game_state.round_data.total_raise_count = 3;
game_state.max_raises_per_round = Some(3);
let current_bet = game_state.current_round_bet();
let actions = vec![
AgentAction::Fold,
AgentAction::Bet(current_bet), AgentAction::Bet(60.0), AgentAction::AllIn,
];
let filtered = filter_raises_when_capped(actions, &game_state);
assert!(filtered.contains(&AgentAction::Fold));
assert!(filtered.contains(&AgentAction::AllIn));
assert!(
filtered
.iter()
.any(|a| matches!(a, AgentAction::Bet(x) if (*x - current_bet).abs() < EPSILON))
);
assert!(
!filtered
.iter()
.any(|a| matches!(a, AgentAction::Bet(x) if (*x - 60.0).abs() < EPSILON))
);
}
#[test]
fn test_all_in_kept_when_capped() {
let mut game_state = create_flop_game_state_with_bet(30.0);
game_state.round_data.total_raise_count = 3;
game_state.max_raises_per_round = Some(3);
let actions = vec![AgentAction::Fold, AgentAction::AllIn];
let filtered = filter_raises_when_capped(actions, &game_state);
assert!(filtered.contains(&AgentAction::AllIn));
}
#[test]
fn test_raises_kept_when_not_capped() {
let mut game_state = create_flop_game_state_with_bet(30.0);
game_state.round_data.total_raise_count = 2;
game_state.max_raises_per_round = Some(3);
let actions = vec![
AgentAction::Fold,
AgentAction::Bet(60.0),
AgentAction::AllIn,
];
let filtered = filter_raises_when_capped(actions, &game_state);
assert_eq!(filtered.len(), 3);
}
#[test]
fn test_raises_kept_when_no_limit() {
let mut game_state = create_flop_game_state_with_bet(30.0);
game_state.round_data.total_raise_count = 10;
game_state.max_raises_per_round = None;
let actions = vec![
AgentAction::Fold,
AgentAction::Bet(60.0),
AgentAction::AllIn,
];
let filtered = filter_raises_when_capped(actions, &game_state);
assert_eq!(filtered.len(), 3);
}
#[test]
fn test_full_validation_pipeline() {
let mut game_state = create_flop_game_state_with_bet(30.0);
game_state.round_data.total_raise_count = 2;
game_state.max_raises_per_round = Some(3);
let current_bet = game_state.current_round_bet();
let all_in =
game_state.current_round_current_player_bet() + game_state.current_player_stack();
let actions = vec![
AgentAction::Fold,
AgentAction::Call,
AgentAction::Bet(current_bet), AgentAction::Bet(50.0),
AgentAction::Bet(50.0), AgentAction::Bet(all_in), AgentAction::AllIn,
];
let filtered = validate_actions(actions, &game_state);
assert!(filtered.contains(&AgentAction::Fold));
assert!(filtered.contains(&AgentAction::Call));
assert!(filtered.contains(&AgentAction::AllIn));
let bet_50_count = filtered
.iter()
.filter(|a| matches!(a, AgentAction::Bet(x) if (*x - 50.0).abs() < EPSILON))
.count();
assert_eq!(bet_50_count, 1);
assert!(
!filtered
.iter()
.any(|a| matches!(a, AgentAction::Bet(x) if (*x - current_bet).abs() < EPSILON))
);
}
#[test]
fn test_all_in_equals_sizing() {
let mut game_state = GameStateBuilder::new()
.num_players_with_stack(2, 50.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round();
let all_in =
game_state.current_round_current_player_bet() + game_state.current_player_stack();
let actions = vec![
AgentAction::Call,
AgentAction::Bet(all_in), AgentAction::AllIn,
];
let filtered = filter_all_in_equivalents(actions, &game_state);
assert!(filtered.contains(&AgentAction::AllIn));
assert_eq!(filtered.len(), 2);
}
#[test]
fn test_empty_actions_after_filtering() {
let game_state = create_test_game_state();
let actions = vec![];
let filtered = validate_actions(actions, &game_state);
assert!(filtered.is_empty());
}
#[test]
fn test_player_already_all_in() {
let mut game_state = GameStateBuilder::new()
.stacks(vec![100.0, 10.0])
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.advance_round(); game_state.advance_round(); game_state.advance_round();
game_state.do_bet(5.0, true).unwrap(); game_state.do_bet(10.0, true).unwrap(); game_state.do_bet(10.0, false).unwrap();
game_state.do_bet(10.0, false).unwrap();
let actions = vec![AgentAction::Fold, AgentAction::Call];
let filtered = validate_actions(actions, &game_state);
assert!(!filtered.is_empty());
}
}