causal_hub/estimators/parameters/
raw.rs

1use std::ops::Deref;
2
3use itertools::Itertools;
4use ndarray::{Zip, prelude::*};
5use rand::{Rng, SeedableRng, seq::SliceRandom};
6use rand_distr::{Distribution, weighted::WeightedIndex};
7use rayon::prelude::*;
8
9use crate::{
10    datasets::{CatTrj, CatTrjEv, CatTrjEvT, CatTrjs, CatTrjsEv, CatType},
11    estimators::{BE, CIMEstimator, ParCIMEstimator},
12    models::{CatCIM, Labelled},
13    types::{Labels, Set},
14};
15
16// TODO: This must be refactored to be stateless.
17
18/// A struct representing a raw estimator.
19///
20/// This estimator is used to find an initial guess of the parameters with the given evidence.
21/// Its purpose is to provide a starting point for the other estimators, like EM.
22///
23#[derive(Debug)]
24pub struct RAWE<'a, R, E, D> {
25    rng: &'a mut R,
26    evidence: &'a E,
27    dataset: Option<D>,
28}
29
30impl<R, E, D> Deref for RAWE<'_, R, E, D> {
31    type Target = D;
32
33    fn deref(&self) -> &Self::Target {
34        self.dataset.as_ref().unwrap()
35    }
36}
37
38impl<R, E, D> Labelled for RAWE<'_, R, E, D>
39where
40    D: Labelled,
41{
42    #[inline]
43    fn labels(&self) -> &Labels {
44        self.dataset.as_ref().unwrap().labels()
45    }
46}
47
48impl<'a, R: Rng + SeedableRng> RAWE<'a, R, CatTrjEv, CatTrj> {
49    /// Constructs a new raw estimator from the evidence.
50    ///
51    /// # Arguments
52    ///
53    /// * `evidence` - A reference to the evidence to fill.
54    ///
55    /// # Returns
56    ///
57    /// A new `RAWE` instance.
58    ///
59    pub fn par_new(rng: &'a mut R, evidence: &'a CatTrjEv) -> Self {
60        // Initialize the estimator.
61        let mut estimator = Self {
62            rng,
63            evidence,
64            dataset: None,
65        };
66
67        // Fill the evidence with the raw estimator.
68        estimator.dataset = Some(estimator.par_fill());
69
70        estimator
71    }
72
73    /// Sample uncertain evidence.
74    /// TODO: Taken from importance sampling, deduplicate.
75    fn sample_evidence(&mut self) -> CatTrjEv {
76        // Get shortened variable type.
77        use CatTrjEvT as E;
78
79        // Sample the evidence for each variable.
80        let certain_evidence = self
81            .evidence
82            // Flatten the evidence.
83            .evidences()
84            .iter()
85            // Map (label, [evidence]) to (label, evidence) pairs.
86            .flatten()
87            .flat_map(|e| {
88                // Get the variable index, starting time, and ending time.
89                let (event, start_time, end_time) = (e.event(), e.start_time(), e.end_time());
90                // Sample the evidence.
91                let e = match e {
92                    E::UncertainPositiveInterval { p_states, .. } => {
93                        // Construct the sampler.
94                        let state = WeightedIndex::new(p_states).unwrap();
95                        // Sample the state.
96                        let state = state.sample(self.rng);
97                        // Return the sample.
98                        E::CertainPositiveInterval {
99                            event,
100                            state,
101                            start_time,
102                            end_time,
103                        }
104                    }
105                    E::UncertainNegativeInterval { p_not_states, .. } => {
106                        // Allocate the not states.
107                        let mut not_states: Set<_> = (0..p_not_states.len()).collect();
108                        // Repeat until only a subset of the not states are sampled.
109                        while not_states.len() == p_not_states.len() {
110                            // Sample the not states.
111                            not_states = p_not_states
112                                .indexed_iter()
113                                // For each (state, p_not_state) pair ...
114                                .filter_map(|(i, &p_i)| {
115                                    // ... with p_i probability, retain the state.
116                                    Some(i).filter(|_| self.rng.random_bool(p_i))
117                                })
118                                .collect();
119                        }
120                        // Return the sample and weight.
121                        E::CertainNegativeInterval {
122                            event,
123                            not_states,
124                            start_time,
125                            end_time,
126                        }
127                    }
128                    _ => e.clone(), // Due to evidence sampling.
129                };
130
131                // Return the certain evidence.
132                Some(e)
133            });
134
135        // Collect the certain evidence.
136        CatTrjEv::new(self.evidence.states().clone(), certain_evidence)
137    }
138
139    /// Fills the evidence with the raw estimator.
140    ///
141    /// # Arguments
142    ///
143    /// * `evidence` - A reference to the evidence to fill.
144    ///
145    /// # Returns
146    ///
147    /// A new `CatTrj` instance.
148    ///
149    fn par_fill(&mut self) -> CatTrj {
150        // Short the evidence name.
151        use CatTrjEvT as E;
152        // Set missing placeholder.
153        const M: CatType = CatType::MAX;
154
155        // Get labels and states.
156        let states = self.evidence.states().clone();
157
158        // Get the ending time of the last event.
159        let end_time = self
160            .evidence
161            .evidences()
162            .iter()
163            // Get the ending time of each event.
164            .flatten()
165            .map(|e| e.end_time())
166            // Get the maximum time.
167            .max_by(|a, b| a.partial_cmp(b).unwrap())
168            // Unwrap the maximum time.
169            .unwrap_or(0.);
170
171        // Sort the evidence by starting time, adding initial and ending time.
172        let times: Array1<_> = self
173            .evidence
174            .evidences()
175            .iter()
176            // Get the starting time of each event.
177            .flatten()
178            .map(|e| e.start_time())
179            // Add initial and ending time.
180            .chain([0., end_time])
181            // Sort the times.
182            .sorted_by(|a, b| a.partial_cmp(b).unwrap())
183            // Deduplicate the times to aggregate the events.
184            .dedup()
185            .collect();
186
187        // Allocate the matrix of events with unknown states.
188        let mut events = Array2::from_elem((times.len(), states.len()), M);
189
190        // Reduce the uncertain evidences to certain evidences.
191        let evidence = self.sample_evidence();
192
193        // Set the states of the events given the evidence.
194        Zip::from(&times)
195            .and(events.axis_iter_mut(Axis(0)))
196            .par_for_each(|time, mut event| {
197                // For each event, set the state of the variable at that time, if any.
198                event.iter_mut().enumerate().for_each(|(i, e)| {
199                    // Get the evidence vector for that variable.
200                    let e_i = &evidence.evidences()[i];
201                    // Get the evidence for that time.
202                    let e_i_t = e_i.iter().find(|e| e.contains(time));
203                    // If the evidence is present, set the state.
204                    if let Some(e_i_t) = e_i_t {
205                        match e_i_t {
206                            E::CertainPositiveInterval { state, .. } => *e = *state as CatType,
207                            E::CertainNegativeInterval { .. } => todo!(), // FIXME:
208                            _ => unreachable!(), // Due to the previous assertions, this should never happen.
209                        }
210                    }
211                });
212            });
213
214        // Get the events with no evidence at all.
215        let no_evidence: Vec<_> = events
216            .axis_iter(Axis(1))
217            .into_par_iter()
218            .enumerate()
219            .filter_map(|(i, e)| {
220                if e.iter().all(|&x| x == M) {
221                    Some(i)
222                } else {
223                    None
224                }
225            })
226            .collect();
227        // If no evidence is present, fill it randomly.
228        for i in no_evidence {
229            // Sample a state uniformly at random.
230            let random_state = Array::from_iter({
231                let random_state = || self.rng.random_range(0..(states[i].len() as CatType));
232                std::iter::repeat_with(random_state).take(events.nrows())
233            });
234            // Fill the event with the sampled state.
235            events.column_mut(i).assign(&random_state);
236        }
237
238        // Fill the unknown states by propagating the known states.
239        events
240            .axis_iter_mut(Axis(1))
241            .into_par_iter()
242            .for_each(|mut event| {
243                // Set the first known state position.
244                let mut first_known = 0;
245                // Check if the first state is known.
246                if event[first_known] == M {
247                    // If the first state is unknown, get the first known state.
248                    // NOTE: Safe unwrap since we know at least one state is present.
249                    first_known = event.iter().position(|e| *e != M).unwrap();
250                    // Get the event to fill with.
251                    let e = event[first_known];
252                    // Backward fill the unknown states.
253                    event.slice_mut(s![..first_known]).fill(e);
254                }
255                // Set the first known state position as the last known state position.
256                let mut last_known = first_known;
257                // Get the first unknown state.
258                while let Some(first_unknown) = event.iter().skip(last_known).position(|e| *e == M)
259                {
260                    // Add displacement to the first known state position because we skipped some elements.
261                    let first_unknown = first_unknown + last_known;
262                    // Get the last known state.
263                    // NOTE: Safe because we know at least one state is present.
264                    let e = event[first_unknown - 1];
265                    // Get the last unknown state after the first unknown state.
266                    // NOTE: We get the "first known state after the first unknown state",
267                    // but we fill with an excluding range, so we can use the same position.
268                    let last_unknown = event.iter().skip(first_unknown).position(|e| *e != M);
269                    // Add displacement to the first unknown state position because we skipped some elements.
270                    let last_unknown =
271                        last_unknown.map(|last_unknown| last_unknown + first_unknown);
272                    // If no last unknown state, set the end.
273                    let last_unknown = last_unknown.unwrap_or(event.len());
274                    // Fill the unknown states with the last known state, or till the end if none.
275                    event.slice_mut(s![first_unknown..last_unknown]).fill(e);
276                    // Set the last known state position as the last unknown state position.
277                    last_known = last_unknown;
278                }
279            });
280
281        // Initialize the events and times with first event and time, if any.
282        let mut new_events: Vec<_> = events
283            .rows()
284            .into_iter()
285            .map(|x| x.to_owned())
286            .take(1)
287            .collect();
288        let mut new_times: Vec<_> = times.iter().cloned().take(1).collect();
289
290        // Check if there is at max one state change per transition.
291        events
292            .rows()
293            .into_iter()
294            .zip(&times)
295            .tuple_windows()
296            .for_each(|((e_i, t_i), (e_j, t_j))| {
297                // Count the number of state changes.
298                let mut diff: Vec<_> = e_i
299                    .indexed_iter()
300                    .zip(e_j.indexed_iter())
301                    .filter_map(|(i, j)| if i != j { Some(j) } else { None })
302                    .collect();
303                // Check if there is at most one state change.
304                if diff.len() <= 1 {
305                    // Add the event and time to the new events.
306                    new_events.push(e_j.to_owned());
307                    new_times.push(*t_j);
308                    // Nothing to fix, just return.
309                    return;
310                }
311                // Otherwise, we have multiple state changes.
312                // Shuffle them to generate a transition order.
313                diff.shuffle(self.rng);
314                // Ignore the last state change to avoid overlap with the next event.
315                diff.pop();
316                // Get the first state change.
317                let (mut e_k, mut t_k) = (e_i.to_owned(), *t_i);
318                // Compute uniform time delta.
319                let t_delta = (t_j - t_i) / (diff.len() + 1) as f64;
320                // Generate the events to add to fill the gaps between e_i and e_j.
321                diff.into_iter().for_each(|(i, x)| {
322                    // Set the state to the event.
323                    e_k[i] = *x;
324                    // Set the time to the event.
325                    t_k += t_delta;
326                    // Add the event and time to the new events.
327                    new_events.push(e_k.clone());
328                    new_times.push(t_k);
329                });
330                // Add the last event and time to the new events.
331                new_events.push(e_j.to_owned());
332                new_times.push(*t_j);
333            });
334
335        // Reshape the events to the number of events and states.
336        let events = Array::from_iter(new_events.into_iter().flatten())
337            .into_shape_with_order((new_times.len(), states.len()))
338            .expect("Failed to reshape events.");
339        // Reshape the times to the number of events.
340        let times = Array::from_iter(new_times);
341
342        // Construct the fully observed trajectory.
343        CatTrj::new(states, events, times)
344    }
345}
346
347impl<'a, R: Rng + SeedableRng> RAWE<'a, R, CatTrjsEv, CatTrjs> {
348    /// Constructs a new raw estimator from the evidence.
349    ///
350    /// # Arguments
351    ///
352    /// * `evidence` - A reference to the evidence to fill.
353    ///
354    /// # Returns
355    ///
356    /// A new `RAWE` instance.
357    ///
358    pub fn par_new(rng: &'a mut R, evidence: &'a CatTrjsEv) -> Self {
359        // Get evidence.
360        let _evidence = evidence.evidences();
361        // Sample seed for parallel sampling.
362        let seeds: Vec<_> = (0.._evidence.len()).map(|_| rng.next_u64()).collect();
363        // Fill the evidence with the raw estimator.
364        let dataset: Option<CatTrjs> = Some(
365            seeds
366                .into_par_iter()
367                .zip(_evidence)
368                .map(|(seed, e)| {
369                    // Create a new random number generator with the seed.
370                    let mut rng = R::seed_from_u64(seed);
371                    // Fill the evidence with the raw estimator.
372                    RAWE::<'_, R, CatTrjEv, CatTrj>::par_new(&mut rng, e)
373                        .dataset
374                        .unwrap()
375                })
376                .collect(),
377        );
378
379        Self {
380            rng,
381            evidence,
382            dataset,
383        }
384    }
385}
386
387impl<R: Rng + SeedableRng> CIMEstimator<CatCIM> for RAWE<'_, R, CatTrjEv, CatTrj> {
388    fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
389        // Estimate the CIM with a uniform prior.
390        BE::new(self.dataset.as_ref().unwrap())
391            .with_prior((1, 1.))
392            .fit(x, z)
393    }
394}
395
396impl<R: Rng + SeedableRng> CIMEstimator<CatCIM> for RAWE<'_, R, CatTrjsEv, CatTrjs> {
397    fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
398        // Estimate the CIM with a uniform prior.
399        BE::new(self.dataset.as_ref().unwrap())
400            .with_prior((1, 1.))
401            .fit(x, z)
402    }
403}
404
405impl<R: Rng + SeedableRng> ParCIMEstimator<CatCIM> for RAWE<'_, R, CatTrjsEv, CatTrjs> {
406    fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
407        // Estimate the CIM with a uniform prior.
408        BE::new(self.dataset.as_ref().unwrap())
409            .with_prior((1, 1.))
410            .par_fit(x, z)
411    }
412}