causal_hub/samplers/importance.rs
1use std::cell::RefCell;
2
3use ndarray::prelude::*;
4use ndarray_stats::QuantileExt;
5use rand::{
6 Rng, SeedableRng,
7 distr::{Distribution, weighted::WeightedIndex},
8};
9use rand_distr::Exp;
10use rayon::prelude::*;
11
12use crate::{
13 datasets::{
14 CatEv, CatEvT, CatSample, CatTable, CatTrj, CatTrjEv, CatTrjEvT, CatType, CatWtdSample,
15 CatWtdTable, CatWtdTrj, CatWtdTrjs, GaussEv, GaussEvT, GaussTable, GaussType,
16 GaussWtdSample, GaussWtdTable,
17 },
18 models::{BN, CIM, CPD, CTBN, CatBN, CatCTBN, GaussBN, Labelled},
19 samplers::{BNSampler, CTBNSampler, ParBNSampler, ParCTBNSampler},
20 set,
21 types::{EPSILON, Set},
22};
23
24/// A struct for sampling using importance sampling.
25#[derive(Debug)]
26pub struct ImportanceSampler<'a, R, M, E> {
27 rng: RefCell<&'a mut R>,
28 model: &'a M,
29 evidence: &'a E,
30}
31
32impl<'a, R, M, E> ImportanceSampler<'a, R, M, E>
33where
34 M: Labelled,
35 E: Labelled,
36{
37 /// Construct a new importance sampler.
38 ///
39 /// # Arguments
40 ///
41 /// * `rng` - A random number generator.
42 /// * `model` - A reference to the model to sample from.
43 /// * `evidence` - A reference to the evidence to sample from.
44 ///
45 /// # Returns
46 ///
47 /// Return a new `ImportanceSampler` instance.
48 ///
49 #[inline]
50 pub fn new(rng: &'a mut R, model: &'a M, evidence: &'a E) -> Self {
51 // Wrap the RNG in a RefCell to allow interior mutability.
52 let rng = RefCell::new(rng);
53
54 // Assert the model and the evidences have the same labels.
55 assert_eq!(
56 model.labels(),
57 evidence.labels(),
58 "The model and the evidences must have the same variables."
59 );
60
61 Self {
62 rng,
63 model,
64 evidence,
65 }
66 }
67}
68
69impl<R: Rng> ImportanceSampler<'_, R, CatBN, CatEv> {
70 /// Sample uncertain evidence.
71 fn sample_evidence<T: Rng>(&self, rng: &mut T) -> CatEv {
72 // Get shortened variable type.
73 use CatEvT as E;
74
75 // Sample the evidence for each variable.
76 let certain_evidence = self
77 .evidence
78 // Flatten the evidence.
79 .evidences()
80 .iter()
81 // Filter empty evidences.
82 .filter_map(|e| {
83 e.as_ref().map(|e| {
84 // Get the event index.
85 let event = e.event();
86 // Sample the evidence.
87 match e {
88 E::UncertainPositive { p_states, .. } => {
89 // Construct the sampler.
90 let state = WeightedIndex::new(p_states).unwrap();
91 // Sample the state.
92 let state = state.sample(rng);
93 // Return the sample.
94 E::CertainPositive { event, state }
95 }
96 E::UncertainNegative { p_not_states, .. } => {
97 // Allocate the not states.
98 let mut not_states: Set<_> = (0..p_not_states.len()).collect();
99 // Repeat until only a subset of the not states are sampled.
100 while not_states.len() == p_not_states.len() {
101 // Sample the not states.
102 not_states = p_not_states
103 .indexed_iter()
104 // For each (state, p_not_state) pair ...
105 .filter_map(|(i, &p_i)| {
106 // ... with p_i probability, retain the state.
107 Some(i).filter(|_| rng.random_bool(p_i))
108 })
109 .collect();
110 }
111 // Return the sample and weight.
112 E::CertainNegative { event, not_states }
113 }
114 _ => e.clone(), // Due to evidence sampling.
115 }
116 })
117 });
118
119 // Collect the certain evidence.
120 CatEv::new(self.evidence.states().clone(), certain_evidence)
121 }
122}
123
124impl<R: Rng> BNSampler<CatBN> for ImportanceSampler<'_, R, CatBN, CatEv> {
125 type Sample = CatWtdSample;
126 type Samples = CatWtdTable;
127
128 fn sample(&self) -> Self::Sample {
129 // Get shortened variable type.
130 use CatEvT as E;
131
132 // Assert the model and the evidences have the same states.
133 // TODO: Move this assertion to the constructor.
134 assert_eq!(
135 self.model.states(),
136 self.evidence.states(),
137 "The model and the evidences must have the same states."
138 );
139
140 // Get a mutable reference to the RNG.
141 let mut rng = self.rng.borrow_mut();
142 // Allocate the sample.
143 let mut sample = Array::zeros(self.model.labels().len());
144 // Initialize the weight.
145 let mut weight = 1.;
146
147 // Reduce the uncertain evidences to certain evidences.
148 let evidence = self.sample_evidence(&mut rng);
149
150 // For each vertex in the topological order ...
151 self.model.topological_order().iter().for_each(|&i| {
152 // Get the evidence of the vertex.
153 let e_i = &evidence.evidences()[i];
154
155 // Get the CPD.
156 let cpd_i = &self.model.cpds()[i];
157 // Compute the index on the parents to condition on.
158 let pa_i = self.model.graph().parents(&set![i]);
159 let pa_i = pa_i.iter().map(|&z| sample[z] as usize);
160 let pa_i = cpd_i.conditioning_multi_index().ravel(pa_i);
161 // Get the distribution of the vertex.
162 let p_i = cpd_i.parameters().row(pa_i);
163
164 // Get the evidence of the vertex.
165 let (s_i, w_i) = match e_i {
166 // If there is evidence, sample from the constrained distribution.
167 Some(e_i) => match e_i {
168 E::CertainPositive { state, .. } => {
169 // Get the state.
170 let s_i = *state as CatType;
171 // Return the state and its weight.
172 (s_i, p_i[*state])
173 }
174 E::CertainNegative { not_states, .. } => {
175 // Initialize the weight.
176 let mut w_i = 1.;
177 // Clone the distribution.
178 let mut p_i = p_i.to_owned();
179 // For each not state ...
180 not_states.iter().for_each(|&j| {
181 // Update the weight.
182 w_i -= p_i[j];
183 // Zero out the not states.
184 p_i[j] = 0.;
185 });
186 // Normalize the probabilities.
187 p_i /= p_i.sum();
188 // Construct the sampler.
189 let s_i = WeightedIndex::new(&p_i).unwrap();
190 // Sample the state.
191 let s_i = s_i.sample(&mut rng) as CatType;
192 // Return the sample and weight.
193 (s_i, w_i)
194 }
195 _ => unreachable!(), // Due to evidence sampling.
196 },
197 // If there is no evidence, sample as usual.
198 None => {
199 // Construct the sampler.
200 let s_i = WeightedIndex::new(&p_i).unwrap();
201 // Sample the state.
202 let s_i = s_i.sample(&mut rng) as CatType;
203 // Return the sample and weight.
204 (s_i, 1.)
205 }
206 };
207
208 // Sample from the distribution.
209 sample[i] = s_i;
210 // Update the weight.
211 weight *= w_i;
212 });
213
214 (sample, weight)
215 }
216
217 fn sample_n(&self, n: usize) -> Self::Samples {
218 // Allocate the samples.
219 let mut samples = Array2::zeros((n, self.model.labels().len()));
220 // Allocate the weights.
221 let mut weights = Array1::zeros(n);
222
223 // Sample the weighted samples.
224 samples
225 .rows_mut()
226 .into_iter()
227 .zip(weights.iter_mut())
228 .for_each(|(mut sample, weight)| {
229 // Sample a weighted sample.
230 let (s_i, w_i) = self.sample();
231 // Assign the sample.
232 sample.assign(&s_i);
233 // Assign the weight.
234 *weight = w_i;
235 });
236
237 // Construct the samples.
238 let samples = CatTable::new(self.model.states().clone(), samples);
239
240 // Return the weighted samples.
241 CatWtdTable::new(samples, weights)
242 }
243}
244
245impl<R: Rng> BNSampler<GaussBN> for ImportanceSampler<'_, R, GaussBN, GaussEv> {
246 type Sample = GaussWtdSample;
247 type Samples = GaussWtdTable;
248
249 fn sample(&self) -> Self::Sample {
250 // Get shortened variable type.
251 use GaussEvT as E;
252
253 // Get a mutable reference to the RNG.
254 let mut rng = self.rng.borrow_mut();
255 // Allocate the sample.
256 let mut sample = Array::zeros(self.model.labels().len());
257 // Initialize the weight.
258 let mut weight = 1.;
259
260 // For each vertex in the topological order ...
261 self.model.topological_order().iter().for_each(|&i| {
262 // Get the evidence of the vertex.
263 let e_i = &self.evidence.evidences()[i];
264
265 // Get the CPD.
266 let cpd_i = &self.model.cpds()[i];
267 // Compute the index on the parents to condition on.
268 let pa_i = self.model.graph().parents(&set![i]);
269 let pa_i = pa_i.iter().map(|&z| sample[z]).collect();
270
271 // Get the evidence of the vertex.
272 let (s_i, w_i) = match e_i {
273 // If there is evidence, sample from the constrained distribution.
274 Some(e_i) => match e_i {
275 E::CertainPositive { value, .. } => {
276 // Get the state.
277 let s_i = *value;
278 // Get the probability.
279 let p_i = cpd_i.pf(&array![s_i], &pa_i);
280 // Return the state and its weight.
281 (s_i, p_i)
282 }
283 },
284 // If there is no evidence, sample as usual.
285 None => {
286 // Sample from the distribution.
287 let s_i = cpd_i.sample(&mut rng, &pa_i)[0];
288 // Return the sample and weight.
289 (s_i, 1.)
290 }
291 };
292
293 // Sample from the distribution.
294 sample[i] = s_i;
295 // Update the weight.
296 weight *= w_i;
297 });
298
299 (sample, weight)
300 }
301 fn sample_n(&self, n: usize) -> Self::Samples {
302 // Allocate the samples.
303 let mut samples = Array2::zeros((n, self.model.labels().len()));
304 // Allocate the weights.
305 let mut weights = Array1::zeros(n);
306
307 // Sample the weighted samples.
308 samples
309 .rows_mut()
310 .into_iter()
311 .zip(weights.iter_mut())
312 .for_each(|(mut sample, weight)| {
313 // Sample a weighted sample.
314 let (s_i, w_i) = self.sample();
315 // Assign the sample.
316 sample.assign(&s_i);
317 // Assign the weight.
318 *weight = w_i;
319 });
320
321 // Construct the samples.
322 let samples = GaussTable::new(self.model.labels().clone(), samples);
323
324 // Return the weighted samples.
325 GaussWtdTable::new(samples, weights)
326 }
327}
328
329impl<R: Rng + SeedableRng> ParBNSampler<CatBN> for ImportanceSampler<'_, R, CatBN, CatEv> {
330 type Samples = CatWtdTable;
331
332 fn par_sample_n(&self, n: usize) -> Self::Samples {
333 // Allocate the samples.
334 let mut samples: Array2<CatType> = Array::zeros((n, self.model.labels().len()));
335 // Allocate the weights.
336 let mut weights: Array1<f64> = Array::zeros(n);
337
338 // Get a mutable reference to the RNG.
339 let rng = self.rng.borrow_mut();
340 // Generate a random seed for each trajectory.
341 let seeds: Vec<_> = rng.random_iter().take(n).collect();
342 // Sample the trajectories in parallel.
343 seeds
344 .into_par_iter()
345 .zip(samples.axis_iter_mut(Axis(0)))
346 .zip(weights.axis_iter_mut(Axis(0)))
347 .for_each(|((seed, mut sample), mut weight)| {
348 // Create a new RNG with the seed.
349 let mut rng = R::seed_from_u64(seed);
350 // Create a new sampler with the RNG.
351 let sampler = ImportanceSampler::new(&mut rng, self.model, self.evidence);
352 // Sample a weighted sample.
353 let (s_i, w_i) = sampler.sample();
354 // Assign the sample.
355 sample.assign(&s_i);
356 // Assign the weight.
357 weight.fill(w_i);
358 });
359
360 // Construct the samples.
361 let samples = CatTable::new(self.model.states().clone(), samples);
362
363 // Return the weighted samples.
364 CatWtdTable::new(samples, weights)
365 }
366}
367
368impl<R: Rng + SeedableRng> ParBNSampler<GaussBN> for ImportanceSampler<'_, R, GaussBN, GaussEv> {
369 type Samples = GaussWtdTable;
370
371 fn par_sample_n(&self, n: usize) -> Self::Samples {
372 // Allocate the samples.
373 let mut samples: Array2<GaussType> = Array::zeros((n, self.model.labels().len()));
374 // Allocate the weights.
375 let mut weights: Array1<f64> = Array::zeros(n);
376
377 // Get a mutable reference to the RNG.
378 let rng = self.rng.borrow_mut();
379 // Generate a random seed for each trajectory.
380 let seeds: Vec<_> = rng.random_iter().take(n).collect();
381 // Sample the trajectories in parallel.
382 seeds
383 .into_par_iter()
384 .zip(samples.axis_iter_mut(Axis(0)))
385 .zip(weights.axis_iter_mut(Axis(0)))
386 .for_each(|((seed, mut sample), mut weight)| {
387 // Create a new RNG with the seed.
388 let mut rng = R::seed_from_u64(seed);
389 // Create a new sampler with the RNG.
390 let sampler = ImportanceSampler::new(&mut rng, self.model, self.evidence);
391 // Sample a weighted sample.
392 let (s_i, w_i) = sampler.sample();
393 // Assign the sample.
394 sample.assign(&s_i);
395 // Assign the weight.
396 weight.fill(w_i);
397 });
398
399 // Construct the samples.
400 let samples = GaussTable::new(self.model.labels().clone(), samples);
401
402 // Return the weighted samples.
403 GaussWtdTable::new(samples, weights)
404 }
405}
406
407impl<R: Rng> ImportanceSampler<'_, R, CatCTBN, CatTrjEv> {
408 /// Sample uncertain evidence.
409 fn sample_evidence<T: Rng>(&self, rng: &mut T) -> CatTrjEv {
410 // Get shortened variable type.
411 use CatTrjEvT as E;
412
413 // Sample the evidence for each variable.
414 let certain_evidence = self
415 .evidence
416 // Flatten the evidence.
417 .evidences()
418 .iter()
419 // Map (label, [evidence]) to (label, evidence) pairs.
420 .flatten()
421 .flat_map(|e| {
422 // Get the variable index, starting time, and ending time.
423 let (event, start_time, end_time) = (e.event(), e.start_time(), e.end_time());
424 // Sample the evidence.
425 let e = match e {
426 E::UncertainPositiveInterval { p_states, .. } => {
427 // Construct the sampler.
428 let state = WeightedIndex::new(p_states).unwrap();
429 // Sample the state.
430 let state = state.sample(rng);
431 // Return the sample.
432 E::CertainPositiveInterval {
433 event,
434 state,
435 start_time,
436 end_time,
437 }
438 }
439 E::UncertainNegativeInterval { p_not_states, .. } => {
440 // Allocate the not states.
441 let mut not_states: Set<_> = (0..p_not_states.len()).collect();
442 // Repeat until only a subset of the not states are sampled.
443 while not_states.len() == p_not_states.len() {
444 // Sample the not states.
445 not_states = p_not_states
446 .indexed_iter()
447 // For each (state, p_not_state) pair ...
448 .filter_map(|(i, &p_i)| {
449 // ... with p_i probability, retain the state.
450 Some(i).filter(|_| rng.random_bool(p_i))
451 })
452 .collect();
453 }
454 // Return the sample and weight.
455 E::CertainNegativeInterval {
456 event,
457 not_states,
458 start_time,
459 end_time,
460 }
461 }
462 _ => e.clone(), // Due to evidence sampling.
463 };
464
465 // Return the certain evidence.
466 Some(e)
467 });
468
469 // Collect the certain evidence.
470 CatTrjEv::new(self.evidence.states().clone(), certain_evidence)
471 }
472
473 /// Sample transition time for variable X_i with state x_i.
474 fn sample_time<T: Rng>(
475 &self,
476 rng: &mut T,
477 evidence: &CatTrjEv,
478 event: &CatSample,
479 i: usize,
480 t: f64,
481 ) -> f64 {
482 // Get shortened variable type.
483 use CatTrjEvT as E;
484
485 // Get the evidence of the vertex.
486 let e_i = &evidence.evidences()[i];
487
488 // Check if there is certain positive evidence at this point in time.
489 let e = e_i.iter().find(|e| match e {
490 E::CertainPositiveInterval { .. } => e.contains(&t),
491 E::CertainNegativeInterval { .. } => false, // Due to state sampling.
492 _ => unreachable!(), // Due to evidence sampling.
493 });
494
495 // If there is certain positive evidence return the time until the end.
496 if let Some(e) = e {
497 return e.end_time() - t;
498 }
499
500 // Cast the state to usize.
501 let x = event[i] as usize;
502 // Get the CIM.
503 let cim_i = &self.model.cims()[i];
504 // Compute the index on the parents to condition on.
505 let pa_i = self.model.graph().parents(&set![i]);
506 let pa_i = pa_i.iter().map(|&z| event[z] as usize);
507 let pa_i = cim_i.conditioning_multi_index().ravel(pa_i);
508 // Get the distribution of the vertex.
509 let q_i_x = -cim_i.parameters()[[pa_i, x, x]];
510
511 // Find an upcoming evidence, if any.
512 let e = e_i.iter().find(|e| t < e.start_time());
513 // Check if there is conflict between current state and upcoming evidence.
514 let e = e.filter(|e| match e {
515 E::CertainPositiveInterval { state, .. } => *state != x,
516 E::CertainNegativeInterval { not_states, .. } => not_states.contains(&x),
517 _ => unreachable!(), // Due to evidence sampling.
518 });
519
520 // If there is a conflict ...
521 if let Some(e) = e {
522 // Get the time until the conflict.
523 let t_c = e.start_time() - t;
524 // Sample from a uniform distribution in the range [0, 1).
525 let u = rng.random_range(0.0..1.0);
526 // Sample from a truncated exponential distribution, where:
527 // 1. The lower bound is 0.
528 // 2. The upper bound is the time until the conflict.
529 // 3. The rate is the negative of the transition rate.
530 return -1. / q_i_x * f64::ln(1. - u * (1. - f64::exp(-q_i_x * t_c)));
531 }
532
533 // If there is no conflict, initialize the exponential distribution.
534 let exp_i_x = Exp::new(q_i_x).unwrap();
535 // Sample the transition time.
536 let t_i = exp_i_x.sample(rng);
537
538 // Find an upcoming evidence, if any.
539 let e = e_i.iter().find(|e| t < e.start_time());
540 // Check if there is compliance between the current state and upcoming evidence ...
541 let e = e.filter(|e| match e {
542 // ... for which starting time is greater than the sampled transition time.
543 E::CertainPositiveInterval { state, .. } => (t_i + t) > e.start_time() && *state == x,
544 E::CertainNegativeInterval { .. } => false, // Due to state sampling.
545 _ => unreachable!(), // Due to evidence sampling.
546 });
547
548 // If there is compliance ...
549 if let Some(e) = e {
550 // Get the time until the compliance.
551 return e.start_time() - t;
552 }
553
554 // Otherwise, return the transition time.
555 t_i
556 }
557
558 fn update_weight(
559 &self,
560 evidence: &CatTrjEv,
561 event: &CatSample,
562 i: usize,
563 t_a: f64,
564 t_b: f64,
565 ) -> f64 {
566 // Get shortened variable type.
567 use CatTrjEvT as E;
568
569 // For each ...
570 event
571 .indexed_iter()
572 .map(|(j, &y)| {
573 // Get the evidence of the vertex.
574 let e_j = &evidence.evidences()[j];
575
576 // Cast the state to usize.
577 let y = y as usize;
578 // Get the CIM.
579 let cim_j = &self.model.cims()[j];
580 // Compute the index on the parents to condition on.
581 let pa_j = self.model.graph().parents(&set![j]);
582 let pa_j = pa_j.iter().map(|&z| event[z] as usize);
583 let pa_j = cim_j.conditioning_multi_index().ravel(pa_j);
584 // Get the distribution of the vertex.
585 let q_j_y = -cim_j.parameters()[[pa_j, y, y]];
586
587 // Check if there is certain positive evidence at this point in time.
588 let e = e_j.iter().find(|e| match e {
589 E::CertainPositiveInterval { .. } => e.contains(&t_a),
590 E::CertainNegativeInterval { .. } => false, // Due to state sampling.
591 _ => unreachable!(), // Due to evidence sampling.
592 });
593 // Find an upcoming evidence, if any. NOTE: t_a < start_time .
594 let e_next = e_j.iter().find(|e| t_a < e.start_time());
595 // Check if there is a difference between current state and upcoming evidence.
596 let e_next = e_next.filter(|e| match e {
597 E::CertainPositiveInterval { state, .. } => *state != y,
598 E::CertainNegativeInterval { not_states, .. } => not_states.contains(&y),
599 _ => unreachable!(), // Due to evidence sampling.
600 });
601 // Check if current state has been set to a certain positive evidence, or
602 // if the upcoming evidence is non-existent or set given a certain negative evidence.
603 if let (
604 Some(E::CertainPositiveInterval { .. }),
605 None | Some(E::CertainNegativeInterval { .. }),
606 ) = (e, e_next)
607 {
608 return f64::exp(-q_j_y * (t_b - t_a));
609 }
610
611 // Find an upcoming evidence, if any. NOTE: t_b < start_time .
612 let e = e_j.iter().find(|e| t_b < e.start_time());
613 // Check if there is conflict between current state and upcoming evidence.
614 let e = e.filter(|e| match e {
615 E::CertainPositiveInterval { state, .. } => *state != y,
616 E::CertainNegativeInterval { not_states, .. } => not_states.contains(&y),
617 _ => unreachable!(), // Due to evidence sampling.
618 });
619 // If there is a conflict ...
620 if let Some(e) = e {
621 // Get starting time of the evidence.
622 let t_e = e.start_time();
623 // Check if the variable is the same as the one that transitioned.
624 return if i == j {
625 1. - f64::exp(-q_j_y * (t_e - t_a))
626 } else {
627 (1. - f64::exp(-q_j_y * (t_e - t_a))) / // .
628 (1. - f64::exp(-q_j_y * (t_e - t_b)))
629 };
630 }
631
632 // Otherwise, return one.
633 1.
634 })
635 // Check numeric stability.
636 .map(|w| if !w.is_finite() { 1. } else { w.clamp(0., 1.) })
637 // Collect the weights.
638 .product()
639 }
640}
641
642impl<R: Rng> CTBNSampler<CatCTBN> for ImportanceSampler<'_, R, CatCTBN, CatTrjEv> {
643 type Sample = CatWtdTrj;
644 type Samples = CatWtdTrjs;
645
646 #[inline]
647 fn sample_by_length(&self, max_length: usize) -> Self::Sample {
648 // Delegate to generic function.
649 self.sample_by_length_or_time(max_length, f64::MAX)
650 }
651
652 #[inline]
653 fn sample_by_time(&self, max_time: f64) -> Self::Sample {
654 // Delegate to generic function.
655 self.sample_by_length_or_time(usize::MAX, max_time)
656 }
657
658 fn sample_by_length_or_time(&self, max_length: usize, max_time: f64) -> Self::Sample {
659 // Get shortened variable type.
660 use CatTrjEvT as E;
661
662 // Assert the model and the evidences have the same states.
663 // TODO: Move this assertion to the constructor.
664 assert_eq!(
665 self.model.states(),
666 self.evidence.states(),
667 "The model and the evidences must have the same states."
668 );
669 // Assert length is positive.
670 assert!(
671 max_length > 0,
672 "The maximum length of the trajectory must be strictly positive."
673 );
674 // Assert time is positive.
675 assert!(max_time > 0., "The maximum time must be positive.");
676
677 // Get a mutable reference to the RNG.
678 let mut rng = self.rng.borrow_mut();
679
680 // Allocate the trajectory components.
681 let mut sample_events = Vec::new();
682 let mut sample_times = Vec::new();
683
684 // Reduce the uncertain evidences to certain evidences.
685 let evidence = self.sample_evidence(&mut rng);
686
687 // Sample the initial states with given initial evidence.
688 let (mut event, mut weight) = {
689 // Get the initial state distribution.
690 let initial_d = self.model.initial_distribution();
691 // Get the initial evidence.
692 let initial_e = &evidence.initial_evidence();
693 // Initialize the sampler for the initial state.
694 let initial = ImportanceSampler::new(&mut rng, initial_d, initial_e);
695 // Sample the initial state.
696 initial.sample()
697 };
698
699 // Append the initial state to the trajectory.
700 sample_events.push(event.clone());
701 sample_times.push(0.);
702
703 // Sample the transition time.
704 let mut times: Array1<_> = (0..event.len())
705 .map(|i| self.sample_time(&mut rng, &evidence, &event, i, 0.))
706 .collect();
707
708 // Get the variable that transitions first.
709 let mut i = times.argmin().unwrap();
710 // Update the weight.
711 weight *= self.update_weight(&evidence, &event, i, 0., times[i]);
712 // Set global time.
713 let mut time = times[i];
714
715 // While:
716 // 1. the length of the trajectory is less than max_length, and ...
717 // 2. the time is less than max_time ...
718 while sample_events.len() < max_length && time < max_time {
719 // Get evidence of the vertex.
720 let e_i = &evidence.evidences()[i];
721
722 // Cast the state to usize.
723 let x = event[i] as usize;
724
725 // Check if there is evidence at this point in time.
726 let e = e_i.iter().find(|e| e.contains(&time));
727 // Check if there is certain evidence at this point in time.
728 if e.is_some_and(|e| match e {
729 E::CertainPositiveInterval { state, .. } => *state == x,
730 E::CertainNegativeInterval { not_states, .. } => !not_states.contains(&x),
731 _ => false,
732 }) {
733 // Sample the transition time.
734 times[i] = time + self.sample_time(&mut rng, &evidence, &event, i, time);
735 } else {
736 // Get the CIM.
737 let cim_i = &self.model.cims()[i];
738 // Compute the index on the parents to condition on.
739 let pa_i = self.model.graph().parents(&set![i]);
740 let pa_i = pa_i.iter().map(|&z| event[z] as usize);
741 let pa_i = cim_i.conditioning_multi_index().ravel(pa_i);
742 // Get the distribution of the vertex.
743 let mut q_i_zx = cim_i.parameters().slice(s![pa_i, x, ..]).to_owned();
744 // Set the diagonal element to zero.
745 q_i_zx[x] = 0.;
746 // Normalize the probabilities.
747 q_i_zx /= q_i_zx.sum();
748
749 // Check if there is evidence at this point in time.
750 let (s_i, w_i) = if e.is_some_and(|e| match e {
751 E::CertainPositiveInterval { state, .. } => *state != x,
752 _ => false,
753 }) {
754 // Get the state of the certain positive interval.
755 match e {
756 Some(E::CertainPositiveInterval { state, .. }) => {
757 (*state as CatType, q_i_zx[*state])
758 }
759 _ => unreachable!(), // Due to previous checks.
760 }
761 } else {
762 //
763 match e {
764 Some(E::CertainNegativeInterval { not_states, .. }) => {
765 // Initialize the weight.
766 let mut w_i = 1.;
767 // Clone the distribution.
768 let mut q_i_zx = q_i_zx.to_owned();
769 // For each not state ...
770 not_states.iter().for_each(|&j| {
771 // Update the weight.
772 w_i -= q_i_zx[j];
773 // Zero out the not states.
774 q_i_zx[j] = 0.;
775 });
776 // Normalize the probabilities.
777 q_i_zx /= q_i_zx.sum();
778 // Construct the sampler.
779 let s_i = WeightedIndex::new(&q_i_zx).unwrap();
780 // Sample the state.
781 let s_i = s_i.sample(&mut rng) as CatType;
782 // Return the sample and weight.
783 (s_i, w_i)
784 }
785 None => {
786 // Initialize a weighted index sampler.
787 let s_i_zx = WeightedIndex::new(&q_i_zx).unwrap();
788 // Sample the next event.
789 let s_i = s_i_zx.sample(&mut rng) as CatType;
790 // Return the sample and weight.
791 (s_i, 1.)
792 }
793 _ => unreachable!(), // Due to previous checks.
794 }
795 };
796
797 // Set the state.
798 event[i] = s_i;
799 // Update the weight.
800 weight *= w_i;
801
802 // Append the event to the trajectory.
803 sample_events.push(event.clone());
804 sample_times.push(time);
805 // Update the transition times for { X } U Ch(X).
806 std::iter::once(i)
807 .chain(self.model.graph().children(&set![i]))
808 .for_each(|j| {
809 // Sample the transition time.
810 times[j] = time + self.sample_time(&mut rng, &evidence, &event, j, time);
811 });
812 }
813
814 // Add a small epsilon to avoid zero transition times.
815 times += EPSILON;
816 // Get the variable to transition first.
817 i = times.argmin().unwrap();
818 // Update the weight.
819 weight *= self.update_weight(&evidence, &event, i, time, times[i].min(max_time));
820 // Update the global time.
821 time = times[i];
822 }
823
824 // Get the states of the CIMs.
825 let states = self.model.states().clone();
826
827 // Convert the events to a 2D array.
828 let shape = (sample_events.len(), sample_events[0].len());
829 let sample_events = Array::from_iter(sample_events.into_iter().flatten())
830 .into_shape_with_order(shape)
831 .expect("Failed to convert events to 2D array.");
832 // Convert the times to a 1D array.
833 let sample_times = Array::from_iter(sample_times);
834
835 // Construct the trajectory.
836 let trajectory = CatTrj::new(states, sample_events, sample_times);
837
838 // Return the trajectory and its weight.
839 (trajectory, weight).into()
840 }
841
842 #[inline]
843 fn sample_n_by_length(&self, max_length: usize, n: usize) -> Self::Samples {
844 (0..n).map(|_| self.sample_by_length(max_length)).collect()
845 }
846
847 #[inline]
848 fn sample_n_by_time(&self, max_time: f64, n: usize) -> Self::Samples {
849 (0..n).map(|_| self.sample_by_time(max_time)).collect()
850 }
851
852 #[inline]
853 fn sample_n_by_length_or_time(
854 &self,
855 max_length: usize,
856 max_time: f64,
857 n: usize,
858 ) -> Self::Samples {
859 (0..n)
860 .map(|_| self.sample_by_length_or_time(max_length, max_time))
861 .collect()
862 }
863}
864
865impl<R: Rng + SeedableRng> ParCTBNSampler<CatCTBN> for ImportanceSampler<'_, R, CatCTBN, CatTrjEv> {
866 type Samples = CatWtdTrjs;
867
868 #[inline]
869 fn par_sample_n_by_length(&self, max_length: usize, n: usize) -> Self::Samples {
870 self.par_sample_n_by_length_or_time(max_length, f64::MAX, n)
871 }
872
873 #[inline]
874 fn par_sample_n_by_time(&self, max_time: f64, n: usize) -> Self::Samples {
875 self.par_sample_n_by_length_or_time(usize::MAX, max_time, n)
876 }
877
878 fn par_sample_n_by_length_or_time(
879 &self,
880 max_length: usize,
881 max_time: f64,
882 n: usize,
883 ) -> Self::Samples {
884 // Get a mutable reference to the RNG.
885 let rng = self.rng.borrow_mut();
886 // Generate a random seed for each trajectory.
887 let seeds: Vec<_> = rng.random_iter().take(n).collect();
888 // Sample the trajectories in parallel.
889 seeds
890 .into_par_iter()
891 .map(|seed| {
892 // Create a new random number generator with the seed.
893 let mut rng = R::seed_from_u64(seed);
894 // Create a new sampler with the random number generator and model.
895 let sampler = ImportanceSampler::new(&mut rng, self.model, self.evidence);
896 // Sample the trajectory.
897 sampler.sample_by_length_or_time(max_length, max_time)
898 })
899 .collect()
900 }
901}