hugr_core/ops/
dataflow.rs

1//! Dataflow operations.
2
3use std::borrow::Cow;
4
5use super::{OpTag, OpTrait, impl_op_name};
6
7use crate::extension::SignatureError;
8use crate::ops::StaticTag;
9use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow};
10use crate::{IncomingPort, type_row};
11
12#[cfg(test)]
13use proptest_derive::Arbitrary;
14
15/// Trait implemented by all dataflow operations.
16pub trait DataflowOpTrait: Sized {
17    /// Tag identifying the operation.
18    const TAG: OpTag;
19
20    /// A human-readable description of the operation.
21    fn description(&self) -> &str;
22
23    /// The signature of the operation.
24    fn signature(&self) -> Cow<'_, Signature>;
25
26    /// The edge kind for the non-dataflow or constant inputs of the operation,
27    /// not described by the signature.
28    ///
29    /// If not None, a single extra output multiport of that kind will be
30    /// present.
31    #[inline]
32    fn other_input(&self) -> Option<EdgeKind> {
33        Some(EdgeKind::StateOrder)
34    }
35    /// The edge kind for the non-dataflow outputs of the operation, not
36    /// described by the signature.
37    ///
38    /// If not None, a single extra output multiport of that kind will be
39    /// present.
40    #[inline]
41    fn other_output(&self) -> Option<EdgeKind> {
42        Some(EdgeKind::StateOrder)
43    }
44
45    /// The edge kind for a single constant input of the operation, not
46    /// described by the dataflow signature.
47    ///
48    /// If not None, an extra input port of that kind will be present after the
49    /// dataflow input ports and before any [`DataflowOpTrait::other_input`] ports.
50    #[inline]
51    fn static_input(&self) -> Option<EdgeKind> {
52        None
53    }
54
55    /// Apply a type-level substitution to this `OpType`, i.e. replace
56    /// [type variables](TypeArg::new_var_use) with new types.
57    fn substitute(&self, _subst: &Substitution) -> Self;
58}
59
60/// Helpers to construct input and output nodes
61pub trait IOTrait {
62    /// Construct a new I/O node from a type row with no extension requirements
63    fn new(types: impl Into<TypeRow>) -> Self;
64}
65
66/// An input node.
67/// The outputs of this node are the inputs to the function.
68#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
69#[cfg_attr(test, derive(Arbitrary))]
70pub struct Input {
71    /// Input value types
72    pub types: TypeRow,
73}
74
75impl_op_name!(Input);
76
77impl IOTrait for Input {
78    fn new(types: impl Into<TypeRow>) -> Self {
79        Input {
80            types: types.into(),
81        }
82    }
83}
84
85/// An output node. The inputs are the outputs of the function.
86#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
87#[cfg_attr(test, derive(Arbitrary))]
88pub struct Output {
89    /// Output value types
90    pub types: TypeRow,
91}
92
93impl_op_name!(Output);
94
95impl IOTrait for Output {
96    fn new(types: impl Into<TypeRow>) -> Self {
97        Output {
98            types: types.into(),
99        }
100    }
101}
102
103impl DataflowOpTrait for Input {
104    const TAG: OpTag = OpTag::Input;
105
106    fn description(&self) -> &'static str {
107        "The input node for this dataflow subgraph"
108    }
109
110    fn other_input(&self) -> Option<EdgeKind> {
111        None
112    }
113
114    fn signature(&self) -> Cow<'_, Signature> {
115        // TODO: Store a cached signature
116        Cow::Owned(Signature::new(TypeRow::new(), self.types.clone()))
117    }
118
119    fn substitute(&self, subst: &Substitution) -> Self {
120        Self {
121            types: self.types.substitute(subst),
122        }
123    }
124}
125impl DataflowOpTrait for Output {
126    const TAG: OpTag = OpTag::Output;
127
128    fn description(&self) -> &'static str {
129        "The output node for this dataflow subgraph"
130    }
131
132    // Note: We know what the input extensions should be, so we *could* give an
133    // instantiated Signature instead
134    fn signature(&self) -> Cow<'_, Signature> {
135        // TODO: Store a cached signature
136        Cow::Owned(Signature::new(self.types.clone(), TypeRow::new()))
137    }
138
139    fn other_output(&self) -> Option<EdgeKind> {
140        None
141    }
142
143    fn substitute(&self, subst: &Substitution) -> Self {
144        Self {
145            types: self.types.substitute(subst),
146        }
147    }
148}
149
150impl<T: DataflowOpTrait + Clone> OpTrait for T {
151    fn description(&self) -> &str {
152        DataflowOpTrait::description(self)
153    }
154
155    fn tag(&self) -> OpTag {
156        T::TAG
157    }
158
159    fn dataflow_signature(&self) -> Option<Cow<'_, Signature>> {
160        Some(DataflowOpTrait::signature(self))
161    }
162
163    fn other_input(&self) -> Option<EdgeKind> {
164        DataflowOpTrait::other_input(self)
165    }
166
167    fn other_output(&self) -> Option<EdgeKind> {
168        DataflowOpTrait::other_output(self)
169    }
170
171    fn static_input(&self) -> Option<EdgeKind> {
172        DataflowOpTrait::static_input(self)
173    }
174
175    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
176        DataflowOpTrait::substitute(self, subst)
177    }
178}
179impl<T: DataflowOpTrait> StaticTag for T {
180    const TAG: OpTag = T::TAG;
181}
182
183/// Call a function directly.
184///
185/// The first ports correspond to the signature of the function being called.
186/// The port immediately following those those is connected to the def/declare
187/// block with a [`EdgeKind::Function`] edge.
188#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
189#[cfg_attr(test, derive(Arbitrary))]
190pub struct Call {
191    /// Signature of function being called.
192    pub func_sig: PolyFuncType,
193    /// The type arguments that instantiate `func_sig`.
194    pub type_args: Vec<TypeArg>,
195    /// The instantiation of `func_sig`.
196    pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature()
197}
198impl_op_name!(Call);
199
200impl DataflowOpTrait for Call {
201    const TAG: OpTag = OpTag::FnCall;
202
203    fn description(&self) -> &'static str {
204        "Call a function directly"
205    }
206
207    fn signature(&self) -> Cow<'_, Signature> {
208        Cow::Borrowed(&self.instantiation)
209    }
210
211    fn static_input(&self) -> Option<EdgeKind> {
212        Some(EdgeKind::Function(self.called_function_type().clone()))
213    }
214
215    fn substitute(&self, subst: &Substitution) -> Self {
216        let type_args = self
217            .type_args
218            .iter()
219            .map(|ta| ta.substitute(subst))
220            .collect::<Vec<_>>();
221        let instantiation = self.instantiation.substitute(subst);
222        debug_assert_eq!(
223            self.func_sig.instantiate(&type_args).as_ref(),
224            Ok(&instantiation)
225        );
226        Self {
227            type_args,
228            instantiation,
229            func_sig: self.func_sig.clone(),
230        }
231    }
232}
233impl Call {
234    /// Try to make a new Call. Returns an error if the `type_args`` do not fit the [TypeParam]s
235    /// declared by the function.
236    ///
237    /// [TypeParam]: crate::types::type_param::TypeParam
238    pub fn try_new(
239        func_sig: PolyFuncType,
240        type_args: impl Into<Vec<TypeArg>>,
241    ) -> Result<Self, SignatureError> {
242        let type_args: Vec<_> = type_args.into();
243        let instantiation = func_sig.instantiate(&type_args)?;
244        Ok(Self {
245            func_sig,
246            type_args,
247            instantiation,
248        })
249    }
250
251    #[inline]
252    /// Return the signature of the function called by this op.
253    #[must_use]
254    pub fn called_function_type(&self) -> &PolyFuncType {
255        &self.func_sig
256    }
257
258    /// The `IncomingPort` which links to the function being called.
259    ///
260    /// This matches [`OpType::static_input_port`].
261    ///
262    /// ```
263    /// # use hugr::ops::dataflow::Call;
264    /// # use hugr::ops::OpType;
265    /// # use hugr::types::Signature;
266    /// # use hugr::extension::prelude::qb_t;
267    /// # use hugr::extension::PRELUDE_REGISTRY;
268    /// let signature = Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]);
269    /// let call = Call::try_new(signature.into(), &[]).unwrap();
270    /// let op = OpType::Call(call.clone());
271    /// assert_eq!(op.static_input_port(), Some(call.called_function_port()));
272    /// ```
273    ///
274    /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
275    #[inline]
276    #[must_use]
277    pub fn called_function_port(&self) -> IncomingPort {
278        self.instantiation.input_count().into()
279    }
280
281    pub(crate) fn validate(&self) -> Result<(), SignatureError> {
282        let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
283        if other.instantiation == self.instantiation {
284            Ok(())
285        } else {
286            Err(SignatureError::CallIncorrectlyAppliesType {
287                cached: self.instantiation.clone(),
288                expected: other.instantiation.clone(),
289            })
290        }
291    }
292}
293
294/// Call a function indirectly. Like call, but the function input is a value
295/// (runtime, not static) dataflow edge, and thus does not need any type-args.
296#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
297#[cfg_attr(test, derive(Arbitrary))]
298pub struct CallIndirect {
299    /// Signature of function being called
300    pub signature: Signature,
301}
302impl_op_name!(CallIndirect);
303
304impl DataflowOpTrait for CallIndirect {
305    const TAG: OpTag = OpTag::DataflowChild;
306
307    fn description(&self) -> &'static str {
308        "Call a function indirectly"
309    }
310
311    fn signature(&self) -> Cow<'_, Signature> {
312        // TODO: Store a cached signature
313        let mut s = self.signature.clone();
314        s.input
315            .to_mut()
316            .insert(0, Type::new_function(self.signature.clone()));
317        Cow::Owned(s)
318    }
319
320    fn substitute(&self, subst: &Substitution) -> Self {
321        Self {
322            signature: self.signature.substitute(subst),
323        }
324    }
325}
326
327/// Load a static constant in to the local dataflow graph.
328#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
329#[cfg_attr(test, derive(Arbitrary))]
330pub struct LoadConstant {
331    /// Constant type
332    pub datatype: Type,
333}
334impl_op_name!(LoadConstant);
335impl DataflowOpTrait for LoadConstant {
336    const TAG: OpTag = OpTag::LoadConst;
337
338    fn description(&self) -> &'static str {
339        "Load a static constant in to the local dataflow graph"
340    }
341
342    fn signature(&self) -> Cow<'_, Signature> {
343        // TODO: Store a cached signature
344        Cow::Owned(Signature::new(TypeRow::new(), vec![self.datatype.clone()]))
345    }
346
347    fn static_input(&self) -> Option<EdgeKind> {
348        Some(EdgeKind::Const(self.constant_type().clone()))
349    }
350
351    fn substitute(&self, _subst: &Substitution) -> Self {
352        // Constants cannot refer to TypeArgs, so neither can loading them
353        self.clone()
354    }
355}
356
357impl LoadConstant {
358    #[inline]
359    /// The type of the constant loaded by this op.
360    #[must_use]
361    pub fn constant_type(&self) -> &Type {
362        &self.datatype
363    }
364
365    /// The `IncomingPort` which links to the loaded constant.
366    ///
367    /// This matches [`OpType::static_input_port`].
368    ///
369    /// ```
370    /// # use hugr::ops::dataflow::LoadConstant;
371    /// # use hugr::ops::OpType;
372    /// # use hugr::types::Type;
373    /// let datatype = Type::UNIT;
374    /// let load_constant = LoadConstant { datatype };
375    /// let op = OpType::LoadConstant(load_constant.clone());
376    /// assert_eq!(op.static_input_port(), Some(load_constant.constant_port()));
377    /// ```
378    ///
379    /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
380    #[inline]
381    #[must_use]
382    pub fn constant_port(&self) -> IncomingPort {
383        0.into()
384    }
385}
386
387/// Load a static function in to the local dataflow graph.
388#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
389#[cfg_attr(test, derive(Arbitrary))]
390pub struct LoadFunction {
391    /// Signature of the function
392    pub func_sig: PolyFuncType,
393    /// The type arguments that instantiate `func_sig`.
394    pub type_args: Vec<TypeArg>,
395    /// The instantiation of `func_sig`.
396    pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature()
397}
398impl_op_name!(LoadFunction);
399impl DataflowOpTrait for LoadFunction {
400    const TAG: OpTag = OpTag::LoadFunc;
401
402    fn description(&self) -> &'static str {
403        "Load a static function in to the local dataflow graph"
404    }
405
406    fn signature(&self) -> Cow<'_, Signature> {
407        Cow::Owned(Signature::new(
408            type_row![],
409            Type::new_function(self.instantiation.clone()),
410        ))
411    }
412
413    fn static_input(&self) -> Option<EdgeKind> {
414        Some(EdgeKind::Function(self.func_sig.clone()))
415    }
416
417    fn substitute(&self, subst: &Substitution) -> Self {
418        let type_args = self
419            .type_args
420            .iter()
421            .map(|ta| ta.substitute(subst))
422            .collect::<Vec<_>>();
423        let instantiation = self.instantiation.substitute(subst);
424        debug_assert_eq!(
425            self.func_sig.instantiate(&type_args).as_ref(),
426            Ok(&instantiation)
427        );
428        Self {
429            func_sig: self.func_sig.clone(),
430            type_args,
431            instantiation,
432        }
433    }
434}
435impl LoadFunction {
436    /// Try to make a new LoadFunction op. Returns an error if the `type_args`` do not fit
437    /// the [TypeParam]s declared by the function.
438    ///
439    /// [TypeParam]: crate::types::type_param::TypeParam
440    pub fn try_new(
441        func_sig: PolyFuncType,
442        type_args: impl Into<Vec<TypeArg>>,
443    ) -> Result<Self, SignatureError> {
444        let type_args: Vec<_> = type_args.into();
445        let instantiation = func_sig.instantiate(&type_args)?;
446        Ok(Self {
447            func_sig,
448            type_args,
449            instantiation,
450        })
451    }
452
453    #[inline]
454    /// Return the type of the function loaded by this op.
455    #[must_use]
456    pub fn function_type(&self) -> &PolyFuncType {
457        &self.func_sig
458    }
459
460    /// The `IncomingPort` which links to the loaded function.
461    ///
462    /// This matches [`OpType::static_input_port`].
463    ///
464    /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
465    #[inline]
466    #[must_use]
467    pub fn function_port(&self) -> IncomingPort {
468        0.into()
469    }
470
471    pub(crate) fn validate(&self) -> Result<(), SignatureError> {
472        let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
473        if other.instantiation == self.instantiation {
474            Ok(())
475        } else {
476            Err(SignatureError::LoadFunctionIncorrectlyAppliesType {
477                cached: self.instantiation.clone(),
478                expected: other.instantiation.clone(),
479            })
480        }
481    }
482}
483
484/// An operation that is the parent of a dataflow graph.
485///
486/// The children region contains an input and an output node matching the
487/// signature returned by [`DataflowParent::inner_signature`].
488pub trait DataflowParent {
489    /// Signature of the inner dataflow graph.
490    fn inner_signature(&self) -> Cow<'_, Signature>;
491}
492
493/// A simply nested dataflow graph.
494#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
495#[cfg_attr(test, derive(Arbitrary))]
496pub struct DFG {
497    /// Signature of DFG node
498    pub signature: Signature,
499}
500
501impl_op_name!(DFG);
502
503impl DataflowParent for DFG {
504    fn inner_signature(&self) -> Cow<'_, Signature> {
505        Cow::Borrowed(&self.signature)
506    }
507}
508
509impl DataflowOpTrait for DFG {
510    const TAG: OpTag = OpTag::Dfg;
511
512    fn description(&self) -> &'static str {
513        "A simply nested dataflow graph"
514    }
515
516    fn signature(&self) -> Cow<'_, Signature> {
517        self.inner_signature()
518    }
519
520    fn substitute(&self, subst: &Substitution) -> Self {
521        Self {
522            signature: self.signature.substitute(subst),
523        }
524    }
525}