cfr 0.4.3

Counterfactual regret minimization solver for two-player zero-sum incomplete-information games
Documentation
use cfr::{Game, GameNode, IntoGameNode, PlayerNum};
use gambit_parser::{Error, ExtensiveFormGame, Node};
use num_traits::cast::ToPrimitive;
use std::collections::{HashMap, HashSet};
use std::io::Read;

struct JoinedNode<'a> {
    node: &'a Node<'a>,
    info: &'a GlobalInfo,
    cum_payoff: f64,
}

struct GlobalInfo {
    infoset_names: [HashMap<u64, String>; 2],
    outcomes: HashMap<u64, f64>,
    sum: f64,
}

impl IntoGameNode for JoinedNode<'_> {
    type PlayerInfo = String;
    type Action = String;
    type ChanceInfo = u64;
    // NOTE these are vectors because we need to sort by action to preserve infosets
    type Outcomes = Vec<(f64, Self)>;
    type Actions = Vec<(String, Self)>;

    fn into_game_node(self) -> GameNode<Self> {
        match self.node {
            Node::Terminal(term) => {
                let node_payoff = self.info.outcomes.get(&term.outcome()).unwrap();
                GameNode::Terminal(self.cum_payoff + node_payoff - self.info.sum)
            }
            Node::Chance(chance) => {
                let node_payoff = if chance.outcome() == 0 {
                    0.0
                } else {
                    *self.info.outcomes.get(&chance.outcome()).unwrap()
                };
                let mut outcomes: Vec<_> = chance
                    .actions()
                    .iter()
                    .map(|(act, prob, node)| {
                        (
                            act.to_string(),
                            prob.to_f64().unwrap(),
                            JoinedNode {
                                node,
                                info: self.info,
                                cum_payoff: self.cum_payoff + node_payoff,
                            },
                        )
                    })
                    .collect();
                // NOTE also sort by probs in case actions have identical names, note this will
                // still error when the game is parsed
                outcomes.sort_unstable_by(|(act1, prob1, _), (act2, prob2, _)| {
                    (act1, prob1).partial_cmp(&(act2, prob2)).unwrap()
                });
                GameNode::Chance(
                    Some(chance.infoset()),
                    outcomes
                        .into_iter()
                        .map(|(_, prob, node)| (prob, node))
                        .collect(),
                )
            }
            Node::Player(player) => {
                let player_num = match player.player_num() {
                    1 => PlayerNum::One,
                    2 => PlayerNum::Two,
                    _ => panic!("internal error: invalid player number"),
                };
                let node_payoff = if player.outcome() == 0 {
                    0.0
                } else {
                    *self.info.outcomes.get(&player.outcome()).unwrap()
                };
                let infoset = self.info.infoset_names[player.player_num() - 1]
                    .get(&player.infoset())
                    .unwrap()
                    .to_string();

                let mut actions: Vec<_> = player
                    .actions()
                    .iter()
                    .map(|(act, node)| {
                        (
                            act.to_string(),
                            JoinedNode {
                                node,
                                info: self.info,
                                cum_payoff: self.cum_payoff + node_payoff,
                            },
                        )
                    })
                    .collect();
                actions.sort_unstable_by(|(act1, _), (act2, _)| act1.cmp(act2));
                GameNode::Player(player_num, infoset, actions)
            }
        }
    }
}

