1use std::borrow::Cow;
4
5use crate::extension::ExtensionSet;
6use crate::types::{EdgeKind, Signature, Type, TypeRow};
7use crate::Direction;
8
9use super::dataflow::{DataflowOpTrait, DataflowParent};
10use super::{impl_op_name, NamedOp, OpTrait, StaticTag};
11use super::{OpName, OpTag};
12
13#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
15#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
16pub struct TailLoop {
17 pub just_inputs: TypeRow,
19 pub just_outputs: TypeRow,
21 pub rest: TypeRow,
23 pub extension_delta: ExtensionSet,
25}
26
27impl_op_name!(TailLoop);
28
29impl DataflowOpTrait for TailLoop {
30 const TAG: OpTag = OpTag::TailLoop;
31
32 fn description(&self) -> &str {
33 "A tail-controlled loop"
34 }
35
36 fn signature(&self) -> Cow<'_, Signature> {
37 let [inputs, outputs] =
39 [&self.just_inputs, &self.just_outputs].map(|row| row.extend(self.rest.iter()));
40 Cow::Owned(
41 Signature::new(inputs, outputs).with_extension_delta(self.extension_delta.clone()),
42 )
43 }
44
45 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
46 Self {
47 just_inputs: self.just_inputs.substitute(subst),
48 just_outputs: self.just_outputs.substitute(subst),
49 rest: self.rest.substitute(subst),
50 extension_delta: self.extension_delta.substitute(subst),
51 }
52 }
53}
54
55impl TailLoop {
56 pub const CONTINUE_TAG: usize = 0;
60
61 pub const BREAK_TAG: usize = 1;
65
66 pub(crate) fn body_output_row(&self) -> TypeRow {
68 let sum_type = Type::new_sum([self.just_inputs.clone(), self.just_outputs.clone()]);
69 let mut outputs = vec![sum_type];
70 outputs.extend_from_slice(&self.rest);
71 outputs.into()
72 }
73
74 pub(crate) fn body_input_row(&self) -> TypeRow {
76 self.just_inputs.extend(self.rest.iter())
77 }
78}
79
80impl DataflowParent for TailLoop {
81 fn inner_signature(&self) -> Cow<'_, Signature> {
82 Cow::Owned(
84 Signature::new(self.body_input_row(), self.body_output_row())
85 .with_extension_delta(self.extension_delta.clone()),
86 )
87 }
88}
89
90#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
92#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
93pub struct Conditional {
94 pub sum_rows: Vec<TypeRow>,
96 pub other_inputs: TypeRow,
98 pub outputs: TypeRow,
100 pub extension_delta: ExtensionSet,
102}
103impl_op_name!(Conditional);
104
105impl DataflowOpTrait for Conditional {
106 const TAG: OpTag = OpTag::Conditional;
107
108 fn description(&self) -> &str {
109 "HUGR conditional operation"
110 }
111
112 fn signature(&self) -> Cow<'_, Signature> {
113 let mut inputs = self.other_inputs.clone();
115 inputs
116 .to_mut()
117 .insert(0, Type::new_sum(self.sum_rows.clone()));
118 Cow::Owned(
119 Signature::new(inputs, self.outputs.clone())
120 .with_extension_delta(self.extension_delta.clone()),
121 )
122 }
123
124 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
125 Self {
126 sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(),
127 other_inputs: self.other_inputs.substitute(subst),
128 outputs: self.outputs.substitute(subst),
129 extension_delta: self.extension_delta.substitute(subst),
130 }
131 }
132}
133
134impl Conditional {
135 pub(crate) fn case_input_row(&self, case: usize) -> Option<TypeRow> {
137 Some(self.sum_rows.get(case)?.extend(self.other_inputs.iter()))
138 }
139}
140
141#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
143#[allow(missing_docs)]
144#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
145pub struct CFG {
146 pub signature: Signature,
147}
148
149impl_op_name!(CFG);
150
151impl DataflowOpTrait for CFG {
152 const TAG: OpTag = OpTag::Cfg;
153
154 fn description(&self) -> &str {
155 "A dataflow node defined by a child CFG"
156 }
157
158 fn signature(&self) -> Cow<'_, Signature> {
159 Cow::Borrowed(&self.signature)
160 }
161
162 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
163 Self {
164 signature: self.signature.substitute(subst),
165 }
166 }
167}
168
169#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
170#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
171#[allow(missing_docs)]
173pub struct DataflowBlock {
174 pub inputs: TypeRow,
175 pub other_outputs: TypeRow,
176 pub sum_rows: Vec<TypeRow>,
177 pub extension_delta: ExtensionSet,
178}
179
180#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
181#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
182pub struct ExitBlock {
185 pub cfg_outputs: TypeRow,
187}
188
189impl NamedOp for DataflowBlock {
190 fn name(&self) -> OpName {
191 "DataflowBlock".into()
192 }
193}
194
195impl NamedOp for ExitBlock {
196 fn name(&self) -> OpName {
197 "ExitBlock".into()
198 }
199}
200
201impl StaticTag for DataflowBlock {
202 const TAG: OpTag = OpTag::DataflowBlock;
203}
204
205impl StaticTag for ExitBlock {
206 const TAG: OpTag = OpTag::BasicBlockExit;
207}
208
209impl DataflowParent for DataflowBlock {
210 fn inner_signature(&self) -> Cow<'_, Signature> {
211 let sum_type = Type::new_sum(self.sum_rows.clone());
214 let mut node_outputs = vec![sum_type];
215 node_outputs.extend_from_slice(&self.other_outputs);
216 Cow::Owned(
217 Signature::new(self.inputs.clone(), TypeRow::from(node_outputs))
218 .with_extension_delta(self.extension_delta.clone()),
219 )
220 }
221}
222
223impl OpTrait for DataflowBlock {
224 fn description(&self) -> &str {
225 "A CFG basic block node"
226 }
227 fn tag(&self) -> OpTag {
229 Self::TAG
230 }
231
232 fn other_input(&self) -> Option<EdgeKind> {
233 Some(EdgeKind::ControlFlow)
234 }
235
236 fn other_output(&self) -> Option<EdgeKind> {
237 Some(EdgeKind::ControlFlow)
238 }
239
240 fn extension_delta(&self) -> ExtensionSet {
241 self.extension_delta.clone()
242 }
243
244 fn non_df_port_count(&self, dir: Direction) -> usize {
245 match dir {
246 Direction::Incoming => 1,
247 Direction::Outgoing => self.sum_rows.len(),
248 }
249 }
250
251 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
252 Self {
253 inputs: self.inputs.substitute(subst),
254 other_outputs: self.other_outputs.substitute(subst),
255 sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(),
256 extension_delta: self.extension_delta.substitute(subst),
257 }
258 }
259}
260
261impl OpTrait for ExitBlock {
262 fn description(&self) -> &str {
263 "A CFG exit block node"
264 }
265 fn tag(&self) -> OpTag {
267 Self::TAG
268 }
269
270 fn other_input(&self) -> Option<EdgeKind> {
271 Some(EdgeKind::ControlFlow)
272 }
273
274 fn other_output(&self) -> Option<EdgeKind> {
275 Some(EdgeKind::ControlFlow)
276 }
277
278 fn non_df_port_count(&self, dir: Direction) -> usize {
279 match dir {
280 Direction::Incoming => 1,
281 Direction::Outgoing => 0,
282 }
283 }
284
285 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
286 Self {
287 cfg_outputs: self.cfg_outputs.substitute(subst),
288 }
289 }
290}
291
292pub trait BasicBlock {
294 fn dataflow_input(&self) -> &TypeRow;
296}
297
298impl BasicBlock for DataflowBlock {
299 fn dataflow_input(&self) -> &TypeRow {
300 &self.inputs
301 }
302}
303impl DataflowBlock {
304 pub fn successor_input(&self, successor: usize) -> Option<TypeRow> {
307 Some(
308 self.sum_rows
309 .get(successor)?
310 .extend(self.other_outputs.iter()),
311 )
312 }
313}
314
315impl BasicBlock for ExitBlock {
316 fn dataflow_input(&self) -> &TypeRow {
317 &self.cfg_outputs
318 }
319}
320
321#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
322#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
323pub struct Case {
325 pub signature: Signature,
327}
328
329impl_op_name!(Case);
330
331impl StaticTag for Case {
332 const TAG: OpTag = OpTag::Case;
333}
334
335impl DataflowParent for Case {
336 fn inner_signature(&self) -> Cow<'_, Signature> {
337 Cow::Borrowed(&self.signature)
338 }
339}
340
341impl OpTrait for Case {
342 fn description(&self) -> &str {
343 "A case node inside a conditional"
344 }
345
346 fn extension_delta(&self) -> ExtensionSet {
347 self.signature.runtime_reqs.clone()
348 }
349
350 fn tag(&self) -> OpTag {
351 <Self as StaticTag>::TAG
352 }
353
354 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
355 Self {
356 signature: self.signature.substitute(subst),
357 }
358 }
359}
360
361impl Case {
362 pub fn dataflow_input(&self) -> &TypeRow {
364 &self.signature.input
365 }
366
367 pub fn dataflow_output(&self) -> &TypeRow {
369 &self.signature.output
370 }
371}
372
373#[cfg(test)]
374mod test {
375 use crate::{
376 extension::{
377 prelude::{qb_t, usize_t, PRELUDE_ID},
378 ExtensionSet,
379 },
380 ops::{Conditional, DataflowOpTrait, DataflowParent},
381 types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRV},
382 };
383
384 use super::{DataflowBlock, TailLoop};
385
386 #[test]
387 fn test_subst_dataflow_block() {
388 use crate::ops::OpTrait;
389 let tv0 = Type::new_var_use(0, TypeBound::Any);
390 let dfb = DataflowBlock {
391 inputs: vec![usize_t(), tv0.clone()].into(),
392 other_outputs: vec![tv0.clone()].into(),
393 sum_rows: vec![usize_t().into(), vec![qb_t(), tv0.clone()].into()],
394 extension_delta: ExtensionSet::type_var(1),
395 };
396 let dfb2 = dfb.substitute(&Substitution::new(&[
397 qb_t().into(),
398 TypeArg::Extensions {
399 es: PRELUDE_ID.into(),
400 },
401 ]));
402 let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]);
403 assert_eq!(
404 dfb2.inner_signature(),
405 Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()])
406 .with_extension_delta(PRELUDE_ID)
407 );
408 }
409
410 #[test]
411 fn test_subst_conditional() {
412 let tv1 = Type::new_var_use(1, TypeBound::Any);
413 let cond = Conditional {
414 sum_rows: vec![usize_t().into(), tv1.clone().into()],
415 other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any))].into(),
416 outputs: vec![usize_t(), tv1].into(),
417 extension_delta: ExtensionSet::new(),
418 };
419 let cond2 = cond.substitute(&Substitution::new(&[
420 TypeArg::Sequence {
421 elems: vec![usize_t().into(); 3],
422 },
423 qb_t().into(),
424 ]));
425 let st = Type::new_sum(vec![usize_t(), qb_t()]); assert_eq!(
427 cond2.signature(),
428 Signature::new(
429 vec![st, Type::new_tuple(vec![usize_t(); 3])],
430 vec![usize_t(), qb_t()]
431 )
432 );
433 }
434
435 #[test]
436 fn test_tail_loop() {
437 let tv0 = Type::new_var_use(0, TypeBound::Copyable);
438 let tail_loop = TailLoop {
439 just_inputs: vec![qb_t(), tv0.clone()].into(),
440 just_outputs: vec![tv0.clone(), qb_t()].into(),
441 rest: vec![tv0.clone()].into(),
442 extension_delta: ExtensionSet::type_var(1),
443 };
444 let tail2 = tail_loop.substitute(&Substitution::new(&[
445 usize_t().into(),
446 TypeArg::Extensions {
447 es: PRELUDE_ID.into(),
448 },
449 ]));
450 assert_eq!(
451 tail2.signature(),
452 Signature::new(
453 vec![qb_t(), usize_t(), usize_t()],
454 vec![usize_t(), qb_t(), usize_t()]
455 )
456 .with_extension_delta(PRELUDE_ID)
457 );
458 }
459}