causal_hub/random/
evidence.rs

1use itertools::Itertools;
2use rand::{Rng, seq::index::sample};
3
4use crate::datasets::{CatTrj, CatTrjEv, CatTrjEvT, CatTrjs, CatTrjsEv, Dataset};
5
6/// A struct representing a random evidence generator.
7pub struct RngEv<'a, R, D> {
8    rng: &'a mut R,
9    dataset: &'a D,
10    p: f64,
11}
12
13impl<'a, R, D> RngEv<'a, R, D> {
14    /// Creates a new `RngEv` instance.
15    ///
16    /// # Arguments
17    ///
18    /// * `rng` - A mutable reference to a random number generator.
19    /// * `dataset` - A reference to the dataset.
20    /// * `p` - The probability of selecting an evidence.
21    ///
22    /// # Panics
23    ///
24    /// Panics if the probability is not in [0, 1].
25    ///
26    /// # Returns
27    ///
28    /// A new `RngEv` instance.
29    pub fn new(rng: &'a mut R, dataset: &'a D, p: f64) -> Self {
30        // Assert that the probability is in [0, 1].
31        assert!((0.0..=1.0).contains(&p), "Probability must be in [0, 1]");
32
33        Self { rng, dataset, p }
34    }
35}
36
37impl<R: Rng> RngEv<'_, R, CatTrj> {
38    /// Generates random evidence from the trajectory.
39    ///
40    /// # Returns
41    ///
42    /// A `CatTrjEv` instance containing the random evidence.
43    ///
44    pub fn random(&mut self) -> CatTrjEv {
45        // Get shortened variable type.
46        use CatTrjEvT as E;
47
48        // Get times.
49        let times = self.dataset.times();
50        // Get events.
51        let events = self.dataset.values().rows();
52        // Zip times and events.
53        let times_events = times.into_iter().zip(events);
54
55        // Iterate over (time, event) pairs.
56        let evidence = times_events
57            .tuple_windows()
58            .filter_map(|((&start_time, v), (&end_time, _))| {
59                // Choose if the event is selected.
60                if !self.rng.random_bool(self.p) {
61                    // If the event is not selected, skip it.
62                    return None;
63                }
64                // Select how many events to select.
65                let n = self.rng.random_range(1..=v.len());
66                // Sample the events.
67                let evidence = sample(self.rng, v.len(), n).into_iter().map(move |index| {
68                    // Get label and state.
69                    let (event, state) = (index, v[index] as usize);
70                    // Create the evidence.
71                    E::CertainPositiveInterval {
72                        event,
73                        state,
74                        start_time,
75                        end_time,
76                    }
77                });
78                // Return the evidences.
79                Some(evidence)
80            })
81            .flatten();
82
83        // Collect the evidence.
84        CatTrjEv::new(self.dataset.states().clone(), evidence)
85    }
86}
87
88impl<R: Rng> RngEv<'_, R, CatTrjs> {
89    /// Generates random evidence from the trajectories.
90    ///
91    /// # Returns
92    ///
93    /// A `CatTrjsEv` instance containing the random evidence.
94    ///
95    pub fn random(&mut self) -> CatTrjsEv {
96        self.dataset
97            .values()
98            .iter()
99            .map(|trj| RngEv::new(&mut self.rng, trj, self.p).random())
100            .collect()
101    }
102}