Skip to main content

multi_skill/systems/true_skill/
nodes.rs

1use super::normal::{Gaussian, ONE, ZERO};
2use std::cell::RefCell;
3use std::rc::{Rc, Weak};
4
5pub type Message = Gaussian;
6
7pub trait TreeNode {
8    fn infer(&mut self);
9}
10
11pub trait ValueNode: TreeNode {
12    fn add_edge(&mut self) -> Weak<RefCell<(Message, Message)>>;
13}
14
15pub trait FuncNode: TreeNode {
16    fn new(neighbours: &mut [&mut dyn ValueNode]) -> Self;
17}
18
19#[derive(Clone)]
20pub struct ProdNode {
21    edges: Vec<Rc<RefCell<(Message, Message)>>>,
22}
23
24#[derive(Clone)]
25pub struct LeqNode {
26    eps: f64,
27    edge: Rc<RefCell<(Message, Message)>>,
28}
29
30#[derive(Clone)]
31pub struct GreaterNode {
32    eps: f64,
33    edge: Rc<RefCell<(Message, Message)>>,
34}
35
36#[derive(Clone)]
37pub struct SumNode {
38    out_edge: Weak<RefCell<(Message, Message)>>,
39    sum_edges: Vec<Weak<RefCell<(Message, Message)>>>,
40}
41
42impl TreeNode for ProdNode {
43    fn infer(&mut self) {
44        fn get_prefix_prods(from: &[Rc<RefCell<(Message, Message)>>]) -> Vec<Message> {
45            let mut prefix_prods = Vec::with_capacity(from.len() + 1);
46            prefix_prods.push(ONE);
47
48            for val in from {
49                let (ref val, _) = *val.borrow();
50                prefix_prods.push(prefix_prods.last().unwrap() * val);
51            }
52
53            prefix_prods
54        }
55
56        let prefix_prods = get_prefix_prods(self.edges.as_slice());
57
58        self.edges.reverse();
59        let mut suffix_prods = get_prefix_prods(self.edges.as_slice());
60        self.edges.reverse();
61        suffix_prods.reverse();
62        let suffix_prods = suffix_prods;
63
64        for i in 0..self.edges.len() {
65            RefCell::borrow_mut(&self.edges[i]).1 = &prefix_prods[i] * &suffix_prods[i + 1];
66        }
67    }
68}
69
70impl ValueNode for ProdNode {
71    fn add_edge(&mut self) -> Weak<RefCell<(Message, Message)>> {
72        self.edges.push(Rc::new(RefCell::new((ONE, ZERO))));
73        Rc::downgrade(&self.edges.last().unwrap())
74    }
75}
76
77impl ProdNode {
78    pub fn get_edges_mut(&mut self) -> &mut Vec<Rc<RefCell<(Message, Message)>>> {
79        &mut self.edges
80    }
81
82    pub fn get_edges(&self) -> &Vec<Rc<RefCell<(Message, Message)>>> {
83        &self.edges
84    }
85
86    pub fn new() -> Self {
87        ProdNode { edges: Vec::new() }
88    }
89}
90
91impl TreeNode for LeqNode {
92    fn infer(&mut self) {
93        let ans;
94        {
95            ans = RefCell::borrow(&self.edge).0.leq_eps(self.eps);
96        }
97        RefCell::borrow_mut(&self.edge).1 = ans;
98    }
99}
100
101impl ValueNode for LeqNode {
102    fn add_edge(&mut self) -> Weak<RefCell<(Message, Message)>> {
103        Rc::downgrade(&self.edge)
104    }
105}
106
107impl LeqNode {
108    pub fn new(eps: f64) -> LeqNode {
109        LeqNode {
110            eps,
111            edge: Rc::new(RefCell::new((ZERO, ZERO))),
112        }
113    }
114}
115
116impl TreeNode for GreaterNode {
117    fn infer(&mut self) {
118        let ans;
119        {
120            ans = RefCell::borrow(&self.edge).0.greater_eps(self.eps);
121        }
122        RefCell::borrow_mut(&self.edge).1 = ans;
123    }
124}
125
126impl ValueNode for GreaterNode {
127    fn add_edge(&mut self) -> Weak<RefCell<(Message, Message)>> {
128        Rc::downgrade(&self.edge)
129    }
130}
131
132impl GreaterNode {
133    pub fn new(eps: f64) -> GreaterNode {
134        GreaterNode {
135            eps,
136            edge: Rc::new(RefCell::new((ZERO, ZERO))),
137        }
138    }
139}
140
141impl FuncNode for SumNode {
142    fn new(neighbours: &mut [&mut dyn ValueNode]) -> Self {
143        assert!(!neighbours.is_empty());
144
145        let sum_edges: Vec<_> = neighbours
146            .iter_mut()
147            .skip(1)
148            .map(|nb| nb.add_edge())
149            .collect();
150
151        SumNode {
152            out_edge: neighbours[0].add_edge(),
153            sum_edges,
154        }
155    }
156}
157
158impl TreeNode for SumNode {
159    fn infer(&mut self) {
160        fn get_prefix_sums(from: &[Weak<RefCell<(Message, Message)>>]) -> Vec<Message> {
161            let mut prefix_sums = Vec::with_capacity(from.len() + 1);
162            prefix_sums.push(ZERO);
163
164            for val in from {
165                let val = val.upgrade().unwrap();
166                let (_, ref val) = *val.borrow();
167                prefix_sums.push(prefix_sums.last().unwrap() + val);
168            }
169
170            prefix_sums
171        }
172
173        let prefix_sums = get_prefix_sums(self.sum_edges.as_slice());
174        self.sum_edges.reverse();
175        let mut suffix_sums = get_prefix_sums(self.sum_edges.as_slice());
176        self.sum_edges.reverse();
177        suffix_sums.reverse();
178        let suffix_sums = suffix_sums;
179
180        RefCell::borrow_mut(&self.out_edge.upgrade().unwrap()).0 =
181            prefix_sums.last().unwrap().clone();
182
183        for i in 0..self.sum_edges.len() {
184            RefCell::borrow_mut(&self.sum_edges[i].upgrade().unwrap()).0 =
185                &RefCell::borrow(&self.out_edge.upgrade().unwrap()).1
186                    - &prefix_sums[i]
187                    - &suffix_sums[i + 1];
188        }
189    }
190}