1use std::{collections::BTreeSet, fmt::Debug};
2
3use super::{GeneralizedIndependence, Independence, MoralGraph};
4use crate::{
5 graphs::directions,
6 prelude::{BaseGraph, DirectedGraph, UndirectedGraph, CC},
7 utils::UnionFind,
8 Adj, An, Ch, Ne, V,
9};
10
11#[derive(Clone, Debug)]
13pub struct GraphicalSeparation<'a, G, D>
14where
15 G: BaseGraph<Direction = D>,
16{
17 g: &'a G,
18}
19
20impl<'a, G, D> GraphicalSeparation<'a, G, D>
21where
22 G: BaseGraph<Direction = D>,
23{
24 #[inline]
63 pub const fn new(g: &'a G) -> Self {
64 Self { g }
65 }
66}
67
68impl<'a, G, D> From<&'a G> for GraphicalSeparation<'a, G, D>
69where
70 G: BaseGraph<Direction = D>,
71{
72 #[inline]
73 fn from(g: &'a G) -> Self {
74 Self::new(g)
75 }
76}
77
78impl<'a, G> Independence for GraphicalSeparation<'a, G, directions::Undirected>
80where
81 G: UndirectedGraph<Direction = directions::Undirected>,
82{
83 #[inline]
84 fn is_independent(&self, x: usize, y: usize, z: &[usize]) -> bool {
85 <Self as GeneralizedIndependence>::are_independent(self, [x], [y], z.iter().cloned())
87 }
88}
89
90impl<'a, G> GeneralizedIndependence for GraphicalSeparation<'a, G, directions::Undirected>
91where
92 G: UndirectedGraph<Direction = directions::Undirected>,
93{
94 fn are_independent<I, J, K>(&self, x: I, y: J, z: K) -> bool
96 where
97 I: IntoIterator<Item = usize>,
98 J: IntoIterator<Item = usize>,
99 K: IntoIterator<Item = usize>,
100 {
101 let x: BTreeSet<_> = x.into_iter().collect();
103 let y: BTreeSet<_> = y.into_iter().collect();
104 assert!(!x.is_empty() && !y.is_empty(), "X and Y must be non-empty");
105
106 let z: BTreeSet<_> = z.into_iter().collect();
108 assert!(
109 x.is_disjoint(&y) && y.is_disjoint(&z) && z.is_disjoint(&x),
110 "X, Y and Z must be disjoint sets"
111 );
112
113 let v: BTreeSet<_> = V!(self.g).collect();
115 assert!(
116 x.is_subset(&v) && y.is_subset(&v) && z.is_subset(&v),
117 "X, Y and Z must be subsets of V"
118 );
119
120 let mut h = self.g.clone();
122
123 let e_z = z
125 .into_iter()
126 .flat_map(|z| Ne!(self.g, z).map(move |w| (z, w)));
127 for (z, w) in e_z {
129 h.del_edge(z, w);
130 }
131
132 let mut union_find = UnionFind::new(h.order());
134 let root_x = *x.first().unwrap();
136 union_find.extend(x);
137 let root_y = *y.first().unwrap();
139 union_find.extend(y);
140
141 let mut cc = CC::from(&h);
143
144 !cc.any(|c| {
147 union_find.extend(c);
149 union_find.contains(root_x, root_y)
151 })
152 }
153}
154
155impl<'a, G> Independence for GraphicalSeparation<'a, G, directions::Directed>
157where
158 G: DirectedGraph<Direction = directions::Directed> + MoralGraph,
159{
160 #[inline]
161 fn is_independent(&self, x: usize, y: usize, z: &[usize]) -> bool {
162 <Self as GeneralizedIndependence>::are_independent(self, [x], [y], z.iter().cloned())
164 }
165}
166
167impl<'a, G> GeneralizedIndependence for GraphicalSeparation<'a, G, directions::Directed>
168where
169 G: DirectedGraph<Direction = directions::Directed> + MoralGraph,
170{
171 fn are_independent<I, J, K>(&self, x: I, y: J, z: K) -> bool
173 where
174 I: IntoIterator<Item = usize>,
175 J: IntoIterator<Item = usize>,
176 K: IntoIterator<Item = usize>,
177 {
178 let x: BTreeSet<_> = x.into_iter().collect();
180 let y: BTreeSet<_> = y.into_iter().collect();
181 assert!(!x.is_empty() && !y.is_empty(), "X and Y must be non-empty");
182
183 let z: BTreeSet<_> = z.into_iter().collect();
185 assert!(
186 x.is_disjoint(&y) && y.is_disjoint(&z) && z.is_disjoint(&x),
187 "X, Y and Z must be disjoint sets"
188 );
189
190 let s = &(&x | &y) | &z;
192
193 let v: BTreeSet<_> = V!(self.g).collect();
195 assert!(s.is_subset(&v), "X, Y and Z must be subsets of V");
196
197 let mut h = self.g.to_undirected();
199
200 let an_s = s.iter().flat_map(|&s| An!(self.g, s)).collect();
202 let an_s = &s | &an_s;
204
205 let e_s = (&v - &an_s)
207 .into_iter()
208 .flat_map(|s| Adj!(self.g, s).flat_map(move |t| [(s, t), (t, s)]));
209 for (s, t) in e_s {
211 h.del_edge(s, t);
212 }
213
214 let e_z = z
216 .into_iter()
217 .flat_map(|z| Ch!(self.g, z).map(move |w| (z, w)));
218 for (z, w) in e_z {
220 h.del_edge(z, w);
221 }
222
223 let mut union_find = UnionFind::new(h.order());
225 let root_x = *x.first().unwrap();
227 union_find.extend(x);
228 let root_y = *y.first().unwrap();
230 union_find.extend(y);
231
232 let mut cc = CC::from(&h);
234
235 !cc.any(|c| {
238 union_find.extend(c);
240 union_find.contains(root_x, root_y)
242 })
243 }
244}