hugr_core/builder/
cfg.rs

1use super::{
2    build_traits::SubContainer,
3    dataflow::{DFGBuilder, DFGWrapper},
4    handle::BuildHandle,
5    BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire,
6};
7
8use crate::extension::TO_BE_INFERRED;
9use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType};
10use crate::{extension::ExtensionSet, types::Signature};
11use crate::{hugr::views::HugrView, types::TypeRow};
12
13use crate::Node;
14use crate::{hugr::HugrMut, type_row, Hugr};
15
16/// Builder for a [`crate::ops::CFG`] child control
17/// flow graph.
18///
19/// These builder methods should ensure that the first two children of a CFG
20/// node are the entry node and the exit node.
21///
22/// # Example
23/// ```
24/// /*  Build a control flow graph with the following structure:
25///            +-----------+
26///            |   Entry   |
27///            +-/-----\---+
28///             /       \
29///            /         \
30///           /           \
31///          /             \
32///   +-----/----+       +--\-------+
33///   | Branch A |       | Branch B |
34///   +-----\----+       +----/-----+
35///          \               /
36///           \             /
37///            \           /
38///             \         /
39///            +-\-------/--+
40///            |    Exit    |
41///            +------------+
42/// */
43/// use hugr::{
44///     builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder, endo_sig, inout_sig},
45///     extension::{prelude, ExtensionSet},
46///     ops, type_row,
47///     types::{Signature, SumType, Type},
48///     Hugr,
49///     extension::prelude::usize_t,
50/// };
51///
52/// fn make_cfg() -> Result<Hugr, BuildError> {
53///     let mut cfg_builder = CFGBuilder::new(Signature::new_endo(usize_t()))?;
54///
55///     // Outputs from basic blocks must be packed in a sum which corresponds to
56///     // which successor to pick. We'll either choose the first branch and pass
57///     // it a usize, or the second branch and pass it nothing.
58///     let sum_variants = vec![vec![usize_t()].into(), type_row![]];
59///
60///     // The second argument says what types will be passed through to every
61///     // successor, in addition to the appropriate `sum_variants` type.
62///     let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), vec![usize_t()].into())?;
63///
64///     let [inw] = entry_b.input_wires_arr();
65///     let entry = {
66///         // Pack the const "42" into the appropriate sum type.
67///         let left_42 = ops::Value::sum(
68///             0,
69///             [prelude::ConstUsize::new(42).into()],
70///             SumType::new(sum_variants.clone()),
71///         )?;
72///         let sum = entry_b.add_load_value(left_42);
73///
74///         entry_b.finish_with_outputs(sum, [inw])?
75///     };
76///
77///     // This block will be the first successor of the entry node. It takes two
78///     // `usize` arguments: one from the `sum_variants` type, and another from the
79///     // entry node's `other_outputs`.
80///     let mut successor_builder = cfg_builder.simple_block_builder(
81///         inout_sig(vec![usize_t(), usize_t()], usize_t()),
82///         1, // only one successor to this block
83///     )?;
84///     let successor_a = {
85///         // This block has one successor. The choice is denoted by a unary sum.
86///         let sum_unary = successor_builder.add_load_const(ops::Value::unary_unit_sum());
87///
88///         // The input wires of a node start with the data embedded in the variant
89///         // which selected this block.
90///         let [_forty_two, in_wire] = successor_builder.input_wires_arr();
91///         successor_builder.finish_with_outputs(sum_unary, [in_wire])?
92///     };
93///
94///     // The only argument to this block is the entry node's `other_outputs`.
95///     let mut successor_builder = cfg_builder.simple_block_builder(endo_sig(usize_t()), 1)?;
96///     let successor_b = {
97///         let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum());
98///         let [in_wire] = successor_builder.input_wires_arr();
99///         successor_builder.finish_with_outputs(sum_unary, [in_wire])?
100///     };
101///     let exit = cfg_builder.exit_block();
102///     cfg_builder.branch(&entry, 0, &successor_a)?; // branch 0 goes to successor_a
103///     cfg_builder.branch(&entry, 1, &successor_b)?; // branch 1 goes to successor_b
104///     cfg_builder.branch(&successor_a, 0, &exit)?;
105///     cfg_builder.branch(&successor_b, 0, &exit)?;
106///     let hugr = cfg_builder.finish_hugr()?;
107///     Ok(hugr)
108/// };
109/// #[cfg(not(feature = "extension_inference"))]
110/// assert!(make_cfg().is_ok());
111/// ```
112#[derive(Debug, PartialEq)]
113pub struct CFGBuilder<T> {
114    pub(super) base: T,
115    pub(super) cfg_node: Node,
116    pub(super) inputs: Option<TypeRow>,
117    pub(super) exit_node: Node,
118    pub(super) n_out_wires: usize,
119}
120
121impl<B: AsMut<Hugr> + AsRef<Hugr>> Container for CFGBuilder<B> {
122    #[inline]
123    fn container_node(&self) -> Node {
124        self.cfg_node
125    }
126
127    #[inline]
128    fn hugr_mut(&mut self) -> &mut Hugr {
129        self.base.as_mut()
130    }
131
132    #[inline]
133    fn hugr(&self) -> &Hugr {
134        self.base.as_ref()
135    }
136}
137
138impl<H: AsMut<Hugr> + AsRef<Hugr>> SubContainer for CFGBuilder<H> {
139    type ContainerHandle = BuildHandle<CfgID>;
140    #[inline]
141    fn finish_sub_container(self) -> Result<Self::ContainerHandle, BuildError> {
142        Ok((self.cfg_node, self.n_out_wires).into())
143    }
144}
145
146impl CFGBuilder<Hugr> {
147    /// New CFG rooted HUGR builder
148    pub fn new(signature: Signature) -> Result<Self, BuildError> {
149        let cfg_op = ops::CFG {
150            signature: signature.clone(),
151        };
152
153        let base = Hugr::new(cfg_op);
154        let cfg_node = base.root();
155        CFGBuilder::create(base, cfg_node, signature.input, signature.output)
156    }
157}
158
159impl HugrBuilder for CFGBuilder<Hugr> {
160    fn finish_hugr(mut self) -> Result<Hugr, crate::hugr::ValidationError> {
161        if cfg!(feature = "extension_inference") {
162            self.base.infer_extensions(false)?;
163        }
164        self.base.validate()?;
165        Ok(self.base)
166    }
167}
168
169impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
170    pub(super) fn create(
171        mut base: B,
172        cfg_node: Node,
173        input: TypeRow,
174        output: TypeRow,
175    ) -> Result<Self, BuildError> {
176        let n_out_wires = output.len();
177        let exit_block_type = OpType::ExitBlock(ExitBlock {
178            cfg_outputs: output,
179        });
180        let exit_node = base
181            .as_mut()
182            // Make the extensions a parameter
183            .add_node_with_parent(cfg_node, exit_block_type);
184        Ok(Self {
185            base,
186            cfg_node,
187            n_out_wires,
188            exit_node,
189            inputs: Some(input),
190        })
191    }
192
193    /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
194    /// and `outputs` and the variants of the branching Sum value
195    /// specified by `sum_rows`. Extension delta will be inferred.
196    ///
197    /// # Errors
198    ///
199    /// This function will return an error if there is an error adding the node.
200    pub fn block_builder(
201        &mut self,
202        inputs: TypeRow,
203        sum_rows: impl IntoIterator<Item = TypeRow>,
204        other_outputs: TypeRow,
205    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
206        self.block_builder_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED)
207    }
208
209    /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
210    /// and `outputs` and the variants of the branching Sum value
211    /// specified by `sum_rows`. Extension delta will be inferred.
212    ///
213    /// # Errors
214    ///
215    /// This function will return an error if there is an error adding the node.
216    pub fn block_builder_exts(
217        &mut self,
218        inputs: TypeRow,
219        sum_rows: impl IntoIterator<Item = TypeRow>,
220        other_outputs: TypeRow,
221        extension_delta: impl Into<ExtensionSet>,
222    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
223        self.any_block_builder(
224            inputs,
225            extension_delta.into(),
226            sum_rows,
227            other_outputs,
228            false,
229        )
230    }
231
232    fn any_block_builder(
233        &mut self,
234        inputs: TypeRow,
235        extension_delta: ExtensionSet,
236        sum_rows: impl IntoIterator<Item = TypeRow>,
237        other_outputs: TypeRow,
238        entry: bool,
239    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
240        let sum_rows: Vec<_> = sum_rows.into_iter().collect();
241        let op = OpType::DataflowBlock(DataflowBlock {
242            inputs: inputs.clone(),
243            other_outputs: other_outputs.clone(),
244            sum_rows,
245            extension_delta,
246        });
247        let parent = self.container_node();
248        let block_n = if entry {
249            let exit = self.exit_node;
250            // TODO: Make extensions a parameter
251            self.hugr_mut().add_node_before(exit, op)
252        } else {
253            // TODO: Make extensions a parameter
254            self.hugr_mut().add_node_with_parent(parent, op)
255        };
256
257        BlockBuilder::create(self.hugr_mut(), block_n)
258    }
259
260    /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs`
261    /// and `outputs` and `extension_delta` explicitly specified, plus a UnitSum type
262    /// (a Sum of `n_cases` unit types) to select the successor.
263    ///
264    /// # Errors
265    ///
266    /// This function will return an error if there is an error adding the node.
267    pub fn simple_block_builder(
268        &mut self,
269        signature: Signature,
270        n_cases: usize,
271    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
272        self.block_builder_exts(
273            signature.input,
274            vec![type_row![]; n_cases],
275            signature.output,
276            signature.runtime_reqs,
277        )
278    }
279
280    /// Return a builder for the entry [`DataflowBlock`] child graph with `outputs`
281    /// and the variants of the branching Sum value specified by `sum_rows`.
282    /// Extension delta will be inferred.
283    ///
284    /// # Errors
285    ///
286    /// This function will return an error if an entry block has already been built.
287    pub fn entry_builder(
288        &mut self,
289        sum_rows: impl IntoIterator<Item = TypeRow>,
290        other_outputs: TypeRow,
291    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
292        self.entry_builder_exts(sum_rows, other_outputs, TO_BE_INFERRED)
293    }
294
295    /// Return a builder for the entry [`DataflowBlock`] child graph with `outputs`,
296    /// the variants of the branching Sum value specified by `sum_rows`, and
297    /// `extension_delta` explicitly specified. ([entry_builder](Self::entry_builder)
298    /// may be used to infer.)
299    ///
300    /// # Errors
301    ///
302    /// This function will return an error if an entry block has already been built.
303    pub fn entry_builder_exts(
304        &mut self,
305        sum_rows: impl IntoIterator<Item = TypeRow>,
306        other_outputs: TypeRow,
307        extension_delta: impl Into<ExtensionSet>,
308    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
309        let inputs = self
310            .inputs
311            .take()
312            .ok_or(BuildError::EntryBuiltError(self.cfg_node))?;
313        self.any_block_builder(
314            inputs,
315            extension_delta.into(),
316            sum_rows,
317            other_outputs,
318            true,
319        )
320    }
321
322    /// Return a builder for the entry [`DataflowBlock`] child graph with
323    /// `outputs` and a UnitSum type: a Sum of `n_cases` unit types.
324    ///
325    /// # Errors
326    ///
327    /// This function will return an error if there is an error adding the node.
328    pub fn simple_entry_builder(
329        &mut self,
330        outputs: TypeRow,
331        n_cases: usize,
332    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
333        self.entry_builder(vec![type_row![]; n_cases], outputs)
334    }
335
336    /// Return a builder for the entry [`DataflowBlock`] child graph with
337    /// `outputs` and a Sum of `n_cases` unit types, and explicit `extension_delta`.
338    /// ([simple_entry_builder](Self::simple_entry_builder) may be used to infer.)
339    ///
340    /// # Errors
341    ///
342    /// This function will return an error if there is an error adding the node.
343    pub fn simple_entry_builder_exts(
344        &mut self,
345        outputs: TypeRow,
346        n_cases: usize,
347        extension_delta: impl Into<ExtensionSet>,
348    ) -> Result<BlockBuilder<&mut Hugr>, BuildError> {
349        self.entry_builder_exts(vec![type_row![]; n_cases], outputs, extension_delta)
350    }
351
352    /// Returns the exit block of this [`CFGBuilder`].
353    pub fn exit_block(&self) -> BasicBlockID {
354        self.exit_node.into()
355    }
356
357    /// Set the `branch` index `successor` block of `predecessor`.
358    ///
359    /// # Errors
360    ///
361    /// This function will return an error if there is an error connecting the blocks.
362    pub fn branch(
363        &mut self,
364        predecessor: &BasicBlockID,
365        branch: usize,
366        successor: &BasicBlockID,
367    ) -> Result<(), BuildError> {
368        let from = predecessor.node();
369        let to = successor.node();
370        self.hugr_mut().connect(from, branch, to, 0);
371        Ok(())
372    }
373}
374
375/// Builder for a [`DataflowBlock`] child graph.
376pub type BlockBuilder<B> = DFGWrapper<B, BasicBlockID>;
377
378impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
379    /// Set the outputs of the block, with `branch_wire` carrying  the value of the
380    /// branch controlling Sum value.  `outputs` are the remaining outputs.
381    pub fn set_outputs(
382        &mut self,
383        branch_wire: Wire,
384        outputs: impl IntoIterator<Item = Wire>,
385    ) -> Result<(), BuildError> {
386        Dataflow::set_outputs(self, [branch_wire].into_iter().chain(outputs))
387    }
388    fn create(base: B, block_n: Node) -> Result<Self, BuildError> {
389        let block_op = base
390            .as_ref()
391            .get_optype(block_n)
392            .as_dataflow_block()
393            .unwrap();
394        let signature = block_op.inner_signature().into_owned();
395        let db = DFGBuilder::create_with_io(base, block_n, signature)?;
396        Ok(BlockBuilder::from_dfg_builder(db))
397    }
398
399    /// [Set outputs](BlockBuilder::set_outputs) and [finish](`BlockBuilder::finish_sub_container`).
400    pub fn finish_with_outputs(
401        mut self,
402        branch_wire: Wire,
403        outputs: impl IntoIterator<Item = Wire>,
404    ) -> Result<<Self as SubContainer>::ContainerHandle, BuildError>
405    where
406        Self: Sized,
407    {
408        self.set_outputs(branch_wire, outputs)?;
409        self.finish_sub_container()
410    }
411}
412
413impl BlockBuilder<Hugr> {
414    /// Initialize a [`DataflowBlock`] rooted HUGR builder.
415    /// Extension delta will be inferred.
416    pub fn new(
417        inputs: impl Into<TypeRow>,
418        sum_rows: impl IntoIterator<Item = TypeRow>,
419        other_outputs: impl Into<TypeRow>,
420    ) -> Result<Self, BuildError> {
421        Self::new_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED)
422    }
423
424    /// Initialize a [`DataflowBlock`] rooted HUGR builder.
425    /// `extension_delta` is explicitly specified; alternatively, [new](Self::new)
426    /// may be used to infer it.
427    pub fn new_exts(
428        inputs: impl Into<TypeRow>,
429        sum_rows: impl IntoIterator<Item = TypeRow>,
430        other_outputs: impl Into<TypeRow>,
431        extension_delta: impl Into<ExtensionSet>,
432    ) -> Result<Self, BuildError> {
433        let inputs = inputs.into();
434        let sum_rows: Vec<_> = sum_rows.into_iter().collect();
435        let other_outputs = other_outputs.into();
436        let op = DataflowBlock {
437            inputs: inputs.clone(),
438            other_outputs: other_outputs.clone(),
439            sum_rows,
440            extension_delta: extension_delta.into(),
441        };
442
443        let base = Hugr::new(op);
444        let root = base.root();
445        Self::create(base, root)
446    }
447
448    /// [Set outputs](BlockBuilder::set_outputs) and [finish_hugr](`BlockBuilder::finish_hugr`).
449    pub fn finish_hugr_with_outputs(
450        mut self,
451        branch_wire: Wire,
452        outputs: impl IntoIterator<Item = Wire>,
453    ) -> Result<Hugr, BuildError> {
454        self.set_outputs(branch_wire, outputs)?;
455        self.finish_hugr().map_err(BuildError::InvalidHUGR)
456    }
457}
458
459#[cfg(test)]
460pub(crate) mod test {
461    use crate::builder::{DataflowSubContainer, ModuleBuilder};
462
463    use crate::extension::prelude::usize_t;
464    use crate::hugr::validate::InterGraphEdgeError;
465    use crate::hugr::ValidationError;
466    use crate::type_row;
467    use cool_asserts::assert_matches;
468
469    use super::*;
470    #[test]
471    fn basic_module_cfg() -> Result<(), BuildError> {
472        let build_result = {
473            let mut module_builder = ModuleBuilder::new();
474            let mut func_builder = module_builder
475                .define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))?;
476            let _f_id = {
477                let [int] = func_builder.input_wires_arr();
478
479                let cfg_id = {
480                    let mut cfg_builder =
481                        func_builder.cfg_builder(vec![(usize_t(), int)], vec![usize_t()].into())?;
482                    build_basic_cfg(&mut cfg_builder)?;
483
484                    cfg_builder.finish_sub_container()?
485                };
486
487                func_builder.finish_with_outputs(cfg_id.outputs())?
488            };
489            module_builder.finish_hugr()
490        };
491
492        assert!(build_result.is_ok(), "{}", build_result.unwrap_err());
493
494        Ok(())
495    }
496    #[test]
497    fn basic_cfg_hugr() -> Result<(), BuildError> {
498        let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
499        build_basic_cfg(&mut cfg_builder)?;
500        assert_matches!(cfg_builder.finish_hugr(), Ok(_));
501
502        Ok(())
503    }
504
505    pub(crate) fn build_basic_cfg<T: AsMut<Hugr> + AsRef<Hugr>>(
506        cfg_builder: &mut CFGBuilder<T>,
507    ) -> Result<(), BuildError> {
508        let usize_row: TypeRow = vec![usize_t()].into();
509        let sum2_variants = vec![usize_row.clone(), usize_row];
510        let mut entry_b = cfg_builder.entry_builder_exts(
511            sum2_variants.clone(),
512            type_row![],
513            ExtensionSet::new(),
514        )?;
515        let entry = {
516            let [inw] = entry_b.input_wires_arr();
517
518            let sum = entry_b.make_sum(1, sum2_variants, [inw])?;
519            entry_b.finish_with_outputs(sum, [])?
520        };
521        let mut middle_b = cfg_builder
522            .simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?;
523        let middle = {
524            let c = middle_b.add_load_const(ops::Value::unary_unit_sum());
525            let [inw] = middle_b.input_wires_arr();
526            middle_b.finish_with_outputs(c, [inw])?
527        };
528        let exit = cfg_builder.exit_block();
529        cfg_builder.branch(&entry, 0, &middle)?;
530        cfg_builder.branch(&middle, 0, &exit)?;
531        cfg_builder.branch(&entry, 1, &exit)?;
532        Ok(())
533    }
534    #[test]
535    fn test_dom_edge() -> Result<(), BuildError> {
536        let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
537        let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
538        let sum_variants = vec![type_row![]];
539
540        let mut entry_b = cfg_builder.entry_builder_exts(
541            sum_variants.clone(),
542            type_row![],
543            ExtensionSet::new(),
544        )?;
545        let [inw] = entry_b.input_wires_arr();
546        let entry = {
547            let sum = entry_b.load_const(&sum_tuple_const);
548
549            entry_b.finish_with_outputs(sum, [])?
550        };
551        let mut middle_b =
552            cfg_builder.simple_block_builder(Signature::new(type_row![], vec![usize_t()]), 1)?;
553        let middle = {
554            let c = middle_b.load_const(&sum_tuple_const);
555            middle_b.finish_with_outputs(c, [inw])?
556        };
557        let exit = cfg_builder.exit_block();
558        cfg_builder.branch(&entry, 0, &middle)?;
559        cfg_builder.branch(&middle, 0, &exit)?;
560        assert_matches!(cfg_builder.finish_hugr(), Ok(_));
561
562        Ok(())
563    }
564
565    #[test]
566    fn test_non_dom_edge() -> Result<(), BuildError> {
567        let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
568        let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
569        let sum_variants = vec![type_row![]];
570        let mut middle_b = cfg_builder
571            .simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?;
572        let [inw] = middle_b.input_wires_arr();
573        let middle = {
574            let c = middle_b.load_const(&sum_tuple_const);
575            middle_b.finish_with_outputs(c, [inw])?
576        };
577
578        let mut entry_b =
579            cfg_builder.entry_builder(sum_variants.clone(), vec![usize_t()].into())?;
580        let entry = {
581            let sum = entry_b.load_const(&sum_tuple_const);
582            // entry block uses wire from middle block even though middle block
583            // does not dominate entry
584            entry_b.finish_with_outputs(sum, [inw])?
585        };
586        let exit = cfg_builder.exit_block();
587        cfg_builder.branch(&entry, 0, &middle)?;
588        cfg_builder.branch(&middle, 0, &exit)?;
589        assert_matches!(
590            cfg_builder.finish_hugr(),
591            Err(ValidationError::InterGraphEdgeError(
592                InterGraphEdgeError::NonDominatedAncestor { .. }
593            ))
594        );
595
596        Ok(())
597    }
598}