causal_hub/models/bayesian_network/categorical/
potential.rs

1use std::ops::{Div, DivAssign, Mul, MulAssign};
2
3use approx::{AbsDiffEq, RelativeEq};
4use itertools::Itertools;
5use ndarray::prelude::*;
6
7use crate::{
8    datasets::{CatEv, CatEvT},
9    models::{CPD, CatCPD, Labelled, Phi},
10    types::{Labels, Set, States},
11};
12
13/// A categorical potential.
14#[derive(Clone, Debug)]
15pub struct CatPhi {
16    labels: Labels,
17    states: States,
18    shape: Array1<usize>,
19    parameters: ArrayD<f64>,
20}
21
22impl Labelled for CatPhi {
23    #[inline]
24    fn labels(&self) -> &Labels {
25        &self.labels
26    }
27}
28
29impl PartialEq for CatPhi {
30    fn eq(&self, other: &Self) -> bool {
31        self.labels.eq(&other.labels)
32            && self.states.eq(&other.states)
33            && self.shape.eq(&other.shape)
34            && self.parameters.eq(&other.parameters)
35    }
36}
37
38impl AbsDiffEq for CatPhi {
39    type Epsilon = f64;
40
41    fn default_epsilon() -> Self::Epsilon {
42        Self::Epsilon::default_epsilon()
43    }
44
45    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
46        self.labels.eq(&other.labels)
47            && self.states.eq(&other.states)
48            && self.shape.eq(&other.shape)
49            && self.parameters.abs_diff_eq(&other.parameters, epsilon)
50    }
51}
52
53impl RelativeEq for CatPhi {
54    fn default_max_relative() -> Self::Epsilon {
55        Self::Epsilon::default_max_relative()
56    }
57
58    fn relative_eq(
59        &self,
60        other: &Self,
61        epsilon: Self::Epsilon,
62        max_relative: Self::Epsilon,
63    ) -> bool {
64        self.labels.eq(&other.labels)
65            && self.states.eq(&other.states)
66            && self.shape.eq(&other.shape)
67            && self
68                .parameters
69                .relative_eq(&other.parameters, epsilon, max_relative)
70    }
71}
72
73impl MulAssign<&CatPhi> for CatPhi {
74    fn mul_assign(&mut self, rhs: &CatPhi) {
75        // Get the union of the states.
76        let mut states = self.states.clone();
77        states.extend(rhs.states.clone());
78        // Sort the states by labels.
79        states.sort_keys();
80
81        // Order LHS axes w.r.t. new states.
82        let mut lhs_axes: Vec<_> = (0..self.states.len()).collect();
83        lhs_axes.sort_by_key(|&i| self.states.get_index(i).unwrap().0);
84        let mut lhs_parameters = self.parameters.clone().permuted_axes(lhs_axes);
85        // Get the axes to insert for LHS broadcasting.
86        let lhs_axes = states.keys().enumerate();
87        let lhs_axes = lhs_axes.filter_map(|(i, k)| (!self.states.contains_key(k)).then_some(i));
88        let lhs_axes: Vec<_> = lhs_axes.sorted().collect();
89        // Insert axes in sorted order for LHS broadcasting.
90        lhs_axes.into_iter().for_each(|i| {
91            lhs_parameters.insert_axis_inplace(Axis(i));
92        });
93
94        // Order RHS axes w.r.t. new states.
95        let mut rhs_axes: Vec<_> = (0..rhs.states.len()).collect();
96        rhs_axes.sort_by_key(|&i| rhs.states.get_index(i).unwrap().0);
97        let mut rhs_parameters = rhs.parameters.clone().permuted_axes(rhs_axes);
98        // Get the axes to insert for RHS broadcasting.
99        let rhs_axes = states.keys().enumerate();
100        let rhs_axes = rhs_axes.filter_map(|(i, k)| (!rhs.states.contains_key(k)).then_some(i));
101        let rhs_axes: Vec<_> = rhs_axes.sorted().collect();
102        // Insert axes in sorted order for RHS broadcasting.
103        rhs_axes.into_iter().for_each(|i| {
104            rhs_parameters.insert_axis_inplace(Axis(i));
105        });
106
107        // Perform element-wise multiplication.
108        let parameters = lhs_parameters * rhs_parameters;
109
110        // Get new labels.
111        let labels: Labels = states.keys().cloned().collect();
112        // Get new shape.
113        let shape = Array::from_iter(states.values().map(Set::len));
114
115        // Update self.
116        self.states = states;
117        self.labels = labels;
118        self.shape = shape;
119        self.parameters = parameters;
120    }
121}
122
123impl Mul<&CatPhi> for &CatPhi {
124    type Output = CatPhi;
125
126    #[inline]
127    fn mul(self, rhs: &CatPhi) -> Self::Output {
128        let mut lhs = self.clone();
129        lhs *= rhs;
130        lhs
131    }
132}
133
134impl DivAssign<&CatPhi> for CatPhi {
135    fn div_assign(&mut self, rhs: &CatPhi) {
136        // Assert that RHS states are a subset of LHS states.
137        assert!(
138            rhs.states.keys().all(|k| self.states.contains_key(k)),
139            "Failed to divide potentials: \n\
140            \t expected:    RHS states to be a subset of LHS states , \n\
141            \t found:       LHS states = {:?} , \n\
142            \t              RHS states = {:?} .",
143            self.states,
144            rhs.states,
145        );
146
147        // Add a small constant to ensure 0 / 0 = 0.
148        let rhs_parameters = &rhs.parameters + f64::MIN_POSITIVE;
149
150        // Order RHS axes w.r.t. new states.
151        let mut rhs_axes: Vec<_> = (0..rhs.states.len()).collect();
152        rhs_axes.sort_by_key(|&i| rhs.states.get_index(i).unwrap().0);
153        let mut rhs_parameters = rhs_parameters.permuted_axes(rhs_axes);
154        // Get the axes to insert for RHS broadcasting.
155        let rhs_axes = self.states.keys().enumerate();
156        let rhs_axes = rhs_axes.filter_map(|(i, k)| (!rhs.states.contains_key(k)).then_some(i));
157        let rhs_axes: Vec<_> = rhs_axes.sorted().collect();
158        // Insert axes in sorted order for RHS broadcasting.
159        rhs_axes.into_iter().for_each(|i| {
160            rhs_parameters.insert_axis_inplace(Axis(i));
161        });
162
163        // Perform element-wise division with 0 / 0 = 0.
164        self.parameters /= &rhs_parameters;
165    }
166}
167
168impl Div<&CatPhi> for &CatPhi {
169    type Output = CatPhi;
170
171    #[inline]
172    fn div(self, rhs: &CatPhi) -> Self::Output {
173        let mut lhs = self.clone();
174        lhs /= rhs;
175        lhs
176    }
177}
178
179impl Phi for CatPhi {
180    type CPD = CatCPD;
181    type Parameters = ArrayD<f64>;
182    type Evidence = CatEv;
183
184    #[inline]
185    fn parameters(&self) -> &Self::Parameters {
186        &self.parameters
187    }
188
189    fn parameters_size(&self) -> usize {
190        self.parameters.len()
191    }
192
193    fn condition(&self, e: &Self::Evidence) -> Self {
194        // Assert that the evidence states match the potential states.
195        assert_eq!(
196            e.states(),
197            self.states(),
198            "Failed to condition on evidence: \n\
199            \t expected:    evidence states to match potential states , \n\
200            \t found:       potential states = {:?} , \n\
201            \t              evidence  states = {:?} .",
202            self.states(),
203            e.states(),
204        );
205
206        // Get the evidence and remove nones.
207        let e = e.evidences().iter().flatten();
208        // Assert that the evidence is certain and positive.
209        let e = e.cloned().map(|e| match e {
210            CatEvT::CertainPositive { event, state } => (event, state),
211            _ => panic!(
212                "Failed to condition on evidence: \n
213                \t expected:    CertainPositive , \n\
214                \t found:       {:?} .",
215                e
216            ),
217        });
218
219        // Get states and parameters.
220        let mut states = self.states.clone();
221        let mut parameters = self.parameters.clone();
222
223        // Condition in reverse order to avoid axis shifting.
224        e.rev().for_each(|(event, state)| {
225            parameters.index_axis_inplace(Axis(event), state);
226            states.shift_remove_index(event);
227        });
228
229        // Return self.
230        Self::new(states, parameters)
231    }
232
233    fn marginalize(&self, x: &Set<usize>) -> Self {
234        // Base case: if no variables to marginalize, return self.
235        if x.is_empty() {
236            return self.clone();
237        }
238
239        // Assert X is a subset of the variables.
240        x.iter().for_each(|&x| {
241            assert!(
242                x < self.labels.len(),
243                "Variable index out of bounds: \n\
244                \t expected:    x <  {} , \n\
245                \t found:       x == {} .",
246                self.labels.len(),
247                x,
248            );
249        });
250
251        // Get the states and the parameters.
252        let states = self.states.clone();
253        let mut parameters = self.parameters.clone();
254
255        // Filter the states.
256        let states = states.into_iter().enumerate();
257        let states = states.filter_map(|(i, s)| (!x.contains(&i)).then_some(s));
258        let states = states.collect();
259
260        // Sum over the axes in reverse order to avoid shifting.
261        x.iter().sorted().rev().for_each(|&i| {
262            parameters = parameters.sum_axis(Axis(i));
263        });
264
265        // Return the new potential.
266        Self::new(states, parameters)
267    }
268
269    #[inline]
270    fn normalize(&self) -> Self {
271        // Get the parameters.
272        let mut parameters = self.parameters.clone();
273        // Normalize the parameters.
274        parameters /= parameters.sum();
275        // Return the new potential.
276        Self::new(self.states.clone(), parameters)
277    }
278
279    fn from_cpd(cpd: Self::CPD) -> Self {
280        // Merge conditioning states and states in this order.
281        let mut states = cpd.conditioning_states().clone();
282        states.extend(cpd.states().clone());
283        // Get n-dimensional shape.
284        let shape: Vec<_> = states.values().map(Set::len).collect();
285        // Reshape the parameters to match the new shape.
286        let parameters = cpd.parameters().clone();
287        let parameters = parameters
288            .into_dyn()
289            .into_shape_with_order(shape)
290            .expect("Failed to reshape parameters.");
291
292        // Get the new axes order w.r.t. sorted labels.
293        let mut axes: Vec<_> = (0..states.len()).collect();
294        axes.sort_by_key(|&i| states.get_index(i).unwrap().0);
295        // Sort the states by labels.
296        states.sort_keys();
297        // Swap axes to match the new order.
298        let parameters = parameters.permuted_axes(axes);
299
300        // Return the new potential.
301        Self::new(states, parameters)
302    }
303
304    fn into_cpd(self, x: &Set<usize>, z: &Set<usize>) -> Self::CPD {
305        // Assert that X and Z are disjoint.
306        assert!(
307            x.is_disjoint(z),
308            "Variables and conditioning variables must be disjoint."
309        );
310        // Assert that X and Z cover all variables.
311        assert!(
312            (x | z).iter().sorted().cloned().eq(0..self.labels.len()),
313            "Variables and conditioning variables must cover all potential variables."
314        );
315
316        // Split states into states and conditioning states.
317        let states_x: States = x
318            .iter()
319            .map(|&i| {
320                self.states
321                    .get_index(i)
322                    .map(|(k, v)| (k.clone(), v.clone()))
323                    .unwrap()
324            })
325            .collect();
326        let states_z: States = z
327            .iter()
328            .map(|&i| {
329                self.states
330                    .get_index(i)
331                    .map(|(k, v)| (k.clone(), v.clone()))
332                    .unwrap()
333            })
334            .collect();
335
336        // Get new axes order.
337        let axes: Vec<_> = z.iter().chain(x).cloned().collect();
338        // Permute parameters to match the new order.
339        let parameters = self.parameters.permuted_axes(axes);
340        // Get the new 2D shape.
341        let shape: (usize, usize) = (
342            states_z.values().map(Set::len).product(),
343            states_x.values().map(Set::len).product(),
344        );
345        // Reshape the parameters to the new 2D shape.
346        let mut parameters = parameters
347            .into_shape_clone(shape)
348            .expect("Failed to reshape parameters.");
349
350        // Normalize the parameters.
351        parameters /= &parameters.sum_axis(Axis(1)).insert_axis(Axis(1));
352
353        // Create the new CPD.
354        CatCPD::new(states_x, states_z, parameters)
355    }
356}
357
358impl CatPhi {
359    /// Creates a new categorical potential.
360    ///
361    /// # Arguments
362    ///
363    /// * `states` - A map from variable names to their possible states.
364    /// * `parameters` - A multi-dimensional array of parameters.
365    ///
366    /// # Returns
367    ///
368    /// A new categorical potential instance.
369    ///
370    pub fn new(mut states: States, mut parameters: ArrayD<f64>) -> Self {
371        // Get labels.
372        let mut labels: Labels = states.keys().cloned().collect();
373        // Get shape.
374        let mut shape = Array::from_iter(states.values().map(Set::len));
375        // Assert parameters shape matches states shape.
376        assert_eq!(
377            parameters.shape(),
378            shape.as_slice().unwrap(),
379            "Parameters shape does not match states shape: \n\
380            \t expected:    {:?} , \n\
381            \t found:       {:?} .",
382            shape,
383            parameters.shape(),
384        );
385
386        // Sort states if not sorted and permute parameters accordingly.
387        if !states.keys().is_sorted() {
388            // Get the new axes order w.r.t. sorted labels.
389            let mut axes: Vec<_> = (0..states.len()).collect();
390            axes.sort_by_key(|&i| states.get_index(i).unwrap().0);
391            // Sort the states by labels.
392            states.sort_keys();
393            // Permute the parameters to match the new order.
394            parameters = parameters.permuted_axes(axes);
395            // Update the labels.
396            labels = states.keys().cloned().collect();
397            // Update the shape.
398            shape = states.values().map(Set::len).collect();
399        }
400
401        Self {
402            labels,
403            states,
404            shape,
405            parameters,
406        }
407    }
408
409    /// States of the potential.
410    ///
411    /// # Returns
412    ///
413    /// A reference to the states.
414    ///
415    #[inline]
416    pub const fn states(&self) -> &States {
417        &self.states
418    }
419
420    /// Shape of the potential.
421    ///
422    /// # Returns
423    ///
424    /// A reference to the shape.
425    ///
426    #[inline]
427    pub const fn shape(&self) -> &Array1<usize> {
428        &self.shape
429    }
430}