use std::fs;
use std::path::Path;
use clap::Args;
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
use tracing::{info, warn};
use rs_poker::arena::agent::{Agent, AgentConfig, AgentConfigError, ConfigAgentBuilder};
use rs_poker::arena::cfr::{CFRState, TraversalSet};
use rs_poker::arena::game_state::GameState;
use rs_poker::arena::historian::{OpenHandHistoryHistorian, SharedStatsStorage};
use rs_poker::arena::{GameStateBuilder, Historian, HoldemSimulationBuilder, seeded_rng};
use crate::tui::app::{self, App};
use crate::tui::event::{EventHandler, SimError, SimMessage};
use crate::tui::hand_store::HandStore;
use crate::tui::state::{GameResult, SeatStats, ending_round_from_stats};
use crate::tui::{TuiFlags, run_blocking_tui_loop};
const MAX_CONSECUTIVE_FAILURES: usize = 100;
#[derive(Debug, thiserror::Error)]
pub enum GenerateError {
#[error("--min-players must be >= 2")]
MinPlayersTooFew,
#[error("--min-players must be <= --max-players")]
MinPlayersExceedsMax,
#[error("stack sizes must be > 0")]
InvalidStackSize,
#[error("--min-stack-bb must be <= --max-stack-bb")]
MinStackExceedsMax,
#[error("blinds must be > 0")]
InvalidBlinds,
#[error("--small-blind must be < --big-blind")]
SmallBlindExceedsBig,
#[error("failed to read agents directory '{path}': {source}")]
ReadAgentsDir {
path: String,
source: std::io::Error,
},
#[error("no valid agent configs found in '{0}'")]
NoConfigs(String),
#[error("too many consecutive failures ({0}), aborting generation")]
TooManyFailures(usize),
#[error("TUI error: {0}")]
TuiError(#[from] std::io::Error),
}
#[derive(Args, Debug, Clone)]
pub struct GenerateArgs {
agents_dir: std::path::PathBuf,
#[arg(short = 'o', long = "output")]
output: std::path::PathBuf,
#[arg(short = 'n', long = "num-games", default_value_t = 0)]
num_games: usize,
#[arg(long = "min-players", default_value_t = 2)]
min_players: usize,
#[arg(long = "max-players", default_value_t = 3)]
max_players: usize,
#[arg(long = "big-blind", default_value_t = 10.0)]
big_blind: f32,
#[arg(long = "small-blind", default_value_t = 5.0)]
small_blind: f32,
#[arg(long = "min-stack-bb", default_value_t = 100.0)]
min_stack_bb: f32,
#[arg(long = "max-stack-bb", default_value_t = 300.0)]
max_stack_bb: f32,
#[arg(short = 's', long = "seed")]
seed: Option<u64>,
#[command(flatten)]
tui: TuiFlags,
}
impl GenerateArgs {
fn validate(&self) -> Result<(), GenerateError> {
if self.min_players < 2 {
return Err(GenerateError::MinPlayersTooFew);
}
if self.min_players > self.max_players {
return Err(GenerateError::MinPlayersExceedsMax);
}
if self.min_stack_bb <= 0.0 || self.max_stack_bb <= 0.0 {
return Err(GenerateError::InvalidStackSize);
}
if self.min_stack_bb > self.max_stack_bb {
return Err(GenerateError::MinStackExceedsMax);
}
if self.small_blind <= 0.0 || self.big_blind <= 0.0 {
return Err(GenerateError::InvalidBlinds);
}
if self.small_blind >= self.big_blind {
return Err(GenerateError::SmallBlindExceedsBig);
}
Ok(())
}
}
fn load_configs(dir: &Path) -> Result<Vec<AgentConfig>, GenerateError> {
let mut configs = Vec::new();
let entries = fs::read_dir(dir).map_err(|e| GenerateError::ReadAgentsDir {
path: dir.display().to_string(),
source: e,
})?;
let mut dir_entries: Vec<_> = entries.filter_map(|e| e.ok()).collect();
dir_entries.sort_by_key(|e| e.path());
for entry in dir_entries {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("json") {
continue;
}
match ConfigAgentBuilder::from_file(&path) {
Ok(builder) => {
info!("Loaded config: {}", path.display());
configs.push(builder.config().clone());
}
Err(e) => {
warn!("Skipping '{}': {}", path.display(), e);
}
}
}
Ok(configs)
}
struct GameSetup {
game_state: GameState,
agents: Vec<Box<dyn Agent>>,
cfr_context: Option<(CFRState, TraversalSet)>,
num_players: usize,
}
struct GenerationContext<'a> {
rng: StdRng,
args: &'a GenerateArgs,
configs: &'a [AgentConfig],
consecutive_failures: usize,
games_completed: usize,
}
impl<'a> GenerationContext<'a> {
fn new(args: &'a GenerateArgs, configs: &'a [AgentConfig]) -> Self {
let rng = seeded_rng(args.seed);
Self {
rng,
args,
configs,
consecutive_failures: 0,
games_completed: 0,
}
}
fn is_done(&self) -> bool {
self.args.num_games > 0 && self.games_completed >= self.args.num_games
}
fn next_game(&mut self) -> Result<Option<GameSetup>, GenerateError> {
let num_players = self
.rng
.random_range(self.args.min_players..=self.args.max_players);
let stacks: Vec<f32> = (0..num_players)
.map(|_| {
self.rng
.random_range(self.args.min_stack_bb..=self.args.max_stack_bb)
* self.args.big_blind
})
.collect();
let dealer_idx = self.rng.random_range(0..num_players);
let selected_configs: Vec<&AgentConfig> = (0..num_players)
.map(|_| &self.configs[self.rng.random_range(0..self.configs.len())])
.collect();
let game_state = match GameStateBuilder::new()
.stacks(stacks)
.big_blind(self.args.big_blind)
.small_blind(self.args.small_blind)
.dealer_idx(dealer_idx)
.build()
{
Ok(gs) => gs,
Err(e) => {
warn!("Failed to build game state: {}", e);
return self.record_failure("game state build");
}
};
let cfr_context = AgentConfig::maybe_shared_cfr_context(
selected_configs.iter().copied(),
&game_state,
num_players,
);
let agents_result: Result<Vec<_>, AgentConfigError> = selected_configs
.iter()
.enumerate()
.map(|(idx, config)| {
let mut builder = ConfigAgentBuilder::new((*config).clone())?.player_idx(idx);
if let Some((ref cfr_state, ref ts)) = cfr_context {
builder = builder.cfr_context(cfr_state.clone(), ts.clone());
}
builder = builder.game_state(game_state.clone());
builder = builder.rng_seed(self.rng.random::<u64>());
Ok(builder.build())
})
.collect();
let agents = match agents_result {
Ok(a) => a,
Err(e) => {
warn!("Failed to build agents: {}", e);
return self.record_failure("agent build");
}
};
Ok(Some(GameSetup {
game_state,
agents,
cfr_context,
num_players,
}))
}
fn record_success(&mut self) {
self.games_completed += 1;
self.consecutive_failures = 0;
}
fn record_failure(&mut self, _phase: &str) -> Result<Option<GameSetup>, GenerateError> {
self.consecutive_failures += 1;
if self.consecutive_failures >= MAX_CONSECUTIVE_FAILURES {
return Err(GenerateError::TooManyFailures(self.consecutive_failures));
}
Ok(None)
}
}
struct GamePlan {
historians: Vec<Box<dyn Historian>>,
on_complete: Box<dyn FnOnce(usize) -> std::ops::ControlFlow<()> + Send>,
}
async fn run_generation_loop<F>(
args: &GenerateArgs,
configs: &[AgentConfig],
warn_on_build_failure: bool,
mut setup_game: F,
) -> Result<usize, GenerateError>
where
F: FnMut(&GameSetup) -> GamePlan,
{
let mut ctx = GenerationContext::new(args, configs);
loop {
if ctx.is_done() {
break;
}
let setup = match ctx.next_game()? {
Some(s) => s,
None => continue,
};
let GamePlan {
historians,
on_complete,
} = setup_game(&setup);
let sub_rng = StdRng::from_rng(&mut ctx.rng);
let sim_result = {
let mut builder = HoldemSimulationBuilder::default()
.game_state(setup.game_state)
.agents(setup.agents)
.historians(historians);
if let Some((cfr_state, traversal_set)) = setup.cfr_context {
builder = builder.cfr_context(cfr_state, traversal_set, true);
}
builder.build_with_rng(sub_rng)
};
let mut sim = match sim_result {
Ok(s) => s,
Err(e) => {
if warn_on_build_failure {
warn!("Failed to build simulation: {}", e);
}
ctx.record_failure("sim build")?;
continue;
}
};
sim.run().await;
drop(sim);
ctx.record_success();
if on_complete(ctx.games_completed).is_break() {
return Ok(ctx.games_completed);
}
}
Ok(ctx.games_completed)
}
async fn run_generation(args: &GenerateArgs, configs: &[AgentConfig]) -> Result<(), GenerateError> {
let report_interval = if args.num_games == 0 {
10
} else {
(args.num_games / 10).max(1)
};
let total = run_generation_loop(args, configs, true, |_setup| GamePlan {
historians: vec![Box::new(OpenHandHistoryHistorian::new(args.output.clone()))],
on_complete: Box::new(move |games_completed| {
if games_completed.is_multiple_of(report_interval) {
info!("Generated {} hands...", games_completed);
}
std::ops::ControlFlow::Continue(())
}),
})
.await?;
info!("Done. Generated {} hands total.", total);
Ok(())
}
async fn run_generation_background(
args: GenerateArgs,
configs: Vec<AgentConfig>,
tx: std::sync::mpsc::SyncSender<SimMessage<GameResult>>,
hand_store: HandStore,
) {
let inner_tx = tx.clone();
let join =
tokio::spawn(
async move { run_generation_inner(&args, &configs, &inner_tx, &hand_store).await },
)
.await;
match join {
Ok(Ok(())) => {
let _ = tx.send(SimMessage::Completed);
}
Ok(Err(GenerateError::TooManyFailures(n))) => {
let _ = tx.send(SimMessage::Error(SimError::TooManyFailures {
consecutive_failures: n,
}));
}
Ok(Err(_)) | Err(_) => {
let _ = tx.send(SimMessage::Error(SimError::Panic));
}
}
}
async fn run_generation_inner(
args: &GenerateArgs,
configs: &[AgentConfig],
tx: &std::sync::mpsc::SyncSender<SimMessage<GameResult>>,
hand_store: &HandStore,
) -> Result<(), GenerateError> {
run_generation_loop(args, configs, false, |setup| {
let num_players = setup.num_players;
let agent_names: Vec<String> = setup.agents.iter().map(|a| a.name().to_string()).collect();
let pre_offset = std::fs::metadata(&args.output)
.map(|m| m.len())
.unwrap_or(0);
let ohh_historian = OpenHandHistoryHistorian::new(args.output.clone());
let stats_storage = SharedStatsStorage::new(num_players);
let stats_historian = stats_storage.historian();
let big_blind = args.big_blind;
let hand_store = hand_store.clone();
let tx = tx.clone();
GamePlan {
historians: vec![Box::new(ohh_historian), Box::new(stats_historian)],
on_complete: Box::new(move |_games_completed| {
hand_store.push_offset(pre_offset);
let stats_snap = stats_storage.snapshot();
let ending_round = ending_round_from_stats(&stats_snap, num_players);
let profits: Vec<f32> = (0..num_players)
.map(|i| stats_snap.total_profit[i])
.collect();
let seat_stats: Vec<SeatStats> = (0..num_players)
.map(|i| SeatStats::from_storage(&stats_snap, i))
.collect();
drop(stats_snap);
let game_result = GameResult {
agent_names,
profits,
ending_round,
seat_stats,
big_blind,
};
if tx.send(SimMessage::GameResult(game_result)).is_err() {
std::ops::ControlFlow::Break(())
} else {
std::ops::ControlFlow::Continue(())
}
}),
}
})
.await?;
Ok(())
}
async fn run_generation_with_tui(
args: GenerateArgs,
configs: Vec<AgentConfig>,
) -> Result<(), GenerateError> {
let (tx, rx) = std::sync::mpsc::sync_channel::<SimMessage<GameResult>>(1024);
let games_target = if args.num_games > 0 {
Some(args.num_games)
} else {
None
};
let hand_store = HandStore::new(args.output.clone());
let bg_args = args.clone();
let bg_configs = configs.clone();
let bg_hand_store = hand_store.clone();
let bg_handle = tokio::spawn(async move {
run_generation_background(bg_args, bg_configs, tx, bg_hand_store).await;
});
run_blocking_tui_loop(
move || {
let handler = EventHandler::new(rx, std::time::Duration::from_millis(33));
let mut tui_app = App::new(games_target);
tui_app.hand_store = hand_store;
app::run_app(&mut tui_app, &handler)
},
bg_handle,
|| {},
)
.await?;
Ok(())
}
pub async fn run(
args: GenerateArgs,
default_budget: &rs_poker::arena::cfr::BudgetConfig,
) -> Result<(), GenerateError> {
let mut configs = load_configs(&args.agents_dir)?;
if configs.is_empty() {
return Err(GenerateError::NoConfigs(
args.agents_dir.display().to_string(),
));
}
info!("Loaded {} agent config(s)", configs.len());
crate::budget::override_budgets(&mut configs, default_budget);
args.validate()?;
if args.tui.should_use_tui() {
run_generation_with_tui(args, configs).await
} else {
run_generation(&args, &configs).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tui::state::RoundLabel;
use rs_poker::arena::historian::StatsStorage;
fn make_stats(num_players: usize) -> StatsStorage {
StatsStorage::new_with_num_players(num_players)
}
#[test]
fn test_ending_round_preflop() {
let stats = make_stats(2);
assert_eq!(ending_round_from_stats(&stats, 2), RoundLabel::Preflop);
}
#[test]
fn test_ending_round_flop() {
let mut stats = make_stats(2);
stats.flop_completes[0] = 1;
assert_eq!(ending_round_from_stats(&stats, 2), RoundLabel::Flop);
}
#[test]
fn test_ending_round_turn() {
let mut stats = make_stats(2);
stats.flop_completes[0] = 1;
stats.turn_completes[1] = 1;
assert_eq!(ending_round_from_stats(&stats, 2), RoundLabel::Turn);
}
#[test]
fn test_ending_round_river() {
let mut stats = make_stats(2);
stats.flop_completes[0] = 1;
stats.turn_completes[0] = 1;
stats.river_completes[0] = 1;
assert_eq!(ending_round_from_stats(&stats, 2), RoundLabel::River);
}
#[test]
fn test_ending_round_showdown() {
let mut stats = make_stats(2);
stats.showdown_count[0] = 1;
assert_eq!(ending_round_from_stats(&stats, 2), RoundLabel::Showdown);
}
#[test]
fn test_ending_round_ignores_players_beyond_num() {
let mut stats = make_stats(3);
stats.showdown_count[2] = 1;
assert_eq!(ending_round_from_stats(&stats, 2), RoundLabel::Preflop);
}
}