causal_hub/inference/
approximate_inference.rs

1use std::cell::RefCell;
2
3use dry::macro_for;
4use rand::{Rng, SeedableRng};
5
6use crate::{
7    estimators::{CPDEstimator, MLE},
8    inference::Modelled,
9    models::{BN, CatBN, GaussBN, Labelled},
10    samplers::{BNSampler, ForwardSampler, ImportanceSampler, ParBNSampler},
11    types::Set,
12};
13
14/// An approximate inference engine.
15#[derive(Debug)]
16pub struct ApproximateInference<'a, R, M, E> {
17    rng: RefCell<&'a mut R>,
18    model: &'a M,
19    evidence: Option<&'a E>,
20    sample_size: Option<usize>,
21}
22
23impl<'a, R, M> ApproximateInference<'a, R, M, ()> {
24    /// Construct a new approximate inference instance.
25    ///
26    /// # Arguments
27    ///
28    /// * `rng` - A random number generator.
29    /// * `model` - A reference to the model to sample from.
30    ///
31    /// # Returns
32    ///
33    /// Return a new approximate inference instance.
34    ///
35    #[inline]
36    pub const fn new(rng: &'a mut R, model: &'a M) -> Self {
37        // Wrap the RNG in a RefCell.
38        let rng = RefCell::new(rng);
39
40        Self {
41            rng,
42            model,
43            evidence: None,
44            sample_size: None,
45        }
46    }
47}
48
49impl<'a, R, M, E> ApproximateInference<'a, R, M, E> {
50    /// Add evidence to the approximate inference instance.
51    ///
52    /// # Arguments
53    ///
54    /// * `evidence` - A reference to the evidence.
55    ///
56    /// # Returns
57    ///
58    /// Return a new approximate inference instance with evidence.
59    ///
60    #[inline]
61    pub const fn with_evidence<T>(self, evidence: &'a T) -> ApproximateInference<'a, R, M, T> {
62        ApproximateInference {
63            rng: self.rng,
64            model: self.model,
65            evidence: Some(evidence),
66            sample_size: self.sample_size,
67        }
68    }
69
70    /// Set the sample size for the approximate inference instance.
71    ///
72    /// # Arguments
73    ///
74    /// * `n` - The sample size.
75    ///
76    /// # Panics
77    ///
78    /// * Panics if `n` is zero.
79    ///
80    /// # Returns
81    ///
82    /// Return a new approximate inference instance with the specified sample size.
83    ///
84    #[inline]
85    pub const fn with_sample_size(mut self, n: usize) -> Self {
86        // Assert the sample size is positive.
87        assert!(n > 0, "Sample size must be positive.");
88        // Set the sample size.
89        self.sample_size = Some(n);
90        self
91    }
92}
93
94impl<R, M, E> Modelled<M> for ApproximateInference<'_, R, M, E> {
95    #[inline]
96    fn model(&self) -> &M {
97        self.model
98    }
99}
100
101/// A trait for inference with Bayesian Networks.
102pub trait BNInference<T>
103where
104    T: BN,
105{
106    /// Estimate the values of `x` conditioned on `z` using `n` samples.
107    ///
108    /// # Arguments
109    ///
110    /// * `x` - The set of variables.
111    /// * `z` - The set of conditioning variables.
112    ///
113    /// # Panics
114    ///
115    /// * Panics if `x` is empty.
116    /// * Panics if `x` and `z` are not disjoint.
117    /// * Panics if `x` or `z` are not in the model.
118    ///
119    /// # Returns
120    ///
121    /// The estimated values of `x` conditioned on `z`.
122    ///
123    fn estimate(&self, x: &Set<usize>, z: &Set<usize>) -> T::CPD;
124}
125
126impl<'a, R, E> ApproximateInference<'a, R, CatBN, E> {
127    #[inline]
128    fn sample_size(&self, x: &Set<usize>, z: &Set<usize>) -> usize {
129        // Get the sample size or compute it if not provided.
130        self.sample_size.unwrap_or_else(|| {
131            // Get the shape of the variables X and Z.
132            let (x_shape, z_shape): (usize, usize) = (
133                x.iter().map(|&i| self.model.shape()[i]).product(),
134                z.iter().map(|&i| self.model.shape()[i]).product(),
135            );
136            // Return the sample size as PAC-like bounds:
137            //  (|Z| * (|X| - 1)) * ln(1 / delta) / epsilon^2, or approximately
138            //  (|Z| * (|X| - 1)) * 1200 for delta = 0.05 and epsilon = 0.05.
139            z_shape * (x_shape - 1) * 1200
140        })
141    }
142}
143
144impl<'a, R, E> ApproximateInference<'a, R, GaussBN, E> {
145    #[inline]
146    fn sample_size(&self, x: &Set<usize>, z: &Set<usize>) -> usize {
147        // Get the sample size or compute it if not provided.
148        self.sample_size.unwrap_or_else(|| {
149            // Get the shape of the variables X and Z.
150            let (x_shape, z_shape) = (x.len(), z.len());
151            // Return the sample size as PAC-like bounds:
152            //  (|X| * |Z| + (|X| * (|X| + 1)) / 2) * ln(1 / delta) / epsilon^2, or approximately
153            //  (|X| * |Z| + (|X| * (|X| + 1)) / 2) * 1200, for delta = 0.05 and epsilon = 0.05.
154            //  |X| * (|Z| + (|X| + 1) / 2) * 1200, for delta = 0.05 and epsilon = 0.05.
155            x_shape * (z_shape + x_shape.div_ceil(2)) * 1200
156        })
157    }
158}
159
160macro_for!($type in [CatBN, GaussBN] {
161
162    impl<R: Rng> BNInference<$type> for ApproximateInference<'_, R, $type, ()> {
163        fn estimate(&self, x: &Set<usize>, z: &Set<usize>) -> <$type as BN>::CPD {
164            // Assert X is not empty.
165            assert!(!x.is_empty(), "Variables X must not be empty.");
166            // Assert X and Z are disjoint.
167            assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
168            // Assert X and Z are in the model.
169            assert!(
170                x.union(z).all(|&i| i < self.model.labels().len()),
171                "Variables X and Z must be in the model."
172            );
173
174            // Get the sample size.
175            let n = self.sample_size(x, z);
176            // Get the RNG.
177            let mut rng = self.rng.borrow_mut();
178            // Initialize the sampler.
179            let sampler = ForwardSampler::new(&mut rng, self.model);
180            // Generate n samples from the model.
181            // TODO: Avoid generating the full dataset,
182            //       e.g., by only sampling the variables in X U Z, and
183            //       by using batching to reduce memory usage.
184            let dataset = sampler.sample_n(n);
185            // Initialize the estimator.
186            let estimator = MLE::new(&dataset);
187            // Fit the CPD.
188            estimator.fit(x, z)
189        }
190    }
191
192    impl<R: Rng> BNInference<$type> for ApproximateInference<'_, R, $type, <$type as BN>::Evidence> {
193        fn estimate(&self, x: &Set<usize>, z: &Set<usize>) -> <$type as BN>::CPD {
194            // Assert X is not empty.
195            assert!(!x.is_empty(), "Variables X must not be empty.");
196            // Assert X and Z are disjoint.
197            assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
198            // Assert X and Z are in the model.
199            assert!(
200                x.union(z).all(|&i| i < self.model.labels().len()),
201                "Variables X and Z must be in the model."
202            );
203
204            // Get the sample size.
205            let n = self.sample_size(x, z);
206            // Get the RNG.
207            let mut rng = self.rng.borrow_mut();
208            // Check if evidence is actually provided.
209            match self.evidence {
210                // Get the evidence.
211                Some(evidence) => {
212                    // Initialize the sampler.
213                    let sampler = ImportanceSampler::new(&mut rng, self.model, evidence);
214                    // Generate n samples from the model.
215                    // TODO: Avoid generating the full dataset,
216                    //       e.g., by only sampling the variables in X U Z, and
217                    //       by using batching to reduce memory usage.
218                    let dataset = sampler.sample_n(n);
219                    // Initialize the estimator.
220                    let estimator = MLE::new(&dataset);
221                    // Fit the CPD.
222                    estimator.fit(x, z)
223                }
224                // Delegate to empty evidence case.
225                None => ApproximateInference::new(&mut rng, self.model)
226                    .with_sample_size(n)
227                    .estimate(x, z),
228            }
229        }
230    }
231
232});
233
234/// A trait for parallel inference with Bayesian Networks.
235pub trait ParBNInference<T>
236where
237    T: BN,
238{
239    /// Estimate the values of `x` conditioned on `z` using `n` samples, in parallel.
240    ///
241    /// # Arguments
242    ///
243    /// * `x` - The set of variables.
244    /// * `z` - The set of conditioning variables.
245    ///
246    /// # Panics
247    ///
248    /// * Panics if `x` is empty.
249    /// * Panics if `x` and `z` are not disjoint.
250    /// * Panics if `x` or `z` are not in the model.
251    ///
252    /// # Returns
253    ///
254    /// The estimated values of `x` conditioned on `z`.
255    ///
256    fn par_estimate(&self, x: &Set<usize>, z: &Set<usize>) -> T::CPD;
257}
258
259macro_for!($type in [CatBN, GaussBN] {
260
261    impl<R: Rng + SeedableRng> ParBNInference<$type> for ApproximateInference<'_, R, $type, ()> {
262        fn par_estimate(&self, x: &Set<usize>, z: &Set<usize>) -> <$type as BN>::CPD {
263            // Assert X is not empty.
264            assert!(!x.is_empty(), "Variables X must not be empty.");
265            // Assert X and Z are disjoint.
266            assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
267            // Assert X and Z are in the model.
268            assert!(
269                x.union(z).all(|&i| i < self.model.labels().len()),
270                "Variables X and Z must be in the model."
271            );
272
273            // Get the sample size.
274            let n = self.sample_size(x, z);
275            // Get the RNG.
276            let mut rng = self.rng.borrow_mut();
277            // Initialize the sampler.
278            let sampler = ForwardSampler::<R, _>::new(&mut rng, self.model);
279            // Generate n samples from the model.
280            // TODO: Avoid generating the full dataset,
281            //       e.g., by only sampling the variables in X U Z, and
282            //       by using batching to reduce memory usage.
283            let dataset = sampler.par_sample_n(n);
284            // Initialize the estimator.
285            let estimator = MLE::new(&dataset);
286            // Fit the CPD.
287            estimator.fit(x, z)
288        }
289    }
290
291    impl<R: Rng + SeedableRng> ParBNInference<$type> for ApproximateInference<'_, R, $type, <$type as BN>::Evidence> {
292        fn par_estimate(&self, x: &Set<usize>, z: &Set<usize>) -> <$type as BN>::CPD {
293            // Assert X is not empty.
294            assert!(!x.is_empty(), "Variables X must not be empty.");
295            // Assert X and Z are disjoint.
296            assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
297            // Assert X and Z are in the model.
298            assert!(
299                x.union(z).all(|&i| i < self.model.labels().len()),
300                "Variables X and Z must be in the model."
301            );
302
303            // Get the sample size.
304            let n = self.sample_size(x, z);
305            // Get the RNG.
306            let mut rng = self.rng.borrow_mut();
307            // Check if evidence is actually provided.
308            match self.evidence {
309                // Get the evidence.
310                Some(evidence) => {
311                    // Initialize the sampler.
312                    let sampler = ImportanceSampler::<R, _, _>::new(&mut rng, self.model, evidence);
313                    // Generate n samples from the model.
314                    // TODO: Avoid generating the full dataset,
315                    //       e.g., by only sampling the variables in X U Z, and
316                    //       by using batching to reduce memory usage.
317                    let dataset = sampler.par_sample_n(n);
318                    // Initialize the estimator.
319                    let estimator = MLE::new(&dataset);
320                    // Fit the CPD.
321                    estimator.fit(x, z)
322                }
323                // Delegate to empty evidence case.
324                None => ApproximateInference::new(&mut rng, self.model)
325                    .with_sample_size(n)
326                    .estimate(x, z),
327            }
328        }
329    }
330
331});