skillratings 0.29.0

Calculate a player's skill rating using algorithms like Elo, Glicko-2, TrueSkill and many more.
Documentation
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;

use super::gaussian::Gaussian;

#[derive(Clone, Debug, Default, PartialEq)]
pub struct Variable {
    pub gaussian: Gaussian,
    messages: HashMap<usize, Gaussian>,
}

impl Variable {
    pub fn new() -> Self {
        Self {
            gaussian: Gaussian::default(),
            messages: HashMap::new(),
        }
    }

    fn set(&mut self, val: Gaussian) -> f64 {
        let delta = self.delta(val);

        self.gaussian.pi = val.pi;
        self.gaussian.tau = val.tau;

        delta
    }

    fn update_message(&mut self, factor_id: usize, message: Gaussian) -> f64 {
        let old_message = self.messages[&factor_id];
        let v = self.messages.entry(factor_id).or_default();
        *v = message;

        self.set(self.gaussian / old_message * message)
    }

    fn update_value(&mut self, factor_id: usize, val: Gaussian) -> f64 {
        let old_message = self.messages[&factor_id];
        let v = self.messages.entry(factor_id).or_default();
        *v = val * old_message / self.gaussian;

        self.set(val)
    }

    fn delta(&self, other: Gaussian) -> f64 {
        let pi_delta = (self.gaussian.pi - other.pi).abs();
        if pi_delta.is_infinite() {
            return 0.0;
        }

        (self.gaussian.tau - other.tau).abs().max(pi_delta.sqrt())
    }
}

pub struct PriorFactor {
    id: usize,
    pub variable: Rc<RefCell<Variable>>,
    val: Gaussian,
    dynamic: f64,
}

impl PriorFactor {
    pub fn new(id: usize, variable: Rc<RefCell<Variable>>, val: Gaussian, dynamic: f64) -> Self {
        variable.borrow_mut().messages.entry(id).or_default();

        Self {
            id,
            variable,
            val,
            dynamic,
        }
    }

    pub fn down(&self) -> f64 {
        let sigma = self.val.sigma().hypot(self.dynamic);
        let value = Gaussian::with_mu_sigma(self.val.mu(), sigma);
        self.variable.borrow_mut().update_value(self.id, value)
    }
}

pub struct LikelihoodFactor {
    id: usize,
    mean: Rc<RefCell<Variable>>,
    value: Rc<RefCell<Variable>>,
    variance: f64,
}

impl LikelihoodFactor {
    pub fn new(
        id: usize,
        mean: Rc<RefCell<Variable>>,
        value: Rc<RefCell<Variable>>,
        variance: f64,
    ) -> Self {
        mean.borrow_mut().messages.entry(id).or_default();
        value.borrow_mut().messages.entry(id).or_default();

        Self {
            id,
            mean,
            value,
            variance,
        }
    }

    pub fn down(&self) -> f64 {
        let msg = {
            let mean = self.mean.borrow();
            mean.gaussian / mean.messages[&self.id]
        };
        let a = self.calc_a(msg);
        self.value
            .borrow_mut()
            .update_message(self.id, Gaussian::with_pi_tau(a * msg.pi, a * msg.tau))
    }

    pub fn up(&self) -> f64 {
        let msg = {
            let value = self.value.borrow();
            value.gaussian / value.messages[&self.id]
        };
        let a = self.calc_a(msg);
        self.mean
            .borrow_mut()
            .update_message(self.id, Gaussian::with_pi_tau(a * msg.pi, a * msg.tau))
    }

    fn calc_a(&self, gaussian: Gaussian) -> f64 {
        self.variance.mul_add(gaussian.pi, 1.0).recip()
    }
}

pub struct SumFactor {
    id: usize,
    sum: Rc<RefCell<Variable>>,
    terms: Vec<Rc<RefCell<Variable>>>,
    coeffs: Vec<f64>,
}

impl SumFactor {
    pub fn new(
        id: usize,
        sum: Rc<RefCell<Variable>>,
        terms: Vec<Rc<RefCell<Variable>>>,
        coeffs: Vec<f64>,
    ) -> Self {
        sum.borrow_mut().messages.entry(id).or_default();
        for term in &terms {
            term.borrow_mut().messages.entry(id).or_default();
        }

        Self {
            id,
            sum,
            terms,
            coeffs,
        }
    }

    pub fn down(&self) -> f64 {
        let msgs: Vec<Gaussian> = self
            .terms
            .iter()
            .map(|term| term.borrow().messages[&self.id])
            .collect();
        self.update(&self.sum, &self.terms, &msgs, &self.coeffs)
    }

