hugr_core/hugr/patch/
outline_cfg.rs

1//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection
2//! of an existing CFG
3use std::collections::HashSet;
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::PortIndex;
9use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
10use crate::hugr::{HugrMut, HugrView};
11use crate::ops;
12use crate::ops::controlflow::BasicBlock;
13use crate::ops::handle::NodeHandle;
14use crate::ops::{DataflowBlock, OpType};
15use crate::{Node, type_row};
16
17use super::{PatchHugrMut, PatchVerification};
18
19/// Moves some of the blocks in a Control-flow region into a new CFG-node that
20/// is the only child of a new Basic Block in the original region.
21pub struct OutlineCfg {
22    blocks: HashSet<Node>,
23}
24
25impl OutlineCfg {
26    /// Create a new `OutlineCfg` rewrite that will move the provided blocks.
27    pub fn new(blocks: impl IntoIterator<Item = Node>) -> Self {
28        Self {
29            blocks: HashSet::from_iter(blocks),
30        }
31    }
32
33    /// Compute the entry and exit nodes of the CFG which contains
34    /// [`self.blocks`], along with the output neighbour its parent graph.
35    fn compute_entry_exit(
36        &self,
37        h: &impl HugrView<Node = Node>,
38    ) -> Result<(Node, Node, Node), OutlineCfgError> {
39        let cfg_n = match self
40            .blocks
41            .iter()
42            .map(|n| h.get_parent(*n))
43            .unique()
44            .exactly_one()
45        {
46            Ok(Some(n)) => n,
47            _ => return Err(OutlineCfgError::NotSiblings),
48        };
49        let o = h.get_optype(cfg_n);
50        let OpType::CFG(_) = o else {
51            return Err(OutlineCfgError::ParentNotCfg(cfg_n, Box::new(o.clone())));
52        };
53        let cfg_entry = h.children(cfg_n).next().unwrap();
54        let mut entry = None;
55        let mut exit_succ = None;
56        for &n in &self.blocks {
57            if n == cfg_entry
58                || h.input_neighbours(n)
59                    .any(|pred| !self.blocks.contains(&pred))
60            {
61                match entry {
62                    None => {
63                        entry = Some(n);
64                    }
65                    Some(prev) => {
66                        return Err(OutlineCfgError::MultipleEntryNodes(prev, n));
67                    }
68                }
69            }
70            let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s));
71            match external_succs.at_most_one() {
72                Ok(None) => (), // No external successors
73                Ok(Some(o)) => match exit_succ {
74                    None => {
75                        exit_succ = Some((n, o));
76                    }
77                    Some((prev, _)) => {
78                        return Err(OutlineCfgError::MultipleExitNodes(prev, n));
79                    }
80                },
81                Err(ext) => return Err(OutlineCfgError::MultipleExitEdges(n, ext.collect())),
82            }
83        }
84        match (entry, exit_succ) {
85            (Some(e), Some((x, o))) => Ok((e, x, o)),
86            (None, _) => Err(OutlineCfgError::NoEntryNode),
87            (_, None) => Err(OutlineCfgError::NoExitNode),
88        }
89    }
90}
91
92impl PatchVerification for OutlineCfg {
93    type Error = OutlineCfgError;
94    type Node = Node;
95    fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), OutlineCfgError> {
96        self.compute_entry_exit(h)?;
97        Ok(())
98    }
99
100    fn invalidated_nodes(
101        &self,
102        _: &impl HugrView<Node = Self::Node>,
103    ) -> impl Iterator<Item = Self::Node> {
104        self.blocks.iter().copied()
105    }
106}
107
108impl PatchHugrMut for OutlineCfg {
109    /// The newly-created basic block, and the [CFG] node inside it
110    ///
111    /// [CFG]: OpType::CFG
112    type Outcome = [Node; 2];
113
114    const UNCHANGED_ON_FAILURE: bool = true;
115    fn apply_hugr_mut(
116        self,
117        h: &mut impl HugrMut<Node = Node>,
118    ) -> Result<[Node; 2], OutlineCfgError> {
119        let (entry, exit, outside) = self.compute_entry_exit(h)?;
120        // 1. Compute signature
121        // These panic()s only happen if the Hugr would not have passed validate()
122        let OpType::DataflowBlock(DataflowBlock { inputs, .. }) = h.get_optype(entry) else {
123            panic!("Entry node is not a basic block")
124        };
125        let inputs = inputs.clone();
126        let outputs = match h.get_optype(outside) {
127            OpType::DataflowBlock(dfb) => dfb.dataflow_input().clone(),
128            OpType::ExitBlock(exit) => exit.dataflow_input().clone(),
129            _ => panic!("External successor not a basic block"),
130        };
131        let outer_cfg = h.get_parent(entry).unwrap();
132        let outer_entry = h.children(outer_cfg).next().unwrap();
133
134        // 2. new_block contains input node, sub-cfg, exit node all connected
135        let (new_block, cfg_node) = {
136            let mut new_block_bldr =
137                BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap();
138            let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires());
139            let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap();
140            let cfg = cfg.finish_sub_container().unwrap();
141            let unit_sum = new_block_bldr.add_constant(ops::Value::unary_unit_sum());
142            let pred_wire = new_block_bldr.load_const(&unit_sum);
143            new_block_bldr
144                .set_outputs(pred_wire, cfg.outputs())
145                .unwrap();
146            let new_block_hugr = std::mem::take(new_block_bldr.hugr_mut());
147            let ins_res = h.insert_hugr(outer_cfg, new_block_hugr);
148            (
149                ins_res.inserted_entrypoint,
150                *ins_res.node_map.get(&cfg.node()).unwrap(),
151            )
152        };
153
154        // 3. Entry edges. Change any edges into entry_block from outside, to target new_block
155        let preds: Vec<_> = h
156            .linked_outputs(entry, h.node_inputs(entry).exactly_one().ok().unwrap())
157            .collect();
158        for (pred, br) in preds {
159            if !self.blocks.contains(&pred) {
160                h.disconnect(pred, br);
161                h.connect(pred, br, new_block, 0);
162            }
163        }
164        if entry == outer_entry {
165            // new_block must be the entry node, i.e. first child, of the enclosing CFG
166            // (the current entry node will be reparented inside new_block below)
167            h.move_before_sibling(new_block, outer_entry);
168        }
169
170        // 4(a). Exit edges.
171        // Remove edge from exit_node (that used to target outside)
172        let exit_port = h
173            .node_outputs(exit)
174            .filter(|p| {
175                let (t, p2) = h.single_linked_input(exit, *p).unwrap();
176                assert!(p2.index() == 0);
177                t == outside
178            })
179            .exactly_one()
180            .ok() // NodePorts does not implement Debug
181            .unwrap();
182        h.disconnect(exit, exit_port);
183        // And connect new_block to outside instead
184        h.connect(new_block, 0, outside, 0);
185
186        // 5. Children of new CFG.
187        let inner_exit = {
188            let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap();
189
190            // Entry node must be first
191            h.move_before_sibling(entry, inner_exit);
192            // And remaining nodes
193            for n in self.blocks {
194                // Do not move the entry node, as we have already
195                if n != entry {
196                    h.set_parent(n, cfg_node);
197                }
198            }
199            inner_exit
200        };
201
202        // 4(b). Reconnect exit edge to the new exit node within the inner CFG
203        h.connect(exit, exit_port, inner_exit, 0);
204
205        Ok([new_block, cfg_node])
206    }
207}
208
209/// Errors that can occur in expressing an `OutlineCfg` rewrite.
210#[derive(Debug, Error)]
211#[non_exhaustive]
212pub enum OutlineCfgError {
213    /// The set of blocks were not siblings
214    #[error("The nodes did not all have the same parent")]
215    NotSiblings,
216    /// The parent node was not a CFG node
217    #[error("The parent node {0} was not a CFG but a {1}")]
218    ParentNotCfg(Node, Box<OpType>),
219    /// Multiple blocks had incoming edges
220    #[error("Multiple blocks had predecessors outside the set - at least {0} and {1}")]
221    MultipleEntryNodes(Node, Node),
222    /// Multiple blocks had outgoing edges
223    // Note possible TODO: straightforward if all outgoing edges target the same BB
224    #[error("Multiple blocks had edges leaving the set - at least {0} and {1}")]
225    MultipleExitNodes(Node, Node),
226    /// One block had multiple outgoing edges
227    #[error("Exit block {0} had edges to multiple external blocks {1:?}")]
228    MultipleExitEdges(Node, Vec<Node>),
229    /// No block was identified as an entry block
230    #[error("No block had predecessors outside the set")]
231    NoEntryNode,
232    /// No block was found with an edge leaving the set (so, must be an infinite loop)
233    #[error("No block had a successor outside the set")]
234    NoExitNode,
235}
236
237#[cfg(test)]
238mod test {
239    use std::collections::HashSet;
240
241    use crate::builder::{
242        BlockBuilder, BuildError, CFGBuilder, Container, Dataflow, DataflowSubContainer,
243        HugrBuilder, ModuleBuilder,
244    };
245    use crate::extension::prelude::usize_t;
246    use crate::hugr::HugrMut;
247    use crate::ops::constant::Value;
248    use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
249    use crate::types::Signature;
250    use crate::{Hugr, HugrView, Node};
251    use cool_asserts::assert_matches;
252    use itertools::Itertools;
253    use rstest::rstest;
254
255    use super::{OutlineCfg, OutlineCfgError};
256
257    ///      /-> left --\
258    /// entry            > merge -> head -> tail -> exit
259    ///      \-> right -/             \-<--<-/
260    struct CondThenLoopCfg {
261        h: Hugr,
262        left: Node,
263        right: Node,
264        merge: Node,
265        head: Node,
266        tail: Node,
267    }
268    impl CondThenLoopCfg {
269        fn new() -> Result<CondThenLoopCfg, BuildError> {
270            let block_ty = Signature::new_endo(usize_t());
271            let mut cfg_builder = CFGBuilder::new(block_ty.clone())?;
272            let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
273            let const_unit = cfg_builder.add_constant(Value::unary_unit_sum());
274            fn n_identity(
275                mut bbldr: BlockBuilder<&mut Hugr>,
276                cst: &ConstID,
277            ) -> Result<BasicBlockID, BuildError> {
278                let pred = bbldr.load_const(cst);
279                let vals = bbldr.input_wires();
280                bbldr.finish_with_outputs(pred, vals)
281            }
282            let id_block = |c: &mut CFGBuilder<_>| {
283                n_identity(c.simple_block_builder(block_ty.clone(), 1)?, &const_unit)
284            };
285
286            let entry = n_identity(
287                cfg_builder.simple_entry_builder(usize_t().into(), 2)?,
288                &pred_const,
289            )?;
290
291            let left = id_block(&mut cfg_builder)?;
292            let right = id_block(&mut cfg_builder)?;
293            cfg_builder.branch(&entry, 0, &left)?;
294            cfg_builder.branch(&entry, 1, &right)?;
295
296            let merge = id_block(&mut cfg_builder)?;
297            cfg_builder.branch(&left, 0, &merge)?;
298            cfg_builder.branch(&right, 0, &merge)?;
299
300            let head = id_block(&mut cfg_builder)?;
301            cfg_builder.branch(&merge, 0, &head)?;
302            let tail = n_identity(
303                cfg_builder.simple_block_builder(Signature::new_endo(usize_t()), 2)?,
304                &pred_const,
305            )?;
306            cfg_builder.branch(&tail, 1, &head)?;
307            cfg_builder.branch(&head, 0, &tail)?; // trivial "loop body"
308            let exit = cfg_builder.exit_block();
309            cfg_builder.branch(&tail, 0, &exit)?;
310
311            let h = cfg_builder.finish_hugr()?;
312            let (left, right) = (left.node(), right.node());
313            let (merge, head, tail) = (merge.node(), head.node(), tail.node());
314            Ok(Self {
315                h,
316                left,
317                right,
318                merge,
319                head,
320                tail,
321            })
322        }
323        fn entry_exit(&self) -> (Node, Node) {
324            self.h
325                .children(self.h.entrypoint())
326                .take(2)
327                .collect_tuple()
328                .unwrap()
329        }
330    }
331
332    #[rstest::fixture]
333    fn cond_then_loop_cfg() -> CondThenLoopCfg {
334        CondThenLoopCfg::new().unwrap()
335    }
336
337    #[rstest]
338    fn test_outline_cfg_errors(cond_then_loop_cfg: CondThenLoopCfg) {
339        let (entry, _) = cond_then_loop_cfg.entry_exit();
340        let CondThenLoopCfg {
341            mut h,
342            left,
343            right,
344            merge,
345            head,
346            tail,
347        } = cond_then_loop_cfg;
348        let backup = h.clone();
349
350        let r = h.apply_patch(OutlineCfg::new([tail]));
351        assert_matches!(r, Err(OutlineCfgError::MultipleExitEdges(_, _)));
352        assert_eq!(h, backup);
353
354        let r = h.apply_patch(OutlineCfg::new([entry, left, right]));
355        assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b))
356            => assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right])));
357        assert_eq!(h, backup);
358
359        let r = h.apply_patch(OutlineCfg::new([left, right, merge]));
360        assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
361            => assert_eq!(HashSet::from([a,b]), HashSet::from([left, right])));
362        assert_eq!(h, backup);
363
364        // The entry node implicitly has an extra incoming edge
365        let r = h.apply_patch(OutlineCfg::new([entry, left, right, merge, head]));
366        assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
367            => assert_eq!(HashSet::from([a,b]), HashSet::from([entry, head])));
368        assert_eq!(h, backup);
369    }
370
371    #[rstest::rstest]
372    fn test_outline_cfg(cond_then_loop_cfg: CondThenLoopCfg) {
373        // Outline the loop, producing:
374        //     /-> left -->\
375        // entry            merge -> newblock -> exit
376        //     \-> right ->/
377        let (_, exit) = cond_then_loop_cfg.entry_exit();
378        let CondThenLoopCfg {
379            mut h,
380            merge,
381            head,
382            tail,
383            ..
384        } = cond_then_loop_cfg;
385        let root = h.entrypoint();
386        let (new_block, _, exit_block) = outline_cfg_check_parents(&mut h, root, vec![head, tail]);
387        assert_eq!(h.output_neighbours(merge).collect_vec(), vec![new_block]);
388        assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
389        assert_eq!(
390            h.output_neighbours(tail).collect::<HashSet<Node>>(),
391            HashSet::from([head, exit_block])
392        );
393    }
394
395    #[rstest]
396    fn test_outline_cfg_multiple_in_edges(cond_then_loop_cfg: CondThenLoopCfg) {
397        // Outline merge, head and tail, producing
398        //     /-> left -->\
399        // entry            newblock -> exit
400        //     \-> right ->/
401        let (_, exit) = cond_then_loop_cfg.entry_exit();
402        let CondThenLoopCfg {
403            mut h,
404            left,
405            right,
406            merge,
407            head,
408            tail,
409        } = cond_then_loop_cfg;
410
411        let root = h.entrypoint();
412        let (new_block, _, inner_exit) =
413            outline_cfg_check_parents(&mut h, root, vec![merge, head, tail]);
414        assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
415        assert_eq!(
416            h.input_neighbours(new_block).collect::<HashSet<_>>(),
417            HashSet::from([left, right])
418        );
419        assert_eq!(
420            h.output_neighbours(tail).collect::<HashSet<Node>>(),
421            HashSet::from([head, inner_exit])
422        );
423    }
424
425    #[rstest]
426    fn test_outline_cfg_subregion(cond_then_loop_cfg: CondThenLoopCfg) {
427        // Outline the loop, as above, but with the CFG inside a Function + Module,
428        // operating via a SiblingMut
429        let mut module_builder = ModuleBuilder::new();
430        let mut fbuild = module_builder
431            .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))
432            .unwrap();
433        let [i1] = fbuild.input_wires_arr();
434        let cfg = fbuild
435            .add_hugr_with_wires(cond_then_loop_cfg.h, [i1])
436            .unwrap();
437        fbuild.finish_with_outputs(cfg.outputs()).unwrap();
438        let mut h = module_builder.finish_hugr().unwrap();
439        // `add_hugr_with_wires` does not return an InsertionResult, so recover the nodes manually:
440        let cfg = cfg.node();
441        let exit_node = h.children(cfg).nth(1).unwrap();
442        let tail = h.input_neighbours(exit_node).exactly_one().ok().unwrap();
443        let head = h.input_neighbours(tail).exactly_one().ok().unwrap();
444        // Just sanity-check we have the correct nodes
445        assert!(h.get_optype(exit_node).is_exit_block());
446        assert_eq!(
447            h.output_neighbours(tail).collect::<HashSet<_>>(),
448            HashSet::from([head, exit_node])
449        );
450        outline_cfg_check_parents(&mut h, cfg, vec![head, tail]);
451        h.validate().unwrap();
452    }
453
454    #[rstest]
455    fn test_outline_cfg_move_entry(cond_then_loop_cfg: CondThenLoopCfg) {
456        // Outline the conditional, producing
457        //
458        //  newblock -> head -> tail -> exit
459        //                 \<--</
460        // (where the new block becomes the entry block)
461        let (entry, _) = cond_then_loop_cfg.entry_exit();
462        let CondThenLoopCfg {
463            mut h,
464            left,
465            right,
466            merge,
467            head,
468            ..
469        } = cond_then_loop_cfg;
470
471        let root = h.entrypoint();
472        let (new_block, _, _) =
473            outline_cfg_check_parents(&mut h, root, vec![entry, left, right, merge]);
474        h.validate().unwrap_or_else(|e| panic!("{e}"));
475        assert_eq!(new_block, h.children(h.entrypoint()).next().unwrap());
476        assert_eq!(h.output_neighbours(new_block).collect_vec(), [head]);
477    }
478
479    fn outline_cfg_check_parents(
480        h: &mut impl HugrMut<Node = Node>,
481        cfg: Node,
482        blocks: Vec<Node>,
483    ) -> (Node, Node, Node) {
484        let mut other_blocks = h.children(cfg).collect::<HashSet<_>>();
485        assert!(blocks.iter().all(|b| other_blocks.remove(b)));
486        let [new_block, new_cfg] = h.apply_patch(OutlineCfg::new(blocks.clone())).unwrap();
487
488        for n in other_blocks {
489            assert_eq!(h.get_parent(n), Some(cfg));
490        }
491        assert_eq!(h.get_parent(new_block), Some(cfg));
492        assert!(h.get_optype(new_block).is_dataflow_block());
493        assert_eq!(h.get_parent(new_cfg), Some(new_block));
494        for n in blocks {
495            assert_eq!(h.get_parent(n), Some(new_cfg));
496        }
497        assert!(h.get_optype(new_cfg).is_cfg());
498        let exit_block = h.children(new_cfg).nth(1).unwrap();
499        assert!(h.get_optype(exit_block).is_exit_block());
500        (new_block, new_cfg, exit_block)
501    }
502}