use std::path::{Path, PathBuf};
use crate::arena::agent::AgentConfig;
use crate::arena::cfr::BudgetConfig;
use super::config::ComparisonConfig;
use super::error::{ComparisonError, Result};
use super::runner::ArenaComparison;
#[derive(Debug, Default)]
pub struct ComparisonBuilder {
agents: Vec<(String, AgentConfig)>,
num_games: Option<usize>,
players_per_table: Option<usize>,
big_blind: Option<f32>,
small_blind: Option<f32>,
min_stack_bb: Option<f32>,
max_stack_bb: Option<f32>,
ante: Option<f32>,
output_dir: Option<PathBuf>,
seed: Option<u64>,
}
impl ComparisonBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn num_games(mut self, num_games: usize) -> Self {
self.num_games = Some(num_games);
self
}
pub fn players_per_table(mut self, players_per_table: usize) -> Self {
self.players_per_table = Some(players_per_table);
self
}
pub fn big_blind(mut self, big_blind: f32) -> Self {
self.big_blind = Some(big_blind);
self
}
pub fn small_blind(mut self, small_blind: f32) -> Self {
self.small_blind = Some(small_blind);
self
}
pub fn min_stack_bb(mut self, min_stack_bb: f32) -> Self {
self.min_stack_bb = Some(min_stack_bb);
self
}
pub fn max_stack_bb(mut self, max_stack_bb: f32) -> Self {
self.max_stack_bb = Some(max_stack_bb);
self
}
pub fn ante(mut self, ante: f32) -> Self {
self.ante = Some(ante);
self
}
pub fn output_dir<P: AsRef<Path>>(mut self, output_dir: P) -> Self {
self.output_dir = Some(output_dir.as_ref().to_path_buf());
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn add_agent(mut self, name: String, config: AgentConfig) -> Self {
self.agents.push((name, config));
self
}
pub fn add_agent_config(mut self, config: AgentConfig) -> Self {
let name = get_agent_name(&config, &format!("Agent{}", self.agents.len()));
self.agents.push((name, config));
self
}
pub fn add_agents(mut self, agents: Vec<(String, AgentConfig)>) -> Self {
self.agents.extend(agents);
self
}
pub fn load_agents_from_dir<P: AsRef<Path>>(mut self, dir: P) -> Result<Self> {
let agents = load_agents_from_dir(dir.as_ref())?;
self.agents.extend(agents);
Ok(self)
}
pub fn fill_default_budget(mut self, default: &BudgetConfig) -> Self {
for (_, cfg) in &mut self.agents {
cfg.fill_default_budget(default);
}
self
}
pub fn build(self) -> Result<ArenaComparison> {
if self.agents.is_empty() {
return Err(ComparisonError::MissingConfig(
"No agents configured. Use add_agent(), add_agent_config(), or load_agents_from_dir()".to_string(),
));
}
let config = ComparisonConfig {
num_games: self.num_games.unwrap_or(1000),
players_per_table: self.players_per_table.unwrap_or(3),
big_blind: self.big_blind.unwrap_or(10.0),
small_blind: self.small_blind.unwrap_or(5.0),
min_stack_bb: self.min_stack_bb.unwrap_or(100.0),
max_stack_bb: self.max_stack_bb.unwrap_or(100.0),
ante: self.ante.unwrap_or(0.0),
output_dir: self.output_dir,
seed: self.seed,
};
config.validate(self.agents.len())?;
for (_, agent_config) in &self.agents {
agent_config.validate()?;
}
Ok(ArenaComparison::new(config, self.agents))
}
}
fn get_agent_name(config: &AgentConfig, fallback_name: &str) -> String {
match config {
AgentConfig::AllIn { name, .. }
| AgentConfig::Calling { name, .. }
| AgentConfig::Folding { name, .. }
| AgentConfig::Random { name, .. }
| AgentConfig::RandomPotControl { name, .. }
| AgentConfig::CfrBasic { name, .. }
| AgentConfig::CfrSimple { name, .. }
| AgentConfig::CfrConfigurable { name, .. }
| AgentConfig::CfrPreflopChart { name, .. } => {
name.clone().unwrap_or_else(|| fallback_name.to_string())
}
}
}
fn load_agents_from_dir(dir: &Path) -> Result<Vec<(String, AgentConfig)>> {
if !dir.exists() {
return Err(ComparisonError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Directory does not exist: {}", dir.display()),
)));
}
if !dir.is_dir() {
return Err(ComparisonError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Path is not a directory: {}", dir.display()),
)));
}
let mut agents = Vec::new();
let entries = std::fs::read_dir(dir)?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("json") {
continue;
}
let fallback_name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
match load_agent_config(&path) {
Ok(config) => {
let agent_name = get_agent_name(&config, &fallback_name);
agents.push((agent_name, config));
}
Err(e) => {
tracing::warn!("Skipping invalid config file {}: {}", path.display(), e);
}
}
}
if agents.is_empty() {
return Err(ComparisonError::NoAgentsFound(dir.display().to_string()));
}
agents.sort_by(|a, b| a.0.cmp(&b.0));
Ok(agents)
}
fn load_agent_config(path: &Path) -> Result<AgentConfig> {
let contents = std::fs::read_to_string(path)?;
let config: AgentConfig =
serde_json::from_str(&contents).map_err(|source| ComparisonError::ParseConfig {
path: path.display().to_string(),
source,
})?;
config.validate()?;
Ok(config)
}
#[cfg(test)]
mod tests {
use super::super::error::ComparisonConfigError;
use super::*;
#[test]
fn test_builder_defaults() {
let builder = ComparisonBuilder::new()
.add_agent_config(AgentConfig::Folding { name: None })
.add_agent_config(AgentConfig::Calling { name: None })
.add_agent_config(AgentConfig::AllIn { name: None });
let comparison = builder.build().unwrap();
let config = comparison.config();
assert_eq!(config.num_games, 1000);
assert_eq!(config.players_per_table, 3);
assert_eq!(config.big_blind, 10.0);
assert_eq!(config.small_blind, 5.0);
}
#[test]
fn test_builder_custom_config() {
let comparison = ComparisonBuilder::new()
.num_games(500)
.players_per_table(2)
.big_blind(20.0)
.small_blind(10.0)
.min_stack_bb(50.0)
.max_stack_bb(150.0)
.seed(42)
.add_agent_config(AgentConfig::Folding { name: None })
.add_agent_config(AgentConfig::Calling { name: None })
.build()
.unwrap();
let config = comparison.config();
assert_eq!(config.num_games, 500);
assert_eq!(config.players_per_table, 2);
assert_eq!(config.big_blind, 20.0);
assert_eq!(config.small_blind, 10.0);
assert_eq!(config.min_stack_bb, 50.0);
assert_eq!(config.max_stack_bb, 150.0);
assert_eq!(config.seed, Some(42));
}
#[test]
fn test_builder_no_agents_error() {
let result = ComparisonBuilder::new().build();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ComparisonError::MissingConfig(_)
));
}
#[test]
fn test_builder_validation_error() {
let result = ComparisonBuilder::new()
.players_per_table(1) .add_agent_config(AgentConfig::Folding { name: None })
.add_agent_config(AgentConfig::Calling { name: None })
.build();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ComparisonError::InvalidConfig(ComparisonConfigError::PlayersPerTableTooSmall(1))
));
}
#[test]
fn test_get_agent_name_with_name() {
let config = AgentConfig::AllIn {
name: Some("MyAgent".to_string()),
};
assert_eq!(get_agent_name(&config, "fallback"), "MyAgent");
}
#[test]
fn test_get_agent_name_without_name() {
let config = AgentConfig::AllIn { name: None };
assert_eq!(get_agent_name(&config, "fallback"), "fallback");
}
}