causal_hub/samplers/
forward.rs

1use core::f64;
2use std::cell::RefCell;
3
4use ndarray::prelude::*;
5use ndarray_stats::QuantileExt;
6use rand::{
7    Rng, SeedableRng,
8    distr::{Distribution, weighted::WeightedIndex},
9};
10use rand_distr::Exp;
11use rayon::prelude::*;
12
13use crate::{
14    datasets::{CatSample, CatTable, CatTrj, CatType, GaussTable},
15    models::{BN, CIM, CPD, CTBN, CatBN, CatCTBN, GaussBN, Labelled},
16    samplers::{BNSampler, CTBNSampler, ParBNSampler, ParCTBNSampler},
17    set,
18    types::EPSILON,
19};
20
21/// A forward sampler.
22#[derive(Debug)]
23pub struct ForwardSampler<'a, R, M> {
24    rng: RefCell<&'a mut R>,
25    model: &'a M,
26}
27
28impl<'a, R, M> ForwardSampler<'a, R, M> {
29    /// Construct a new forward sampler.
30    ///
31    /// # Arguments
32    ///
33    /// * `rng` - A random number generator.
34    /// * `model` - A reference to the model to sample from.
35    ///
36    /// # Returns
37    ///
38    /// Return a new `ForwardSampler` instance.
39    ///
40    #[inline]
41    pub const fn new(rng: &'a mut R, model: &'a M) -> Self {
42        // Wrap the RNG in a RefCell to allow interior mutability.
43        let rng = RefCell::new(rng);
44
45        Self { rng, model }
46    }
47}
48
49impl<R: Rng> BNSampler<CatBN> for ForwardSampler<'_, R, CatBN> {
50    type Sample = <CatBN as BN>::Sample;
51    type Samples = <CatBN as BN>::Samples;
52
53    fn sample(&self) -> Self::Sample {
54        // Get a mutable reference to the RNG.
55        let mut rng = self.rng.borrow_mut();
56        // Allocate the sample.
57        let mut sample = Array::zeros(self.model.labels().len());
58
59        // For each vertex in the topological order ...
60        self.model.topological_order().iter().for_each(|&i| {
61            // Get the CPD.
62            let cpd_i = &self.model.cpds()[i];
63            // Compute the index on the parents to condition on.
64            let pa_i = self.model.graph().parents(&set![i]);
65            let pa_i = pa_i.iter().map(|&z| sample[z]).collect();
66            // Sample from the distribution.
67            sample[i] = cpd_i.sample(&mut rng, &pa_i)[0];
68        });
69
70        sample
71    }
72
73    fn sample_n(&self, n: usize) -> Self::Samples {
74        // Allocate the dataset.
75        let mut dataset = Array::zeros((n, self.model.labels().len()));
76
77        // For each sample ...
78        dataset.rows_mut().into_iter().for_each(|mut row| {
79            // Sample from the distribution.
80            row.assign(&self.sample());
81        });
82
83        // Construct the dataset.
84        CatTable::new(self.model.states().clone(), dataset)
85    }
86}
87
88impl<R: Rng + SeedableRng> ParBNSampler<CatBN> for ForwardSampler<'_, R, CatBN> {
89    type Samples = <CatBN as BN>::Samples;
90
91    fn par_sample_n(&self, n: usize) -> Self::Samples {
92        // Get a mutable reference to the RNG.
93        let rng = self.rng.borrow_mut();
94        // Generate a random seed for each sample.
95        let seeds: Vec<_> = rng.random_iter().take(n).collect();
96
97        // Allocate the samples.
98        let mut samples = Array::zeros((n, self.model.labels().len()));
99
100        // Sample the samples in parallel.
101        seeds
102            .into_par_iter()
103            .zip(samples.axis_iter_mut(Axis(0)))
104            .for_each(|(seed, mut row)| {
105                // Create a new random number generator with the seed.
106                let mut rng = R::seed_from_u64(seed);
107                // Create a new sampler with the random number generator and model.
108                let sampler = ForwardSampler::new(&mut rng, self.model);
109                // Sample from the distribution.
110                row.assign(&sampler.sample());
111            });
112
113        // Construct the dataset.
114        CatTable::new(self.model.states().clone(), samples)
115    }
116}
117
118impl<R: Rng> BNSampler<GaussBN> for ForwardSampler<'_, R, GaussBN> {
119    type Sample = <GaussBN as BN>::Sample;
120    type Samples = <GaussBN as BN>::Samples;
121
122    fn sample(&self) -> Self::Sample {
123        // Get a mutable reference to the RNG.
124        let mut rng = self.rng.borrow_mut();
125        // Allocate the sample.
126        let mut sample = Array::zeros(self.model.labels().len());
127
128        // For each vertex in the topological order ...
129        self.model.topological_order().iter().for_each(|&i| {
130            // Get the CPD.
131            let cpd_i = &self.model.cpds()[i];
132            // Get the parents.
133            let pa_i = self.model.graph().parents(&set![i]);
134            let pa_i = pa_i.iter().map(|&z| sample[z]).collect();
135            // Compute the value of the variable.
136            sample[i] = cpd_i.sample(&mut rng, &pa_i)[0];
137        });
138
139        sample
140    }
141
142    fn sample_n(&self, n: usize) -> Self::Samples {
143        // Allocate the samples.
144        let mut samples = Array::zeros((n, self.model.labels().len()));
145
146        // For each sample ...
147        samples.rows_mut().into_iter().for_each(|mut row| {
148            // Sample from the distribution.
149            row.assign(&self.sample());
150        });
151
152        // Construct the dataset.
153        GaussTable::new(self.model.labels().clone(), samples)
154    }
155}
156
157impl<R: Rng + SeedableRng> ParBNSampler<GaussBN> for ForwardSampler<'_, R, GaussBN> {
158    type Samples = <GaussBN as BN>::Samples;
159
160    fn par_sample_n(&self, n: usize) -> Self::Samples {
161        // Get a mutable reference to the RNG.
162        let rng = self.rng.borrow_mut();
163        // Generate a random seed for each sample.
164        let seeds: Vec<_> = rng.random_iter().take(n).collect();
165
166        // Allocate the samples.
167        let mut samples = Array::zeros((n, self.model.labels().len()));
168
169        // Sample the samples in parallel.
170        seeds
171            .into_par_iter()
172            .zip(samples.axis_iter_mut(Axis(0)))
173            .for_each(|(seed, mut row)| {
174                // Create a new random number generator with the seed.
175                let mut rng = R::seed_from_u64(seed);
176                // Create a new sampler with the random number generator and model.
177                let sampler = ForwardSampler::new(&mut rng, self.model);
178                // Sample from the distribution.
179                row.assign(&sampler.sample());
180            });
181
182        // Construct the dataset.
183        GaussTable::new(self.model.labels().clone(), samples)
184    }
185}
186
187impl<R: Rng> ForwardSampler<'_, R, CatCTBN> {
188    /// Sample transition time for variable X_i with state x_i.
189    fn sample_time(&self, event: &CatSample, i: usize) -> f64 {
190        // Cast the state to usize.
191        let x = event[i] as usize;
192        // Get the CIM.
193        let cim_i = &self.model.cims()[i];
194        // Compute the index on the parents to condition on.
195        let pa_i = self.model.graph().parents(&set![i]);
196        let pa_i = pa_i.iter().map(|&z| event[z] as usize);
197        let pa_i = cim_i.conditioning_multi_index().ravel(pa_i);
198        // Get the distribution of the vertex.
199        let q_i_x = -cim_i.parameters()[[pa_i, x, x]];
200        // Initialize the exponential distribution.
201        let exp_i_x = Exp::new(q_i_x).unwrap();
202        // Sample the transition time.
203        exp_i_x.sample(&mut self.rng.borrow_mut())
204    }
205}
206
207impl<R: Rng> CTBNSampler<CatCTBN> for ForwardSampler<'_, R, CatCTBN> {
208    type Sample = <CatCTBN as CTBN>::Trajectory;
209    type Samples = <CatCTBN as CTBN>::Trajectories;
210
211    #[inline]
212    fn sample_by_length(&self, max_length: usize) -> Self::Sample {
213        // Delegate to generic function.
214        self.sample_by_length_or_time(max_length, f64::MAX)
215    }
216
217    #[inline]
218    fn sample_by_time(&self, max_time: f64) -> Self::Sample {
219        // Delegate to generic function.
220        self.sample_by_length_or_time(usize::MAX, max_time)
221    }
222
223    fn sample_by_length_or_time(&self, max_length: usize, max_time: f64) -> Self::Sample {
224        // Assert length is positive.
225        assert!(
226            max_length > 0,
227            "The maximum length of the trajectory must be strictly positive."
228        );
229        // Assert time is positive.
230        assert!(max_time > 0., "The maximum time must be positive.");
231
232        // Allocate the trajectory components.
233        let mut sample_events = Vec::new();
234        let mut sample_times = Vec::new();
235
236        // Sample the initial states.
237        let mut event = {
238            let mut rng = self.rng.borrow_mut();
239            let initial = self.model.initial_distribution();
240            let initial = ForwardSampler::new(&mut rng, initial);
241            initial.sample()
242        };
243        // Append the initial state to the trajectory.
244        sample_events.push(event.clone());
245        sample_times.push(0.);
246
247        // Sample the transition time.
248        let mut times: Array1<_> = (0..event.len())
249            .map(|i| self.sample_time(&event, i))
250            .collect();
251
252        // Get the variable that transitions first.
253        let mut i = times.argmin().unwrap();
254        // Set global time.
255        let mut time = times[i];
256
257        // While:
258        //  1. the length of the trajectory is less than max_length, and ...
259        //  2. the time is less than max_time ...
260        while sample_events.len() < max_length && time < max_time {
261            // Cast the state to usize.
262            let x = event[i] as usize;
263            // Get the CIM.
264            let cim_i = &self.model.cims()[i];
265            // Compute the index on the parents to condition on.
266            let pa_i = self.model.graph().parents(&set![i]);
267            let pa_i = pa_i.iter().map(|&z| event[z] as usize);
268            let pa_i = cim_i.conditioning_multi_index().ravel(pa_i);
269            // Get the distribution of the vertex.
270            let mut q_i_zx = cim_i.parameters().slice(s![pa_i, x, ..]).to_owned();
271            // Set the diagonal element to zero.
272            q_i_zx[x] = 0.;
273            // Normalize the probabilities.
274            q_i_zx /= q_i_zx.sum();
275            // Initialize a weighted index sampler.
276            let s_i_zx = WeightedIndex::new(&q_i_zx).unwrap();
277            // Sample the next event.
278            event[i] = s_i_zx.sample(&mut self.rng.borrow_mut()) as CatType;
279            // Append the event to the trajectory.
280            sample_events.push(event.clone());
281            sample_times.push(time);
282            // Update the transition times for { X } U Ch(X).
283            std::iter::once(i)
284                .chain(self.model.graph().children(&set![i]))
285                .for_each(|j| {
286                    // Sample the transition time.
287                    times[j] = time + self.sample_time(&event, j);
288                });
289            // Add a small epsilon to avoid zero transition times.
290            times += EPSILON;
291            // Get the variable to transition first.
292            i = times.argmin().unwrap();
293            // Update the global time.
294            time = times[i];
295        }
296
297        // Get the states of the CIMs.
298        let states = self.model.states().clone();
299
300        // Convert the events to a 2D array.
301        let shape = (sample_events.len(), sample_events[0].len());
302        let sample_events = Array::from_iter(sample_events.into_iter().flatten())
303            .into_shape_with_order(shape)
304            .expect("Failed to convert events to 2D array.");
305        // Convert the times to a 1D array.
306        let sample_times = Array::from_iter(sample_times);
307
308        // Return the trajectory.
309        CatTrj::new(states, sample_events, sample_times)
310    }
311
312    #[inline]
313    fn sample_n_by_length(&self, max_length: usize, n: usize) -> Self::Samples {
314        (0..n).map(|_| self.sample_by_length(max_length)).collect()
315    }
316
317    #[inline]
318    fn sample_n_by_time(&self, max_time: f64, n: usize) -> Self::Samples {
319        (0..n).map(|_| self.sample_by_time(max_time)).collect()
320    }
321
322    #[inline]
323    fn sample_n_by_length_or_time(
324        &self,
325        max_length: usize,
326        max_time: f64,
327        n: usize,
328    ) -> Self::Samples {
329        (0..n)
330            .map(|_| self.sample_by_length_or_time(max_length, max_time))
331            .collect()
332    }
333}
334
335impl<R: Rng + SeedableRng> ParCTBNSampler<CatCTBN> for ForwardSampler<'_, R, CatCTBN> {
336    type Samples = <CatCTBN as CTBN>::Trajectories;
337
338    #[inline]
339    fn par_sample_n_by_length(&self, max_length: usize, n: usize) -> Self::Samples {
340        self.par_sample_n_by_length_or_time(max_length, f64::MAX, n)
341    }
342
343    #[inline]
344    fn par_sample_n_by_time(&self, max_time: f64, n: usize) -> Self::Samples {
345        self.par_sample_n_by_length_or_time(usize::MAX, max_time, n)
346    }
347
348    fn par_sample_n_by_length_or_time(
349        &self,
350        max_length: usize,
351        max_time: f64,
352        n: usize,
353    ) -> Self::Samples {
354        // Get a mutable reference to the RNG.
355        let rng = self.rng.borrow_mut();
356        // Generate a random seed for each trajectory.
357        let seeds: Vec<_> = rng.random_iter().take(n).collect();
358        // Sample the trajectories in parallel.
359        seeds
360            .into_par_iter()
361            .map(|seed| {
362                // Create a new random number generator with the seed.
363                let mut rng = R::seed_from_u64(seed);
364                // Create a new sampler with the random number generator and model.
365                let sampler = ForwardSampler::new(&mut rng, self.model);
366                // Sample the trajectory.
367                sampler.sample_by_length_or_time(max_length, max_time)
368            })
369            .collect()
370    }
371}