hugr_core/builder/
tail_loop.rs1use crate::ops::{self, DataflowOpTrait};
2
3use crate::hugr::views::HugrView;
4use crate::types::{Signature, TypeRow};
5use crate::{Hugr, Node};
6
7use super::handle::BuildHandle;
8use super::{
9 BuildError, Container, Dataflow, TailLoopID, Wire,
10 dataflow::{DFGBuilder, DFGWrapper},
11};
12
13pub type TailLoopBuilder<B> = DFGWrapper<B, BuildHandle<TailLoopID>>;
15
16impl<B: AsMut<Hugr> + AsRef<Hugr>> TailLoopBuilder<B> {
17 pub(super) fn create_with_io(
18 base: B,
19 loop_node: Node,
20 tail_loop: &ops::TailLoop,
21 ) -> Result<Self, BuildError> {
22 let signature = Signature::new(tail_loop.body_input_row(), tail_loop.body_output_row());
23 let dfg_build = DFGBuilder::create_with_io(base, loop_node, signature)?;
24
25 Ok(TailLoopBuilder::from_dfg_builder(dfg_build))
26 }
27 pub fn set_outputs(
30 &mut self,
31 out_variant: Wire,
32 rest: impl IntoIterator<Item = Wire>,
33 ) -> Result<(), BuildError> {
34 Dataflow::set_outputs(self, [out_variant].into_iter().chain(rest))
35 }
36
37 pub fn loop_signature(&self) -> Result<&ops::TailLoop, BuildError> {
40 self.hugr()
41 .get_optype(self.container_node())
42 .as_tail_loop()
43 .ok_or(BuildError::UnexpectedType {
44 node: self.container_node(),
45 op_desc: "crate::ops::TailLoop",
46 })
47 }
48
49 pub fn internal_output_row(&self) -> Result<TypeRow, BuildError> {
51 self.loop_signature().map(ops::TailLoop::body_output_row)
52 }
53
54 pub fn finish_with_outputs(
56 mut self,
57 out_variant: Wire,
58 rest: impl IntoIterator<Item = Wire>,
59 ) -> Result<BuildHandle<TailLoopID>, BuildError>
60 where
61 Self: Sized,
62 {
63 self.set_outputs(out_variant, rest)?;
64 Ok((
65 self.container_node(),
66 self.loop_signature()?.signature().output_count(),
67 )
68 .into())
69 }
70}
71
72impl TailLoopBuilder<Hugr> {
73 pub fn new(
75 just_inputs: impl Into<TypeRow>,
76 inputs_outputs: impl Into<TypeRow>,
77 just_outputs: impl Into<TypeRow>,
78 ) -> Result<Self, BuildError> {
79 let tail_loop = ops::TailLoop {
80 just_inputs: just_inputs.into(),
81 just_outputs: just_outputs.into(),
82 rest: inputs_outputs.into(),
83 };
84 let base = Hugr::new_with_entrypoint(tail_loop.clone())
85 .expect("tail_loop entrypoint should be valid");
86 let root = base.entrypoint();
87 Self::create_with_io(base, root, &tail_loop)
88 }
89}
90
91#[cfg(test)]
92mod test {
93 use cool_asserts::assert_matches;
94
95 use crate::extension::prelude::bool_t;
96 use crate::{
97 builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer},
98 extension::prelude::{ConstUsize, usize_t},
99 hugr::ValidationError,
100 ops::Value,
101 type_row,
102 types::Signature,
103 };
104
105 use super::*;
106 #[test]
107 fn basic_loop() -> Result<(), BuildError> {
108 let build_result: Result<Hugr, ValidationError<_>> = {
109 let mut loop_b = TailLoopBuilder::new(vec![], vec![bool_t()], vec![usize_t()])?;
110 let [i1] = loop_b.input_wires_arr();
111 let const_wire = loop_b.add_load_value(ConstUsize::new(1));
112
113 let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?;
114 loop_b.set_outputs(break_wire, [i1])?;
115 loop_b.finish_hugr()
116 };
117
118 assert_matches!(build_result, Ok(_));
119 Ok(())
120 }
121
122 #[test]
123 fn loop_with_conditional() -> Result<(), BuildError> {
124 let build_result = {
125 let mut module_builder = ModuleBuilder::new();
126 let mut fbuild = module_builder
127 .define_function("main", Signature::new(vec![bool_t()], vec![usize_t()]))?;
128 let _fdef = {
129 let [b1] = fbuild.input_wires_arr();
130 let loop_id = {
131 let mut loop_b = fbuild.tail_loop_builder(
132 vec![(bool_t(), b1)],
133 vec![],
134 vec![usize_t()].into(),
135 )?;
136 let signature = loop_b.loop_signature()?.clone();
137 let const_wire = loop_b.add_load_const(Value::true_val());
138 let [b1] = loop_b.input_wires_arr();
139 let conditional_id = {
140 let output_row = loop_b.internal_output_row()?;
141 let mut conditional_b = loop_b.conditional_builder(
142 ([type_row![], type_row![]], const_wire),
143 vec![(bool_t(), b1)],
144 output_row,
145 )?;
146
147 let mut branch_0 = conditional_b.case_builder(0)?;
148 let [b1] = branch_0.input_wires_arr();
149
150 let continue_wire = branch_0.make_continue(signature.clone(), [b1])?;
151 branch_0.finish_with_outputs([continue_wire])?;
152
153 let mut branch_1 = conditional_b.case_builder(1)?;
154 let [_b1] = branch_1.input_wires_arr();
155
156 let wire = branch_1.add_load_value(ConstUsize::new(2));
157 let break_wire = branch_1.make_break(signature, [wire])?;
158 branch_1.finish_with_outputs([break_wire])?;
159
160 conditional_b.finish_sub_container()?
161 };
162 loop_b.finish_with_outputs(conditional_id.out_wire(0), [])?
163 };
164 fbuild.finish_with_outputs(loop_id.outputs())?
165 };
166 module_builder.finish_hugr()
167 };
168
169 assert_matches!(build_result, Ok(_));
170
171 Ok(())
172 }
173
174 #[test]
175 fn tailloop_output_arr() {
177 let mut builder = TailLoopBuilder::new(type_row![], type_row![], type_row![]).unwrap();
178 let control = builder.add_load_value(Value::false_val());
179 let tailloop = builder.finish_with_outputs(control, []).unwrap();
180 let [] = tailloop.outputs_arr();
181 }
182}