causal_hub/datasets/table/categorical/evidence.rs
1use approx::relative_eq;
2use ndarray::prelude::*;
3
4use crate::{
5 datasets::CatTrjEvT,
6 models::Labelled,
7 types::{Labels, Set, States},
8};
9
10/// Categorical evidence type.
11#[non_exhaustive]
12#[derive(Clone, Debug)]
13pub enum CatEvT {
14 /// Certain positive evidence.
15 CertainPositive {
16 /// The observed event of the evidence.
17 event: usize,
18 /// The state of the evidence.
19 state: usize,
20 },
21 /// Certain negative evidence.
22 CertainNegative {
23 /// The observed event of the evidence.
24 event: usize,
25 /// The states of the evidence.
26 not_states: Set<usize>,
27 },
28 /// Uncertain positive evidence.
29 UncertainPositive {
30 /// The observed event of the evidence.
31 event: usize,
32 /// The probabilities of the states.
33 p_states: Array1<f64>,
34 },
35 /// Uncertain negative evidence.
36 UncertainNegative {
37 /// The observed event of the evidence.
38 event: usize,
39 /// The probabilities of the states.
40 p_not_states: Array1<f64>,
41 },
42}
43
44impl From<CatTrjEvT> for CatEvT {
45 fn from(evidence: CatTrjEvT) -> Self {
46 // Get shortened variable types.
47 use CatEvT as U;
48 use CatTrjEvT as T;
49 // Match the evidence type discard the temporal information.
50 match evidence {
51 T::CertainPositiveInterval { event, state, .. } => U::CertainPositive { event, state },
52 T::CertainNegativeInterval {
53 event, not_states, ..
54 } => U::CertainNegative { event, not_states },
55 T::UncertainPositiveInterval {
56 event, p_states, ..
57 } => U::UncertainPositive { event, p_states },
58 T::UncertainNegativeInterval {
59 event,
60 p_not_states,
61 ..
62 } => U::UncertainNegative {
63 event,
64 p_not_states,
65 },
66 }
67 }
68}
69
70impl CatEvT {
71 /// Return the observed event of the evidence.
72 ///
73 /// # Returns
74 ///
75 /// The observed event of the evidence.
76 ///
77 pub const fn event(&self) -> usize {
78 match self {
79 Self::CertainPositive { event, .. }
80 | Self::CertainNegative { event, .. }
81 | Self::UncertainPositive { event, .. }
82 | Self::UncertainNegative { event, .. } => *event,
83 }
84 }
85}
86
87/// Categorical evidence structure.
88#[derive(Clone, Debug)]
89pub struct CatEv {
90 labels: Labels,
91 states: States,
92 shape: Array1<usize>,
93 evidences: Vec<Option<CatEvT>>,
94}
95
96impl Labelled for CatEv {
97 fn labels(&self) -> &Labels {
98 &self.labels
99 }
100}
101
102impl CatEv {
103 /// Creates a new categorical evidence structure.
104 ///
105 /// # Arguments
106 ///
107 /// * `states` - A collection of states, where each state is a tuple of a string and an iterator of strings.
108 /// * `values` - A collection of values, where each value is a categorical evidence type.
109 ///
110 /// # Returns
111 ///
112 /// A new categorical evidence structure.
113 ///
114 pub fn new<I>(mut states: States, values: I) -> Self
115 where
116 I: IntoIterator<Item = CatEvT>,
117 {
118 // Get shortened variable type.
119 use CatEvT as E;
120
121 // Get the sorted labels.
122 let mut labels = states.keys().cloned().collect();
123 // Get the shape of the states.
124 let mut shape = Array::from_iter(states.values().map(Set::len));
125 // Allocate evidences.
126 let mut evidences = vec![None; states.len()];
127
128 // Fill the evidences.
129 values.into_iter().for_each(|e| {
130 // Get the event of the evidence.
131 let event = e.event();
132 // Push the value into the variable events.
133 evidences[event] = Some(e);
134 });
135
136 // Sort states, if necessary.
137 if !states.keys().is_sorted() || !states.values().all(|x| x.iter().is_sorted()) {
138 // Clone the states.
139 let mut new_states = states.clone();
140 // Sort the states.
141 new_states.sort_keys();
142 new_states.values_mut().for_each(Set::sort);
143
144 // Allocate new evidences.
145 let mut new_evidences = vec![None; states.len()];
146
147 // Iterate over the values and insert them into the events map using sorted indices.
148 evidences.into_iter().flatten().for_each(|e| {
149 // Get the event and states of the evidence.
150 let (event, states) = states
151 .get_index(e.event())
152 .expect("Failed to get label of evidence.");
153 // Sort the event index.
154 let (event, _, new_states) = new_states
155 .get_full(event)
156 .expect("Failed to get full state.");
157
158 // Sort the variable states.
159 let e = match e {
160 E::CertainPositive { state, .. } => {
161 // Sort the variable states.
162 let state = new_states
163 .get_index_of(&states[state])
164 .expect("Failed to get index of state.");
165 // Construct the sorted evidence.
166 E::CertainPositive { event, state }
167 }
168 E::CertainNegative { not_states, .. } => {
169 // Sort the variable states.
170 let not_states = not_states
171 .iter()
172 .map(|&state| {
173 new_states
174 .get_index_of(&states[state])
175 .expect("Failed to get index of state.")
176 })
177 .collect();
178 // Construct the sorted evidence.
179 E::CertainNegative { event, not_states }
180 }
181 E::UncertainPositive { p_states, .. } => {
182 // Allocate new variable states.
183 let mut new_p_states = Array::zeros(p_states.len());
184 // Sort the variable states.
185 p_states.indexed_iter().for_each(|(i, &p)| {
186 // Get sorted index.
187 let state = new_states
188 .get_index_of(&states[i])
189 .expect("Failed to get index of state.");
190 // Assign probability to sorted index.
191 new_p_states[state] = p;
192 });
193 // Substitute the sorted states.
194 let p_states = new_p_states;
195 // Construct the sorted evidence.
196 E::UncertainPositive { event, p_states }
197 }
198 E::UncertainNegative { p_not_states, .. } => {
199 // Allocate new variable states.
200 let mut new_p_not_states = Array::zeros(p_not_states.len());
201 // Sort the variable states.
202 p_not_states.indexed_iter().for_each(|(i, &p)| {
203 // Get sorted index.
204 let state = new_states
205 .get_index_of(&states[i])
206 .expect("Failed to get index of state.");
207 // Assign probability to sorted index.
208 new_p_not_states[state] = p;
209 });
210 // Substitute the sorted states.
211 let p_not_states = new_p_not_states;
212 // Construct the sorted evidence.
213 E::UncertainNegative {
214 event,
215 p_not_states,
216 }
217 }
218 };
219
220 // Push the value into the variable events.
221 new_evidences[event] = Some(e);
222 });
223
224 // Update the states.
225 states = new_states;
226 // Update the evidences.
227 evidences = new_evidences;
228 // Update the labels.
229 labels = states.keys().cloned().collect();
230 // Update the shape.
231 shape = states.values().map(Set::len).collect();
232 }
233
234 // For each variable ...
235 for (i, e) in evidences.iter_mut().enumerate() {
236 // Assert states distributions have the correct size.
237 assert!(
238 e.as_ref().is_none_or(|e| match e {
239 E::CertainPositive { .. } => true,
240 E::CertainNegative { .. } => true,
241 E::UncertainPositive { p_states, .. } => {
242 p_states.len() == shape[i]
243 }
244 E::UncertainNegative { p_not_states, .. } => {
245 p_not_states.len() == shape[i]
246 }
247 }),
248 "Evidence states distributions must have the correct size."
249 );
250 // Assert states distributions are not negative.
251 assert!(
252 e.as_ref().is_none_or(|e| match e {
253 E::CertainPositive { .. } => true,
254 E::CertainNegative { .. } => true,
255 E::UncertainPositive { p_states, .. } => {
256 p_states.iter().all(|&x| x >= 0.)
257 }
258 E::UncertainNegative { p_not_states, .. } => {
259 p_not_states.iter().all(|&x| x >= 0.)
260 }
261 }),
262 "Evidence states distributions must be non-negative."
263 );
264 // Assert states distributions sum to 1.
265 assert!(
266 e.as_ref().is_none_or(|e| match e {
267 E::CertainPositive { .. } => true,
268 E::CertainNegative { .. } => true,
269 E::UncertainPositive { p_states, .. } => {
270 relative_eq!(p_states.sum(), 1.)
271 }
272 E::UncertainNegative { p_not_states, .. } => {
273 relative_eq!(p_not_states.sum(), 1.)
274 }
275 }),
276 "Evidence states distributions must sum to 1."
277 );
278 }
279
280 Self {
281 labels,
282 states,
283 shape,
284 evidences,
285 }
286 }
287
288 /// The states of the evidence.
289 ///
290 /// # Returns
291 ///
292 /// A reference to the states of the evidence.
293 ///
294 #[inline]
295 pub const fn states(&self) -> &States {
296 &self.states
297 }
298
299 /// The shape of the evidence.
300 ///
301 /// # Returns
302 ///
303 /// A reference to the shape of the evidence.
304 ///
305 #[inline]
306 pub const fn shape(&self) -> &Array1<usize> {
307 &self.shape
308 }
309
310 /// The evidences of the evidence.
311 ///
312 /// # Returns
313 ///
314 /// A reference to the evidences of the evidence.
315 ///
316 #[inline]
317 pub const fn evidences(&self) -> &Vec<Option<CatEvT>> {
318 &self.evidences
319 }
320}