/// Gets global game information
///
/// This assumes that the earlier validation was correct and will produce undefined behavior if
/// it's not validated
fn get_global_info(root: &Node<'_>) -> GlobalInfo {
    // get outcomes and named infosets
    let mut infoset_names: [HashMap<_, _>; 2] = Default::default();
    let mut infosets: [HashSet<_>; 2] = Default::default();
    let mut outcomes: HashMap<_, [f64; 2]> = HashMap::new();
    let mut queue = vec![root];
    while let Some(node) = queue.pop() {
        match node {
            Node::Terminal(terminal) => {
                let floats: Vec<f64> = terminal
                    .outcome_payoffs()
                    .iter()
                    .map(|r| r.to_f64().unwrap())
                    .collect();
                outcomes.insert(terminal.outcome(), floats.try_into().unwrap());
            }
            Node::Chance(chance) => {
                if let Some(pays) = chance.outcome_payoffs() {
                    let floats: Vec<f64> = pays.iter().map(|r| r.to_f64().unwrap()).collect();
                    outcomes.insert(chance.outcome(), floats.try_into().unwrap());
                }
                queue.extend(chance.actions().iter().map(|(_, _, next)| next));
            }
            Node::Player(player) => {
                if let Some(pays) = player.outcome_payoffs() {
                    let floats: Vec<f64> = pays.iter().map(|r| r.to_f64().unwrap()).collect();
                    outcomes.insert(player.outcome(), floats.try_into().unwrap());
                }
                if let Some(name) = player.infoset_name() {
                    infoset_names[player.player_num() - 1]
                        .entry(player.infoset())
                        .or_insert_with(|| name.to_string());
                }
                infosets[player.player_num() - 1].insert(player.infoset());
                queue.extend(player.actions().iter().map(|(_, next)| next));
            }
        }
    }

    // for unnamed infosets convert to their number
    for (given, mut seen) in infoset_names.iter_mut().zip(infosets) {
        let mut used_names = HashSet::new();
        for (infoset, name) in given.iter() {
            seen.remove(infoset);
            used_names.insert(name);
        }

        let mut number_names = Vec::new();
        for info in seen {
            let name = info.to_string();
            if used_names.contains(&name) {
                panic!("some infosets had no names, but the string of their number was used as the name of another infoset : https://github.com/erikbrinkman/cfr#duplicate-infosets");
            } else {
                number_names.push((info, name));
            }
        }
        given.extend(number_names);
    }

    // go through terminal payoffs to determine value of constant sum
    let mut min = f64::INFINITY;
    let mut max = -f64::INFINITY;
    let mut one_min = f64::INFINITY;
    let mut one_max = -f64::INFINITY;
    let mut queue = vec![(root, [0.0; 2])];
    while let Some((node, mut cum_pays)) = queue.pop() {
        match node {
            Node::Terminal(terminal) => {
                for (cum, out) in cum_pays
                    .iter_mut()
                    .zip(outcomes.get(&terminal.outcome()).unwrap())
                {
                    *cum += *out
                }
                let [one, two] = cum_pays;
                let sum = one + (two - one) / 2.0;
                if !sum.is_finite() {
                    panic!(
                        "received non-finite payoffs in gambit format; make sure payoffs fit in a double"
                    );
                }
                min = f64::min(min, sum);
                max = f64::max(max, sum);
                one_min = f64::min(one_min, one);
                one_max = f64::max(one_max, one);
            }
            Node::Chance(chance) => {
                if chance.outcome() != 0 {
                    for (cum, out) in cum_pays
                        .iter_mut()
                        .zip(outcomes.get(&chance.outcome()).unwrap())
                    {
                        *cum += *out
                    }
                }
                queue.extend(chance.actions().iter().map(|(_, _, next)| (next, cum_pays)));
            }
            Node::Player(player) => {
                if player.outcome() != 0 {
                    for (cum, out) in cum_pays
                        .iter_mut()
                        .zip(outcomes.get(&player.outcome()).unwrap())
                    {
                        *cum += *out
                    }
                }
                queue.extend(player.actions().iter().map(|(_, next)| (next, cum_pays)));
            }
        }
    }
    if (max - min) * 1000.0 > (one_max - one_min) {
        panic!(
            "gambit file wasn't constant sum : https://github.com/erikbrinkman/cfr#constant-sum"
        );
    }

    // merge results
    GlobalInfo {
        infoset_names,
        outcomes: outcomes
            .into_iter()
            .map(|(info, [one, _])| (info, one))
            .collect(),
        sum: min + (max - min) / 2.0,
    }
}

/// This returns an error on parsing indicating another format should be tried
pub fn from_str(raw: &str) -> Result<(Game<String, String>, f64), Error<'_>> {
    let gambit = ExtensiveFormGame::try_from(raw)?;
    if gambit.player_names().len() != 2 {
        panic!(
            "game file has {} players, but `cfr` only supports two player games",
            gambit.player_names().len()
        );
    }
    let info = get_global_info(gambit.root());
    let game = Game::from_root(JoinedNode {
            node: gambit.root(),
            info: &info,
            cum_payoff: 0.0,
        })
        .expect("couldn't extract a compact game representation due to problems with the structure : https://github.com/erikbrinkman/cfr#game-error");
    Ok((game, info.sum))
}

pub fn from_reader(reader: &mut impl Read) -> (Game<String, String>, f64) {
    let mut buff = String::new();
    reader.read_to_string(&mut buff).unwrap();
    from_str(&buff).expect(
        "couldn't parse gambit game definition : https://github.com/erikbrinkman/cfr#gambit-error",
    )
}

#[cfg(test)]
mod tests {
    #[test]
    #[should_panic(expected = "couldn't parse gambit game definition")]
    fn test_parse() {
        super::from_reader(&mut "EFG 2 F".as_bytes());
    }

    #[test]
    #[should_panic(expected = "received non-finite payoffs in gambit format")]
    fn test_finite_payoffs() {
        super::from_str(r#"EFG 2 R "" { "" "" } t "" 1 { 1000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 0 }"#)
            .unwrap();
    }

    #[test]
    #[should_panic(expected = "game file has 3 players")]
    fn test_three_players() {
        super::from_str(r#"EFG 2 R "" { "" "" "" } t "" 1 { 0 0 0 }"#).unwrap();
    }

    #[test]
    #[should_panic(expected = "gambit file wasn't constant sum")]
    fn test_constant_sum_error() {
        super::from_str(
            r#"EFG 2 R "" { "" "" } c "" 1 { "" 1/2 "" 1/2 } 0 t "" 1 { 0 0 } t "" 2 { 1 1 }"#,
        )
        .unwrap();
    }

    #[test]
    #[should_panic(
        expected = "some infosets had no names, but the string of their number was used as the name of another infoset"
    )]
    fn test_infoset_names() {
        super::from_str(
            r#"EFG 2 R "" { "" "" } p "" 1 1 "2" { "" "" } 0 t "" 1 { 0 0 } p "" 1 2 { "" } 0 t "" 1 { 0 0 }"#,
        )
        .unwrap();
    }

    #[test]
    #[should_panic(
        expected = "couldn't extract a compact game representation due to problems with the structure"
    )]
    fn test_repr() {
        super::from_str(
            r#"EFG 2 R "" { "" "" } p "" 1 1 { "" "" } 0 t "" 1 { 0 0 } p "" 1 1 { "" "" } 0 t "" 1 { 0 0 } t "" 1 { 0 0 }"#,
        )
        .unwrap();
    }

    #[test]
    fn test_constant_sum() {
        let (_, sum) = super::from_str(r#"EFG 2 R "" { "" "" } t "" 2 { 1 1 }"#).unwrap();
        assert_eq!(sum, 1.0);
    }
}