causal_hub/datasets/table/categorical/
evidence.rs

1use approx::relative_eq;
2use ndarray::prelude::*;
3
4use crate::{
5    datasets::CatTrjEvT,
6    models::Labelled,
7    types::{Labels, Set, States},
8};
9
10/// Categorical evidence type.
11#[non_exhaustive]
12#[derive(Clone, Debug)]
13pub enum CatEvT {
14    /// Certain positive evidence.
15    CertainPositive {
16        /// The observed event of the evidence.
17        event: usize,
18        /// The state of the evidence.
19        state: usize,
20    },
21    /// Certain negative evidence.
22    CertainNegative {
23        /// The observed event of the evidence.
24        event: usize,
25        /// The states of the evidence.
26        not_states: Set<usize>,
27    },
28    /// Uncertain positive evidence.
29    UncertainPositive {
30        /// The observed event of the evidence.
31        event: usize,
32        /// The probabilities of the states.
33        p_states: Array1<f64>,
34    },
35    /// Uncertain negative evidence.
36    UncertainNegative {
37        /// The observed event of the evidence.
38        event: usize,
39        /// The probabilities of the states.
40        p_not_states: Array1<f64>,
41    },
42}
43
44impl From<CatTrjEvT> for CatEvT {
45    fn from(evidence: CatTrjEvT) -> Self {
46        // Get shortened variable types.
47        use CatEvT as U;
48        use CatTrjEvT as T;
49        // Match the evidence type discard the temporal information.
50        match evidence {
51            T::CertainPositiveInterval { event, state, .. } => U::CertainPositive { event, state },
52            T::CertainNegativeInterval {
53                event, not_states, ..
54            } => U::CertainNegative { event, not_states },
55            T::UncertainPositiveInterval {
56                event, p_states, ..
57            } => U::UncertainPositive { event, p_states },
58            T::UncertainNegativeInterval {
59                event,
60                p_not_states,
61                ..
62            } => U::UncertainNegative {
63                event,
64                p_not_states,
65            },
66        }
67    }
68}
69
70impl CatEvT {
71    /// Return the observed event of the evidence.
72    ///
73    /// # Returns
74    ///
75    /// The observed event of the evidence.
76    ///
77    pub const fn event(&self) -> usize {
78        match self {
79            Self::CertainPositive { event, .. }
80            | Self::CertainNegative { event, .. }
81            | Self::UncertainPositive { event, .. }
82            | Self::UncertainNegative { event, .. } => *event,
83        }
84    }
85}
86
87/// Categorical evidence structure.
88#[derive(Clone, Debug)]
89pub struct CatEv {
90    labels: Labels,
91    states: States,
92    shape: Array1<usize>,
93    evidences: Vec<Option<CatEvT>>,
94}
95
96impl Labelled for CatEv {
97    fn labels(&self) -> &Labels {
98        &self.labels
99    }
100}
101
102impl CatEv {
103    /// Creates a new categorical evidence structure.
104    ///
105    /// # Arguments
106    ///
107    /// * `states` - A collection of states, where each state is a tuple of a string and an iterator of strings.
108    /// * `values` - A collection of values, where each value is a categorical evidence type.
109    ///
110    /// # Returns
111    ///
112    /// A new categorical evidence structure.
113    ///
114    pub fn new<I>(mut states: States, values: I) -> Self
115    where
116        I: IntoIterator<Item = CatEvT>,
117    {
118        // Get shortened variable type.
119        use CatEvT as E;
120
121        // Get the sorted labels.
122        let mut labels = states.keys().cloned().collect();
123        // Get the shape of the states.
124        let mut shape = Array::from_iter(states.values().map(Set::len));
125        // Allocate evidences.
126        let mut evidences = vec![None; states.len()];
127
128        // Fill the evidences.
129        values.into_iter().for_each(|e| {
130            // Get the event of the evidence.
131            let event = e.event();
132            // Push the value into the variable events.
133            evidences[event] = Some(e);
134        });
135
136        // Sort states, if necessary.
137        if !states.keys().is_sorted() || !states.values().all(|x| x.iter().is_sorted()) {
138            // Clone the states.
139            let mut new_states = states.clone();
140            // Sort the states.
141            new_states.sort_keys();
142            new_states.values_mut().for_each(Set::sort);
143
144            // Allocate new evidences.
145            let mut new_evidences = vec![None; states.len()];
146
147            // Iterate over the values and insert them into the events map using sorted indices.
148            evidences.into_iter().flatten().for_each(|e| {
149                // Get the event and states of the evidence.
150                let (event, states) = states
151                    .get_index(e.event())
152                    .expect("Failed to get label of evidence.");
153                // Sort the event index.
154                let (event, _, new_states) = new_states
155                    .get_full(event)
156                    .expect("Failed to get full state.");
157
158                // Sort the variable states.
159                let e = match e {
160                    E::CertainPositive { state, .. } => {
161                        // Sort the variable states.
162                        let state = new_states
163                            .get_index_of(&states[state])
164                            .expect("Failed to get index of state.");
165                        // Construct the sorted evidence.
166                        E::CertainPositive { event, state }
167                    }
168                    E::CertainNegative { not_states, .. } => {
169                        // Sort the variable states.
170                        let not_states = not_states
171                            .iter()
172                            .map(|&state| {
173                                new_states
174                                    .get_index_of(&states[state])
175                                    .expect("Failed to get index of state.")
176                            })
177                            .collect();
178                        // Construct the sorted evidence.
179                        E::CertainNegative { event, not_states }
180                    }
181                    E::UncertainPositive { p_states, .. } => {
182                        // Allocate new variable states.
183                        let mut new_p_states = Array::zeros(p_states.len());
184                        // Sort the variable states.
185                        p_states.indexed_iter().for_each(|(i, &p)| {
186                            // Get sorted index.
187                            let state = new_states
188                                .get_index_of(&states[i])
189                                .expect("Failed to get index of state.");
190                            // Assign probability to sorted index.
191                            new_p_states[state] = p;
192                        });
193                        // Substitute the sorted states.
194                        let p_states = new_p_states;
195                        // Construct the sorted evidence.
196                        E::UncertainPositive { event, p_states }
197                    }
198                    E::UncertainNegative { p_not_states, .. } => {
199                        // Allocate new variable states.
200                        let mut new_p_not_states = Array::zeros(p_not_states.len());
201                        // Sort the variable states.
202                        p_not_states.indexed_iter().for_each(|(i, &p)| {
203                            // Get sorted index.
204                            let state = new_states
205                                .get_index_of(&states[i])
206                                .expect("Failed to get index of state.");
207                            // Assign probability to sorted index.
208                            new_p_not_states[state] = p;
209                        });
210                        // Substitute the sorted states.
211                        let p_not_states = new_p_not_states;
212                        // Construct the sorted evidence.
213                        E::UncertainNegative {
214                            event,
215                            p_not_states,
216                        }
217                    }
218                };
219
220                // Push the value into the variable events.
221                new_evidences[event] = Some(e);
222            });
223
224            // Update the states.
225            states = new_states;
226            // Update the evidences.
227            evidences = new_evidences;
228            // Update the labels.
229            labels = states.keys().cloned().collect();
230            // Update the shape.
231            shape = states.values().map(Set::len).collect();
232        }
233
234        // For each variable ...
235        for (i, e) in evidences.iter_mut().enumerate() {
236            // Assert states distributions have the correct size.
237            assert!(
238                e.as_ref().is_none_or(|e| match e {
239                    E::CertainPositive { .. } => true,
240                    E::CertainNegative { .. } => true,
241                    E::UncertainPositive { p_states, .. } => {
242                        p_states.len() == shape[i]
243                    }
244                    E::UncertainNegative { p_not_states, .. } => {
245                        p_not_states.len() == shape[i]
246                    }
247                }),
248                "Evidence states distributions must have the correct size."
249            );
250            // Assert states distributions are not negative.
251            assert!(
252                e.as_ref().is_none_or(|e| match e {
253                    E::CertainPositive { .. } => true,
254                    E::CertainNegative { .. } => true,
255                    E::UncertainPositive { p_states, .. } => {
256                        p_states.iter().all(|&x| x >= 0.)
257                    }
258                    E::UncertainNegative { p_not_states, .. } => {
259                        p_not_states.iter().all(|&x| x >= 0.)
260                    }
261                }),
262                "Evidence states distributions must be non-negative."
263            );
264            // Assert states distributions sum to 1.
265            assert!(
266                e.as_ref().is_none_or(|e| match e {
267                    E::CertainPositive { .. } => true,
268                    E::CertainNegative { .. } => true,
269                    E::UncertainPositive { p_states, .. } => {
270                        relative_eq!(p_states.sum(), 1.)
271                    }
272                    E::UncertainNegative { p_not_states, .. } => {
273                        relative_eq!(p_not_states.sum(), 1.)
274                    }
275                }),
276                "Evidence states distributions must sum to 1."
277            );
278        }
279
280        Self {
281            labels,
282            states,
283            shape,
284            evidences,
285        }
286    }
287
288    /// The states of the evidence.
289    ///
290    /// # Returns
291    ///
292    /// A reference to the states of the evidence.
293    ///
294    #[inline]
295    pub const fn states(&self) -> &States {
296        &self.states
297    }
298
299    /// The shape of the evidence.
300    ///
301    /// # Returns
302    ///
303    /// A reference to the shape of the evidence.
304    ///
305    #[inline]
306    pub const fn shape(&self) -> &Array1<usize> {
307        &self.shape
308    }
309
310    /// The evidences of the evidence.
311    ///
312    /// # Returns
313    ///
314    /// A reference to the evidences of the evidence.
315    ///
316    #[inline]
317    pub const fn evidences(&self) -> &Vec<Option<CatEvT>> {
318        &self.evidences
319    }
320}