hugr_core/hugr/rewrite/
outline_cfg.rs

1//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection of an existing CFG
2use std::collections::HashSet;
3
4use itertools::Itertools;
5use thiserror::Error;
6
7use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
8use crate::extension::ExtensionSet;
9use crate::hugr::internal::HugrMutInternals;
10use crate::hugr::rewrite::Rewrite;
11use crate::hugr::views::sibling::SiblingMut;
12use crate::hugr::{HugrMut, HugrView};
13use crate::ops;
14use crate::ops::controlflow::BasicBlock;
15use crate::ops::dataflow::DataflowOpTrait;
16use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle};
17use crate::ops::{DataflowBlock, OpType};
18use crate::PortIndex;
19use crate::{type_row, Node};
20
21/// Moves part of a Control-flow Sibling Graph into a new CFG-node
22/// that is the only child of a new Basic Block in the original CSG.
23pub struct OutlineCfg {
24    blocks: HashSet<Node>,
25}
26
27impl OutlineCfg {
28    /// Create a new OutlineCfg rewrite that will move the provided blocks.
29    pub fn new(blocks: impl IntoIterator<Item = Node>) -> Self {
30        Self {
31            blocks: HashSet::from_iter(blocks),
32        }
33    }
34
35    /// Compute the entry and exit nodes of the CFG which contains
36    /// [`self.blocks`], along with the output neighbour its parent graph and
37    /// the combined extension_deltas of all of the blocks.
38    fn compute_entry_exit_outside_extensions(
39        &self,
40        h: &impl HugrView<Node = Node>,
41    ) -> Result<(Node, Node, Node, ExtensionSet), OutlineCfgError> {
42        let cfg_n = match self
43            .blocks
44            .iter()
45            .map(|n| h.get_parent(*n))
46            .unique()
47            .exactly_one()
48        {
49            Ok(Some(n)) => n,
50            _ => return Err(OutlineCfgError::NotSiblings),
51        };
52        let o = h.get_optype(cfg_n);
53        let OpType::CFG(o) = o else {
54            return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone()));
55        };
56        let cfg_entry = h.children(cfg_n).next().unwrap();
57        let mut entry = None;
58        let mut exit_succ = None;
59        let mut extension_delta = ExtensionSet::new();
60        for &n in self.blocks.iter() {
61            if n == cfg_entry
62                || h.input_neighbours(n)
63                    .any(|pred| !self.blocks.contains(&pred))
64            {
65                match entry {
66                    None => {
67                        entry = Some(n);
68                    }
69                    Some(prev) => {
70                        return Err(OutlineCfgError::MultipleEntryNodes(prev, n));
71                    }
72                }
73            }
74            extension_delta = extension_delta.union(o.signature().runtime_reqs.clone());
75            let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s));
76            match external_succs.at_most_one() {
77                Ok(None) => (), // No external successors
78                Ok(Some(o)) => match exit_succ {
79                    None => {
80                        exit_succ = Some((n, o));
81                    }
82                    Some((prev, _)) => {
83                        return Err(OutlineCfgError::MultipleExitNodes(prev, n));
84                    }
85                },
86                Err(ext) => return Err(OutlineCfgError::MultipleExitEdges(n, ext.collect())),
87            };
88        }
89        match (entry, exit_succ) {
90            (Some(e), Some((x, o))) => Ok((e, x, o, extension_delta)),
91            (None, _) => Err(OutlineCfgError::NoEntryNode),
92            (_, None) => Err(OutlineCfgError::NoExitNode),
93        }
94    }
95}
96
97impl Rewrite for OutlineCfg {
98    type Error = OutlineCfgError;
99    /// The newly-created basic block, and the [CFG] node inside it
100    ///
101    /// [CFG]: OpType::CFG
102    type ApplyResult = (Node, Node);
103
104    const UNCHANGED_ON_FAILURE: bool = true;
105    fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), OutlineCfgError> {
106        self.compute_entry_exit_outside_extensions(h)?;
107        Ok(())
108    }
109    fn apply(self, h: &mut impl HugrMut<Node = Node>) -> Result<(Node, Node), OutlineCfgError> {
110        let (entry, exit, outside, extension_delta) =
111            self.compute_entry_exit_outside_extensions(h)?;
112        // 1. Compute signature
113        // These panic()s only happen if the Hugr would not have passed validate()
114        let OpType::DataflowBlock(DataflowBlock { inputs, .. }) = h.get_optype(entry) else {
115            panic!("Entry node is not a basic block")
116        };
117        let inputs = inputs.clone();
118        let outputs = match h.get_optype(outside) {
119            OpType::DataflowBlock(dfb) => dfb.dataflow_input().clone(),
120            OpType::ExitBlock(exit) => exit.dataflow_input().clone(),
121            _ => panic!("External successor not a basic block"),
122        };
123        let outer_cfg = h.get_parent(entry).unwrap();
124        let outer_entry = h.children(outer_cfg).next().unwrap();
125
126        // 2. new_block contains input node, sub-cfg, exit node all connected
127        let (new_block, cfg_node) = {
128            let mut new_block_bldr = BlockBuilder::new_exts(
129                inputs.clone(),
130                vec![type_row![]],
131                outputs.clone(),
132                extension_delta.clone(),
133            )
134            .unwrap();
135            let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires());
136            let cfg = new_block_bldr
137                .cfg_builder_exts(wires_in, outputs, extension_delta)
138                .unwrap();
139            let cfg = cfg.finish_sub_container().unwrap();
140            let unit_sum = new_block_bldr.add_constant(ops::Value::unary_unit_sum());
141            let pred_wire = new_block_bldr.load_const(&unit_sum);
142            new_block_bldr
143                .set_outputs(pred_wire, cfg.outputs())
144                .unwrap();
145            let ins_res = h.insert_hugr(outer_cfg, new_block_bldr.hugr().clone());
146            (
147                ins_res.new_root,
148                *ins_res.node_map.get(&cfg.node()).unwrap(),
149            )
150        };
151
152        // 3. Entry edges. Change any edges into entry_block from outside, to target new_block
153        let preds: Vec<_> = h
154            .linked_outputs(entry, h.node_inputs(entry).exactly_one().ok().unwrap())
155            .collect();
156        for (pred, br) in preds {
157            if !self.blocks.contains(&pred) {
158                h.disconnect(pred, br);
159                h.connect(pred, br, new_block, 0);
160            }
161        }
162        if entry == outer_entry {
163            // new_block must be the entry node, i.e. first child, of the enclosing CFG
164            // (the current entry node will be reparented inside new_block below)
165            h.move_before_sibling(new_block, outer_entry);
166        }
167
168        // 4(a). Exit edges.
169        // Remove edge from exit_node (that used to target outside)
170        let exit_port = h
171            .node_outputs(exit)
172            .filter(|p| {
173                let (t, p2) = h.single_linked_input(exit, *p).unwrap();
174                assert!(p2.index() == 0);
175                t == outside
176            })
177            .exactly_one()
178            .ok() // NodePorts does not implement Debug
179            .unwrap();
180        h.disconnect(exit, exit_port);
181        // And connect new_block to outside instead
182        h.connect(new_block, 0, outside, 0);
183
184        // 5. Children of new CFG.
185        let inner_exit = {
186            // These operations do not fit within any CSG/SiblingMut
187            // so we need to access the Hugr directly.
188            let h = h.hugr_mut();
189            let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap();
190            // Entry node must be first
191            h.move_before_sibling(entry, inner_exit);
192            // And remaining nodes
193            for n in self.blocks {
194                // Do not move the entry node, as we have already
195                if n != entry {
196                    h.set_parent(n, cfg_node);
197                }
198            }
199            inner_exit
200        };
201
202        // 4(b). Reconnect exit edge to the new exit node within the inner CFG
203        // Use nested SiblingMut's in case the outer `h` is only a SiblingMut itself.
204        let mut in_bb_view: SiblingMut<'_, BasicBlockID> =
205            SiblingMut::try_new(h, new_block).unwrap();
206        let mut in_cfg_view: SiblingMut<'_, CfgID> =
207            SiblingMut::try_new(&mut in_bb_view, cfg_node).unwrap();
208        in_cfg_view.connect(exit, exit_port, inner_exit, 0);
209
210        Ok((new_block, cfg_node))
211    }
212
213    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
214        self.blocks.iter().copied()
215    }
216}
217
218/// Errors that can occur in expressing an OutlineCfg rewrite.
219#[derive(Debug, Error)]
220#[non_exhaustive]
221pub enum OutlineCfgError {
222    /// The set of blocks were not siblings
223    #[error("The nodes did not all have the same parent")]
224    NotSiblings,
225    /// The parent node was not a CFG node
226    #[error("The parent node {0} was not a CFG but a {1}")]
227    ParentNotCfg(Node, OpType),
228    /// Multiple blocks had incoming edges
229    #[error("Multiple blocks had predecessors outside the set - at least {0} and {1}")]
230    MultipleEntryNodes(Node, Node),
231    /// Multiple blocks had outgoing edges
232    // Note possible TODO: straightforward if all outgoing edges target the same BB
233    #[error("Multiple blocks had edges leaving the set - at least {0} and {1}")]
234    MultipleExitNodes(Node, Node),
235    /// One block had multiple outgoing edges
236    #[error("Exit block {0} had edges to multiple external blocks {1:?}")]
237    MultipleExitEdges(Node, Vec<Node>),
238    /// No block was identified as an entry block
239    #[error("No block had predecessors outside the set")]
240    NoEntryNode,
241    /// No block was found with an edge leaving the set (so, must be an infinite loop)
242    #[error("No block had a successor outside the set")]
243    NoExitNode,
244}
245
246#[cfg(test)]
247mod test {
248    use std::collections::HashSet;
249
250    use crate::builder::{
251        BlockBuilder, BuildError, CFGBuilder, Container, Dataflow, DataflowSubContainer,
252        HugrBuilder, ModuleBuilder,
253    };
254    use crate::extension::prelude::usize_t;
255    use crate::hugr::views::sibling::SiblingMut;
256    use crate::hugr::HugrMut;
257    use crate::ops::constant::Value;
258    use crate::ops::handle::{BasicBlockID, CfgID, ConstID, NodeHandle};
259    use crate::types::Signature;
260    use crate::{Hugr, HugrView, Node};
261    use cool_asserts::assert_matches;
262    use itertools::Itertools;
263    use rstest::rstest;
264
265    use super::{OutlineCfg, OutlineCfgError};
266
267    ///      /-> left --\
268    /// entry            > merge -> head -> tail -> exit
269    ///      \-> right -/             \-<--<-/
270    struct CondThenLoopCfg {
271        h: Hugr,
272        left: Node,
273        right: Node,
274        merge: Node,
275        head: Node,
276        tail: Node,
277    }
278    impl CondThenLoopCfg {
279        fn new() -> Result<CondThenLoopCfg, BuildError> {
280            let block_ty = Signature::new_endo(usize_t());
281            let mut cfg_builder = CFGBuilder::new(block_ty.clone())?;
282            let pred_const = cfg_builder.add_constant(Value::unit_sum(0, 2).expect("0 < 2"));
283            let const_unit = cfg_builder.add_constant(Value::unary_unit_sum());
284            fn n_identity(
285                mut bbldr: BlockBuilder<&mut Hugr>,
286                cst: &ConstID,
287            ) -> Result<BasicBlockID, BuildError> {
288                let pred = bbldr.load_const(cst);
289                let vals = bbldr.input_wires();
290                bbldr.finish_with_outputs(pred, vals)
291            }
292            let id_block = |c: &mut CFGBuilder<_>| {
293                n_identity(c.simple_block_builder(block_ty.clone(), 1)?, &const_unit)
294            };
295
296            let entry = n_identity(
297                cfg_builder.simple_entry_builder(usize_t().into(), 2)?,
298                &pred_const,
299            )?;
300
301            let left = id_block(&mut cfg_builder)?;
302            let right = id_block(&mut cfg_builder)?;
303            cfg_builder.branch(&entry, 0, &left)?;
304            cfg_builder.branch(&entry, 1, &right)?;
305
306            let merge = id_block(&mut cfg_builder)?;
307            cfg_builder.branch(&left, 0, &merge)?;
308            cfg_builder.branch(&right, 0, &merge)?;
309
310            let head = id_block(&mut cfg_builder)?;
311            cfg_builder.branch(&merge, 0, &head)?;
312            let tail = n_identity(
313                cfg_builder.simple_block_builder(Signature::new_endo(usize_t()), 2)?,
314                &pred_const,
315            )?;
316            cfg_builder.branch(&tail, 1, &head)?;
317            cfg_builder.branch(&head, 0, &tail)?; // trivial "loop body"
318            let exit = cfg_builder.exit_block();
319            cfg_builder.branch(&tail, 0, &exit)?;
320
321            let h = cfg_builder.finish_hugr()?;
322            let (left, right) = (left.node(), right.node());
323            let (merge, head, tail) = (merge.node(), head.node(), tail.node());
324            Ok(Self {
325                h,
326                left,
327                right,
328                merge,
329                head,
330                tail,
331            })
332        }
333        fn entry_exit(&self) -> (Node, Node) {
334            self.h
335                .children(self.h.root())
336                .take(2)
337                .collect_tuple()
338                .unwrap()
339        }
340    }
341
342    #[rstest::fixture]
343    fn cond_then_loop_cfg() -> CondThenLoopCfg {
344        CondThenLoopCfg::new().unwrap()
345    }
346
347    #[rstest]
348    fn test_outline_cfg_errors(cond_then_loop_cfg: CondThenLoopCfg) {
349        let (entry, _) = cond_then_loop_cfg.entry_exit();
350        let CondThenLoopCfg {
351            mut h,
352            left,
353            right,
354            merge,
355            head,
356            tail,
357        } = cond_then_loop_cfg;
358        let backup = h.clone();
359
360        let r = h.apply_rewrite(OutlineCfg::new([tail]));
361        assert_matches!(r, Err(OutlineCfgError::MultipleExitEdges(_, _)));
362        assert_eq!(h, backup);
363
364        let r = h.apply_rewrite(OutlineCfg::new([entry, left, right]));
365        assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b))
366            => assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right])));
367        assert_eq!(h, backup);
368
369        let r = h.apply_rewrite(OutlineCfg::new([left, right, merge]));
370        assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
371            => assert_eq!(HashSet::from([a,b]), HashSet::from([left, right])));
372        assert_eq!(h, backup);
373
374        // The entry node implicitly has an extra incoming edge
375        let r = h.apply_rewrite(OutlineCfg::new([entry, left, right, merge, head]));
376        assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b))
377            => assert_eq!(HashSet::from([a,b]), HashSet::from([entry, head])));
378        assert_eq!(h, backup);
379    }
380
381    #[rstest::rstest]
382    fn test_outline_cfg(cond_then_loop_cfg: CondThenLoopCfg) {
383        // Outline the loop, producing:
384        //     /-> left -->\
385        // entry            merge -> newblock -> exit
386        //     \-> right ->/
387        let (_, exit) = cond_then_loop_cfg.entry_exit();
388        let CondThenLoopCfg {
389            mut h,
390            merge,
391            head,
392            tail,
393            ..
394        } = cond_then_loop_cfg;
395        let root = h.root();
396        let (new_block, _, exit_block) = outline_cfg_check_parents(&mut h, root, vec![head, tail]);
397        assert_eq!(h.output_neighbours(merge).collect_vec(), vec![new_block]);
398        assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
399        assert_eq!(
400            h.output_neighbours(tail).collect::<HashSet<Node>>(),
401            HashSet::from([head, exit_block])
402        );
403    }
404
405    #[rstest]
406    fn test_outline_cfg_multiple_in_edges(cond_then_loop_cfg: CondThenLoopCfg) {
407        // Outline merge, head and tail, producing
408        //     /-> left -->\
409        // entry            newblock -> exit
410        //     \-> right ->/
411        let (_, exit) = cond_then_loop_cfg.entry_exit();
412        let CondThenLoopCfg {
413            mut h,
414            left,
415            right,
416            merge,
417            head,
418            tail,
419        } = cond_then_loop_cfg;
420
421        let root = h.root();
422        let (new_block, _, inner_exit) =
423            outline_cfg_check_parents(&mut h, root, vec![merge, head, tail]);
424        assert_eq!(h.input_neighbours(exit).collect_vec(), vec![new_block]);
425        assert_eq!(
426            h.input_neighbours(new_block).collect::<HashSet<_>>(),
427            HashSet::from([left, right])
428        );
429        assert_eq!(
430            h.output_neighbours(tail).collect::<HashSet<Node>>(),
431            HashSet::from([head, inner_exit])
432        );
433    }
434
435    #[rstest]
436    fn test_outline_cfg_subregion(cond_then_loop_cfg: CondThenLoopCfg) {
437        // Outline the loop, as above, but with the CFG inside a Function + Module,
438        // operating via a SiblingMut
439        let mut module_builder = ModuleBuilder::new();
440        let mut fbuild = module_builder
441            .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))
442            .unwrap();
443        let [i1] = fbuild.input_wires_arr();
444        let cfg = fbuild
445            .add_hugr_with_wires(cond_then_loop_cfg.h, [i1])
446            .unwrap();
447        fbuild.finish_with_outputs(cfg.outputs()).unwrap();
448        let mut h = module_builder.finish_hugr().unwrap();
449        // `add_hugr_with_wires` does not return an InsertionResult, so recover the nodes manually:
450        let cfg = cfg.node();
451        let exit_node = h.children(cfg).nth(1).unwrap();
452        let tail = h.input_neighbours(exit_node).exactly_one().ok().unwrap();
453        let head = h.input_neighbours(tail).exactly_one().ok().unwrap();
454        // Just sanity-check we have the correct nodes
455        assert!(h.get_optype(exit_node).is_exit_block());
456        assert_eq!(
457            h.output_neighbours(tail).collect::<HashSet<_>>(),
458            HashSet::from([head, exit_node])
459        );
460        outline_cfg_check_parents(
461            &mut SiblingMut::<'_, CfgID>::try_new(&mut h, cfg).unwrap(),
462            cfg,
463            vec![head, tail],
464        );
465        h.validate().unwrap();
466    }
467
468    #[rstest]
469    fn test_outline_cfg_move_entry(cond_then_loop_cfg: CondThenLoopCfg) {
470        // Outline the conditional, producing
471        //
472        //  newblock -> head -> tail -> exit
473        //                 \<--</
474        // (where the new block becomes the entry block)
475        let (entry, _) = cond_then_loop_cfg.entry_exit();
476        let CondThenLoopCfg {
477            mut h,
478            left,
479            right,
480            merge,
481            head,
482            ..
483        } = cond_then_loop_cfg;
484
485        let root = h.root();
486        let (new_block, _, _) =
487            outline_cfg_check_parents(&mut h, root, vec![entry, left, right, merge]);
488        h.validate().unwrap();
489        assert_eq!(new_block, h.children(h.root()).next().unwrap());
490        assert_eq!(h.output_neighbours(new_block).collect_vec(), [head]);
491    }
492
493    fn outline_cfg_check_parents(
494        h: &mut impl HugrMut,
495        cfg: Node,
496        blocks: Vec<Node>,
497    ) -> (Node, Node, Node) {
498        let mut other_blocks = h.children(cfg).collect::<HashSet<_>>();
499        assert!(blocks.iter().all(|b| other_blocks.remove(b)));
500        let (new_block, new_cfg) = h.apply_rewrite(OutlineCfg::new(blocks.clone())).unwrap();
501
502        for n in other_blocks {
503            assert_eq!(h.get_parent(n), Some(cfg))
504        }
505        assert_eq!(h.get_parent(new_block), Some(cfg));
506        assert!(h.get_optype(new_block).is_dataflow_block());
507        let b = h.base_hugr(); // To cope with `h` potentially being a SiblingMut
508        assert_eq!(b.get_parent(new_cfg), Some(new_block));
509        for n in blocks {
510            assert_eq!(b.get_parent(n), Some(new_cfg));
511        }
512        assert!(b.get_optype(new_cfg).is_cfg());
513        let exit_block = b.children(new_cfg).nth(1).unwrap();
514        assert!(b.get_optype(exit_block).is_exit_block());
515        (new_block, new_cfg, exit_block)
516    }
517}