causal_hub/estimators/parameters/raw.rs
1use std::ops::Deref;
2
3use itertools::Itertools;
4use ndarray::{Zip, prelude::*};
5use rand::{Rng, SeedableRng, seq::SliceRandom};
6use rand_distr::{Distribution, weighted::WeightedIndex};
7use rayon::prelude::*;
8
9use crate::{
10 datasets::{CatTrj, CatTrjEv, CatTrjEvT, CatTrjs, CatTrjsEv, CatType},
11 estimators::{BE, CIMEstimator, ParCIMEstimator},
12 models::{CatCIM, Labelled},
13 types::{Labels, Set},
14};
15
16// TODO: This must be refactored to be stateless.
17
18/// A struct representing a raw estimator.
19///
20/// This estimator is used to find an initial guess of the parameters with the given evidence.
21/// Its purpose is to provide a starting point for the other estimators, like EM.
22///
23#[derive(Debug)]
24pub struct RAWE<'a, R, E, D> {
25 rng: &'a mut R,
26 evidence: &'a E,
27 dataset: Option<D>,
28}
29
30impl<R, E, D> Deref for RAWE<'_, R, E, D> {
31 type Target = D;
32
33 fn deref(&self) -> &Self::Target {
34 self.dataset.as_ref().unwrap()
35 }
36}
37
38impl<R, E, D> Labelled for RAWE<'_, R, E, D>
39where
40 D: Labelled,
41{
42 #[inline]
43 fn labels(&self) -> &Labels {
44 self.dataset.as_ref().unwrap().labels()
45 }
46}
47
48impl<'a, R: Rng + SeedableRng> RAWE<'a, R, CatTrjEv, CatTrj> {
49 /// Constructs a new raw estimator from the evidence.
50 ///
51 /// # Arguments
52 ///
53 /// * `evidence` - A reference to the evidence to fill.
54 ///
55 /// # Returns
56 ///
57 /// A new `RAWE` instance.
58 ///
59 pub fn par_new(rng: &'a mut R, evidence: &'a CatTrjEv) -> Self {
60 // Initialize the estimator.
61 let mut estimator = Self {
62 rng,
63 evidence,
64 dataset: None,
65 };
66
67 // Fill the evidence with the raw estimator.
68 estimator.dataset = Some(estimator.par_fill());
69
70 estimator
71 }
72
73 /// Sample uncertain evidence.
74 /// TODO: Taken from importance sampling, deduplicate.
75 fn sample_evidence(&mut self) -> CatTrjEv {
76 // Get shortened variable type.
77 use CatTrjEvT as E;
78
79 // Sample the evidence for each variable.
80 let certain_evidence = self
81 .evidence
82 // Flatten the evidence.
83 .evidences()
84 .iter()
85 // Map (label, [evidence]) to (label, evidence) pairs.
86 .flatten()
87 .flat_map(|e| {
88 // Get the variable index, starting time, and ending time.
89 let (event, start_time, end_time) = (e.event(), e.start_time(), e.end_time());
90 // Sample the evidence.
91 let e = match e {
92 E::UncertainPositiveInterval { p_states, .. } => {
93 // Construct the sampler.
94 let state = WeightedIndex::new(p_states).unwrap();
95 // Sample the state.
96 let state = state.sample(self.rng);
97 // Return the sample.
98 E::CertainPositiveInterval {
99 event,
100 state,
101 start_time,
102 end_time,
103 }
104 }
105 E::UncertainNegativeInterval { p_not_states, .. } => {
106 // Allocate the not states.
107 let mut not_states: Set<_> = (0..p_not_states.len()).collect();
108 // Repeat until only a subset of the not states are sampled.
109 while not_states.len() == p_not_states.len() {
110 // Sample the not states.
111 not_states = p_not_states
112 .indexed_iter()
113 // For each (state, p_not_state) pair ...
114 .filter_map(|(i, &p_i)| {
115 // ... with p_i probability, retain the state.
116 Some(i).filter(|_| self.rng.random_bool(p_i))
117 })
118 .collect();
119 }
120 // Return the sample and weight.
121 E::CertainNegativeInterval {
122 event,
123 not_states,
124 start_time,
125 end_time,
126 }
127 }
128 _ => e.clone(), // Due to evidence sampling.
129 };
130
131 // Return the certain evidence.
132 Some(e)
133 });
134
135 // Collect the certain evidence.
136 CatTrjEv::new(self.evidence.states().clone(), certain_evidence)
137 }
138
139 /// Fills the evidence with the raw estimator.
140 ///
141 /// # Arguments
142 ///
143 /// * `evidence` - A reference to the evidence to fill.
144 ///
145 /// # Returns
146 ///
147 /// A new `CatTrj` instance.
148 ///
149 fn par_fill(&mut self) -> CatTrj {
150 // Short the evidence name.
151 use CatTrjEvT as E;
152 // Set missing placeholder.
153 const M: CatType = CatType::MAX;
154
155 // Get labels and states.
156 let states = self.evidence.states().clone();
157
158 // Get the ending time of the last event.
159 let end_time = self
160 .evidence
161 .evidences()
162 .iter()
163 // Get the ending time of each event.
164 .flatten()
165 .map(|e| e.end_time())
166 // Get the maximum time.
167 .max_by(|a, b| a.partial_cmp(b).unwrap())
168 // Unwrap the maximum time.
169 .unwrap_or(0.);
170
171 // Sort the evidence by starting time, adding initial and ending time.
172 let times: Array1<_> = self
173 .evidence
174 .evidences()
175 .iter()
176 // Get the starting time of each event.
177 .flatten()
178 .map(|e| e.start_time())
179 // Add initial and ending time.
180 .chain([0., end_time])
181 // Sort the times.
182 .sorted_by(|a, b| a.partial_cmp(b).unwrap())
183 // Deduplicate the times to aggregate the events.
184 .dedup()
185 .collect();
186
187 // Allocate the matrix of events with unknown states.
188 let mut events = Array2::from_elem((times.len(), states.len()), M);
189
190 // Reduce the uncertain evidences to certain evidences.
191 let evidence = self.sample_evidence();
192
193 // Set the states of the events given the evidence.
194 Zip::from(×)
195 .and(events.axis_iter_mut(Axis(0)))
196 .par_for_each(|time, mut event| {
197 // For each event, set the state of the variable at that time, if any.
198 event.iter_mut().enumerate().for_each(|(i, e)| {
199 // Get the evidence vector for that variable.
200 let e_i = &evidence.evidences()[i];
201 // Get the evidence for that time.
202 let e_i_t = e_i.iter().find(|e| e.contains(time));
203 // If the evidence is present, set the state.
204 if let Some(e_i_t) = e_i_t {
205 match e_i_t {
206 E::CertainPositiveInterval { state, .. } => *e = *state as CatType,
207 E::CertainNegativeInterval { .. } => todo!(), // FIXME:
208 _ => unreachable!(), // Due to the previous assertions, this should never happen.
209 }
210 }
211 });
212 });
213
214 // Get the events with no evidence at all.
215 let no_evidence: Vec<_> = events
216 .axis_iter(Axis(1))
217 .into_par_iter()
218 .enumerate()
219 .filter_map(|(i, e)| {
220 if e.iter().all(|&x| x == M) {
221 Some(i)
222 } else {
223 None
224 }
225 })
226 .collect();
227 // If no evidence is present, fill it randomly.
228 for i in no_evidence {
229 // Sample a state uniformly at random.
230 let random_state = Array::from_iter({
231 let random_state = || self.rng.random_range(0..(states[i].len() as CatType));
232 std::iter::repeat_with(random_state).take(events.nrows())
233 });
234 // Fill the event with the sampled state.
235 events.column_mut(i).assign(&random_state);
236 }
237
238 // Fill the unknown states by propagating the known states.
239 events
240 .axis_iter_mut(Axis(1))
241 .into_par_iter()
242 .for_each(|mut event| {
243 // Set the first known state position.
244 let mut first_known = 0;
245 // Check if the first state is known.
246 if event[first_known] == M {
247 // If the first state is unknown, get the first known state.
248 // NOTE: Safe unwrap since we know at least one state is present.
249 first_known = event.iter().position(|e| *e != M).unwrap();
250 // Get the event to fill with.
251 let e = event[first_known];
252 // Backward fill the unknown states.
253 event.slice_mut(s![..first_known]).fill(e);
254 }
255 // Set the first known state position as the last known state position.
256 let mut last_known = first_known;
257 // Get the first unknown state.
258 while let Some(first_unknown) = event.iter().skip(last_known).position(|e| *e == M)
259 {
260 // Add displacement to the first known state position because we skipped some elements.
261 let first_unknown = first_unknown + last_known;
262 // Get the last known state.
263 // NOTE: Safe because we know at least one state is present.
264 let e = event[first_unknown - 1];
265 // Get the last unknown state after the first unknown state.
266 // NOTE: We get the "first known state after the first unknown state",
267 // but we fill with an excluding range, so we can use the same position.
268 let last_unknown = event.iter().skip(first_unknown).position(|e| *e != M);
269 // Add displacement to the first unknown state position because we skipped some elements.
270 let last_unknown =
271 last_unknown.map(|last_unknown| last_unknown + first_unknown);
272 // If no last unknown state, set the end.
273 let last_unknown = last_unknown.unwrap_or(event.len());
274 // Fill the unknown states with the last known state, or till the end if none.
275 event.slice_mut(s![first_unknown..last_unknown]).fill(e);
276 // Set the last known state position as the last unknown state position.
277 last_known = last_unknown;
278 }
279 });
280
281 // Initialize the events and times with first event and time, if any.
282 let mut new_events: Vec<_> = events
283 .rows()
284 .into_iter()
285 .map(|x| x.to_owned())
286 .take(1)
287 .collect();
288 let mut new_times: Vec<_> = times.iter().cloned().take(1).collect();
289
290 // Check if there is at max one state change per transition.
291 events
292 .rows()
293 .into_iter()
294 .zip(×)
295 .tuple_windows()
296 .for_each(|((e_i, t_i), (e_j, t_j))| {
297 // Count the number of state changes.
298 let mut diff: Vec<_> = e_i
299 .indexed_iter()
300 .zip(e_j.indexed_iter())
301 .filter_map(|(i, j)| if i != j { Some(j) } else { None })
302 .collect();
303 // Check if there is at most one state change.
304 if diff.len() <= 1 {
305 // Add the event and time to the new events.
306 new_events.push(e_j.to_owned());
307 new_times.push(*t_j);
308 // Nothing to fix, just return.
309 return;
310 }
311 // Otherwise, we have multiple state changes.
312 // Shuffle them to generate a transition order.
313 diff.shuffle(self.rng);
314 // Ignore the last state change to avoid overlap with the next event.
315 diff.pop();
316 // Get the first state change.
317 let (mut e_k, mut t_k) = (e_i.to_owned(), *t_i);
318 // Compute uniform time delta.
319 let t_delta = (t_j - t_i) / (diff.len() + 1) as f64;
320 // Generate the events to add to fill the gaps between e_i and e_j.
321 diff.into_iter().for_each(|(i, x)| {
322 // Set the state to the event.
323 e_k[i] = *x;
324 // Set the time to the event.
325 t_k += t_delta;
326 // Add the event and time to the new events.
327 new_events.push(e_k.clone());
328 new_times.push(t_k);
329 });
330 // Add the last event and time to the new events.
331 new_events.push(e_j.to_owned());
332 new_times.push(*t_j);
333 });
334
335 // Reshape the events to the number of events and states.
336 let events = Array::from_iter(new_events.into_iter().flatten())
337 .into_shape_with_order((new_times.len(), states.len()))
338 .expect("Failed to reshape events.");
339 // Reshape the times to the number of events.
340 let times = Array::from_iter(new_times);
341
342 // Construct the fully observed trajectory.
343 CatTrj::new(states, events, times)
344 }
345}
346
347impl<'a, R: Rng + SeedableRng> RAWE<'a, R, CatTrjsEv, CatTrjs> {
348 /// Constructs a new raw estimator from the evidence.
349 ///
350 /// # Arguments
351 ///
352 /// * `evidence` - A reference to the evidence to fill.
353 ///
354 /// # Returns
355 ///
356 /// A new `RAWE` instance.
357 ///
358 pub fn par_new(rng: &'a mut R, evidence: &'a CatTrjsEv) -> Self {
359 // Get evidence.
360 let _evidence = evidence.evidences();
361 // Sample seed for parallel sampling.
362 let seeds: Vec<_> = (0.._evidence.len()).map(|_| rng.next_u64()).collect();
363 // Fill the evidence with the raw estimator.
364 let dataset: Option<CatTrjs> = Some(
365 seeds
366 .into_par_iter()
367 .zip(_evidence)
368 .map(|(seed, e)| {
369 // Create a new random number generator with the seed.
370 let mut rng = R::seed_from_u64(seed);
371 // Fill the evidence with the raw estimator.
372 RAWE::<'_, R, CatTrjEv, CatTrj>::par_new(&mut rng, e)
373 .dataset
374 .unwrap()
375 })
376 .collect(),
377 );
378
379 Self {
380 rng,
381 evidence,
382 dataset,
383 }
384 }
385}
386
387impl<R: Rng + SeedableRng> CIMEstimator<CatCIM> for RAWE<'_, R, CatTrjEv, CatTrj> {
388 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
389 // Estimate the CIM with a uniform prior.
390 BE::new(self.dataset.as_ref().unwrap())
391 .with_prior((1, 1.))
392 .fit(x, z)
393 }
394}
395
396impl<R: Rng + SeedableRng> CIMEstimator<CatCIM> for RAWE<'_, R, CatTrjsEv, CatTrjs> {
397 fn fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
398 // Estimate the CIM with a uniform prior.
399 BE::new(self.dataset.as_ref().unwrap())
400 .with_prior((1, 1.))
401 .fit(x, z)
402 }
403}
404
405impl<R: Rng + SeedableRng> ParCIMEstimator<CatCIM> for RAWE<'_, R, CatTrjsEv, CatTrjs> {
406 fn par_fit(&self, x: &Set<usize>, z: &Set<usize>) -> CatCIM {
407 // Estimate the CIM with a uniform prior.
408 BE::new(self.dataset.as_ref().unwrap())
409 .with_prior((1, 1.))
410 .par_fit(x, z)
411 }
412}