causal_hub/inference/graphical_separation.rs
1use std::collections::VecDeque;
2
3use crate::{
4 models::{DiGraph, Graph},
5 set,
6 types::Set,
7};
8
9/// A trait for graphical separation.
10pub trait GraphicalSeparation {
11 /// Checks if the `Z` is a separator set for `X` and `Y`.
12 ///
13 /// # Arguments
14 ///
15 /// * `x` - A set of vertices representing set `X`.
16 /// * `y` - A set of vertices representing set `Y`.
17 /// * `z` - A set of vertices representing set `Z`.
18 ///
19 /// # Panics
20 ///
21 /// * If any of the vertex in `X`, `Y`, or `Z` are out of bounds.
22 /// * If `X`, `Y` or `Z` are not disjoint sets.
23 /// * If `X` and `Y` are empty sets.
24 ///
25 /// # Returns
26 ///
27 /// `true` if `X` and `Y` are separated by `Z`, `false` otherwise.
28 ///
29 fn is_separator_set(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> bool;
30
31 /// Checks if the `Z` is a minimal separator set for `X` and `Y`.
32 ///
33 /// # Arguments
34 ///
35 /// * `x` - A set of vertices representing set `X`.
36 /// * `y` - A set of vertices representing set `Y`.
37 /// * `z` - A set of vertices representing set `Z`.
38 /// * `w` - An optional iterable collection of vertices representing set `W`.
39 /// * `v` - An optional iterable collection of vertices representing set `V`.
40 ///
41 /// # Panics
42 ///
43 /// * If any of the vertex in `X`, `Y`, `Z`, `W` or `V` are out of bounds.
44 /// * If `X`, `Y` or `Z` are not disjoint sets.
45 /// * If `X` and `Y` are empty sets.
46 /// * If not `W` <= `Z` <= `V`.
47 ///
48 /// # Returns
49 ///
50 /// `true` if `Z` is a minimal separator set for `X` and `Y`, `false` otherwise.
51 ///
52 fn is_minimal_separator_set(
53 &self,
54 x: &Set<usize>,
55 y: &Set<usize>,
56 z: &Set<usize>,
57 w: Option<&Set<usize>>,
58 v: Option<&Set<usize>>,
59 ) -> bool;
60
61 /// Finds a minimal separator set for the vertex sets `X` and `Y`, if any.
62 ///
63 /// # Arguments
64 ///
65 /// * `x` - A set of vertices representing set `X`.
66 /// * `y` - A set of vertices representing set `Y`.
67 ///
68 /// # Panics
69 ///
70 /// * If any of the vertex in `X`, `Y`, `W` or `V` are out of bounds.
71 /// * If `X` and `Y` are not disjoint sets.
72 /// * If `X` or `Y` are empty sets.
73 /// * If not `W` <= `V`.
74 ///
75 /// # Returns
76 ///
77 /// `Some(Set)` containing the minimal separator set, or `None` if no separator set exists.
78 ///
79 fn find_minimal_separator_set(
80 &self,
81 x: &Set<usize>,
82 y: &Set<usize>,
83 w: Option<&Set<usize>>,
84 v: Option<&Set<usize>>,
85 ) -> Option<Set<usize>>;
86}
87
88// Implementation of the `GraphicalSeparation` trait for directed graphs.
89pub(crate) mod digraph {
90 use super::*;
91 use crate::inference::TopologicalOrder;
92
93 /// Asserts the validity of the sets and returns them as `Set<usize>`.
94 pub(crate) fn _assert(
95 g: &DiGraph,
96 x: &Set<usize>,
97 y: &Set<usize>,
98 z: Option<&Set<usize>>,
99 w: Option<&Set<usize>>,
100 v: Option<&Set<usize>>,
101 ) {
102 // Assert the included set is a subset of the restricted set.
103 if let (Some(w), Some(v)) = (w.as_ref(), v.as_ref()) {
104 assert!(w.is_subset(v), "Set W must be a subset of set V.");
105 }
106
107 // Convert X to set, while checking for out of bounds.
108 for &x in x {
109 assert!(g.has_vertex(x), "Vertex `{x}` in set X is out of bounds.");
110 }
111 // Convert Y to set, while checking for out of bounds.
112 for &y in y {
113 assert!(g.has_vertex(y), "Vertex `{y}` in set Y is out of bounds.");
114 }
115 // Convert Z to set, while checking for out of bounds.
116 if let Some(z) = z {
117 for &z in z {
118 assert!(g.has_vertex(z), "Vertex `{z}` in set Z is out of bounds.");
119 }
120 }
121
122 // Assert X is non-empty.
123 assert!(!x.is_empty(), "Set X must not be empty.");
124 // Assert Y is non-empty.
125 assert!(!y.is_empty(), "Set Y must not be empty.");
126
127 // Assert X and Y are disjoint.
128 assert!(x.is_disjoint(y), "Sets X and Y must be disjoint.");
129
130 // If Z is provided, convert it to a set.
131 if let Some(z) = &z {
132 // Assert X and Z are disjoint.
133 assert!(x.is_disjoint(z), "Sets X and Z must be disjoint.");
134 // Assert Y and Z are disjoint.
135 assert!(y.is_disjoint(z), "Sets Y and Z must be disjoint.");
136 // Assert Z includes.
137 if let Some(w) = w {
138 assert!(z.is_superset(w), "Set Z must be a superset of set W.");
139 }
140 // Assert Z is restricted.
141 if let Some(v) = v {
142 assert!(z.is_subset(v), "Set Z must be a subset of set V.");
143 }
144 }
145 }
146
147 fn _reachable(g: &DiGraph, x: &Set<usize>, an_x: &Set<usize>, z: &Set<usize>) -> Set<usize> {
148 // Assert the graph is a DAG.
149 assert!(g.topological_order().is_some(), "Graph must be a DAG.");
150
151 // Check if the ball passes or not.
152 let _pass = |e: bool, v: usize, f: bool, n: usize| {
153 let is_element_of_a = an_x.contains(&n);
154 let almost_definite_status = true; // NOTE: Always true for DAGs, not so for RCGs.
155 let collider_if_in_z = !z.contains(&v) || (e && !f);
156 // If the edge is forward, the vertex must be an ancestor or in Z.
157 is_element_of_a && collider_if_in_z && almost_definite_status
158 };
159
160 // Initialize the queue.
161 let mut queue: VecDeque<(bool, usize)> = Default::default();
162 // For each vertex in X ...
163 for &w in x {
164 // If the vertex has predecessors, add it to the queue as a backward edge.
165 if !g.parents(&set![w]).is_empty() {
166 queue.push_back((false, w));
167 }
168 // If the vertex has successors, add it to the queue as a forward edge.
169 if !g.children(&set![w]).is_empty() {
170 queue.push_back((true, w));
171 }
172 }
173
174 // Initialize the processed set with the queue.
175 let mut visited = queue.clone();
176
177 // For each element in the queue ...
178 while let Some((e, v)) = queue.pop_front() {
179 // Get the predecessors and successors of the vertex.
180 let pa_v = g.parents(&set![v]).into_iter().map(|n| (false, n));
181 let ch_v = g.children(&set![v]).into_iter().map(|n| (true, n));
182
183 // Create pairs of (forward, vertex) for predecessors and successors.
184 let f_n_pairs = pa_v.chain(ch_v);
185
186 // For each pair ...
187 for (f, n) in f_n_pairs {
188 // If the pair has not been processed and passes the condition ...
189 if !visited.contains(&(f, n)) && _pass(e, v, f, n) {
190 // Add it to the queue and mark it as processed.
191 queue.push_back((f, n));
192 visited.push_back((f, n));
193 }
194 }
195 }
196
197 // Return the set of visited vertices.
198 visited.into_iter().map(|(_, w)| w).collect()
199 }
200
201 impl GraphicalSeparation for DiGraph {
202 fn is_separator_set(&self, x: &Set<usize>, y: &Set<usize>, z: &Set<usize>) -> bool {
203 // Perform sanity checks and convert sets.
204 _assert(self, x, y, Some(z), None::<&Set<_>>, None::<&Set<_>>);
205
206 // Initialize the forward and backward deques and visited sets.
207
208 // Contains -> and <-> edges from starting vertex.
209 let mut forward_deque: VecDeque<usize> = Default::default();
210 let mut forward_visited: Set<usize> = set![];
211 // Contains <- and - edges from starting vertex.
212 let mut backward_deque: VecDeque<usize> = Default::default();
213 let mut backward_visited: Set<usize> = set![];
214
215 // Initialize the backward deque with the vertices in X.
216 backward_deque.extend(x.iter().cloned());
217
218 // Compute the ancestors of X and Z.
219 let ancestors_or_z = &self.ancestors(z) | &(z | x);
220
221 // While there are vertices to visit in the forward or backward deques ...
222 while !forward_deque.is_empty() || !backward_deque.is_empty() {
223 // If there are vertices in the backward deque ...
224 if let Some(w) = backward_deque.pop_front() {
225 // Mark the W as visited.
226 backward_visited.insert(w);
227 // If the W is in Y, return false (not separated).
228 if y.contains(&w) {
229 return false;
230 }
231 // If the W is in Z, continue to the next iteration.
232 if z.contains(&w) {
233 continue;
234 }
235 // Add all predecessors of the W to the backward deque.
236 for pred in self.parents(&set![w]) {
237 if !backward_visited.contains(&pred) {
238 backward_deque.push_back(pred);
239 }
240 }
241 // Add all successors of the W to the forward deque.
242 for succ in self.children(&set![w]) {
243 if !forward_visited.contains(&succ) {
244 forward_deque.push_back(succ);
245 }
246 }
247 }
248
249 // If there are vertices in the forward deque ...
250 if let Some(w) = forward_deque.pop_front() {
251 // Mark the W as visited.
252 forward_visited.insert(w);
253 // If the W is in Y, return false (not separated).
254 if y.contains(&w) {
255 return false;
256 }
257 // If the W is an ancestor or in Z, add its predecessors to the backward deque.
258 if ancestors_or_z.contains(&w) {
259 for pred in self.parents(&set![w]) {
260 if !backward_visited.contains(&pred) {
261 backward_deque.push_back(pred);
262 }
263 }
264 }
265 // If the W is not in Z, add its successors to the forward deque.
266 if !z.contains(&w) {
267 for succ in self.children(&set![w]) {
268 if !forward_visited.contains(&succ) {
269 forward_deque.push_back(succ);
270 }
271 }
272 }
273 }
274 }
275
276 // Otherwise, return true.
277 true
278 }
279
280 fn is_minimal_separator_set(
281 &self,
282 x: &Set<usize>,
283 y: &Set<usize>,
284 z: &Set<usize>,
285 w: Option<&Set<usize>>,
286 v: Option<&Set<usize>>,
287 ) -> bool {
288 // Perform sanity checks and convert sets.
289 _assert(self, x, y, Some(z), w, v);
290
291 // Set default values for W if not provided.
292 let w = match w {
293 Some(w) => w,
294 None => &set![],
295 };
296
297 // Compute the ancestors of X and Y.
298 let x_y_w = &(x | y) | w;
299 let an_x_y_w = &self.ancestors(&x_y_w) | &x_y_w;
300
301 // a) Check that Z is a separator.
302 let x_closure = _reachable(self, x, &an_x_y_w, z);
303 if !x_closure.is_disjoint(y) {
304 return false;
305 }
306
307 // b) Check that Z is constrained to An(X, Y).
308 if !z.is_subset(&an_x_y_w) {
309 return false;
310 }
311
312 // c) Check that Z is minimal.
313 let y_closure = _reachable(self, y, &an_x_y_w, z);
314 if !((z - w).is_subset(&(&x_closure & &y_closure))) {
315 return false;
316 }
317
318 // Otherwise, return true.
319 true
320 }
321
322 fn find_minimal_separator_set(
323 &self,
324 x: &Set<usize>,
325 y: &Set<usize>,
326 w: Option<&Set<usize>>,
327 v: Option<&Set<usize>>,
328 ) -> Option<Set<usize>> {
329 // Perform sanity checks and convert sets.
330 _assert(self, x, y, None::<&Set<_>>, w, v);
331
332 // Set default values for W and V if not provided.
333 let w = match w {
334 Some(w) => w,
335 None => &set![],
336 };
337 let v = match v {
338 Some(v) => v,
339 None => &self.vertices(),
340 };
341
342 // Compute the ancestors of X and Y.
343 let x_y_w = &(x | y) | w;
344 let an_x_y_w = &self.ancestors(&x_y_w) | &x_y_w;
345
346 // Initialize the restricted set with the intersection of X, Y, and included.
347 let z = v & &(&an_x_y_w - &(x | y));
348
349 // Check if Z is a separator.
350 let x_closure = _reachable(self, x, &an_x_y_w, &z);
351 if !x_closure.is_disjoint(y) {
352 return None; // No minimal separator exists.
353 }
354
355 // Update Z.
356 let z = &z & &(&x_closure | w);
357
358 // Check if Z is a separator.
359 let y_closure = _reachable(self, y, &an_x_y_w, &z);
360
361 // Return the minimal separator.
362 Some(&z & &(&y_closure | w))
363 }
364 }
365}