causal_hub/inference/causal_inference.rs
1use dry::macro_for;
2
3use crate::{
4 inference::{BNInference, BackdoorCriterion, Modelled, ParBNInference},
5 models::{BN, CatBN, GaussBN, Labelled, Phi},
6 set,
7 types::Set,
8};
9
10/// A causal inference engine.
11#[derive(Clone, Debug)]
12pub struct CausalInference<'a, E> {
13 engine: &'a E,
14}
15
16impl<'a, E> CausalInference<'a, E> {
17 /// Create a new causal inference engine.
18 ///
19 /// # Arguments
20 ///
21 /// * `engine` - The underlying inference engine.
22 ///
23 /// # Returns
24 ///
25 /// The causal inference engine.
26 ///
27 pub fn new(engine: &'a E) -> Self {
28 Self { engine }
29 }
30}
31
32/// A trait for causal inference with Bayesian Networks.
33pub trait BNCausalInference<T>
34where
35 T: BN,
36{
37 /// Estimate the average causal effect of `X` on `Y` as E(Y | do(X)).
38 ///
39 /// # Arguments
40 ///
41 /// * `x` - The cause variables.
42 /// * `y` - The effect variables.
43 ///
44 /// # Panics
45 ///
46 /// * If `X` is empty.
47 /// * If `Y` is empty.
48 /// * If `X` and `Y` are not disjoint.
49 ///
50 /// # Returns
51 ///
52 /// The estimated average causal effect of `X` on `Y`.
53 ///
54 fn ace_estimate(&self, x: &Set<usize>, y: &Set<usize>) -> Option<T::CPD> {
55 self.cace_estimate(x, y, &set![])
56 }
57
58 /// Estimate the conditional average causal effect of `X` on `Y` given `Z` as E(Y | do(X), Z).
59 ///
60 /// # Arguments
61 ///
62 /// * `x` - The cause variables.
63 /// * `y` - The effect variables.
64 /// * `z` - The conditioning variables.
65 ///
66 /// # Panics
67 ///
68 /// * If `X` is empty.
69 /// * If `Y` is empty.
70 /// * If `X` and `Y` are not disjoint.
71 /// * If `X` and `Z` are not disjoint.
72 /// * If `Y` and `Z` are not disjoint.
73 ///
74 /// # Returns
75 ///
76 /// The estimated conditional average causal effect of `X` on `Y` given `Z`.
77 ///
78 fn cace_estimate(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> Option<T::CPD>;
79}
80
81macro_for!($type in [CatBN, GaussBN] {
82
83 impl<E> BNCausalInference<$type> for CausalInference<'_, E>
84 where
85 E: Modelled<$type> + BNInference<$type>,
86 {
87 fn cace_estimate(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> Option<<$type as BN>::CPD> {
88 // Assert X is not empty.
89 assert!(!x.is_empty(), "Variables X must not be empty.");
90 // Assert Y is not empty.
91 assert!(!y.is_empty(), "Variables Y must not be empty.");
92 // Assert X and Y are disjoint.
93 assert!(x.is_disjoint(y), "Variables X and Y must be disjoint.");
94 // Assert X and Z are disjoint.
95 assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
96 // Assert Y and Z are disjoint.
97 assert!(y.is_disjoint(z), "Variables Y and Z must be disjoint.");
98
99 /* Effect Identification */
100
101 // Get the model.
102 let m = self.engine.model();
103 // Find a minimal backdoor adjustment set Z \cup S, if any.
104 let z_s = m.graph().find_minimal_backdoor_set(x, y, Some(z), None);
105
106 /* Effect Estimation */
107
108 // Match on the backdoor adjustment set.
109 match z_s {
110 // If no backdoor adjustment set exists, return None.
111 None => None,
112 // If the backdoor adjustment set is empty ...
113 Some(z_s) if z_s.is_empty() => {
114 // ... estimate P(Y | do(X), Z) as P(Y | X, Z).
115 Some(self.engine.estimate(y, &(x | z)))
116 }
117 // If the backdoor adjustment set is non-empty ...
118 Some(z_s) => {
119 // Get the S part.
120 let s = &(&z_s - z);
121 // Estimate P(Y | X, Z, S) and P(S).
122 let p_y_x_z_s = self.engine.estimate(y, &(x | s));
123 let p_s = self.engine.estimate(s, &set![]);
124 // Convert to potentials for aligned multiplication.
125 let p_y_x_z_s = p_y_x_z_s.into_phi();
126 let p_s = p_s.into_phi();
127 // Compute P(Y | X, Z, S) * P(S) using potentials.
128 let p_y_s_do_x_z = &p_y_x_z_s * &p_s;
129 // Map BN indices to the potential indices.
130 let s = p_y_s_do_x_z.indices_from(s, m.labels());
131 // Marginalize over S.
132 let p_y_do_x_z = p_y_s_do_x_z.marginalize(&s);
133 // Map BN indices to the potential indices.
134 let x = p_y_do_x_z.indices_from(x, m.labels());
135 let y = p_y_do_x_z.indices_from(y, m.labels());
136 let z = p_y_do_x_z.indices_from(z, m.labels());
137 // Convert back to CPD.
138 let p_y_do_x_z = p_y_do_x_z.into_cpd(&y, &(&x | &z));
139 // Return the result.
140 Some(p_y_do_x_z)
141 }
142 }
143 }
144 }
145
146});
147
148/// A trait for causal inference with Bayesian Networks in parallel.
149pub trait ParBNCausalInference<T>
150where
151 T: BN,
152{
153 /// Estimate the average causal effect of `X` on `Y` as E(Y | do(X)) in parallel.
154 ///
155 /// # Arguments
156 ///
157 /// * `x` - The cause variables.
158 /// * `y` - The effect variables.
159 ///
160 /// # Panics
161 ///
162 /// * If `X` is empty.
163 /// * If `Y` is empty.
164 /// * If `X` and `Y` are not disjoint.
165 ///
166 /// # Returns
167 ///
168 /// The estimated average causal effect of `X` on `Y`.
169 ///
170 fn par_ace_estimate(&self, x: &Set<usize>, y: &Set<usize>) -> Option<T::CPD> {
171 self.par_cace_estimate(x, y, &set![])
172 }
173
174 /// Estimate the conditional average causal effect of `X` on `Y` given `Z` as E(Y | do(X), Z) in parallel.
175 ///
176 /// # Arguments
177 ///
178 /// * `x` - The cause variables.
179 /// * `y` - The effect variables.
180 /// * `z` - The conditioning variables.
181 ///
182 /// # Panics
183 ///
184 /// * If `X` is empty.
185 /// * If `Y` is empty.
186 /// * If `X` and `Y` are not disjoint.
187 /// * If `X` and `Z` are not disjoint.
188 /// * If `Y` and `Z` are not disjoint.
189 ///
190 /// # Returns
191 ///
192 /// The estimated conditional average causal effect of `X` on `Y` given `Z`.
193 ///
194 fn par_cace_estimate(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> Option<T::CPD>;
195}
196
197macro_for!($type in [CatBN, GaussBN] {
198
199 impl<E> ParBNCausalInference<$type> for CausalInference<'_, E>
200 where
201 E: Modelled<$type> + ParBNInference<$type>,
202 {
203 fn par_cace_estimate(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> Option<<$type as BN>::CPD> {
204 // Assert X is not empty.
205 assert!(!x.is_empty(), "Variables X must not be empty.");
206 // Assert Y is not empty.
207 assert!(!y.is_empty(), "Variables Y must not be empty.");
208 // Assert X and Y are disjoint.
209 assert!(x.is_disjoint(y), "Variables X and Y must be disjoint.");
210 // Assert X and Z are disjoint.
211 assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
212 // Assert Y and Z are disjoint.
213 assert!(y.is_disjoint(z), "Variables Y and Z must be disjoint.");
214
215 /* Effect Identification */
216
217 // Get the model.
218 let m = self.engine.model();
219 // Find a minimal backdoor adjustment set Z \cup S, if any.
220 let z_s = m.graph().find_minimal_backdoor_set(x, y, Some(z), None);
221
222 /* Effect Estimation */
223
224 // Match on the backdoor adjustment set.
225 match z_s {
226 // If no backdoor adjustment set exists, return None.
227 None => None,
228 // If the backdoor adjustment set is empty ...
229 Some(z_s) if z_s.is_empty() => {
230 // ... estimate P(Y | do(X), Z) as P(Y | X, Z).
231 Some(self.engine.par_estimate(y, &(x | z)))
232 }
233 // If the backdoor adjustment set is non-empty ...
234 Some(z_s) => {
235 // Get the S part.
236 let s = &(&z_s - z);
237 // Estimate P(Y | X, Z, S) and P(S).
238 let p_y_x_z_s = self.engine.par_estimate(y, &(x | s));
239 let p_s = self.engine.par_estimate(s, &set![]);
240 // Convert to potentials for aligned multiplication.
241 let p_y_x_z_s = p_y_x_z_s.into_phi();
242 let p_s = p_s.into_phi();
243 // Compute P(Y | X, Z, S) * P(S) using potentials.
244 let p_y_s_do_x_z = &p_y_x_z_s * &p_s;
245 // Map BN indices to the potential indices.
246 let s = p_y_s_do_x_z.indices_from(s, m.labels());
247 // Marginalize over S.
248 let p_y_do_x_z = p_y_s_do_x_z.marginalize(&s);
249 // Map BN indices to the potential indices.
250 let x = p_y_do_x_z.indices_from(x, m.labels());
251 let y = p_y_do_x_z.indices_from(y, m.labels());
252 let z = p_y_do_x_z.indices_from(z, m.labels());
253 // Convert back to CPD.
254 let p_y_do_x_z = p_y_do_x_z.into_cpd(&y, &(&x | &z));
255 // Return the result.
256 Some(p_y_do_x_z)
257 }
258 }
259 }
260 }
261
262});