oximedia_graph/
subgraph.rs1#![allow(dead_code)]
3
4use std::collections::{HashMap, HashSet};
5
6#[derive(Debug, Clone, Default)]
9pub struct SubgraphBoundary {
10 pub input_nodes: Vec<usize>,
14 pub output_nodes: Vec<usize>,
18}
19
20impl SubgraphBoundary {
21 #[must_use]
23 pub fn input_count(&self) -> usize {
24 self.input_nodes.len()
25 }
26
27 #[must_use]
29 pub fn output_count(&self) -> usize {
30 self.output_nodes.len()
31 }
32
33 #[must_use]
35 pub fn is_connected(&self) -> bool {
36 !self.input_nodes.is_empty() && !self.output_nodes.is_empty()
37 }
38}
39
40#[derive(Debug, Clone)]
43pub struct Subgraph {
44 nodes: HashSet<usize>,
46 internal_edges: Vec<(usize, usize)>,
48 label: String,
50}
51
52impl Subgraph {
53 #[must_use]
55 pub fn new(label: impl Into<String>) -> Self {
56 Self {
57 nodes: HashSet::new(),
58 internal_edges: Vec::new(),
59 label: label.into(),
60 }
61 }
62
63 pub fn add_node(&mut self, node_idx: usize) {
65 self.nodes.insert(node_idx);
66 }
67
68 pub fn add_internal_edge(&mut self, from: usize, to: usize) -> Result<(), String> {
71 if !self.nodes.contains(&from) {
72 return Err(format!("Node {from} is not part of this subgraph"));
73 }
74 if !self.nodes.contains(&to) {
75 return Err(format!("Node {to} is not part of this subgraph"));
76 }
77 self.internal_edges.push((from, to));
78 Ok(())
79 }
80
81 #[must_use]
83 pub fn node_count(&self) -> usize {
84 self.nodes.len()
85 }
86
87 #[must_use]
89 pub fn edge_count(&self) -> usize {
90 self.internal_edges.len()
91 }
92
93 #[must_use]
95 pub fn contains_node(&self, node_idx: usize) -> bool {
96 self.nodes.contains(&node_idx)
97 }
98
99 #[must_use]
101 pub fn label(&self) -> &str {
102 &self.label
103 }
104
105 #[must_use]
107 pub fn internal_edges(&self) -> &[(usize, usize)] {
108 &self.internal_edges
109 }
110
111 #[must_use]
120 pub fn boundary(&self, all_edges: &[(usize, usize)]) -> SubgraphBoundary {
121 let mut inputs: HashSet<usize> = HashSet::new();
122 let mut outputs: HashSet<usize> = HashSet::new();
123
124 for &(from, to) in all_edges {
125 if !self.nodes.contains(&from) && self.nodes.contains(&to) {
126 inputs.insert(from);
127 }
128 if self.nodes.contains(&from) && !self.nodes.contains(&to) {
129 outputs.insert(from);
130 }
131 }
132
133 SubgraphBoundary {
134 input_nodes: inputs.into_iter().collect(),
135 output_nodes: outputs.into_iter().collect(),
136 }
137 }
138}
139
140pub struct SubgraphExtractor {
143 all_edges: Vec<(usize, usize)>,
145 node_count: usize,
147}
148
149impl SubgraphExtractor {
150 #[must_use]
152 pub fn new(node_count: usize, edges: Vec<(usize, usize)>) -> Self {
153 Self {
154 all_edges: edges,
155 node_count,
156 }
157 }
158
159 pub fn extract(&self, seeds: &[usize], label: impl Into<String>) -> Result<Subgraph, String> {
164 for &s in seeds {
165 if s >= self.node_count {
166 return Err(format!(
167 "Seed node {s} is out of range (node_count={})",
168 self.node_count
169 ));
170 }
171 }
172
173 let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
175 for &(f, t) in &self.all_edges {
176 adj.entry(f).or_default().push(t);
177 }
178
179 let mut visited: HashSet<usize> = HashSet::new();
180 let mut stack: Vec<usize> = seeds.to_vec();
181
182 while let Some(node) = stack.pop() {
183 if visited.insert(node) {
184 if let Some(neighbours) = adj.get(&node) {
185 for &nb in neighbours {
186 if !visited.contains(&nb) {
187 stack.push(nb);
188 }
189 }
190 }
191 }
192 }
193
194 let mut sg = Subgraph::new(label);
195 for n in &visited {
196 sg.add_node(*n);
197 }
198 for &(f, t) in &self.all_edges {
200 if visited.contains(&f) && visited.contains(&t) {
201 sg.add_internal_edge(f, t).ok();
202 }
203 }
204 Ok(sg)
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 fn sample_extractor() -> SubgraphExtractor {
213 SubgraphExtractor::new(5, vec![(0, 1), (0, 2), (1, 3), (2, 3), (3, 4)])
215 }
216
217 #[test]
220 fn test_boundary_input_count() {
221 let b = SubgraphBoundary {
222 input_nodes: vec![0, 1],
223 output_nodes: vec![5],
224 };
225 assert_eq!(b.input_count(), 2);
226 }
227
228 #[test]
229 fn test_boundary_output_count() {
230 let b = SubgraphBoundary {
231 input_nodes: vec![0],
232 output_nodes: vec![5, 6],
233 };
234 assert_eq!(b.output_count(), 2);
235 }
236
237 #[test]
238 fn test_boundary_is_connected() {
239 let b = SubgraphBoundary {
240 input_nodes: vec![0],
241 output_nodes: vec![3],
242 };
243 assert!(b.is_connected());
244 }
245
246 #[test]
247 fn test_boundary_not_connected_no_inputs() {
248 let b = SubgraphBoundary {
249 input_nodes: vec![],
250 output_nodes: vec![3],
251 };
252 assert!(!b.is_connected());
253 }
254
255 #[test]
258 fn test_add_node_increments_count() {
259 let mut sg = Subgraph::new("test");
260 sg.add_node(0);
261 sg.add_node(1);
262 assert_eq!(sg.node_count(), 2);
263 }
264
265 #[test]
266 fn test_add_internal_edge_valid() {
267 let mut sg = Subgraph::new("test");
268 sg.add_node(0);
269 sg.add_node(1);
270 assert!(sg.add_internal_edge(0, 1).is_ok());
271 assert_eq!(sg.edge_count(), 1);
272 }
273
274 #[test]
275 fn test_add_internal_edge_missing_node_returns_error() {
276 let mut sg = Subgraph::new("test");
277 sg.add_node(0);
278 assert!(sg.add_internal_edge(0, 99).is_err());
279 }
280
281 #[test]
282 fn test_contains_node_true() {
283 let mut sg = Subgraph::new("test");
284 sg.add_node(5);
285 assert!(sg.contains_node(5));
286 }
287
288 #[test]
289 fn test_contains_node_false() {
290 let sg = Subgraph::new("test");
291 assert!(!sg.contains_node(0));
292 }
293
294 #[test]
295 fn test_label_stored() {
296 let sg = Subgraph::new("my_sub");
297 assert_eq!(sg.label(), "my_sub");
298 }
299
300 #[test]
301 fn test_boundary_detects_input_output_nodes() {
302 let mut sg = Subgraph::new("inner");
303 sg.add_node(1);
305 sg.add_node(2);
306 sg.add_node(3);
307 let all_edges = vec![(0, 1), (0, 2), (1, 3), (2, 3), (3, 4)];
308 let b = sg.boundary(&all_edges);
309 assert!(b.input_nodes.contains(&0));
311 assert!(b.output_nodes.contains(&3));
313 }
314
315 #[test]
318 fn test_extract_from_root_contains_all_nodes() {
319 let ext = sample_extractor();
320 let sg = ext.extract(&[0], "full").expect("extract should succeed");
321 assert_eq!(sg.node_count(), 5);
322 }
323
324 #[test]
325 fn test_extract_from_midpoint() {
326 let ext = sample_extractor();
327 let sg = ext.extract(&[3], "tail").expect("extract should succeed");
328 assert!(sg.contains_node(3));
329 assert!(sg.contains_node(4));
330 assert!(!sg.contains_node(0));
331 }
332
333 #[test]
334 fn test_extract_includes_internal_edges() {
335 let ext = sample_extractor();
336 let sg = ext.extract(&[0], "full").expect("extract should succeed");
337 assert!(sg.edge_count() > 0);
338 }
339
340 #[test]
341 fn test_extract_out_of_range_seed_returns_error() {
342 let ext = sample_extractor();
343 assert!(ext.extract(&[99], "bad").is_err());
344 }
345
346 #[test]
347 fn test_extract_empty_seeds() {
348 let ext = sample_extractor();
349 let sg = ext.extract(&[], "empty").expect("extract should succeed");
350 assert_eq!(sg.node_count(), 0);
351 }
352}