1use crate::extension::{ExtensionSet, TO_BE_INFERRED};
2use crate::ops::{self, DataflowOpTrait};
3
4use crate::hugr::views::HugrView;
5use crate::types::{Signature, TypeRow};
6use crate::{Hugr, Node};
7
8use super::handle::BuildHandle;
9use super::{
10 dataflow::{DFGBuilder, DFGWrapper},
11 BuildError, Container, Dataflow, TailLoopID, Wire,
12};
13
14pub type TailLoopBuilder<B> = DFGWrapper<B, BuildHandle<TailLoopID>>;
16
17impl<B: AsMut<Hugr> + AsRef<Hugr>> TailLoopBuilder<B> {
18 pub(super) fn create_with_io(
19 base: B,
20 loop_node: Node,
21 tail_loop: &ops::TailLoop,
22 ) -> Result<Self, BuildError> {
23 let signature = Signature::new(tail_loop.body_input_row(), tail_loop.body_output_row());
24 let dfg_build = DFGBuilder::create_with_io(base, loop_node, signature)?;
25
26 Ok(TailLoopBuilder::from_dfg_builder(dfg_build))
27 }
28 pub fn set_outputs(
31 &mut self,
32 out_variant: Wire,
33 rest: impl IntoIterator<Item = Wire>,
34 ) -> Result<(), BuildError> {
35 Dataflow::set_outputs(self, [out_variant].into_iter().chain(rest))
36 }
37
38 pub fn loop_signature(&self) -> Result<&ops::TailLoop, BuildError> {
41 self.hugr()
42 .get_optype(self.container_node())
43 .as_tail_loop()
44 .ok_or(BuildError::UnexpectedType {
45 node: self.container_node(),
46 op_desc: "crate::ops::TailLoop",
47 })
48 }
49
50 pub fn internal_output_row(&self) -> Result<TypeRow, BuildError> {
52 self.loop_signature().map(ops::TailLoop::body_output_row)
53 }
54
55 pub fn finish_with_outputs(
57 mut self,
58 out_variant: Wire,
59 rest: impl IntoIterator<Item = Wire>,
60 ) -> Result<BuildHandle<TailLoopID>, BuildError>
61 where
62 Self: Sized,
63 {
64 self.set_outputs(out_variant, rest)?;
65 Ok((
66 self.container_node(),
67 self.loop_signature()?.signature().output_count(),
68 )
69 .into())
70 }
71}
72
73impl TailLoopBuilder<Hugr> {
74 pub fn new(
77 just_inputs: impl Into<TypeRow>,
78 inputs_outputs: impl Into<TypeRow>,
79 just_outputs: impl Into<TypeRow>,
80 ) -> Result<Self, BuildError> {
81 Self::new_exts(just_inputs, inputs_outputs, just_outputs, TO_BE_INFERRED)
82 }
83
84 pub fn new_exts(
88 just_inputs: impl Into<TypeRow>,
89 inputs_outputs: impl Into<TypeRow>,
90 just_outputs: impl Into<TypeRow>,
91 extension_delta: impl Into<ExtensionSet>,
92 ) -> Result<Self, BuildError> {
93 let tail_loop = ops::TailLoop {
94 just_inputs: just_inputs.into(),
95 just_outputs: just_outputs.into(),
96 rest: inputs_outputs.into(),
97 extension_delta: extension_delta.into(),
98 };
99 let base = Hugr::new(tail_loop.clone());
100 let root = base.root();
101 Self::create_with_io(base, root, &tail_loop)
102 }
103}
104
105#[cfg(test)]
106mod test {
107 use cool_asserts::assert_matches;
108
109 use crate::extension::prelude::bool_t;
110 use crate::{
111 builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer},
112 extension::prelude::{usize_t, ConstUsize, PRELUDE_ID},
113 hugr::ValidationError,
114 ops::Value,
115 type_row,
116 types::Signature,
117 };
118
119 use super::*;
120 #[test]
121 fn basic_loop() -> Result<(), BuildError> {
122 let build_result: Result<Hugr, ValidationError> = {
123 let mut loop_b =
124 TailLoopBuilder::new_exts(vec![], vec![bool_t()], vec![usize_t()], PRELUDE_ID)?;
125 let [i1] = loop_b.input_wires_arr();
126 let const_wire = loop_b.add_load_value(ConstUsize::new(1));
127
128 let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
129 loop_b.set_outputs(break_wire, [i1])?;
130 loop_b.finish_hugr()
131 };
132
133 assert_matches!(build_result, Ok(_));
134 Ok(())
135 }
136
137 #[test]
138 fn loop_with_conditional() -> Result<(), BuildError> {
139 let build_result = {
140 let mut module_builder = ModuleBuilder::new();
141 let mut fbuild = module_builder.define_function(
142 "main",
143 Signature::new(vec![bool_t()], vec![usize_t()]).with_prelude(),
144 )?;
145 let _fdef = {
146 let [b1] = fbuild.input_wires_arr();
147 let loop_id = {
148 let mut loop_b = fbuild.tail_loop_builder(
149 vec![(bool_t(), b1)],
150 vec![],
151 vec![usize_t()].into(),
152 )?;
153 let signature = loop_b.loop_signature()?.clone();
154 let const_wire = loop_b.add_load_const(Value::true_val());
155 let [b1] = loop_b.input_wires_arr();
156 let conditional_id = {
157 let output_row = loop_b.internal_output_row()?;
158 let mut conditional_b = loop_b.conditional_builder(
159 ([type_row![], type_row![]], const_wire),
160 vec![(bool_t(), b1)],
161 output_row,
162 )?;
163
164 let mut branch_0 = conditional_b.case_builder(0)?;
165 let [b1] = branch_0.input_wires_arr();
166
167 let continue_wire = branch_0.make_continue(signature.clone(), [b1])?;
168 branch_0.finish_with_outputs([continue_wire])?;
169
170 let mut branch_1 = conditional_b.case_builder(1)?;
171 let [_b1] = branch_1.input_wires_arr();
172
173 let wire = branch_1.add_load_value(ConstUsize::new(2));
174 let break_wire = branch_1.make_break(signature, [wire])?;
175 branch_1.finish_with_outputs([break_wire])?;
176
177 conditional_b.finish_sub_container()?
178 };
179 loop_b.finish_with_outputs(conditional_id.out_wire(0), [])?
180 };
181 fbuild.finish_with_outputs(loop_id.outputs())?
182 };
183 module_builder.finish_hugr()
184 };
185
186 assert_matches!(build_result, Ok(_));
187
188 Ok(())
189 }
190
191 #[test]
192 fn tailloop_output_arr() {
194 let mut builder = TailLoopBuilder::new(type_row![], type_row![], type_row![]).unwrap();
195 let control = builder.add_load_value(Value::false_val());
196 let tailloop = builder.finish_with_outputs(control, []).unwrap();
197 let [] = tailloop.outputs_arr();
198 }
199}