1use core::f64;
2use std::cell::RefCell;
3
4use ndarray::prelude::*;
5use ndarray_stats::QuantileExt;
6use rand::{
7 Rng, SeedableRng,
8 distr::{Distribution, weighted::WeightedIndex},
9};
10use rand_distr::Exp;
11use rayon::prelude::*;
12
13use crate::{
14 datasets::{CatSample, CatTable, CatTrj, CatType, GaussTable},
15 models::{BN, CIM, CPD, CTBN, CatBN, CatCTBN, GaussBN, Labelled},
16 samplers::{BNSampler, CTBNSampler, ParBNSampler, ParCTBNSampler},
17 set,
18 types::EPSILON,
19};
20
21#[derive(Debug)]
23pub struct ForwardSampler<'a, R, M> {
24 rng: RefCell<&'a mut R>,
25 model: &'a M,
26}
27
28impl<'a, R, M> ForwardSampler<'a, R, M> {
29 #[inline]
41 pub const fn new(rng: &'a mut R, model: &'a M) -> Self {
42 let rng = RefCell::new(rng);
44
45 Self { rng, model }
46 }
47}
48
49impl<R: Rng> BNSampler<CatBN> for ForwardSampler<'_, R, CatBN> {
50 type Sample = <CatBN as BN>::Sample;
51 type Samples = <CatBN as BN>::Samples;
52
53 fn sample(&self) -> Self::Sample {
54 let mut rng = self.rng.borrow_mut();
56 let mut sample = Array::zeros(self.model.labels().len());
58
59 self.model.topological_order().iter().for_each(|&i| {
61 let cpd_i = &self.model.cpds()[i];
63 let pa_i = self.model.graph().parents(&set![i]);
65 let pa_i = pa_i.iter().map(|&z| sample[z]).collect();
66 sample[i] = cpd_i.sample(&mut rng, &pa_i)[0];
68 });
69
70 sample
71 }
72
73 fn sample_n(&self, n: usize) -> Self::Samples {
74 let mut dataset = Array::zeros((n, self.model.labels().len()));
76
77 dataset.rows_mut().into_iter().for_each(|mut row| {
79 row.assign(&self.sample());
81 });
82
83 CatTable::new(self.model.states().clone(), dataset)
85 }
86}
87
88impl<R: Rng + SeedableRng> ParBNSampler<CatBN> for ForwardSampler<'_, R, CatBN> {
89 type Samples = <CatBN as BN>::Samples;
90
91 fn par_sample_n(&self, n: usize) -> Self::Samples {
92 let rng = self.rng.borrow_mut();
94 let seeds: Vec<_> = rng.random_iter().take(n).collect();
96
97 let mut samples = Array::zeros((n, self.model.labels().len()));
99
100 seeds
102 .into_par_iter()
103 .zip(samples.axis_iter_mut(Axis(0)))
104 .for_each(|(seed, mut row)| {
105 let mut rng = R::seed_from_u64(seed);
107 let sampler = ForwardSampler::new(&mut rng, self.model);
109 row.assign(&sampler.sample());
111 });
112
113 CatTable::new(self.model.states().clone(), samples)
115 }
116}
117
118impl<R: Rng> BNSampler<GaussBN> for ForwardSampler<'_, R, GaussBN> {
119 type Sample = <GaussBN as BN>::Sample;
120 type Samples = <GaussBN as BN>::Samples;
121
122 fn sample(&self) -> Self::Sample {
123 let mut rng = self.rng.borrow_mut();
125 let mut sample = Array::zeros(self.model.labels().len());
127
128 self.model.topological_order().iter().for_each(|&i| {
130 let cpd_i = &self.model.cpds()[i];
132 let pa_i = self.model.graph().parents(&set![i]);
134 let pa_i = pa_i.iter().map(|&z| sample[z]).collect();
135 sample[i] = cpd_i.sample(&mut rng, &pa_i)[0];
137 });
138
139 sample
140 }
141
142 fn sample_n(&self, n: usize) -> Self::Samples {
143 let mut samples = Array::zeros((n, self.model.labels().len()));
145
146 samples.rows_mut().into_iter().for_each(|mut row| {
148 row.assign(&self.sample());
150 });
151
152 GaussTable::new(self.model.labels().clone(), samples)
154 }
155}
156
157impl<R: Rng + SeedableRng> ParBNSampler<GaussBN> for ForwardSampler<'_, R, GaussBN> {
158 type Samples = <GaussBN as BN>::Samples;
159
160 fn par_sample_n(&self, n: usize) -> Self::Samples {
161 let rng = self.rng.borrow_mut();
163 let seeds: Vec<_> = rng.random_iter().take(n).collect();
165
166 let mut samples = Array::zeros((n, self.model.labels().len()));
168
169 seeds
171 .into_par_iter()
172 .zip(samples.axis_iter_mut(Axis(0)))
173 .for_each(|(seed, mut row)| {
174 let mut rng = R::seed_from_u64(seed);
176 let sampler = ForwardSampler::new(&mut rng, self.model);
178 row.assign(&sampler.sample());
180 });
181
182 GaussTable::new(self.model.labels().clone(), samples)
184 }
185}
186
187impl<R: Rng> ForwardSampler<'_, R, CatCTBN> {
188 fn sample_time(&self, event: &CatSample, i: usize) -> f64 {
190 let x = event[i] as usize;
192 let cim_i = &self.model.cims()[i];
194 let pa_i = self.model.graph().parents(&set![i]);
196 let pa_i = pa_i.iter().map(|&z| event[z] as usize);
197 let pa_i = cim_i.conditioning_multi_index().ravel(pa_i);
198 let q_i_x = -cim_i.parameters()[[pa_i, x, x]];
200 let exp_i_x = Exp::new(q_i_x).unwrap();
202 exp_i_x.sample(&mut self.rng.borrow_mut())
204 }
205}
206
207impl<R: Rng> CTBNSampler<CatCTBN> for ForwardSampler<'_, R, CatCTBN> {
208 type Sample = <CatCTBN as CTBN>::Trajectory;
209 type Samples = <CatCTBN as CTBN>::Trajectories;
210
211 #[inline]
212 fn sample_by_length(&self, max_length: usize) -> Self::Sample {
213 self.sample_by_length_or_time(max_length, f64::MAX)
215 }
216
217 #[inline]
218 fn sample_by_time(&self, max_time: f64) -> Self::Sample {
219 self.sample_by_length_or_time(usize::MAX, max_time)
221 }
222
223 fn sample_by_length_or_time(&self, max_length: usize, max_time: f64) -> Self::Sample {
224 assert!(
226 max_length > 0,
227 "The maximum length of the trajectory must be strictly positive."
228 );
229 assert!(max_time > 0., "The maximum time must be positive.");
231
232 let mut sample_events = Vec::new();
234 let mut sample_times = Vec::new();
235
236 let mut event = {
238 let mut rng = self.rng.borrow_mut();
239 let initial = self.model.initial_distribution();
240 let initial = ForwardSampler::new(&mut rng, initial);
241 initial.sample()
242 };
243 sample_events.push(event.clone());
245 sample_times.push(0.);
246
247 let mut times: Array1<_> = (0..event.len())
249 .map(|i| self.sample_time(&event, i))
250 .collect();
251
252 let mut i = times.argmin().unwrap();
254 let mut time = times[i];
256
257 while sample_events.len() < max_length && time < max_time {
261 let x = event[i] as usize;
263 let cim_i = &self.model.cims()[i];
265 let pa_i = self.model.graph().parents(&set![i]);
267 let pa_i = pa_i.iter().map(|&z| event[z] as usize);
268 let pa_i = cim_i.conditioning_multi_index().ravel(pa_i);
269 let mut q_i_zx = cim_i.parameters().slice(s![pa_i, x, ..]).to_owned();
271 q_i_zx[x] = 0.;
273 q_i_zx /= q_i_zx.sum();
275 let s_i_zx = WeightedIndex::new(&q_i_zx).unwrap();
277 event[i] = s_i_zx.sample(&mut self.rng.borrow_mut()) as CatType;
279 sample_events.push(event.clone());
281 sample_times.push(time);
282 std::iter::once(i)
284 .chain(self.model.graph().children(&set![i]))
285 .for_each(|j| {
286 times[j] = time + self.sample_time(&event, j);
288 });
289 times += EPSILON;
291 i = times.argmin().unwrap();
293 time = times[i];
295 }
296
297 let states = self.model.states().clone();
299
300 let shape = (sample_events.len(), sample_events[0].len());
302 let sample_events = Array::from_iter(sample_events.into_iter().flatten())
303 .into_shape_with_order(shape)
304 .expect("Failed to convert events to 2D array.");
305 let sample_times = Array::from_iter(sample_times);
307
308 CatTrj::new(states, sample_events, sample_times)
310 }
311
312 #[inline]
313 fn sample_n_by_length(&self, max_length: usize, n: usize) -> Self::Samples {
314 (0..n).map(|_| self.sample_by_length(max_length)).collect()
315 }
316
317 #[inline]
318 fn sample_n_by_time(&self, max_time: f64, n: usize) -> Self::Samples {
319 (0..n).map(|_| self.sample_by_time(max_time)).collect()
320 }
321
322 #[inline]
323 fn sample_n_by_length_or_time(
324 &self,
325 max_length: usize,
326 max_time: f64,
327 n: usize,
328 ) -> Self::Samples {
329 (0..n)
330 .map(|_| self.sample_by_length_or_time(max_length, max_time))
331 .collect()
332 }
333}
334
335impl<R: Rng + SeedableRng> ParCTBNSampler<CatCTBN> for ForwardSampler<'_, R, CatCTBN> {
336 type Samples = <CatCTBN as CTBN>::Trajectories;
337
338 #[inline]
339 fn par_sample_n_by_length(&self, max_length: usize, n: usize) -> Self::Samples {
340 self.par_sample_n_by_length_or_time(max_length, f64::MAX, n)
341 }
342
343 #[inline]
344 fn par_sample_n_by_time(&self, max_time: f64, n: usize) -> Self::Samples {
345 self.par_sample_n_by_length_or_time(usize::MAX, max_time, n)
346 }
347
348 fn par_sample_n_by_length_or_time(
349 &self,
350 max_length: usize,
351 max_time: f64,
352 n: usize,
353 ) -> Self::Samples {
354 let rng = self.rng.borrow_mut();
356 let seeds: Vec<_> = rng.random_iter().take(n).collect();
358 seeds
360 .into_par_iter()
361 .map(|seed| {
362 let mut rng = R::seed_from_u64(seed);
364 let sampler = ForwardSampler::new(&mut rng, self.model);
366 sampler.sample_by_length_or_time(max_length, max_time)
368 })
369 .collect()
370 }
371}