1use std::borrow::Cow;
4
5use crate::Direction;
6use crate::types::{EdgeKind, Signature, Type, TypeRow};
7
8use super::OpTag;
9use super::dataflow::{DataflowOpTrait, DataflowParent};
10use super::{OpTrait, StaticTag, impl_op_name};
11
12#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
14#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
15pub struct TailLoop {
16 pub just_inputs: TypeRow,
18 pub just_outputs: TypeRow,
20 pub rest: TypeRow,
22}
23
24impl_op_name!(TailLoop);
25
26impl DataflowOpTrait for TailLoop {
27 const TAG: OpTag = OpTag::TailLoop;
28
29 fn description(&self) -> &'static str {
30 "A tail-controlled loop"
31 }
32
33 fn signature(&self) -> Cow<'_, Signature> {
34 let [inputs, outputs] =
36 [&self.just_inputs, &self.just_outputs].map(|row| row.extend(self.rest.iter()));
37 Cow::Owned(Signature::new(inputs, outputs))
38 }
39
40 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
41 Self {
42 just_inputs: self.just_inputs.substitute(subst),
43 just_outputs: self.just_outputs.substitute(subst),
44 rest: self.rest.substitute(subst),
45 }
46 }
47}
48
49impl TailLoop {
50 pub const CONTINUE_TAG: usize = 0;
54
55 pub const BREAK_TAG: usize = 1;
59
60 pub(crate) fn body_output_row(&self) -> TypeRow {
62 let mut outputs = vec![Type::new_sum(self.control_variants())];
63 outputs.extend_from_slice(&self.rest);
64 outputs.into()
65 }
66
67 pub(crate) fn control_variants(&self) -> [TypeRow; 2] {
69 [self.just_inputs.clone(), self.just_outputs.clone()]
70 }
71
72 pub(crate) fn body_input_row(&self) -> TypeRow {
74 self.just_inputs.extend(self.rest.iter())
75 }
76}
77
78impl DataflowParent for TailLoop {
79 fn inner_signature(&self) -> Cow<'_, Signature> {
80 Cow::Owned(Signature::new(
82 self.body_input_row(),
83 self.body_output_row(),
84 ))
85 }
86}
87
88#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
90#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
91pub struct Conditional {
92 pub sum_rows: Vec<TypeRow>,
94 pub other_inputs: TypeRow,
96 pub outputs: TypeRow,
98}
99impl_op_name!(Conditional);
100
101impl DataflowOpTrait for Conditional {
102 const TAG: OpTag = OpTag::Conditional;
103
104 fn description(&self) -> &'static str {
105 "HUGR conditional operation"
106 }
107
108 fn signature(&self) -> Cow<'_, Signature> {
109 let mut inputs = self.other_inputs.clone();
111 inputs
112 .to_mut()
113 .insert(0, Type::new_sum(self.sum_rows.clone()));
114 Cow::Owned(Signature::new(inputs, self.outputs.clone()))
115 }
116
117 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
118 Self {
119 sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(),
120 other_inputs: self.other_inputs.substitute(subst),
121 outputs: self.outputs.substitute(subst),
122 }
123 }
124}
125
126impl Conditional {
127 pub(crate) fn case_input_row(&self, case: usize) -> Option<TypeRow> {
129 Some(self.sum_rows.get(case)?.extend(self.other_inputs.iter()))
130 }
131}
132
133#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
135#[allow(missing_docs)]
136#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
137pub struct CFG {
138 pub signature: Signature,
139}
140
141impl_op_name!(CFG);
142
143impl DataflowOpTrait for CFG {
144 const TAG: OpTag = OpTag::Cfg;
145
146 fn description(&self) -> &'static str {
147 "A dataflow node defined by a child CFG"
148 }
149
150 fn signature(&self) -> Cow<'_, Signature> {
151 Cow::Borrowed(&self.signature)
152 }
153
154 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
155 Self {
156 signature: self.signature.substitute(subst),
157 }
158 }
159}
160
161#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
162#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
163#[allow(missing_docs)]
165pub struct DataflowBlock {
166 pub inputs: TypeRow,
167 pub other_outputs: TypeRow,
168 pub sum_rows: Vec<TypeRow>,
169}
170
171#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
172#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
173pub struct ExitBlock {
176 pub cfg_outputs: TypeRow,
178}
179
180impl_op_name!(DataflowBlock);
181impl_op_name!(ExitBlock);
182
183impl StaticTag for DataflowBlock {
184 const TAG: OpTag = OpTag::DataflowBlock;
185}
186
187impl StaticTag for ExitBlock {
188 const TAG: OpTag = OpTag::BasicBlockExit;
189}
190
191impl DataflowParent for DataflowBlock {
192 fn inner_signature(&self) -> Cow<'_, Signature> {
193 let sum_type = Type::new_sum(self.sum_rows.clone());
196 let mut node_outputs = vec![sum_type];
197 node_outputs.extend_from_slice(&self.other_outputs);
198 Cow::Owned(Signature::new(
199 self.inputs.clone(),
200 TypeRow::from(node_outputs),
201 ))
202 }
203}
204
205impl OpTrait for DataflowBlock {
206 fn description(&self) -> &'static str {
207 "A CFG basic block node"
208 }
209 fn tag(&self) -> OpTag {
211 Self::TAG
212 }
213
214 fn other_input(&self) -> Option<EdgeKind> {
215 Some(EdgeKind::ControlFlow)
216 }
217
218 fn other_output(&self) -> Option<EdgeKind> {
219 Some(EdgeKind::ControlFlow)
220 }
221
222 fn non_df_port_count(&self, dir: Direction) -> usize {
223 match dir {
224 Direction::Incoming => 1,
225 Direction::Outgoing => self.sum_rows.len(),
226 }
227 }
228
229 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
230 Self {
231 inputs: self.inputs.substitute(subst),
232 other_outputs: self.other_outputs.substitute(subst),
233 sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(),
234 }
235 }
236}
237
238impl OpTrait for ExitBlock {
239 fn description(&self) -> &'static str {
240 "A CFG exit block node"
241 }
242 fn tag(&self) -> OpTag {
244 Self::TAG
245 }
246
247 fn other_input(&self) -> Option<EdgeKind> {
248 Some(EdgeKind::ControlFlow)
249 }
250
251 fn other_output(&self) -> Option<EdgeKind> {
252 Some(EdgeKind::ControlFlow)
253 }
254
255 fn non_df_port_count(&self, dir: Direction) -> usize {
256 match dir {
257 Direction::Incoming => 1,
258 Direction::Outgoing => 0,
259 }
260 }
261
262 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
263 Self {
264 cfg_outputs: self.cfg_outputs.substitute(subst),
265 }
266 }
267}
268
269pub trait BasicBlock {
271 fn dataflow_input(&self) -> &TypeRow;
273}
274
275impl BasicBlock for DataflowBlock {
276 fn dataflow_input(&self) -> &TypeRow {
277 &self.inputs
278 }
279}
280impl DataflowBlock {
281 #[must_use]
284 pub fn successor_input(&self, successor: usize) -> Option<TypeRow> {
285 Some(
286 self.sum_rows
287 .get(successor)?
288 .extend(self.other_outputs.iter()),
289 )
290 }
291}
292
293impl BasicBlock for ExitBlock {
294 fn dataflow_input(&self) -> &TypeRow {
295 &self.cfg_outputs
296 }
297}
298
299#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
300#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
301pub struct Case {
303 pub signature: Signature,
305}
306
307impl_op_name!(Case);
308
309impl StaticTag for Case {
310 const TAG: OpTag = OpTag::Case;
311}
312
313impl DataflowParent for Case {
314 fn inner_signature(&self) -> Cow<'_, Signature> {
315 Cow::Borrowed(&self.signature)
316 }
317}
318
319impl OpTrait for Case {
320 fn description(&self) -> &'static str {
321 "A case node inside a conditional"
322 }
323
324 fn tag(&self) -> OpTag {
325 <Self as StaticTag>::TAG
326 }
327
328 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
329 Self {
330 signature: self.signature.substitute(subst),
331 }
332 }
333}
334
335impl Case {
336 #[must_use]
338 pub fn dataflow_input(&self) -> &TypeRow {
339 &self.signature.input
340 }
341
342 #[must_use]
344 pub fn dataflow_output(&self) -> &TypeRow {
345 &self.signature.output
346 }
347}
348
349#[cfg(test)]
350mod test {
351 use crate::{
352 extension::prelude::{qb_t, usize_t},
353 ops::{Conditional, DataflowOpTrait, DataflowParent},
354 types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRV},
355 };
356
357 use super::{DataflowBlock, TailLoop};
358
359 #[test]
360 fn test_subst_dataflow_block() {
361 use crate::ops::OpTrait;
362 let tv0 = Type::new_var_use(0, TypeBound::Linear);
363 let dfb = DataflowBlock {
364 inputs: vec![usize_t(), tv0.clone()].into(),
365 other_outputs: vec![tv0.clone()].into(),
366 sum_rows: vec![usize_t().into(), vec![qb_t(), tv0.clone()].into()],
367 };
368 let dfb2 = dfb.substitute(&Substitution::new(&[qb_t().into()]));
369 let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]);
370 assert_eq!(
371 dfb2.inner_signature(),
372 Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()])
373 );
374 }
375
376 #[test]
377 fn test_subst_conditional() {
378 let tv1 = Type::new_var_use(1, TypeBound::Linear);
379 let cond = Conditional {
380 sum_rows: vec![usize_t().into(), tv1.clone().into()],
381 other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use(
382 0,
383 TypeBound::Linear,
384 ))]
385 .into(),
386 outputs: vec![usize_t(), tv1].into(),
387 };
388 let cond2 = cond.substitute(&Substitution::new(&[
389 TypeArg::new_list([usize_t().into(), usize_t().into(), usize_t().into()]),
390 qb_t().into(),
391 ]));
392 let st = Type::new_sum(vec![usize_t(), qb_t()]); assert_eq!(
394 cond2.signature(),
395 Signature::new(
396 vec![st, Type::new_tuple(vec![usize_t(); 3])],
397 vec![usize_t(), qb_t()]
398 )
399 );
400 }
401
402 #[test]
403 fn test_tail_loop() {
404 let tv0 = Type::new_var_use(0, TypeBound::Copyable);
405 let tail_loop = TailLoop {
406 just_inputs: vec![qb_t(), tv0.clone()].into(),
407 just_outputs: vec![tv0.clone(), qb_t()].into(),
408 rest: vec![tv0.clone()].into(),
409 };
410 let tail2 = tail_loop.substitute(&Substitution::new(&[usize_t().into()]));
411 assert_eq!(
412 tail2.signature(),
413 Signature::new(
414 vec![qb_t(), usize_t(), usize_t()],
415 vec![usize_t(), qb_t(), usize_t()]
416 )
417 );
418 }
419}