use std::sync::atomic::{AtomicUsize, Ordering};
use async_trait::async_trait;
use tracing::{instrument, trace};
use crate::arena::{action::AgentAction, game_state::GameState};
use super::{Agent, AgentGenerator};
#[derive(Debug, Clone)]
pub struct FoldingAgent {
name: String,
}
impl FoldingAgent {
pub fn new(name: impl Into<String>) -> Self {
Self { name: name.into() }
}
}
impl Default for FoldingAgent {
fn default() -> Self {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let idx = COUNTER.fetch_add(1, Ordering::Relaxed);
FoldingAgent::new(format!("FoldingAgent-{idx}"))
}
}
#[async_trait]
impl Agent for FoldingAgent {
#[instrument(level = "trace", skip(self, game_state), fields(agent_name = %self.name))]
async fn act(self: &mut FoldingAgent, _id: u128, game_state: &GameState) -> AgentAction {
let players_in_hand = game_state.num_active_players() + game_state.num_all_in_players();
if players_in_hand == 1 {
let bet = game_state.current_round_bet();
trace!(
bet,
players_in_hand, "FoldingAgent claiming pot (last player)"
);
AgentAction::Bet(bet)
} else {
let current_bet = game_state.current_round_bet();
let player_bet = game_state.current_round_current_player_bet();
let to_call = current_bet - player_bet;
if to_call > 0.0 {
trace!(players_in_hand, to_call, "FoldingAgent folding");
AgentAction::Fold
} else {
trace!(players_in_hand, "FoldingAgent checking (nothing to call)");
AgentAction::Bet(current_bet)
}
}
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug, Clone, Default)]
pub struct FoldingAgentGenerator {
name: Option<String>,
}
impl FoldingAgentGenerator {
pub fn new() -> Self {
Self { name: None }
}
pub fn with_name(name: impl Into<String>) -> Self {
Self {
name: Some(name.into()),
}
}
fn resolve_name(&self, player_idx: usize) -> String {
self.name
.clone()
.unwrap_or_else(|| format!("FoldingAgent-{player_idx}"))
}
}
impl AgentGenerator for FoldingAgentGenerator {
fn generate(&self, player_idx: usize, _game_state: &GameState) -> Box<dyn Agent> {
Box::new(FoldingAgent::new(self.resolve_name(player_idx)))
}
}
#[cfg(test)]
mod tests {
use approx::assert_abs_diff_eq;
use rand::{SeedableRng, rngs::StdRng};
use crate::arena::{HoldemSimulationBuilder, game_state::Round};
use super::*;
use crate::arena::GameStateBuilder;
#[tokio::test(flavor = "current_thread")]
async fn test_folding_generator_creates_named_folder() {
let generator = FoldingAgentGenerator::default();
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let mut agent = generator.generate(0, &game_state);
assert_eq!(agent.name(), "FoldingAgent-0");
match agent.act(0, &game_state).await {
AgentAction::Bet(0.0) => {} action => panic!("Expected Bet(0.0) action (check), got {:?}", action),
}
}
#[tokio::test(flavor = "current_thread")]
async fn test_folding_agent_folds_when_facing_bet() {
use crate::arena::game_state::RoundData;
use crate::core::PlayerBitSet;
let mut round_data = RoundData::new(2, 10.0, PlayerBitSet::new(2), 1);
round_data.bet = 20.0; round_data.player_bet[0] = 20.0; round_data.player_bet[1] = 10.0;
let game_state = GameStateBuilder::new()
.round(crate::arena::game_state::Round::Preflop)
.round_data(round_data)
.stacks(vec![100.0; 2])
.big_blind(10.0)
.small_blind(5.0)
.build()
.unwrap();
let mut agent = FoldingAgent::new("TestFolder");
match agent.act(0, &game_state).await {
AgentAction::Fold => {}
action => panic!("Expected Fold action, got {:?}", action),
}
}
#[test]
fn test_folding_generator_uses_custom_name() {
let generator = FoldingAgentGenerator::with_name("FolderZ");
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 40.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let agent = generator.generate(0, &game_state);
assert_eq!(agent.name(), "FolderZ");
}
#[tokio::test(flavor = "current_thread")]
async fn test_folding_agents() {
let stacks = vec![100.0; 2];
let game_state = GameStateBuilder::new()
.stacks(stacks)
.blinds(10.0, 5.0)
.build()
.unwrap();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(vec![
Box::new(FoldingAgent::new("FoldingAgent-0")),
Box::new(FoldingAgent::new("FoldingAgent-1")),
])
.build_with_rng(StdRng::seed_from_u64(420))
.unwrap();
sim.run().await;
assert_eq!(sim.game_state.num_active_players(), 1);
assert_eq!(sim.game_state.round, Round::Complete);
assert_abs_diff_eq!(15.0_f32, sim.game_state.player_bet.iter().sum());
assert_abs_diff_eq!(15.0_f32, sim.game_state.player_winnings.iter().sum());
assert_abs_diff_eq!(15.0_f32, sim.game_state.player_winnings[1]);
}
#[tokio::test(flavor = "current_thread")]
async fn test_folding_agent_checks_when_bet_matched() {
use crate::arena::game_state::RoundData;
use crate::core::PlayerBitSet;
let mut round_data = RoundData::new(2, 20.0, PlayerBitSet::new(2), 1);
round_data.bet = 20.0; round_data.player_bet[0] = 20.0; round_data.player_bet[1] = 20.0;
let game_state = GameStateBuilder::new()
.round(crate::arena::game_state::Round::Preflop)
.round_data(round_data)
.stacks(vec![100.0; 2])
.big_blind(10.0)
.small_blind(5.0)
.build()
.unwrap();
let mut agent = FoldingAgent::new("TestFolder");
match agent.act(0, &game_state).await {
AgentAction::Bet(bet) => {
assert_eq!(
bet, 20.0,
"Should check/call at current bet level when nothing to call"
);
}
action => panic!("Expected Bet action (check), got {:?}", action),
}
}
}