use approx::assert_abs_diff_eq;
use crate::arena::game_state::Round;
use super::action::AgentAction;
use super::{GameState, game_state::RoundData};
use crate::arena::action::Action;
use crate::arena::historian::HistoryRecord;
pub fn assert_valid_round_data(round_data: &RoundData) {
let active_bets: Vec<f32> = round_data
.player_bet
.iter()
.enumerate()
.filter(|(idx, _)| round_data.needs_action.get(*idx))
.map(|(_, bet)| *bet)
.collect();
let max_active = active_bets.clone().into_iter().reduce(f32::max);
if let Some(max) = max_active {
let epsilon = if max == 0.0 {
f32::EPSILON
} else {
max / 100_000.0
};
for bet in active_bets.into_iter() {
assert_abs_diff_eq!(bet, max, epsilon = epsilon);
}
}
}
pub fn assert_valid_game_state(game_state: &GameState) {
assert_eq!(Round::Complete, game_state.round);
let should_have_bets = game_state.ante + game_state.small_blind + game_state.big_blind > 0.0;
let total_bet = game_state.player_bet.iter().copied().sum();
if should_have_bets {
let any_above_zero = game_state.player_bet.iter().any(|bet| *bet > 0.0);
assert!(
any_above_zero,
"At least one player should have a bet, game_state: {:?}",
game_state.player_bet
);
assert_ne!(0.0, total_bet);
}
let epsilon = total_bet / 100_000.0;
assert_abs_diff_eq!(total_bet, game_state.total_pot, epsilon = epsilon);
let total_winning: f32 = game_state.player_winnings.iter().copied().sum();
assert_abs_diff_eq!(total_winning, total_bet, epsilon = epsilon);
assert_abs_diff_eq!(total_winning, game_state.total_pot, epsilon = epsilon);
assert!(game_state.dealer_idx < game_state.num_players);
assert!(game_state.board.len() <= 5);
assert!(game_state.small_blind <= game_state.big_blind);
validate_board_cards(game_state);
validate_player_states(game_state);
validate_betting_structure(game_state);
validate_winnings_distribution(game_state);
validate_deck_integrity(game_state);
validate_dealer_positioning(game_state);
validate_ante_structure(game_state);
validate_stack_integrity(game_state);
validate_side_pot_distribution(game_state);
for idx in 0..game_state.num_players {
if !game_state.player_active.get(idx) && !game_state.player_all_in.get(idx) {
assert_abs_diff_eq!(game_state.player_winnings[idx], 0.0, epsilon = f32::EPSILON);
}
}
}
pub fn assert_valid_history(history_storage: &[HistoryRecord]) {
assert!(!history_storage.is_empty());
assert!(
matches!(history_storage[0].action, Action::GameStart(_)),
"First action should be GameStart, but was: {:?}",
history_storage[0].action
);
assert_advances_to_complete(history_storage);
assert_round_contains_valid_player_actions(history_storage);
assert_no_player_actions_after_fold(history_storage);
validate_betting_sequence(history_storage);
validate_round_progression(history_storage);
}
fn assert_advances_to_complete(history_storage: &[HistoryRecord]) {
let round_advances: Vec<&Action> = history_storage
.iter()
.filter(|record| matches!(record.action, Action::RoundAdvance(Round::Complete)))
.map(|record| &record.action)
.collect();
assert_eq!(1, round_advances.len());
}
fn assert_round_contains_valid_player_actions(history_storage: &[HistoryRecord]) {
for round in &[Round::Preflop, Round::Flop, Round::Turn, Round::River] {
let advance_history = history_storage.iter().find(|record| {
if let Action::RoundAdvance(found_round) = &record.action {
found_round == round
} else {
false
}
});
if advance_history.is_none() {
continue;
}
}
}
fn assert_no_player_actions_after_fold(history_storage: &[HistoryRecord]) {
let player_fold_index: Vec<(usize, usize)> = history_storage
.iter()
.enumerate()
.filter_map(|(index, record)| {
if let Action::PlayedAction(action) = &record.action {
if action.action == AgentAction::Fold {
Some((action.idx, index))
} else {
None
}
} else {
None
}
})
.collect();
for (player_idx, fold_index) in player_fold_index {
let actions_after_fold = history_storage
.iter()
.skip(fold_index + 1)
.filter(|record| {
if let Action::PlayedAction(action) = &record.action {
action.idx == player_idx
} else {
false
}
});
assert_eq!(0, actions_after_fold.count());
}
}
fn validate_board_cards(game_state: &GameState) {
let board_len = game_state.board.len();
assert!(
matches!(board_len, 0 | 3 | 4 | 5),
"Invalid board card count: {}. Texas Hold'em boards have 0, 3, 4, or 5 cards",
board_len
);
let mut seen_cards = std::collections::HashSet::new();
for card in &game_state.board {
assert!(
seen_cards.insert(card),
"Duplicate card {:?} found on board",
card
);
}
}
fn validate_player_states(game_state: &GameState) {
for idx in 0..game_state.num_players {
let is_active = game_state.player_active.get(idx);
let is_all_in = game_state.player_all_in.get(idx);
let bet = game_state.player_bet[idx];
let winnings = game_state.player_winnings[idx];
if is_all_in {
}
if !is_active && !is_all_in {
assert_abs_diff_eq!(winnings, 0.0, epsilon = f32::EPSILON);
}
assert!(bet >= 0.0, "Player {} has negative bet: {}", idx, bet);
assert!(
winnings >= 0.0,
"Player {} has negative winnings: {}",
idx,
winnings
);
}
}
fn validate_betting_structure(game_state: &GameState) {
if game_state.small_blind > 0.0 || game_state.big_blind > 0.0 {
assert!(
game_state.small_blind <= game_state.big_blind,
"Small blind {} cannot be larger than big blind {}",
game_state.small_blind,
game_state.big_blind
);
}
}
fn validate_winnings_distribution(game_state: &GameState) {
let total_winnings: f32 = game_state.player_winnings.iter().sum();
let total_bets: f32 = game_state.player_bet.iter().sum();
assert_abs_diff_eq!(total_winnings, total_bets, epsilon = total_bets / 100_000.0);
if total_bets > 0.0 {
let someone_won = game_state.player_winnings.iter().any(|&w| w > 0.0);
assert!(someone_won, "Someone must win if there was betting action");
}
for (idx, &winnings) in game_state.player_winnings.iter().enumerate() {
if winnings > 0.0 {
let is_active = game_state.player_active.get(idx);
let is_all_in = game_state.player_all_in.get(idx);
assert!(
is_active || is_all_in,
"Winner player {} must be either active or all-in, winnings: {}",
idx,
winnings
);
}
}
}
fn validate_deck_integrity(game_state: &GameState) {
let mut dealt_cards = std::collections::HashSet::new();
let board_set: std::collections::HashSet<_> = game_state.board.iter().copied().collect();
assert_eq!(
board_set.len(),
game_state.board.len(),
"Duplicate cards found on the board"
);
dealt_cards.extend(&board_set);
for (player_idx, hand) in game_state.hands.iter().enumerate() {
let hand_set: std::collections::HashSet<_> = hand.iter().collect();
let hole_cards: std::collections::HashSet<_> =
hand_set.difference(&board_set).copied().collect();
if !hole_cards.is_empty() {
for hole_card in &hole_cards {
assert!(
dealt_cards.insert(*hole_card),
"Duplicate hole card {:?} found for player {}",
hole_card,
player_idx
);
}
}
for board_card in &game_state.board {
assert!(
hand.contains(board_card),
"Player {} missing community card {:?}",
player_idx,
board_card
);
}
}
assert!(
dealt_cards.len() <= 52,
"Too many cards dealt: {}. Standard deck has 52 cards",
dealt_cards.len()
);
}
fn validate_dealer_positioning(game_state: &GameState) {
assert!(
game_state.dealer_idx < game_state.num_players,
"Dealer index {} is out of bounds for {} players",
game_state.dealer_idx,
game_state.num_players
);
if game_state.num_players >= 2 {
let _small_blind_idx = (game_state.dealer_idx + 1) % game_state.num_players;
let _big_blind_idx = (game_state.dealer_idx + 2) % game_state.num_players;
if game_state.num_players == 2 {
} else {
}
}
}
fn validate_ante_structure(game_state: &GameState) {
if game_state.ante > 0.0 {
assert!(
game_state.ante >= 0.0,
"Ante cannot be negative: {}",
game_state.ante
);
if game_state.big_blind > 0.0 {
assert!(
game_state.ante <= game_state.big_blind,
"Ante {} should not exceed big blind {}",
game_state.ante,
game_state.big_blind
);
}
}
}
fn validate_stack_integrity(game_state: &GameState) {
for (idx, &stack) in game_state.stacks.iter().enumerate() {
assert!(stack >= 0.0, "Player {} has negative stack: {}", idx, stack);
}
for (idx, &stack) in game_state.starting_stacks.iter().enumerate() {
assert!(
stack >= 0.0,
"Player {} had negative starting stack: {}",
idx,
stack
);
}
}
fn validate_side_pot_distribution(game_state: &GameState) {
for (winner_idx, &winnings) in game_state.player_winnings.iter().enumerate() {
if winnings <= 0.0 {
continue;
}
let winner_bet = game_state.player_bet[winner_idx];
if game_state.ante == 0.0 && winner_bet <= 0.0 {
panic!(
"Player {} won {} but had no contribution to the pot",
winner_idx, winnings
);
}
}
}
fn validate_betting_sequence(history_storage: &[HistoryRecord]) {
let mut round_raise_counts: std::collections::HashMap<Round, usize> =
std::collections::HashMap::new();
let mut active_players_per_round: std::collections::HashMap<
Round,
std::collections::HashSet<usize>,
> = std::collections::HashMap::new();
let mut current_round = Round::Preflop;
for record in history_storage {
match &record.action {
Action::RoundAdvance(round) => {
current_round = *round;
if *round != Round::Complete {
round_raise_counts.insert(*round, 0);
}
}
Action::PlayedAction(action) => {
active_players_per_round
.entry(current_round)
.or_default()
.insert(action.idx);
match action.action {
AgentAction::Bet(_) => {
let raise_count = round_raise_counts.entry(current_round).or_insert(0);
*raise_count += 1;
let num_active = active_players_per_round
.get(¤t_round)
.map(|s| s.len())
.unwrap_or(0);
if num_active > 2 {
assert!(
*raise_count <= 3,
"Too many raises in round {:?}: {}. Maximum 3 raises allowed with 3+ players",
current_round,
*raise_count
);
}
}
AgentAction::Fold => {
}
AgentAction::Call | AgentAction::AllIn => {
}
}
}
_ => {
}
}
}
}
fn validate_round_progression(history_storage: &[HistoryRecord]) {
let round_advances: Vec<Round> = history_storage
.iter()
.filter_map(|record| {
if let Action::RoundAdvance(round) = &record.action {
Some(*round)
} else {
None
}
})
.collect();
if round_advances.is_empty() {
return; }
let mut prev_round: Option<Round> = None;
for round in &round_advances {
if let Some(prev) = prev_round {
match (prev, *round) {
(Round::Starting, Round::Ante) |
(Round::Ante, Round::DealPreflop) |
(Round::DealPreflop, Round::Preflop) |
(Round::Preflop, Round::DealFlop) |
(Round::DealFlop, Round::Flop) |
(Round::Flop, Round::DealTurn) |
(Round::DealTurn, Round::Turn) |
(Round::Turn, Round::DealRiver) |
(Round::DealRiver, Round::River) |
(Round::River, Round::Showdown) |
(Round::Showdown, Round::Complete) |
(Round::Starting, Round::Complete) | (Round::Ante, Round::Complete) | (Round::DealPreflop, Round::Complete) | (Round::Preflop, Round::Complete) | (Round::DealFlop, Round::Complete) | (Round::Flop, Round::Complete) | (Round::DealTurn, Round::Complete) | (Round::Turn, Round::Complete) | (Round::DealRiver, Round::Complete) | (Round::River, Round::Complete) => { }
_ => {
panic!(
"Invalid round progression: {:?} -> {:?}",
prev, round
);
}
}
}
prev_round = Some(*round);
}
assert_eq!(
round_advances.last(),
Some(&Round::Complete),
"Game should end with Complete round, but ended with: {:?}",
round_advances.last()
);
}