hugr_core/ops/
controlflow.rs

1//! Control flow operations.
2
3use std::borrow::Cow;
4
5use crate::Direction;
6use crate::types::{EdgeKind, Signature, Type, TypeRow};
7
8use super::OpTag;
9use super::dataflow::{DataflowOpTrait, DataflowParent};
10use super::{OpTrait, StaticTag, impl_op_name};
11
12/// Tail-controlled loop.
13#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
14#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
15pub struct TailLoop {
16    /// Types that are only input
17    pub just_inputs: TypeRow,
18    /// Types that are only output
19    pub just_outputs: TypeRow,
20    /// Types that are appended to both input and output
21    pub rest: TypeRow,
22}
23
24impl_op_name!(TailLoop);
25
26impl DataflowOpTrait for TailLoop {
27    const TAG: OpTag = OpTag::TailLoop;
28
29    fn description(&self) -> &'static str {
30        "A tail-controlled loop"
31    }
32
33    fn signature(&self) -> Cow<'_, Signature> {
34        // TODO: Store a cached signature
35        let [inputs, outputs] =
36            [&self.just_inputs, &self.just_outputs].map(|row| row.extend(self.rest.iter()));
37        Cow::Owned(Signature::new(inputs, outputs))
38    }
39
40    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
41        Self {
42            just_inputs: self.just_inputs.substitute(subst),
43            just_outputs: self.just_outputs.substitute(subst),
44            rest: self.rest.substitute(subst),
45        }
46    }
47}
48
49impl TailLoop {
50    /// The [tag] for a loop body output to indicate the loop should iterate again.
51    ///
52    /// [tag]: crate::ops::constant::Sum::tag
53    pub const CONTINUE_TAG: usize = 0;
54
55    /// The [tag] for a loop body output to indicate the loop should exit with the supplied values.
56    ///
57    /// [tag]: crate::ops::constant::Sum::tag
58    pub const BREAK_TAG: usize = 1;
59
60    /// Build the output `TypeRow` of the child graph of a `TailLoop` node.
61    pub(crate) fn body_output_row(&self) -> TypeRow {
62        let mut outputs = vec![Type::new_sum(self.control_variants())];
63        outputs.extend_from_slice(&self.rest);
64        outputs.into()
65    }
66
67    /// The variants (continue / break) of the first output from the child graph
68    pub(crate) fn control_variants(&self) -> [TypeRow; 2] {
69        [self.just_inputs.clone(), self.just_outputs.clone()]
70    }
71
72    /// Build the input `TypeRow` of the child graph of a `TailLoop` node.
73    pub(crate) fn body_input_row(&self) -> TypeRow {
74        self.just_inputs.extend(self.rest.iter())
75    }
76}
77
78impl DataflowParent for TailLoop {
79    fn inner_signature(&self) -> Cow<'_, Signature> {
80        // TODO: Store a cached signature
81        Cow::Owned(Signature::new(
82            self.body_input_row(),
83            self.body_output_row(),
84        ))
85    }
86}
87
88/// Conditional operation, defined by child `Case` nodes for each branch.
89#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
90#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
91pub struct Conditional {
92    /// The possible rows of the Sum input
93    pub sum_rows: Vec<TypeRow>,
94    /// Remaining input types
95    pub other_inputs: TypeRow,
96    /// Output types
97    pub outputs: TypeRow,
98}
99impl_op_name!(Conditional);
100
101impl DataflowOpTrait for Conditional {
102    const TAG: OpTag = OpTag::Conditional;
103
104    fn description(&self) -> &'static str {
105        "HUGR conditional operation"
106    }
107
108    fn signature(&self) -> Cow<'_, Signature> {
109        // TODO: Store a cached signature
110        let mut inputs = self.other_inputs.clone();
111        inputs
112            .to_mut()
113            .insert(0, Type::new_sum(self.sum_rows.clone()));
114        Cow::Owned(Signature::new(inputs, self.outputs.clone()))
115    }
116
117    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
118        Self {
119            sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(),
120            other_inputs: self.other_inputs.substitute(subst),
121            outputs: self.outputs.substitute(subst),
122        }
123    }
124}
125
126impl Conditional {
127    /// Build the input `TypeRow` of the nth child graph of a Conditional node.
128    pub(crate) fn case_input_row(&self, case: usize) -> Option<TypeRow> {
129        Some(self.sum_rows.get(case)?.extend(self.other_inputs.iter()))
130    }
131}
132
133/// A dataflow node which is defined by a child CFG.
134#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
135#[allow(missing_docs)]
136#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
137pub struct CFG {
138    pub signature: Signature,
139}
140
141impl_op_name!(CFG);
142
143impl DataflowOpTrait for CFG {
144    const TAG: OpTag = OpTag::Cfg;
145
146    fn description(&self) -> &'static str {
147        "A dataflow node defined by a child CFG"
148    }
149
150    fn signature(&self) -> Cow<'_, Signature> {
151        Cow::Borrowed(&self.signature)
152    }
153
154    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
155        Self {
156            signature: self.signature.substitute(subst),
157        }
158    }
159}
160
161#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
162#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
163/// A CFG basic block node. The signature is that of the internal Dataflow graph.
164#[allow(missing_docs)]
165pub struct DataflowBlock {
166    pub inputs: TypeRow,
167    pub other_outputs: TypeRow,
168    pub sum_rows: Vec<TypeRow>,
169}
170
171#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
172#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
173/// The single exit node of the CFG. Has no children,
174/// stores the types of the CFG node output.
175pub struct ExitBlock {
176    /// Output type row of the CFG.
177    pub cfg_outputs: TypeRow,
178}
179
180impl_op_name!(DataflowBlock);
181impl_op_name!(ExitBlock);
182
183impl StaticTag for DataflowBlock {
184    const TAG: OpTag = OpTag::DataflowBlock;
185}
186
187impl StaticTag for ExitBlock {
188    const TAG: OpTag = OpTag::BasicBlockExit;
189}
190
191impl DataflowParent for DataflowBlock {
192    fn inner_signature(&self) -> Cow<'_, Signature> {
193        // TODO: Store a cached signature
194        // The node outputs a Sum before the data outputs of the block node
195        let sum_type = Type::new_sum(self.sum_rows.clone());
196        let mut node_outputs = vec![sum_type];
197        node_outputs.extend_from_slice(&self.other_outputs);
198        Cow::Owned(Signature::new(
199            self.inputs.clone(),
200            TypeRow::from(node_outputs),
201        ))
202    }
203}
204
205impl OpTrait for DataflowBlock {
206    fn description(&self) -> &'static str {
207        "A CFG basic block node"
208    }
209    /// Tag identifying the operation.
210    fn tag(&self) -> OpTag {
211        Self::TAG
212    }
213
214    fn other_input(&self) -> Option<EdgeKind> {
215        Some(EdgeKind::ControlFlow)
216    }
217
218    fn other_output(&self) -> Option<EdgeKind> {
219        Some(EdgeKind::ControlFlow)
220    }
221
222    fn non_df_port_count(&self, dir: Direction) -> usize {
223        match dir {
224            Direction::Incoming => 1,
225            Direction::Outgoing => self.sum_rows.len(),
226        }
227    }
228
229    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
230        Self {
231            inputs: self.inputs.substitute(subst),
232            other_outputs: self.other_outputs.substitute(subst),
233            sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(),
234        }
235    }
236}
237
238impl OpTrait for ExitBlock {
239    fn description(&self) -> &'static str {
240        "A CFG exit block node"
241    }
242    /// Tag identifying the operation.
243    fn tag(&self) -> OpTag {
244        Self::TAG
245    }
246
247    fn other_input(&self) -> Option<EdgeKind> {
248        Some(EdgeKind::ControlFlow)
249    }
250
251    fn other_output(&self) -> Option<EdgeKind> {
252        Some(EdgeKind::ControlFlow)
253    }
254
255    fn non_df_port_count(&self, dir: Direction) -> usize {
256        match dir {
257            Direction::Incoming => 1,
258            Direction::Outgoing => 0,
259        }
260    }
261
262    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
263        Self {
264            cfg_outputs: self.cfg_outputs.substitute(subst),
265        }
266    }
267}
268
269/// Functionality shared by `DataflowBlock` and Exit CFG block types.
270pub trait BasicBlock {
271    /// The input dataflow signature of the CFG block.
272    fn dataflow_input(&self) -> &TypeRow;
273}
274
275impl BasicBlock for DataflowBlock {
276    fn dataflow_input(&self) -> &TypeRow {
277        &self.inputs
278    }
279}
280impl DataflowBlock {
281    /// The correct inputs of any successors. Returns None if successor is not a
282    /// valid index.
283    #[must_use]
284    pub fn successor_input(&self, successor: usize) -> Option<TypeRow> {
285        Some(
286            self.sum_rows
287                .get(successor)?
288                .extend(self.other_outputs.iter()),
289        )
290    }
291}
292
293impl BasicBlock for ExitBlock {
294    fn dataflow_input(&self) -> &TypeRow {
295        &self.cfg_outputs
296    }
297}
298
299#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
300#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
301/// Case ops - nodes valid inside Conditional nodes.
302pub struct Case {
303    /// The signature of the contained dataflow graph.
304    pub signature: Signature,
305}
306
307impl_op_name!(Case);
308
309impl StaticTag for Case {
310    const TAG: OpTag = OpTag::Case;
311}
312
313impl DataflowParent for Case {
314    fn inner_signature(&self) -> Cow<'_, Signature> {
315        Cow::Borrowed(&self.signature)
316    }
317}
318
319impl OpTrait for Case {
320    fn description(&self) -> &'static str {
321        "A case node inside a conditional"
322    }
323
324    fn tag(&self) -> OpTag {
325        <Self as StaticTag>::TAG
326    }
327
328    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
329        Self {
330            signature: self.signature.substitute(subst),
331        }
332    }
333}
334
335impl Case {
336    /// The input signature of the contained dataflow graph.
337    #[must_use]
338    pub fn dataflow_input(&self) -> &TypeRow {
339        &self.signature.input
340    }
341
342    /// The output signature of the contained dataflow graph.
343    #[must_use]
344    pub fn dataflow_output(&self) -> &TypeRow {
345        &self.signature.output
346    }
347}
348
349#[cfg(test)]
350mod test {
351    use crate::{
352        extension::prelude::{qb_t, usize_t},
353        ops::{Conditional, DataflowOpTrait, DataflowParent},
354        types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRV},
355    };
356
357    use super::{DataflowBlock, TailLoop};
358
359    #[test]
360    fn test_subst_dataflow_block() {
361        use crate::ops::OpTrait;
362        let tv0 = Type::new_var_use(0, TypeBound::Linear);
363        let dfb = DataflowBlock {
364            inputs: vec![usize_t(), tv0.clone()].into(),
365            other_outputs: vec![tv0.clone()].into(),
366            sum_rows: vec![usize_t().into(), vec![qb_t(), tv0.clone()].into()],
367        };
368        let dfb2 = dfb.substitute(&Substitution::new(&[qb_t().into()]));
369        let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]);
370        assert_eq!(
371            dfb2.inner_signature(),
372            Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()])
373        );
374    }
375
376    #[test]
377    fn test_subst_conditional() {
378        let tv1 = Type::new_var_use(1, TypeBound::Linear);
379        let cond = Conditional {
380            sum_rows: vec![usize_t().into(), tv1.clone().into()],
381            other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use(
382                0,
383                TypeBound::Linear,
384            ))]
385            .into(),
386            outputs: vec![usize_t(), tv1].into(),
387        };
388        let cond2 = cond.substitute(&Substitution::new(&[
389            TypeArg::new_list([usize_t().into(), usize_t().into(), usize_t().into()]),
390            qb_t().into(),
391        ]));
392        let st = Type::new_sum(vec![usize_t(), qb_t()]); //both single-element variants
393        assert_eq!(
394            cond2.signature(),
395            Signature::new(
396                vec![st, Type::new_tuple(vec![usize_t(); 3])],
397                vec![usize_t(), qb_t()]
398            )
399        );
400    }
401
402    #[test]
403    fn test_tail_loop() {
404        let tv0 = Type::new_var_use(0, TypeBound::Copyable);
405        let tail_loop = TailLoop {
406            just_inputs: vec![qb_t(), tv0.clone()].into(),
407            just_outputs: vec![tv0.clone(), qb_t()].into(),
408            rest: vec![tv0.clone()].into(),
409        };
410        let tail2 = tail_loop.substitute(&Substitution::new(&[usize_t().into()]));
411        assert_eq!(
412            tail2.signature(),
413            Signature::new(
414                vec![qb_t(), usize_t(), usize_t()],
415                vec![usize_t(), qb_t(), usize_t()]
416            )
417        );
418    }
419}