cfr 0.1.0

Counterfactual regret minimization solver for two-player zero-sum incomplete-information games
Documentation
use rand::thread_rng;
use rand_distr::{Distribution, WeightedAliasIndex};
use std::cell::RefCell;

#[derive(Debug)]
pub struct SampledChance {
    index: WeightedAliasIndex<f64>,
    cached: RefCell<usize>,
}

impl SampledChance {
    pub fn new(probs: &[f64]) -> Self {
        SampledChance {
            index: WeightedAliasIndex::new(probs.to_vec()).unwrap(),
            cached: RefCell::new(0),
        }
    }

    pub fn sample(&self) -> usize {
        let mut cached = self.cached.borrow_mut();
        if *cached == 0 {
            let res = self.index.sample(&mut thread_rng());
            *cached = res + 1;
            res
        } else {
            *cached - 1
        }
    }

    pub fn reset(&mut self) {
        *self.cached.get_mut() = 0;
    }
}

#[derive(Debug)]
pub struct RegretInfoset<const AVG: bool> {
    pub cum_regret: Box<[f64]>,
    pub cum_strat: Box<[f64]>,
    pub strat: Box<[f64]>,
}

impl<const AVG: bool> RegretInfoset<AVG> {
    pub fn new(num_actions: usize) -> RegretInfoset<AVG> {
        RegretInfoset {
            cum_regret: vec![0.0; num_actions].into_boxed_slice(),
            cum_strat: vec![0.0; num_actions].into_boxed_slice(),
            strat: vec![1.0 / num_actions as f64; num_actions].into_boxed_slice(),
        }
    }

    pub fn regret_match(&mut self) {
        let norm: f64 = self.cum_regret.iter().filter(|v| v > &&0.0).sum();
        if norm > 0.0 {
            for (reg, val) in self.cum_regret.iter().zip(self.strat.iter_mut()) {
                *val = if reg > &0.0 { reg / norm } else { 0.0 }
            }
        } else if AVG {
            self.strat.fill(1.0 / self.strat.len() as f64);
        } else {
            let (ind, _) = self
                .cum_regret
                .iter()
                .enumerate()
                .max_by(|(_, l), (_, r)| l.partial_cmp(r).unwrap())
                .unwrap();
            self.strat.fill(0.0);
            self.strat[ind] = 1.0;
        }
    }

    pub fn cum_regret(&self) -> f64 {
        f64::max(
            self.cum_regret
                .iter()
                .copied()
                .reduce(f64::max)
                .unwrap_or(0.0),
            0.0,
        )
    }

    pub fn into_avg_strat(mut self) -> Box<[f64]> {
        let norm: f64 = self.cum_strat.iter().sum();
        if norm == 0.0 {
            self.cum_strat.fill(1.0 / self.cum_strat.len() as f64);
        } else {
            for prob in self.cum_strat.iter_mut() {
                *prob /= norm;
            }
        }
        self.cum_strat
    }
}