multi_skill/systems/true_skill/
mod.rs

1mod nodes;
2mod normal;
3
4use super::util::{Player, Rating, RatingSystem};
5
6use nodes::{FuncNode, GreaterNode, LeqNode, ProdNode, SumNode, TreeNode, ValueNode};
7use normal::Gaussian;
8
9use std::cell::RefCell;
10use std::rc::Rc;
11
12type TSMessage = nodes::Message;
13type TSPlayer<'a> = (&'a mut Player, Gaussian);
14type TSTeam<'a> = Vec<TSPlayer<'a>>;
15type TSContestPlace<'a> = Vec<TSTeam<'a>>;
16type TSContest<'a> = Vec<TSContestPlace<'a>>;
17
18// TrueSkillStPB rating system
19#[derive(Debug)]
20pub struct TrueSkillSPb {
21    // epsilon used for ties
22    pub eps: f64,
23    // performance sigma
24    pub beta: f64,
25    // epsilon used for convergence loop
26    pub convergence_eps: f64,
27    // defines sigma growth per round
28    pub sig_drift: f64,
29}
30
31impl Default for TrueSkillSPb {
32    fn default() -> Self {
33        Self {
34            eps: 1.,
35            beta: 175.,
36            convergence_eps: 1e-4,
37            sig_drift: 35.,
38        }
39    }
40}
41
42fn gen_team_message<T, K: Clone>(places: &[Vec<T>], default: &K) -> Vec<Vec<K>> {
43    places
44        .iter()
45        .map(|place| vec![default.clone(); place.len()])
46        .collect()
47}
48
49fn gen_player_message<T, K: Clone>(places: &[Vec<Vec<T>>], default: &K) -> Vec<Vec<Vec<K>>> {
50    places
51        .iter()
52        .map(|place| {
53            place
54                .iter()
55                .map(|team| vec![default.clone(); team.len()])
56                .collect()
57        })
58        .collect()
59}
60
61fn infer1(who: &mut Vec<impl TreeNode>) {
62    for item in who {
63        item.infer();
64    }
65}
66
67fn infer2(who: &mut Vec<Vec<impl TreeNode>>) {
68    for item in who {
69        infer1(item);
70    }
71}
72
73fn infer3(who: &mut Vec<Vec<Vec<impl TreeNode>>>) {
74    for item in who {
75        infer2(item);
76    }
77}
78
79fn infer_ld(ld: &mut Vec<impl TreeNode>, l: &mut Vec<impl TreeNode>) {
80    for i in 0..ld.len() {
81        l[i].infer();
82        ld[i].infer();
83    }
84    l.last_mut().unwrap().infer();
85    for i in (0..ld.len()).rev() {
86        ld[i].infer();
87        l[i].infer();
88    }
89}
90
91fn check_convergence(
92    a: &[Rc<RefCell<(TSMessage, TSMessage)>>],
93    b: &[(TSMessage, TSMessage)],
94) -> f64 {
95    if a.len() != b.len() {
96        return std::f64::INFINITY;
97    }
98
99    a.iter()
100        .map(|ai| ai.borrow())
101        .zip(b.iter())
102        .flat_map(|(ai, bi)| {
103            vec![
104                ai.0.mu - bi.0.mu,
105                ai.0.sigma - bi.0.sigma,
106                ai.1.mu - bi.1.mu,
107                ai.1.sigma - bi.1.sigma,
108            ]
109        })
110        .map(f64::abs)
111        .max_by(|x, y| x.partial_cmp(y).expect("Difference became NaN"))
112        .unwrap_or(0.)
113}
114
115impl TrueSkillSPb {
116    fn inference(&self, contest_weight: f64, contest: &mut TSContest) {
117        if contest.is_empty() {
118            return;
119        }
120
121        // could be optimized, written that way for simplicity
122        // TODO: invent better variable names
123        let sig_perf = self.beta / contest_weight.sqrt();
124        let mut s = gen_player_message(contest, &ProdNode::new());
125        let mut perf = gen_player_message(contest, &ProdNode::new());
126        let mut p = gen_player_message(contest, &ProdNode::new());
127        let mut t = gen_team_message(contest, &ProdNode::new());
128        let mut u = gen_team_message(contest, &LeqNode::new(self.eps));
129        let mut l = vec![ProdNode::new(); contest.len()];
130        let mut d = vec![GreaterNode::new(2. * self.eps); contest.len() - 1];
131        let mut sp = vec![];
132        let mut pt = vec![];
133        let mut tul = vec![];
134        let mut ld = vec![];
135        let mut players = vec![];
136        let mut conv = vec![];
137        let mut old_conv = vec![];
138
139        for i in 0..contest.len() {
140            for j in 0..contest[i].len() {
141                for k in 0..contest[i][j].len() {
142                    let new_edge = s[i][j][k].add_edge();
143
144                    new_edge.upgrade().unwrap().borrow_mut().0 = contest[i][j][k].1.clone();
145
146                    sp.push(SumNode::new(&mut [
147                        &mut p[i][j][k],
148                        &mut s[i][j][k],
149                        &mut perf[i][j][k],
150                    ]));
151                    RefCell::borrow_mut(perf[i][j][k].get_edges_mut().last_mut().unwrap()).1 =
152                        Gaussian {
153                            mu: 0.,
154                            sigma: sig_perf,
155                        };
156
157                    players.push((i, j, k, new_edge));
158                }
159
160                let mut tt: Vec<&mut dyn ValueNode> = vec![&mut t[i][j]];
161                tt.extend(p[i][j].iter_mut().map(|pp| pp as &mut dyn ValueNode));
162
163                pt.push(SumNode::new(&mut tt));
164                tul.push(SumNode::new(&mut [&mut l[i], &mut t[i][j], &mut u[i][j]]));
165                conv.push(t[i][j].get_edges().last().unwrap().clone());
166            }
167
168            if i != 0 {
169                match &mut l[i - 1..=i] {
170                    [a, b] => {
171                        ld.push(SumNode::new(&mut [a, b, &mut d[i - 1]]));
172                    }
173                    _ => panic!("Must have 0 < i < l.len()"),
174                };
175            }
176        }
177
178        infer3(&mut s);
179        infer1(&mut sp);
180        infer3(&mut p);
181        infer1(&mut pt);
182        infer2(&mut t);
183        infer1(&mut tul);
184        infer2(&mut u);
185        infer1(&mut tul);
186
187        //let mut rounds = 0;
188
189        while check_convergence(&conv, &old_conv) >= self.convergence_eps {
190            old_conv.clear();
191            for item in &conv {
192                old_conv.push(RefCell::borrow(item).clone());
193            }
194            //rounds += 1;
195
196            infer_ld(&mut ld, &mut l);
197            infer1(&mut d);
198            infer_ld(&mut ld, &mut l);
199            infer1(&mut tul);
200            infer2(&mut u);
201            infer1(&mut tul);
202        }
203
204        //eprintln!("Rounds until convergence: {}", rounds);
205
206        infer2(&mut t);
207        infer1(&mut pt);
208        infer3(&mut p);
209        infer1(&mut sp);
210        infer3(&mut s);
211
212        for (i, j, k, mess) in players {
213            let val = mess.upgrade().unwrap();
214            let (prior, performance) = &*val.borrow();
215            let (player, gaussian) = &mut contest[i][j][k];
216
217            *gaussian = prior * performance;
218            player.update_rating(
219                Rating {
220                    mu: gaussian.mu,
221                    sig: gaussian.sigma,
222                },
223                0.,
224            );
225        }
226    }
227}
228
229impl RatingSystem for TrueSkillSPb {
230    fn round_update(&self, contest_weight: f64, standings: Vec<(&mut Player, usize, usize)>) {
231        let mut contest = TSContest::new();
232
233        for i in 1..standings.len() {
234            assert!(standings[i - 1].1 <= standings[i].1);
235        }
236
237        let mut prev = usize::MAX;
238        for (user, lo, _hi) in standings {
239            if lo != prev {
240                contest.push(vec![]);
241            }
242            let noised = user.approx_posterior.with_noise(self.sig_drift);
243            let gaussian = Gaussian {
244                mu: noised.mu,
245                sigma: noised.sig,
246            };
247            contest.last_mut().unwrap().push(vec![(user, gaussian)]);
248            prev = lo;
249        }
250
251        // do inference
252        self.inference(contest_weight, &mut contest);
253    }
254}