1use std::collections::HashSet;
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::PortIndex;
9use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
10use crate::hugr::{HugrMut, HugrView};
11use crate::ops;
12use crate::ops::controlflow::BasicBlock;
13use crate::ops::handle::NodeHandle;
14use crate::ops::{DataflowBlock, OpType};
15use crate::{Node, type_row};
16
17use super::{PatchHugrMut, PatchVerification};
18
19pub struct OutlineCfg {
22 blocks: HashSet<Node>,
23}
24
25impl OutlineCfg {
26 pub fn new(blocks: impl IntoIterator<Item = Node>) -> Self {
28 Self {
29 blocks: HashSet::from_iter(blocks),
30 }
31 }
32
33 fn compute_entry_exit(
36 &self,
37 h: &impl HugrView<Node = Node>,
38 ) -> Result<(Node, Node, Node), OutlineCfgError> {
39 let cfg_n = match self
40 .blocks
41 .iter()
42 .map(|n| h.get_parent(*n))
43 .unique()
44 .exactly_one()
45 {
46 Ok(Some(n)) => n,
47 _ => return Err(OutlineCfgError::NotSiblings),
48 };
49 let o = h.get_optype(cfg_n);
50 let OpType::CFG(_) = o else {
51 return Err(OutlineCfgError::ParentNotCfg(cfg_n, Box::new(o.clone())));
52 };
53 let cfg_entry = h.children(cfg_n).next().unwrap();
54 let mut entry = None;
55 let mut exit_succ = None;
56 for &n in &self.blocks {
57 if n == cfg_entry
58 || h.input_neighbours(n)
59 .any(|pred| !self.blocks.contains(&pred))
60 {
61 match entry {
62 None => {
63 entry = Some(n);
64 }
65 Some(prev) => {
66 return Err(OutlineCfgError::MultipleEntryNodes(prev, n));
67 }
68 }
69 }
70 let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s));
71 match external_succs.at_most_one() {
72 Ok(None) => (), Ok(Some(o)) => match exit_succ {
74 None => {
75 exit_succ = Some((n, o));
76 }
77 Some((prev, _)) => {
78 return Err(OutlineCfgError::MultipleExitNodes(prev, n));
79 }
80 },
81 Err(ext) => return Err(OutlineCfgError::MultipleExitEdges(n, ext.collect())),
82 }
83 }
84 match (entry, exit_succ) {
85 (Some(e), Some((x, o))) => Ok((e, x, o)),
86 (None, _) => Err(OutlineCfgError::NoEntryNode),
87 (_, None) => Err(OutlineCfgError::NoExitNode),
88 }
89 }
90}
91
92impl PatchVerification for OutlineCfg {
93 type Error = OutlineCfgError;
94 type Node = Node;
95 fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), OutlineCfgError> {
96 self.compute_entry_exit(h)?;
97 Ok(())
98 }
99
100 fn invalidated_nodes(
101 &self,
102 _: &impl HugrView<Node = Self::Node>,
103 ) -> impl Iterator<Item = Self::Node> {
104 self.blocks.iter().copied()
105 }
106}
107
108impl PatchHugrMut for OutlineCfg {
109 type Outcome = [Node; 2];
113
114 const UNCHANGED_ON_FAILURE: bool = true;
115 fn apply_hugr_mut(
116 self,
117 h: &mut impl HugrMut<Node = Node>,
118 ) -> Result<[Node; 2], OutlineCfgError> {
119 let (entry, exit, outside) = self.compute_entry_exit(h)?;
120 let OpType::DataflowBlock(DataflowBlock { inputs, .. }) = h.get_optype(entry) else {
123 panic!("Entry node is not a basic block")
124 };
125 let inputs = inputs.clone();
126 let outputs = match h.get_optype(outside) {
127 OpType::DataflowBlock(dfb) => dfb.dataflow_input().clone(),
128 OpType::ExitBlock(exit) => exit.dataflow_input().clone(),
129 _ => panic!("External successor not a basic block"),
130 };
131 let outer_cfg = h.get_parent(entry).unwrap();
132 let outer_entry = h.children(outer_cfg).next().unwrap();
133
134 let (new_block, cfg_node) = {
136 let mut new_block_bldr =
137 BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap();
138 let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires());
139 let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap();
140 let cfg = cfg.finish_sub_container().unwrap();
141 let unit_sum = new_block_bldr.add_constant(ops::Value::unary_unit_sum());
142 let pred_wire = new_block_bldr.load_const(&unit_sum);
143 new_block_bldr
144 .set_outputs(pred_wire, cfg.outputs())
145 .unwrap();
146 let new_block_hugr = std::mem::take(new_block_bldr.hugr_mut());
147 let ins_res = h.insert_hugr(outer_cfg, new_block_hugr);
148 (
149 ins_res.inserted_entrypoint,
150 *ins_res.node_map.get(&cfg.node()).unwrap(),
151 )
152 };
153
154 let preds: Vec<_> = h
156 .linked_outputs(entry, h.node_inputs(entry).exactly_one().ok().unwrap())
157 .collect();
158 for (pred, br) in preds {
159 if !self.blocks.contains(&pred) {
160 h.disconnect(pred, br);
161 h.connect(pred, br, new_block, 0);
162 }
163 }
164 if entry == outer_entry {
165 h.move_before_sibling(new_block, outer_entry);
168 }
169
170 let exit_port = h
173 .node_outputs(exit)
174 .filter(|p| {
175 let (t, p2) = h.single_linked_input(exit, *p).unwrap();
176 assert!(p2.index() == 0);
177 t == outside
178 })
179 .exactly_one()
180 .ok() .unwrap();
182 h.disconnect(exit, exit_port);
183 h.connect(new_block, 0, outside, 0);
185
186 let inner_exit = {
188 let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap();
189
190 h.move_before_sibling(entry, inner_exit);
192 for n in self.blocks {
194 if n != entry {
196 h.set_parent(n, cfg_node);
197 }
198 }
199 inner_exit
200 };
201
202 h.connect(exit, exit_port, inner_exit, 0);
204
205 Ok([new_block, cfg_node])
206 }
207}
208
209#[derive(Debug, Error)]
211#[non_exhaustive]
212pub enum OutlineCfgError {
213 #[error("The nodes did not all have the same parent")]
215 NotSiblings,
216 #[error("The parent node {0} was not a CFG but a {1}")]
218 ParentNotCfg(Node, Box<OpType>),
219 #[error("Multiple blocks had predecessors outside the set - at least {0} and {1}")]
221 MultipleEntryNodes(Node, Node),
222 #[error("Multiple blocks had edges leaving the set - at least {0} and {1}")]
225 MultipleExitNodes(Node, Node),
226 #[error("Exit block {0} had edges to multiple external blocks {1:?}")]
228 MultipleExitEdges(Node, Vec<Node>),
229 #[error("No block had predecessors outside the set")]
231 NoEntryNode,
232 #[error("No block had a successor outside the set")]
234 NoExitNode,
235}
236
237#[cfg(test)]
238mod test {
239 use std::collections::HashSet;
240
241 use crate::builder::{
242 BlockBuilder, BuildError, CFGBuilder, Container, Dataflow, DataflowSubContainer,
243 HugrBuilder, ModuleBuilder,
244 };
245 use crate::extension::prelude::usize_t;
246 use crate::hugr::HugrMut;
247 use crate::ops::constant::Value;
248 use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
249 use crate::types::Signature;
250 use crate::{Hugr, HugrView, Node};
251 use cool_asserts::assert_matches;
252 use itertools::Itertools;
253 use rstest::rstest;
254
255 use super::{OutlineCfg, OutlineCfgError};
256
257 struct CondThenLoopCfg {
261 h: Hugr,
262 left: Node,
263 right: Node,
264 merge: Node,
265 head: Node,
266 tail: Node,
267 }
268 impl CondThenLoopCfg {
269 fn new() -> Result<CondThenLoopCfg, BuildError> {
270 let block_ty = Signature::new_endo(usize_t());
271 let mut cfg_builder = CFGBuilder::new(block_ty.clone())?;
272 let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
273 let const_unit = cfg_builder.add_constant(Value::unary_unit_sum());
274 fn n_identity(
275 mut bbldr: BlockBuilder<&mut Hugr>,
276 cst: &ConstID,
277 ) -> Result<BasicBlockID, BuildError> {
278 let pred = bbldr.load_const(cst);
279 let vals = bbldr.input_wires();
280 bbldr.finish_with_outputs(pred, vals)
281 }
282 let id_block = |c: &mut CFGBuilder<_>| {
283 n_identity(c.simple_block_builder(block_ty.clone(), 1)?, &const_unit)
284 };
285
286 let entry = n_identity(
287 cfg_builder.simple_entry_builder(usize_t().into(), 2)?,
288 &pred_const,
289 )?;
290
291 let left = id_block(&mut cfg_builder)?;
292 let right = id_block(&mut cfg_builder)?;
293 cfg_builder.branch(&entry, 0, &left)?;
294 cfg_builder.branch(&entry, 1, &right)?;
295
296 let merge = id_block(&mut cfg_builder)?;
297 cfg_builder.branch(&left, 0, &merge)?;
298 cfg_builder.branch(&right, 0, &merge)?;
299
300 let head = id_block(&mut cfg_builder)?;
301 cfg_builder.branch(&merge, 0, &head)?;
302 let tail = n_identity(
303 cfg_builder.simple_block_builder(Signature::new_endo(usize_t()), 2)?,
304 &pred_const,
305 )?;
306 cfg_builder.branch(&tail, 1, &head)?;
307 cfg_builder.branch(&head, 0, &tail)?; let exit = cfg_builder.exit_block();
309 cfg_builder.branch(&tail, 0, &exit)?;
310
311 let h = cfg_builder.finish_hugr()?;
312 let (left, right) = (left.node(), right.node());
313 let (merge, head, tail) = (merge.node(), head.node(), tail.node());
314 Ok(Self {
315 h,
316 left,
317 right,
318 merge,
319 head,
320 tail,
321 })
322 }
323 fn entry_exit(&self) -> (Node, Node) {
324 self.h
325 .children(self.h.entrypoint())
326 .take(2)
327 .collect_tuple()
328 .unwrap()
329 }
330 }
331
332 #[rstest::fixture]
333 fn cond_then_loop_cfg() -> CondThenLoopCfg {
334 CondThenLoopCfg::new().unwrap()
335 }
336
337 #[rstest]
338 fn test_outline_cfg_errors(cond_then_loop_cfg: CondThenLoopCfg) {
339 let (entry, _) = cond_then_loop_cfg.entry_exit();
340 let CondThenLoopCfg {
341 mut h,
342 left,
343 right,
344 merge,
345 head,
346 tail,
347 } = cond_then_loop_cfg;
348 let backup = h.clone();
349
350 let r = h.apply_patch(OutlineCfg::new([tail]));
351 assert_matches!(r, Err(OutlineCfgError::MultipleExitEdges(_, _)));
352 assert_eq!(h, backup);
353
354 let r = h.apply_patch(OutlineCfg::new([entry, left, right]));
355 assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b))
356 => assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right])));
357 assert_eq!(h, backup);
358
359 let r = h.apply_patch(OutlineCfg::new([left, right, merge]));
360 assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
361 => assert_eq!(HashSet::from([a,b]), HashSet::from([left, right])));
362 assert_eq!(h, backup);
363
364 let r = h.apply_patch(OutlineCfg::new([entry, left, right, merge, head]));
366 assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
367 => assert_eq!(HashSet::from([a,b]), HashSet::from([entry, head])));
368 assert_eq!(h, backup);
369 }
370
371 #[rstest::rstest]
372 fn test_outline_cfg(cond_then_loop_cfg: CondThenLoopCfg) {
373 let (_, exit) = cond_then_loop_cfg.entry_exit();
378 let CondThenLoopCfg {
379 mut h,
380 merge,
381 head,
382 tail,
383 ..
384 } = cond_then_loop_cfg;
385 let root = h.entrypoint();
386 let (new_block, _, exit_block) = outline_cfg_check_parents(&mut h, root, vec![head, tail]);
387 assert_eq!(h.output_neighbours(merge).collect_vec(), vec![new_block]);
388 assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
389 assert_eq!(
390 h.output_neighbours(tail).collect::<HashSet<Node>>(),
391 HashSet::from([head, exit_block])
392 );
393 }
394
395 #[rstest]
396 fn test_outline_cfg_multiple_in_edges(cond_then_loop_cfg: CondThenLoopCfg) {
397 let (_, exit) = cond_then_loop_cfg.entry_exit();
402 let CondThenLoopCfg {
403 mut h,
404 left,
405 right,
406 merge,
407 head,
408 tail,
409 } = cond_then_loop_cfg;
410
411 let root = h.entrypoint();
412 let (new_block, _, inner_exit) =
413 outline_cfg_check_parents(&mut h, root, vec![merge, head, tail]);
414 assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
415 assert_eq!(
416 h.input_neighbours(new_block).collect::<HashSet<_>>(),
417 HashSet::from([left, right])
418 );
419 assert_eq!(
420 h.output_neighbours(tail).collect::<HashSet<Node>>(),
421 HashSet::from([head, inner_exit])
422 );
423 }
424
425 #[rstest]
426 fn test_outline_cfg_subregion(cond_then_loop_cfg: CondThenLoopCfg) {
427 let mut module_builder = ModuleBuilder::new();
430 let mut fbuild = module_builder
431 .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))
432 .unwrap();
433 let [i1] = fbuild.input_wires_arr();
434 let cfg = fbuild
435 .add_hugr_with_wires(cond_then_loop_cfg.h, [i1])
436 .unwrap();
437 fbuild.finish_with_outputs(cfg.outputs()).unwrap();
438 let mut h = module_builder.finish_hugr().unwrap();
439 let cfg = cfg.node();
441 let exit_node = h.children(cfg).nth(1).unwrap();
442 let tail = h.input_neighbours(exit_node).exactly_one().ok().unwrap();
443 let head = h.input_neighbours(tail).exactly_one().ok().unwrap();
444 assert!(h.get_optype(exit_node).is_exit_block());
446 assert_eq!(
447 h.output_neighbours(tail).collect::<HashSet<_>>(),
448 HashSet::from([head, exit_node])
449 );
450 outline_cfg_check_parents(&mut h, cfg, vec![head, tail]);
451 h.validate().unwrap();
452 }
453
454 #[rstest]
455 fn test_outline_cfg_move_entry(cond_then_loop_cfg: CondThenLoopCfg) {
456 let (entry, _) = cond_then_loop_cfg.entry_exit();
462 let CondThenLoopCfg {
463 mut h,
464 left,
465 right,
466 merge,
467 head,
468 ..
469 } = cond_then_loop_cfg;
470
471 let root = h.entrypoint();
472 let (new_block, _, _) =
473 outline_cfg_check_parents(&mut h, root, vec![entry, left, right, merge]);
474 h.validate().unwrap_or_else(|e| panic!("{e}"));
475 assert_eq!(new_block, h.children(h.entrypoint()).next().unwrap());
476 assert_eq!(h.output_neighbours(new_block).collect_vec(), [head]);
477 }
478
479 fn outline_cfg_check_parents(
480 h: &mut impl HugrMut<Node = Node>,
481 cfg: Node,
482 blocks: Vec<Node>,
483 ) -> (Node, Node, Node) {
484 let mut other_blocks = h.children(cfg).collect::<HashSet<_>>();
485 assert!(blocks.iter().all(|b| other_blocks.remove(b)));
486 let [new_block, new_cfg] = h.apply_patch(OutlineCfg::new(blocks.clone())).unwrap();
487
488 for n in other_blocks {
489 assert_eq!(h.get_parent(n), Some(cfg));
490 }
491 assert_eq!(h.get_parent(new_block), Some(cfg));
492 assert!(h.get_optype(new_block).is_dataflow_block());
493 assert_eq!(h.get_parent(new_cfg), Some(new_block));
494 for n in blocks {
495 assert_eq!(h.get_parent(n), Some(new_cfg));
496 }
497 assert!(h.get_optype(new_cfg).is_cfg());
498 let exit_block = h.children(new_cfg).nth(1).unwrap();
499 assert!(h.get_optype(exit_block).is_exit_block());
500 (new_block, new_cfg, exit_block)
501 }
502}