hugr_core/hugr/patch/
outline_cfg.rs

1//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection
2//! of an existing CFG
3use std::collections::HashSet;
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::PortIndex;
9use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
10use crate::hugr::{HugrMut, HugrView};
11use crate::ops;
12use crate::ops::controlflow::BasicBlock;
13use crate::ops::handle::NodeHandle;
14use crate::ops::{DataflowBlock, OpType};
15use crate::{Node, type_row};
16
17use super::{PatchHugrMut, PatchVerification};
18
19/// Moves some of the blocks in a Control-flow region into a new CFG-node that
20/// is the only child of a new Basic Block in the original region.
21pub struct OutlineCfg {
22    blocks: HashSet<Node>,
23}
24
25impl OutlineCfg {
26    /// Create a new `OutlineCfg` rewrite that will move the provided blocks.
27    pub fn new(blocks: impl IntoIterator<Item = Node>) -> Self {
28        Self {
29            blocks: HashSet::from_iter(blocks),
30        }
31    }
32
33    /// Compute the entry and exit nodes of the CFG which contains
34    /// [`self.blocks`], along with the output neighbour its parent graph.
35    fn compute_entry_exit(
36        &self,
37        h: &impl HugrView<Node = Node>,
38    ) -> Result<(Node, Node, Node), OutlineCfgError> {
39        let cfg_n = match self
40            .blocks
41            .iter()
42            .map(|n| h.get_parent(*n))
43            .unique()
44            .exactly_one()
45        {
46            Ok(Some(n)) => n,
47            _ => return Err(OutlineCfgError::NotSiblings),
48        };
49        let o = h.get_optype(cfg_n);
50        let OpType::CFG(_) = o else {
51            return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone()));
52        };
53        let cfg_entry = h.children(cfg_n).next().unwrap();
54        let mut entry = None;
55        let mut exit_succ = None;
56        for &n in &self.blocks {
57            if n == cfg_entry
58                || h.input_neighbours(n)
59                    .any(|pred| !self.blocks.contains(&pred))
60            {
61                match entry {
62                    None => {
63                        entry = Some(n);
64                    }
65                    Some(prev) => {
66                        return Err(OutlineCfgError::MultipleEntryNodes(prev, n));
67                    }
68                }
69            }
70            let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s));
71            match external_succs.at_most_one() {
72                Ok(None) => (), // No external successors
73                Ok(Some(o)) => match exit_succ {
74                    None => {
75                        exit_succ = Some((n, o));
76                    }
77                    Some((prev, _)) => {
78                        return Err(OutlineCfgError::MultipleExitNodes(prev, n));
79                    }
80                },
81                Err(ext) => return Err(OutlineCfgError::MultipleExitEdges(n, ext.collect())),
82            }
83        }
84        match (entry, exit_succ) {
85            (Some(e), Some((x, o))) => Ok((e, x, o)),
86            (None, _) => Err(OutlineCfgError::NoEntryNode),
87            (_, None) => Err(OutlineCfgError::NoExitNode),
88        }
89    }
90}
91
92impl PatchVerification for OutlineCfg {
93    type Error = OutlineCfgError;
94    type Node = Node;
95    fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), OutlineCfgError> {
96        self.compute_entry_exit(h)?;
97        Ok(())
98    }
99
100    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
101        self.blocks.iter().copied()
102    }
103}
104
105impl PatchHugrMut for OutlineCfg {
106    /// The newly-created basic block, and the [CFG] node inside it
107    ///
108    /// [CFG]: OpType::CFG
109    type Outcome = [Node; 2];
110
111    const UNCHANGED_ON_FAILURE: bool = true;
112    fn apply_hugr_mut(
113        self,
114        h: &mut impl HugrMut<Node = Node>,
115    ) -> Result<[Node; 2], OutlineCfgError> {
116        let (entry, exit, outside) = self.compute_entry_exit(h)?;
117        // 1. Compute signature
118        // These panic()s only happen if the Hugr would not have passed validate()
119        let OpType::DataflowBlock(DataflowBlock { inputs, .. }) = h.get_optype(entry) else {
120            panic!("Entry node is not a basic block")
121        };
122        let inputs = inputs.clone();
123        let outputs = match h.get_optype(outside) {
124            OpType::DataflowBlock(dfb) => dfb.dataflow_input().clone(),
125            OpType::ExitBlock(exit) => exit.dataflow_input().clone(),
126            _ => panic!("External successor not a basic block"),
127        };
128        let outer_cfg = h.get_parent(entry).unwrap();
129        let outer_entry = h.children(outer_cfg).next().unwrap();
130
131        // 2. new_block contains input node, sub-cfg, exit node all connected
132        let (new_block, cfg_node) = {
133            let mut new_block_bldr =
134                BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap();
135            let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires());
136            let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap();
137            let cfg = cfg.finish_sub_container().unwrap();
138            let unit_sum = new_block_bldr.add_constant(ops::Value::unary_unit_sum());
139            let pred_wire = new_block_bldr.load_const(&unit_sum);
140            new_block_bldr
141                .set_outputs(pred_wire, cfg.outputs())
142                .unwrap();
143            let new_block_hugr = std::mem::take(new_block_bldr.hugr_mut());
144            let ins_res = h.insert_hugr(outer_cfg, new_block_hugr);
145            (
146                ins_res.inserted_entrypoint,
147                *ins_res.node_map.get(&cfg.node()).unwrap(),
148            )
149        };
150
151        // 3. Entry edges. Change any edges into entry_block from outside, to target new_block
152        let preds: Vec<_> = h
153            .linked_outputs(entry, h.node_inputs(entry).exactly_one().ok().unwrap())
154            .collect();
155        for (pred, br) in preds {
156            if !self.blocks.contains(&pred) {
157                h.disconnect(pred, br);
158                h.connect(pred, br, new_block, 0);
159            }
160        }
161        if entry == outer_entry {
162            // new_block must be the entry node, i.e. first child, of the enclosing CFG
163            // (the current entry node will be reparented inside new_block below)
164            h.move_before_sibling(new_block, outer_entry);
165        }
166
167        // 4(a). Exit edges.
168        // Remove edge from exit_node (that used to target outside)
169        let exit_port = h
170            .node_outputs(exit)
171            .filter(|p| {
172                let (t, p2) = h.single_linked_input(exit, *p).unwrap();
173                assert!(p2.index() == 0);
174                t == outside
175            })
176            .exactly_one()
177            .ok() // NodePorts does not implement Debug
178            .unwrap();
179        h.disconnect(exit, exit_port);
180        // And connect new_block to outside instead
181        h.connect(new_block, 0, outside, 0);
182
183        // 5. Children of new CFG.
184        let inner_exit = {
185            let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap();
186
187            // Entry node must be first
188            h.move_before_sibling(entry, inner_exit);
189            // And remaining nodes
190            for n in self.blocks {
191                // Do not move the entry node, as we have already
192                if n != entry {
193                    h.set_parent(n, cfg_node);
194                }
195            }
196            inner_exit
197        };
198
199        // 4(b). Reconnect exit edge to the new exit node within the inner CFG
200        h.connect(exit, exit_port, inner_exit, 0);
201
202        Ok([new_block, cfg_node])
203    }
204}
205
206/// Errors that can occur in expressing an `OutlineCfg` rewrite.
207#[derive(Debug, Error)]
208#[non_exhaustive]
209pub enum OutlineCfgError {
210    /// The set of blocks were not siblings
211    #[error("The nodes did not all have the same parent")]
212    NotSiblings,
213    /// The parent node was not a CFG node
214    #[error("The parent node {0} was not a CFG but a {1}")]
215    ParentNotCfg(Node, OpType),
216    /// Multiple blocks had incoming edges
217    #[error("Multiple blocks had predecessors outside the set - at least {0} and {1}")]
218    MultipleEntryNodes(Node, Node),
219    /// Multiple blocks had outgoing edges
220    // Note possible TODO: straightforward if all outgoing edges target the same BB
221    #[error("Multiple blocks had edges leaving the set - at least {0} and {1}")]
222    MultipleExitNodes(Node, Node),
223    /// One block had multiple outgoing edges
224    #[error("Exit block {0} had edges to multiple external blocks {1:?}")]
225    MultipleExitEdges(Node, Vec<Node>),
226    /// No block was identified as an entry block
227    #[error("No block had predecessors outside the set")]
228    NoEntryNode,
229    /// No block was found with an edge leaving the set (so, must be an infinite loop)
230    #[error("No block had a successor outside the set")]
231    NoExitNode,
232}
233
234#[cfg(test)]
235mod test {
236    use std::collections::HashSet;
237
238    use crate::builder::{
239        BlockBuilder, BuildError, CFGBuilder, Container, Dataflow, DataflowSubContainer,
240        HugrBuilder, ModuleBuilder,
241    };
242    use crate::extension::prelude::usize_t;
243    use crate::hugr::HugrMut;
244    use crate::ops::constant::Value;
245    use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
246    use crate::types::Signature;
247    use crate::{Hugr, HugrView, Node};
248    use cool_asserts::assert_matches;
249    use itertools::Itertools;
250    use rstest::rstest;
251
252    use super::{OutlineCfg, OutlineCfgError};
253
254    ///      /-> left --\
255    /// entry            > merge -> head -> tail -> exit
256    ///      \-> right -/             \-<--<-/
257    struct CondThenLoopCfg {
258        h: Hugr,
259        left: Node,
260        right: Node,
261        merge: Node,
262        head: Node,
263        tail: Node,
264    }
265    impl CondThenLoopCfg {
266        fn new() -> Result<CondThenLoopCfg, BuildError> {
267            let block_ty = Signature::new_endo(usize_t());
268            let mut cfg_builder = CFGBuilder::new(block_ty.clone())?;
269            let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
270            let const_unit = cfg_builder.add_constant(Value::unary_unit_sum());
271            fn n_identity(
272                mut bbldr: BlockBuilder<&mut Hugr>,
273                cst: &ConstID,
274            ) -> Result<BasicBlockID, BuildError> {
275                let pred = bbldr.load_const(cst);
276                let vals = bbldr.input_wires();
277                bbldr.finish_with_outputs(pred, vals)
278            }
279            let id_block = |c: &mut CFGBuilder<_>| {
280                n_identity(c.simple_block_builder(block_ty.clone(), 1)?, &const_unit)
281            };
282
283            let entry = n_identity(
284                cfg_builder.simple_entry_builder(usize_t().into(), 2)?,
285                &pred_const,
286            )?;
287
288            let left = id_block(&mut cfg_builder)?;
289            let right = id_block(&mut cfg_builder)?;
290            cfg_builder.branch(&entry, 0, &left)?;
291            cfg_builder.branch(&entry, 1, &right)?;
292
293            let merge = id_block(&mut cfg_builder)?;
294            cfg_builder.branch(&left, 0, &merge)?;
295            cfg_builder.branch(&right, 0, &merge)?;
296
297            let head = id_block(&mut cfg_builder)?;
298            cfg_builder.branch(&merge, 0, &head)?;
299            let tail = n_identity(
300                cfg_builder.simple_block_builder(Signature::new_endo(usize_t()), 2)?,
301                &pred_const,
302            )?;
303            cfg_builder.branch(&tail, 1, &head)?;
304            cfg_builder.branch(&head, 0, &tail)?; // trivial "loop body"
305            let exit = cfg_builder.exit_block();
306            cfg_builder.branch(&tail, 0, &exit)?;
307
308            let h = cfg_builder.finish_hugr()?;
309            let (left, right) = (left.node(), right.node());
310            let (merge, head, tail) = (merge.node(), head.node(), tail.node());
311            Ok(Self {
312                h,
313                left,
314                right,
315                merge,
316                head,
317                tail,
318            })
319        }
320        fn entry_exit(&self) -> (Node, Node) {
321            self.h
322                .children(self.h.entrypoint())
323                .take(2)
324                .collect_tuple()
325                .unwrap()
326        }
327    }
328
329    #[rstest::fixture]
330    fn cond_then_loop_cfg() -> CondThenLoopCfg {
331        CondThenLoopCfg::new().unwrap()
332    }
333
334    #[rstest]
335    fn test_outline_cfg_errors(cond_then_loop_cfg: CondThenLoopCfg) {
336        let (entry, _) = cond_then_loop_cfg.entry_exit();
337        let CondThenLoopCfg {
338            mut h,
339            left,
340            right,
341            merge,
342            head,
343            tail,
344        } = cond_then_loop_cfg;
345        let backup = h.clone();
346
347        let r = h.apply_patch(OutlineCfg::new([tail]));
348        assert_matches!(r, Err(OutlineCfgError::MultipleExitEdges(_, _)));
349        assert_eq!(h, backup);
350
351        let r = h.apply_patch(OutlineCfg::new([entry, left, right]));
352        assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b))
353            => assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right])));
354        assert_eq!(h, backup);
355
356        let r = h.apply_patch(OutlineCfg::new([left, right, merge]));
357        assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
358            => assert_eq!(HashSet::from([a,b]), HashSet::from([left, right])));
359        assert_eq!(h, backup);
360
361        // The entry node implicitly has an extra incoming edge
362        let r = h.apply_patch(OutlineCfg::new([entry, left, right, merge, head]));
363        assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
364            => assert_eq!(HashSet::from([a,b]), HashSet::from([entry, head])));
365        assert_eq!(h, backup);
366    }
367
368    #[rstest::rstest]
369    fn test_outline_cfg(cond_then_loop_cfg: CondThenLoopCfg) {
370        // Outline the loop, producing:
371        //     /-> left -->\
372        // entry            merge -> newblock -> exit
373        //     \-> right ->/
374        let (_, exit) = cond_then_loop_cfg.entry_exit();
375        let CondThenLoopCfg {
376            mut h,
377            merge,
378            head,
379            tail,
380            ..
381        } = cond_then_loop_cfg;
382        let root = h.entrypoint();
383        let (new_block, _, exit_block) = outline_cfg_check_parents(&mut h, root, vec![head, tail]);
384        assert_eq!(h.output_neighbours(merge).collect_vec(), vec![new_block]);
385        assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
386        assert_eq!(
387            h.output_neighbours(tail).collect::<HashSet<Node>>(),
388            HashSet::from([head, exit_block])
389        );
390    }
391
392    #[rstest]
393    fn test_outline_cfg_multiple_in_edges(cond_then_loop_cfg: CondThenLoopCfg) {
394        // Outline merge, head and tail, producing
395        //     /-> left -->\
396        // entry            newblock -> exit
397        //     \-> right ->/
398        let (_, exit) = cond_then_loop_cfg.entry_exit();
399        let CondThenLoopCfg {
400            mut h,
401            left,
402            right,
403            merge,
404            head,
405            tail,
406        } = cond_then_loop_cfg;
407
408        let root = h.entrypoint();
409        let (new_block, _, inner_exit) =
410            outline_cfg_check_parents(&mut h, root, vec![merge, head, tail]);
411        assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
412        assert_eq!(
413            h.input_neighbours(new_block).collect::<HashSet<_>>(),
414            HashSet::from([left, right])
415        );
416        assert_eq!(
417            h.output_neighbours(tail).collect::<HashSet<Node>>(),
418            HashSet::from([head, inner_exit])
419        );
420    }
421
422    #[rstest]
423    fn test_outline_cfg_subregion(cond_then_loop_cfg: CondThenLoopCfg) {
424        // Outline the loop, as above, but with the CFG inside a Function + Module,
425        // operating via a SiblingMut
426        let mut module_builder = ModuleBuilder::new();
427        let mut fbuild = module_builder
428            .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))
429            .unwrap();
430        let [i1] = fbuild.input_wires_arr();
431        let cfg = fbuild
432            .add_hugr_with_wires(cond_then_loop_cfg.h, [i1])
433            .unwrap();
434        fbuild.finish_with_outputs(cfg.outputs()).unwrap();
435        let mut h = module_builder.finish_hugr().unwrap();
436        // `add_hugr_with_wires` does not return an InsertionResult, so recover the nodes manually:
437        let cfg = cfg.node();
438        let exit_node = h.children(cfg).nth(1).unwrap();
439        let tail = h.input_neighbours(exit_node).exactly_one().ok().unwrap();
440        let head = h.input_neighbours(tail).exactly_one().ok().unwrap();
441        // Just sanity-check we have the correct nodes
442        assert!(h.get_optype(exit_node).is_exit_block());
443        assert_eq!(
444            h.output_neighbours(tail).collect::<HashSet<_>>(),
445            HashSet::from([head, exit_node])
446        );
447        outline_cfg_check_parents(&mut h, cfg, vec![head, tail]);
448        h.validate().unwrap();
449    }
450
451    #[rstest]
452    fn test_outline_cfg_move_entry(cond_then_loop_cfg: CondThenLoopCfg) {
453        // Outline the conditional, producing
454        //
455        //  newblock -> head -> tail -> exit
456        //                 \<--</
457        // (where the new block becomes the entry block)
458        let (entry, _) = cond_then_loop_cfg.entry_exit();
459        let CondThenLoopCfg {
460            mut h,
461            left,
462            right,
463            merge,
464            head,
465            ..
466        } = cond_then_loop_cfg;
467
468        let root = h.entrypoint();
469        let (new_block, _, _) =
470            outline_cfg_check_parents(&mut h, root, vec![entry, left, right, merge]);
471        h.validate().unwrap_or_else(|e| panic!("{e}"));
472        assert_eq!(new_block, h.children(h.entrypoint()).next().unwrap());
473        assert_eq!(h.output_neighbours(new_block).collect_vec(), [head]);
474    }
475
476    fn outline_cfg_check_parents(
477        h: &mut impl HugrMut<Node = Node>,
478        cfg: Node,
479        blocks: Vec<Node>,
480    ) -> (Node, Node, Node) {
481        let mut other_blocks = h.children(cfg).collect::<HashSet<_>>();
482        assert!(blocks.iter().all(|b| other_blocks.remove(b)));
483        let [new_block, new_cfg] = h.apply_patch(OutlineCfg::new(blocks.clone())).unwrap();
484
485        for n in other_blocks {
486            assert_eq!(h.get_parent(n), Some(cfg));
487        }
488        assert_eq!(h.get_parent(new_block), Some(cfg));
489        assert!(h.get_optype(new_block).is_dataflow_block());
490        assert_eq!(h.get_parent(new_cfg), Some(new_block));
491        for n in blocks {
492            assert_eq!(h.get_parent(n), Some(new_cfg));
493        }
494        assert!(h.get_optype(new_cfg).is_cfg());
495        let exit_block = h.children(new_cfg).nth(1).unwrap();
496        assert!(h.get_optype(exit_block).is_exit_block());
497        (new_block, new_cfg, exit_block)
498    }
499}