hugr_core/builder/
tail_loop.rs

1use crate::extension::{ExtensionSet, TO_BE_INFERRED};
2use crate::ops::{self, DataflowOpTrait};
3
4use crate::hugr::views::HugrView;
5use crate::types::{Signature, TypeRow};
6use crate::{Hugr, Node};
7
8use super::handle::BuildHandle;
9use super::{
10    dataflow::{DFGBuilder, DFGWrapper},
11    BuildError, Container, Dataflow, TailLoopID, Wire,
12};
13
14/// Builder for a [`ops::TailLoop`] node.
15pub type TailLoopBuilder<B> = DFGWrapper<B, BuildHandle<TailLoopID>>;
16
17impl<B: AsMut<Hugr> + AsRef<Hugr>> TailLoopBuilder<B> {
18    pub(super) fn create_with_io(
19        base: B,
20        loop_node: Node,
21        tail_loop: &ops::TailLoop,
22    ) -> Result<Self, BuildError> {
23        let signature = Signature::new(tail_loop.body_input_row(), tail_loop.body_output_row());
24        let dfg_build = DFGBuilder::create_with_io(base, loop_node, signature)?;
25
26        Ok(TailLoopBuilder::from_dfg_builder(dfg_build))
27    }
28    /// Set the outputs of the [`ops::TailLoop`], with `out_variant` as the value of the
29    /// termination Sum, and `rest` being the remaining outputs.
30    pub fn set_outputs(
31        &mut self,
32        out_variant: Wire,
33        rest: impl IntoIterator<Item = Wire>,
34    ) -> Result<(), BuildError> {
35        Dataflow::set_outputs(self, [out_variant].into_iter().chain(rest))
36    }
37
38    /// Get a reference to the [`ops::TailLoop`]
39    /// that defines the signature of the [`ops::TailLoop`]
40    pub fn loop_signature(&self) -> Result<&ops::TailLoop, BuildError> {
41        self.hugr()
42            .get_optype(self.container_node())
43            .as_tail_loop()
44            .ok_or(BuildError::UnexpectedType {
45                node: self.container_node(),
46                op_desc: "crate::ops::TailLoop",
47            })
48    }
49
50    /// The output types of the child graph, including the Sum as the first.
51    pub fn internal_output_row(&self) -> Result<TypeRow, BuildError> {
52        self.loop_signature().map(ops::TailLoop::body_output_row)
53    }
54
55    /// Set outputs and finish, see [`TailLoopBuilder::set_outputs`]
56    pub fn finish_with_outputs(
57        mut self,
58        out_variant: Wire,
59        rest: impl IntoIterator<Item = Wire>,
60    ) -> Result<BuildHandle<TailLoopID>, BuildError>
61    where
62        Self: Sized,
63    {
64        self.set_outputs(out_variant, rest)?;
65        Ok((
66            self.container_node(),
67            self.loop_signature()?.signature().output_count(),
68        )
69            .into())
70    }
71}
72
73impl TailLoopBuilder<Hugr> {
74    /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR.
75    /// Extension delta will be inferred.
76    pub fn new(
77        just_inputs: impl Into<TypeRow>,
78        inputs_outputs: impl Into<TypeRow>,
79        just_outputs: impl Into<TypeRow>,
80    ) -> Result<Self, BuildError> {
81        Self::new_exts(just_inputs, inputs_outputs, just_outputs, TO_BE_INFERRED)
82    }
83
84    /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR.
85    /// `extension_delta` is explicitly specified; alternatively, [new](Self::new)
86    /// may be used to infer it.
87    pub fn new_exts(
88        just_inputs: impl Into<TypeRow>,
89        inputs_outputs: impl Into<TypeRow>,
90        just_outputs: impl Into<TypeRow>,
91        extension_delta: impl Into<ExtensionSet>,
92    ) -> Result<Self, BuildError> {
93        let tail_loop = ops::TailLoop {
94            just_inputs: just_inputs.into(),
95            just_outputs: just_outputs.into(),
96            rest: inputs_outputs.into(),
97            extension_delta: extension_delta.into(),
98        };
99        let base = Hugr::new(tail_loop.clone());
100        let root = base.root();
101        Self::create_with_io(base, root, &tail_loop)
102    }
103}
104
105#[cfg(test)]
106mod test {
107    use cool_asserts::assert_matches;
108
109    use crate::extension::prelude::bool_t;
110    use crate::{
111        builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer},
112        extension::prelude::{usize_t, ConstUsize, PRELUDE_ID},
113        hugr::ValidationError,
114        ops::Value,
115        type_row,
116        types::Signature,
117    };
118
119    use super::*;
120    #[test]
121    fn basic_loop() -> Result<(), BuildError> {
122        let build_result: Result<Hugr, ValidationError> = {
123            let mut loop_b =
124                TailLoopBuilder::new_exts(vec![], vec![bool_t()], vec![usize_t()], PRELUDE_ID)?;
125            let [i1] = loop_b.input_wires_arr();
126            let const_wire = loop_b.add_load_value(ConstUsize::new(1));
127
128            let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
129            loop_b.set_outputs(break_wire, [i1])?;
130            loop_b.finish_hugr()
131        };
132
133        assert_matches!(build_result, Ok(_));
134        Ok(())
135    }
136
137    #[test]
138    fn loop_with_conditional() -> Result<(), BuildError> {
139        let build_result = {
140            let mut module_builder = ModuleBuilder::new();
141            let mut fbuild = module_builder.define_function(
142                "main",
143                Signature::new(vec![bool_t()], vec![usize_t()]).with_prelude(),
144            )?;
145            let _fdef = {
146                let [b1] = fbuild.input_wires_arr();
147                let loop_id = {
148                    let mut loop_b = fbuild.tail_loop_builder(
149                        vec![(bool_t(), b1)],
150                        vec![],
151                        vec![usize_t()].into(),
152                    )?;
153                    let signature = loop_b.loop_signature()?.clone();
154                    let const_wire = loop_b.add_load_const(Value::true_val());
155                    let [b1] = loop_b.input_wires_arr();
156                    let conditional_id = {
157                        let output_row = loop_b.internal_output_row()?;
158                        let mut conditional_b = loop_b.conditional_builder(
159                            ([type_row![], type_row![]], const_wire),
160                            vec![(bool_t(), b1)],
161                            output_row,
162                        )?;
163
164                        let mut branch_0 = conditional_b.case_builder(0)?;
165                        let [b1] = branch_0.input_wires_arr();
166
167                        let continue_wire = branch_0.make_continue(signature.clone(), [b1])?;
168                        branch_0.finish_with_outputs([continue_wire])?;
169
170                        let mut branch_1 = conditional_b.case_builder(1)?;
171                        let [_b1] = branch_1.input_wires_arr();
172
173                        let wire = branch_1.add_load_value(ConstUsize::new(2));
174                        let break_wire = branch_1.make_break(signature, [wire])?;
175                        branch_1.finish_with_outputs([break_wire])?;
176
177                        conditional_b.finish_sub_container()?
178                    };
179                    loop_b.finish_with_outputs(conditional_id.out_wire(0), [])?
180                };
181                fbuild.finish_with_outputs(loop_id.outputs())?
182            };
183            module_builder.finish_hugr()
184        };
185
186        assert_matches!(build_result, Ok(_));
187
188        Ok(())
189    }
190
191    #[test]
192    // fixed: issue 1257: When building a TailLoop, calling outputs_arr, you are given an OrderEdge "output wire"
193    fn tailloop_output_arr() {
194        let mut builder = TailLoopBuilder::new(type_row![], type_row![], type_row![]).unwrap();
195        let control = builder.add_load_value(Value::false_val());
196        let tailloop = builder.finish_with_outputs(control, []).unwrap();
197        let [] = tailloop.outputs_arr();
198    }
199}