causal_hub/samplers/
importance.rs

1use std::cell::RefCell;
2
3use ndarray::prelude::*;
4use ndarray_stats::QuantileExt;
5use rand::{
6    Rng, SeedableRng,
7    distr::{Distribution, weighted::WeightedIndex},
8};
9use rand_distr::Exp;
10use rayon::prelude::*;
11
12use crate::{
13    datasets::{
14        CatEv, CatEvT, CatSample, CatTable, CatTrj, CatTrjEv, CatTrjEvT, CatType, CatWtdSample,
15        CatWtdTable, CatWtdTrj, CatWtdTrjs, GaussEv, GaussEvT, GaussTable, GaussType,
16        GaussWtdSample, GaussWtdTable,
17    },
18    models::{BN, CIM, CPD, CTBN, CatBN, CatCTBN, GaussBN, Labelled},
19    samplers::{BNSampler, CTBNSampler, ParBNSampler, ParCTBNSampler},
20    set,
21    types::{EPSILON, Set},
22};
23
24/// A struct for sampling using importance sampling.
25#[derive(Debug)]
26pub struct ImportanceSampler<'a, R, M, E> {
27    rng: RefCell<&'a mut R>,
28    model: &'a M,
29    evidence: &'a E,
30}
31
32impl<'a, R, M, E> ImportanceSampler<'a, R, M, E>
33where
34    M: Labelled,
35    E: Labelled,
36{
37    /// Construct a new importance sampler.
38    ///
39    /// # Arguments
40    ///
41    /// * `rng` - A random number generator.
42    /// * `model` - A reference to the model to sample from.
43    /// * `evidence` - A reference to the evidence to sample from.
44    ///
45    /// # Returns
46    ///
47    /// Return a new `ImportanceSampler` instance.
48    ///
49    #[inline]
50    pub fn new(rng: &'a mut R, model: &'a M, evidence: &'a E) -> Self {
51        // Wrap the RNG in a RefCell to allow interior mutability.
52        let rng = RefCell::new(rng);
53
54        // Assert the model and the evidences have the same labels.
55        assert_eq!(
56            model.labels(),
57            evidence.labels(),
58            "The model and the evidences must have the same variables."
59        );
60
61        Self {
62            rng,
63            model,
64            evidence,
65        }
66    }
67}
68
69impl<R: Rng> ImportanceSampler<'_, R, CatBN, CatEv> {
70    /// Sample uncertain evidence.
71    fn sample_evidence<T: Rng>(&self, rng: &mut T) -> CatEv {
72        // Get shortened variable type.
73        use CatEvT as E;
74
75        // Sample the evidence for each variable.
76        let certain_evidence = self
77            .evidence
78            // Flatten the evidence.
79            .evidences()
80            .iter()
81            // Filter empty evidences.
82            .filter_map(|e| {
83                e.as_ref().map(|e| {
84                    // Get the event index.
85                    let event = e.event();
86                    // Sample the evidence.
87                    match e {
88                        E::UncertainPositive { p_states, .. } => {
89                            // Construct the sampler.
90                            let state = WeightedIndex::new(p_states).unwrap();
91                            // Sample the state.
92                            let state = state.sample(rng);
93                            // Return the sample.
94                            E::CertainPositive { event, state }
95                        }
96                        E::UncertainNegative { p_not_states, .. } => {
97                            // Allocate the not states.
98                            let mut not_states: Set<_> = (0..p_not_states.len()).collect();
99                            // Repeat until only a subset of the not states are sampled.
100                            while not_states.len() == p_not_states.len() {
101                                // Sample the not states.
102                                not_states = p_not_states
103                                    .indexed_iter()
104                                    // For each (state, p_not_state) pair ...
105                                    .filter_map(|(i, &p_i)| {
106                                        // ... with p_i probability, retain the state.
107                                        Some(i).filter(|_| rng.random_bool(p_i))
108                                    })
109                                    .collect();
110                            }
111                            // Return the sample and weight.
112                            E::CertainNegative { event, not_states }
113                        }
114                        _ => e.clone(), // Due to evidence sampling.
115                    }
116                })
117            });
118
119        // Collect the certain evidence.
120        CatEv::new(self.evidence.states().clone(), certain_evidence)
121    }
122}
123
124impl<R: Rng> BNSampler<CatBN> for ImportanceSampler<'_, R, CatBN, CatEv> {
125    type Sample = CatWtdSample;
126    type Samples = CatWtdTable;
127
128    fn sample(&self) -> Self::Sample {
129        // Get shortened variable type.
130        use CatEvT as E;
131
132        // Assert the model and the evidences have the same states.
133        // TODO: Move this assertion to the constructor.
134        assert_eq!(
135            self.model.states(),
136            self.evidence.states(),
137            "The model and the evidences must have the same states."
138        );
139
140        // Get a mutable reference to the RNG.
141        let mut rng = self.rng.borrow_mut();
142        // Allocate the sample.
143        let mut sample = Array::zeros(self.model.labels().len());
144        // Initialize the weight.
145        let mut weight = 1.;
146
147        // Reduce the uncertain evidences to certain evidences.
148        let evidence = self.sample_evidence(&mut rng);
149
150        // For each vertex in the topological order ...
151        self.model.topological_order().iter().for_each(|&i| {
152            // Get the evidence of the vertex.
153            let e_i = &evidence.evidences()[i];
154
155            // Get the CPD.
156            let cpd_i = &self.model.cpds()[i];
157            // Compute the index on the parents to condition on.
158            let pa_i = self.model.graph().parents(&set![i]);
159            let pa_i = pa_i.iter().map(|&z| sample[z] as usize);
160            let pa_i = cpd_i.conditioning_multi_index().ravel(pa_i);
161            // Get the distribution of the vertex.
162            let p_i = cpd_i.parameters().row(pa_i);
163
164            // Get the evidence of the vertex.
165            let (s_i, w_i) = match e_i {
166                // If there is evidence, sample from the constrained distribution.
167                Some(e_i) => match e_i {
168                    E::CertainPositive { state, .. } => {
169                        // Get the state.
170                        let s_i = *state as CatType;
171                        // Return the state and its weight.
172                        (s_i, p_i[*state])
173                    }
174                    E::CertainNegative { not_states, .. } => {
175                        // Initialize the weight.
176                        let mut w_i = 1.;
177                        // Clone the distribution.
178                        let mut p_i = p_i.to_owned();
179                        // For each not state ...
180                        not_states.iter().for_each(|&j| {
181                            // Update the weight.
182                            w_i -= p_i[j];
183                            // Zero out the not states.
184                            p_i[j] = 0.;
185                        });
186                        // Normalize the probabilities.
187                        p_i /= p_i.sum();
188                        // Construct the sampler.
189                        let s_i = WeightedIndex::new(&p_i).unwrap();
190                        // Sample the state.
191                        let s_i = s_i.sample(&mut rng) as CatType;
192                        // Return the sample and weight.
193                        (s_i, w_i)
194                    }
195                    _ => unreachable!(), // Due to evidence sampling.
196                },
197                // If there is no evidence, sample as usual.
198                None => {
199                    // Construct the sampler.
200                    let s_i = WeightedIndex::new(&p_i).unwrap();
201                    // Sample the state.
202                    let s_i = s_i.sample(&mut rng) as CatType;
203                    // Return the sample and weight.
204                    (s_i, 1.)
205                }
206            };
207
208            // Sample from the distribution.
209            sample[i] = s_i;
210            // Update the weight.
211            weight *= w_i;
212        });
213
214        (sample, weight)
215    }
216
217    fn sample_n(&self, n: usize) -> Self::Samples {
218        // Allocate the samples.
219        let mut samples = Array2::zeros((n, self.model.labels().len()));
220        // Allocate the weights.
221        let mut weights = Array1::zeros(n);
222
223        // Sample the weighted samples.
224        samples
225            .rows_mut()
226            .into_iter()
227            .zip(weights.iter_mut())
228            .for_each(|(mut sample, weight)| {
229                // Sample a weighted sample.
230                let (s_i, w_i) = self.sample();
231                // Assign the sample.
232                sample.assign(&s_i);
233                // Assign the weight.
234                *weight = w_i;
235            });
236
237        // Construct the samples.
238        let samples = CatTable::new(self.model.states().clone(), samples);
239
240        // Return the weighted samples.
241        CatWtdTable::new(samples, weights)
242    }
243}
244
245impl<R: Rng> BNSampler<GaussBN> for ImportanceSampler<'_, R, GaussBN, GaussEv> {
246    type Sample = GaussWtdSample;
247    type Samples = GaussWtdTable;
248
249    fn sample(&self) -> Self::Sample {
250        // Get shortened variable type.
251        use GaussEvT as E;
252
253        // Get a mutable reference to the RNG.
254        let mut rng = self.rng.borrow_mut();
255        // Allocate the sample.
256        let mut sample = Array::zeros(self.model.labels().len());
257        // Initialize the weight.
258        let mut weight = 1.;
259
260        // For each vertex in the topological order ...
261        self.model.topological_order().iter().for_each(|&i| {
262            // Get the evidence of the vertex.
263            let e_i = &self.evidence.evidences()[i];
264
265            // Get the CPD.
266            let cpd_i = &self.model.cpds()[i];
267            // Compute the index on the parents to condition on.
268            let pa_i = self.model.graph().parents(&set![i]);
269            let pa_i = pa_i.iter().map(|&z| sample[z]).collect();
270
271            // Get the evidence of the vertex.
272            let (s_i, w_i) = match e_i {
273                // If there is evidence, sample from the constrained distribution.
274                Some(e_i) => match e_i {
275                    E::CertainPositive { value, .. } => {
276                        // Get the state.
277                        let s_i = *value;
278                        // Get the probability.
279                        let p_i = cpd_i.pf(&array![s_i], &pa_i);
280                        // Return the state and its weight.
281                        (s_i, p_i)
282                    }
283                },
284                // If there is no evidence, sample as usual.
285                None => {
286                    // Sample from the distribution.
287                    let s_i = cpd_i.sample(&mut rng, &pa_i)[0];
288                    // Return the sample and weight.
289                    (s_i, 1.)
290                }
291            };
292
293            // Sample from the distribution.
294            sample[i] = s_i;
295            // Update the weight.
296            weight *= w_i;
297        });
298
299        (sample, weight)
300    }
301    fn sample_n(&self, n: usize) -> Self::Samples {
302        // Allocate the samples.
303        let mut samples = Array2::zeros((n, self.model.labels().len()));
304        // Allocate the weights.
305        let mut weights = Array1::zeros(n);
306
307        // Sample the weighted samples.
308        samples
309            .rows_mut()
310            .into_iter()
311            .zip(weights.iter_mut())
312            .for_each(|(mut sample, weight)| {
313                // Sample a weighted sample.
314                let (s_i, w_i) = self.sample();
315                // Assign the sample.
316                sample.assign(&s_i);
317                // Assign the weight.
318                *weight = w_i;
319            });
320
321        // Construct the samples.
322        let samples = GaussTable::new(self.model.labels().clone(), samples);
323
324        // Return the weighted samples.
325        GaussWtdTable::new(samples, weights)
326    }
327}
328
329impl<R: Rng + SeedableRng> ParBNSampler<CatBN> for ImportanceSampler<'_, R, CatBN, CatEv> {
330    type Samples = CatWtdTable;
331
332    fn par_sample_n(&self, n: usize) -> Self::Samples {
333        // Allocate the samples.
334        let mut samples: Array2<CatType> = Array::zeros((n, self.model.labels().len()));
335        // Allocate the weights.
336        let mut weights: Array1<f64> = Array::zeros(n);
337
338        // Get a mutable reference to the RNG.
339        let rng = self.rng.borrow_mut();
340        // Generate a random seed for each trajectory.
341        let seeds: Vec<_> = rng.random_iter().take(n).collect();
342        // Sample the trajectories in parallel.
343        seeds
344            .into_par_iter()
345            .zip(samples.axis_iter_mut(Axis(0)))
346            .zip(weights.axis_iter_mut(Axis(0)))
347            .for_each(|((seed, mut sample), mut weight)| {
348                // Create a new RNG with the seed.
349                let mut rng = R::seed_from_u64(seed);
350                // Create a new sampler with the RNG.
351                let sampler = ImportanceSampler::new(&mut rng, self.model, self.evidence);
352                // Sample a weighted sample.
353                let (s_i, w_i) = sampler.sample();
354                // Assign the sample.
355                sample.assign(&s_i);
356                // Assign the weight.
357                weight.fill(w_i);
358            });
359
360        // Construct the samples.
361        let samples = CatTable::new(self.model.states().clone(), samples);
362
363        // Return the weighted samples.
364        CatWtdTable::new(samples, weights)
365    }
366}
367
368impl<R: Rng + SeedableRng> ParBNSampler<GaussBN> for ImportanceSampler<'_, R, GaussBN, GaussEv> {
369    type Samples = GaussWtdTable;
370
371    fn par_sample_n(&self, n: usize) -> Self::Samples {
372        // Allocate the samples.
373        let mut samples: Array2<GaussType> = Array::zeros((n, self.model.labels().len()));
374        // Allocate the weights.
375        let mut weights: Array1<f64> = Array::zeros(n);
376
377        // Get a mutable reference to the RNG.
378        let rng = self.rng.borrow_mut();
379        // Generate a random seed for each trajectory.
380        let seeds: Vec<_> = rng.random_iter().take(n).collect();
381        // Sample the trajectories in parallel.
382        seeds
383            .into_par_iter()
384            .zip(samples.axis_iter_mut(Axis(0)))
385            .zip(weights.axis_iter_mut(Axis(0)))
386            .for_each(|((seed, mut sample), mut weight)| {
387                // Create a new RNG with the seed.
388                let mut rng = R::seed_from_u64(seed);
389                // Create a new sampler with the RNG.
390                let sampler = ImportanceSampler::new(&mut rng, self.model, self.evidence);
391                // Sample a weighted sample.
392                let (s_i, w_i) = sampler.sample();
393                // Assign the sample.
394                sample.assign(&s_i);
395                // Assign the weight.
396                weight.fill(w_i);
397            });
398
399        // Construct the samples.
400        let samples = GaussTable::new(self.model.labels().clone(), samples);
401
402        // Return the weighted samples.
403        GaussWtdTable::new(samples, weights)
404    }
405}
406
407impl<R: Rng> ImportanceSampler<'_, R, CatCTBN, CatTrjEv> {
408    /// Sample uncertain evidence.
409    fn sample_evidence<T: Rng>(&self, rng: &mut T) -> CatTrjEv {
410        // Get shortened variable type.
411        use CatTrjEvT as E;
412
413        // Sample the evidence for each variable.
414        let certain_evidence = self
415            .evidence
416            // Flatten the evidence.
417            .evidences()
418            .iter()
419            // Map (label, [evidence]) to (label, evidence) pairs.
420            .flatten()
421            .flat_map(|e| {
422                // Get the variable index, starting time, and ending time.
423                let (event, start_time, end_time) = (e.event(), e.start_time(), e.end_time());
424                // Sample the evidence.
425                let e = match e {
426                    E::UncertainPositiveInterval { p_states, .. } => {
427                        // Construct the sampler.
428                        let state = WeightedIndex::new(p_states).unwrap();
429                        // Sample the state.
430                        let state = state.sample(rng);
431                        // Return the sample.
432                        E::CertainPositiveInterval {
433                            event,
434                            state,
435                            start_time,
436                            end_time,
437                        }
438                    }
439                    E::UncertainNegativeInterval { p_not_states, .. } => {
440                        // Allocate the not states.
441                        let mut not_states: Set<_> = (0..p_not_states.len()).collect();
442                        // Repeat until only a subset of the not states are sampled.
443                        while not_states.len() == p_not_states.len() {
444                            // Sample the not states.
445                            not_states = p_not_states
446                                .indexed_iter()
447                                // For each (state, p_not_state) pair ...
448                                .filter_map(|(i, &p_i)| {
449                                    // ... with p_i probability, retain the state.
450                                    Some(i).filter(|_| rng.random_bool(p_i))
451                                })
452                                .collect();
453                        }
454                        // Return the sample and weight.
455                        E::CertainNegativeInterval {
456                            event,
457                            not_states,
458                            start_time,
459                            end_time,
460                        }
461                    }
462                    _ => e.clone(), // Due to evidence sampling.
463                };
464
465                // Return the certain evidence.
466                Some(e)
467            });
468
469        // Collect the certain evidence.
470        CatTrjEv::new(self.evidence.states().clone(), certain_evidence)
471    }
472
473    /// Sample transition time for variable X_i with state x_i.
474    fn sample_time<T: Rng>(
475        &self,
476        rng: &mut T,
477        evidence: &CatTrjEv,
478        event: &CatSample,
479        i: usize,
480        t: f64,
481    ) -> f64 {
482        // Get shortened variable type.
483        use CatTrjEvT as E;
484
485        // Get the evidence of the vertex.
486        let e_i = &evidence.evidences()[i];
487
488        // Check if there is certain positive evidence at this point in time.
489        let e = e_i.iter().find(|e| match e {
490            E::CertainPositiveInterval { .. } => e.contains(&t),
491            E::CertainNegativeInterval { .. } => false, // Due to state sampling.
492            _ => unreachable!(),                        // Due to evidence sampling.
493        });
494
495        // If there is certain positive evidence return the time until the end.
496        if let Some(e) = e {
497            return e.end_time() - t;
498        }
499
500        // Cast the state to usize.
501        let x = event[i] as usize;
502        // Get the CIM.
503        let cim_i = &self.model.cims()[i];
504        // Compute the index on the parents to condition on.
505        let pa_i = self.model.graph().parents(&set![i]);
506        let pa_i = pa_i.iter().map(|&z| event[z] as usize);
507        let pa_i = cim_i.conditioning_multi_index().ravel(pa_i);
508        // Get the distribution of the vertex.
509        let q_i_x = -cim_i.parameters()[[pa_i, x, x]];
510
511        // Find an upcoming evidence, if any.
512        let e = e_i.iter().find(|e| t < e.start_time());
513        // Check if there is conflict between current state and upcoming evidence.
514        let e = e.filter(|e| match e {
515            E::CertainPositiveInterval { state, .. } => *state != x,
516            E::CertainNegativeInterval { not_states, .. } => not_states.contains(&x),
517            _ => unreachable!(), // Due to evidence sampling.
518        });
519
520        // If there is a conflict ...
521        if let Some(e) = e {
522            // Get the time until the conflict.
523            let t_c = e.start_time() - t;
524            // Sample from a uniform distribution in the range [0, 1).
525            let u = rng.random_range(0.0..1.0);
526            // Sample from a truncated exponential distribution, where:
527            //  1. The lower bound is 0.
528            //  2. The upper bound is the time until the conflict.
529            //  3. The rate is the negative of the transition rate.
530            return -1. / q_i_x * f64::ln(1. - u * (1. - f64::exp(-q_i_x * t_c)));
531        }
532
533        // If there is no conflict, initialize the exponential distribution.
534        let exp_i_x = Exp::new(q_i_x).unwrap();
535        // Sample the transition time.
536        let t_i = exp_i_x.sample(rng);
537
538        // Find an upcoming evidence, if any.
539        let e = e_i.iter().find(|e| t < e.start_time());
540        // Check if there is compliance between the current state and upcoming evidence ...
541        let e = e.filter(|e| match e {
542            // ... for which starting time is greater than the sampled transition time.
543            E::CertainPositiveInterval { state, .. } => (t_i + t) > e.start_time() && *state == x,
544            E::CertainNegativeInterval { .. } => false, // Due to state sampling.
545            _ => unreachable!(),                        // Due to evidence sampling.
546        });
547
548        // If there is compliance ...
549        if let Some(e) = e {
550            // Get the time until the compliance.
551            return e.start_time() - t;
552        }
553
554        // Otherwise, return the transition time.
555        t_i
556    }
557
558    fn update_weight(
559        &self,
560        evidence: &CatTrjEv,
561        event: &CatSample,
562        i: usize,
563        t_a: f64,
564        t_b: f64,
565    ) -> f64 {
566        // Get shortened variable type.
567        use CatTrjEvT as E;
568
569        // For each ...
570        event
571            .indexed_iter()
572            .map(|(j, &y)| {
573                // Get the evidence of the vertex.
574                let e_j = &evidence.evidences()[j];
575
576                // Cast the state to usize.
577                let y = y as usize;
578                // Get the CIM.
579                let cim_j = &self.model.cims()[j];
580                // Compute the index on the parents to condition on.
581                let pa_j = self.model.graph().parents(&set![j]);
582                let pa_j = pa_j.iter().map(|&z| event[z] as usize);
583                let pa_j = cim_j.conditioning_multi_index().ravel(pa_j);
584                // Get the distribution of the vertex.
585                let q_j_y = -cim_j.parameters()[[pa_j, y, y]];
586
587                // Check if there is certain positive evidence at this point in time.
588                let e = e_j.iter().find(|e| match e {
589                    E::CertainPositiveInterval { .. } => e.contains(&t_a),
590                    E::CertainNegativeInterval { .. } => false, // Due to state sampling.
591                    _ => unreachable!(),                        // Due to evidence sampling.
592                });
593                // Find an upcoming evidence, if any. NOTE: t_a < start_time .
594                let e_next = e_j.iter().find(|e| t_a < e.start_time());
595                // Check if there is a difference between current state and upcoming evidence.
596                let e_next = e_next.filter(|e| match e {
597                    E::CertainPositiveInterval { state, .. } => *state != y,
598                    E::CertainNegativeInterval { not_states, .. } => not_states.contains(&y),
599                    _ => unreachable!(), // Due to evidence sampling.
600                });
601                // Check if current state has been set to a certain positive evidence, or
602                // if the upcoming evidence is non-existent or set given a certain negative evidence.
603                if let (
604                    Some(E::CertainPositiveInterval { .. }),
605                    None | Some(E::CertainNegativeInterval { .. }),
606                ) = (e, e_next)
607                {
608                    return f64::exp(-q_j_y * (t_b - t_a));
609                }
610
611                // Find an upcoming evidence, if any. NOTE: t_b < start_time .
612                let e = e_j.iter().find(|e| t_b < e.start_time());
613                // Check if there is conflict between current state and upcoming evidence.
614                let e = e.filter(|e| match e {
615                    E::CertainPositiveInterval { state, .. } => *state != y,
616                    E::CertainNegativeInterval { not_states, .. } => not_states.contains(&y),
617                    _ => unreachable!(), // Due to evidence sampling.
618                });
619                // If there is a conflict ...
620                if let Some(e) = e {
621                    // Get starting time of the evidence.
622                    let t_e = e.start_time();
623                    // Check if the variable is the same as the one that transitioned.
624                    return if i == j {
625                        1. - f64::exp(-q_j_y * (t_e - t_a))
626                    } else {
627                        (1. - f64::exp(-q_j_y * (t_e - t_a))) / // .
628                        (1. - f64::exp(-q_j_y * (t_e - t_b)))
629                    };
630                }
631
632                // Otherwise, return one.
633                1.
634            })
635            // Check numeric stability.
636            .map(|w| if !w.is_finite() { 1. } else { w.clamp(0., 1.) })
637            // Collect the weights.
638            .product()
639    }
640}
641
642impl<R: Rng> CTBNSampler<CatCTBN> for ImportanceSampler<'_, R, CatCTBN, CatTrjEv> {
643    type Sample = CatWtdTrj;
644    type Samples = CatWtdTrjs;
645
646    #[inline]
647    fn sample_by_length(&self, max_length: usize) -> Self::Sample {
648        // Delegate to generic function.
649        self.sample_by_length_or_time(max_length, f64::MAX)
650    }
651
652    #[inline]
653    fn sample_by_time(&self, max_time: f64) -> Self::Sample {
654        // Delegate to generic function.
655        self.sample_by_length_or_time(usize::MAX, max_time)
656    }
657
658    fn sample_by_length_or_time(&self, max_length: usize, max_time: f64) -> Self::Sample {
659        // Get shortened variable type.
660        use CatTrjEvT as E;
661
662        // Assert the model and the evidences have the same states.
663        // TODO: Move this assertion to the constructor.
664        assert_eq!(
665            self.model.states(),
666            self.evidence.states(),
667            "The model and the evidences must have the same states."
668        );
669        // Assert length is positive.
670        assert!(
671            max_length > 0,
672            "The maximum length of the trajectory must be strictly positive."
673        );
674        // Assert time is positive.
675        assert!(max_time > 0., "The maximum time must be positive.");
676
677        // Get a mutable reference to the RNG.
678        let mut rng = self.rng.borrow_mut();
679
680        // Allocate the trajectory components.
681        let mut sample_events = Vec::new();
682        let mut sample_times = Vec::new();
683
684        // Reduce the uncertain evidences to certain evidences.
685        let evidence = self.sample_evidence(&mut rng);
686
687        // Sample the initial states with given initial evidence.
688        let (mut event, mut weight) = {
689            // Get the initial state distribution.
690            let initial_d = self.model.initial_distribution();
691            // Get the initial evidence.
692            let initial_e = &evidence.initial_evidence();
693            // Initialize the sampler for the initial state.
694            let initial = ImportanceSampler::new(&mut rng, initial_d, initial_e);
695            // Sample the initial state.
696            initial.sample()
697        };
698
699        // Append the initial state to the trajectory.
700        sample_events.push(event.clone());
701        sample_times.push(0.);
702
703        // Sample the transition time.
704        let mut times: Array1<_> = (0..event.len())
705            .map(|i| self.sample_time(&mut rng, &evidence, &event, i, 0.))
706            .collect();
707
708        // Get the variable that transitions first.
709        let mut i = times.argmin().unwrap();
710        // Update the weight.
711        weight *= self.update_weight(&evidence, &event, i, 0., times[i]);
712        // Set global time.
713        let mut time = times[i];
714
715        // While:
716        //  1. the length of the trajectory is less than max_length, and ...
717        //  2. the time is less than max_time ...
718        while sample_events.len() < max_length && time < max_time {
719            // Get evidence of the vertex.
720            let e_i = &evidence.evidences()[i];
721
722            // Cast the state to usize.
723            let x = event[i] as usize;
724
725            // Check if there is evidence at this point in time.
726            let e = e_i.iter().find(|e| e.contains(&time));
727            // Check if there is certain evidence at this point in time.
728            if e.is_some_and(|e| match e {
729                E::CertainPositiveInterval { state, .. } => *state == x,
730                E::CertainNegativeInterval { not_states, .. } => !not_states.contains(&x),
731                _ => false,
732            }) {
733                // Sample the transition time.
734                times[i] = time + self.sample_time(&mut rng, &evidence, &event, i, time);
735            } else {
736                // Get the CIM.
737                let cim_i = &self.model.cims()[i];
738                // Compute the index on the parents to condition on.
739                let pa_i = self.model.graph().parents(&set![i]);
740                let pa_i = pa_i.iter().map(|&z| event[z] as usize);
741                let pa_i = cim_i.conditioning_multi_index().ravel(pa_i);
742                // Get the distribution of the vertex.
743                let mut q_i_zx = cim_i.parameters().slice(s![pa_i, x, ..]).to_owned();
744                // Set the diagonal element to zero.
745                q_i_zx[x] = 0.;
746                // Normalize the probabilities.
747                q_i_zx /= q_i_zx.sum();
748
749                // Check if there is evidence at this point in time.
750                let (s_i, w_i) = if e.is_some_and(|e| match e {
751                    E::CertainPositiveInterval { state, .. } => *state != x,
752                    _ => false,
753                }) {
754                    // Get the state of the certain positive interval.
755                    match e {
756                        Some(E::CertainPositiveInterval { state, .. }) => {
757                            (*state as CatType, q_i_zx[*state])
758                        }
759                        _ => unreachable!(), // Due to previous checks.
760                    }
761                } else {
762                    //
763                    match e {
764                        Some(E::CertainNegativeInterval { not_states, .. }) => {
765                            // Initialize the weight.
766                            let mut w_i = 1.;
767                            // Clone the distribution.
768                            let mut q_i_zx = q_i_zx.to_owned();
769                            // For each not state ...
770                            not_states.iter().for_each(|&j| {
771                                // Update the weight.
772                                w_i -= q_i_zx[j];
773                                // Zero out the not states.
774                                q_i_zx[j] = 0.;
775                            });
776                            // Normalize the probabilities.
777                            q_i_zx /= q_i_zx.sum();
778                            // Construct the sampler.
779                            let s_i = WeightedIndex::new(&q_i_zx).unwrap();
780                            // Sample the state.
781                            let s_i = s_i.sample(&mut rng) as CatType;
782                            // Return the sample and weight.
783                            (s_i, w_i)
784                        }
785                        None => {
786                            // Initialize a weighted index sampler.
787                            let s_i_zx = WeightedIndex::new(&q_i_zx).unwrap();
788                            // Sample the next event.
789                            let s_i = s_i_zx.sample(&mut rng) as CatType;
790                            // Return the sample and weight.
791                            (s_i, 1.)
792                        }
793                        _ => unreachable!(), // Due to previous checks.
794                    }
795                };
796
797                // Set the state.
798                event[i] = s_i;
799                // Update the weight.
800                weight *= w_i;
801
802                // Append the event to the trajectory.
803                sample_events.push(event.clone());
804                sample_times.push(time);
805                // Update the transition times for { X } U Ch(X).
806                std::iter::once(i)
807                    .chain(self.model.graph().children(&set![i]))
808                    .for_each(|j| {
809                        // Sample the transition time.
810                        times[j] = time + self.sample_time(&mut rng, &evidence, &event, j, time);
811                    });
812            }
813
814            // Add a small epsilon to avoid zero transition times.
815            times += EPSILON;
816            // Get the variable to transition first.
817            i = times.argmin().unwrap();
818            // Update the weight.
819            weight *= self.update_weight(&evidence, &event, i, time, times[i].min(max_time));
820            // Update the global time.
821            time = times[i];
822        }
823
824        // Get the states of the CIMs.
825        let states = self.model.states().clone();
826
827        // Convert the events to a 2D array.
828        let shape = (sample_events.len(), sample_events[0].len());
829        let sample_events = Array::from_iter(sample_events.into_iter().flatten())
830            .into_shape_with_order(shape)
831            .expect("Failed to convert events to 2D array.");
832        // Convert the times to a 1D array.
833        let sample_times = Array::from_iter(sample_times);
834
835        // Construct the trajectory.
836        let trajectory = CatTrj::new(states, sample_events, sample_times);
837
838        // Return the trajectory and its weight.
839        (trajectory, weight).into()
840    }
841
842    #[inline]
843    fn sample_n_by_length(&self, max_length: usize, n: usize) -> Self::Samples {
844        (0..n).map(|_| self.sample_by_length(max_length)).collect()
845    }
846
847    #[inline]
848    fn sample_n_by_time(&self, max_time: f64, n: usize) -> Self::Samples {
849        (0..n).map(|_| self.sample_by_time(max_time)).collect()
850    }
851
852    #[inline]
853    fn sample_n_by_length_or_time(
854        &self,
855        max_length: usize,
856        max_time: f64,
857        n: usize,
858    ) -> Self::Samples {
859        (0..n)
860            .map(|_| self.sample_by_length_or_time(max_length, max_time))
861            .collect()
862    }
863}
864
865impl<R: Rng + SeedableRng> ParCTBNSampler<CatCTBN> for ImportanceSampler<'_, R, CatCTBN, CatTrjEv> {
866    type Samples = CatWtdTrjs;
867
868    #[inline]
869    fn par_sample_n_by_length(&self, max_length: usize, n: usize) -> Self::Samples {
870        self.par_sample_n_by_length_or_time(max_length, f64::MAX, n)
871    }
872
873    #[inline]
874    fn par_sample_n_by_time(&self, max_time: f64, n: usize) -> Self::Samples {
875        self.par_sample_n_by_length_or_time(usize::MAX, max_time, n)
876    }
877
878    fn par_sample_n_by_length_or_time(
879        &self,
880        max_length: usize,
881        max_time: f64,
882        n: usize,
883    ) -> Self::Samples {
884        // Get a mutable reference to the RNG.
885        let rng = self.rng.borrow_mut();
886        // Generate a random seed for each trajectory.
887        let seeds: Vec<_> = rng.random_iter().take(n).collect();
888        // Sample the trajectories in parallel.
889        seeds
890            .into_par_iter()
891            .map(|seed| {
892                // Create a new random number generator with the seed.
893                let mut rng = R::seed_from_u64(seed);
894                // Create a new sampler with the random number generator and model.
895                let sampler = ImportanceSampler::new(&mut rng, self.model, self.evidence);
896                // Sample the trajectory.
897                sampler.sample_by_length_or_time(max_length, max_time)
898            })
899            .collect()
900    }
901}