causal_hub/datasets/trajectory/categorical/
evidence.rs

1use approx::relative_eq;
2use ndarray::prelude::*;
3use rayon::prelude::*;
4
5use crate::{
6    datasets::CatEv,
7    models::Labelled,
8    types::{EPSILON, Labels, Set, States},
9};
10
11/// A type representing the evidence type for categorical trajectories.
12#[non_exhaustive]
13#[derive(Clone, Debug)]
14pub enum CatTrjEvT {
15    /// Certain positive interval evidence.
16    CertainPositiveInterval {
17        /// The observed event.
18        event: usize,
19        /// The observed state.
20        state: usize,
21        /// The start time of the observed interval.
22        start_time: f64,
23        /// The end time of the observed interval.
24        end_time: f64,
25    },
26    /// Certain negative interval evidence.
27    CertainNegativeInterval {
28        /// The observed event.
29        event: usize,
30        /// The non-observed states.
31        not_states: Set<usize>,
32        /// The start time of the non-observed interval.
33        start_time: f64,
34        /// The end time of the non-observed interval.
35        end_time: f64,
36    },
37    /// Uncertain positive interval evidence.
38    UncertainPositiveInterval {
39        /// The observed event.
40        event: usize,
41        /// The distribution of the observed states.
42        p_states: Array1<f64>,
43        /// The start time of the observed interval.
44        start_time: f64,
45        /// The end time of the observed interval.
46        end_time: f64,
47    },
48    /// Uncertain negative interval evidence.
49    UncertainNegativeInterval {
50        /// The observed event.
51        event: usize,
52        /// The distribution of the non-observed states.
53        p_not_states: Array1<f64>,
54        /// The start time of the non-observed interval.
55        start_time: f64,
56        /// The end time of the non-observed interval.
57        end_time: f64,
58    },
59}
60
61impl CatTrjEvT {
62    /// Return the observed event of the evidence.
63    ///
64    /// # Returns
65    ///
66    /// The observed event of the evidence.
67    ///
68    pub const fn event(&self) -> usize {
69        match self {
70            Self::CertainPositiveInterval { event, .. }
71            | Self::CertainNegativeInterval { event, .. }
72            | Self::UncertainPositiveInterval { event, .. }
73            | Self::UncertainNegativeInterval { event, .. } => *event,
74        }
75    }
76
77    /// Returns the start time of the evidence.
78    ///
79    /// # Returns
80    ///
81    /// The start time of the evidence.
82    ///
83    pub const fn start_time(&self) -> f64 {
84        match self {
85            Self::CertainPositiveInterval { start_time, .. }
86            | Self::CertainNegativeInterval { start_time, .. }
87            | Self::UncertainPositiveInterval { start_time, .. }
88            | Self::UncertainNegativeInterval { start_time, .. } => *start_time,
89        }
90    }
91
92    /// Returns the end time of the evidence.
93    ///
94    /// # Returns
95    ///
96    /// The end time of the evidence.
97    ///
98    pub const fn end_time(&self) -> f64 {
99        match self {
100            Self::CertainPositiveInterval { end_time, .. }
101            | Self::CertainNegativeInterval { end_time, .. }
102            | Self::UncertainPositiveInterval { end_time, .. }
103            | Self::UncertainNegativeInterval { end_time, .. } => *end_time,
104        }
105    }
106
107    /// Checks if the evidence contains a given time.
108    ///
109    /// # Arguments
110    ///
111    /// * `time` - The time to check.
112    ///
113    /// # Returns
114    ///
115    /// `true` if the time is in [start_time, end_time), `false` otherwise.
116    ///
117    pub fn contains(&self, time: &f64) -> bool {
118        (self.start_time()..self.end_time()).contains(time)
119    }
120}
121
122/// A type representing a collection of evidences for a categorical trajectory.
123#[derive(Clone, Debug)]
124pub struct CatTrjEv {
125    labels: Labels,
126    states: States,
127    shape: Array1<usize>,
128    evidences: Vec<Vec<CatTrjEvT>>,
129}
130
131impl Labelled for CatTrjEv {
132    #[inline]
133    fn labels(&self) -> &Labels {
134        &self.labels
135    }
136}
137
138impl CatTrjEv {
139    /// Constructs a new `CatTrjEv` instance.
140    ///
141    /// # Arguments
142    ///
143    /// * `labels` - A set of labels for the variables.
144    /// * `states` - A map of states for each variable.
145    /// * `events` - A map of events for each variable.
146    ///
147    /// # Returns
148    ///
149    /// A new `CategoricalTrajectoryEvidence` instance.
150    ///
151    pub fn new<I>(mut states: States, values: I) -> Self
152    where
153        I: IntoIterator<Item = CatTrjEvT>,
154    {
155        // Get shortened variable type.
156        use CatTrjEvT as E;
157
158        // Get the sorted labels.
159        let mut labels = states.keys().cloned().collect();
160        // Get the shape of the states.
161        let mut shape = Array::from_iter(states.values().map(Set::len));
162        // Allocate evidences.
163        let mut evidences = vec![vec![]; states.len()];
164
165        // Fill the evidences.
166        values.into_iter().for_each(|e| {
167            // Get the event index.
168            let event = e.event();
169            // Push the value into the events.
170            evidences[event].push(e);
171        });
172
173        // Sort states, if necessary.
174        if !states.keys().is_sorted() || !states.values().all(|x| x.iter().is_sorted()) {
175            // Clone the states.
176            let mut new_states = states.clone();
177            // Sort the states.
178            new_states.sort_keys();
179            new_states.values_mut().for_each(Set::sort);
180
181            // Allocate new evidences.
182            let mut new_evidences = vec![vec![]; states.len()];
183
184            // Iterate over the values and insert them into the events map using sorted indices.
185            evidences.into_iter().flatten().for_each(|e| {
186                // Get the event index, starting time, and ending time.
187                let (start_time, end_time) = (e.start_time(), e.end_time());
188                // Get the event and states of the evidence.
189                let (event, states) = states
190                    .get_index(e.event())
191                    .expect("Failed to get label of evidence.");
192                // Sort the event index.
193                let (event, _, new_states) = new_states
194                    .get_full(event)
195                    .expect("Failed to get full state.");
196
197                // Sort the event states.
198                let e = match e {
199                    E::CertainPositiveInterval { state, .. } => {
200                        // Sort the variable states.
201                        let state = new_states
202                            .get_index_of(&states[state])
203                            .expect("Failed to get index of state.");
204                        // Construct the sorted evidence.
205                        E::CertainPositiveInterval {
206                            event,
207                            state,
208                            start_time,
209                            end_time,
210                        }
211                    }
212                    E::CertainNegativeInterval { not_states, .. } => {
213                        // Sort the event states.
214                        let not_states = not_states
215                            .iter()
216                            .map(|&state| {
217                                new_states
218                                    .get_index_of(&states[state])
219                                    .expect("Failed to get index of state.")
220                            })
221                            .collect();
222                        // Construct the sorted evidence.
223                        E::CertainNegativeInterval {
224                            event,
225                            not_states,
226                            start_time,
227                            end_time,
228                        }
229                    }
230                    E::UncertainPositiveInterval { p_states, .. } => {
231                        // Allocate new event states.
232                        let mut new_p_states = Array::zeros(p_states.len());
233                        // Sort the event states.
234                        p_states.indexed_iter().for_each(|(i, &p)| {
235                            // Get sorted index.
236                            let state = new_states
237                                .get_index_of(&states[i])
238                                .expect("Failed to get index of state.");
239                            // Assign probability to sorted index.
240                            new_p_states[state] = p;
241                        });
242                        // Substitute the sorted states.
243                        let p_states = new_p_states;
244                        // Construct the sorted evidence.
245                        E::UncertainPositiveInterval {
246                            event,
247                            p_states,
248                            start_time,
249                            end_time,
250                        }
251                    }
252                    E::UncertainNegativeInterval { p_not_states, .. } => {
253                        // Allocate new event states.
254                        let mut new_p_not_states = Array::zeros(p_not_states.len());
255                        // Sort the event states.
256                        p_not_states.indexed_iter().for_each(|(i, &p)| {
257                            // Get sorted index.
258                            let state = new_states
259                                .get_index_of(&states[i])
260                                .expect("Failed to get index of state.");
261                            // Assign probability to sorted index.
262                            new_p_not_states[state] = p;
263                        });
264                        // Substitute the sorted states.
265                        let p_not_states = new_p_not_states;
266                        // Construct the sorted evidence.
267                        E::UncertainNegativeInterval {
268                            event,
269                            p_not_states,
270                            start_time,
271                            end_time,
272                        }
273                    }
274                };
275
276                // Push the value into the events.
277                new_evidences[event].push(e);
278            });
279
280            // Update the states.
281            states = new_states;
282            // Update the evidences.
283            evidences = new_evidences;
284            // Update the labels.
285            labels = states.keys().cloned().collect();
286            // Update the shape.
287            shape = states.values().map(Set::len).collect();
288        }
289
290        // Check and fix incoherent evidences.
291        evidences.iter_mut().zip(&shape).for_each(
292            |(e, shape): (&mut Vec<E>, &usize)| {
293                // Assert state, starting and ending times are coherent.
294                e.iter().for_each(|e| {
295                    // Assert starting time must be positive and finite.
296                    assert!(
297                        e.start_time().is_finite() && e.start_time() >= 0.0,
298                        "Starting time must be positive and finite."
299                    );
300                    // Assert ending time must be positive and finite.
301                    assert!(
302                        e.end_time().is_finite() && e.end_time() >= 0.0,
303                        "Ending time must be positive and finite."
304                    );
305                    // Assert starting time is less or equal than ending time.
306                    assert!(
307                        e.start_time() <= e.end_time(),
308                        "Starting time must be less or equal than ending time."
309                    );
310                    // Assert states distributions have the correct size.
311                    assert!(
312                        match e {
313                            E::CertainPositiveInterval { .. } => true,
314                            E::CertainNegativeInterval { .. } => true,
315                            E::UncertainPositiveInterval { p_states, .. } => {
316                                p_states.len() == *shape
317                            }
318                            E::UncertainNegativeInterval { p_not_states, .. } => {
319                                p_not_states.len() == *shape
320                            }
321                        },
322                        "States distributions must have the correct size."
323                    );
324                    // Assert states distributions are not negative.
325                    assert!(
326                        match e {
327                            E::CertainPositiveInterval { .. } => true,
328                            E::CertainNegativeInterval { .. } => true,
329                            E::UncertainPositiveInterval { p_states, .. } => {
330                                p_states.iter().all(|&x| x >= 0.)
331                            }
332                            E::UncertainNegativeInterval { p_not_states, .. } => {
333                                p_not_states.iter().all(|&x| x >= 0.)
334                            }
335                        },
336                        "States distributions must be non-negative."
337                    );
338                    // Assert states distributions sum to 1.
339                    assert!(
340                        match e {
341                            E::CertainPositiveInterval { .. } => true,
342                            E::CertainNegativeInterval { .. } => true,
343                            E::UncertainPositiveInterval { p_states, .. } => {
344                                relative_eq!(p_states.sum(), 1., epsilon = EPSILON)
345                            }
346                            E::UncertainNegativeInterval { p_not_states, .. } => {
347                                relative_eq!(p_not_states.sum(), 1., epsilon = EPSILON)
348                            }
349                        },
350                        "States distributions must sum to one."
351                    );
352                });
353
354                // Sort the events by starting time.
355                e.sort_by(|a, b| {
356                    a.start_time()
357                        .partial_cmp(&b.start_time())
358                        // Due to previous assertions, this should never fail.
359                        .unwrap_or_else(|| unreachable!())
360                });
361
362                // Handle overlapping intervals.
363                *e = e.iter().fold(Vec::new(), |mut e: Vec<E>, e_j: &E| {
364                    // Ii evence is empty ...
365                    if e.is_empty() {
366                        // ... push current evidence and exit.
367                        e.push(e_j.clone());
368                        return e;
369                    }
370
371                    // Get the last evidence.
372                    let e_i: &E = e.last().unwrap();
373                    // Assert intervals times are coherent.
374                    assert!(
375                        e_i.start_time() <= e_j.start_time(),
376                        "Two evidences for the same variable must have non-increasing starting time: \n\
377                        \t expected: e(i).start_time <= e(i+1).start_time, \n\
378                        \t found:    e(i).start_time >  e(i+1).start_time, \n\
379                        \t for:      e(i).start_time == {} , \n\
380                        \t and:    e(i+1).start_time == {} .",
381                        e_i.start_time(),
382                        e_j.start_time()
383                    );
384                    // If the current evidence ends before the next one starts ...
385                    if e_i.end_time() <= e_j.start_time() {
386                        // ... push current evidence and exit.
387                        e.push(e_j.clone());
388                        return e;
389                    }
390                    // Otherwise, we have overlapping intervals,
391                    // check if they are the same type of evidence.
392                    match (e_i, e_j) {
393                        // If they are the same type of evidence ...
394                        (
395                            E::CertainPositiveInterval { state: s_i, .. },
396                            E::CertainPositiveInterval { state: s_j, .. },
397                        ) => {
398                            // Check if they are the same state.
399                            if s_i == s_j {
400                                // Get evidence event, state, start time and end time.
401                                let (event, state, start_time) = (e_i.event(), *s_i, e_i.start_time());
402                                // Set end time to the maximum of both.
403                                let end_time = e_i.end_time().max(e_j.end_time());
404                                // Set the last evidence end time to the maximum of both.
405                                *e.last_mut().unwrap() = E::CertainPositiveInterval {
406                                    event,
407                                    state,
408                                    start_time,
409                                    end_time,
410                                };
411                            // Otherwise, merge the two certain evidences into an uncertain one.
412                            } else {
413                                // Construct uncertain positive evidence.
414                                let mut p_states = Array::zeros(*shape);
415                                // Set the state of the evidence with a weight proportion to the time.
416                                p_states[*s_i] = e_i.end_time() - e_i.start_time();
417                                p_states[*s_j] = e_j.end_time() - e_j.start_time();
418                                // Normalize the states.
419                                p_states /= p_states.sum();
420                                // Get evidence event, states, start time and end time.
421                                let event = e_i.event();
422                                let start_time = e_i.start_time().min(e_j.start_time());
423                                let end_time = e_i.end_time().max(e_j.end_time());
424                                // Set the last evidence end time to the maximum of both.
425                                *e.last_mut().unwrap() = E::UncertainPositiveInterval {
426                                    event,
427                                    p_states,
428                                    start_time,
429                                    end_time,
430                                };
431                            }
432                        }
433                        (
434                            E::CertainNegativeInterval {
435                                not_states: s_i, ..
436                            },
437                            E::CertainNegativeInterval {
438                                not_states: s_j, ..
439                            },
440                        ) => {
441                            // Check if they are the same states.
442                            assert_eq!(
443                                s_i, s_j,
444                                "Overlapping negative evidence must have the same states."
445                            );
446                            // Get evidence event, not states, start time and end time.
447                            let (event, not_states, start_time) =
448                                (e_i.event(), s_i.clone(), e_i.start_time());
449                            // Set end time to the maximum of both.
450                            let end_time = e_i.end_time().max(e_j.end_time());
451                            // Set the last evidence end time to the maximum of both.
452                            *e.last_mut().unwrap() = E::CertainNegativeInterval {
453                                event,
454                                not_states,
455                                start_time,
456                                end_time,
457                            };
458                        }
459                        (
460                            E::UncertainPositiveInterval { p_states: s_i, .. },
461                            E::UncertainPositiveInterval { p_states: s_j, .. },
462                        ) => {
463                            // Check if they are the same states.
464                            assert!(
465                                relative_eq!(s_i, s_j),
466                                "Overlapping uncertain evidence must have the same states."
467                            );
468                            // Get evidence event, states, start time and end time.
469                            let (event, p_states, start_time) =
470                                (e_i.event(), s_i.clone(), e_i.start_time());
471                            // Set end time to the maximum of both.
472                            let end_time = e_i.end_time().max(e_j.end_time());
473                            // Set the last evidence end time to the maximum of both.
474                            *e.last_mut().unwrap() = E::UncertainPositiveInterval {
475                                event,
476                                p_states,
477                                start_time,
478                                end_time,
479                            };
480                        }
481                        (
482                            E::UncertainNegativeInterval {
483                                p_not_states: s_i, ..
484                            },
485                            E::UncertainNegativeInterval {
486                                p_not_states: s_j, ..
487                            },
488                        ) => {
489                            // Check if they are the same states.
490                            assert!(
491                                relative_eq!(s_i, s_j),
492                                "Overlapping uncertain evidence must have the same states."
493                            );
494                            // Get evidence event, not states, start time and end time.
495                            let (event, p_not_states, start_time) =
496                                (e_i.event(), s_i.clone(), e_i.start_time());
497                            // Set end time to the maximum of both.
498                            let end_time = e_i.end_time().max(e_j.end_time());
499                            // Set the last evidence end time to the maximum of both.
500                            *e.last_mut().unwrap() = E::UncertainNegativeInterval {
501                                event,
502                                p_not_states,
503                                start_time,
504                                end_time,
505                            };
506                        }
507                        // If they are not the same type of evidence ...
508                        _ => panic!("Overlapping evidence must have the same type"),
509                    }
510
511                    e
512                });
513
514                // Assert current ending time is less or equal than next starting time.
515                assert!(
516                    e
517                        .windows(2)
518                        .all(|e| e[0].end_time() <= e[1].start_time()),
519                    "Ending time must be less or equal than next starting time."
520                );
521            },
522        );
523
524        // Create a new categorical trajectory evidence instance.
525        Self {
526            labels,
527            states,
528            shape,
529            evidences,
530        }
531    }
532
533    /// Returns the states of the trajectory evidence.
534    ///
535    /// # Returns
536    ///
537    /// A reference to the states of the trajectory evidence.
538    ///
539    #[inline]
540    pub const fn states(&self) -> &States {
541        &self.states
542    }
543
544    /// Returns the shape of the trajectory evidence.
545    ///
546    /// # Returns
547    ///
548    /// A reference to the shape of the trajectory evidence.
549    ///
550    #[inline]
551    pub const fn shape(&self) -> &Array1<usize> {
552        &self.shape
553    }
554
555    /// Returns the evidences of the trajectory.
556    ///
557    /// # Returns
558    ///
559    /// A reference to the evidences of the trajectory.
560    ///
561    #[inline]
562    pub fn evidences(&self) -> &Vec<Vec<CatTrjEvT>> {
563        &self.evidences
564    }
565
566    /// Returns the evidences at time zero.
567    ///
568    /// # Returns
569    ///
570    /// The evidences at time zero.
571    ///
572    pub fn initial_evidence(&self) -> CatEv {
573        // Get the evidences at time zero.
574        let evidences = self.evidences.iter().filter_map(|e| {
575            // Get the first evidence, if any.
576            let e = e.iter().next().cloned();
577            // Check if the evidence is at time zero.
578            let e = e.filter(|e| relative_eq!(e.start_time(), 0.));
579            // Map the evidence to its variable.
580            e.map(|e| e.into())
581        });
582
583        // Clone the states.
584        let states = self.states.clone();
585
586        // Create a new categorical evidence instance.
587        CatEv::new(states, evidences)
588    }
589}
590
591/// A collection of multivariate trajectories evidence.
592#[derive(Clone, Debug)]
593pub struct CatTrjsEv {
594    labels: Labels,
595    states: States,
596    shape: Array1<usize>,
597    evidences: Vec<CatTrjEv>,
598}
599
600impl Labelled for CatTrjsEv {
601    #[inline]
602    fn labels(&self) -> &Labels {
603        &self.labels
604    }
605}
606
607impl CatTrjsEv {
608    /// Constructs a new collection of trajectories evidence.
609    ///
610    /// # Arguments
611    ///
612    /// * `trajectories` - An iterator of `CatTrjEv` instances.
613    ///
614    /// # Panics
615    ///
616    /// Panics if:
617    ///
618    /// * The trajectories have different labels.
619    /// * The trajectories have different states.
620    /// * The trajectories have different shape.
621    ///
622    /// # Returns
623    ///
624    /// A new instance of `CategoricalTrajectoriesEvidence`.
625    ///
626    pub fn new<I>(evidences: I) -> Self
627    where
628        I: IntoIterator<Item = CatTrjEv>,
629    {
630        // Collect the trajectories into a vector.
631        let evidences: Vec<_> = evidences.into_iter().collect();
632
633        // Assert every trajectory has the same labels.
634        assert!(
635            evidences
636                .windows(2)
637                .all(|trjs| trjs[0].labels().eq(trjs[1].labels())),
638            "All trajectories must have the same labels."
639        );
640        // Assert every trajectory has the same states.
641        assert!(
642            evidences
643                .windows(2)
644                .all(|trjs| trjs[0].states().eq(trjs[1].states())),
645            "All trajectories must have the same states."
646        );
647        // Assert every trajectory has the same shape.
648        assert!(
649            evidences
650                .windows(2)
651                .all(|trjs| trjs[0].shape().eq(trjs[1].shape())),
652            "All trajectories must have the same shape."
653        );
654
655        // Get the labels, states and shape from the first trajectory.
656        let (labels, states, shape) = match evidences.first() {
657            None => (Labels::default(), States::default(), Array1::default((0,))),
658            Some(x) => (x.labels().clone(), x.states().clone(), x.shape().clone()),
659        };
660
661        Self {
662            labels,
663            states,
664            shape,
665            evidences,
666        }
667    }
668
669    /// Returns the states of the trajectories evidence.
670    ///
671    /// # Returns
672    ///
673    /// A reference to the states of the trajectories evidence.
674    ///
675    #[inline]
676    pub fn states(&self) -> &States {
677        &self.states
678    }
679
680    /// Returns the shape of the trajectories evidence.
681    ///
682    /// # Returns
683    ///
684    /// A reference to the shape of the trajectories evidence.
685    ///
686    #[inline]
687    pub fn shape(&self) -> &Array1<usize> {
688        &self.shape
689    }
690
691    /// Returns the evidences of the trajectories.
692    ///
693    /// # Returns
694    ///
695    /// A reference to the evidences of the trajectories.
696    ///
697    #[inline]
698    pub fn evidences(&self) -> &Vec<CatTrjEv> {
699        &self.evidences
700    }
701}
702
703impl FromIterator<CatTrjEv> for CatTrjsEv {
704    #[inline]
705    fn from_iter<I: IntoIterator<Item = CatTrjEv>>(iter: I) -> Self {
706        Self::new(iter)
707    }
708}
709
710impl FromParallelIterator<CatTrjEv> for CatTrjsEv {
711    #[inline]
712    fn from_par_iter<I: IntoParallelIterator<Item = CatTrjEv>>(iter: I) -> Self {
713        Self::new(iter.into_par_iter().collect::<Vec<_>>())
714    }
715}
716
717impl<'a> IntoIterator for &'a CatTrjsEv {
718    type IntoIter = std::slice::Iter<'a, CatTrjEv>;
719    type Item = &'a CatTrjEv;
720
721    #[inline]
722    fn into_iter(self) -> Self::IntoIter {
723        self.evidences.iter()
724    }
725}
726
727impl<'a> IntoParallelRefIterator<'a> for CatTrjsEv {
728    type Item = &'a CatTrjEv;
729    type Iter = rayon::slice::Iter<'a, CatTrjEv>;
730
731    #[inline]
732    fn par_iter(&'a self) -> Self::Iter {
733        self.evidences.par_iter()
734    }
735}