causal_hub/datasets/trajectory/categorical/evidence.rs
1use approx::relative_eq;
2use ndarray::prelude::*;
3use rayon::prelude::*;
4
5use crate::{
6 datasets::CatEv,
7 models::Labelled,
8 types::{EPSILON, Labels, Set, States},
9};
10
11/// A type representing the evidence type for categorical trajectories.
12#[non_exhaustive]
13#[derive(Clone, Debug)]
14pub enum CatTrjEvT {
15 /// Certain positive interval evidence.
16 CertainPositiveInterval {
17 /// The observed event.
18 event: usize,
19 /// The observed state.
20 state: usize,
21 /// The start time of the observed interval.
22 start_time: f64,
23 /// The end time of the observed interval.
24 end_time: f64,
25 },
26 /// Certain negative interval evidence.
27 CertainNegativeInterval {
28 /// The observed event.
29 event: usize,
30 /// The non-observed states.
31 not_states: Set<usize>,
32 /// The start time of the non-observed interval.
33 start_time: f64,
34 /// The end time of the non-observed interval.
35 end_time: f64,
36 },
37 /// Uncertain positive interval evidence.
38 UncertainPositiveInterval {
39 /// The observed event.
40 event: usize,
41 /// The distribution of the observed states.
42 p_states: Array1<f64>,
43 /// The start time of the observed interval.
44 start_time: f64,
45 /// The end time of the observed interval.
46 end_time: f64,
47 },
48 /// Uncertain negative interval evidence.
49 UncertainNegativeInterval {
50 /// The observed event.
51 event: usize,
52 /// The distribution of the non-observed states.
53 p_not_states: Array1<f64>,
54 /// The start time of the non-observed interval.
55 start_time: f64,
56 /// The end time of the non-observed interval.
57 end_time: f64,
58 },
59}
60
61impl CatTrjEvT {
62 /// Return the observed event of the evidence.
63 ///
64 /// # Returns
65 ///
66 /// The observed event of the evidence.
67 ///
68 pub const fn event(&self) -> usize {
69 match self {
70 Self::CertainPositiveInterval { event, .. }
71 | Self::CertainNegativeInterval { event, .. }
72 | Self::UncertainPositiveInterval { event, .. }
73 | Self::UncertainNegativeInterval { event, .. } => *event,
74 }
75 }
76
77 /// Returns the start time of the evidence.
78 ///
79 /// # Returns
80 ///
81 /// The start time of the evidence.
82 ///
83 pub const fn start_time(&self) -> f64 {
84 match self {
85 Self::CertainPositiveInterval { start_time, .. }
86 | Self::CertainNegativeInterval { start_time, .. }
87 | Self::UncertainPositiveInterval { start_time, .. }
88 | Self::UncertainNegativeInterval { start_time, .. } => *start_time,
89 }
90 }
91
92 /// Returns the end time of the evidence.
93 ///
94 /// # Returns
95 ///
96 /// The end time of the evidence.
97 ///
98 pub const fn end_time(&self) -> f64 {
99 match self {
100 Self::CertainPositiveInterval { end_time, .. }
101 | Self::CertainNegativeInterval { end_time, .. }
102 | Self::UncertainPositiveInterval { end_time, .. }
103 | Self::UncertainNegativeInterval { end_time, .. } => *end_time,
104 }
105 }
106
107 /// Checks if the evidence contains a given time.
108 ///
109 /// # Arguments
110 ///
111 /// * `time` - The time to check.
112 ///
113 /// # Returns
114 ///
115 /// `true` if the time is in [start_time, end_time), `false` otherwise.
116 ///
117 pub fn contains(&self, time: &f64) -> bool {
118 (self.start_time()..self.end_time()).contains(time)
119 }
120}
121
122/// A type representing a collection of evidences for a categorical trajectory.
123#[derive(Clone, Debug)]
124pub struct CatTrjEv {
125 labels: Labels,
126 states: States,
127 shape: Array1<usize>,
128 evidences: Vec<Vec<CatTrjEvT>>,
129}
130
131impl Labelled for CatTrjEv {
132 #[inline]
133 fn labels(&self) -> &Labels {
134 &self.labels
135 }
136}
137
138impl CatTrjEv {
139 /// Constructs a new `CatTrjEv` instance.
140 ///
141 /// # Arguments
142 ///
143 /// * `labels` - A set of labels for the variables.
144 /// * `states` - A map of states for each variable.
145 /// * `events` - A map of events for each variable.
146 ///
147 /// # Returns
148 ///
149 /// A new `CategoricalTrajectoryEvidence` instance.
150 ///
151 pub fn new<I>(mut states: States, values: I) -> Self
152 where
153 I: IntoIterator<Item = CatTrjEvT>,
154 {
155 // Get shortened variable type.
156 use CatTrjEvT as E;
157
158 // Get the sorted labels.
159 let mut labels = states.keys().cloned().collect();
160 // Get the shape of the states.
161 let mut shape = Array::from_iter(states.values().map(Set::len));
162 // Allocate evidences.
163 let mut evidences = vec![vec![]; states.len()];
164
165 // Fill the evidences.
166 values.into_iter().for_each(|e| {
167 // Get the event index.
168 let event = e.event();
169 // Push the value into the events.
170 evidences[event].push(e);
171 });
172
173 // Sort states, if necessary.
174 if !states.keys().is_sorted() || !states.values().all(|x| x.iter().is_sorted()) {
175 // Clone the states.
176 let mut new_states = states.clone();
177 // Sort the states.
178 new_states.sort_keys();
179 new_states.values_mut().for_each(Set::sort);
180
181 // Allocate new evidences.
182 let mut new_evidences = vec![vec![]; states.len()];
183
184 // Iterate over the values and insert them into the events map using sorted indices.
185 evidences.into_iter().flatten().for_each(|e| {
186 // Get the event index, starting time, and ending time.
187 let (start_time, end_time) = (e.start_time(), e.end_time());
188 // Get the event and states of the evidence.
189 let (event, states) = states
190 .get_index(e.event())
191 .expect("Failed to get label of evidence.");
192 // Sort the event index.
193 let (event, _, new_states) = new_states
194 .get_full(event)
195 .expect("Failed to get full state.");
196
197 // Sort the event states.
198 let e = match e {
199 E::CertainPositiveInterval { state, .. } => {
200 // Sort the variable states.
201 let state = new_states
202 .get_index_of(&states[state])
203 .expect("Failed to get index of state.");
204 // Construct the sorted evidence.
205 E::CertainPositiveInterval {
206 event,
207 state,
208 start_time,
209 end_time,
210 }
211 }
212 E::CertainNegativeInterval { not_states, .. } => {
213 // Sort the event states.
214 let not_states = not_states
215 .iter()
216 .map(|&state| {
217 new_states
218 .get_index_of(&states[state])
219 .expect("Failed to get index of state.")
220 })
221 .collect();
222 // Construct the sorted evidence.
223 E::CertainNegativeInterval {
224 event,
225 not_states,
226 start_time,
227 end_time,
228 }
229 }
230 E::UncertainPositiveInterval { p_states, .. } => {
231 // Allocate new event states.
232 let mut new_p_states = Array::zeros(p_states.len());
233 // Sort the event states.
234 p_states.indexed_iter().for_each(|(i, &p)| {
235 // Get sorted index.
236 let state = new_states
237 .get_index_of(&states[i])
238 .expect("Failed to get index of state.");
239 // Assign probability to sorted index.
240 new_p_states[state] = p;
241 });
242 // Substitute the sorted states.
243 let p_states = new_p_states;
244 // Construct the sorted evidence.
245 E::UncertainPositiveInterval {
246 event,
247 p_states,
248 start_time,
249 end_time,
250 }
251 }
252 E::UncertainNegativeInterval { p_not_states, .. } => {
253 // Allocate new event states.
254 let mut new_p_not_states = Array::zeros(p_not_states.len());
255 // Sort the event states.
256 p_not_states.indexed_iter().for_each(|(i, &p)| {
257 // Get sorted index.
258 let state = new_states
259 .get_index_of(&states[i])
260 .expect("Failed to get index of state.");
261 // Assign probability to sorted index.
262 new_p_not_states[state] = p;
263 });
264 // Substitute the sorted states.
265 let p_not_states = new_p_not_states;
266 // Construct the sorted evidence.
267 E::UncertainNegativeInterval {
268 event,
269 p_not_states,
270 start_time,
271 end_time,
272 }
273 }
274 };
275
276 // Push the value into the events.
277 new_evidences[event].push(e);
278 });
279
280 // Update the states.
281 states = new_states;
282 // Update the evidences.
283 evidences = new_evidences;
284 // Update the labels.
285 labels = states.keys().cloned().collect();
286 // Update the shape.
287 shape = states.values().map(Set::len).collect();
288 }
289
290 // Check and fix incoherent evidences.
291 evidences.iter_mut().zip(&shape).for_each(
292 |(e, shape): (&mut Vec<E>, &usize)| {
293 // Assert state, starting and ending times are coherent.
294 e.iter().for_each(|e| {
295 // Assert starting time must be positive and finite.
296 assert!(
297 e.start_time().is_finite() && e.start_time() >= 0.0,
298 "Starting time must be positive and finite."
299 );
300 // Assert ending time must be positive and finite.
301 assert!(
302 e.end_time().is_finite() && e.end_time() >= 0.0,
303 "Ending time must be positive and finite."
304 );
305 // Assert starting time is less or equal than ending time.
306 assert!(
307 e.start_time() <= e.end_time(),
308 "Starting time must be less or equal than ending time."
309 );
310 // Assert states distributions have the correct size.
311 assert!(
312 match e {
313 E::CertainPositiveInterval { .. } => true,
314 E::CertainNegativeInterval { .. } => true,
315 E::UncertainPositiveInterval { p_states, .. } => {
316 p_states.len() == *shape
317 }
318 E::UncertainNegativeInterval { p_not_states, .. } => {
319 p_not_states.len() == *shape
320 }
321 },
322 "States distributions must have the correct size."
323 );
324 // Assert states distributions are not negative.
325 assert!(
326 match e {
327 E::CertainPositiveInterval { .. } => true,
328 E::CertainNegativeInterval { .. } => true,
329 E::UncertainPositiveInterval { p_states, .. } => {
330 p_states.iter().all(|&x| x >= 0.)
331 }
332 E::UncertainNegativeInterval { p_not_states, .. } => {
333 p_not_states.iter().all(|&x| x >= 0.)
334 }
335 },
336 "States distributions must be non-negative."
337 );
338 // Assert states distributions sum to 1.
339 assert!(
340 match e {
341 E::CertainPositiveInterval { .. } => true,
342 E::CertainNegativeInterval { .. } => true,
343 E::UncertainPositiveInterval { p_states, .. } => {
344 relative_eq!(p_states.sum(), 1., epsilon = EPSILON)
345 }
346 E::UncertainNegativeInterval { p_not_states, .. } => {
347 relative_eq!(p_not_states.sum(), 1., epsilon = EPSILON)
348 }
349 },
350 "States distributions must sum to one."
351 );
352 });
353
354 // Sort the events by starting time.
355 e.sort_by(|a, b| {
356 a.start_time()
357 .partial_cmp(&b.start_time())
358 // Due to previous assertions, this should never fail.
359 .unwrap_or_else(|| unreachable!())
360 });
361
362 // Handle overlapping intervals.
363 *e = e.iter().fold(Vec::new(), |mut e: Vec<E>, e_j: &E| {
364 // Ii evence is empty ...
365 if e.is_empty() {
366 // ... push current evidence and exit.
367 e.push(e_j.clone());
368 return e;
369 }
370
371 // Get the last evidence.
372 let e_i: &E = e.last().unwrap();
373 // Assert intervals times are coherent.
374 assert!(
375 e_i.start_time() <= e_j.start_time(),
376 "Two evidences for the same variable must have non-increasing starting time: \n\
377 \t expected: e(i).start_time <= e(i+1).start_time, \n\
378 \t found: e(i).start_time > e(i+1).start_time, \n\
379 \t for: e(i).start_time == {} , \n\
380 \t and: e(i+1).start_time == {} .",
381 e_i.start_time(),
382 e_j.start_time()
383 );
384 // If the current evidence ends before the next one starts ...
385 if e_i.end_time() <= e_j.start_time() {
386 // ... push current evidence and exit.
387 e.push(e_j.clone());
388 return e;
389 }
390 // Otherwise, we have overlapping intervals,
391 // check if they are the same type of evidence.
392 match (e_i, e_j) {
393 // If they are the same type of evidence ...
394 (
395 E::CertainPositiveInterval { state: s_i, .. },
396 E::CertainPositiveInterval { state: s_j, .. },
397 ) => {
398 // Check if they are the same state.
399 if s_i == s_j {
400 // Get evidence event, state, start time and end time.
401 let (event, state, start_time) = (e_i.event(), *s_i, e_i.start_time());
402 // Set end time to the maximum of both.
403 let end_time = e_i.end_time().max(e_j.end_time());
404 // Set the last evidence end time to the maximum of both.
405 *e.last_mut().unwrap() = E::CertainPositiveInterval {
406 event,
407 state,
408 start_time,
409 end_time,
410 };
411 // Otherwise, merge the two certain evidences into an uncertain one.
412 } else {
413 // Construct uncertain positive evidence.
414 let mut p_states = Array::zeros(*shape);
415 // Set the state of the evidence with a weight proportion to the time.
416 p_states[*s_i] = e_i.end_time() - e_i.start_time();
417 p_states[*s_j] = e_j.end_time() - e_j.start_time();
418 // Normalize the states.
419 p_states /= p_states.sum();
420 // Get evidence event, states, start time and end time.
421 let event = e_i.event();
422 let start_time = e_i.start_time().min(e_j.start_time());
423 let end_time = e_i.end_time().max(e_j.end_time());
424 // Set the last evidence end time to the maximum of both.
425 *e.last_mut().unwrap() = E::UncertainPositiveInterval {
426 event,
427 p_states,
428 start_time,
429 end_time,
430 };
431 }
432 }
433 (
434 E::CertainNegativeInterval {
435 not_states: s_i, ..
436 },
437 E::CertainNegativeInterval {
438 not_states: s_j, ..
439 },
440 ) => {
441 // Check if they are the same states.
442 assert_eq!(
443 s_i, s_j,
444 "Overlapping negative evidence must have the same states."
445 );
446 // Get evidence event, not states, start time and end time.
447 let (event, not_states, start_time) =
448 (e_i.event(), s_i.clone(), e_i.start_time());
449 // Set end time to the maximum of both.
450 let end_time = e_i.end_time().max(e_j.end_time());
451 // Set the last evidence end time to the maximum of both.
452 *e.last_mut().unwrap() = E::CertainNegativeInterval {
453 event,
454 not_states,
455 start_time,
456 end_time,
457 };
458 }
459 (
460 E::UncertainPositiveInterval { p_states: s_i, .. },
461 E::UncertainPositiveInterval { p_states: s_j, .. },
462 ) => {
463 // Check if they are the same states.
464 assert!(
465 relative_eq!(s_i, s_j),
466 "Overlapping uncertain evidence must have the same states."
467 );
468 // Get evidence event, states, start time and end time.
469 let (event, p_states, start_time) =
470 (e_i.event(), s_i.clone(), e_i.start_time());
471 // Set end time to the maximum of both.
472 let end_time = e_i.end_time().max(e_j.end_time());
473 // Set the last evidence end time to the maximum of both.
474 *e.last_mut().unwrap() = E::UncertainPositiveInterval {
475 event,
476 p_states,
477 start_time,
478 end_time,
479 };
480 }
481 (
482 E::UncertainNegativeInterval {
483 p_not_states: s_i, ..
484 },
485 E::UncertainNegativeInterval {
486 p_not_states: s_j, ..
487 },
488 ) => {
489 // Check if they are the same states.
490 assert!(
491 relative_eq!(s_i, s_j),
492 "Overlapping uncertain evidence must have the same states."
493 );
494 // Get evidence event, not states, start time and end time.
495 let (event, p_not_states, start_time) =
496 (e_i.event(), s_i.clone(), e_i.start_time());
497 // Set end time to the maximum of both.
498 let end_time = e_i.end_time().max(e_j.end_time());
499 // Set the last evidence end time to the maximum of both.
500 *e.last_mut().unwrap() = E::UncertainNegativeInterval {
501 event,
502 p_not_states,
503 start_time,
504 end_time,
505 };
506 }
507 // If they are not the same type of evidence ...
508 _ => panic!("Overlapping evidence must have the same type"),
509 }
510
511 e
512 });
513
514 // Assert current ending time is less or equal than next starting time.
515 assert!(
516 e
517 .windows(2)
518 .all(|e| e[0].end_time() <= e[1].start_time()),
519 "Ending time must be less or equal than next starting time."
520 );
521 },
522 );
523
524 // Create a new categorical trajectory evidence instance.
525 Self {
526 labels,
527 states,
528 shape,
529 evidences,
530 }
531 }
532
533 /// Returns the states of the trajectory evidence.
534 ///
535 /// # Returns
536 ///
537 /// A reference to the states of the trajectory evidence.
538 ///
539 #[inline]
540 pub const fn states(&self) -> &States {
541 &self.states
542 }
543
544 /// Returns the shape of the trajectory evidence.
545 ///
546 /// # Returns
547 ///
548 /// A reference to the shape of the trajectory evidence.
549 ///
550 #[inline]
551 pub const fn shape(&self) -> &Array1<usize> {
552 &self.shape
553 }
554
555 /// Returns the evidences of the trajectory.
556 ///
557 /// # Returns
558 ///
559 /// A reference to the evidences of the trajectory.
560 ///
561 #[inline]
562 pub fn evidences(&self) -> &Vec<Vec<CatTrjEvT>> {
563 &self.evidences
564 }
565
566 /// Returns the evidences at time zero.
567 ///
568 /// # Returns
569 ///
570 /// The evidences at time zero.
571 ///
572 pub fn initial_evidence(&self) -> CatEv {
573 // Get the evidences at time zero.
574 let evidences = self.evidences.iter().filter_map(|e| {
575 // Get the first evidence, if any.
576 let e = e.iter().next().cloned();
577 // Check if the evidence is at time zero.
578 let e = e.filter(|e| relative_eq!(e.start_time(), 0.));
579 // Map the evidence to its variable.
580 e.map(|e| e.into())
581 });
582
583 // Clone the states.
584 let states = self.states.clone();
585
586 // Create a new categorical evidence instance.
587 CatEv::new(states, evidences)
588 }
589}
590
591/// A collection of multivariate trajectories evidence.
592#[derive(Clone, Debug)]
593pub struct CatTrjsEv {
594 labels: Labels,
595 states: States,
596 shape: Array1<usize>,
597 evidences: Vec<CatTrjEv>,
598}
599
600impl Labelled for CatTrjsEv {
601 #[inline]
602 fn labels(&self) -> &Labels {
603 &self.labels
604 }
605}
606
607impl CatTrjsEv {
608 /// Constructs a new collection of trajectories evidence.
609 ///
610 /// # Arguments
611 ///
612 /// * `trajectories` - An iterator of `CatTrjEv` instances.
613 ///
614 /// # Panics
615 ///
616 /// Panics if:
617 ///
618 /// * The trajectories have different labels.
619 /// * The trajectories have different states.
620 /// * The trajectories have different shape.
621 ///
622 /// # Returns
623 ///
624 /// A new instance of `CategoricalTrajectoriesEvidence`.
625 ///
626 pub fn new<I>(evidences: I) -> Self
627 where
628 I: IntoIterator<Item = CatTrjEv>,
629 {
630 // Collect the trajectories into a vector.
631 let evidences: Vec<_> = evidences.into_iter().collect();
632
633 // Assert every trajectory has the same labels.
634 assert!(
635 evidences
636 .windows(2)
637 .all(|trjs| trjs[0].labels().eq(trjs[1].labels())),
638 "All trajectories must have the same labels."
639 );
640 // Assert every trajectory has the same states.
641 assert!(
642 evidences
643 .windows(2)
644 .all(|trjs| trjs[0].states().eq(trjs[1].states())),
645 "All trajectories must have the same states."
646 );
647 // Assert every trajectory has the same shape.
648 assert!(
649 evidences
650 .windows(2)
651 .all(|trjs| trjs[0].shape().eq(trjs[1].shape())),
652 "All trajectories must have the same shape."
653 );
654
655 // Get the labels, states and shape from the first trajectory.
656 let (labels, states, shape) = match evidences.first() {
657 None => (Labels::default(), States::default(), Array1::default((0,))),
658 Some(x) => (x.labels().clone(), x.states().clone(), x.shape().clone()),
659 };
660
661 Self {
662 labels,
663 states,
664 shape,
665 evidences,
666 }
667 }
668
669 /// Returns the states of the trajectories evidence.
670 ///
671 /// # Returns
672 ///
673 /// A reference to the states of the trajectories evidence.
674 ///
675 #[inline]
676 pub fn states(&self) -> &States {
677 &self.states
678 }
679
680 /// Returns the shape of the trajectories evidence.
681 ///
682 /// # Returns
683 ///
684 /// A reference to the shape of the trajectories evidence.
685 ///
686 #[inline]
687 pub fn shape(&self) -> &Array1<usize> {
688 &self.shape
689 }
690
691 /// Returns the evidences of the trajectories.
692 ///
693 /// # Returns
694 ///
695 /// A reference to the evidences of the trajectories.
696 ///
697 #[inline]
698 pub fn evidences(&self) -> &Vec<CatTrjEv> {
699 &self.evidences
700 }
701}
702
703impl FromIterator<CatTrjEv> for CatTrjsEv {
704 #[inline]
705 fn from_iter<I: IntoIterator<Item = CatTrjEv>>(iter: I) -> Self {
706 Self::new(iter)
707 }
708}
709
710impl FromParallelIterator<CatTrjEv> for CatTrjsEv {
711 #[inline]
712 fn from_par_iter<I: IntoParallelIterator<Item = CatTrjEv>>(iter: I) -> Self {
713 Self::new(iter.into_par_iter().collect::<Vec<_>>())
714 }
715}
716
717impl<'a> IntoIterator for &'a CatTrjsEv {
718 type IntoIter = std::slice::Iter<'a, CatTrjEv>;
719 type Item = &'a CatTrjEv;
720
721 #[inline]
722 fn into_iter(self) -> Self::IntoIter {
723 self.evidences.iter()
724 }
725}
726
727impl<'a> IntoParallelRefIterator<'a> for CatTrjsEv {
728 type Item = &'a CatTrjEv;
729 type Iter = rayon::slice::Iter<'a, CatTrjEv>;
730
731 #[inline]
732 fn par_iter(&'a self) -> Self::Iter {
733 self.evidences.par_iter()
734 }
735}