    pub fn up(&self, index: usize) -> f64 {
        let coeff = self.coeffs[index];
        let mut coeffs = Vec::new();
        for (x, c) in self.coeffs.iter().enumerate() {
            if coeff == 0.0 {
                coeffs.push(0.0);
            } else if x == index {
                coeffs.push(coeff.recip());
            } else {
                coeffs.push(-(*c) / coeff);
            }
        }

        let mut vals = self.terms.clone();
        vals[index] = self.sum.clone();
        let msgs: Vec<Gaussian> = vals
            .iter()
            .map(|val| val.borrow().messages[&self.id])
            .collect();

        self.update(&self.terms[index], &vals, &msgs, &coeffs)
    }

    #[inline]
    pub const fn terms_len(&self) -> usize {
        self.terms.len()
    }

    fn update(
        &self,
        var: &Rc<RefCell<Variable>>,
        vals: &[Rc<RefCell<Variable>>],
        msgs: &[Gaussian],
        coeffs: &[f64],
    ) -> f64 {
        let mut pi_inv = 0.0_f64;
        let mut mu = 0.0;

        for ((val, msg), coeff) in vals.iter().zip(msgs).zip(coeffs) {
            let div = val.borrow().gaussian / *msg;
            mu += coeff * div.mu();
            if pi_inv.is_infinite() {
                continue;
            }

            if div.pi == 0.0 {
                pi_inv = f64::INFINITY;
            } else {
                pi_inv += coeff.powi(2) / div.pi;
            }
        }

        let pi = pi_inv.recip();
        let tau = pi * mu;

        var.borrow_mut()
            .update_message(self.id, Gaussian::with_pi_tau(pi, tau))
    }
}

pub struct TruncateFactor {
    id: usize,
    variable: Rc<RefCell<Variable>>,
    v_func: Box<dyn Fn(f64, f64, f64) -> f64>,
    w_func: Box<dyn Fn(f64, f64, f64) -> f64>,
    draw_margin: f64,
}

impl TruncateFactor {
    pub fn new(
        id: usize,
        variable: Rc<RefCell<Variable>>,
        v_func: Box<dyn Fn(f64, f64, f64) -> f64>,
        w_func: Box<dyn Fn(f64, f64, f64) -> f64>,
        draw_margin: f64,
    ) -> Self {
        variable.borrow_mut().messages.entry(id).or_default();

        Self {
            id,
            variable,
            v_func,
            w_func,
            draw_margin,
        }
    }

    pub fn up(&self) -> f64 {
        let div = {
            let variable = self.variable.borrow();
            variable.gaussian / variable.messages[&self.id]
        };
        let pi_sqrt = div.pi.sqrt();
        let arg_1 = div.tau;
        let arg_2 = self.draw_margin * div.pi;
        let v = (self.v_func)(arg_1, arg_2, pi_sqrt);
        let w = (self.w_func)(arg_1, arg_2, pi_sqrt);
        let denom = 1.0 - w;

        let pi = div.pi / denom;
        let tau = pi_sqrt.mul_add(v, div.tau) / denom;

        self.variable
            .borrow_mut()
            .update_value(self.id, Gaussian::with_pi_tau(pi, tau))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_delta_inf() {
        let mut v1 = Variable::new();

        v1.set(Gaussian::with_pi_tau(f64::INFINITY, 1.0));

        assert!(v1.delta(Gaussian::with_pi_tau(0.0, 0.0)) < f64::EPSILON);
    }

    #[test]
    fn test_sum_factor() {
        let mut v1 = Variable::new();
        let mut v2 = Variable::new();

        v1.set(Gaussian::with_pi_tau(f64::INFINITY, 1.0));
        v2.set(Gaussian::with_pi_tau(0.0, 1.0));

        let sm1 = SumFactor::new(
            0,
            Rc::new(RefCell::new(v1.clone())),
            vec![Rc::new(RefCell::new(v2.clone()))],
            vec![0.0],
        );

        sm1.up(0);

        assert_eq!(sm1.id, 0);
        assert_eq!(sm1.coeffs, vec![0.0]);
    }

    #[test]
    #[should_panic(expected = "no entry found for key")]
    fn test_no_update() {
        let mut v1 = Variable::new();
        let mut v2 = Variable::new();

        v1.set(Gaussian::with_pi_tau(f64::INFINITY, 1.0));
        v2.set(Gaussian::with_pi_tau(0.0, 1.0));

        let sm1 = SumFactor::new(
            0,
            Rc::new(RefCell::new(v1.clone())),
            vec![Rc::new(RefCell::new(v2.clone()))],
            vec![0.0],
        );

        sm1.up(0);

        assert_eq!(sm1.id, 0);
        assert_eq!(sm1.coeffs, vec![0.0]);

        sm1.update(
            &Rc::new(RefCell::new(v1)),
            &[Rc::new(RefCell::new(v2))],
            &[Gaussian::with_pi_tau(0.0, 1.0)],
            &[0.0, 1.0, 2.0, 3.0],
        );
    }
}