causal_hub/random/evidence.rs
1use itertools::Itertools;
2use rand::{Rng, seq::index::sample};
3
4use crate::datasets::{CatTrj, CatTrjEv, CatTrjEvT, CatTrjs, CatTrjsEv, Dataset};
5
6/// A struct representing a random evidence generator.
7pub struct RngEv<'a, R, D> {
8 rng: &'a mut R,
9 dataset: &'a D,
10 p: f64,
11}
12
13impl<'a, R, D> RngEv<'a, R, D> {
14 /// Creates a new `RngEv` instance.
15 ///
16 /// # Arguments
17 ///
18 /// * `rng` - A mutable reference to a random number generator.
19 /// * `dataset` - A reference to the dataset.
20 /// * `p` - The probability of selecting an evidence.
21 ///
22 /// # Panics
23 ///
24 /// Panics if the probability is not in [0, 1].
25 ///
26 /// # Returns
27 ///
28 /// A new `RngEv` instance.
29 pub fn new(rng: &'a mut R, dataset: &'a D, p: f64) -> Self {
30 // Assert that the probability is in [0, 1].
31 assert!((0.0..=1.0).contains(&p), "Probability must be in [0, 1]");
32
33 Self { rng, dataset, p }
34 }
35}
36
37impl<R: Rng> RngEv<'_, R, CatTrj> {
38 /// Generates random evidence from the trajectory.
39 ///
40 /// # Returns
41 ///
42 /// A `CatTrjEv` instance containing the random evidence.
43 ///
44 pub fn random(&mut self) -> CatTrjEv {
45 // Get shortened variable type.
46 use CatTrjEvT as E;
47
48 // Get times.
49 let times = self.dataset.times();
50 // Get events.
51 let events = self.dataset.values().rows();
52 // Zip times and events.
53 let times_events = times.into_iter().zip(events);
54
55 // Iterate over (time, event) pairs.
56 let evidence = times_events
57 .tuple_windows()
58 .filter_map(|((&start_time, v), (&end_time, _))| {
59 // Choose if the event is selected.
60 if !self.rng.random_bool(self.p) {
61 // If the event is not selected, skip it.
62 return None;
63 }
64 // Select how many events to select.
65 let n = self.rng.random_range(1..=v.len());
66 // Sample the events.
67 let evidence = sample(self.rng, v.len(), n).into_iter().map(move |index| {
68 // Get label and state.
69 let (event, state) = (index, v[index] as usize);
70 // Create the evidence.
71 E::CertainPositiveInterval {
72 event,
73 state,
74 start_time,
75 end_time,
76 }
77 });
78 // Return the evidences.
79 Some(evidence)
80 })
81 .flatten();
82
83 // Collect the evidence.
84 CatTrjEv::new(self.dataset.states().clone(), evidence)
85 }
86}
87
88impl<R: Rng> RngEv<'_, R, CatTrjs> {
89 /// Generates random evidence from the trajectories.
90 ///
91 /// # Returns
92 ///
93 /// A `CatTrjsEv` instance containing the random evidence.
94 ///
95 pub fn random(&mut self) -> CatTrjsEv {
96 self.dataset
97 .values()
98 .iter()
99 .map(|trj| RngEv::new(&mut self.rng, trj, self.p).random())
100 .collect()
101 }
102}