hugr_core/ops/
dataflow.rs

1//! Dataflow operations.
2
3use std::borrow::Cow;
4
5use super::{impl_op_name, OpTag, OpTrait};
6
7use crate::extension::{ExtensionSet, SignatureError};
8use crate::ops::StaticTag;
9use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow};
10use crate::{type_row, IncomingPort};
11
12#[cfg(test)]
13use proptest_derive::Arbitrary;
14
15/// Trait implemented by all dataflow operations.
16pub trait DataflowOpTrait: Sized {
17    /// Tag identifying the operation.
18    const TAG: OpTag;
19
20    /// A human-readable description of the operation.
21    fn description(&self) -> &str;
22
23    /// The signature of the operation.
24    fn signature(&self) -> Cow<'_, Signature>;
25
26    /// The edge kind for the non-dataflow or constant inputs of the operation,
27    /// not described by the signature.
28    ///
29    /// If not None, a single extra output multiport of that kind will be
30    /// present.
31    #[inline]
32    fn other_input(&self) -> Option<EdgeKind> {
33        Some(EdgeKind::StateOrder)
34    }
35    /// The edge kind for the non-dataflow outputs of the operation, not
36    /// described by the signature.
37    ///
38    /// If not None, a single extra output multiport of that kind will be
39    /// present.
40    #[inline]
41    fn other_output(&self) -> Option<EdgeKind> {
42        Some(EdgeKind::StateOrder)
43    }
44
45    /// The edge kind for a single constant input of the operation, not
46    /// described by the dataflow signature.
47    ///
48    /// If not None, an extra input port of that kind will be present after the
49    /// dataflow input ports and before any [`DataflowOpTrait::other_input`] ports.
50    #[inline]
51    fn static_input(&self) -> Option<EdgeKind> {
52        None
53    }
54
55    /// Apply a type-level substitution to this OpType, i.e. replace
56    /// [type variables](TypeArg::new_var_use) with new types.
57    fn substitute(&self, _subst: &Substitution) -> Self;
58}
59
60/// Helpers to construct input and output nodes
61pub trait IOTrait {
62    /// Construct a new I/O node from a type row with no extension requirements
63    fn new(types: impl Into<TypeRow>) -> Self;
64}
65
66/// An input node.
67/// The outputs of this node are the inputs to the function.
68#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
69#[cfg_attr(test, derive(Arbitrary))]
70pub struct Input {
71    /// Input value types
72    pub types: TypeRow,
73}
74
75impl_op_name!(Input);
76
77impl IOTrait for Input {
78    fn new(types: impl Into<TypeRow>) -> Self {
79        Input {
80            types: types.into(),
81        }
82    }
83}
84
85/// An output node. The inputs are the outputs of the function.
86#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
87#[cfg_attr(test, derive(Arbitrary))]
88pub struct Output {
89    /// Output value types
90    pub types: TypeRow,
91}
92
93impl_op_name!(Output);
94
95impl IOTrait for Output {
96    fn new(types: impl Into<TypeRow>) -> Self {
97        Output {
98            types: types.into(),
99        }
100    }
101}
102
103impl DataflowOpTrait for Input {
104    const TAG: OpTag = OpTag::Input;
105
106    fn description(&self) -> &str {
107        "The input node for this dataflow subgraph"
108    }
109
110    fn other_input(&self) -> Option<EdgeKind> {
111        None
112    }
113
114    fn signature(&self) -> Cow<'_, Signature> {
115        // TODO: Store a cached signature
116        Cow::Owned(Signature::new(TypeRow::new(), self.types.clone()))
117    }
118
119    fn substitute(&self, subst: &Substitution) -> Self {
120        Self {
121            types: self.types.substitute(subst),
122        }
123    }
124}
125impl DataflowOpTrait for Output {
126    const TAG: OpTag = OpTag::Output;
127
128    fn description(&self) -> &str {
129        "The output node for this dataflow subgraph"
130    }
131
132    // Note: We know what the input extensions should be, so we *could* give an
133    // instantiated Signature instead
134    fn signature(&self) -> Cow<'_, Signature> {
135        // TODO: Store a cached signature
136        Cow::Owned(Signature::new(self.types.clone(), TypeRow::new()))
137    }
138
139    fn other_output(&self) -> Option<EdgeKind> {
140        None
141    }
142
143    fn substitute(&self, subst: &Substitution) -> Self {
144        Self {
145            types: self.types.substitute(subst),
146        }
147    }
148}
149
150impl<T: DataflowOpTrait + Clone> OpTrait for T {
151    fn description(&self) -> &str {
152        DataflowOpTrait::description(self)
153    }
154    fn tag(&self) -> OpTag {
155        T::TAG
156    }
157    fn dataflow_signature(&self) -> Option<Cow<'_, Signature>> {
158        Some(DataflowOpTrait::signature(self))
159    }
160    fn extension_delta(&self) -> ExtensionSet {
161        DataflowOpTrait::signature(self).runtime_reqs.clone()
162    }
163    fn other_input(&self) -> Option<EdgeKind> {
164        DataflowOpTrait::other_input(self)
165    }
166
167    fn other_output(&self) -> Option<EdgeKind> {
168        DataflowOpTrait::other_output(self)
169    }
170
171    fn static_input(&self) -> Option<EdgeKind> {
172        DataflowOpTrait::static_input(self)
173    }
174
175    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
176        DataflowOpTrait::substitute(self, subst)
177    }
178}
179impl<T: DataflowOpTrait> StaticTag for T {
180    const TAG: OpTag = T::TAG;
181}
182
183/// Call a function directly.
184///
185/// The first ports correspond to the signature of the function being called.
186/// The port immediately following those those is connected to the def/declare
187/// block with a [`EdgeKind::Function`] edge.
188#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
189#[cfg_attr(test, derive(Arbitrary))]
190pub struct Call {
191    /// Signature of function being called.
192    pub func_sig: PolyFuncType,
193    /// The type arguments that instantiate `func_sig`.
194    pub type_args: Vec<TypeArg>,
195    /// The instantiation of `func_sig`.
196    pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature()
197}
198impl_op_name!(Call);
199
200impl DataflowOpTrait for Call {
201    const TAG: OpTag = OpTag::FnCall;
202
203    fn description(&self) -> &str {
204        "Call a function directly"
205    }
206
207    fn signature(&self) -> Cow<'_, Signature> {
208        Cow::Borrowed(&self.instantiation)
209    }
210
211    fn static_input(&self) -> Option<EdgeKind> {
212        Some(EdgeKind::Function(self.called_function_type().clone()))
213    }
214
215    fn substitute(&self, subst: &Substitution) -> Self {
216        let type_args = self
217            .type_args
218            .iter()
219            .map(|ta| ta.substitute(subst))
220            .collect::<Vec<_>>();
221        let instantiation = self.instantiation.substitute(subst);
222        debug_assert_eq!(
223            self.func_sig.instantiate(&type_args).as_ref(),
224            Ok(&instantiation)
225        );
226        Self {
227            type_args,
228            instantiation,
229            func_sig: self.func_sig.clone(),
230        }
231    }
232}
233impl Call {
234    /// Try to make a new Call. Returns an error if the `type_args`` do not fit the [TypeParam]s
235    /// declared by the function.
236    ///
237    /// [TypeParam]: crate::types::type_param::TypeParam
238    pub fn try_new(
239        func_sig: PolyFuncType,
240        type_args: impl Into<Vec<TypeArg>>,
241    ) -> Result<Self, SignatureError> {
242        let type_args: Vec<_> = type_args.into();
243        let instantiation = func_sig.instantiate(&type_args)?;
244        Ok(Self {
245            func_sig,
246            type_args,
247            instantiation,
248        })
249    }
250
251    #[inline]
252    /// Return the signature of the function called by this op.
253    pub fn called_function_type(&self) -> &PolyFuncType {
254        &self.func_sig
255    }
256
257    /// The IncomingPort which links to the function being called.
258    ///
259    /// This matches [`OpType::static_input_port`].
260    ///
261    /// ```
262    /// # use hugr::ops::dataflow::Call;
263    /// # use hugr::ops::OpType;
264    /// # use hugr::types::Signature;
265    /// # use hugr::extension::prelude::qb_t;
266    /// # use hugr::extension::PRELUDE_REGISTRY;
267    /// let signature = Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]);
268    /// let call = Call::try_new(signature.into(), &[]).unwrap();
269    /// let op = OpType::Call(call.clone());
270    /// assert_eq!(op.static_input_port(), Some(call.called_function_port()));
271    /// ```
272    ///
273    /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
274    #[inline]
275    pub fn called_function_port(&self) -> IncomingPort {
276        self.instantiation.input_count().into()
277    }
278
279    pub(crate) fn validate(&self) -> Result<(), SignatureError> {
280        let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
281        if other.instantiation == self.instantiation {
282            Ok(())
283        } else {
284            Err(SignatureError::CallIncorrectlyAppliesType {
285                cached: self.instantiation.clone(),
286                expected: other.instantiation.clone(),
287            })
288        }
289    }
290}
291
292/// Call a function indirectly. Like call, but the function input is a value
293/// (runtime, not static) dataflow edge, and thus does not need any type-args.
294#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
295#[cfg_attr(test, derive(Arbitrary))]
296pub struct CallIndirect {
297    /// Signature of function being called
298    pub signature: Signature,
299}
300impl_op_name!(CallIndirect);
301
302impl DataflowOpTrait for CallIndirect {
303    const TAG: OpTag = OpTag::FnCall;
304
305    fn description(&self) -> &str {
306        "Call a function indirectly"
307    }
308
309    fn signature(&self) -> Cow<'_, Signature> {
310        // TODO: Store a cached signature
311        let mut s = self.signature.clone();
312        s.input
313            .to_mut()
314            .insert(0, Type::new_function(self.signature.clone()));
315        Cow::Owned(s)
316    }
317
318    fn substitute(&self, subst: &Substitution) -> Self {
319        Self {
320            signature: self.signature.substitute(subst),
321        }
322    }
323}
324
325/// Load a static constant in to the local dataflow graph.
326#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
327#[cfg_attr(test, derive(Arbitrary))]
328pub struct LoadConstant {
329    /// Constant type
330    pub datatype: Type,
331}
332impl_op_name!(LoadConstant);
333impl DataflowOpTrait for LoadConstant {
334    const TAG: OpTag = OpTag::LoadConst;
335
336    fn description(&self) -> &str {
337        "Load a static constant in to the local dataflow graph"
338    }
339
340    fn signature(&self) -> Cow<'_, Signature> {
341        // TODO: Store a cached signature
342        Cow::Owned(Signature::new(TypeRow::new(), vec![self.datatype.clone()]))
343    }
344
345    fn static_input(&self) -> Option<EdgeKind> {
346        Some(EdgeKind::Const(self.constant_type().clone()))
347    }
348
349    fn substitute(&self, _subst: &Substitution) -> Self {
350        // Constants cannot refer to TypeArgs, so neither can loading them
351        self.clone()
352    }
353}
354
355impl LoadConstant {
356    #[inline]
357    /// The type of the constant loaded by this op.
358    pub fn constant_type(&self) -> &Type {
359        &self.datatype
360    }
361
362    /// The IncomingPort which links to the loaded constant.
363    ///
364    /// This matches [`OpType::static_input_port`].
365    ///
366    /// ```
367    /// # use hugr::ops::dataflow::LoadConstant;
368    /// # use hugr::ops::OpType;
369    /// # use hugr::types::Type;
370    /// let datatype = Type::UNIT;
371    /// let load_constant = LoadConstant { datatype };
372    /// let op = OpType::LoadConstant(load_constant.clone());
373    /// assert_eq!(op.static_input_port(), Some(load_constant.constant_port()));
374    /// ```
375    ///
376    /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
377    #[inline]
378    pub fn constant_port(&self) -> IncomingPort {
379        0.into()
380    }
381}
382
383/// Load a static function in to the local dataflow graph.
384#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
385#[cfg_attr(test, derive(Arbitrary))]
386pub struct LoadFunction {
387    /// Signature of the function
388    pub func_sig: PolyFuncType,
389    /// The type arguments that instantiate `func_sig`.
390    pub type_args: Vec<TypeArg>,
391    /// The instantiation of `func_sig`.
392    pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature()
393}
394impl_op_name!(LoadFunction);
395impl DataflowOpTrait for LoadFunction {
396    const TAG: OpTag = OpTag::LoadFunc;
397
398    fn description(&self) -> &str {
399        "Load a static function in to the local dataflow graph"
400    }
401
402    fn signature(&self) -> Cow<'_, Signature> {
403        Cow::Owned(Signature::new(
404            type_row![],
405            Type::new_function(self.instantiation.clone()),
406        ))
407    }
408
409    fn static_input(&self) -> Option<EdgeKind> {
410        Some(EdgeKind::Function(self.func_sig.clone()))
411    }
412
413    fn substitute(&self, subst: &Substitution) -> Self {
414        let type_args = self
415            .type_args
416            .iter()
417            .map(|ta| ta.substitute(subst))
418            .collect::<Vec<_>>();
419        let instantiation = self.instantiation.substitute(subst);
420        debug_assert_eq!(
421            self.func_sig.instantiate(&type_args).as_ref(),
422            Ok(&instantiation)
423        );
424        Self {
425            func_sig: self.func_sig.clone(),
426            type_args,
427            instantiation,
428        }
429    }
430}
431impl LoadFunction {
432    /// Try to make a new LoadFunction op. Returns an error if the `type_args`` do not fit
433    /// the [TypeParam]s declared by the function.
434    ///
435    /// [TypeParam]: crate::types::type_param::TypeParam
436    pub fn try_new(
437        func_sig: PolyFuncType,
438        type_args: impl Into<Vec<TypeArg>>,
439    ) -> Result<Self, SignatureError> {
440        let type_args: Vec<_> = type_args.into();
441        let instantiation = func_sig.instantiate(&type_args)?;
442        Ok(Self {
443            func_sig,
444            type_args,
445            instantiation,
446        })
447    }
448
449    #[inline]
450    /// Return the type of the function loaded by this op.
451    pub fn function_type(&self) -> &PolyFuncType {
452        &self.func_sig
453    }
454
455    /// The IncomingPort which links to the loaded function.
456    ///
457    /// This matches [`OpType::static_input_port`].
458    ///
459    /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
460    #[inline]
461    pub fn function_port(&self) -> IncomingPort {
462        0.into()
463    }
464
465    pub(crate) fn validate(&self) -> Result<(), SignatureError> {
466        let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
467        if other.instantiation == self.instantiation {
468            Ok(())
469        } else {
470            Err(SignatureError::LoadFunctionIncorrectlyAppliesType {
471                cached: self.instantiation.clone(),
472                expected: other.instantiation.clone(),
473            })
474        }
475    }
476}
477
478/// An operation that is the parent of a dataflow graph.
479///
480/// The children region contains an input and an output node matching the
481/// signature returned by [`DataflowParent::inner_signature`].
482pub trait DataflowParent {
483    /// Signature of the inner dataflow graph.
484    fn inner_signature(&self) -> Cow<'_, Signature>;
485}
486
487/// A simply nested dataflow graph.
488#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
489#[cfg_attr(test, derive(Arbitrary))]
490pub struct DFG {
491    /// Signature of DFG node
492    pub signature: Signature,
493}
494
495impl_op_name!(DFG);
496
497impl DataflowParent for DFG {
498    fn inner_signature(&self) -> Cow<'_, Signature> {
499        Cow::Borrowed(&self.signature)
500    }
501}
502
503impl DataflowOpTrait for DFG {
504    const TAG: OpTag = OpTag::Dfg;
505
506    fn description(&self) -> &str {
507        "A simply nested dataflow graph"
508    }
509
510    fn signature(&self) -> Cow<'_, Signature> {
511        self.inner_signature()
512    }
513
514    fn substitute(&self, subst: &Substitution) -> Self {
515        Self {
516            signature: self.signature.substitute(subst),
517        }
518    }
519}