causal_hub/inference/
causal_inference.rs

1use dry::macro_for;
2
3use crate::{
4    inference::{BNInference, BackdoorCriterion, Modelled, ParBNInference},
5    models::{BN, CatBN, GaussBN, Labelled, Phi},
6    set,
7    types::Set,
8};
9
10/// A causal inference engine.
11#[derive(Clone, Debug)]
12pub struct CausalInference<'a, E> {
13    engine: &'a E,
14}
15
16impl<'a, E> CausalInference<'a, E> {
17    /// Create a new causal inference engine.
18    ///
19    /// # Arguments
20    ///
21    /// * `engine` - The underlying inference engine.
22    ///
23    /// # Returns
24    ///
25    /// The causal inference engine.
26    ///
27    pub fn new(engine: &'a E) -> Self {
28        Self { engine }
29    }
30}
31
32/// A trait for causal inference with Bayesian Networks.
33pub trait BNCausalInference<T>
34where
35    T: BN,
36{
37    /// Estimate the average causal effect of `X` on `Y` as E(Y | do(X)).
38    ///
39    /// # Arguments
40    ///
41    /// * `x` - The cause variables.
42    /// * `y` - The effect variables.
43    ///
44    /// # Panics
45    ///
46    /// * If `X` is empty.
47    /// * If `Y` is empty.
48    /// * If `X` and `Y` are not disjoint.
49    ///
50    /// # Returns
51    ///
52    /// The estimated average causal effect of `X` on `Y`.
53    ///
54    fn ace_estimate(&self, x: &Set<usize>, y: &Set<usize>) -> Option<T::CPD> {
55        self.cace_estimate(x, y, &set![])
56    }
57
58    /// Estimate the conditional average causal effect of `X` on `Y` given `Z` as E(Y | do(X), Z).
59    ///
60    /// # Arguments
61    ///
62    /// * `x` - The cause variables.
63    /// * `y` - The effect variables.
64    /// * `z` - The conditioning variables.
65    ///
66    /// # Panics
67    ///
68    /// * If `X` is empty.
69    /// * If `Y` is empty.
70    /// * If `X` and `Y` are not disjoint.
71    /// * If `X` and `Z` are not disjoint.
72    /// * If `Y` and `Z` are not disjoint.
73    ///
74    /// # Returns
75    ///
76    /// The estimated conditional average causal effect of `X` on `Y` given `Z`.
77    ///
78    fn cace_estimate(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> Option<T::CPD>;
79}
80
81macro_for!($type in [CatBN, GaussBN] {
82
83    impl<E> BNCausalInference<$type> for CausalInference<'_, E>
84    where
85        E: Modelled<$type> + BNInference<$type>,
86    {
87        fn cace_estimate(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> Option<<$type as BN>::CPD> {
88            // Assert X is not empty.
89            assert!(!x.is_empty(), "Variables X must not be empty.");
90            // Assert Y is not empty.
91            assert!(!y.is_empty(), "Variables Y must not be empty.");
92            // Assert X and Y are disjoint.
93            assert!(x.is_disjoint(y), "Variables X and Y must be disjoint.");
94            // Assert X and Z are disjoint.
95            assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
96            // Assert Y and Z are disjoint.
97            assert!(y.is_disjoint(z), "Variables Y and Z must be disjoint.");
98
99            /* Effect Identification */
100
101            // Get the model.
102            let m = self.engine.model();
103            // Find a minimal backdoor adjustment set Z \cup S, if any.
104            let z_s = m.graph().find_minimal_backdoor_set(x, y, Some(z), None);
105
106            /* Effect Estimation */
107
108            // Match on the backdoor adjustment set.
109            match z_s {
110                // If no backdoor adjustment set exists, return None.
111                None => None,
112                // If the backdoor adjustment set is empty ...
113                Some(z_s) if z_s.is_empty() => {
114                    // ... estimate P(Y | do(X), Z) as P(Y | X, Z).
115                    Some(self.engine.estimate(y, &(x | z)))
116                }
117                // If the backdoor adjustment set is non-empty ...
118                Some(z_s) => {
119                    // Get the S part.
120                    let s = &(&z_s - z);
121                    // Estimate P(Y | X, Z, S) and P(S).
122                    let p_y_x_z_s = self.engine.estimate(y, &(x | s));
123                    let p_s = self.engine.estimate(s, &set![]);
124                    // Convert to potentials for aligned multiplication.
125                    let p_y_x_z_s = p_y_x_z_s.into_phi();
126                    let p_s = p_s.into_phi();
127                    // Compute P(Y | X, Z, S) * P(S) using potentials.
128                    let p_y_s_do_x_z = &p_y_x_z_s * &p_s;
129                    // Map BN indices to the potential indices.
130                    let s = p_y_s_do_x_z.indices_from(s, m.labels());
131                    // Marginalize over S.
132                    let p_y_do_x_z = p_y_s_do_x_z.marginalize(&s);
133                    // Map BN indices to the potential indices.
134                    let x = p_y_do_x_z.indices_from(x, m.labels());
135                    let y = p_y_do_x_z.indices_from(y, m.labels());
136                    let z = p_y_do_x_z.indices_from(z, m.labels());
137                    // Convert back to CPD.
138                    let p_y_do_x_z = p_y_do_x_z.into_cpd(&y, &(&x | &z));
139                    // Return the result.
140                    Some(p_y_do_x_z)
141                }
142            }
143        }
144    }
145
146});
147
148/// A trait for causal inference with Bayesian Networks in parallel.
149pub trait ParBNCausalInference<T>
150where
151    T: BN,
152{
153    /// Estimate the average causal effect of `X` on `Y` as E(Y | do(X)) in parallel.
154    ///
155    /// # Arguments
156    ///
157    /// * `x` - The cause variables.
158    /// * `y` - The effect variables.
159    ///
160    /// # Panics
161    ///
162    /// * If `X` is empty.
163    /// * If `Y` is empty.
164    /// * If `X` and `Y` are not disjoint.
165    ///
166    /// # Returns
167    ///
168    /// The estimated average causal effect of `X` on `Y`.
169    ///
170    fn par_ace_estimate(&self, x: &Set<usize>, y: &Set<usize>) -> Option<T::CPD> {
171        self.par_cace_estimate(x, y, &set![])
172    }
173
174    /// Estimate the conditional average causal effect of `X` on `Y` given `Z` as E(Y | do(X), Z) in parallel.
175    ///
176    /// # Arguments
177    ///
178    /// * `x` - The cause variables.
179    /// * `y` - The effect variables.
180    /// * `z` - The conditioning variables.
181    ///
182    /// # Panics
183    ///
184    /// * If `X` is empty.
185    /// * If `Y` is empty.
186    /// * If `X` and `Y` are not disjoint.
187    /// * If `X` and `Z` are not disjoint.
188    /// * If `Y` and `Z` are not disjoint.
189    ///
190    /// # Returns
191    ///
192    /// The estimated conditional average causal effect of `X` on `Y` given `Z`.
193    ///
194    fn par_cace_estimate(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> Option<T::CPD>;
195}
196
197macro_for!($type in [CatBN, GaussBN] {
198
199    impl<E> ParBNCausalInference<$type> for CausalInference<'_, E>
200    where
201        E: Modelled<$type> + ParBNInference<$type>,
202    {
203        fn par_cace_estimate(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> Option<<$type as BN>::CPD> {
204            // Assert X is not empty.
205            assert!(!x.is_empty(), "Variables X must not be empty.");
206            // Assert Y is not empty.
207            assert!(!y.is_empty(), "Variables Y must not be empty.");
208            // Assert X and Y are disjoint.
209            assert!(x.is_disjoint(y), "Variables X and Y must be disjoint.");
210            // Assert X and Z are disjoint.
211            assert!(x.is_disjoint(z), "Variables X and Z must be disjoint.");
212            // Assert Y and Z are disjoint.
213            assert!(y.is_disjoint(z), "Variables Y and Z must be disjoint.");
214
215            /* Effect Identification */
216
217            // Get the model.
218            let m = self.engine.model();
219            // Find a minimal backdoor adjustment set Z \cup S, if any.
220            let z_s = m.graph().find_minimal_backdoor_set(x, y, Some(z), None);
221
222            /* Effect Estimation */
223
224            // Match on the backdoor adjustment set.
225            match z_s {
226                // If no backdoor adjustment set exists, return None.
227                None => None,
228                // If the backdoor adjustment set is empty ...
229                Some(z_s) if z_s.is_empty() => {
230                    // ... estimate P(Y | do(X), Z) as P(Y | X, Z).
231                    Some(self.engine.par_estimate(y, &(x | z)))
232                }
233                // If the backdoor adjustment set is non-empty ...
234                Some(z_s) => {
235                    // Get the S part.
236                    let s = &(&z_s - z);
237                    // Estimate P(Y | X, Z, S) and P(S).
238                    let p_y_x_z_s = self.engine.par_estimate(y, &(x | s));
239                    let p_s = self.engine.par_estimate(s, &set![]);
240                    // Convert to potentials for aligned multiplication.
241                    let p_y_x_z_s = p_y_x_z_s.into_phi();
242                    let p_s = p_s.into_phi();
243                    // Compute P(Y | X, Z, S) * P(S) using potentials.
244                    let p_y_s_do_x_z = &p_y_x_z_s * &p_s;
245                    // Map BN indices to the potential indices.
246                    let s = p_y_s_do_x_z.indices_from(s, m.labels());
247                    // Marginalize over S.
248                    let p_y_do_x_z = p_y_s_do_x_z.marginalize(&s);
249                    // Map BN indices to the potential indices.
250                    let x = p_y_do_x_z.indices_from(x, m.labels());
251                    let y = p_y_do_x_z.indices_from(y, m.labels());
252                    let z = p_y_do_x_z.indices_from(z, m.labels());
253                    // Convert back to CPD.
254                    let p_y_do_x_z = p_y_do_x_z.into_cpd(&y, &(&x | &z));
255                    // Return the result.
256                    Some(p_y_do_x_z)
257                }
258            }
259        }
260    }
261
262});