1use std::borrow::Cow;
4
5use super::{impl_op_name, OpTag, OpTrait};
6
7use crate::extension::{ExtensionSet, SignatureError};
8use crate::ops::StaticTag;
9use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow};
10use crate::{type_row, IncomingPort};
11
12#[cfg(test)]
13use proptest_derive::Arbitrary;
14
15pub trait DataflowOpTrait: Sized {
17 const TAG: OpTag;
19
20 fn description(&self) -> &str;
22
23 fn signature(&self) -> Cow<'_, Signature>;
25
26 #[inline]
32 fn other_input(&self) -> Option<EdgeKind> {
33 Some(EdgeKind::StateOrder)
34 }
35 #[inline]
41 fn other_output(&self) -> Option<EdgeKind> {
42 Some(EdgeKind::StateOrder)
43 }
44
45 #[inline]
51 fn static_input(&self) -> Option<EdgeKind> {
52 None
53 }
54
55 fn substitute(&self, _subst: &Substitution) -> Self;
58}
59
60pub trait IOTrait {
62 fn new(types: impl Into<TypeRow>) -> Self;
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
69#[cfg_attr(test, derive(Arbitrary))]
70pub struct Input {
71 pub types: TypeRow,
73}
74
75impl_op_name!(Input);
76
77impl IOTrait for Input {
78 fn new(types: impl Into<TypeRow>) -> Self {
79 Input {
80 types: types.into(),
81 }
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
87#[cfg_attr(test, derive(Arbitrary))]
88pub struct Output {
89 pub types: TypeRow,
91}
92
93impl_op_name!(Output);
94
95impl IOTrait for Output {
96 fn new(types: impl Into<TypeRow>) -> Self {
97 Output {
98 types: types.into(),
99 }
100 }
101}
102
103impl DataflowOpTrait for Input {
104 const TAG: OpTag = OpTag::Input;
105
106 fn description(&self) -> &str {
107 "The input node for this dataflow subgraph"
108 }
109
110 fn other_input(&self) -> Option<EdgeKind> {
111 None
112 }
113
114 fn signature(&self) -> Cow<'_, Signature> {
115 Cow::Owned(Signature::new(TypeRow::new(), self.types.clone()))
117 }
118
119 fn substitute(&self, subst: &Substitution) -> Self {
120 Self {
121 types: self.types.substitute(subst),
122 }
123 }
124}
125impl DataflowOpTrait for Output {
126 const TAG: OpTag = OpTag::Output;
127
128 fn description(&self) -> &str {
129 "The output node for this dataflow subgraph"
130 }
131
132 fn signature(&self) -> Cow<'_, Signature> {
135 Cow::Owned(Signature::new(self.types.clone(), TypeRow::new()))
137 }
138
139 fn other_output(&self) -> Option<EdgeKind> {
140 None
141 }
142
143 fn substitute(&self, subst: &Substitution) -> Self {
144 Self {
145 types: self.types.substitute(subst),
146 }
147 }
148}
149
150impl<T: DataflowOpTrait + Clone> OpTrait for T {
151 fn description(&self) -> &str {
152 DataflowOpTrait::description(self)
153 }
154 fn tag(&self) -> OpTag {
155 T::TAG
156 }
157 fn dataflow_signature(&self) -> Option<Cow<'_, Signature>> {
158 Some(DataflowOpTrait::signature(self))
159 }
160 fn extension_delta(&self) -> ExtensionSet {
161 DataflowOpTrait::signature(self).runtime_reqs.clone()
162 }
163 fn other_input(&self) -> Option<EdgeKind> {
164 DataflowOpTrait::other_input(self)
165 }
166
167 fn other_output(&self) -> Option<EdgeKind> {
168 DataflowOpTrait::other_output(self)
169 }
170
171 fn static_input(&self) -> Option<EdgeKind> {
172 DataflowOpTrait::static_input(self)
173 }
174
175 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
176 DataflowOpTrait::substitute(self, subst)
177 }
178}
179impl<T: DataflowOpTrait> StaticTag for T {
180 const TAG: OpTag = T::TAG;
181}
182
183#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
189#[cfg_attr(test, derive(Arbitrary))]
190pub struct Call {
191 pub func_sig: PolyFuncType,
193 pub type_args: Vec<TypeArg>,
195 pub instantiation: Signature, }
198impl_op_name!(Call);
199
200impl DataflowOpTrait for Call {
201 const TAG: OpTag = OpTag::FnCall;
202
203 fn description(&self) -> &str {
204 "Call a function directly"
205 }
206
207 fn signature(&self) -> Cow<'_, Signature> {
208 Cow::Borrowed(&self.instantiation)
209 }
210
211 fn static_input(&self) -> Option<EdgeKind> {
212 Some(EdgeKind::Function(self.called_function_type().clone()))
213 }
214
215 fn substitute(&self, subst: &Substitution) -> Self {
216 let type_args = self
217 .type_args
218 .iter()
219 .map(|ta| ta.substitute(subst))
220 .collect::<Vec<_>>();
221 let instantiation = self.instantiation.substitute(subst);
222 debug_assert_eq!(
223 self.func_sig.instantiate(&type_args).as_ref(),
224 Ok(&instantiation)
225 );
226 Self {
227 type_args,
228 instantiation,
229 func_sig: self.func_sig.clone(),
230 }
231 }
232}
233impl Call {
234 pub fn try_new(
239 func_sig: PolyFuncType,
240 type_args: impl Into<Vec<TypeArg>>,
241 ) -> Result<Self, SignatureError> {
242 let type_args: Vec<_> = type_args.into();
243 let instantiation = func_sig.instantiate(&type_args)?;
244 Ok(Self {
245 func_sig,
246 type_args,
247 instantiation,
248 })
249 }
250
251 #[inline]
252 pub fn called_function_type(&self) -> &PolyFuncType {
254 &self.func_sig
255 }
256
257 #[inline]
275 pub fn called_function_port(&self) -> IncomingPort {
276 self.instantiation.input_count().into()
277 }
278
279 pub(crate) fn validate(&self) -> Result<(), SignatureError> {
280 let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
281 if other.instantiation == self.instantiation {
282 Ok(())
283 } else {
284 Err(SignatureError::CallIncorrectlyAppliesType {
285 cached: self.instantiation.clone(),
286 expected: other.instantiation.clone(),
287 })
288 }
289 }
290}
291
292#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
295#[cfg_attr(test, derive(Arbitrary))]
296pub struct CallIndirect {
297 pub signature: Signature,
299}
300impl_op_name!(CallIndirect);
301
302impl DataflowOpTrait for CallIndirect {
303 const TAG: OpTag = OpTag::FnCall;
304
305 fn description(&self) -> &str {
306 "Call a function indirectly"
307 }
308
309 fn signature(&self) -> Cow<'_, Signature> {
310 let mut s = self.signature.clone();
312 s.input
313 .to_mut()
314 .insert(0, Type::new_function(self.signature.clone()));
315 Cow::Owned(s)
316 }
317
318 fn substitute(&self, subst: &Substitution) -> Self {
319 Self {
320 signature: self.signature.substitute(subst),
321 }
322 }
323}
324
325#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
327#[cfg_attr(test, derive(Arbitrary))]
328pub struct LoadConstant {
329 pub datatype: Type,
331}
332impl_op_name!(LoadConstant);
333impl DataflowOpTrait for LoadConstant {
334 const TAG: OpTag = OpTag::LoadConst;
335
336 fn description(&self) -> &str {
337 "Load a static constant in to the local dataflow graph"
338 }
339
340 fn signature(&self) -> Cow<'_, Signature> {
341 Cow::Owned(Signature::new(TypeRow::new(), vec![self.datatype.clone()]))
343 }
344
345 fn static_input(&self) -> Option<EdgeKind> {
346 Some(EdgeKind::Const(self.constant_type().clone()))
347 }
348
349 fn substitute(&self, _subst: &Substitution) -> Self {
350 self.clone()
352 }
353}
354
355impl LoadConstant {
356 #[inline]
357 pub fn constant_type(&self) -> &Type {
359 &self.datatype
360 }
361
362 #[inline]
378 pub fn constant_port(&self) -> IncomingPort {
379 0.into()
380 }
381}
382
383#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
385#[cfg_attr(test, derive(Arbitrary))]
386pub struct LoadFunction {
387 pub func_sig: PolyFuncType,
389 pub type_args: Vec<TypeArg>,
391 pub instantiation: Signature, }
394impl_op_name!(LoadFunction);
395impl DataflowOpTrait for LoadFunction {
396 const TAG: OpTag = OpTag::LoadFunc;
397
398 fn description(&self) -> &str {
399 "Load a static function in to the local dataflow graph"
400 }
401
402 fn signature(&self) -> Cow<'_, Signature> {
403 Cow::Owned(Signature::new(
404 type_row![],
405 Type::new_function(self.instantiation.clone()),
406 ))
407 }
408
409 fn static_input(&self) -> Option<EdgeKind> {
410 Some(EdgeKind::Function(self.func_sig.clone()))
411 }
412
413 fn substitute(&self, subst: &Substitution) -> Self {
414 let type_args = self
415 .type_args
416 .iter()
417 .map(|ta| ta.substitute(subst))
418 .collect::<Vec<_>>();
419 let instantiation = self.instantiation.substitute(subst);
420 debug_assert_eq!(
421 self.func_sig.instantiate(&type_args).as_ref(),
422 Ok(&instantiation)
423 );
424 Self {
425 func_sig: self.func_sig.clone(),
426 type_args,
427 instantiation,
428 }
429 }
430}
431impl LoadFunction {
432 pub fn try_new(
437 func_sig: PolyFuncType,
438 type_args: impl Into<Vec<TypeArg>>,
439 ) -> Result<Self, SignatureError> {
440 let type_args: Vec<_> = type_args.into();
441 let instantiation = func_sig.instantiate(&type_args)?;
442 Ok(Self {
443 func_sig,
444 type_args,
445 instantiation,
446 })
447 }
448
449 #[inline]
450 pub fn function_type(&self) -> &PolyFuncType {
452 &self.func_sig
453 }
454
455 #[inline]
461 pub fn function_port(&self) -> IncomingPort {
462 0.into()
463 }
464
465 pub(crate) fn validate(&self) -> Result<(), SignatureError> {
466 let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
467 if other.instantiation == self.instantiation {
468 Ok(())
469 } else {
470 Err(SignatureError::LoadFunctionIncorrectlyAppliesType {
471 cached: self.instantiation.clone(),
472 expected: other.instantiation.clone(),
473 })
474 }
475 }
476}
477
478pub trait DataflowParent {
483 fn inner_signature(&self) -> Cow<'_, Signature>;
485}
486
487#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
489#[cfg_attr(test, derive(Arbitrary))]
490pub struct DFG {
491 pub signature: Signature,
493}
494
495impl_op_name!(DFG);
496
497impl DataflowParent for DFG {
498 fn inner_signature(&self) -> Cow<'_, Signature> {
499 Cow::Borrowed(&self.signature)
500 }
501}
502
503impl DataflowOpTrait for DFG {
504 const TAG: OpTag = OpTag::Dfg;
505
506 fn description(&self) -> &str {
507 "A simply nested dataflow graph"
508 }
509
510 fn signature(&self) -> Cow<'_, Signature> {
511 self.inner_signature()
512 }
513
514 fn substitute(&self, subst: &Substitution) -> Self {
515 Self {
516 signature: self.signature.substitute(subst),
517 }
518 }
519}