1use std::borrow::Cow;
4
5use crate::extension::ExtensionSet;
6use crate::types::{EdgeKind, Signature, Type, TypeRow};
7use crate::Direction;
8
9use super::dataflow::{DataflowOpTrait, DataflowParent};
10use super::{impl_op_name, NamedOp, OpTrait, StaticTag};
11use super::{OpName, OpTag};
12
13#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
15#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
16pub struct TailLoop {
17    pub just_inputs: TypeRow,
19    pub just_outputs: TypeRow,
21    pub rest: TypeRow,
23    pub extension_delta: ExtensionSet,
25}
26
27impl_op_name!(TailLoop);
28
29impl DataflowOpTrait for TailLoop {
30    const TAG: OpTag = OpTag::TailLoop;
31
32    fn description(&self) -> &str {
33        "A tail-controlled loop"
34    }
35
36    fn signature(&self) -> Cow<'_, Signature> {
37        let [inputs, outputs] =
39            [&self.just_inputs, &self.just_outputs].map(|row| row.extend(self.rest.iter()));
40        Cow::Owned(
41            Signature::new(inputs, outputs).with_extension_delta(self.extension_delta.clone()),
42        )
43    }
44
45    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
46        Self {
47            just_inputs: self.just_inputs.substitute(subst),
48            just_outputs: self.just_outputs.substitute(subst),
49            rest: self.rest.substitute(subst),
50            extension_delta: self.extension_delta.substitute(subst),
51        }
52    }
53}
54
55impl TailLoop {
56    pub const CONTINUE_TAG: usize = 0;
60
61    pub const BREAK_TAG: usize = 1;
65
66    pub(crate) fn body_output_row(&self) -> TypeRow {
68        let sum_type = Type::new_sum([self.just_inputs.clone(), self.just_outputs.clone()]);
69        let mut outputs = vec![sum_type];
70        outputs.extend_from_slice(&self.rest);
71        outputs.into()
72    }
73
74    pub(crate) fn body_input_row(&self) -> TypeRow {
76        self.just_inputs.extend(self.rest.iter())
77    }
78}
79
80impl DataflowParent for TailLoop {
81    fn inner_signature(&self) -> Cow<'_, Signature> {
82        Cow::Owned(
84            Signature::new(self.body_input_row(), self.body_output_row())
85                .with_extension_delta(self.extension_delta.clone()),
86        )
87    }
88}
89
90#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
92#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
93pub struct Conditional {
94    pub sum_rows: Vec<TypeRow>,
96    pub other_inputs: TypeRow,
98    pub outputs: TypeRow,
100    pub extension_delta: ExtensionSet,
102}
103impl_op_name!(Conditional);
104
105impl DataflowOpTrait for Conditional {
106    const TAG: OpTag = OpTag::Conditional;
107
108    fn description(&self) -> &str {
109        "HUGR conditional operation"
110    }
111
112    fn signature(&self) -> Cow<'_, Signature> {
113        let mut inputs = self.other_inputs.clone();
115        inputs
116            .to_mut()
117            .insert(0, Type::new_sum(self.sum_rows.clone()));
118        Cow::Owned(
119            Signature::new(inputs, self.outputs.clone())
120                .with_extension_delta(self.extension_delta.clone()),
121        )
122    }
123
124    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
125        Self {
126            sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(),
127            other_inputs: self.other_inputs.substitute(subst),
128            outputs: self.outputs.substitute(subst),
129            extension_delta: self.extension_delta.substitute(subst),
130        }
131    }
132}
133
134impl Conditional {
135    pub(crate) fn case_input_row(&self, case: usize) -> Option<TypeRow> {
137        Some(self.sum_rows.get(case)?.extend(self.other_inputs.iter()))
138    }
139}
140
141#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
143#[allow(missing_docs)]
144#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
145pub struct CFG {
146    pub signature: Signature,
147}
148
149impl_op_name!(CFG);
150
151impl DataflowOpTrait for CFG {
152    const TAG: OpTag = OpTag::Cfg;
153
154    fn description(&self) -> &str {
155        "A dataflow node defined by a child CFG"
156    }
157
158    fn signature(&self) -> Cow<'_, Signature> {
159        Cow::Borrowed(&self.signature)
160    }
161
162    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
163        Self {
164            signature: self.signature.substitute(subst),
165        }
166    }
167}
168
169#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
170#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
171#[allow(missing_docs)]
173pub struct DataflowBlock {
174    pub inputs: TypeRow,
175    pub other_outputs: TypeRow,
176    pub sum_rows: Vec<TypeRow>,
177    pub extension_delta: ExtensionSet,
178}
179
180#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
181#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
182pub struct ExitBlock {
185    pub cfg_outputs: TypeRow,
187}
188
189impl NamedOp for DataflowBlock {
190    fn name(&self) -> OpName {
191        "DataflowBlock".into()
192    }
193}
194
195impl NamedOp for ExitBlock {
196    fn name(&self) -> OpName {
197        "ExitBlock".into()
198    }
199}
200
201impl StaticTag for DataflowBlock {
202    const TAG: OpTag = OpTag::DataflowBlock;
203}
204
205impl StaticTag for ExitBlock {
206    const TAG: OpTag = OpTag::BasicBlockExit;
207}
208
209impl DataflowParent for DataflowBlock {
210    fn inner_signature(&self) -> Cow<'_, Signature> {
211        let sum_type = Type::new_sum(self.sum_rows.clone());
214        let mut node_outputs = vec![sum_type];
215        node_outputs.extend_from_slice(&self.other_outputs);
216        Cow::Owned(
217            Signature::new(self.inputs.clone(), TypeRow::from(node_outputs))
218                .with_extension_delta(self.extension_delta.clone()),
219        )
220    }
221}
222
223impl OpTrait for DataflowBlock {
224    fn description(&self) -> &str {
225        "A CFG basic block node"
226    }
227    fn tag(&self) -> OpTag {
229        Self::TAG
230    }
231
232    fn other_input(&self) -> Option<EdgeKind> {
233        Some(EdgeKind::ControlFlow)
234    }
235
236    fn other_output(&self) -> Option<EdgeKind> {
237        Some(EdgeKind::ControlFlow)
238    }
239
240    fn extension_delta(&self) -> ExtensionSet {
241        self.extension_delta.clone()
242    }
243
244    fn non_df_port_count(&self, dir: Direction) -> usize {
245        match dir {
246            Direction::Incoming => 1,
247            Direction::Outgoing => self.sum_rows.len(),
248        }
249    }
250
251    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
252        Self {
253            inputs: self.inputs.substitute(subst),
254            other_outputs: self.other_outputs.substitute(subst),
255            sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(),
256            extension_delta: self.extension_delta.substitute(subst),
257        }
258    }
259}
260
261impl OpTrait for ExitBlock {
262    fn description(&self) -> &str {
263        "A CFG exit block node"
264    }
265    fn tag(&self) -> OpTag {
267        Self::TAG
268    }
269
270    fn other_input(&self) -> Option<EdgeKind> {
271        Some(EdgeKind::ControlFlow)
272    }
273
274    fn other_output(&self) -> Option<EdgeKind> {
275        Some(EdgeKind::ControlFlow)
276    }
277
278    fn non_df_port_count(&self, dir: Direction) -> usize {
279        match dir {
280            Direction::Incoming => 1,
281            Direction::Outgoing => 0,
282        }
283    }
284
285    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
286        Self {
287            cfg_outputs: self.cfg_outputs.substitute(subst),
288        }
289    }
290}
291
292pub trait BasicBlock {
294    fn dataflow_input(&self) -> &TypeRow;
296}
297
298impl BasicBlock for DataflowBlock {
299    fn dataflow_input(&self) -> &TypeRow {
300        &self.inputs
301    }
302}
303impl DataflowBlock {
304    pub fn successor_input(&self, successor: usize) -> Option<TypeRow> {
307        Some(
308            self.sum_rows
309                .get(successor)?
310                .extend(self.other_outputs.iter()),
311        )
312    }
313}
314
315impl BasicBlock for ExitBlock {
316    fn dataflow_input(&self) -> &TypeRow {
317        &self.cfg_outputs
318    }
319}
320
321#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
322#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
323pub struct Case {
325    pub signature: Signature,
327}
328
329impl_op_name!(Case);
330
331impl StaticTag for Case {
332    const TAG: OpTag = OpTag::Case;
333}
334
335impl DataflowParent for Case {
336    fn inner_signature(&self) -> Cow<'_, Signature> {
337        Cow::Borrowed(&self.signature)
338    }
339}
340
341impl OpTrait for Case {
342    fn description(&self) -> &str {
343        "A case node inside a conditional"
344    }
345
346    fn extension_delta(&self) -> ExtensionSet {
347        self.signature.runtime_reqs.clone()
348    }
349
350    fn tag(&self) -> OpTag {
351        <Self as StaticTag>::TAG
352    }
353
354    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
355        Self {
356            signature: self.signature.substitute(subst),
357        }
358    }
359}
360
361impl Case {
362    pub fn dataflow_input(&self) -> &TypeRow {
364        &self.signature.input
365    }
366
367    pub fn dataflow_output(&self) -> &TypeRow {
369        &self.signature.output
370    }
371}
372
373#[cfg(test)]
374mod test {
375    use crate::{
376        extension::{
377            prelude::{qb_t, usize_t, PRELUDE_ID},
378            ExtensionSet,
379        },
380        ops::{Conditional, DataflowOpTrait, DataflowParent},
381        types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRV},
382    };
383
384    use super::{DataflowBlock, TailLoop};
385
386    #[test]
387    fn test_subst_dataflow_block() {
388        use crate::ops::OpTrait;
389        let tv0 = Type::new_var_use(0, TypeBound::Any);
390        let dfb = DataflowBlock {
391            inputs: vec![usize_t(), tv0.clone()].into(),
392            other_outputs: vec![tv0.clone()].into(),
393            sum_rows: vec![usize_t().into(), vec![qb_t(), tv0.clone()].into()],
394            extension_delta: ExtensionSet::type_var(1),
395        };
396        let dfb2 = dfb.substitute(&Substitution::new(&[
397            qb_t().into(),
398            TypeArg::Extensions {
399                es: PRELUDE_ID.into(),
400            },
401        ]));
402        let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]);
403        assert_eq!(
404            dfb2.inner_signature(),
405            Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()])
406                .with_extension_delta(PRELUDE_ID)
407        );
408    }
409
410    #[test]
411    fn test_subst_conditional() {
412        let tv1 = Type::new_var_use(1, TypeBound::Any);
413        let cond = Conditional {
414            sum_rows: vec![usize_t().into(), tv1.clone().into()],
415            other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any))].into(),
416            outputs: vec![usize_t(), tv1].into(),
417            extension_delta: ExtensionSet::new(),
418        };
419        let cond2 = cond.substitute(&Substitution::new(&[
420            TypeArg::Sequence {
421                elems: vec![usize_t().into(); 3],
422            },
423            qb_t().into(),
424        ]));
425        let st = Type::new_sum(vec![usize_t(), qb_t()]); assert_eq!(
427            cond2.signature(),
428            Signature::new(
429                vec![st, Type::new_tuple(vec![usize_t(); 3])],
430                vec![usize_t(), qb_t()]
431            )
432        );
433    }
434
435    #[test]
436    fn test_tail_loop() {
437        let tv0 = Type::new_var_use(0, TypeBound::Copyable);
438        let tail_loop = TailLoop {
439            just_inputs: vec![qb_t(), tv0.clone()].into(),
440            just_outputs: vec![tv0.clone(), qb_t()].into(),
441            rest: vec![tv0.clone()].into(),
442            extension_delta: ExtensionSet::type_var(1),
443        };
444        let tail2 = tail_loop.substitute(&Substitution::new(&[
445            usize_t().into(),
446            TypeArg::Extensions {
447                es: PRELUDE_ID.into(),
448            },
449        ]));
450        assert_eq!(
451            tail2.signature(),
452            Signature::new(
453                vec![qb_t(), usize_t(), usize_t()],
454                vec![usize_t(), qb_t(), usize_t()]
455            )
456            .with_extension_delta(PRELUDE_ID)
457        );
458    }
459}