hugr_core/builder/
cfg.rs

1use super::{
2    BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire,
3    build_traits::SubContainer,
4    dataflow::{DFGBuilder, DFGWrapper},
5    handle::BuildHandle,
6};
7
8use crate::ops::{self, DataflowBlock, DataflowParent, ExitBlock, OpType, handle::NodeHandle};
9use crate::types::Signature;
10use crate::{hugr::views::HugrView, types::TypeRow};
11
12use crate::Node;
13use crate::{Hugr, hugr::HugrMut, type_row};
14
15/// Builder for a [`crate::ops::CFG`] child control
16/// flow graph.
17///
18/// These builder methods should ensure that the first two children of a CFG
19/// node are the entry node and the exit node.
20///
21/// # Example
22/// ```
23/// /*  Build a control flow graph with the following structure:
24///            +-----------+
25///            |   Entry   |
26///            +-/-----\---+
27///             /       \
28///            /         \
29///           /           \
30///          /             \
31///   +-----/----+       +--\-------+
32///   | Branch A |       | Branch B |
33///   +-----\----+       +----/-----+
34///          \               /
35///           \             /
36///            \           /
37///             \         /
38///            +-\-------/--+
39///            |    Exit    |
40///            +------------+
41/// */
42/// use hugr::{
43///     builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder, endo_sig, inout_sig},
44///     extension::{prelude, ExtensionSet},
45///     ops, type_row,
46///     types::{Signature, SumType, Type},
47///     Hugr,
48///     extension::prelude::usize_t,
49/// };
50///
51/// fn make_cfg() -> Result<Hugr, BuildError> {
52///     let mut cfg_builder = CFGBuilder::new(Signature::new_endo(usize_t()))?;
53///
54///     // Outputs from basic blocks must be packed in a sum which corresponds to
55///     // which successor to pick. We'll either choose the first branch and pass
56///     // it a usize, or the second branch and pass it nothing.
57///     let sum_variants = vec![vec![usize_t()].into(), type_row![]];
58///
59///     // The second argument says what types will be passed through to every
60///     // successor, in addition to the appropriate `sum_variants` type.
61///     let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), vec![usize_t()].into())?;
62///
63///     let [inw] = entry_b.input_wires_arr();
64///     let entry = {
65///         // Pack the const "42" into the appropriate sum type.
66///         let left_42 = ops::Value::sum(
67///             0,
68///             [prelude::ConstUsize::new(42).into()],
69///             SumType::new(sum_variants.clone()),
70///         )?;
71///         let sum = entry_b.add_load_value(left_42);
72///
73///         entry_b.finish_with_outputs(sum, [inw])?
74///     };
75///
76///     // This block will be the first successor of the entry node. It takes two
77///     // `usize` arguments: one from the `sum_variants` type, and another from the
78///     // entry node's `other_outputs`.
79///     let mut successor_builder = cfg_builder.simple_block_builder(
80///         inout_sig(vec![usize_t(), usize_t()], usize_t()),
81///         1, // only one successor to this block
82///     )?;
83///     let successor_a = {
84///         // This block has one successor. The choice is denoted by a unary sum.
85///         let sum_unary = successor_builder.add_load_const(ops::Value::unary_unit_sum());
86///
87///         // The input wires of a node start with the data embedded in the variant
88///         // which selected this block.
89///         let [_forty_two, in_wire] = successor_builder.input_wires_arr();
90///         successor_builder.finish_with_outputs(sum_unary, [in_wire])?
91///     };
92///
93///     // The only argument to this block is the entry node's `other_outputs`.
94///     let mut successor_builder = cfg_builder.simple_block_builder(endo_sig(usize_t()), 1)?;
95///     let successor_b = {
96///         let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum());
97///         let [in_wire] = successor_builder.input_wires_arr();
98///         successor_builder.finish_with_outputs(sum_unary, [in_wire])?
99///     };
100///     let exit = cfg_builder.exit_block();
101///     cfg_builder.branch(&entry, 0, &successor_a)?; // branch 0 goes to successor_a
102///     cfg_builder.branch(&entry, 1, &successor_b)?; // branch 1 goes to successor_b
103///     cfg_builder.branch(&successor_a, 0, &exit)?;
104///     cfg_builder.branch(&successor_b, 0, &exit)?;
105///     let hugr = cfg_builder.finish_hugr()?;
106///     Ok(hugr)
107/// };
108/// assert!(make_cfg().is_ok());
109/// ```
110#[derive(Debug, PartialEq)]
111pub struct CFGBuilder<T> {
112    pub(super) base: T,
113    pub(super) cfg_node: Node,
114    pub(super) inputs: Option<TypeRow>,
115    pub(super) exit_node: Node,
116    pub(super) n_out_wires: usize,
117}
118
119impl<B: AsMut<Hugr> + AsRef<Hugr>> Container for CFGBuilder<B> {
120    #[inline]
121    fn container_node(&self) -> Node {
122        self.cfg_node
123    }
124
125    #[inline]
126    fn hugr_mut(&mut self) -> &mut Hugr {
127        self.base.as_mut()
128    }
129
130    #[inline]
131    fn hugr(&self) -> &Hugr {
132        self.base.as_ref()
133    }
134}
135
136impl<H: AsMut<Hugr> + AsRef<Hugr>> SubContainer for CFGBuilder<H> {
137    type ContainerHandle = BuildHandle<CfgID>;
138    #[inline]
139    fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
140        Ok((self.cfg_node, self.n_out_wires).into())
141    }
142}
143
144impl CFGBuilder<Hugr> {
145    /// New CFG rooted HUGR builder
146    pub fn new(signature: Signature) -> Result<Self, BuildError> {
147        let cfg_op = ops::CFG {
148            signature: signature.clone(),
149        };
150
151        let base = Hugr::new_with_entrypoint(cfg_op).expect("CFG entrypoints be valid");
152        let cfg_node = base.entrypoint();
153        CFGBuilder::create(base, cfg_node, signature.input, signature.output)
154    }
155}
156
157impl HugrBuilder for CFGBuilder<Hugr> {
158    fn finish_hugr(self) -> Result<Hugr, crate::hugr::ValidationError<Node>> {
159        self.base.validate()?;
160        Ok(self.base)
161    }
162}
163
164impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
165    pub(super) fn create(
166        mut base: B,
167        cfg_node: Node,
168        input: TypeRow,
169        output: TypeRow,
170    ) -> Result<Self, BuildError> {
171        let n_out_wires = output.len();
172        let exit_block_type = OpType::ExitBlock(ExitBlock {
173            cfg_outputs: output,
174        });
175        let exit_node = base
176            .as_mut()
177            // Make the extensions a parameter
178            .add_node_with_parent(cfg_node, exit_block_type);
179        Ok(Self {
180            base,
181            cfg_node,
182            n_out_wires,
183            exit_node,
184            inputs: Some(input),
185        })
186    }
187
188    /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
189    /// and `outputs` and the variants of the branching Sum value
190    /// specified by `sum_rows`.
191    ///
192    /// # Errors
193    ///
194    /// This function will return an error if there is an error adding the node.
195    pub fn block_builder(
196        &mut self,
197        inputs: TypeRow,
198        sum_rows: impl IntoIterator<Item = TypeRow>,
199        other_outputs: TypeRow,
200    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
201        self.any_block_builder(inputs, sum_rows, other_outputs, false)
202    }
203
204    fn any_block_builder(
205        &mut self,
206        inputs: TypeRow,
207        sum_rows: impl IntoIterator<Item = TypeRow>,
208        other_outputs: TypeRow,
209        entry: bool,
210    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
211        let sum_rows: Vec<_> = sum_rows.into_iter().collect();
212        let op = OpType::DataflowBlock(DataflowBlock {
213            inputs: inputs.clone(),
214            other_outputs: other_outputs.clone(),
215            sum_rows,
216        });
217        let parent = self.container_node();
218        let block_n = if entry {
219            let exit = self.exit_node;
220            // TODO: Make extensions a parameter
221            self.hugr_mut().add_node_before(exit, op)
222        } else {
223            // TODO: Make extensions a parameter
224            self.hugr_mut().add_node_with_parent(parent, op)
225        };
226
227        BlockBuilder::create_with_io(self.hugr_mut(), block_n)
228    }
229
230    /// Return a builder for a non-entry [`DataflowBlock`] child graph with
231    /// `inputs` and `outputs` , plus a `UnitSum` type (a Sum of `n_cases` unit
232    /// types) to select the successor.
233    ///
234    /// # Errors
235    ///
236    /// This function will return an error if there is an error adding the node.
237    pub fn simple_block_builder(
238        &mut self,
239        signature: Signature,
240        n_cases: usize,
241    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
242        self.block_builder(
243            signature.input,
244            vec![type_row![]; n_cases],
245            signature.output,
246        )
247    }
248
249    /// Return a builder for the entry [`DataflowBlock`] child graph with `outputs`
250    /// and the variants of the branching Sum value specified by `sum_rows`.
251    ///
252    /// # Errors
253    ///
254    /// This function will return an error if an entry block has already been built.
255    pub fn entry_builder(
256        &mut self,
257        sum_rows: impl IntoIterator<Item = TypeRow>,
258        other_outputs: TypeRow,
259    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
260        let inputs = self
261            .inputs
262            .take()
263            .ok_or(BuildError::EntryBuiltError(self.cfg_node))?;
264        self.any_block_builder(inputs, sum_rows, other_outputs, true)
265    }
266
267    /// Return a builder for the entry [`DataflowBlock`] child graph with
268    /// `outputs` and a `UnitSum` type: a Sum of `n_cases` unit types.
269    ///
270    /// # Errors
271    ///
272    /// This function will return an error if there is an error adding the node.
273    pub fn simple_entry_builder(
274        &mut self,
275        outputs: TypeRow,
276        n_cases: usize,
277    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
278        self.entry_builder(vec![type_row![]; n_cases], outputs)
279    }
280
281    /// Returns the exit block of this [`CFGBuilder`].
282    pub fn exit_block(&self) -> BasicBlockID {
283        self.exit_node.into()
284    }
285
286    /// Set the `branch` index `successor` block of `predecessor`.
287    ///
288    /// # Errors
289    ///
290    /// This function will return an error if there is an error connecting the blocks.
291    pub fn branch(
292        &mut self,
293        predecessor: &BasicBlockID,
294        branch: usize,
295        successor: &BasicBlockID,
296    ) -> Result<(), BuildError> {
297        let from = predecessor.node();
298        let to = successor.node();
299        self.hugr_mut().connect(from, branch, to, 0);
300        Ok(())
301    }
302}
303
304/// Builder for a [`DataflowBlock`] child graph.
305pub type BlockBuilder<B> = DFGWrapper<B, BasicBlockID>;
306
307impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
308    /// Set the outputs of the block, with `branch_wire` carrying  the value of the
309    /// branch controlling Sum value.  `outputs` are the remaining outputs.
310    pub fn set_outputs(
311        &mut self,
312        branch_wire: Wire,
313        outputs: impl IntoIterator<Item = Wire>,
314    ) -> Result<(), BuildError> {
315        Dataflow::set_outputs(self, [branch_wire].into_iter().chain(outputs))
316    }
317
318    /// Create a new `BlockBuilder`.
319    ///
320    /// See [`BlockBuilder::create_with_io`] if you need to initialize the input
321    /// and output nodes.
322    ///
323    /// # Parameters
324    /// - `base`: The base HUGR to build on.
325    /// - `block_n`: The block we are building.
326    fn create(base: B, block_n: Node) -> Result<Self, BuildError> {
327        let db = DFGBuilder::create(base, block_n)?;
328        Ok(BlockBuilder::from_dfg_builder(db))
329    }
330
331    /// Create a new `BlockBuilder`, initializing the input and output nodes.
332    ///
333    /// See [`BlockBuilder::create`] if you don't need to initialize the input
334    /// and output nodes.
335    ///
336    /// # Parameters
337    /// - `base`: The base HUGR to build on.
338    /// - `block_n`: The block we are building.
339    fn create_with_io(base: B, block_n: Node) -> Result<Self, BuildError> {
340        let block_op = base
341            .as_ref()
342            .get_optype(block_n)
343            .as_dataflow_block()
344            .unwrap();
345        let signature = block_op.inner_signature().into_owned();
346        let db = DFGBuilder::create_with_io(base, block_n, signature)?;
347        Ok(BlockBuilder::from_dfg_builder(db))
348    }
349
350    /// [Set outputs](BlockBuilder::set_outputs) and [finish](`BlockBuilder::finish_sub_container`).
351    pub fn finish_with_outputs(
352        mut self,
353        branch_wire: Wire,
354        outputs: impl IntoIterator<Item = Wire>,
355    ) -> Result<<Self as SubContainer>::ContainerHandle, BuildError>
356    where
357        Self: Sized,
358    {
359        self.set_outputs(branch_wire, outputs)?;
360        self.finish_sub_container()
361    }
362}
363
364impl BlockBuilder<Hugr> {
365    /// Initialize a [`DataflowBlock`] rooted HUGR builder.
366    pub fn new(
367        inputs: impl Into<TypeRow>,
368        sum_rows: impl IntoIterator<Item = TypeRow>,
369        other_outputs: impl Into<TypeRow>,
370    ) -> Result<Self, BuildError> {
371        let inputs = inputs.into();
372        let sum_rows: Vec<_> = sum_rows.into_iter().collect();
373        let other_outputs: TypeRow = other_outputs.into();
374        let num_out_branches = sum_rows.len();
375
376        // We only support blocks where all the possible `sum_rows` branches have the same types,
377        // as that lets us branch it directly to an exit node.
378        if let Some(row) = sum_rows.first() {
379            if sum_rows.iter().skip(1).any(|r2| row != r2) {
380                return Err(BuildError::BasicBlockTooComplex);
381            }
382        }
383        let cfg_outputs = sum_rows.first().cloned().unwrap_or_default();
384        let cfg_outputs = cfg_outputs.extend(other_outputs.as_slice());
385
386        let mut cfg = CFGBuilder::new(Signature::new(inputs, cfg_outputs))?;
387        let block = cfg.entry_builder(sum_rows, other_outputs)?;
388        let block = block.finish_sub_container()?;
389        for i in 0..num_out_branches {
390            cfg.branch(&block, i, &cfg.exit_block())?;
391        }
392        let mut base = std::mem::take(cfg.hugr_mut());
393        let root = block.node();
394        base.set_entrypoint(root);
395        Self::create(base, root)
396    }
397
398    /// [Set outputs](BlockBuilder::set_outputs) and [`finish_hugr`](`BlockBuilder::finish_hugr`).
399    pub fn finish_hugr_with_outputs(
400        mut self,
401        branch_wire: Wire,
402        outputs: impl IntoIterator<Item = Wire>,
403    ) -> Result<Hugr, BuildError> {
404        self.set_outputs(branch_wire, outputs)?;
405        self.finish_hugr().map_err(BuildError::InvalidHUGR)
406    }
407}
408
409#[cfg(test)]
410pub(crate) mod test {
411    use crate::builder::{DataflowSubContainer, ModuleBuilder};
412
413    use crate::extension::prelude::{bool_t, usize_t};
414    use crate::hugr::ValidationError;
415    use crate::hugr::validate::InterGraphEdgeError;
416    use crate::type_row;
417    use cool_asserts::assert_matches;
418
419    use super::*;
420    #[test]
421    fn basic_module_cfg() -> Result<(), BuildError> {
422        let build_result = {
423            let mut module_builder = ModuleBuilder::new();
424            let mut func_builder = module_builder
425                .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))?;
426            let _f_id = {
427                let [int] = func_builder.input_wires_arr();
428
429                let cfg_id = {
430                    let mut cfg_builder =
431                        func_builder.cfg_builder(vec![(usize_t(), int)], vec![usize_t()].into())?;
432                    build_basic_cfg(&mut cfg_builder)?;
433
434                    cfg_builder.finish_sub_container()?
435                };
436
437                func_builder.finish_with_outputs(cfg_id.outputs())?
438            };
439            module_builder.finish_hugr()
440        };
441
442        assert!(build_result.is_ok(), "{}", build_result.unwrap_err());
443
444        Ok(())
445    }
446    #[test]
447    fn basic_cfg_hugr() -> Result<(), BuildError> {
448        let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
449        build_basic_cfg(&mut cfg_builder)?;
450        assert_matches!(cfg_builder.finish_hugr(), Ok(_));
451
452        Ok(())
453    }
454
455    #[test]
456    fn basic_cfg_block() -> Result<(), BuildError> {
457        assert_eq!(
458            BlockBuilder::new(
459                vec![],
460                [vec![usize_t()].into(), vec![bool_t()].into()],
461                vec![]
462            ),
463            Err(BuildError::BasicBlockTooComplex)
464        );
465
466        let sum_rows: Vec<TypeRow> = vec![vec![usize_t()].into(), vec![usize_t()].into()];
467        let mut block_builder =
468            BlockBuilder::new(vec![usize_t()], sum_rows.clone(), vec![usize_t()])?;
469        let [inp] = block_builder.input_wires_arr();
470        let branch = block_builder.make_sum(0, sum_rows, [inp])?;
471        let hugr = block_builder.finish_hugr_with_outputs(branch, [inp])?;
472
473        hugr.validate().unwrap();
474
475        Ok(())
476    }
477
478    pub(crate) fn build_basic_cfg<T: AsMut<Hugr> + AsRef<Hugr>>(
479        cfg_builder: &mut CFGBuilder<T>,
480    ) -> Result<(), BuildError> {
481        let usize_row: TypeRow = vec![usize_t()].into();
482        let sum2_variants = vec![usize_row.clone(), usize_row];
483        let mut entry_b = cfg_builder.entry_builder(sum2_variants.clone(), type_row![])?;
484        let entry = {
485            let [inw] = entry_b.input_wires_arr();
486
487            let sum = entry_b.make_sum(1, sum2_variants, [inw])?;
488            entry_b.finish_with_outputs(sum, [])?
489        };
490        let mut middle_b = cfg_builder
491            .simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?;
492        let middle = {
493            let c = middle_b.add_load_const(ops::Value::unary_unit_sum());
494            let [inw] = middle_b.input_wires_arr();
495            middle_b.finish_with_outputs(c, [inw])?
496        };
497        let exit = cfg_builder.exit_block();
498        cfg_builder.branch(&entry, 0, &middle)?;
499        cfg_builder.branch(&middle, 0, &exit)?;
500        cfg_builder.branch(&entry, 1, &exit)?;
501        Ok(())
502    }
503    #[test]
504    fn test_dom_edge() -> Result<(), BuildError> {
505        let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
506        let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
507        let sum_variants = vec![type_row![]];
508
509        let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![])?;
510        let [inw] = entry_b.input_wires_arr();
511        let entry = {
512            let sum = entry_b.load_const(&sum_tuple_const);
513
514            entry_b.finish_with_outputs(sum, [])?
515        };
516        let mut middle_b =
517            cfg_builder.simple_block_builder(Signature::new(type_row![], vec![usize_t()]), 1)?;
518        let middle = {
519            let c = middle_b.load_const(&sum_tuple_const);
520            middle_b.finish_with_outputs(c, [inw])?
521        };
522        let exit = cfg_builder.exit_block();
523        cfg_builder.branch(&entry, 0, &middle)?;
524        cfg_builder.branch(&middle, 0, &exit)?;
525        assert_matches!(cfg_builder.finish_hugr(), Ok(_));
526
527        Ok(())
528    }
529
530    #[test]
531    fn test_non_dom_edge() -> Result<(), BuildError> {
532        let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
533        let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
534        let sum_variants = vec![type_row![]];
535        let mut middle_b = cfg_builder
536            .simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?;
537        let [inw] = middle_b.input_wires_arr();
538        let middle = {
539            let c = middle_b.load_const(&sum_tuple_const);
540            middle_b.finish_with_outputs(c, [inw])?
541        };
542
543        let mut entry_b =
544            cfg_builder.entry_builder(sum_variants.clone(), vec![usize_t()].into())?;
545        let entry = {
546            let sum = entry_b.load_const(&sum_tuple_const);
547            // entry block uses wire from middle block even though middle block
548            // does not dominate entry
549            entry_b.finish_with_outputs(sum, [inw])?
550        };
551        let exit = cfg_builder.exit_block();
552        cfg_builder.branch(&entry, 0, &middle)?;
553        cfg_builder.branch(&middle, 0, &exit)?;
554        assert_matches!(
555            cfg_builder.finish_hugr(),
556            Err(ValidationError::InterGraphEdgeError(
557                InterGraphEdgeError::NonDominatedAncestor { .. }
558            ))
559        );
560
561        Ok(())
562    }
563}