multi_skill/systems/true_skill/
nodes.rs1use super::normal::{Gaussian, ONE, ZERO};
2use std::cell::RefCell;
3use std::rc::{Rc, Weak};
4
5pub type Message = Gaussian;
6
7pub trait TreeNode {
8 fn infer(&mut self);
9}
10
11pub trait ValueNode: TreeNode {
12 fn add_edge(&mut self) -> Weak<RefCell<(Message, Message)>>;
13}
14
15pub trait FuncNode: TreeNode {
16 fn new(neighbours: &mut [&mut dyn ValueNode]) -> Self;
17}
18
19#[derive(Clone)]
20pub struct ProdNode {
21 edges: Vec<Rc<RefCell<(Message, Message)>>>,
22}
23
24#[derive(Clone)]
25pub struct LeqNode {
26 eps: f64,
27 edge: Rc<RefCell<(Message, Message)>>,
28}
29
30#[derive(Clone)]
31pub struct GreaterNode {
32 eps: f64,
33 edge: Rc<RefCell<(Message, Message)>>,
34}
35
36#[derive(Clone)]
37pub struct SumNode {
38 out_edge: Weak<RefCell<(Message, Message)>>,
39 sum_edges: Vec<Weak<RefCell<(Message, Message)>>>,
40}
41
42impl TreeNode for ProdNode {
43 fn infer(&mut self) {
44 fn get_prefix_prods(from: &[Rc<RefCell<(Message, Message)>>]) -> Vec<Message> {
45 let mut prefix_prods = Vec::with_capacity(from.len() + 1);
46 prefix_prods.push(ONE);
47
48 for val in from {
49 let (ref val, _) = *val.borrow();
50 prefix_prods.push(prefix_prods.last().unwrap() * val);
51 }
52
53 prefix_prods
54 }
55
56 let prefix_prods = get_prefix_prods(self.edges.as_slice());
57
58 self.edges.reverse();
59 let mut suffix_prods = get_prefix_prods(self.edges.as_slice());
60 self.edges.reverse();
61 suffix_prods.reverse();
62 let suffix_prods = suffix_prods;
63
64 for i in 0..self.edges.len() {
65 RefCell::borrow_mut(&self.edges[i]).1 = &prefix_prods[i] * &suffix_prods[i + 1];
66 }
67 }
68}
69
70impl ValueNode for ProdNode {
71 fn add_edge(&mut self) -> Weak<RefCell<(Message, Message)>> {
72 self.edges.push(Rc::new(RefCell::new((ONE, ZERO))));
73 Rc::downgrade(&self.edges.last().unwrap())
74 }
75}
76
77impl ProdNode {
78 pub fn get_edges_mut(&mut self) -> &mut Vec<Rc<RefCell<(Message, Message)>>> {
79 &mut self.edges
80 }
81
82 pub fn get_edges(&self) -> &Vec<Rc<RefCell<(Message, Message)>>> {
83 &self.edges
84 }
85
86 pub fn new() -> Self {
87 ProdNode { edges: Vec::new() }
88 }
89}
90
91impl TreeNode for LeqNode {
92 fn infer(&mut self) {
93 let ans;
94 {
95 ans = RefCell::borrow(&self.edge).0.leq_eps(self.eps);
96 }
97 RefCell::borrow_mut(&self.edge).1 = ans;
98 }
99}
100
101impl ValueNode for LeqNode {
102 fn add_edge(&mut self) -> Weak<RefCell<(Message, Message)>> {
103 Rc::downgrade(&self.edge)
104 }
105}
106
107impl LeqNode {
108 pub fn new(eps: f64) -> LeqNode {
109 LeqNode {
110 eps,
111 edge: Rc::new(RefCell::new((ZERO, ZERO))),
112 }
113 }
114}
115
116impl TreeNode for GreaterNode {
117 fn infer(&mut self) {
118 let ans;
119 {
120 ans = RefCell::borrow(&self.edge).0.greater_eps(self.eps);
121 }
122 RefCell::borrow_mut(&self.edge).1 = ans;
123 }
124}
125
126impl ValueNode for GreaterNode {
127 fn add_edge(&mut self) -> Weak<RefCell<(Message, Message)>> {
128 Rc::downgrade(&self.edge)
129 }
130}
131
132impl GreaterNode {
133 pub fn new(eps: f64) -> GreaterNode {
134 GreaterNode {
135 eps,
136 edge: Rc::new(RefCell::new((ZERO, ZERO))),
137 }
138 }
139}
140
141impl FuncNode for SumNode {
142 fn new(neighbours: &mut [&mut dyn ValueNode]) -> Self {
143 assert!(!neighbours.is_empty());
144
145 let sum_edges: Vec<_> = neighbours
146 .iter_mut()
147 .skip(1)
148 .map(|nb| nb.add_edge())
149 .collect();
150
151 SumNode {
152 out_edge: neighbours[0].add_edge(),
153 sum_edges,
154 }
155 }
156}
157
158impl TreeNode for SumNode {
159 fn infer(&mut self) {
160 fn get_prefix_sums(from: &[Weak<RefCell<(Message, Message)>>]) -> Vec<Message> {
161 let mut prefix_sums = Vec::with_capacity(from.len() + 1);
162 prefix_sums.push(ZERO);
163
164 for val in from {
165 let val = val.upgrade().unwrap();
166 let (_, ref val) = *val.borrow();
167 prefix_sums.push(prefix_sums.last().unwrap() + val);
168 }
169
170 prefix_sums
171 }
172
173 let prefix_sums = get_prefix_sums(self.sum_edges.as_slice());
174 self.sum_edges.reverse();
175 let mut suffix_sums = get_prefix_sums(self.sum_edges.as_slice());
176 self.sum_edges.reverse();
177 suffix_sums.reverse();
178 let suffix_sums = suffix_sums;
179
180 RefCell::borrow_mut(&self.out_edge.upgrade().unwrap()).0 =
181 prefix_sums.last().unwrap().clone();
182
183 for i in 0..self.sum_edges.len() {
184 RefCell::borrow_mut(&self.sum_edges[i].upgrade().unwrap()).0 =
185 &RefCell::borrow(&self.out_edge.upgrade().unwrap()).1
186 - &prefix_sums[i]
187 - &suffix_sums[i + 1];
188 }
189 }
190}