1use std::collections::HashSet;
3
4use itertools::Itertools;
5use thiserror::Error;
6
7use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
8use crate::extension::ExtensionSet;
9use crate::hugr::internal::HugrMutInternals;
10use crate::hugr::rewrite::Rewrite;
11use crate::hugr::views::sibling::SiblingMut;
12use crate::hugr::{HugrMut, HugrView};
13use crate::ops;
14use crate::ops::controlflow::BasicBlock;
15use crate::ops::dataflow::DataflowOpTrait;
16use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle};
17use crate::ops::{DataflowBlock, OpType};
18use crate::PortIndex;
19use crate::{type_row, Node};
20
21pub struct OutlineCfg {
24 blocks: HashSet<Node>,
25}
26
27impl OutlineCfg {
28 pub fn new(blocks: impl IntoIterator<Item = Node>) -> Self {
30 Self {
31 blocks: HashSet::from_iter(blocks),
32 }
33 }
34
35 fn compute_entry_exit_outside_extensions(
39 &self,
40 h: &impl HugrView<Node = Node>,
41 ) -> Result<(Node, Node, Node, ExtensionSet), OutlineCfgError> {
42 let cfg_n = match self
43 .blocks
44 .iter()
45 .map(|n| h.get_parent(*n))
46 .unique()
47 .exactly_one()
48 {
49 Ok(Some(n)) => n,
50 _ => return Err(OutlineCfgError::NotSiblings),
51 };
52 let o = h.get_optype(cfg_n);
53 let OpType::CFG(o) = o else {
54 return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone()));
55 };
56 let cfg_entry = h.children(cfg_n).next().unwrap();
57 let mut entry = None;
58 let mut exit_succ = None;
59 let mut extension_delta = ExtensionSet::new();
60 for &n in self.blocks.iter() {
61 if n == cfg_entry
62 || h.input_neighbours(n)
63 .any(|pred| !self.blocks.contains(&pred))
64 {
65 match entry {
66 None => {
67 entry = Some(n);
68 }
69 Some(prev) => {
70 return Err(OutlineCfgError::MultipleEntryNodes(prev, n));
71 }
72 }
73 }
74 extension_delta = extension_delta.union(o.signature().runtime_reqs.clone());
75 let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s));
76 match external_succs.at_most_one() {
77 Ok(None) => (), Ok(Some(o)) => match exit_succ {
79 None => {
80 exit_succ = Some((n, o));
81 }
82 Some((prev, _)) => {
83 return Err(OutlineCfgError::MultipleExitNodes(prev, n));
84 }
85 },
86 Err(ext) => return Err(OutlineCfgError::MultipleExitEdges(n, ext.collect())),
87 };
88 }
89 match (entry, exit_succ) {
90 (Some(e), Some((x, o))) => Ok((e, x, o, extension_delta)),
91 (None, _) => Err(OutlineCfgError::NoEntryNode),
92 (_, None) => Err(OutlineCfgError::NoExitNode),
93 }
94 }
95}
96
97impl Rewrite for OutlineCfg {
98 type Error = OutlineCfgError;
99 type ApplyResult = (Node, Node);
103
104 const UNCHANGED_ON_FAILURE: bool = true;
105 fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), OutlineCfgError> {
106 self.compute_entry_exit_outside_extensions(h)?;
107 Ok(())
108 }
109 fn apply(self, h: &mut impl HugrMut<Node = Node>) -> Result<(Node, Node), OutlineCfgError> {
110 let (entry, exit, outside, extension_delta) =
111 self.compute_entry_exit_outside_extensions(h)?;
112 let OpType::DataflowBlock(DataflowBlock { inputs, .. }) = h.get_optype(entry) else {
115 panic!("Entry node is not a basic block")
116 };
117 let inputs = inputs.clone();
118 let outputs = match h.get_optype(outside) {
119 OpType::DataflowBlock(dfb) => dfb.dataflow_input().clone(),
120 OpType::ExitBlock(exit) => exit.dataflow_input().clone(),
121 _ => panic!("External successor not a basic block"),
122 };
123 let outer_cfg = h.get_parent(entry).unwrap();
124 let outer_entry = h.children(outer_cfg).next().unwrap();
125
126 let (new_block, cfg_node) = {
128 let mut new_block_bldr = BlockBuilder::new_exts(
129 inputs.clone(),
130 vec![type_row![]],
131 outputs.clone(),
132 extension_delta.clone(),
133 )
134 .unwrap();
135 let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires());
136 let cfg = new_block_bldr
137 .cfg_builder_exts(wires_in, outputs, extension_delta)
138 .unwrap();
139 let cfg = cfg.finish_sub_container().unwrap();
140 let unit_sum = new_block_bldr.add_constant(ops::Value::unary_unit_sum());
141 let pred_wire = new_block_bldr.load_const(&unit_sum);
142 new_block_bldr
143 .set_outputs(pred_wire, cfg.outputs())
144 .unwrap();
145 let ins_res = h.insert_hugr(outer_cfg, new_block_bldr.hugr().clone());
146 (
147 ins_res.new_root,
148 *ins_res.node_map.get(&cfg.node()).unwrap(),
149 )
150 };
151
152 let preds: Vec<_> = h
154 .linked_outputs(entry, h.node_inputs(entry).exactly_one().ok().unwrap())
155 .collect();
156 for (pred, br) in preds {
157 if !self.blocks.contains(&pred) {
158 h.disconnect(pred, br);
159 h.connect(pred, br, new_block, 0);
160 }
161 }
162 if entry == outer_entry {
163 h.move_before_sibling(new_block, outer_entry);
166 }
167
168 let exit_port = h
171 .node_outputs(exit)
172 .filter(|p| {
173 let (t, p2) = h.single_linked_input(exit, *p).unwrap();
174 assert!(p2.index() == 0);
175 t == outside
176 })
177 .exactly_one()
178 .ok() .unwrap();
180 h.disconnect(exit, exit_port);
181 h.connect(new_block, 0, outside, 0);
183
184 let inner_exit = {
186 let h = h.hugr_mut();
189 let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap();
190 h.move_before_sibling(entry, inner_exit);
192 for n in self.blocks {
194 if n != entry {
196 h.set_parent(n, cfg_node);
197 }
198 }
199 inner_exit
200 };
201
202 let mut in_bb_view: SiblingMut<'_, BasicBlockID> =
205 SiblingMut::try_new(h, new_block).unwrap();
206 let mut in_cfg_view: SiblingMut<'_, CfgID> =
207 SiblingMut::try_new(&mut in_bb_view, cfg_node).unwrap();
208 in_cfg_view.connect(exit, exit_port, inner_exit, 0);
209
210 Ok((new_block, cfg_node))
211 }
212
213 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
214 self.blocks.iter().copied()
215 }
216}
217
218#[derive(Debug, Error)]
220#[non_exhaustive]
221pub enum OutlineCfgError {
222 #[error("The nodes did not all have the same parent")]
224 NotSiblings,
225 #[error("The parent node {0} was not a CFG but a {1}")]
227 ParentNotCfg(Node, OpType),
228 #[error("Multiple blocks had predecessors outside the set - at least {0} and {1}")]
230 MultipleEntryNodes(Node, Node),
231 #[error("Multiple blocks had edges leaving the set - at least {0} and {1}")]
234 MultipleExitNodes(Node, Node),
235 #[error("Exit block {0} had edges to multiple external blocks {1:?}")]
237 MultipleExitEdges(Node, Vec<Node>),
238 #[error("No block had predecessors outside the set")]
240 NoEntryNode,
241 #[error("No block had a successor outside the set")]
243 NoExitNode,
244}
245
246#[cfg(test)]
247mod test {
248 use std::collections::HashSet;
249
250 use crate::builder::{
251 BlockBuilder, BuildError, CFGBuilder, Container, Dataflow, DataflowSubContainer,
252 HugrBuilder, ModuleBuilder,
253 };
254 use crate::extension::prelude::usize_t;
255 use crate::hugr::views::sibling::SiblingMut;
256 use crate::hugr::HugrMut;
257 use crate::ops::constant::Value;
258 use crate::ops::handle::{BasicBlockID, CfgID, ConstID, NodeHandle};
259 use crate::types::Signature;
260 use crate::{Hugr, HugrView, Node};
261 use cool_asserts::assert_matches;
262 use itertools::Itertools;
263 use rstest::rstest;
264
265 use super::{OutlineCfg, OutlineCfgError};
266
267 struct CondThenLoopCfg {
271 h: Hugr,
272 left: Node,
273 right: Node,
274 merge: Node,
275 head: Node,
276 tail: Node,
277 }
278 impl CondThenLoopCfg {
279 fn new() -> Result<CondThenLoopCfg, BuildError> {
280 let block_ty = Signature::new_endo(usize_t());
281 let mut cfg_builder = CFGBuilder::new(block_ty.clone())?;
282 let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
283 let const_unit = cfg_builder.add_constant(Value::unary_unit_sum());
284 fn n_identity(
285 mut bbldr: BlockBuilder<&mut Hugr>,
286 cst: &ConstID,
287 ) -> Result<BasicBlockID, BuildError> {
288 let pred = bbldr.load_const(cst);
289 let vals = bbldr.input_wires();
290 bbldr.finish_with_outputs(pred, vals)
291 }
292 let id_block = |c: &mut CFGBuilder<_>| {
293 n_identity(c.simple_block_builder(block_ty.clone(), 1)?, &const_unit)
294 };
295
296 let entry = n_identity(
297 cfg_builder.simple_entry_builder(usize_t().into(), 2)?,
298 &pred_const,
299 )?;
300
301 let left = id_block(&mut cfg_builder)?;
302 let right = id_block(&mut cfg_builder)?;
303 cfg_builder.branch(&entry, 0, &left)?;
304 cfg_builder.branch(&entry, 1, &right)?;
305
306 let merge = id_block(&mut cfg_builder)?;
307 cfg_builder.branch(&left, 0, &merge)?;
308 cfg_builder.branch(&right, 0, &merge)?;
309
310 let head = id_block(&mut cfg_builder)?;
311 cfg_builder.branch(&merge, 0, &head)?;
312 let tail = n_identity(
313 cfg_builder.simple_block_builder(Signature::new_endo(usize_t()), 2)?,
314 &pred_const,
315 )?;
316 cfg_builder.branch(&tail, 1, &head)?;
317 cfg_builder.branch(&head, 0, &tail)?; let exit = cfg_builder.exit_block();
319 cfg_builder.branch(&tail, 0, &exit)?;
320
321 let h = cfg_builder.finish_hugr()?;
322 let (left, right) = (left.node(), right.node());
323 let (merge, head, tail) = (merge.node(), head.node(), tail.node());
324 Ok(Self {
325 h,
326 left,
327 right,
328 merge,
329 head,
330 tail,
331 })
332 }
333 fn entry_exit(&self) -> (Node, Node) {
334 self.h
335 .children(self.h.root())
336 .take(2)
337 .collect_tuple()
338 .unwrap()
339 }
340 }
341
342 #[rstest::fixture]
343 fn cond_then_loop_cfg() -> CondThenLoopCfg {
344 CondThenLoopCfg::new().unwrap()
345 }
346
347 #[rstest]
348 fn test_outline_cfg_errors(cond_then_loop_cfg: CondThenLoopCfg) {
349 let (entry, _) = cond_then_loop_cfg.entry_exit();
350 let CondThenLoopCfg {
351 mut h,
352 left,
353 right,
354 merge,
355 head,
356 tail,
357 } = cond_then_loop_cfg;
358 let backup = h.clone();
359
360 let r = h.apply_rewrite(OutlineCfg::new([tail]));
361 assert_matches!(r, Err(OutlineCfgError::MultipleExitEdges(_, _)));
362 assert_eq!(h, backup);
363
364 let r = h.apply_rewrite(OutlineCfg::new([entry, left, right]));
365 assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b))
366 => assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right])));
367 assert_eq!(h, backup);
368
369 let r = h.apply_rewrite(OutlineCfg::new([left, right, merge]));
370 assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
371 => assert_eq!(HashSet::from([a,b]), HashSet::from([left, right])));
372 assert_eq!(h, backup);
373
374 let r = h.apply_rewrite(OutlineCfg::new([entry, left, right, merge, head]));
376 assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
377 => assert_eq!(HashSet::from([a,b]), HashSet::from([entry, head])));
378 assert_eq!(h, backup);
379 }
380
381 #[rstest::rstest]
382 fn test_outline_cfg(cond_then_loop_cfg: CondThenLoopCfg) {
383 let (_, exit) = cond_then_loop_cfg.entry_exit();
388 let CondThenLoopCfg {
389 mut h,
390 merge,
391 head,
392 tail,
393 ..
394 } = cond_then_loop_cfg;
395 let root = h.root();
396 let (new_block, _, exit_block) = outline_cfg_check_parents(&mut h, root, vec![head, tail]);
397 assert_eq!(h.output_neighbours(merge).collect_vec(), vec![new_block]);
398 assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
399 assert_eq!(
400 h.output_neighbours(tail).collect::<HashSet<Node>>(),
401 HashSet::from([head, exit_block])
402 );
403 }
404
405 #[rstest]
406 fn test_outline_cfg_multiple_in_edges(cond_then_loop_cfg: CondThenLoopCfg) {
407 let (_, exit) = cond_then_loop_cfg.entry_exit();
412 let CondThenLoopCfg {
413 mut h,
414 left,
415 right,
416 merge,
417 head,
418 tail,
419 } = cond_then_loop_cfg;
420
421 let root = h.root();
422 let (new_block, _, inner_exit) =
423 outline_cfg_check_parents(&mut h, root, vec![merge, head, tail]);
424 assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
425 assert_eq!(
426 h.input_neighbours(new_block).collect::<HashSet<_>>(),
427 HashSet::from([left, right])
428 );
429 assert_eq!(
430 h.output_neighbours(tail).collect::<HashSet<Node>>(),
431 HashSet::from([head, inner_exit])
432 );
433 }
434
435 #[rstest]
436 fn test_outline_cfg_subregion(cond_then_loop_cfg: CondThenLoopCfg) {
437 let mut module_builder = ModuleBuilder::new();
440 let mut fbuild = module_builder
441 .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))
442 .unwrap();
443 let [i1] = fbuild.input_wires_arr();
444 let cfg = fbuild
445 .add_hugr_with_wires(cond_then_loop_cfg.h, [i1])
446 .unwrap();
447 fbuild.finish_with_outputs(cfg.outputs()).unwrap();
448 let mut h = module_builder.finish_hugr().unwrap();
449 let cfg = cfg.node();
451 let exit_node = h.children(cfg).nth(1).unwrap();
452 let tail = h.input_neighbours(exit_node).exactly_one().ok().unwrap();
453 let head = h.input_neighbours(tail).exactly_one().ok().unwrap();
454 assert!(h.get_optype(exit_node).is_exit_block());
456 assert_eq!(
457 h.output_neighbours(tail).collect::<HashSet<_>>(),
458 HashSet::from([head, exit_node])
459 );
460 outline_cfg_check_parents(
461 &mut SiblingMut::<'_, CfgID>::try_new(&mut h, cfg).unwrap(),
462 cfg,
463 vec![head, tail],
464 );
465 h.validate().unwrap();
466 }
467
468 #[rstest]
469 fn test_outline_cfg_move_entry(cond_then_loop_cfg: CondThenLoopCfg) {
470 let (entry, _) = cond_then_loop_cfg.entry_exit();
476 let CondThenLoopCfg {
477 mut h,
478 left,
479 right,
480 merge,
481 head,
482 ..
483 } = cond_then_loop_cfg;
484
485 let root = h.root();
486 let (new_block, _, _) =
487 outline_cfg_check_parents(&mut h, root, vec![entry, left, right, merge]);
488 h.validate().unwrap();
489 assert_eq!(new_block, h.children(h.root()).next().unwrap());
490 assert_eq!(h.output_neighbours(new_block).collect_vec(), [head]);
491 }
492
493 fn outline_cfg_check_parents(
494 h: &mut impl HugrMut,
495 cfg: Node,
496 blocks: Vec<Node>,
497 ) -> (Node, Node, Node) {
498 let mut other_blocks = h.children(cfg).collect::<HashSet<_>>();
499 assert!(blocks.iter().all(|b| other_blocks.remove(b)));
500 let (new_block, new_cfg) = h.apply_rewrite(OutlineCfg::new(blocks.clone())).unwrap();
501
502 for n in other_blocks {
503 assert_eq!(h.get_parent(n), Some(cfg))
504 }
505 assert_eq!(h.get_parent(new_block), Some(cfg));
506 assert!(h.get_optype(new_block).is_dataflow_block());
507 let b = h.base_hugr(); assert_eq!(b.get_parent(new_cfg), Some(new_block));
509 for n in blocks {
510 assert_eq!(b.get_parent(n), Some(new_cfg));
511 }
512 assert!(b.get_optype(new_cfg).is_cfg());
513 let exit_block = b.children(new_cfg).nth(1).unwrap();
514 assert!(b.get_optype(exit_block).is_exit_block());
515 (new_block, new_cfg, exit_block)
516 }
517}