causal_hub/estimators/structures/
continuous_time_peter_clark.rs

1use itertools::Itertools;
2use log::debug;
3use ndarray::{Zip, prelude::*};
4use rayon::prelude::*;
5use statrs::distribution::{ChiSquared, ContinuousCDF, FisherSnedecor};
6
7use crate::{
8    estimators::{CIMEstimator, PK},
9    models::{CIM, CatCIM, DiGraph, Graph, Labelled},
10    set,
11    types::{Labels, Set},
12};
13
14/// A trait for conditional independence testing.
15pub trait CITest {
16    /// Test for conditional independence as X _||_ Y | Z.
17    ///
18    /// # Arguments
19    ///
20    /// * `x` - The first variable.
21    /// * `y` - The second variable.
22    /// * `z` - The conditioning set.
23    ///
24    /// # Returns
25    ///
26    /// `true` if X _||_ Y | Z, `false` otherwise.
27    ///
28    fn call(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> bool;
29}
30
31/// A struct representing the Chi-squared test.
32pub struct ChiSquaredTest<'a, E> {
33    estimator: &'a E,
34    alpha: f64,
35}
36
37impl<'a, E> ChiSquaredTest<'a, E> {
38    /// Creates a new `ChiSquaredTest` instance.
39    ///
40    /// # Arguments
41    ///
42    /// * `estimator` - A reference to the estimator.
43    /// * `alpha` - The significance level.
44    ///
45    /// # Panics
46    ///
47    /// Panics if the significance level is not in [0, 1].
48    ///
49    /// # Returns
50    ///
51    /// A new `ChiSquaredTest` instance.
52    ///
53    #[inline]
54    pub fn new(estimator: &'a E, alpha: f64) -> Self {
55        // Assert that the significance level is in [0, 1].
56        assert!((0.0..=1.0).contains(&alpha), "Alpha must be in [0, 1]");
57
58        Self { estimator, alpha }
59    }
60}
61
62impl<'a, E> Labelled for ChiSquaredTest<'a, E>
63where
64    E: Labelled,
65{
66    #[inline]
67    fn labels(&self) -> &Labels {
68        self.estimator.labels()
69    }
70}
71
72impl<E> CITest for ChiSquaredTest<'_, E>
73where
74    E: CIMEstimator<CatCIM>,
75{
76    fn call(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> bool {
77        // Assert Y contains exactly one label.
78        // TODO: Refactor code and remove this assumption.
79        assert_eq!(y.len(), 1, "Y must contain exactly one label.");
80
81        // Compute the extended separation set.
82        let mut s = z.clone();
83        // Get the ordered position of Y in the extended separation set.
84        let s_y = z.binary_search(&y[0]).unwrap_err();
85        // Insert Y into the extended separation set in sorted order.
86        s.shift_insert(s_y, y[0]);
87
88        // Fit the intensity matrices.
89        let q_xz = self.estimator.fit(x, z);
90        let q_xs = self.estimator.fit(x, &s);
91        // Get the sufficient statistics for the sets.
92        let n_xz = q_xz
93            .sample_statistics()
94            .map(|s| s.sample_conditional_counts())
95            .unwrap();
96        let n_xs = q_xs
97            .sample_statistics()
98            .map(|s| s.sample_conditional_counts())
99            .unwrap();
100
101        // Get the shape of the extended separation set.
102        let c_s = q_xs.conditioning_shape();
103        // Get the shape of the parent and the remaining strides.
104        let (c_y, c_s) = (c_s[s_y], c_s.slice(s![(s_y + 1)..]).product());
105
106        // For each combination of the extended parent set ...
107        (0..n_xs.shape()[0]).all(|j| {
108            // Compute the corresponding index for the separation set.
109            let i = j % c_s + (j / (c_s * c_y)) * c_s;
110            // Get the parameters of the chi-squared distribution.
111            let k_xz = n_xz.index_axis(Axis(0), i);
112            let k_xs = n_xs.index_axis(Axis(0), j);
113            // Compute the scaling factors.
114            let k = &k_xz.sum_axis(Axis(1)) / &k_xs.sum_axis(Axis(1));
115            let k = k.sqrt().insert_axis(Axis(1));
116            let l = k.recip();
117            // Compute the chi-squared statistic for uneven number of samples.
118            let chi_sq_num = (&k * &k_xs - &l * &k_xz).powi(2);
119            let chi_sq_den = &k_xs + &k_xz;
120            let chi_sq = chi_sq_num / &chi_sq_den;
121            // Fix division by zero.
122            let chi_sq = chi_sq.mapv(|x| if x.is_finite() { x } else { 0. });
123            // Compute the chi-squared statistic.
124            let chi_sq = chi_sq.sum_axis(Axis(1));
125            // For each chi-squared statistic ...
126            chi_sq
127                .into_iter()
128                .zip(chi_sq_den.rows())
129                .map(|(c, d)| {
130                    // Count the non-zero degrees of freedom.
131                    let dof = d.mapv(|d| (d > 0.) as usize).sum();
132                    // Check if the degrees of freedom is at least 2.
133                    let dof = if dof >= 2 { dof } else { 2 };
134                    // Initialize the chi-squared distribution.
135                    let n = ChiSquared::new((dof - 1) as f64).unwrap();
136                    // Compute the p-value.
137                    n.cdf(c)
138                })
139                // Check if the p-value is in the alpha range.
140                .all(|p_value| p_value < (1. - self.alpha))
141        })
142    }
143}
144
145/// A struct representing the F test.
146pub struct FTest<'a, E> {
147    estimator: &'a E,
148    alpha: f64,
149}
150
151impl<'a, E> FTest<'a, E> {
152    /// Creates a new `FTest` instance.
153    ///
154    /// # Arguments
155    ///
156    /// * `estimator` - A reference to the estimator.
157    /// * `alpha` - The significance level.
158    ///
159    /// # Panics
160    ///
161    /// Panics if the significance level is not in [0, 1].
162    ///
163    /// # Returns
164    ///
165    /// A new `FTest` instance.
166    ///
167    #[inline]
168    pub fn new(estimator: &'a E, alpha: f64) -> Self {
169        // Assert that the significance level is in [0, 1].
170        assert!((0.0..=1.0).contains(&alpha), "Alpha must be in [0, 1]");
171
172        Self { estimator, alpha }
173    }
174}
175
176impl<E> Labelled for FTest<'_, E>
177where
178    E: Labelled,
179{
180    #[inline]
181    fn labels(&self) -> &Labels {
182        self.estimator.labels()
183    }
184}
185
186impl<E> CITest for FTest<'_, E>
187where
188    E: CIMEstimator<CatCIM>,
189{
190    fn call(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> bool {
191        // Assert Y contains exactly one label.
192        // TODO: Refactor code and remove this assumption.
193        assert_eq!(y.len(), 1, "Y must contain exactly one label.");
194
195        // Compute the alpha range.
196        let alpha = (self.alpha / 2.)..=(1. - self.alpha / 2.);
197
198        // Compute the extended separation set.
199        let mut s = z.clone();
200        // Get the ordered position of Y in the extended separation set.
201        let s_y = z.binary_search(&y[0]).unwrap_err();
202        // Insert Y into the extended separation set in sorted order.
203        s.shift_insert(s_y, y[0]);
204
205        // Fit the intensity matrices.
206        let q_xz = self.estimator.fit(x, z);
207        let q_xs = self.estimator.fit(x, &s);
208        // Get the sufficient statistics for the sets.
209        let n_xz = q_xz
210            .sample_statistics()
211            .map(|s| s.sample_conditional_counts())
212            .unwrap();
213        let n_xs = q_xs
214            .sample_statistics()
215            .map(|s| s.sample_conditional_counts())
216            .unwrap();
217
218        // Get the shape of the extended separation set.
219        let c_s = q_xs.conditioning_shape();
220        // Get the shape of the parent and the remaining strides.
221        let (c_y, c_s) = (c_s[s_y], c_s.slice(s![(s_y + 1)..]).product());
222
223        // For each combination of the extended parent set ...
224        (0..n_xs.shape()[0]).all(|j| {
225            // Compute the corresponding index for the separation set.
226            let i = j % c_s + (j / (c_s * c_y)) * c_s;
227            // Get the parameters of the Fisher-Snedecor distribution.
228            let r_xz = n_xz.index_axis(Axis(0), i).sum_axis(Axis(1));
229            let r_xs = n_xs.index_axis(Axis(0), j).sum_axis(Axis(1));
230            // Get the intensity matrices for the separation sets.
231            let q_xz = q_xz.parameters().index_axis(Axis(0), i);
232            let q_xs = q_xs.parameters().index_axis(Axis(0), j);
233            // Perform the F-test.
234            Zip::from(&r_xz)
235                .and(&r_xs)
236                .and(q_xz.diag())
237                .and(q_xs.diag())
238                .all(|&r_xz, &r_xs, &q_xz, &q_xs| {
239                    // Initialize the Fisher-Snedecor distribution.
240                    let f = FisherSnedecor::new(r_xz, r_xs).unwrap();
241                    // Compute the p-value.
242                    let p_value = f.cdf(q_xz / q_xs);
243                    // Check if the p-value is in the alpha range.
244                    alpha.contains(&p_value)
245                })
246        })
247    }
248}
249
250/// A struct representing a continuous-time Peter-Clark estimator.
251#[derive(Clone, Debug)]
252pub struct CTPC<'a, T, S> {
253    initial_graph: &'a DiGraph,
254    null_time: &'a T,
255    null_state: &'a S,
256    prior_knowledge: Option<&'a PK>,
257}
258
259impl<'a, T, S> CTPC<'a, T, S>
260where
261    T: CITest + Labelled,
262    S: CITest + Labelled,
263{
264    /// Creates a new `CTPC` instance.
265    ///
266    /// # Arguments
267    ///
268    /// * `initial_graph` - A reference to the initial graph.
269    /// * `null_time` - A reference to the null time to transition hypothesis test.
270    /// * `null_state` - A reference to the null state-to-state transition hypothesis test.
271    ///
272    /// # Returns
273    ///
274    /// A new `CTPC` instance.
275    ///
276    #[inline]
277    pub fn new(initial_graph: &'a DiGraph, null_time: &'a T, null_state: &'a S) -> Self {
278        // Assert labels of the initial graph and the estimator are the same.
279        assert_eq!(
280            initial_graph.labels(),
281            null_time.labels(),
282            "Labels of initial graph and estimator must be the same: \n\
283            \t expected:    {:?}, \n\
284            \t found:       {:?}.",
285            initial_graph.labels(),
286            null_time.labels()
287        );
288        // Assert labels of the initial graph and the estimator are the same.
289        assert_eq!(
290            initial_graph.labels(),
291            null_state.labels(),
292            "Labels of initial graph and estimator must be the same: \n\
293            \t expected:    {:?}, \n\
294            \t found:       {:?}.",
295            initial_graph.labels(),
296            null_state.labels()
297        );
298
299        Self {
300            initial_graph,
301            null_time,
302            null_state,
303            prior_knowledge: None,
304        }
305    }
306
307    /// Sets the prior knowledge for the algorithm.
308    ///
309    /// # Arguments
310    ///
311    /// * `prior_knowledge` - The prior knowledge to use.
312    ///
313    /// # Returns
314    ///
315    /// A mutable reference to the current instance.
316    ///
317    #[inline]
318    pub fn with_prior_knowledge(mut self, prior_knowledge: &'a PK) -> Self {
319        // Assert labels of prior knowledge and initial graph are the same.
320        assert_eq!(
321            self.initial_graph.labels(),
322            prior_knowledge.labels(),
323            "Labels of initial graph and prior knowledge must be the same: \n\
324            \t expected:    {:?}, \n\
325            \t found:       {:?}.",
326            self.initial_graph.labels(),
327            prior_knowledge.labels()
328        );
329        // Assert prior knowledge is consistent with initial graph.
330        self.initial_graph
331            .vertices()
332            .into_iter()
333            .permutations(2)
334            .for_each(|edge| {
335                // Get the edge indices.
336                let (i, j) = (edge[0], edge[1]);
337                // Assert edge must be either present and not forbidden ...
338                if self.initial_graph.has_edge(i, j) {
339                    assert!(
340                        !prior_knowledge.is_forbidden(i, j),
341                        "Initial graph contains forbidden edge ({i}, {j})."
342                    );
343                // ... or absent and not required.
344                } else {
345                    assert!(
346                        !prior_knowledge.is_required(i, j),
347                        "Initial graph does not contain required edge ({i}, {j})."
348                    );
349                }
350            });
351        // Set prior knowledge.
352        self.prior_knowledge = Some(prior_knowledge);
353        self
354    }
355
356    /// Execute the CTPC algorithm.
357    ///
358    /// # Returns
359    ///
360    /// The fitted graph.
361    ///
362    pub fn fit(&self) -> DiGraph {
363        // Clone the initial graph.
364        let mut graph = self.initial_graph.clone();
365
366        // For each vertex in the graph ...
367        for i in graph.vertices() {
368            // Get the parents of the vertex.
369            let mut pa_i = graph.parents(&set![i]);
370
371            // Initialize the counter.
372            let mut k = 0;
373
374            // While the counter is smaller than the number of parents ...
375            while k < pa_i.len() {
376                // Initialize the set of vertices to remove, to ensure stability.
377                let mut not_pa_i = Vec::new();
378
379                // For each parent ...
380                for &j in &pa_i {
381                    // Check prior knowledge, if available.
382                    if let Some(pk) = self.prior_knowledge {
383                        // If the edge is required, skip the tests.
384                        // NOTE: Since CTPC only removes edges,
385                        //  it is sufficient to check for required edges.
386                        if pk.is_required(j, i) {
387                            // Log the skipped CIT.
388                            debug!("CIT for {j} _||_ {i} | [*] ... SKIPPED");
389                            continue;
390                        }
391                    }
392                    // Filter out the parent.
393                    let pa_i_not_j = pa_i.iter().filter(|&&z| z != j).cloned();
394                    // For any combination of size k of Pa(X_i) \ { X_j } ...
395                    for s_ij in pa_i_not_j.combinations(k).map(Set::from_iter) {
396                        // Log the current combination.
397                        debug!("CIT for {i} _||_ {j} | {s_ij:?} ...");
398                        // If X_i _||_ X_j | S_{X_i, X_j} ...
399                        if self.null_time.call(&set![i], &set![j], &s_ij)
400                            && self.null_state.call(&set![i], &set![j], &s_ij)
401                        {
402                            // Log the result of the CIT.
403                            debug!("CIT for {i} _||_ {j} | {s_ij:?} ... PASSED");
404                            // Add the parent to the set of vertices to remove.
405                            not_pa_i.push(j);
406                            // Break the outer loop.
407                            break;
408                        }
409                    }
410                }
411
412                // Remove the vertices from the graph.
413                for &j in &not_pa_i {
414                    // Remove the vertex from the parents.
415                    pa_i.retain(|&x| x != j);
416                    // Remove the edge from the graph.
417                    graph.del_edge(j, i);
418                }
419
420                // Increment the counter.
421                k += 1;
422            }
423        }
424
425        // Return the fitted graph.
426        graph
427    }
428}
429
430impl<'a, T, S> CTPC<'a, T, S>
431where
432    T: CITest + Sync,
433    S: CITest + Sync,
434{
435    /// Execute the CTPC algorithm and return the fitted graph in parallel.
436    ///
437    /// # Returns
438    ///
439    /// The fitted graph.
440    ///
441    pub fn par_fit(&self) -> DiGraph {
442        // For each vertex in the graph ...
443        let parents: Vec<_> = self
444            .initial_graph
445            .vertices()
446            .into_par_iter()
447            .map(|i| {
448                // Get the parents of the vertex.
449                let mut pa_i = self.initial_graph.parents(&set![i]);
450
451                // Initialize the counter.
452                let mut k = 0;
453
454                // While the counter is smaller than the number of parents ...
455                while k < pa_i.len() {
456                    // Filter the parents in parallel.
457                    pa_i = pa_i
458                        .par_iter()
459                        .filter_map(|&j| {
460                            // Check prior knowledge, if available.
461                            if let Some(pk) = self.prior_knowledge {
462                                // If the edge is required, skip the tests.
463                                // NOTE: Since CTPC only removes edges,
464                                //  it is sufficient to check for required edges.
465                                if pk.is_required(j, i) {
466                                    // Log the skipped CIT.
467                                    debug!("CIT for {j} _||_ {i} | [*] ... SKIPPED");
468                                    return Some(j);
469                                }
470                            }
471                            // Filter out the parent.
472                            let pa_i_not_j = pa_i.iter().filter(|&&z| z != j).cloned();
473                            // For any combination of size k of Pa(X_i) \ { X_j } ...
474                            for s_ij in pa_i_not_j.combinations(k).map(Set::from_iter) {
475                                // Log the current combination.
476                                debug!("CIT for {i} _||_ {j} | {s_ij:?} ...");
477                                // If X_i _||_ X_j | S_{X_i, X_j} ...
478                                if self.null_time.call(&set![i], &set![j], &s_ij)
479                                    && self.null_state.call(&set![i], &set![j], &s_ij)
480                                {
481                                    // Log the result of the CIT.
482                                    debug!("CIT for {i} _||_ {j} | {s_ij:?} ... PASSED");
483                                    // Add the parent to the set of vertices to remove.
484                                    return None;
485                                }
486                            }
487                            // Otherwise, keep the parent.
488                            Some(j)
489                        })
490                        .collect();
491                    // Increment the counter.
492                    k += 1;
493                }
494
495                // Return the parents of the vertex.
496                pa_i
497            })
498            .collect();
499
500        // Initialize an empty graph.
501        let mut graph = DiGraph::empty(self.initial_graph.labels());
502
503        // Set the parents of each vertex.
504        parents.into_iter().enumerate().for_each(|(i, pa_i)| {
505            // For each parent ...
506            pa_i.into_iter().for_each(|j| {
507                // Add the edge to the graph.
508                graph.add_edge(j, i);
509            })
510        });
511
512        // Return the fitted graph.
513        graph
514    }
515}