1use std::collections::HashSet;
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::PortIndex;
9use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
10use crate::hugr::{HugrMut, HugrView};
11use crate::ops;
12use crate::ops::controlflow::BasicBlock;
13use crate::ops::handle::NodeHandle;
14use crate::ops::{DataflowBlock, OpType};
15use crate::{Node, type_row};
16
17use super::{PatchHugrMut, PatchVerification};
18
19pub struct OutlineCfg {
22 blocks: HashSet<Node>,
23}
24
25impl OutlineCfg {
26 pub fn new(blocks: impl IntoIterator<Item = Node>) -> Self {
28 Self {
29 blocks: HashSet::from_iter(blocks),
30 }
31 }
32
33 fn compute_entry_exit(
36 &self,
37 h: &impl HugrView<Node = Node>,
38 ) -> Result<(Node, Node, Node), OutlineCfgError> {
39 let cfg_n = match self
40 .blocks
41 .iter()
42 .map(|n| h.get_parent(*n))
43 .unique()
44 .exactly_one()
45 {
46 Ok(Some(n)) => n,
47 _ => return Err(OutlineCfgError::NotSiblings),
48 };
49 let o = h.get_optype(cfg_n);
50 let OpType::CFG(_) = o else {
51 return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone()));
52 };
53 let cfg_entry = h.children(cfg_n).next().unwrap();
54 let mut entry = None;
55 let mut exit_succ = None;
56 for &n in &self.blocks {
57 if n == cfg_entry
58 || h.input_neighbours(n)
59 .any(|pred| !self.blocks.contains(&pred))
60 {
61 match entry {
62 None => {
63 entry = Some(n);
64 }
65 Some(prev) => {
66 return Err(OutlineCfgError::MultipleEntryNodes(prev, n));
67 }
68 }
69 }
70 let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s));
71 match external_succs.at_most_one() {
72 Ok(None) => (), Ok(Some(o)) => match exit_succ {
74 None => {
75 exit_succ = Some((n, o));
76 }
77 Some((prev, _)) => {
78 return Err(OutlineCfgError::MultipleExitNodes(prev, n));
79 }
80 },
81 Err(ext) => return Err(OutlineCfgError::MultipleExitEdges(n, ext.collect())),
82 }
83 }
84 match (entry, exit_succ) {
85 (Some(e), Some((x, o))) => Ok((e, x, o)),
86 (None, _) => Err(OutlineCfgError::NoEntryNode),
87 (_, None) => Err(OutlineCfgError::NoExitNode),
88 }
89 }
90}
91
92impl PatchVerification for OutlineCfg {
93 type Error = OutlineCfgError;
94 type Node = Node;
95 fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), OutlineCfgError> {
96 self.compute_entry_exit(h)?;
97 Ok(())
98 }
99
100 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
101 self.blocks.iter().copied()
102 }
103}
104
105impl PatchHugrMut for OutlineCfg {
106 type Outcome = [Node; 2];
110
111 const UNCHANGED_ON_FAILURE: bool = true;
112 fn apply_hugr_mut(
113 self,
114 h: &mut impl HugrMut<Node = Node>,
115 ) -> Result<[Node; 2], OutlineCfgError> {
116 let (entry, exit, outside) = self.compute_entry_exit(h)?;
117 let OpType::DataflowBlock(DataflowBlock { inputs, .. }) = h.get_optype(entry) else {
120 panic!("Entry node is not a basic block")
121 };
122 let inputs = inputs.clone();
123 let outputs = match h.get_optype(outside) {
124 OpType::DataflowBlock(dfb) => dfb.dataflow_input().clone(),
125 OpType::ExitBlock(exit) => exit.dataflow_input().clone(),
126 _ => panic!("External successor not a basic block"),
127 };
128 let outer_cfg = h.get_parent(entry).unwrap();
129 let outer_entry = h.children(outer_cfg).next().unwrap();
130
131 let (new_block, cfg_node) = {
133 let mut new_block_bldr =
134 BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap();
135 let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires());
136 let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap();
137 let cfg = cfg.finish_sub_container().unwrap();
138 let unit_sum = new_block_bldr.add_constant(ops::Value::unary_unit_sum());
139 let pred_wire = new_block_bldr.load_const(&unit_sum);
140 new_block_bldr
141 .set_outputs(pred_wire, cfg.outputs())
142 .unwrap();
143 let new_block_hugr = std::mem::take(new_block_bldr.hugr_mut());
144 let ins_res = h.insert_hugr(outer_cfg, new_block_hugr);
145 (
146 ins_res.inserted_entrypoint,
147 *ins_res.node_map.get(&cfg.node()).unwrap(),
148 )
149 };
150
151 let preds: Vec<_> = h
153 .linked_outputs(entry, h.node_inputs(entry).exactly_one().ok().unwrap())
154 .collect();
155 for (pred, br) in preds {
156 if !self.blocks.contains(&pred) {
157 h.disconnect(pred, br);
158 h.connect(pred, br, new_block, 0);
159 }
160 }
161 if entry == outer_entry {
162 h.move_before_sibling(new_block, outer_entry);
165 }
166
167 let exit_port = h
170 .node_outputs(exit)
171 .filter(|p| {
172 let (t, p2) = h.single_linked_input(exit, *p).unwrap();
173 assert!(p2.index() == 0);
174 t == outside
175 })
176 .exactly_one()
177 .ok() .unwrap();
179 h.disconnect(exit, exit_port);
180 h.connect(new_block, 0, outside, 0);
182
183 let inner_exit = {
185 let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap();
186
187 h.move_before_sibling(entry, inner_exit);
189 for n in self.blocks {
191 if n != entry {
193 h.set_parent(n, cfg_node);
194 }
195 }
196 inner_exit
197 };
198
199 h.connect(exit, exit_port, inner_exit, 0);
201
202 Ok([new_block, cfg_node])
203 }
204}
205
206#[derive(Debug, Error)]
208#[non_exhaustive]
209pub enum OutlineCfgError {
210 #[error("The nodes did not all have the same parent")]
212 NotSiblings,
213 #[error("The parent node {0} was not a CFG but a {1}")]
215 ParentNotCfg(Node, OpType),
216 #[error("Multiple blocks had predecessors outside the set - at least {0} and {1}")]
218 MultipleEntryNodes(Node, Node),
219 #[error("Multiple blocks had edges leaving the set - at least {0} and {1}")]
222 MultipleExitNodes(Node, Node),
223 #[error("Exit block {0} had edges to multiple external blocks {1:?}")]
225 MultipleExitEdges(Node, Vec<Node>),
226 #[error("No block had predecessors outside the set")]
228 NoEntryNode,
229 #[error("No block had a successor outside the set")]
231 NoExitNode,
232}
233
234#[cfg(test)]
235mod test {
236 use std::collections::HashSet;
237
238 use crate::builder::{
239 BlockBuilder, BuildError, CFGBuilder, Container, Dataflow, DataflowSubContainer,
240 HugrBuilder, ModuleBuilder,
241 };
242 use crate::extension::prelude::usize_t;
243 use crate::hugr::HugrMut;
244 use crate::ops::constant::Value;
245 use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
246 use crate::types::Signature;
247 use crate::{Hugr, HugrView, Node};
248 use cool_asserts::assert_matches;
249 use itertools::Itertools;
250 use rstest::rstest;
251
252 use super::{OutlineCfg, OutlineCfgError};
253
254 struct CondThenLoopCfg {
258 h: Hugr,
259 left: Node,
260 right: Node,
261 merge: Node,
262 head: Node,
263 tail: Node,
264 }
265 impl CondThenLoopCfg {
266 fn new() -> Result<CondThenLoopCfg, BuildError> {
267 let block_ty = Signature::new_endo(usize_t());
268 let mut cfg_builder = CFGBuilder::new(block_ty.clone())?;
269 let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
270 let const_unit = cfg_builder.add_constant(Value::unary_unit_sum());
271 fn n_identity(
272 mut bbldr: BlockBuilder<&mut Hugr>,
273 cst: &ConstID,
274 ) -> Result<BasicBlockID, BuildError> {
275 let pred = bbldr.load_const(cst);
276 let vals = bbldr.input_wires();
277 bbldr.finish_with_outputs(pred, vals)
278 }
279 let id_block = |c: &mut CFGBuilder<_>| {
280 n_identity(c.simple_block_builder(block_ty.clone(), 1)?, &const_unit)
281 };
282
283 let entry = n_identity(
284 cfg_builder.simple_entry_builder(usize_t().into(), 2)?,
285 &pred_const,
286 )?;
287
288 let left = id_block(&mut cfg_builder)?;
289 let right = id_block(&mut cfg_builder)?;
290 cfg_builder.branch(&entry, 0, &left)?;
291 cfg_builder.branch(&entry, 1, &right)?;
292
293 let merge = id_block(&mut cfg_builder)?;
294 cfg_builder.branch(&left, 0, &merge)?;
295 cfg_builder.branch(&right, 0, &merge)?;
296
297 let head = id_block(&mut cfg_builder)?;
298 cfg_builder.branch(&merge, 0, &head)?;
299 let tail = n_identity(
300 cfg_builder.simple_block_builder(Signature::new_endo(usize_t()), 2)?,
301 &pred_const,
302 )?;
303 cfg_builder.branch(&tail, 1, &head)?;
304 cfg_builder.branch(&head, 0, &tail)?; let exit = cfg_builder.exit_block();
306 cfg_builder.branch(&tail, 0, &exit)?;
307
308 let h = cfg_builder.finish_hugr()?;
309 let (left, right) = (left.node(), right.node());
310 let (merge, head, tail) = (merge.node(), head.node(), tail.node());
311 Ok(Self {
312 h,
313 left,
314 right,
315 merge,
316 head,
317 tail,
318 })
319 }
320 fn entry_exit(&self) -> (Node, Node) {
321 self.h
322 .children(self.h.entrypoint())
323 .take(2)
324 .collect_tuple()
325 .unwrap()
326 }
327 }
328
329 #[rstest::fixture]
330 fn cond_then_loop_cfg() -> CondThenLoopCfg {
331 CondThenLoopCfg::new().unwrap()
332 }
333
334 #[rstest]
335 fn test_outline_cfg_errors(cond_then_loop_cfg: CondThenLoopCfg) {
336 let (entry, _) = cond_then_loop_cfg.entry_exit();
337 let CondThenLoopCfg {
338 mut h,
339 left,
340 right,
341 merge,
342 head,
343 tail,
344 } = cond_then_loop_cfg;
345 let backup = h.clone();
346
347 let r = h.apply_patch(OutlineCfg::new([tail]));
348 assert_matches!(r, Err(OutlineCfgError::MultipleExitEdges(_, _)));
349 assert_eq!(h, backup);
350
351 let r = h.apply_patch(OutlineCfg::new([entry, left, right]));
352 assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b))
353 => assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right])));
354 assert_eq!(h, backup);
355
356 let r = h.apply_patch(OutlineCfg::new([left, right, merge]));
357 assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
358 => assert_eq!(HashSet::from([a,b]), HashSet::from([left, right])));
359 assert_eq!(h, backup);
360
361 let r = h.apply_patch(OutlineCfg::new([entry, left, right, merge, head]));
363 assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
364 => assert_eq!(HashSet::from([a,b]), HashSet::from([entry, head])));
365 assert_eq!(h, backup);
366 }
367
368 #[rstest::rstest]
369 fn test_outline_cfg(cond_then_loop_cfg: CondThenLoopCfg) {
370 let (_, exit) = cond_then_loop_cfg.entry_exit();
375 let CondThenLoopCfg {
376 mut h,
377 merge,
378 head,
379 tail,
380 ..
381 } = cond_then_loop_cfg;
382 let root = h.entrypoint();
383 let (new_block, _, exit_block) = outline_cfg_check_parents(&mut h, root, vec![head, tail]);
384 assert_eq!(h.output_neighbours(merge).collect_vec(), vec![new_block]);
385 assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
386 assert_eq!(
387 h.output_neighbours(tail).collect::<HashSet<Node>>(),
388 HashSet::from([head, exit_block])
389 );
390 }
391
392 #[rstest]
393 fn test_outline_cfg_multiple_in_edges(cond_then_loop_cfg: CondThenLoopCfg) {
394 let (_, exit) = cond_then_loop_cfg.entry_exit();
399 let CondThenLoopCfg {
400 mut h,
401 left,
402 right,
403 merge,
404 head,
405 tail,
406 } = cond_then_loop_cfg;
407
408 let root = h.entrypoint();
409 let (new_block, _, inner_exit) =
410 outline_cfg_check_parents(&mut h, root, vec![merge, head, tail]);
411 assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
412 assert_eq!(
413 h.input_neighbours(new_block).collect::<HashSet<_>>(),
414 HashSet::from([left, right])
415 );
416 assert_eq!(
417 h.output_neighbours(tail).collect::<HashSet<Node>>(),
418 HashSet::from([head, inner_exit])
419 );
420 }
421
422 #[rstest]
423 fn test_outline_cfg_subregion(cond_then_loop_cfg: CondThenLoopCfg) {
424 let mut module_builder = ModuleBuilder::new();
427 let mut fbuild = module_builder
428 .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))
429 .unwrap();
430 let [i1] = fbuild.input_wires_arr();
431 let cfg = fbuild
432 .add_hugr_with_wires(cond_then_loop_cfg.h, [i1])
433 .unwrap();
434 fbuild.finish_with_outputs(cfg.outputs()).unwrap();
435 let mut h = module_builder.finish_hugr().unwrap();
436 let cfg = cfg.node();
438 let exit_node = h.children(cfg).nth(1).unwrap();
439 let tail = h.input_neighbours(exit_node).exactly_one().ok().unwrap();
440 let head = h.input_neighbours(tail).exactly_one().ok().unwrap();
441 assert!(h.get_optype(exit_node).is_exit_block());
443 assert_eq!(
444 h.output_neighbours(tail).collect::<HashSet<_>>(),
445 HashSet::from([head, exit_node])
446 );
447 outline_cfg_check_parents(&mut h, cfg, vec![head, tail]);
448 h.validate().unwrap();
449 }
450
451 #[rstest]
452 fn test_outline_cfg_move_entry(cond_then_loop_cfg: CondThenLoopCfg) {
453 let (entry, _) = cond_then_loop_cfg.entry_exit();
459 let CondThenLoopCfg {
460 mut h,
461 left,
462 right,
463 merge,
464 head,
465 ..
466 } = cond_then_loop_cfg;
467
468 let root = h.entrypoint();
469 let (new_block, _, _) =
470 outline_cfg_check_parents(&mut h, root, vec![entry, left, right, merge]);
471 h.validate().unwrap_or_else(|e| panic!("{e}"));
472 assert_eq!(new_block, h.children(h.entrypoint()).next().unwrap());
473 assert_eq!(h.output_neighbours(new_block).collect_vec(), [head]);
474 }
475
476 fn outline_cfg_check_parents(
477 h: &mut impl HugrMut<Node = Node>,
478 cfg: Node,
479 blocks: Vec<Node>,
480 ) -> (Node, Node, Node) {
481 let mut other_blocks = h.children(cfg).collect::<HashSet<_>>();
482 assert!(blocks.iter().all(|b| other_blocks.remove(b)));
483 let [new_block, new_cfg] = h.apply_patch(OutlineCfg::new(blocks.clone())).unwrap();
484
485 for n in other_blocks {
486 assert_eq!(h.get_parent(n), Some(cfg));
487 }
488 assert_eq!(h.get_parent(new_block), Some(cfg));
489 assert!(h.get_optype(new_block).is_dataflow_block());
490 assert_eq!(h.get_parent(new_cfg), Some(new_block));
491 for n in blocks {
492 assert_eq!(h.get_parent(n), Some(new_cfg));
493 }
494 assert!(h.get_optype(new_cfg).is_cfg());
495 let exit_block = h.children(new_cfg).nth(1).unwrap();
496 assert!(h.get_optype(exit_block).is_exit_block());
497 (new_block, new_cfg, exit_block)
498 }
499}