mod auto;
mod gambit;
mod json;
use cfr::PlayerNum;
use clap::{Parser, ValueEnum};
use serde::Serialize;
use std::borrow::Borrow;
use std::collections::HashMap;
use std::fs::File;
use std::io;
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
enum Method {
Full,
Sampled,
External,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
enum InputFormat {
Auto,
Gambit,
Json,
}
#[derive(Parser, Debug)]
#[clap(author, version, about)]
struct Args {
#[clap(short = 'p', long, value_parser, default_value_t = 0.0)]
prune_threshold: f64,
#[clap(short = 'r', long, value_parser, default_value_t = 0.0)]
max_regret: f64,
#[clap(short = 't', long, value_parser, default_value_t = 1000)]
max_iters: u64,
#[clap(short, long, value_enum, default_value_t = Method::Sampled)]
method: Method,
#[clap(long, value_enum, default_value_t = InputFormat::Auto)]
input_format: InputFormat,
#[clap(short, long, value_parser, default_value = "-")]
input: String,
#[clap(short, long, value_parser, default_value = "-")]
output: String,
}
#[derive(Serialize)]
struct Strategy(HashMap<String, HashMap<String, f64>>);
impl<I, A, S, T, N> From<I> for Strategy
where
I: IntoIterator<Item = (S, A)>,
A: IntoIterator<Item = (T, N)>,
S: AsRef<str>,
T: AsRef<str>,
N: Borrow<f64>,
{
fn from(named_strategies: I) -> Self {
let mut map = HashMap::new();
for (info, actions) in named_strategies {
assert!(
map.insert(
info.as_ref().to_owned(),
actions
.into_iter()
.filter(|(_, p)| p.borrow() > &0.0)
.map(|(a, p)| (a.as_ref().to_owned(), *p.borrow()))
.collect()
)
.is_none(),
"internal error: found duplicate infosets"
);
}
Strategy(map)
}
}
#[derive(Serialize)]
struct Output {
expected_one_utility: f64,
player_one_regret: f64,
player_two_regret: f64,
regret: f64,
player_one_strategy: Strategy,
player_two_strategy: Strategy,
}
fn main() {
let args = Args::parse();
let (game, sum) = if args.input == "-" {
match args.input_format {
InputFormat::Json => json::from_reader(io::stdin()),
InputFormat::Gambit => gambit::from_reader(io::stdin()),
InputFormat::Auto => auto::from_reader(io::stdin()),
}
} else {
match args.input_format {
InputFormat::Json => json::from_reader(File::open(args.input).unwrap()),
InputFormat::Auto if args.input.ends_with(".json") => {
json::from_reader(File::open(args.input).unwrap())
}
InputFormat::Gambit => gambit::from_reader(File::open(args.input).unwrap()),
InputFormat::Auto if args.input.ends_with(".efg") => {
gambit::from_reader(File::open(args.input).unwrap())
}
InputFormat::Auto => auto::from_reader(File::open(args.input).unwrap()),
}
};
let max_iters = if args.max_iters == 0 {
u64::MAX
} else {
args.max_iters
};
let (mut strategies, _) = match args.method {
Method::Full => game.solve_full(max_iters, args.max_regret),
Method::Sampled => game.solve_sampled(max_iters, args.max_regret),
Method::External => game.solve_external(max_iters, args.max_regret),
};
let mut info = strategies.get_info();
let mut pruned_strats = strategies.clone();
pruned_strats.truncate(args.prune_threshold);
let pruned_info = pruned_strats.get_info();
if pruned_info.regret() < info.regret() {
strategies = pruned_strats;
info = pruned_info;
}
let [one, two] = strategies.as_named();
let out = Output {
expected_one_utility: info.player_utility(PlayerNum::One) + sum,
player_one_regret: info.player_regret(PlayerNum::One),
player_two_regret: info.player_regret(PlayerNum::Two),
regret: info.regret(),
player_one_strategy: one.into(),
player_two_strategy: two.into(),
};
if args.output == "-" {
serde_json::to_writer(io::stdout(), &out).unwrap();
} else {
serde_json::to_writer(File::create(args.output).unwrap(), &out).unwrap();
};
}