hugr_core/ops/
dataflow.rs

1//! Dataflow operations.
2
3use std::borrow::Cow;
4
5use super::{OpTag, OpTrait, impl_op_name};
6
7use crate::extension::SignatureError;
8use crate::ops::StaticTag;
9use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow};
10use crate::{IncomingPort, type_row};
11
12#[cfg(test)]
13use {crate::types::proptest_utils::any_serde_type_arg_vec, proptest_derive::Arbitrary};
14
15/// Trait implemented by all dataflow operations.
16pub trait DataflowOpTrait: Sized {
17    /// Tag identifying the operation.
18    const TAG: OpTag;
19
20    /// A human-readable description of the operation.
21    fn description(&self) -> &str;
22
23    /// The signature of the operation.
24    fn signature(&self) -> Cow<'_, Signature>;
25
26    /// The edge kind for the non-dataflow or constant inputs of the operation,
27    /// not described by the signature.
28    ///
29    /// If not None, a single extra output multiport of that kind will be
30    /// present.
31    #[inline]
32    fn other_input(&self) -> Option<EdgeKind> {
33        Some(EdgeKind::StateOrder)
34    }
35    /// The edge kind for the non-dataflow outputs of the operation, not
36    /// described by the signature.
37    ///
38    /// If not None, a single extra output multiport of that kind will be
39    /// present.
40    #[inline]
41    fn other_output(&self) -> Option<EdgeKind> {
42        Some(EdgeKind::StateOrder)
43    }
44
45    /// The edge kind for a single constant input of the operation, not
46    /// described by the dataflow signature.
47    ///
48    /// If not None, an extra input port of that kind will be present after the
49    /// dataflow input ports and before any [`DataflowOpTrait::other_input`] ports.
50    #[inline]
51    fn static_input(&self) -> Option<EdgeKind> {
52        None
53    }
54
55    /// Apply a type-level substitution to this `OpType`, i.e. replace
56    /// [type variables](TypeArg::new_var_use) with new types.
57    fn substitute(&self, _subst: &Substitution) -> Self;
58}
59
60/// Helpers to construct input and output nodes
61pub trait IOTrait {
62    /// Construct a new I/O node from a type row with no extension requirements
63    fn new(types: impl Into<TypeRow>) -> Self;
64}
65
66/// An input node.
67/// The outputs of this node are the inputs to the function.
68#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
69#[cfg_attr(test, derive(Arbitrary))]
70pub struct Input {
71    /// Input value types
72    pub types: TypeRow,
73}
74
75impl_op_name!(Input);
76
77impl IOTrait for Input {
78    fn new(types: impl Into<TypeRow>) -> Self {
79        Input {
80            types: types.into(),
81        }
82    }
83}
84
85/// An output node. The inputs are the outputs of the function.
86#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
87#[cfg_attr(test, derive(Arbitrary))]
88pub struct Output {
89    /// Output value types
90    pub types: TypeRow,
91}
92
93impl_op_name!(Output);
94
95impl IOTrait for Output {
96    fn new(types: impl Into<TypeRow>) -> Self {
97        Output {
98            types: types.into(),
99        }
100    }
101}
102
103impl DataflowOpTrait for Input {
104    const TAG: OpTag = OpTag::Input;
105
106    fn description(&self) -> &'static str {
107        "The input node for this dataflow subgraph"
108    }
109
110    fn other_input(&self) -> Option<EdgeKind> {
111        None
112    }
113
114    fn signature(&self) -> Cow<'_, Signature> {
115        // TODO: Store a cached signature
116        Cow::Owned(Signature::new(TypeRow::new(), self.types.clone()))
117    }
118
119    fn substitute(&self, subst: &Substitution) -> Self {
120        Self {
121            types: self.types.substitute(subst),
122        }
123    }
124}
125impl DataflowOpTrait for Output {
126    const TAG: OpTag = OpTag::Output;
127
128    fn description(&self) -> &'static str {
129        "The output node for this dataflow subgraph"
130    }
131
132    // Note: We know what the input extensions should be, so we *could* give an
133    // instantiated Signature instead
134    fn signature(&self) -> Cow<'_, Signature> {
135        // TODO: Store a cached signature
136        Cow::Owned(Signature::new(self.types.clone(), TypeRow::new()))
137    }
138
139    fn other_output(&self) -> Option<EdgeKind> {
140        None
141    }
142
143    fn substitute(&self, subst: &Substitution) -> Self {
144        Self {
145            types: self.types.substitute(subst),
146        }
147    }
148}
149
150impl<T: DataflowOpTrait + Clone> OpTrait for T {
151    fn description(&self) -> &str {
152        DataflowOpTrait::description(self)
153    }
154
155    fn tag(&self) -> OpTag {
156        T::TAG
157    }
158
159    fn dataflow_signature(&self) -> Option<Cow<'_, Signature>> {
160        Some(DataflowOpTrait::signature(self))
161    }
162
163    fn other_input(&self) -> Option<EdgeKind> {
164        DataflowOpTrait::other_input(self)
165    }
166
167    fn other_output(&self) -> Option<EdgeKind> {
168        DataflowOpTrait::other_output(self)
169    }
170
171    fn static_input(&self) -> Option<EdgeKind> {
172        DataflowOpTrait::static_input(self)
173    }
174
175    fn substitute(&self, subst: &crate::types::Substitution) -> Self {
176        DataflowOpTrait::substitute(self, subst)
177    }
178}
179impl<T: DataflowOpTrait> StaticTag for T {
180    const TAG: OpTag = T::TAG;
181}
182
183/// Call a function directly.
184///
185/// The first ports correspond to the signature of the function being called.
186/// The port immediately following those those is connected to the def/declare
187/// block with a [`EdgeKind::Function`] edge.
188#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
189#[cfg_attr(test, derive(Arbitrary))]
190pub struct Call {
191    /// Signature of function being called.
192    pub func_sig: PolyFuncType,
193    /// The type arguments that instantiate `func_sig`.
194    #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
195    pub type_args: Vec<TypeArg>,
196    /// The instantiation of `func_sig`.
197    pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature()
198}
199impl_op_name!(Call);
200
201impl DataflowOpTrait for Call {
202    const TAG: OpTag = OpTag::FnCall;
203
204    fn description(&self) -> &'static str {
205        "Call a function directly"
206    }
207
208    fn signature(&self) -> Cow<'_, Signature> {
209        Cow::Borrowed(&self.instantiation)
210    }
211
212    fn static_input(&self) -> Option<EdgeKind> {
213        Some(EdgeKind::Function(self.called_function_type().clone()))
214    }
215
216    fn substitute(&self, subst: &Substitution) -> Self {
217        let type_args = self
218            .type_args
219            .iter()
220            .map(|ta| ta.substitute(subst))
221            .collect::<Vec<_>>();
222        let instantiation = self.instantiation.substitute(subst);
223        debug_assert_eq!(
224            self.func_sig.instantiate(&type_args).as_ref(),
225            Ok(&instantiation)
226        );
227        Self {
228            type_args,
229            instantiation,
230            func_sig: self.func_sig.clone(),
231        }
232    }
233}
234impl Call {
235    /// Try to make a new Call. Returns an error if the `type_args`` do not fit the [TypeParam]s
236    /// declared by the function.
237    ///
238    /// [TypeParam]: crate::types::type_param::TypeParam
239    pub fn try_new(
240        func_sig: PolyFuncType,
241        type_args: impl Into<Vec<TypeArg>>,
242    ) -> Result<Self, SignatureError> {
243        let type_args: Vec<_> = type_args.into();
244        let instantiation = func_sig.instantiate(&type_args)?;
245        Ok(Self {
246            func_sig,
247            type_args,
248            instantiation,
249        })
250    }
251
252    #[inline]
253    /// Return the signature of the function called by this op.
254    #[must_use]
255    pub fn called_function_type(&self) -> &PolyFuncType {
256        &self.func_sig
257    }
258
259    /// The `IncomingPort` which links to the function being called.
260    ///
261    /// This matches [`OpType::static_input_port`].
262    ///
263    /// ```
264    /// # use hugr::ops::dataflow::Call;
265    /// # use hugr::ops::OpType;
266    /// # use hugr::types::Signature;
267    /// # use hugr::extension::prelude::qb_t;
268    /// # use hugr::extension::PRELUDE_REGISTRY;
269    /// let signature = Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]);
270    /// let call = Call::try_new(signature.into(), &[]).unwrap();
271    /// let op = OpType::Call(call.clone());
272    /// assert_eq!(op.static_input_port(), Some(call.called_function_port()));
273    /// ```
274    ///
275    /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
276    #[inline]
277    #[must_use]
278    pub fn called_function_port(&self) -> IncomingPort {
279        self.instantiation.input_count().into()
280    }
281
282    pub(crate) fn validate(&self) -> Result<(), SignatureError> {
283        let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
284        if other.instantiation == self.instantiation {
285            Ok(())
286        } else {
287            Err(SignatureError::CallIncorrectlyAppliesType {
288                cached: Box::new(self.instantiation.clone()),
289                expected: Box::new(other.instantiation.clone()),
290            })
291        }
292    }
293}
294
295/// Call a function indirectly. Like call, but the function input is a value
296/// (runtime, not static) dataflow edge, and thus does not need any type-args.
297#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
298#[cfg_attr(test, derive(Arbitrary))]
299pub struct CallIndirect {
300    /// Signature of function being called
301    pub signature: Signature,
302}
303impl_op_name!(CallIndirect);
304
305impl DataflowOpTrait for CallIndirect {
306    const TAG: OpTag = OpTag::DataflowChild;
307
308    fn description(&self) -> &'static str {
309        "Call a function indirectly"
310    }
311
312    fn signature(&self) -> Cow<'_, Signature> {
313        // TODO: Store a cached signature
314        let mut s = self.signature.clone();
315        s.input
316            .to_mut()
317            .insert(0, Type::new_function(self.signature.clone()));
318        Cow::Owned(s)
319    }
320
321    fn substitute(&self, subst: &Substitution) -> Self {
322        Self {
323            signature: self.signature.substitute(subst),
324        }
325    }
326}
327
328/// Load a static constant in to the local dataflow graph.
329#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
330#[cfg_attr(test, derive(Arbitrary))]
331pub struct LoadConstant {
332    /// Constant type
333    pub datatype: Type,
334}
335impl_op_name!(LoadConstant);
336impl DataflowOpTrait for LoadConstant {
337    const TAG: OpTag = OpTag::LoadConst;
338
339    fn description(&self) -> &'static str {
340        "Load a static constant in to the local dataflow graph"
341    }
342
343    fn signature(&self) -> Cow<'_, Signature> {
344        // TODO: Store a cached signature
345        Cow::Owned(Signature::new(TypeRow::new(), vec![self.datatype.clone()]))
346    }
347
348    fn static_input(&self) -> Option<EdgeKind> {
349        Some(EdgeKind::Const(self.constant_type().clone()))
350    }
351
352    fn substitute(&self, _subst: &Substitution) -> Self {
353        // Constants cannot refer to TypeArgs, so neither can loading them
354        self.clone()
355    }
356}
357
358impl LoadConstant {
359    #[inline]
360    /// The type of the constant loaded by this op.
361    #[must_use]
362    pub fn constant_type(&self) -> &Type {
363        &self.datatype
364    }
365
366    /// The `IncomingPort` which links to the loaded constant.
367    ///
368    /// This matches [`OpType::static_input_port`].
369    ///
370    /// ```
371    /// # use hugr::ops::dataflow::LoadConstant;
372    /// # use hugr::ops::OpType;
373    /// # use hugr::types::Type;
374    /// let datatype = Type::UNIT;
375    /// let load_constant = LoadConstant { datatype };
376    /// let op = OpType::LoadConstant(load_constant.clone());
377    /// assert_eq!(op.static_input_port(), Some(load_constant.constant_port()));
378    /// ```
379    ///
380    /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
381    #[inline]
382    #[must_use]
383    pub fn constant_port(&self) -> IncomingPort {
384        0.into()
385    }
386}
387
388/// Load a static function in to the local dataflow graph.
389#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
390#[cfg_attr(test, derive(Arbitrary))]
391pub struct LoadFunction {
392    /// Signature of the function
393    pub func_sig: PolyFuncType,
394    /// The type arguments that instantiate `func_sig`.
395    #[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
396    pub type_args: Vec<TypeArg>,
397    /// The instantiation of `func_sig`.
398    pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature()
399}
400impl_op_name!(LoadFunction);
401impl DataflowOpTrait for LoadFunction {
402    const TAG: OpTag = OpTag::LoadFunc;
403
404    fn description(&self) -> &'static str {
405        "Load a static function in to the local dataflow graph"
406    }
407
408    fn signature(&self) -> Cow<'_, Signature> {
409        Cow::Owned(Signature::new(
410            type_row![],
411            Type::new_function(self.instantiation.clone()),
412        ))
413    }
414
415    fn static_input(&self) -> Option<EdgeKind> {
416        Some(EdgeKind::Function(self.func_sig.clone()))
417    }
418
419    fn substitute(&self, subst: &Substitution) -> Self {
420        let type_args = self
421            .type_args
422            .iter()
423            .map(|ta| ta.substitute(subst))
424            .collect::<Vec<_>>();
425        let instantiation = self.instantiation.substitute(subst);
426        debug_assert_eq!(
427            self.func_sig.instantiate(&type_args).as_ref(),
428            Ok(&instantiation)
429        );
430        Self {
431            func_sig: self.func_sig.clone(),
432            type_args,
433            instantiation,
434        }
435    }
436}
437impl LoadFunction {
438    /// Try to make a new LoadFunction op. Returns an error if the `type_args`` do not fit
439    /// the [TypeParam]s declared by the function.
440    ///
441    /// [TypeParam]: crate::types::type_param::TypeParam
442    pub fn try_new(
443        func_sig: PolyFuncType,
444        type_args: impl Into<Vec<TypeArg>>,
445    ) -> Result<Self, SignatureError> {
446        let type_args: Vec<_> = type_args.into();
447        let instantiation = func_sig.instantiate(&type_args)?;
448        Ok(Self {
449            func_sig,
450            type_args,
451            instantiation,
452        })
453    }
454
455    #[inline]
456    /// Return the type of the function loaded by this op.
457    #[must_use]
458    pub fn function_type(&self) -> &PolyFuncType {
459        &self.func_sig
460    }
461
462    /// The `IncomingPort` which links to the loaded function.
463    ///
464    /// This matches [`OpType::static_input_port`].
465    ///
466    /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port
467    #[inline]
468    #[must_use]
469    pub fn function_port(&self) -> IncomingPort {
470        0.into()
471    }
472
473    pub(crate) fn validate(&self) -> Result<(), SignatureError> {
474        let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
475        if other.instantiation == self.instantiation {
476            Ok(())
477        } else {
478            Err(SignatureError::LoadFunctionIncorrectlyAppliesType {
479                cached: Box::new(self.instantiation.clone()),
480                expected: Box::new(other.instantiation.clone()),
481            })
482        }
483    }
484}
485
486/// An operation that is the parent of a dataflow graph.
487///
488/// The children region contains an input and an output node matching the
489/// signature returned by [`DataflowParent::inner_signature`].
490pub trait DataflowParent {
491    /// Signature of the inner dataflow graph.
492    fn inner_signature(&self) -> Cow<'_, Signature>;
493}
494
495/// A simply nested dataflow graph.
496#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
497#[cfg_attr(test, derive(Arbitrary))]
498pub struct DFG {
499    /// Signature of DFG node
500    pub signature: Signature,
501}
502
503impl_op_name!(DFG);
504
505impl DataflowParent for DFG {
506    fn inner_signature(&self) -> Cow<'_, Signature> {
507        Cow::Borrowed(&self.signature)
508    }
509}
510
511impl DataflowOpTrait for DFG {
512    const TAG: OpTag = OpTag::Dfg;
513
514    fn description(&self) -> &'static str {
515        "A simply nested dataflow graph"
516    }
517
518    fn signature(&self) -> Cow<'_, Signature> {
519        self.inner_signature()
520    }
521
522    fn substitute(&self, subst: &Substitution) -> Self {
523        Self {
524            signature: self.signature.substitute(subst),
525        }
526    }
527}