causal_hub/inference/approximate_inference.rs
1use std::cell::RefCell;
2
3use dry::macro_for;
4use rand::{Rng, SeedableRng};
5
6use crate::{
7 estimators::{CPDEstimator, MLE},
8 inference::Modelled,
9 models::{BN, CatBN, GaussBN, Labelled},
10 samplers::{BNSampler, ForwardSampler, ImportanceSampler, ParBNSampler},
11 types::Set,
12};
13
14/// An approximate inference engine.
15#[derive(Debug)]
16pub struct ApproximateInference<'a, R, M, E> {
17 rng: RefCell<&'a mut R>,
18 model: &'a M,
19 evidence: Option<&'a E>,
20 sample_size: Option<usize>,
21}
22
23impl<'a, R, M> ApproximateInference<'a, R, M, ()> {
24 /// Construct a new approximate inference instance.
25 ///
26 /// # Arguments
27 ///
28 /// * `rng` - A random number generator.
29 /// * `model` - A reference to the model to sample from.
30 ///
31 /// # Returns
32 ///
33 /// Return a new approximate inference instance.
34 ///
35 #[inline]
36 pub const fn new(rng: &'a mut R, model: &'a M) -> Self {
37 // Wrap the RNG in a RefCell.
38 let rng = RefCell::new(rng);
39
40 Self {
41 rng,
42 model,
43 evidence: None,
44 sample_size: None,
45 }
46 }
47}
48
49impl<'a, R, M, E> ApproximateInference<'a, R, M, E> {
50 /// Add evidence to the approximate inference instance.
51 ///
52 /// # Arguments
53 ///
54 /// * `evidence` - A reference to the evidence.
55 ///
56 /// # Returns
57 ///
58 /// Return a new approximate inference instance with evidence.
59 ///
60 #[inline]
61 pub const fn with_evidence<T>(self, evidence: &'a T) -> ApproximateInference<'a, R, M, T> {
62 ApproximateInference {
63 rng: self.rng,
64 model: self.model,
65 evidence: Some(evidence),
66 sample_size: self.sample_size,
67 }
68 }
69
70 /// Set the sample size for the approximate inference instance.
71 ///
72 /// # Arguments
73 ///
74 /// * `n` - The sample size.
75 ///
76 /// # Panics
77 ///
78 /// * Panics if `n` is zero.
79 ///
80 /// # Returns
81 ///
82 /// Return a new approximate inference instance with the specified sample size.
83 ///
84 #[inline]
85 pub const fn with_sample_size(mut self, n: usize) -> Self {
86 // Assert the sample size is positive.
87 assert!(n > 0, "Sample size must be positive.");
88 // Set the sample size.
89 self.sample_size = Some(n);
90 self
91 }
92}
93
94impl<R, M, E> Modelled<M> for ApproximateInference<'_, R, M, E> {
95 #[inline]
96 fn model(&self) -> &M {
97 self.model
98 }
99}
100
101/// A trait for inference with Bayesian Networks.
102pub trait BNInference<T>
103where
104 T: BN,
105{
106 /// Estimate the values of `x` conditioned on `z` using `n` samples.
107 ///
108 /// # Arguments
109 ///
110 /// * `x` - The set of variables.
111 /// * `z` - The set of conditioning variables.
112 ///
113 /// # Panics
114 ///
115 /// * Panics if `x` is empty.
116 /// * Panics if `x` and `z` are not disjoint.
117 /// * Panics if `x` or `z` are not in the model.
118 ///
119 /// # Returns
120 ///
121 /// The estimated values of `x` conditioned on `z`.
122 ///
123 fn estimate(&self, x: &Set<usize>, z: &Set<usize>) -> T::CPD;
124}
125
126impl<'a, R, E> ApproximateInference<'a, R, CatBN, E> {
127 #[inline]
128 fn sample_size(&self, x: &Set<usize>, z: &Set<usize>) -> usize {
129 // Get the sample size or compute it if not provided.
130 self.sample_size.unwrap_or_else(|| {
131 // Get the shape of the variables X and Z.
132 let (x_shape, z_shape): (usize, usize) = (
133 x.iter().map(|&i| self.model.shape()[i]).product(),
134 z.iter().map(|&i| self.model.shape()[i]).product(),
135 );
136 // Return the sample size as PAC-like bounds:
137 // (|Z| * (|X| - 1)) * ln(1 / delta) / epsilon^2, or approximately
138 // (|Z| * (|X| - 1)) * 1200 for delta = 0.05 and epsilon = 0.05.
139 z_shape * (x_shape - 1) * 1200
140 })
141 }
142}
143
144impl<'a, R, E> ApproximateInference<'a, R, GaussBN, E> {
145 #[inline]
146 fn sample_size(&self, x: &Set<usize>, z: &Set<usize>) -> usize {
147 // Get the sample size or compute it if not provided.
148 self.sample_size.unwrap_or_else(|| {
149 // Get the shape of the variables X and Z.
150 let (x_shape, z_shape) = (x.len(), z.len());
151 // Return the sample size as PAC-like bounds:
152 // (|X| * |Z| + (|X| * (|X| + 1)) / 2) * ln(1 / delta) / epsilon^2, or approximately
153 // (|X| * |Z| + (|X| * (|X| + 1)) / 2) * 1200, for delta = 0.05 and epsilon = 0.05.
154 // |X| * (|Z| + (|X| + 1) / 2) * 1200, for delta = 0.05 and epsilon = 0.05.
155 x_shape * (z_shape + x_shape.div_ceil(2)) * 1200
156 })
157 }
158}
159
160macro_for!($type in [CatBN, GaussBN] {
161
162 impl<R: Rng> BNInference<$type> for ApproximateInference<'_, R, $type, ()> {
163 fn estimate(&self, x: &Set<usize>, z: &Set<usize>) -> <$type as BN>::CPD {
164 // Assert X is not empty.
165 assert!(!x.is_empty(), "Variables X must not be empty.");
166 // Assert X and Z are disjoint.
167 assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
168 // Assert X and Z are in the model.
169 assert!(
170 x.union(z).all(|&i| i < self.model.labels().len()),
171 "Variables X and Z must be in the model."
172 );
173
174 // Get the sample size.
175 let n = self.sample_size(x, z);
176 // Get the RNG.
177 let mut rng = self.rng.borrow_mut();
178 // Initialize the sampler.
179 let sampler = ForwardSampler::new(&mut rng, self.model);
180 // Generate n samples from the model.
181 // TODO: Avoid generating the full dataset,
182 // e.g., by only sampling the variables in X U Z, and
183 // by using batching to reduce memory usage.
184 let dataset = sampler.sample_n(n);
185 // Initialize the estimator.
186 let estimator = MLE::new(&dataset);
187 // Fit the CPD.
188 estimator.fit(x, z)
189 }
190 }
191
192 impl<R: Rng> BNInference<$type> for ApproximateInference<'_, R, $type, <$type as BN>::Evidence> {
193 fn estimate(&self, x: &Set<usize>, z: &Set<usize>) -> <$type as BN>::CPD {
194 // Assert X is not empty.
195 assert!(!x.is_empty(), "Variables X must not be empty.");
196 // Assert X and Z are disjoint.
197 assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
198 // Assert X and Z are in the model.
199 assert!(
200 x.union(z).all(|&i| i < self.model.labels().len()),
201 "Variables X and Z must be in the model."
202 );
203
204 // Get the sample size.
205 let n = self.sample_size(x, z);
206 // Get the RNG.
207 let mut rng = self.rng.borrow_mut();
208 // Check if evidence is actually provided.
209 match self.evidence {
210 // Get the evidence.
211 Some(evidence) => {
212 // Initialize the sampler.
213 let sampler = ImportanceSampler::new(&mut rng, self.model, evidence);
214 // Generate n samples from the model.
215 // TODO: Avoid generating the full dataset,
216 // e.g., by only sampling the variables in X U Z, and
217 // by using batching to reduce memory usage.
218 let dataset = sampler.sample_n(n);
219 // Initialize the estimator.
220 let estimator = MLE::new(&dataset);
221 // Fit the CPD.
222 estimator.fit(x, z)
223 }
224 // Delegate to empty evidence case.
225 None => ApproximateInference::new(&mut rng, self.model)
226 .with_sample_size(n)
227 .estimate(x, z),
228 }
229 }
230 }
231
232});
233
234/// A trait for parallel inference with Bayesian Networks.
235pub trait ParBNInference<T>
236where
237 T: BN,
238{
239 /// Estimate the values of `x` conditioned on `z` using `n` samples, in parallel.
240 ///
241 /// # Arguments
242 ///
243 /// * `x` - The set of variables.
244 /// * `z` - The set of conditioning variables.
245 ///
246 /// # Panics
247 ///
248 /// * Panics if `x` is empty.
249 /// * Panics if `x` and `z` are not disjoint.
250 /// * Panics if `x` or `z` are not in the model.
251 ///
252 /// # Returns
253 ///
254 /// The estimated values of `x` conditioned on `z`.
255 ///
256 fn par_estimate(&self, x: &Set<usize>, z: &Set<usize>) -> T::CPD;
257}
258
259macro_for!($type in [CatBN, GaussBN] {
260
261 impl<R: Rng + SeedableRng> ParBNInference<$type> for ApproximateInference<'_, R, $type, ()> {
262 fn par_estimate(&self, x: &Set<usize>, z: &Set<usize>) -> <$type as BN>::CPD {
263 // Assert X is not empty.
264 assert!(!x.is_empty(), "Variables X must not be empty.");
265 // Assert X and Z are disjoint.
266 assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
267 // Assert X and Z are in the model.
268 assert!(
269 x.union(z).all(|&i| i < self.model.labels().len()),
270 "Variables X and Z must be in the model."
271 );
272
273 // Get the sample size.
274 let n = self.sample_size(x, z);
275 // Get the RNG.
276 let mut rng = self.rng.borrow_mut();
277 // Initialize the sampler.
278 let sampler = ForwardSampler::<R, _>::new(&mut rng, self.model);
279 // Generate n samples from the model.
280 // TODO: Avoid generating the full dataset,
281 // e.g., by only sampling the variables in X U Z, and
282 // by using batching to reduce memory usage.
283 let dataset = sampler.par_sample_n(n);
284 // Initialize the estimator.
285 let estimator = MLE::new(&dataset);
286 // Fit the CPD.
287 estimator.fit(x, z)
288 }
289 }
290
291 impl<R: Rng + SeedableRng> ParBNInference<$type> for ApproximateInference<'_, R, $type, <$type as BN>::Evidence> {
292 fn par_estimate(&self, x: &Set<usize>, z: &Set<usize>) -> <$type as BN>::CPD {
293 // Assert X is not empty.
294 assert!(!x.is_empty(), "Variables X must not be empty.");
295 // Assert X and Z are disjoint.
296 assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
297 // Assert X and Z are in the model.
298 assert!(
299 x.union(z).all(|&i| i < self.model.labels().len()),
300 "Variables X and Z must be in the model."
301 );
302
303 // Get the sample size.
304 let n = self.sample_size(x, z);
305 // Get the RNG.
306 let mut rng = self.rng.borrow_mut();
307 // Check if evidence is actually provided.
308 match self.evidence {
309 // Get the evidence.
310 Some(evidence) => {
311 // Initialize the sampler.
312 let sampler = ImportanceSampler::<R, _, _>::new(&mut rng, self.model, evidence);
313 // Generate n samples from the model.
314 // TODO: Avoid generating the full dataset,
315 // e.g., by only sampling the variables in X U Z, and
316 // by using batching to reduce memory usage.
317 let dataset = sampler.par_sample_n(n);
318 // Initialize the estimator.
319 let estimator = MLE::new(&dataset);
320 // Fit the CPD.
321 estimator.fit(x, z)
322 }
323 // Delegate to empty evidence case.
324 None => ApproximateInference::new(&mut rng, self.model)
325 .with_sample_size(n)
326 .estimate(x, z),
327 }
328 }
329 }
330
331});