causal_hub/models/
graphical_separation.rs

1use std::{collections::BTreeSet, fmt::Debug};
2
3use super::{GeneralizedIndependence, Independence, MoralGraph};
4use crate::{
5    graphs::directions,
6    prelude::{BaseGraph, DirectedGraph, UndirectedGraph, CC},
7    utils::UnionFind,
8    Adj, An, Ch, Ne, V,
9};
10
11/// Graphical independence struct
12#[derive(Clone, Debug)]
13pub struct GraphicalSeparation<'a, G, D>
14where
15    G: BaseGraph<Direction = D>,
16{
17    g: &'a G,
18}
19
20impl<'a, G, D> GraphicalSeparation<'a, G, D>
21where
22    G: BaseGraph<Direction = D>,
23{
24    /// Build a new graphical independence struct.
25    ///
26    /// # Panics
27    ///
28    /// If $\mathbf{X}$, $\mathbf{Y}$ and $\mathbf{Z}$
29    /// are not disjoint subsets of $\mathbf{V}$.
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// use causal_hub::prelude::*;
35    ///
36    /// // Build a new directed graph.
37    /// let g = DiGraph::new(
38    ///     ["A", "B", "C", "D", "E", "F"],
39    ///     [
40    ///         ("A", "C"),
41    ///         ("B", "C"),
42    ///         ("C", "D"),
43    ///         ("C", "E"),
44    ///     ]
45    /// );
46    ///
47    /// // Build d-separation query struct.
48    /// let q = GSeparation::from(&g);
49    ///
50    /// // Assert A _||_ B | { } .
51    /// assert!(q.are_independent([0], [1], []));
52    /// // Assert A _||_ B | { C } .
53    /// assert!(!q.are_independent([0], [1], [2]));
54    /// // Assert A _||_ D | { } .
55    /// assert!(!q.are_independent([0], [3], []));
56    /// // Assert A _||_ D | { C } .
57    /// assert!(q.are_independent([0], [3], [2]));
58    /// // Assert { A, B } _||_ { D, E } | { C } .
59    /// assert!(q.are_independent([0, 1], [3, 4], [2]));
60    /// ```
61    ///
62    #[inline]
63    pub const fn new(g: &'a G) -> Self {
64        Self { g }
65    }
66}
67
68impl<'a, G, D> From<&'a G> for GraphicalSeparation<'a, G, D>
69where
70    G: BaseGraph<Direction = D>,
71{
72    #[inline]
73    fn from(g: &'a G) -> Self {
74        Self::new(g)
75    }
76}
77
78/* Implement u-separation */
79impl<'a, G> Independence for GraphicalSeparation<'a, G, directions::Undirected>
80where
81    G: UndirectedGraph<Direction = directions::Undirected>,
82{
83    #[inline]
84    fn is_independent(&self, x: usize, y: usize, z: &[usize]) -> bool {
85        // TODO: Implement more efficient non-generalized version.
86        <Self as GeneralizedIndependence>::are_independent(self, [x], [y], z.iter().cloned())
87    }
88}
89
90impl<'a, G> GeneralizedIndependence for GraphicalSeparation<'a, G, directions::Undirected>
91where
92    G: UndirectedGraph<Direction = directions::Undirected>,
93{
94    /// Checks whether $\mathbf{X} \mathrlap{\thinspace\perp}{\perp}_{\mathcal{G}} \mathbf{Y} \mid \mathbf{Z}$ holds or not.
95    fn are_independent<I, J, K>(&self, x: I, y: J, z: K) -> bool
96    where
97        I: IntoIterator<Item = usize>,
98        J: IntoIterator<Item = usize>,
99        K: IntoIterator<Item = usize>,
100    {
101        // Check that X and Y are non-empty.
102        let x: BTreeSet<_> = x.into_iter().collect();
103        let y: BTreeSet<_> = y.into_iter().collect();
104        assert!(!x.is_empty() && !y.is_empty(), "X and Y must be non-empty");
105
106        // Check that X, Y and Z are disjoint, if not panic.
107        let z: BTreeSet<_> = z.into_iter().collect();
108        assert!(
109            x.is_disjoint(&y) && y.is_disjoint(&z) && z.is_disjoint(&x),
110            "X, Y and Z must be disjoint sets"
111        );
112
113        // Check that X, Y and Z are in V, if not panic.
114        let v: BTreeSet<_> = V!(self.g).collect();
115        assert!(
116            x.is_subset(&v) && y.is_subset(&v) && z.is_subset(&v),
117            "X, Y and Z must be subsets of V"
118        );
119
120        // Clone current graph.
121        let mut h = self.g.clone();
122
123        // Compute the set of out-going edges of Z.
124        let e_z = z
125            .into_iter()
126            .flat_map(|z| Ne!(self.g, z).map(move |w| (z, w)));
127        // Disconnect vertices in Z from the rest of the graph.
128        for (z, w) in e_z {
129            h.del_edge(z, w);
130        }
131
132        // Initialize union-find.
133        let mut union_find = UnionFind::new(h.order());
134        // Add X to union-find.
135        let root_x = *x.first().unwrap();
136        union_find.extend(x);
137        // Add X to union-find.
138        let root_y = *y.first().unwrap();
139        union_find.extend(y);
140
141        // Compute the connected components of the modified graph.
142        let mut cc = CC::from(&h);
143
144        // Check if there exists no connected component C s.t.
145        //          |C \cap X| > 0 && |C \cap Y| > 0 .
146        !cc.any(|c| {
147            // Add current connected component to union-find.
148            union_find.extend(c);
149            // Check if X and Y are in the same set.
150            union_find.contains(root_x, root_y)
151        })
152    }
153}
154
155/* Implement d-separation */
156impl<'a, G> Independence for GraphicalSeparation<'a, G, directions::Directed>
157where
158    G: DirectedGraph<Direction = directions::Directed> + MoralGraph,
159{
160    #[inline]
161    fn is_independent(&self, x: usize, y: usize, z: &[usize]) -> bool {
162        // TODO: Implement more efficient non-generalized version.
163        <Self as GeneralizedIndependence>::are_independent(self, [x], [y], z.iter().cloned())
164    }
165}
166
167impl<'a, G> GeneralizedIndependence for GraphicalSeparation<'a, G, directions::Directed>
168where
169    G: DirectedGraph<Direction = directions::Directed> + MoralGraph,
170{
171    /// Checks whether $\mathbf{X} \mathrlap{\thinspace\perp}{\perp}_{\mathcal{G}} \mathbf{Y} \mid \mathbf{Z}$ holds or not.
172    fn are_independent<I, J, K>(&self, x: I, y: J, z: K) -> bool
173    where
174        I: IntoIterator<Item = usize>,
175        J: IntoIterator<Item = usize>,
176        K: IntoIterator<Item = usize>,
177    {
178        // Check that X and Y are non-empty.
179        let x: BTreeSet<_> = x.into_iter().collect();
180        let y: BTreeSet<_> = y.into_iter().collect();
181        assert!(!x.is_empty() && !y.is_empty(), "X and Y must be non-empty");
182
183        // Check that X, Y and Z are disjoint, if not panic.
184        let z: BTreeSet<_> = z.into_iter().collect();
185        assert!(
186            x.is_disjoint(&y) && y.is_disjoint(&z) && z.is_disjoint(&x),
187            "X, Y and Z must be disjoint sets"
188        );
189
190        // Compute S = X \cup Y \cup Z.
191        let s = &(&x | &y) | &z;
192
193        // Check that X, Y and Z are in V, if not panic.
194        let v: BTreeSet<_> = V!(self.g).collect();
195        assert!(s.is_subset(&v), "X, Y and Z must be subsets of V");
196
197        // Clone current graph.
198        let mut h = self.g.to_undirected();
199
200        // Compute the ancestors of S.
201        let an_s = s.iter().flat_map(|&s| An!(self.g, s)).collect();
202        // Compute the ancestral set of S.
203        let an_s = &s | &an_s;
204
205        // Compute the set of out-going edges of V \ An_S.
206        let e_s = (&v - &an_s)
207            .into_iter()
208            .flat_map(|s| Adj!(self.g, s).flat_map(move |t| [(s, t), (t, s)]));
209        // Disconnect vertices in V \ S from the rest of the graph, i.e. compute the upward closure.
210        for (s, t) in e_s {
211            h.del_edge(s, t);
212        }
213
214        // Compute the set of out-going edges of Z.
215        let e_z = z
216            .into_iter()
217            .flat_map(|z| Ch!(self.g, z).map(move |w| (z, w)));
218        // Disconnect vertices in Z from the rest of the graph, i.e. compute the moral graph.
219        for (z, w) in e_z {
220            h.del_edge(z, w);
221        }
222
223        // Initialize union-find.
224        let mut union_find = UnionFind::new(h.order());
225        // Add X to union-find.
226        let root_x = *x.first().unwrap();
227        union_find.extend(x);
228        // Add X to union-find.
229        let root_y = *y.first().unwrap();
230        union_find.extend(y);
231
232        // Compute the connected components of the modified graph.
233        let mut cc = CC::from(&h);
234
235        // Check if there exists no connected component C s.t.
236        //          |C \cap X| > 0 && |C \cap Y| > 0 .
237        !cc.any(|c| {
238            // Add current connected component to union-find.
239            union_find.extend(c);
240            // Check if X and Y are in the same set.
241            union_find.contains(root_x, root_y)
242        })
243    }
244}