mod auto;
mod gambit;
mod json;
use cfr::{PlayerNum, RegretParams, SolveMethod};
use clap::{Parser, ValueEnum};
use serde::Serialize;
use std::borrow::Borrow;
use std::collections::HashMap;
use std::fs::File;
use std::io;
use std::io::BufReader;
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
enum Method {
Full,
Sampled,
External,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
enum InputFormat {
Auto,
Gambit,
Json,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
enum Discount {
Vanilla,
Lcfr,
CfrPlus,
Dcfr,
DcfrPrune,
}
impl Discount {
fn into_params(self) -> RegretParams {
match self {
Discount::Vanilla => RegretParams::vanilla(),
Discount::Lcfr => RegretParams::lcfr(),
Discount::CfrPlus => RegretParams::cfr_plus(),
Discount::Dcfr => RegretParams::dcfr(),
Discount::DcfrPrune => RegretParams::dcfr_prune(),
}
}
}
#[derive(Parser, Debug)]
#[clap(author, version, about)]
struct Args {
#[clap(short, long, value_parser, default_value_t = 0.0)]
clip_threshold: f64,
#[clap(short = 'r', long, value_parser, default_value_t = 0.0)]
max_regret: f64,
#[clap(short = 't', long, value_parser, default_value_t = 1000)]
max_iters: u64,
#[clap(short, long, value_parser, default_value_t = 0)]
parallel: usize,
#[clap(short, long, value_enum, default_value_t = Method::External)]
method: Method,
#[clap(long, value_enum, default_value_t = InputFormat::Auto)]
input_format: InputFormat,
#[clap(short, long, value_enum, default_value_t = Discount::Dcfr)]
discount: Discount,
#[clap(short, long, value_parser, default_value = "-")]
input: String,
#[clap(short, long, value_parser, default_value = "-")]
output: String,
}
#[derive(Serialize)]
struct Strategy(HashMap<String, HashMap<String, f64>>);
impl<I, A, S, T, N> From<I> for Strategy
where
I: IntoIterator<Item = (S, A)>,
A: IntoIterator<Item = (T, N)>,
S: AsRef<str>,
T: AsRef<str>,
N: Borrow<f64>,
{
fn from(named_strategies: I) -> Self {
let mut map = HashMap::new();
for (info, actions) in named_strategies {
assert!(
map.insert(
info.as_ref().to_owned(),
actions
.into_iter()
.filter(|(_, p)| p.borrow() > &0.0)
.map(|(a, p)| (a.as_ref().to_owned(), *p.borrow()))
.collect()
)
.is_none(),
"internal error: found duplicate infosets"
);
}
Strategy(map)
}
}
#[derive(Serialize)]
struct Output {
regret: f64,
player_one_utility: f64,
player_two_utility: f64,
player_one_regret: f64,
player_two_regret: f64,
player_one_strategy: Strategy,
player_two_strategy: Strategy,
}
fn main() {
let args = Args::parse();
let (game, sum) = if args.input == "-" {
let mut inp = io::stdin().lock();
match args.input_format {
InputFormat::Json => json::from_reader(&mut inp),
InputFormat::Gambit => gambit::from_reader(&mut inp),
InputFormat::Auto => auto::from_reader(&mut inp),
}
} else {
let mut inp = BufReader::new(File::open(&args.input).unwrap());
match args.input_format {
InputFormat::Json => json::from_reader(&mut inp),
InputFormat::Auto if args.input.ends_with(".json") => json::from_reader(&mut inp),
InputFormat::Gambit => gambit::from_reader(&mut inp),
InputFormat::Auto if args.input.ends_with(".efg") => gambit::from_reader(&mut inp),
InputFormat::Auto => auto::from_reader(&mut inp),
}
};
let max_iters = if args.max_iters == 0 {
u64::MAX
} else {
args.max_iters
};
let method = match args.method {
Method::Full => SolveMethod::Full,
Method::Sampled => SolveMethod::Sampled,
Method::External => SolveMethod::External,
};
let (mut strategies, _) = game
.solve(
method,
max_iters,
args.max_regret,
args.parallel,
Some(args.discount.into_params()),
)
.unwrap();
let mut info = strategies.get_info();
let mut pruned_strats = strategies.clone();
pruned_strats.truncate(args.clip_threshold);
let pruned_info = pruned_strats.get_info();
if pruned_info.regret() < info.regret() {
strategies = pruned_strats;
info = pruned_info;
}
let [one, two] = strategies.as_named();
let out = Output {
regret: info.regret(),
player_one_utility: info.player_utility(PlayerNum::One) + sum,
player_two_utility: info.player_utility(PlayerNum::Two) - sum,
player_one_regret: info.player_regret(PlayerNum::One),
player_two_regret: info.player_regret(PlayerNum::Two),
player_one_strategy: one.into(),
player_two_strategy: two.into(),
};
if args.output == "-" {
serde_json::to_writer(io::stdout(), &out).unwrap();
} else {
serde_json::to_writer(File::create(args.output).unwrap(), &out).unwrap();
};
}
#[cfg(test)]
mod tests {
use super::Args;
use clap::CommandFactory;
#[test]
fn test_cli() {
Args::command().debug_assert()
}
}