use std::sync::Arc;
use crate::arena::agent::{
AllInAgent, CallingAgent, FoldingAgent, RandomAgent, RandomPotControlAgent,
};
use crate::arena::cfr::{
BasicCFRActionGenerator, BudgetConfig, CFRAgentBuilder, CFRState, ConfigurableActionConfig,
ConfigurableActionConfigError, ConfigurableActionGenerator, PreflopChartActionConfig,
PreflopChartActionGenerator, PreflopChartConfig, PreflopChartConfigError,
SimpleActionGenerator, TraversalSet,
};
use crate::arena::hand_estimator::{
HandDistributionEstimator, KnownHandsEstimator, UniformRandomEstimator,
};
use crate::arena::{Agent, GameState};
use serde::{Deserialize, Serialize};
use std::{io::ErrorKind, path::Path};
use thiserror::Error;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum AgentConfig {
AllIn {
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Calling {
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Folding {
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
#[serde(alias = "random")]
Random {
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(default = "default_percent_fold")]
percent_fold: Vec<f64>,
#[serde(default = "default_percent_call")]
percent_call: Vec<f64>,
},
RandomPotControl {
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
percent_call: Vec<f64>,
},
CfrBasic {
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(default)]
exploration: CfrExploration,
#[serde(default)]
hand_estimator: EstimatorConfig,
},
CfrSimple {
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(default)]
exploration: CfrExploration,
#[serde(default)]
hand_estimator: EstimatorConfig,
},
CfrConfigurable {
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(default)]
exploration: CfrExploration,
#[serde(default)]
hand_estimator: EstimatorConfig,
action_config: Box<ConfigurableActionConfig>,
},
CfrPreflopChart {
#[serde(default, skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(default)]
exploration: CfrExploration,
#[serde(default)]
hand_estimator: EstimatorConfig,
#[serde(default)]
preflop_config: PreflopChartConfigOption,
#[serde(default, skip_serializing_if = "Option::is_none")]
postflop_config: Option<Box<ConfigurableActionConfig>>,
},
}
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum PreflopChartConfigOption {
Preset(String),
Inline(PreflopChartConfig),
}
impl<'de> Deserialize<'de> for PreflopChartConfigOption {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
if let serde_json::Value::String(s) = &value {
return Ok(PreflopChartConfigOption::Preset(s.clone()));
}
serde_json::from_value::<PreflopChartConfig>(value)
.map(PreflopChartConfigOption::Inline)
.map_err(serde::de::Error::custom)
}
}
impl Default for PreflopChartConfigOption {
fn default() -> Self {
PreflopChartConfigOption::Preset("6max_gto".to_string())
}
}
impl PreflopChartConfigOption {
pub fn resolve(&self) -> Result<PreflopChartConfig, AgentConfigError> {
match self {
PreflopChartConfigOption::Preset(name) => Err(
AgentConfigError::PreflopChartPresetUnavailable(name.clone()),
),
PreflopChartConfigOption::Inline(config) => Ok(config.clone()),
}
}
}
fn default_percent_fold() -> Vec<f64> {
vec![0.25, 0.30, 0.50]
}
fn default_percent_call() -> Vec<f64> {
vec![0.5, 0.6, 0.45]
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub struct CfrExploration {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub budget: Option<BudgetConfig>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[serde(rename_all = "snake_case")]
pub enum EstimatorConfig {
#[default]
Known,
Uniform,
}
impl EstimatorConfig {
pub fn build(&self) -> Arc<dyn HandDistributionEstimator> {
match self {
EstimatorConfig::Known => Arc::new(KnownHandsEstimator),
EstimatorConfig::Uniform => Arc::new(UniformRandomEstimator),
}
}
}
#[derive(Debug, Error)]
pub enum AgentConfigError {
#[error("Invalid probability value: {0} (must be between 0.0 and 1.0)")]
InvalidProbability(f64),
#[error("JSON parsing error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("File I/O error: {0}")]
IoError(#[from] std::io::Error),
#[error("invalid configurable action config: {0}")]
ConfigurableActionConfig(#[from] ConfigurableActionConfigError),
#[error("invalid preflop chart config: {0}")]
PreflopChartConfig(#[from] PreflopChartConfigError),
#[error(
"preflop chart preset '{0}' is not available; use inline configuration instead. \
See examples/configs/preflop_6max_rfi.json for an example"
)]
PreflopChartPresetUnavailable(String),
}
impl AgentConfig {
pub fn is_cfr(&self) -> bool {
matches!(
self,
AgentConfig::CfrBasic { .. }
| AgentConfig::CfrSimple { .. }
| AgentConfig::CfrConfigurable { .. }
| AgentConfig::CfrPreflopChart { .. }
)
}
pub fn maybe_shared_cfr_context<'a, I>(
configs: I,
game_state: &GameState,
num_players: usize,
) -> Option<(CFRState, TraversalSet)>
where
I: IntoIterator<Item = &'a AgentConfig>,
{
let has_cfr = configs.into_iter().any(|c| c.is_cfr());
has_cfr.then(|| {
(
CFRState::new(game_state.clone()),
TraversalSet::new(num_players),
)
})
}
pub fn fill_default_budget(&mut self, default: &BudgetConfig) {
let target = match self {
AgentConfig::CfrBasic { exploration, .. }
| AgentConfig::CfrSimple { exploration, .. }
| AgentConfig::CfrConfigurable { exploration, .. }
| AgentConfig::CfrPreflopChart { exploration, .. } => &mut exploration.budget,
_ => return,
};
if target.is_none() {
*target = Some(default.clone());
}
}
pub fn validate(&self) -> Result<(), AgentConfigError> {
match self {
AgentConfig::Random {
percent_fold,
percent_call,
..
} => {
validate_probabilities(percent_fold)?;
validate_probabilities(percent_call)?;
}
AgentConfig::RandomPotControl { percent_call, .. } => {
validate_probabilities(percent_call)?;
}
AgentConfig::CfrConfigurable { action_config, .. } => {
action_config.validate()?;
}
AgentConfig::CfrPreflopChart { preflop_config, .. } => {
let resolved = preflop_config.resolve()?;
resolved.validate()?;
}
_ => {}
}
Ok(())
}
}
fn validate_probabilities(probs: &[f64]) -> Result<(), AgentConfigError> {
for &p in probs {
if !(0.0..=1.0).contains(&p) {
return Err(AgentConfigError::InvalidProbability(p));
}
}
Ok(())
}
fn default_agent_name(agent_kind: &str, player_idx: usize) -> String {
format!("{agent_kind}-{player_idx}")
}
fn resolve_agent_name(name: &Option<String>, agent_kind: &str, player_idx: usize) -> String {
name.clone()
.unwrap_or_else(|| default_agent_name(agent_kind, player_idx))
}
#[derive(Debug, Clone)]
pub struct ConfigAgentBuilder {
config: AgentConfig,
player_idx: Option<usize>,
game_state: Option<GameState>,
cfr_state: Option<CFRState>,
traversal_set: Option<TraversalSet>,
rng_seed: Option<u64>,
}
impl ConfigAgentBuilder {
pub fn new(config: AgentConfig) -> Result<Self, AgentConfigError> {
config.validate()?;
Ok(Self {
config,
player_idx: None,
game_state: None,
cfr_state: None,
traversal_set: None,
rng_seed: None,
})
}
pub fn player_idx(mut self, idx: usize) -> Self {
self.player_idx = Some(idx);
self
}
pub fn game_state(mut self, game_state: GameState) -> Self {
if self.config.is_cfr() && self.cfr_state.is_none() {
self.cfr_state = Some(CFRState::new(game_state.clone()));
self.traversal_set = Some(TraversalSet::new(game_state.num_players));
}
self.game_state = Some(game_state);
self
}
pub fn cfr_context(mut self, cfr_state: CFRState, traversal_set: TraversalSet) -> Self {
self.cfr_state = Some(cfr_state);
self.traversal_set = Some(traversal_set);
self
}
pub fn rng_seed(mut self, seed: u64) -> Self {
self.rng_seed = Some(seed);
self
}
pub fn config(&self) -> &AgentConfig {
&self.config
}
pub fn from_json(json: &str) -> Result<Self, AgentConfigError> {
let config: AgentConfig = serde_json::from_str(json)?;
Self::new(config)
}
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, AgentConfigError> {
let json = std::fs::read_to_string(path)?;
Self::from_json(&json)
}
pub fn from_str_or_file(input: &str) -> Result<Self, AgentConfigError> {
match Self::from_file(input) {
Ok(builder) => Ok(builder),
Err(AgentConfigError::IoError(err)) if err.kind() == ErrorKind::NotFound => {
Self::from_json(input)
}
Err(err) => Err(err),
}
}
pub fn build(self) -> Box<dyn Agent> {
let player_idx = self.player_idx.expect("player_idx is required");
match &self.config {
AgentConfig::AllIn { name } => Box::new(AllInAgent::new(resolve_agent_name(
name,
"AllInAgent",
player_idx,
))),
AgentConfig::Calling { name } => Box::new(CallingAgent::new(resolve_agent_name(
name,
"CallingAgent",
player_idx,
))),
AgentConfig::Folding { name } => Box::new(FoldingAgent::new(resolve_agent_name(
name,
"FoldingAgent",
player_idx,
))),
AgentConfig::Random {
name,
percent_fold,
percent_call,
} => {
let agent_name = resolve_agent_name(name, "RandomAgent", player_idx);
if let Some(seed) = self.rng_seed {
Box::new(RandomAgent::new_with_seed(
agent_name,
percent_fold.clone(),
percent_call.clone(),
seed,
))
} else {
Box::new(RandomAgent::new(
agent_name,
percent_fold.clone(),
percent_call.clone(),
))
}
}
AgentConfig::RandomPotControl { name, percent_call } => {
let agent_name = resolve_agent_name(name, "RandomPotControlAgent", player_idx);
if let Some(seed) = self.rng_seed {
Box::new(RandomPotControlAgent::new_with_seed(
agent_name,
percent_call.clone(),
seed,
))
} else {
Box::new(RandomPotControlAgent::new(agent_name, percent_call.clone()))
}
}
AgentConfig::CfrBasic {
name,
exploration,
hand_estimator,
} => {
let (cfr_state, traversal_set) = self.resolve_cfr_context();
let budget = exploration.budget.clone().unwrap_or_default().build();
let builder = CFRAgentBuilder::<BasicCFRActionGenerator>::new()
.name(resolve_agent_name(name, "CFRAgent", player_idx))
.player_idx(player_idx)
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.budget(budget)
.estimator(hand_estimator.build())
.action_gen_config(());
Box::new(builder.build())
}
AgentConfig::CfrSimple {
name,
exploration,
hand_estimator,
} => {
let (cfr_state, traversal_set) = self.resolve_cfr_context();
let budget = exploration.budget.clone().unwrap_or_default().build();
let builder = CFRAgentBuilder::<SimpleActionGenerator>::new()
.name(resolve_agent_name(name, "CFRSimpleAgent", player_idx))
.player_idx(player_idx)
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.budget(budget)
.estimator(hand_estimator.build())
.action_gen_config(());
Box::new(builder.build())
}
AgentConfig::CfrConfigurable {
name,
exploration,
hand_estimator,
action_config,
} => {
let (cfr_state, traversal_set) = self.resolve_cfr_context();
let budget = exploration.budget.clone().unwrap_or_default().build();
let builder = CFRAgentBuilder::<ConfigurableActionGenerator>::new()
.name(resolve_agent_name(name, "CFRConfigurableAgent", player_idx))
.player_idx(player_idx)
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.budget(budget)
.estimator(hand_estimator.build())
.action_gen_config(action_config.as_ref().clone());
Box::new(builder.build())
}
AgentConfig::CfrPreflopChart {
name,
exploration,
hand_estimator,
preflop_config,
postflop_config,
} => {
let resolved_preflop_config = preflop_config
.resolve()
.expect("Invalid preflop config - should have been validated");
let (cfr_state, traversal_set) = self.resolve_cfr_context();
let action_config = PreflopChartActionConfig {
preflop_config: resolved_preflop_config,
postflop_config: postflop_config
.as_ref()
.map(|c| c.as_ref().clone())
.unwrap_or_default(),
};
let budget = exploration.budget.clone().unwrap_or_default().build();
let builder = CFRAgentBuilder::<PreflopChartActionGenerator>::new()
.name(resolve_agent_name(name, "CFRPreflopChartAgent", player_idx))
.player_idx(player_idx)
.cfr_state(cfr_state)
.traversal_set(traversal_set)
.budget(budget)
.estimator(hand_estimator.build())
.action_gen_config(action_config);
Box::new(builder.build())
}
}
}
fn resolve_cfr_context(&self) -> (CFRState, TraversalSet) {
let cfr_state = self
.cfr_state
.clone()
.expect("cfr_context() or game_state() is required for CFR agents");
let traversal_set = self
.traversal_set
.clone()
.expect("cfr_context() or game_state() is required for CFR agents");
(cfr_state, traversal_set)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::GameStateBuilder;
use crate::arena::cfr::BudgetItem;
#[test]
fn test_serialize_all_in() {
let config = AgentConfig::AllIn { name: None };
let json = serde_json::to_string(&config).unwrap();
assert_eq!(json, r#"{"type":"all_in"}"#);
}
#[test]
fn test_serialize_calling() {
let config = AgentConfig::Calling { name: None };
let json = serde_json::to_string(&config).unwrap();
assert_eq!(json, r#"{"type":"calling"}"#);
}
#[test]
fn test_serialize_folding() {
let config = AgentConfig::Folding { name: None };
let json = serde_json::to_string(&config).unwrap();
assert_eq!(json, r#"{"type":"folding"}"#);
}
#[test]
fn test_deserialize_all_in() {
let json = r#"{"type":"all_in"}"#;
let config: AgentConfig = serde_json::from_str(json).unwrap();
match config {
AgentConfig::AllIn { name } => assert!(name.is_none()),
_ => panic!("Expected AllIn variant"),
}
}
#[test]
fn test_deserialize_calling() {
let json = r#"{"type":"calling"}"#;
let config: AgentConfig = serde_json::from_str(json).unwrap();
match config {
AgentConfig::Calling { name } => assert!(name.is_none()),
_ => panic!("Expected Calling variant"),
}
}
#[test]
fn test_deserialize_folding() {
let json = r#"{"type":"folding"}"#;
let config: AgentConfig = serde_json::from_str(json).unwrap();
match config {
AgentConfig::Folding { name } => assert!(name.is_none()),
_ => panic!("Expected Folding variant"),
}
}
#[test]
fn test_deserialize_random_with_defaults() {
let json = r#"{"type":"random"}"#;
let config: AgentConfig = serde_json::from_str(json).unwrap();
match config {
AgentConfig::Random {
name,
percent_fold,
percent_call,
} => {
assert!(name.is_none());
assert_eq!(percent_fold, vec![0.25, 0.30, 0.50]);
assert_eq!(percent_call, vec![0.5, 0.6, 0.45]);
}
_ => panic!("Expected Random variant"),
}
}
#[test]
fn test_deserialize_random_with_params() {
let json = r#"{"type":"random","percent_fold":[0.1,0.2],"percent_call":[0.6,0.7]}"#;
let config: AgentConfig = serde_json::from_str(json).unwrap();
match config {
AgentConfig::Random {
name,
percent_fold,
percent_call,
} => {
assert!(name.is_none());
assert_eq!(percent_fold, vec![0.1, 0.2]);
assert_eq!(percent_call, vec![0.6, 0.7]);
}
_ => panic!("Expected Random variant"),
}
}
#[test]
fn test_deserialize_random_pot_control() {
let json = r#"{"type":"random_pot_control","percent_call":[0.5,0.3]}"#;
let config: AgentConfig = serde_json::from_str(json).unwrap();
match config {
AgentConfig::RandomPotControl { name, percent_call } => {
assert!(name.is_none());
assert_eq!(percent_call, vec![0.5, 0.3]);
}
_ => panic!("Expected RandomPotControl variant"),
}
}
#[test]
fn test_validate_invalid_probability_too_high() {
let config = AgentConfig::Random {
name: None,
percent_fold: vec![1.5],
percent_call: vec![0.5],
};
assert!(config.validate().is_err());
}
#[test]
fn test_validate_invalid_probability_negative() {
let config = AgentConfig::Random {
name: None,
percent_fold: vec![0.25],
percent_call: vec![-0.1],
};
assert!(config.validate().is_err());
}
#[test]
fn test_validate_valid_config() {
let config = AgentConfig::Random {
name: None,
percent_fold: vec![0.25, 0.30],
percent_call: vec![0.5, 0.6],
};
assert!(config.validate().is_ok());
}
#[test]
fn test_validate_edge_cases() {
let config = AgentConfig::Random {
name: None,
percent_fold: vec![0.0, 1.0],
percent_call: vec![0.0, 1.0],
};
assert!(config.validate().is_ok());
}
#[test]
fn test_round_trip_serialization() {
let configs = vec![
AgentConfig::AllIn { name: None },
AgentConfig::Calling { name: None },
AgentConfig::Folding { name: None },
AgentConfig::Random {
name: None,
percent_fold: vec![0.2],
percent_call: vec![0.5],
},
AgentConfig::RandomPotControl {
name: None,
percent_call: vec![0.4, 0.3],
},
];
for config in configs {
let json = serde_json::to_string(&config).unwrap();
let deserialized: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(format!("{:?}", config), format!("{:?}", deserialized));
}
}
#[test]
fn test_create_from_config() {
let config = AgentConfig::AllIn { name: None };
let _agent = ConfigAgentBuilder::new(config)
.unwrap()
.player_idx(0)
.build();
}
#[test]
fn test_from_json() {
let json = r#"{"type":"calling"}"#;
let generator = ConfigAgentBuilder::from_json(json).unwrap();
assert!(matches!(generator.config, AgentConfig::Calling { .. }));
}
#[test]
fn test_from_json_with_params() {
let json = r#"{"type":"random","percent_fold":[0.2],"percent_call":[0.5]}"#;
let generator = ConfigAgentBuilder::from_json(json).unwrap();
match generator.config {
AgentConfig::Random {
name,
percent_fold,
percent_call,
} => {
assert!(name.is_none());
assert_eq!(percent_fold, vec![0.2]);
assert_eq!(percent_call, vec![0.5]);
}
_ => panic!("Expected Random variant"),
}
}
#[test]
fn test_validation_on_construction() {
let json = r#"{"type":"random","percent_fold":[1.5],"percent_call":[0.5]}"#;
assert!(ConfigAgentBuilder::from_json(json).is_err());
}
#[test]
fn test_build_multiple_agents() {
let config = AgentConfig::Random {
name: None,
percent_fold: vec![0.25],
percent_call: vec![0.5],
};
let builder = ConfigAgentBuilder::new(config).unwrap();
let _agent1 = builder.clone().player_idx(0).build();
let _agent2 = builder.player_idx(1).build();
}
#[test]
fn test_from_str_or_file_inline_json() {
let json = r#"{"type":"all_in"}"#;
let generator = ConfigAgentBuilder::from_str_or_file(json).unwrap();
assert!(matches!(generator.config, AgentConfig::AllIn { .. }));
}
#[test]
fn test_cfr_basic_config() {
let json = r#"{
"type": "cfr_basic",
"exploration": {
"budget": [
{ "type": "per_depth_iterations", "counts": [10, 5, 1] }
]
}
}"#;
let config: AgentConfig = serde_json::from_str(json).unwrap();
match config {
AgentConfig::CfrBasic {
name, exploration, ..
} => {
assert!(name.is_none());
assert_eq!(
exploration.budget,
Some(BudgetConfig(vec![BudgetItem::PerDepthIterations {
counts: vec![10, 5, 1],
fallback: 1,
}]))
);
}
_ => panic!("Expected CfrBasic variant"),
}
}
#[test]
fn test_cfr_basic_defaults() {
let json = r#"{"type":"cfr_basic"}"#;
let config: AgentConfig = serde_json::from_str(json).unwrap();
match config {
AgentConfig::CfrBasic {
name, exploration, ..
} => {
assert!(name.is_none());
assert_eq!(exploration, CfrExploration::default());
assert!(exploration.budget.is_none());
}
_ => panic!("Expected CfrBasic variant"),
}
}
#[test]
fn cfr_configurable_parses_exploration() {
let json = r#"{
"type": "cfr_configurable",
"name": "X",
"exploration": {
"budget": [
{ "type": "per_depth_iterations", "counts": [100,5,3,1] }
]
},
"action_config": {
"preflop": {"call_enabled": true,"raise_mult":[2.5],"pot_mult":[],"setup_shove":false,"all_in":true},
"flop": {"call_enabled": true,"raise_mult":[],"pot_mult":[0.33,1.0],"setup_shove":false,"all_in":false},
"turn": {"call_enabled": true,"raise_mult":[],"pot_mult":[0.67,1.0],"setup_shove":true,"all_in":false},
"river": {"call_enabled": true,"raise_mult":[],"pot_mult":[0.67,1.5],"setup_shove":true,"all_in":true}
}
}"#;
let cfg: AgentConfig = serde_json::from_str(json).unwrap();
match cfg {
AgentConfig::CfrConfigurable { exploration, .. } => {
assert_eq!(
exploration.budget,
Some(BudgetConfig(vec![BudgetItem::PerDepthIterations {
counts: vec![100, 5, 3, 1],
fallback: 1,
}]))
);
}
_ => panic!("wrong variant"),
}
}
#[test]
fn test_cfr_agent_builder() {
let config = AgentConfig::CfrBasic {
name: None,
exploration: CfrExploration {
budget: Some(BudgetConfig(vec![BudgetItem::PerDepthIterations {
counts: vec![5, 1],
fallback: 1,
}])),
},
hand_estimator: EstimatorConfig::default(),
};
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let _agent = ConfigAgentBuilder::new(config)
.unwrap()
.player_idx(0)
.game_state(game_state)
.build();
}
#[test]
fn test_cfr_agent_builder_depth_based() {
let config = AgentConfig::CfrBasic {
name: Some("TestCFR".to_string()),
exploration: CfrExploration {
budget: Some(BudgetConfig(vec![BudgetItem::PerDepthIterations {
counts: vec![3, 2, 1],
fallback: 1,
}])),
},
hand_estimator: EstimatorConfig::default(),
};
let game_state = GameStateBuilder::new()
.num_players_with_stack(3, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let agent = ConfigAgentBuilder::new(config)
.unwrap()
.player_idx(1)
.game_state(game_state)
.build();
assert_eq!(agent.name(), "TestCFR");
}
#[test]
fn cfr_configurable_example_loads_and_builds() {
let json = std::fs::read_to_string("examples/configs/cfr_configurable.json")
.expect("example config should be readable");
let config: AgentConfig =
serde_json::from_str(&json).expect("example config should deserialize");
assert!(matches!(&config, AgentConfig::CfrConfigurable { .. }));
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let _agent = ConfigAgentBuilder::new(config)
.unwrap()
.player_idx(0)
.game_state(game_state)
.build();
}
#[test]
fn test_default_agent_name_format() {
let name = default_agent_name("TestAgent", 5);
assert!(!name.is_empty(), "agent name should not be empty");
assert_ne!(name, "xyzzy", "agent name should not be 'xyzzy'");
assert_eq!(name, "TestAgent-5", "agent name should be 'TestAgent-5'");
}
#[test]
fn test_validate_random_pot_control_match_arm() {
let config = AgentConfig::RandomPotControl {
name: Some("Test".to_string()),
percent_call: vec![0.5, 0.6], };
let result = config.validate();
assert!(
result.is_ok(),
"Valid RandomPotControl should pass validation"
);
let invalid_config = AgentConfig::RandomPotControl {
name: Some("Test".to_string()),
percent_call: vec![1.5], };
let result = invalid_config.validate();
assert!(
result.is_err(),
"Invalid RandomPotControl should fail validation"
);
}
#[test]
fn test_from_str_or_file_match_guard() {
let json_input = r#"{"type": "all_in"}"#;
let result = ConfigAgentBuilder::from_str_or_file(json_input);
assert!(
result.is_ok(),
"Valid JSON should parse when file not found"
);
let nonexistent = "/nonexistent/path/to/config.json";
let result = ConfigAgentBuilder::from_str_or_file(nonexistent);
assert!(
result.is_err(),
"Non-existent file with invalid JSON should fail"
);
}
#[test]
fn cfr_exploration_default_is_empty_budget() {
let e = CfrExploration::default();
assert!(e.budget.is_none());
}
#[test]
fn cfr_exploration_round_trips() {
let json = r#"{
"budget": [
{ "type": "deadline", "millis": 250 },
{ "type": "per_depth_iterations", "counts": [50,40,30,20,10,5] }
]
}"#;
let e: CfrExploration = serde_json::from_str(json).unwrap();
assert_eq!(
e.budget,
Some(BudgetConfig(vec![
BudgetItem::Deadline { millis: 250 },
BudgetItem::PerDepthIterations {
counts: vec![50, 40, 30, 20, 10, 5],
fallback: 1,
},
]))
);
}
#[test]
fn fill_default_budget_fills_none_and_preserves_some() {
let mut none_cfg = AgentConfig::CfrBasic {
name: None,
exploration: CfrExploration::default(),
hand_estimator: EstimatorConfig::default(),
};
let custom = BudgetConfig(vec![BudgetItem::IterationCount { max: 7 }]);
none_cfg.fill_default_budget(&custom);
match &none_cfg {
AgentConfig::CfrBasic { exploration, .. } => {
assert_eq!(exploration.budget.as_ref(), Some(&custom));
}
_ => unreachable!(),
}
let explicit = BudgetConfig(vec![BudgetItem::Deadline { millis: 42 }]);
let mut some_cfg = AgentConfig::CfrBasic {
name: None,
exploration: CfrExploration {
budget: Some(explicit.clone()),
},
hand_estimator: EstimatorConfig::default(),
};
some_cfg.fill_default_budget(&custom);
match &some_cfg {
AgentConfig::CfrBasic { exploration, .. } => {
assert_eq!(exploration.budget.as_ref(), Some(&explicit));
}
_ => unreachable!(),
}
let mut non_cfr = AgentConfig::AllIn { name: None };
non_cfr.fill_default_budget(&custom);
assert!(matches!(non_cfr, AgentConfig::AllIn { name: None }));
}
#[test]
fn test_resolve_agent_name_with_none() {
let name = resolve_agent_name(&None, "TestKind", 3);
assert_eq!(name, "TestKind-3", "Should use default name format");
}
#[test]
fn test_resolve_agent_name_with_some() {
let name = resolve_agent_name(&Some("CustomName".to_string()), "TestKind", 3);
assert_eq!(name, "CustomName", "Should use provided name");
}
#[test]
fn estimator_config_defaults_to_known() {
assert_eq!(EstimatorConfig::default(), EstimatorConfig::Known);
}
#[test]
fn estimator_config_round_trips_json() {
let parsed: EstimatorConfig = serde_json::from_str("\"uniform\"").unwrap();
assert_eq!(parsed, EstimatorConfig::Uniform);
}
#[tokio::test]
async fn estimator_config_builds_estimator() {
use crate::arena::hand_estimator::HandDistribution;
use crate::core::{Card, Hand};
let mut game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
game_state.hands[0] = Hand::new_with_cards(vec![Card::from(0), Card::from(1)]);
game_state.hands[1] = Hand::new_with_cards(vec![Card::from(2), Card::from(3)]);
game_state.round_data.to_act_idx = 0;
let known = EstimatorConfig::Known
.build()
.estimate(&game_state, None)
.await;
assert!(matches!(known.get(1), Some(HandDistribution::PointMass(_))));
let uniform = EstimatorConfig::Uniform
.build()
.estimate(&game_state, None)
.await;
assert!(matches!(
uniform.get(1),
Some(HandDistribution::Weighted(_))
));
}
#[test]
fn cfr_configurable_parses_hand_estimator() {
let json = r#"{
"type": "cfr_configurable",
"hand_estimator": "uniform",
"action_config": {}
}"#;
let cfg: AgentConfig = serde_json::from_str(json).unwrap();
match cfg {
AgentConfig::CfrConfigurable { hand_estimator, .. } => {
assert_eq!(hand_estimator, EstimatorConfig::Uniform);
}
other => panic!("expected CfrConfigurable, got {other:?}"),
}
}
#[test]
fn cfr_configurable_defaults_hand_estimator_to_known() {
let json = r#"{ "type": "cfr_configurable", "action_config": {} }"#;
let cfg: AgentConfig = serde_json::from_str(json).unwrap();
match cfg {
AgentConfig::CfrConfigurable { hand_estimator, .. } => {
assert_eq!(hand_estimator, EstimatorConfig::Known);
}
other => panic!("expected CfrConfigurable, got {other:?}"),
}
}
}