causal_hub/models/bayesian_network/categorical/
potential.rs1use std::ops::{Div, DivAssign, Mul, MulAssign};
2
3use approx::{AbsDiffEq, RelativeEq};
4use itertools::Itertools;
5use ndarray::prelude::*;
6
7use crate::{
8 datasets::{CatEv, CatEvT},
9 models::{CPD, CatCPD, Labelled, Phi},
10 types::{Labels, Set, States},
11};
12
13#[derive(Clone, Debug)]
15pub struct CatPhi {
16 labels: Labels,
17 states: States,
18 shape: Array1<usize>,
19 parameters: ArrayD<f64>,
20}
21
22impl Labelled for CatPhi {
23 #[inline]
24 fn labels(&self) -> &Labels {
25 &self.labels
26 }
27}
28
29impl PartialEq for CatPhi {
30 fn eq(&self, other: &Self) -> bool {
31 self.labels.eq(&other.labels)
32 && self.states.eq(&other.states)
33 && self.shape.eq(&other.shape)
34 && self.parameters.eq(&other.parameters)
35 }
36}
37
38impl AbsDiffEq for CatPhi {
39 type Epsilon = f64;
40
41 fn default_epsilon() -> Self::Epsilon {
42 Self::Epsilon::default_epsilon()
43 }
44
45 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
46 self.labels.eq(&other.labels)
47 && self.states.eq(&other.states)
48 && self.shape.eq(&other.shape)
49 && self.parameters.abs_diff_eq(&other.parameters, epsilon)
50 }
51}
52
53impl RelativeEq for CatPhi {
54 fn default_max_relative() -> Self::Epsilon {
55 Self::Epsilon::default_max_relative()
56 }
57
58 fn relative_eq(
59 &self,
60 other: &Self,
61 epsilon: Self::Epsilon,
62 max_relative: Self::Epsilon,
63 ) -> bool {
64 self.labels.eq(&other.labels)
65 && self.states.eq(&other.states)
66 && self.shape.eq(&other.shape)
67 && self
68 .parameters
69 .relative_eq(&other.parameters, epsilon, max_relative)
70 }
71}
72
73impl MulAssign<&CatPhi> for CatPhi {
74 fn mul_assign(&mut self, rhs: &CatPhi) {
75 let mut states = self.states.clone();
77 states.extend(rhs.states.clone());
78 states.sort_keys();
80
81 let mut lhs_axes: Vec<_> = (0..self.states.len()).collect();
83 lhs_axes.sort_by_key(|&i| self.states.get_index(i).unwrap().0);
84 let mut lhs_parameters = self.parameters.clone().permuted_axes(lhs_axes);
85 let lhs_axes = states.keys().enumerate();
87 let lhs_axes = lhs_axes.filter_map(|(i, k)| (!self.states.contains_key(k)).then_some(i));
88 let lhs_axes: Vec<_> = lhs_axes.sorted().collect();
89 lhs_axes.into_iter().for_each(|i| {
91 lhs_parameters.insert_axis_inplace(Axis(i));
92 });
93
94 let mut rhs_axes: Vec<_> = (0..rhs.states.len()).collect();
96 rhs_axes.sort_by_key(|&i| rhs.states.get_index(i).unwrap().0);
97 let mut rhs_parameters = rhs.parameters.clone().permuted_axes(rhs_axes);
98 let rhs_axes = states.keys().enumerate();
100 let rhs_axes = rhs_axes.filter_map(|(i, k)| (!rhs.states.contains_key(k)).then_some(i));
101 let rhs_axes: Vec<_> = rhs_axes.sorted().collect();
102 rhs_axes.into_iter().for_each(|i| {
104 rhs_parameters.insert_axis_inplace(Axis(i));
105 });
106
107 let parameters = lhs_parameters * rhs_parameters;
109
110 let labels: Labels = states.keys().cloned().collect();
112 let shape = Array::from_iter(states.values().map(Set::len));
114
115 self.states = states;
117 self.labels = labels;
118 self.shape = shape;
119 self.parameters = parameters;
120 }
121}
122
123impl Mul<&CatPhi> for &CatPhi {
124 type Output = CatPhi;
125
126 #[inline]
127 fn mul(self, rhs: &CatPhi) -> Self::Output {
128 let mut lhs = self.clone();
129 lhs *= rhs;
130 lhs
131 }
132}
133
134impl DivAssign<&CatPhi> for CatPhi {
135 fn div_assign(&mut self, rhs: &CatPhi) {
136 assert!(
138 rhs.states.keys().all(|k| self.states.contains_key(k)),
139 "Failed to divide potentials: \n\
140 \t expected: RHS states to be a subset of LHS states , \n\
141 \t found: LHS states = {:?} , \n\
142 \t RHS states = {:?} .",
143 self.states,
144 rhs.states,
145 );
146
147 let rhs_parameters = &rhs.parameters + f64::MIN_POSITIVE;
149
150 let mut rhs_axes: Vec<_> = (0..rhs.states.len()).collect();
152 rhs_axes.sort_by_key(|&i| rhs.states.get_index(i).unwrap().0);
153 let mut rhs_parameters = rhs_parameters.permuted_axes(rhs_axes);
154 let rhs_axes = self.states.keys().enumerate();
156 let rhs_axes = rhs_axes.filter_map(|(i, k)| (!rhs.states.contains_key(k)).then_some(i));
157 let rhs_axes: Vec<_> = rhs_axes.sorted().collect();
158 rhs_axes.into_iter().for_each(|i| {
160 rhs_parameters.insert_axis_inplace(Axis(i));
161 });
162
163 self.parameters /= &rhs_parameters;
165 }
166}
167
168impl Div<&CatPhi> for &CatPhi {
169 type Output = CatPhi;
170
171 #[inline]
172 fn div(self, rhs: &CatPhi) -> Self::Output {
173 let mut lhs = self.clone();
174 lhs /= rhs;
175 lhs
176 }
177}
178
179impl Phi for CatPhi {
180 type CPD = CatCPD;
181 type Parameters = ArrayD<f64>;
182 type Evidence = CatEv;
183
184 #[inline]
185 fn parameters(&self) -> &Self::Parameters {
186 &self.parameters
187 }
188
189 fn parameters_size(&self) -> usize {
190 self.parameters.len()
191 }
192
193 fn condition(&self, e: &Self::Evidence) -> Self {
194 assert_eq!(
196 e.states(),
197 self.states(),
198 "Failed to condition on evidence: \n\
199 \t expected: evidence states to match potential states , \n\
200 \t found: potential states = {:?} , \n\
201 \t evidence states = {:?} .",
202 self.states(),
203 e.states(),
204 );
205
206 let e = e.evidences().iter().flatten();
208 let e = e.cloned().map(|e| match e {
210 CatEvT::CertainPositive { event, state } => (event, state),
211 _ => panic!(
212 "Failed to condition on evidence: \n
213 \t expected: CertainPositive , \n\
214 \t found: {:?} .",
215 e
216 ),
217 });
218
219 let mut states = self.states.clone();
221 let mut parameters = self.parameters.clone();
222
223 e.rev().for_each(|(event, state)| {
225 parameters.index_axis_inplace(Axis(event), state);
226 states.shift_remove_index(event);
227 });
228
229 Self::new(states, parameters)
231 }
232
233 fn marginalize(&self, x: &Set<usize>) -> Self {
234 if x.is_empty() {
236 return self.clone();
237 }
238
239 x.iter().for_each(|&x| {
241 assert!(
242 x < self.labels.len(),
243 "Variable index out of bounds: \n\
244 \t expected: x < {} , \n\
245 \t found: x == {} .",
246 self.labels.len(),
247 x,
248 );
249 });
250
251 let states = self.states.clone();
253 let mut parameters = self.parameters.clone();
254
255 let states = states.into_iter().enumerate();
257 let states = states.filter_map(|(i, s)| (!x.contains(&i)).then_some(s));
258 let states = states.collect();
259
260 x.iter().sorted().rev().for_each(|&i| {
262 parameters = parameters.sum_axis(Axis(i));
263 });
264
265 Self::new(states, parameters)
267 }
268
269 #[inline]
270 fn normalize(&self) -> Self {
271 let mut parameters = self.parameters.clone();
273 parameters /= parameters.sum();
275 Self::new(self.states.clone(), parameters)
277 }
278
279 fn from_cpd(cpd: Self::CPD) -> Self {
280 let mut states = cpd.conditioning_states().clone();
282 states.extend(cpd.states().clone());
283 let shape: Vec<_> = states.values().map(Set::len).collect();
285 let parameters = cpd.parameters().clone();
287 let parameters = parameters
288 .into_dyn()
289 .into_shape_with_order(shape)
290 .expect("Failed to reshape parameters.");
291
292 let mut axes: Vec<_> = (0..states.len()).collect();
294 axes.sort_by_key(|&i| states.get_index(i).unwrap().0);
295 states.sort_keys();
297 let parameters = parameters.permuted_axes(axes);
299
300 Self::new(states, parameters)
302 }
303
304 fn into_cpd(self, x: &Set<usize>, z: &Set<usize>) -> Self::CPD {
305 assert!(
307 x.is_disjoint(z),
308 "Variables and conditioning variables must be disjoint."
309 );
310 assert!(
312 (x | z).iter().sorted().cloned().eq(0..self.labels.len()),
313 "Variables and conditioning variables must cover all potential variables."
314 );
315
316 let states_x: States = x
318 .iter()
319 .map(|&i| {
320 self.states
321 .get_index(i)
322 .map(|(k, v)| (k.clone(), v.clone()))
323 .unwrap()
324 })
325 .collect();
326 let states_z: States = z
327 .iter()
328 .map(|&i| {
329 self.states
330 .get_index(i)
331 .map(|(k, v)| (k.clone(), v.clone()))
332 .unwrap()
333 })
334 .collect();
335
336 let axes: Vec<_> = z.iter().chain(x).cloned().collect();
338 let parameters = self.parameters.permuted_axes(axes);
340 let shape: (usize, usize) = (
342 states_z.values().map(Set::len).product(),
343 states_x.values().map(Set::len).product(),
344 );
345 let mut parameters = parameters
347 .into_shape_clone(shape)
348 .expect("Failed to reshape parameters.");
349
350 parameters /= ¶meters.sum_axis(Axis(1)).insert_axis(Axis(1));
352
353 CatCPD::new(states_x, states_z, parameters)
355 }
356}
357
358impl CatPhi {
359 pub fn new(mut states: States, mut parameters: ArrayD<f64>) -> Self {
371 let mut labels: Labels = states.keys().cloned().collect();
373 let mut shape = Array::from_iter(states.values().map(Set::len));
375 assert_eq!(
377 parameters.shape(),
378 shape.as_slice().unwrap(),
379 "Parameters shape does not match states shape: \n\
380 \t expected: {:?} , \n\
381 \t found: {:?} .",
382 shape,
383 parameters.shape(),
384 );
385
386 if !states.keys().is_sorted() {
388 let mut axes: Vec<_> = (0..states.len()).collect();
390 axes.sort_by_key(|&i| states.get_index(i).unwrap().0);
391 states.sort_keys();
393 parameters = parameters.permuted_axes(axes);
395 labels = states.keys().cloned().collect();
397 shape = states.values().map(Set::len).collect();
399 }
400
401 Self {
402 labels,
403 states,
404 shape,
405 parameters,
406 }
407 }
408
409 #[inline]
416 pub const fn states(&self) -> &States {
417 &self.states
418 }
419
420 #[inline]
427 pub const fn shape(&self) -> &Array1<usize> {
428 &self.shape
429 }
430}