hugr_core/builder/
tail_loop.rs

1use crate::ops::{self, DataflowOpTrait};
2
3use crate::hugr::views::HugrView;
4use crate::types::{Signature, TypeRow};
5use crate::{Hugr, Node};
6
7use super::handle::BuildHandle;
8use super::{
9    BuildError, Container, Dataflow, TailLoopID, Wire,
10    dataflow::{DFGBuilder, DFGWrapper},
11};
12
13/// Builder for a [`ops::TailLoop`] node.
14pub type TailLoopBuilder<B> = DFGWrapper<B, BuildHandle<TailLoopID>>;
15
16impl<B: AsMut<Hugr> + AsRef<Hugr>> TailLoopBuilder<B> {
17    pub(super) fn create_with_io(
18        base: B,
19        loop_node: Node,
20        tail_loop: &ops::TailLoop,
21    ) -> Result<Self, BuildError> {
22        let signature = Signature::new(tail_loop.body_input_row(), tail_loop.body_output_row());
23        let dfg_build = DFGBuilder::create_with_io(base, loop_node, signature)?;
24
25        Ok(TailLoopBuilder::from_dfg_builder(dfg_build))
26    }
27    /// Set the outputs of the [`ops::TailLoop`], with `out_variant` as the value of the
28    /// termination Sum, and `rest` being the remaining outputs.
29    pub fn set_outputs(
30        &mut self,
31        out_variant: Wire,
32        rest: impl IntoIterator<Item = Wire>,
33    ) -> Result<(), BuildError> {
34        Dataflow::set_outputs(self, [out_variant].into_iter().chain(rest))
35    }
36
37    /// Get a reference to the [`ops::TailLoop`]
38    /// that defines the signature of the [`ops::TailLoop`]
39    pub fn loop_signature(&self) -> Result<&ops::TailLoop, BuildError> {
40        self.hugr()
41            .get_optype(self.container_node())
42            .as_tail_loop()
43            .ok_or(BuildError::UnexpectedType {
44                node: self.container_node(),
45                op_desc: "crate::ops::TailLoop",
46            })
47    }
48
49    /// The output types of the child graph, including the Sum as the first.
50    pub fn internal_output_row(&self) -> Result<TypeRow, BuildError> {
51        self.loop_signature().map(ops::TailLoop::body_output_row)
52    }
53
54    /// Set outputs and finish, see [`TailLoopBuilder::set_outputs`]
55    pub fn finish_with_outputs(
56        mut self,
57        out_variant: Wire,
58        rest: impl IntoIterator<Item = Wire>,
59    ) -> Result<BuildHandle<TailLoopID>, BuildError>
60    where
61        Self: Sized,
62    {
63        self.set_outputs(out_variant, rest)?;
64        Ok((
65            self.container_node(),
66            self.loop_signature()?.signature().output_count(),
67        )
68            .into())
69    }
70}
71
72impl TailLoopBuilder<Hugr> {
73    /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR.
74    pub fn new(
75        just_inputs: impl Into<TypeRow>,
76        inputs_outputs: impl Into<TypeRow>,
77        just_outputs: impl Into<TypeRow>,
78    ) -> Result<Self, BuildError> {
79        let tail_loop = ops::TailLoop {
80            just_inputs: just_inputs.into(),
81            just_outputs: just_outputs.into(),
82            rest: inputs_outputs.into(),
83        };
84        let base = Hugr::new_with_entrypoint(tail_loop.clone())
85            .expect("tail_loop entrypoint should be valid");
86        let root = base.entrypoint();
87        Self::create_with_io(base, root, &tail_loop)
88    }
89}
90
91#[cfg(test)]
92mod test {
93    use cool_asserts::assert_matches;
94
95    use crate::extension::prelude::bool_t;
96    use crate::{
97        builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer},
98        extension::prelude::{ConstUsize, usize_t},
99        hugr::ValidationError,
100        ops::Value,
101        type_row,
102        types::Signature,
103    };
104
105    use super::*;
106    #[test]
107    fn basic_loop() -> Result<(), BuildError> {
108        let build_result: Result<Hugr, ValidationError<_>> = {
109            let mut loop_b = TailLoopBuilder::new(vec![], vec![bool_t()], vec![usize_t()])?;
110            let [i1] = loop_b.input_wires_arr();
111            let const_wire = loop_b.add_load_value(ConstUsize::new(1));
112
113            let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
114            loop_b.set_outputs(break_wire, [i1])?;
115            loop_b.finish_hugr()
116        };
117
118        assert_matches!(build_result, Ok(_));
119        Ok(())
120    }
121
122    #[test]
123    fn loop_with_conditional() -> Result<(), BuildError> {
124        let build_result = {
125            let mut module_builder = ModuleBuilder::new();
126            let mut fbuild = module_builder
127                .define_function("main", Signature::new(vec![bool_t()], vec![usize_t()]))?;
128            let _fdef = {
129                let [b1] = fbuild.input_wires_arr();
130                let loop_id = {
131                    let mut loop_b = fbuild.tail_loop_builder(
132                        vec![(bool_t(), b1)],
133                        vec![],
134                        vec![usize_t()].into(),
135                    )?;
136                    let signature = loop_b.loop_signature()?.clone();
137                    let const_wire = loop_b.add_load_const(Value::true_val());
138                    let [b1] = loop_b.input_wires_arr();
139                    let conditional_id = {
140                        let output_row = loop_b.internal_output_row()?;
141                        let mut conditional_b = loop_b.conditional_builder(
142                            ([type_row![], type_row![]], const_wire),
143                            vec![(bool_t(), b1)],
144                            output_row,
145                        )?;
146
147                        let mut branch_0 = conditional_b.case_builder(0)?;
148                        let [b1] = branch_0.input_wires_arr();
149
150                        let continue_wire = branch_0.make_continue(signature.clone(), [b1])?;
151                        branch_0.finish_with_outputs([continue_wire])?;
152
153                        let mut branch_1 = conditional_b.case_builder(1)?;
154                        let [_b1] = branch_1.input_wires_arr();
155
156                        let wire = branch_1.add_load_value(ConstUsize::new(2));
157                        let break_wire = branch_1.make_break(signature, [wire])?;
158                        branch_1.finish_with_outputs([break_wire])?;
159
160                        conditional_b.finish_sub_container()?
161                    };
162                    loop_b.finish_with_outputs(conditional_id.out_wire(0), [])?
163                };
164                fbuild.finish_with_outputs(loop_id.outputs())?
165            };
166            module_builder.finish_hugr()
167        };
168
169        assert_matches!(build_result, Ok(_));
170
171        Ok(())
172    }
173
174    #[test]
175    // fixed: issue 1257: When building a TailLoop, calling outputs_arr, you are given an OrderEdge "output wire"
176    fn tailloop_output_arr() {
177        let mut builder = TailLoopBuilder::new(type_row![], type_row![], type_row![]).unwrap();
178        let control = builder.add_load_value(Value::false_val());
179        let tailloop = builder.finish_with_outputs(control, []).unwrap();
180        let [] = tailloop.outputs_arr();
181    }
182}