1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
use crate::LogProbVector;
use ndarray::{Array, ArrayD, Axis, Dimension, RemoveAxis};
#[derive(Debug)]
struct Node {
parents: Vec<(usize, LogProbVector)>,
children: Vec<(usize, LogProbVector)>,
log_probas: ArrayD<f32>,
evidence: Option<usize>,
lambda: Option<LogProbVector>,
pi: Option<LogProbVector>,
}
impl Node {
fn evidence_vec(&self) -> LogProbVector {
if let Some(id) = self.evidence {
LogProbVector::deterministic(self.log_probas.shape()[0], id)
} else {
LogProbVector::uniform(self.log_probas.shape()[0])
}
}
fn compute_lambda(&self) -> LogProbVector {
self.children
.iter()
.fold(self.evidence_vec(), |mut curr_ev, &(_, ref lambda)| {
curr_ev.prod(lambda);
curr_ev
})
}
fn compute_and_cache_lambda(&mut self) {
let lambda = self.compute_lambda();
self.lambda = Some(lambda.clone());
}
fn get_or_compute_lambda(&mut self) -> LogProbVector {
if self.lambda.is_none() {
self.compute_and_cache_lambda();
}
self.lambda.clone().unwrap()
}
fn compute_pi(&self) -> LogProbVector {
let mut pi = self.log_probas.clone();
for (_, ref pi_msg) in self.parents.iter().rev() {
pi = crate::math::log_contract(
pi.view(),
pi_msg.log_probabilities(),
Axis(pi.ndim() - 1),
);
}
// sanity check
assert!(pi.ndim() == 1);
LogProbVector::from_log_probabilities(pi.into_shape((self.log_probas.shape()[0],)).unwrap())
}
fn compute_and_cache_pi(&mut self) {
let pi = self.compute_pi();
self.pi = Some(pi.clone());
}
fn get_or_compute_pi(&mut self) -> LogProbVector {
if self.pi.is_none() {
self.compute_and_cache_pi();
}
self.pi.clone().unwrap()
}
}
/// Representation of a Bayesian Network
///
/// Once built by adding the nodes one by one, you can use it for inference
/// computation on the graph given some evidence.
pub struct BayesNet {
nodes: Vec<Node>,
}
impl BayesNet {
/// Create a new empty Bayesian Network
pub fn new() -> BayesNet {
BayesNet { nodes: Vec::new() }
}
/// Add a new node to the network
///
/// You need to specify the list of its parents, and an array of probabilities representing `p(x | parents)`.
/// If the parents are `(p1, ... pk)`, the shape of the array should thus be: `(N, N_p1, ... N_pk)`, where
/// `N` is the number of possible values for the current variables, and `N_pi` is the number of values of
/// parent `pi`.
///
/// If the node has no parents, the propabilities must be single-dimenstionnal and represents a prior.
///
/// All values of probabilities should be finite, but the probabilities array does not need to be normalized,
/// as it will be during the construction process.
pub fn add_node_from_probabilities<D: Dimension + RemoveAxis>(
&mut self,
parents: &[usize],
probabilities: Array<f32, D>,
) -> usize {
self.add_node_from_log_probabilities(parents, probabilities.mapv(f32::ln))
}
/// Add a new node to the network from log-probabilities
///
/// Same as `add_node_from_probabilities`, but the input is in the form of log-probabilities, for greated precision.
///
/// All values of log-probas should be strictly smaller than `+inf`. `-inf` is valid and represents a
/// probability of 0. The probabilities array does not need to be normalized, as it will be during the construction
/// process. For example, the log-vector `[0.0, -inf]` will represent a vector of probabilities of `[1.0, 0.0]`.
///
/// Log-probabilities are intepreted as computed with the natural logarithm (base e).
pub fn add_node_from_log_probabilities<D: Dimension + RemoveAxis>(
&mut self,
parents: &[usize],
mut log_probabilities: Array<f32, D>,
) -> usize {
let id = self.nodes.len();
// sanity checks
let shape = log_probabilities.shape();
assert!(
shape.len() == parents.len() + 1,
"Dimensions of log_probas array does not match number of parents"
);
for (i, (&val, &parent)) in shape.iter().skip(1).zip(parents.iter()).enumerate() {
let parent_n_val = self.nodes[parent].log_probas.shape()[0];
if parent_n_val != val {
panic!("Dimension {} of log_probas array does not match its associated parent number of element: got {} but parent {} has {}.", i+1, val, parent, parent_n_val);
}
}
// the shapes match, proceed to insert the node
for &p in parents {
let size = self.nodes[p].log_probas.shape()[0];
self.nodes[p]
.children
.push((id, LogProbVector::uniform(size)));
}
crate::math::normalize_log_probas(log_probabilities.view_mut());
let parents = parents
.iter()
.map(|&p| {
(
p,
LogProbVector::uniform(self.nodes[p].log_probas.shape()[0]),
)
})
.collect();
self.nodes.push(Node {
parents,
children: Vec::new(),
log_probas: log_probabilities.into_dyn(),
evidence: None,
lambda: None,
pi: None,
});
id
}
/// Sets the evidence for the network
///
/// Input is interpreted as a list of `(node_id, node_value)`. Out-of-range evidence is not checked, but
/// will result into a probability of `0`.
pub fn set_evidence(&mut self, evidence: &[(usize, usize)]) {
// Reset the evidences to None before applying the new evidence
for node in &mut self.nodes {
node.evidence = None;
}
for &(node, value) in evidence {
self.nodes[node].evidence = Some(value);
}
}
/// Resets the internal state of the inference algorithm, to begin a new inference
pub fn reset_state(&mut self) {
for node in &mut self.nodes {
for &mut (_, ref mut msg) in &mut node.children {
msg.reset();
}
for &mut (_, ref mut msg) in &mut node.parents {
msg.reset();
}
node.lambda = None;
node.pi = None;
}
}
/// Compute the current state belief of each node according to the current internal messages
pub fn beliefs(&self) -> Vec<LogProbVector> {
self.nodes
.iter()
.map(|node| {
let mut lambda = node.lambda.clone().unwrap_or_else(|| node.compute_lambda());
let pi = node.pi.clone().unwrap_or_else(|| node.compute_pi());
lambda.prod(&pi);
lambda.renormalize();
lambda
})
.collect()
}
/// Compute one step of the Loopy Belief Propagation Algorithm
///
/// The algorithm can be run for any number of steps. it is up to you to decide when to stop.
///
/// A classic stopping criterion is when the yielded beliefs stop significantly changing.
pub fn step(&mut self) {
// At the start of the algorithm, we assume all present cached values for lambda and pi are valid for
// the currently stored messages. We will then compute the new messages and invalidate the caches.
// Compute the new messages and store them into thes two big vectors, once this done we will replace
// them into the graph.
// Their layout is (from, to, content). We pre-allocate the correct capacity.
let mut pi_msgs: Vec<(usize, usize, LogProbVector)> =
Vec::with_capacity(self.nodes.iter().map(|n| n.children.len()).sum());
let mut lambda_msgs: Vec<(usize, usize, LogProbVector)> =
Vec::with_capacity(self.nodes.iter().map(|n| n.parents.len()).sum());
for (id, node) in self.nodes.iter_mut().enumerate() {
// compute the pi messages:
let mut pi = node.get_or_compute_pi();
pi.prod(&node.evidence_vec());
for &(child_id, _) in &node.children {
let mut msg = node
.children
.iter()
.filter(|&&(cid, _)| cid != child_id)
.fold(pi.clone(), |mut acc, (_, ref v)| {
acc.prod(v);
acc
});
msg.renormalize();
pi_msgs.push((id, child_id, msg));
}
// compute the lambda messages:
let lambda = node.get_or_compute_lambda();
for &(parent_id, _) in &node.parents {
let acc = node
.parents
.iter()
.enumerate()
.rev()
.filter(|&(_, &(pid, _))| pid != parent_id)
.fold(node.log_probas.clone(), |acc, (axid, &(_, ref v))| {
crate::math::log_contract(acc.view(), v.log_probabilities(), Axis(axid + 1))
});
let acc =
crate::math::log_contract(acc.view(), lambda.log_probabilities(), Axis(0));
assert!(acc.ndim() == 1);
let shape = (acc.len(),);
let mut msg = LogProbVector::from_log_probabilities(acc.into_shape(shape).unwrap());
msg.renormalize();
lambda_msgs.push((id, parent_id, msg));
}
// invalidate the cached lambda & pi
node.lambda = None;
node.pi = None;
}
// Finally, store the msgs in their new place
for (from, to, msg) in pi_msgs {
if let Some(&mut (_, ref mut place)) = self.nodes[to]
.parents
.iter_mut()
.find(|&&mut (parent_id, _)| parent_id == from)
{
*place = msg;
} else {
panic!(
"Message from {} to {} who doesn't recognize its parent?!",
from, to
);
}
}
for (from, to, msg) in lambda_msgs {
if let Some(&mut (_, ref mut place)) = self.nodes[to]
.children
.iter_mut()
.find(|&&mut (child_id, _)| child_id == from)
{
*place = msg;
} else {
panic!(
"Message from {} to {} who doesn't recognize its child?!",
from, to
);
}
}
}
}