causal_hub/inference/
graphical_separation.rs

1use std::collections::VecDeque;
2
3use crate::{
4    models::{DiGraph, Graph},
5    set,
6    types::Set,
7};
8
9/// A trait for graphical separation.
10pub trait GraphicalSeparation {
11    /// Checks if the `Z` is a separator set for `X` and `Y`.
12    ///
13    /// # Arguments
14    ///
15    /// * `x` - A set of vertices representing set `X`.
16    /// * `y` - A set of vertices representing set `Y`.
17    /// * `z` - A set of vertices representing set `Z`.
18    ///
19    /// # Panics
20    ///
21    /// * If any of the vertex in `X`, `Y`, or `Z` are out of bounds.
22    /// * If `X`, `Y` or `Z` are not disjoint sets.
23    /// * If `X` and `Y` are empty sets.
24    ///
25    /// # Returns
26    ///
27    /// `true` if `X` and `Y` are separated by `Z`, `false` otherwise.
28    ///
29    fn is_separator_set(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> bool;
30
31    /// Checks if the `Z` is a minimal separator set for `X` and `Y`.
32    ///
33    /// # Arguments
34    ///
35    /// * `x` - A set of vertices representing set `X`.
36    /// * `y` - A set of vertices representing set `Y`.
37    /// * `z` - A set of vertices representing set `Z`.
38    /// * `w` - An optional iterable collection of vertices representing set `W`.
39    /// * `v` - An optional iterable collection of vertices representing set `V`.
40    ///
41    /// # Panics
42    ///
43    /// * If any of the vertex in `X`, `Y`, `Z`, `W` or `V` are out of bounds.
44    /// * If `X`, `Y` or `Z` are not disjoint sets.
45    /// * If `X` and `Y` are empty sets.
46    /// * If not `W` <= `Z` <= `V`.
47    ///
48    /// # Returns
49    ///
50    /// `true` if `Z` is a minimal separator set for `X` and `Y`, `false` otherwise.
51    ///
52    fn is_minimal_separator_set(
53        &self,
54        x: &Set<usize>,
55        y: &Set<usize>,
56        z: &Set<usize>,
57        w: Option<&Set<usize>>,
58        v: Option<&Set<usize>>,
59    ) -> bool;
60
61    /// Finds a minimal separator set for the vertex sets `X` and `Y`, if any.
62    ///
63    /// # Arguments
64    ///
65    /// * `x` - A set of vertices representing set `X`.
66    /// * `y` - A set of vertices representing set `Y`.
67    ///
68    /// # Panics
69    ///
70    /// * If any of the vertex in `X`, `Y`, `W` or `V` are out of bounds.
71    /// * If `X` and `Y` are not disjoint sets.
72    /// * If `X` or `Y` are empty sets.
73    /// * If not `W` <= `V`.
74    ///
75    /// # Returns
76    ///
77    /// `Some(Set)` containing the minimal separator set, or `None` if no separator set exists.
78    ///
79    fn find_minimal_separator_set(
80        &self,
81        x: &Set<usize>,
82        y: &Set<usize>,
83        w: Option<&Set<usize>>,
84        v: Option<&Set<usize>>,
85    ) -> Option<Set<usize>>;
86}
87
88// Implementation of the `GraphicalSeparation` trait for directed graphs.
89pub(crate) mod digraph {
90    use super::*;
91    use crate::inference::TopologicalOrder;
92
93    /// Asserts the validity of the sets and returns them as `Set<usize>`.
94    pub(crate) fn _assert(
95        g: &DiGraph,
96        x: &Set<usize>,
97        y: &Set<usize>,
98        z: Option<&Set<usize>>,
99        w: Option<&Set<usize>>,
100        v: Option<&Set<usize>>,
101    ) {
102        // Assert the included set is a subset of the restricted set.
103        if let (Some(w), Some(v)) = (w.as_ref(), v.as_ref()) {
104            assert!(w.is_subset(v), "Set W must be a subset of set V.");
105        }
106
107        // Convert X to set, while checking for out of bounds.
108        for &x in x {
109            assert!(g.has_vertex(x), "Vertex `{x}` in set X is out of bounds.");
110        }
111        // Convert Y to set, while checking for out of bounds.
112        for &y in y {
113            assert!(g.has_vertex(y), "Vertex `{y}` in set Y is out of bounds.");
114        }
115        // Convert Z to set, while checking for out of bounds.
116        if let Some(z) = z {
117            for &z in z {
118                assert!(g.has_vertex(z), "Vertex `{z}` in set Z is out of bounds.");
119            }
120        }
121
122        // Assert X is non-empty.
123        assert!(!x.is_empty(), "Set X must not be empty.");
124        // Assert Y is non-empty.
125        assert!(!y.is_empty(), "Set Y must not be empty.");
126
127        // Assert X and Y are disjoint.
128        assert!(x.is_disjoint(y), "Sets X and Y must be disjoint.");
129
130        // If Z is provided, convert it to a set.
131        if let Some(z) = &z {
132            // Assert X and Z are disjoint.
133            assert!(x.is_disjoint(z), "Sets X and Z must be disjoint.");
134            // Assert Y and Z are disjoint.
135            assert!(y.is_disjoint(z), "Sets Y and Z must be disjoint.");
136            // Assert Z includes.
137            if let Some(w) = w {
138                assert!(z.is_superset(w), "Set Z must be a superset of set W.");
139            }
140            // Assert Z is restricted.
141            if let Some(v) = v {
142                assert!(z.is_subset(v), "Set Z must be a subset of set V.");
143            }
144        }
145    }
146
147    fn _reachable(g: &DiGraph, x: &Set<usize>, an_x: &Set<usize>, z: &Set<usize>) -> Set<usize> {
148        // Assert the graph is a DAG.
149        assert!(g.topological_order().is_some(), "Graph must be a DAG.");
150
151        // Check if the ball passes or not.
152        let _pass = |e: bool, v: usize, f: bool, n: usize| {
153            let is_element_of_a = an_x.contains(&n);
154            let almost_definite_status = true; // NOTE: Always true for DAGs, not so for RCGs.
155            let collider_if_in_z = !z.contains(&v) || (e && !f);
156            // If the edge is forward, the vertex must be an ancestor or in Z.
157            is_element_of_a && collider_if_in_z && almost_definite_status
158        };
159
160        // Initialize the queue.
161        let mut queue: VecDeque<(bool, usize)> = Default::default();
162        // For each vertex in X ...
163        for &w in x {
164            // If the vertex has predecessors, add it to the queue as a backward edge.
165            if !g.parents(&set![w]).is_empty() {
166                queue.push_back((false, w));
167            }
168            // If the vertex has successors, add it to the queue as a forward edge.
169            if !g.children(&set![w]).is_empty() {
170                queue.push_back((true, w));
171            }
172        }
173
174        // Initialize the processed set with the queue.
175        let mut visited = queue.clone();
176
177        // For each element in the queue ...
178        while let Some((e, v)) = queue.pop_front() {
179            // Get the predecessors and successors of the vertex.
180            let pa_v = g.parents(&set![v]).into_iter().map(|n| (false, n));
181            let ch_v = g.children(&set![v]).into_iter().map(|n| (true, n));
182
183            // Create pairs of (forward, vertex) for predecessors and successors.
184            let f_n_pairs = pa_v.chain(ch_v);
185
186            // For each pair ...
187            for (f, n) in f_n_pairs {
188                // If the pair has not been processed and passes the condition ...
189                if !visited.contains(&(f, n)) && _pass(e, v, f, n) {
190                    // Add it to the queue and mark it as processed.
191                    queue.push_back((f, n));
192                    visited.push_back((f, n));
193                }
194            }
195        }
196
197        // Return the set of visited vertices.
198        visited.into_iter().map(|(_, w)| w).collect()
199    }
200
201    impl GraphicalSeparation for DiGraph {
202        fn is_separator_set(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> bool {
203            // Perform sanity checks and convert sets.
204            _assert(self, x, y, Some(z), None::<&Set<_>>, None::<&Set<_>>);
205
206            // Initialize the forward and backward deques and visited sets.
207
208            // Contains -> and <-> edges from starting vertex.
209            let mut forward_deque: VecDeque<usize> = Default::default();
210            let mut forward_visited: Set<usize> = set![];
211            // Contains <- and - edges from starting vertex.
212            let mut backward_deque: VecDeque<usize> = Default::default();
213            let mut backward_visited: Set<usize> = set![];
214
215            // Initialize the backward deque with the vertices in X.
216            backward_deque.extend(x.iter().cloned());
217
218            // Compute the ancestors of X and Z.
219            let ancestors_or_z = &self.ancestors(z) | &(z | x);
220
221            // While there are vertices to visit in the forward or backward deques ...
222            while !forward_deque.is_empty() || !backward_deque.is_empty() {
223                // If there are vertices in the backward deque ...
224                if let Some(w) = backward_deque.pop_front() {
225                    // Mark the W as visited.
226                    backward_visited.insert(w);
227                    // If the W is in Y, return false (not separated).
228                    if y.contains(&w) {
229                        return false;
230                    }
231                    // If the W is in Z, continue to the next iteration.
232                    if z.contains(&w) {
233                        continue;
234                    }
235                    // Add all predecessors of the W to the backward deque.
236                    for pred in self.parents(&set![w]) {
237                        if !backward_visited.contains(&pred) {
238                            backward_deque.push_back(pred);
239                        }
240                    }
241                    // Add all successors of the W to the forward deque.
242                    for succ in self.children(&set![w]) {
243                        if !forward_visited.contains(&succ) {
244                            forward_deque.push_back(succ);
245                        }
246                    }
247                }
248
249                // If there are vertices in the forward deque ...
250                if let Some(w) = forward_deque.pop_front() {
251                    // Mark the W as visited.
252                    forward_visited.insert(w);
253                    // If the W is in Y, return false (not separated).
254                    if y.contains(&w) {
255                        return false;
256                    }
257                    // If the W is an ancestor or in Z, add its predecessors to the backward deque.
258                    if ancestors_or_z.contains(&w) {
259                        for pred in self.parents(&set![w]) {
260                            if !backward_visited.contains(&pred) {
261                                backward_deque.push_back(pred);
262                            }
263                        }
264                    }
265                    // If the W is not in Z, add its successors to the forward deque.
266                    if !z.contains(&w) {
267                        for succ in self.children(&set![w]) {
268                            if !forward_visited.contains(&succ) {
269                                forward_deque.push_back(succ);
270                            }
271                        }
272                    }
273                }
274            }
275
276            // Otherwise, return true.
277            true
278        }
279
280        fn is_minimal_separator_set(
281            &self,
282            x: &Set<usize>,
283            y: &Set<usize>,
284            z: &Set<usize>,
285            w: Option<&Set<usize>>,
286            v: Option<&Set<usize>>,
287        ) -> bool {
288            // Perform sanity checks and convert sets.
289            _assert(self, x, y, Some(z), w, v);
290
291            // Set default values for W if not provided.
292            let w = match w {
293                Some(w) => w,
294                None => &set![],
295            };
296
297            // Compute the ancestors of X and Y.
298            let x_y_w = &(x | y) | w;
299            let an_x_y_w = &self.ancestors(&x_y_w) | &x_y_w;
300
301            // a) Check that Z is a separator.
302            let x_closure = _reachable(self, x, &an_x_y_w, z);
303            if !x_closure.is_disjoint(y) {
304                return false;
305            }
306
307            // b) Check that Z is constrained to An(X, Y).
308            if !z.is_subset(&an_x_y_w) {
309                return false;
310            }
311
312            // c) Check that Z is minimal.
313            let y_closure = _reachable(self, y, &an_x_y_w, z);
314            if !((z - w).is_subset(&(&x_closure & &y_closure))) {
315                return false;
316            }
317
318            // Otherwise, return true.
319            true
320        }
321
322        fn find_minimal_separator_set(
323            &self,
324            x: &Set<usize>,
325            y: &Set<usize>,
326            w: Option<&Set<usize>>,
327            v: Option<&Set<usize>>,
328        ) -> Option<Set<usize>> {
329            // Perform sanity checks and convert sets.
330            _assert(self, x, y, None::<&Set<_>>, w, v);
331
332            // Set default values for W and V if not provided.
333            let w = match w {
334                Some(w) => w,
335                None => &set![],
336            };
337            let v = match v {
338                Some(v) => v,
339                None => &self.vertices(),
340            };
341
342            // Compute the ancestors of X and Y.
343            let x_y_w = &(x | y) | w;
344            let an_x_y_w = &self.ancestors(&x_y_w) | &x_y_w;
345
346            // Initialize the restricted set with the intersection of X, Y, and included.
347            let z = v & &(&an_x_y_w - &(x | y));
348
349            // Check if Z is a separator.
350            let x_closure = _reachable(self, x, &an_x_y_w, &z);
351            if !x_closure.is_disjoint(y) {
352                return None; // No minimal separator exists.
353            }
354
355            // Update Z.
356            let z = &z & &(&x_closure | w);
357
358            // Check if Z is a separator.
359            let y_closure = _reachable(self, y, &an_x_y_w, &z);
360
361            // Return the minimal separator.
362            Some(&z & &(&y_closure | w))
363        }
364    }
365}