causal_hub/estimators/structures/continuous_time_peter_clark.rs
1use itertools::Itertools;
2use log::debug;
3use ndarray::{Zip, prelude::*};
4use rayon::prelude::*;
5use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
6
7use crate::{
8 estimators::{CIMEstimator, PK},
9 models::{CIM, CatCIM, DiGraph, Graph, Labelled},
10 set,
11 types::{Labels, Set},
12};
13
14/// A trait for conditional independence testing.
15pub trait CITest {
16 /// Test for conditional independence as X _||_ Y | Z.
17 ///
18 /// # Arguments
19 ///
20 /// * `x` - The first variable.
21 /// * `y` - The second variable.
22 /// * `z` - The conditioning set.
23 ///
24 /// # Returns
25 ///
26 /// `true` if X _||_ Y | Z, `false` otherwise.
27 ///
28 fn call(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> bool;
29}
30
31/// A struct representing the Chi-squared test.
32pub struct ChiSquaredTest<'a, E> {
33 estimator: &'a E,
34 alpha: f64,
35}
36
37impl<'a, E> ChiSquaredTest<'a, E> {
38 /// Creates a new `ChiSquaredTest` instance.
39 ///
40 /// # Arguments
41 ///
42 /// * `estimator` - A reference to the estimator.
43 /// * `alpha` - The significance level.
44 ///
45 /// # Panics
46 ///
47 /// Panics if the significance level is not in [0, 1].
48 ///
49 /// # Returns
50 ///
51 /// A new `ChiSquaredTest` instance.
52 ///
53 #[inline]
54 pub fn new(estimator: &'a E, alpha: f64) -> Self {
55 // Assert that the significance level is in [0, 1].
56 assert!((0.0..=1.0).contains(&alpha), "Alpha must be in [0, 1]");
57
58 Self { estimator, alpha }
59 }
60}
61
62impl<'a, E> Labelled for ChiSquaredTest<'a, E>
63where
64 E: Labelled,
65{
66 #[inline]
67 fn labels(&self) -> &Labels {
68 self.estimator.labels()
69 }
70}
71
72impl<E> CITest for ChiSquaredTest<'_, E>
73where
74 E: CIMEstimator<CatCIM>,
75{
76 fn call(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> bool {
77 // Assert Y contains exactly one label.
78 // TODO: Refactor code and remove this assumption.
79 assert_eq!(y.len(), 1, "Y must contain exactly one label.");
80
81 // Compute the extended separation set.
82 let mut s = z.clone();
83 // Get the ordered position of Y in the extended separation set.
84 let s_y = z.binary_search(&y[0]).unwrap_err();
85 // Insert Y into the extended separation set in sorted order.
86 s.shift_insert(s_y, y[0]);
87
88 // Fit the intensity matrices.
89 let q_xz = self.estimator.fit(x, z);
90 let q_xs = self.estimator.fit(x, &s);
91 // Get the sufficient statistics for the sets.
92 let n_xz = q_xz
93 .sample_statistics()
94 .map(|s| s.sample_conditional_counts())
95 .unwrap();
96 let n_xs = q_xs
97 .sample_statistics()
98 .map(|s| s.sample_conditional_counts())
99 .unwrap();
100
101 // Get the shape of the extended separation set.
102 let c_s = q_xs.conditioning_shape();
103 // Get the shape of the parent and the remaining strides.
104 let (c_y, c_s) = (c_s[s_y], c_s.slice(s![(s_y + 1)..]).product());
105
106 // For each combination of the extended parent set ...
107 (0..n_xs.shape()[0]).all(|j| {
108 // Compute the corresponding index for the separation set.
109 let i = j % c_s + (j / (c_s * c_y)) * c_s;
110 // Get the parameters of the chi-squared distribution.
111 let k_xz = n_xz.index_axis(Axis(0), i);
112 let k_xs = n_xs.index_axis(Axis(0), j);
113 // Compute the scaling factors.
114 let k = &k_xz.sum_axis(Axis(1)) / &k_xs.sum_axis(Axis(1));
115 let k = k.sqrt().insert_axis(Axis(1));
116 let l = k.recip();
117 // Compute the chi-squared statistic for uneven number of samples.
118 let chi_sq_num = (&k * &k_xs - &l * &k_xz).powi(2);
119 let chi_sq_den = &k_xs + &k_xz;
120 let chi_sq = chi_sq_num / &chi_sq_den;
121 // Fix division by zero.
122 let chi_sq = chi_sq.mapv(|x| if x.is_finite() { x } else { 0. });
123 // Compute the chi-squared statistic.
124 let chi_sq = chi_sq.sum_axis(Axis(1));
125 // For each chi-squared statistic ...
126 chi_sq
127 .into_iter()
128 .zip(chi_sq_den.rows())
129 .map(|(c, d)| {
130 // Count the non-zero degrees of freedom.
131 let dof = d.mapv(|d| (d > 0.) as usize).sum();
132 // Check if the degrees of freedom is at least 2.
133 let dof = if dof >= 2 { dof } else { 2 };
134 // Initialize the chi-squared distribution.
135 let n = ChiSquared::new((dof - 1) as f64).unwrap();
136 // Compute the p-value.
137 n.cdf(c)
138 })
139 // Check if the p-value is in the alpha range.
140 .all(|p_value| p_value < (1. - self.alpha))
141 })
142 }
143}
144
145/// A struct representing the F test.
146pub struct FTest<'a, E> {
147 estimator: &'a E,
148 alpha: f64,
149}
150
151impl<'a, E> FTest<'a, E> {
152 /// Creates a new `FTest` instance.
153 ///
154 /// # Arguments
155 ///
156 /// * `estimator` - A reference to the estimator.
157 /// * `alpha` - The significance level.
158 ///
159 /// # Panics
160 ///
161 /// Panics if the significance level is not in [0, 1].
162 ///
163 /// # Returns
164 ///
165 /// A new `FTest` instance.
166 ///
167 #[inline]
168 pub fn new(estimator: &'a E, alpha: f64) -> Self {
169 // Assert that the significance level is in [0, 1].
170 assert!((0.0..=1.0).contains(&alpha), "Alpha must be in [0, 1]");
171
172 Self { estimator, alpha }
173 }
174}
175
176impl<E> Labelled for FTest<'_, E>
177where
178 E: Labelled,
179{
180 #[inline]
181 fn labels(&self) -> &Labels {
182 self.estimator.labels()
183 }
184}
185
186impl<E> CITest for FTest<'_, E>
187where
188 E: CIMEstimator<CatCIM>,
189{
190 fn call(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> bool {
191 // Assert Y contains exactly one label.
192 // TODO: Refactor code and remove this assumption.
193 assert_eq!(y.len(), 1, "Y must contain exactly one label.");
194
195 // Compute the alpha range.
196 let alpha = (self.alpha / 2.)..=(1. - self.alpha / 2.);
197
198 // Compute the extended separation set.
199 let mut s = z.clone();
200 // Get the ordered position of Y in the extended separation set.
201 let s_y = z.binary_search(&y[0]).unwrap_err();
202 // Insert Y into the extended separation set in sorted order.
203 s.shift_insert(s_y, y[0]);
204
205 // Fit the intensity matrices.
206 let q_xz = self.estimator.fit(x, z);
207 let q_xs = self.estimator.fit(x, &s);
208 // Get the sufficient statistics for the sets.
209 let n_xz = q_xz
210 .sample_statistics()
211 .map(|s| s.sample_conditional_counts())
212 .unwrap();
213 let n_xs = q_xs
214 .sample_statistics()
215 .map(|s| s.sample_conditional_counts())
216 .unwrap();
217
218 // Get the shape of the extended separation set.
219 let c_s = q_xs.conditioning_shape();
220 // Get the shape of the parent and the remaining strides.
221 let (c_y, c_s) = (c_s[s_y], c_s.slice(s![(s_y + 1)..]).product());
222
223 // For each combination of the extended parent set ...
224 (0..n_xs.shape()[0]).all(|j| {
225 // Compute the corresponding index for the separation set.
226 let i = j % c_s + (j / (c_s * c_y)) * c_s;
227 // Get the parameters of the Fisher-Snedecor distribution.
228 let r_xz = n_xz.index_axis(Axis(0), i).sum_axis(Axis(1));
229 let r_xs = n_xs.index_axis(Axis(0), j).sum_axis(Axis(1));
230 // Get the intensity matrices for the separation sets.
231 let q_xz = q_xz.parameters().index_axis(Axis(0), i);
232 let q_xs = q_xs.parameters().index_axis(Axis(0), j);
233 // Perform the F-test.
234 Zip::from(&r_xz)
235 .and(&r_xs)
236 .and(q_xz.diag())
237 .and(q_xs.diag())
238 .all(|&r_xz, &r_xs, &q_xz, &q_xs| {
239 // Initialize the Fisher-Snedecor distribution.
240 let f = FisherSnedecor::new(r_xz, r_xs).unwrap();
241 // Compute the p-value.
242 let p_value = f.cdf(q_xz / q_xs);
243 // Check if the p-value is in the alpha range.
244 alpha.contains(&p_value)
245 })
246 })
247 }
248}
249
250/// A struct representing a continuous-time Peter-Clark estimator.
251#[derive(Clone, Debug)]
252pub struct CTPC<'a, T, S> {
253 initial_graph: &'a DiGraph,
254 null_time: &'a T,
255 null_state: &'a S,
256 prior_knowledge: Option<&'a PK>,
257}
258
259impl<'a, T, S> CTPC<'a, T, S>
260where
261 T: CITest + Labelled,
262 S: CITest + Labelled,
263{
264 /// Creates a new `CTPC` instance.
265 ///
266 /// # Arguments
267 ///
268 /// * `initial_graph` - A reference to the initial graph.
269 /// * `null_time` - A reference to the null time to transition hypothesis test.
270 /// * `null_state` - A reference to the null state-to-state transition hypothesis test.
271 ///
272 /// # Returns
273 ///
274 /// A new `CTPC` instance.
275 ///
276 #[inline]
277 pub fn new(initial_graph: &'a DiGraph, null_time: &'a T, null_state: &'a S) -> Self {
278 // Assert labels of the initial graph and the estimator are the same.
279 assert_eq!(
280 initial_graph.labels(),
281 null_time.labels(),
282 "Labels of initial graph and estimator must be the same: \n\
283 \t expected: {:?}, \n\
284 \t found: {:?}.",
285 initial_graph.labels(),
286 null_time.labels()
287 );
288 // Assert labels of the initial graph and the estimator are the same.
289 assert_eq!(
290 initial_graph.labels(),
291 null_state.labels(),
292 "Labels of initial graph and estimator must be the same: \n\
293 \t expected: {:?}, \n\
294 \t found: {:?}.",
295 initial_graph.labels(),
296 null_state.labels()
297 );
298
299 Self {
300 initial_graph,
301 null_time,
302 null_state,
303 prior_knowledge: None,
304 }
305 }
306
307 /// Sets the prior knowledge for the algorithm.
308 ///
309 /// # Arguments
310 ///
311 /// * `prior_knowledge` - The prior knowledge to use.
312 ///
313 /// # Returns
314 ///
315 /// A mutable reference to the current instance.
316 ///
317 #[inline]
318 pub fn with_prior_knowledge(mut self, prior_knowledge: &'a PK) -> Self {
319 // Assert labels of prior knowledge and initial graph are the same.
320 assert_eq!(
321 self.initial_graph.labels(),
322 prior_knowledge.labels(),
323 "Labels of initial graph and prior knowledge must be the same: \n\
324 \t expected: {:?}, \n\
325 \t found: {:?}.",
326 self.initial_graph.labels(),
327 prior_knowledge.labels()
328 );
329 // Assert prior knowledge is consistent with initial graph.
330 self.initial_graph
331 .vertices()
332 .into_iter()
333 .permutations(2)
334 .for_each(|edge| {
335 // Get the edge indices.
336 let (i, j) = (edge[0], edge[1]);
337 // Assert edge must be either present and not forbidden ...
338 if self.initial_graph.has_edge(i, j) {
339 assert!(
340 !prior_knowledge.is_forbidden(i, j),
341 "Initial graph contains forbidden edge ({i}, {j})."
342 );
343 // ... or absent and not required.
344 } else {
345 assert!(
346 !prior_knowledge.is_required(i, j),
347 "Initial graph does not contain required edge ({i}, {j})."
348 );
349 }
350 });
351 // Set prior knowledge.
352 self.prior_knowledge = Some(prior_knowledge);
353 self
354 }
355
356 /// Execute the CTPC algorithm.
357 ///
358 /// # Returns
359 ///
360 /// The fitted graph.
361 ///
362 pub fn fit(&self) -> DiGraph {
363 // Clone the initial graph.
364 let mut graph = self.initial_graph.clone();
365
366 // For each vertex in the graph ...
367 for i in graph.vertices() {
368 // Get the parents of the vertex.
369 let mut pa_i = graph.parents(&set![i]);
370
371 // Initialize the counter.
372 let mut k = 0;
373
374 // While the counter is smaller than the number of parents ...
375 while k < pa_i.len() {
376 // Initialize the set of vertices to remove, to ensure stability.
377 let mut not_pa_i = Vec::new();
378
379 // For each parent ...
380 for &j in &pa_i {
381 // Check prior knowledge, if available.
382 if let Some(pk) = self.prior_knowledge {
383 // If the edge is required, skip the tests.
384 // NOTE: Since CTPC only removes edges,
385 // it is sufficient to check for required edges.
386 if pk.is_required(j, i) {
387 // Log the skipped CIT.
388 debug!("CIT for {j} _||_ {i} | [*] ... SKIPPED");
389 continue;
390 }
391 }
392 // Filter out the parent.
393 let pa_i_not_j = pa_i.iter().filter(|&&z| z != j).cloned();
394 // For any combination of size k of Pa(X_i) \ { X_j } ...
395 for s_ij in pa_i_not_j.combinations(k).map(Set::from_iter) {
396 // Log the current combination.
397 debug!("CIT for {i} _||_ {j} | {s_ij:?} ...");
398 // If X_i _||_ X_j | S_{X_i, X_j} ...
399 if self.null_time.call(&set![i], &set![j], &s_ij)
400 && self.null_state.call(&set![i], &set![j], &s_ij)
401 {
402 // Log the result of the CIT.
403 debug!("CIT for {i} _||_ {j} | {s_ij:?} ... PASSED");
404 // Add the parent to the set of vertices to remove.
405 not_pa_i.push(j);
406 // Break the outer loop.
407 break;
408 }
409 }
410 }
411
412 // Remove the vertices from the graph.
413 for &j in ¬_pa_i {
414 // Remove the vertex from the parents.
415 pa_i.retain(|&x| x != j);
416 // Remove the edge from the graph.
417 graph.del_edge(j, i);
418 }
419
420 // Increment the counter.
421 k += 1;
422 }
423 }
424
425 // Return the fitted graph.
426 graph
427 }
428}
429
430impl<'a, T, S> CTPC<'a, T, S>
431where
432 T: CITest + Sync,
433 S: CITest + Sync,
434{
435 /// Execute the CTPC algorithm and return the fitted graph in parallel.
436 ///
437 /// # Returns
438 ///
439 /// The fitted graph.
440 ///
441 pub fn par_fit(&self) -> DiGraph {
442 // For each vertex in the graph ...
443 let parents: Vec<_> = self
444 .initial_graph
445 .vertices()
446 .into_par_iter()
447 .map(|i| {
448 // Get the parents of the vertex.
449 let mut pa_i = self.initial_graph.parents(&set![i]);
450
451 // Initialize the counter.
452 let mut k = 0;
453
454 // While the counter is smaller than the number of parents ...
455 while k < pa_i.len() {
456 // Filter the parents in parallel.
457 pa_i = pa_i
458 .par_iter()
459 .filter_map(|&j| {
460 // Check prior knowledge, if available.
461 if let Some(pk) = self.prior_knowledge {
462 // If the edge is required, skip the tests.
463 // NOTE: Since CTPC only removes edges,
464 // it is sufficient to check for required edges.
465 if pk.is_required(j, i) {
466 // Log the skipped CIT.
467 debug!("CIT for {j} _||_ {i} | [*] ... SKIPPED");
468 return Some(j);
469 }
470 }
471 // Filter out the parent.
472 let pa_i_not_j = pa_i.iter().filter(|&&z| z != j).cloned();
473 // For any combination of size k of Pa(X_i) \ { X_j } ...
474 for s_ij in pa_i_not_j.combinations(k).map(Set::from_iter) {
475 // Log the current combination.
476 debug!("CIT for {i} _||_ {j} | {s_ij:?} ...");
477 // If X_i _||_ X_j | S_{X_i, X_j} ...
478 if self.null_time.call(&set![i], &set![j], &s_ij)
479 && self.null_state.call(&set![i], &set![j], &s_ij)
480 {
481 // Log the result of the CIT.
482 debug!("CIT for {i} _||_ {j} | {s_ij:?} ... PASSED");
483 // Add the parent to the set of vertices to remove.
484 return None;
485 }
486 }
487 // Otherwise, keep the parent.
488 Some(j)
489 })
490 .collect();
491 // Increment the counter.
492 k += 1;
493 }
494
495 // Return the parents of the vertex.
496 pa_i
497 })
498 .collect();
499
500 // Initialize an empty graph.
501 let mut graph = DiGraph::empty(self.initial_graph.labels());
502
503 // Set the parents of each vertex.
504 parents.into_iter().enumerate().for_each(|(i, pa_i)| {
505 // For each parent ...
506 pa_i.into_iter().for_each(|j| {
507 // Add the edge to the graph.
508 graph.add_edge(j, i);
509 })
510 });
511
512 // Return the fitted graph.
513 graph
514 }
515}