1use std::borrow::Cow;
4
5use super::{OpTag, OpTrait, impl_op_name};
6
7use crate::extension::SignatureError;
8use crate::ops::StaticTag;
9use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow};
10use crate::{IncomingPort, type_row};
11
12#[cfg(test)]
13use proptest_derive::Arbitrary;
14
15pub trait DataflowOpTrait: Sized {
17 const TAG: OpTag;
19
20 fn description(&self) -> &str;
22
23 fn signature(&self) -> Cow<'_, Signature>;
25
26 #[inline]
32 fn other_input(&self) -> Option<EdgeKind> {
33 Some(EdgeKind::StateOrder)
34 }
35 #[inline]
41 fn other_output(&self) -> Option<EdgeKind> {
42 Some(EdgeKind::StateOrder)
43 }
44
45 #[inline]
51 fn static_input(&self) -> Option<EdgeKind> {
52 None
53 }
54
55 fn substitute(&self, _subst: &Substitution) -> Self;
58}
59
60pub trait IOTrait {
62 fn new(types: impl Into<TypeRow>) -> Self;
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
69#[cfg_attr(test, derive(Arbitrary))]
70pub struct Input {
71 pub types: TypeRow,
73}
74
75impl_op_name!(Input);
76
77impl IOTrait for Input {
78 fn new(types: impl Into<TypeRow>) -> Self {
79 Input {
80 types: types.into(),
81 }
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
87#[cfg_attr(test, derive(Arbitrary))]
88pub struct Output {
89 pub types: TypeRow,
91}
92
93impl_op_name!(Output);
94
95impl IOTrait for Output {
96 fn new(types: impl Into<TypeRow>) -> Self {
97 Output {
98 types: types.into(),
99 }
100 }
101}
102
103impl DataflowOpTrait for Input {
104 const TAG: OpTag = OpTag::Input;
105
106 fn description(&self) -> &'static str {
107 "The input node for this dataflow subgraph"
108 }
109
110 fn other_input(&self) -> Option<EdgeKind> {
111 None
112 }
113
114 fn signature(&self) -> Cow<'_, Signature> {
115 Cow::Owned(Signature::new(TypeRow::new(), self.types.clone()))
117 }
118
119 fn substitute(&self, subst: &Substitution) -> Self {
120 Self {
121 types: self.types.substitute(subst),
122 }
123 }
124}
125impl DataflowOpTrait for Output {
126 const TAG: OpTag = OpTag::Output;
127
128 fn description(&self) -> &'static str {
129 "The output node for this dataflow subgraph"
130 }
131
132 fn signature(&self) -> Cow<'_, Signature> {
135 Cow::Owned(Signature::new(self.types.clone(), TypeRow::new()))
137 }
138
139 fn other_output(&self) -> Option<EdgeKind> {
140 None
141 }
142
143 fn substitute(&self, subst: &Substitution) -> Self {
144 Self {
145 types: self.types.substitute(subst),
146 }
147 }
148}
149
150impl<T: DataflowOpTrait + Clone> OpTrait for T {
151 fn description(&self) -> &str {
152 DataflowOpTrait::description(self)
153 }
154
155 fn tag(&self) -> OpTag {
156 T::TAG
157 }
158
159 fn dataflow_signature(&self) -> Option<Cow<'_, Signature>> {
160 Some(DataflowOpTrait::signature(self))
161 }
162
163 fn other_input(&self) -> Option<EdgeKind> {
164 DataflowOpTrait::other_input(self)
165 }
166
167 fn other_output(&self) -> Option<EdgeKind> {
168 DataflowOpTrait::other_output(self)
169 }
170
171 fn static_input(&self) -> Option<EdgeKind> {
172 DataflowOpTrait::static_input(self)
173 }
174
175 fn substitute(&self, subst: &crate::types::Substitution) -> Self {
176 DataflowOpTrait::substitute(self, subst)
177 }
178}
179impl<T: DataflowOpTrait> StaticTag for T {
180 const TAG: OpTag = T::TAG;
181}
182
183#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
189#[cfg_attr(test, derive(Arbitrary))]
190pub struct Call {
191 pub func_sig: PolyFuncType,
193 pub type_args: Vec<TypeArg>,
195 pub instantiation: Signature, }
198impl_op_name!(Call);
199
200impl DataflowOpTrait for Call {
201 const TAG: OpTag = OpTag::FnCall;
202
203 fn description(&self) -> &'static str {
204 "Call a function directly"
205 }
206
207 fn signature(&self) -> Cow<'_, Signature> {
208 Cow::Borrowed(&self.instantiation)
209 }
210
211 fn static_input(&self) -> Option<EdgeKind> {
212 Some(EdgeKind::Function(self.called_function_type().clone()))
213 }
214
215 fn substitute(&self, subst: &Substitution) -> Self {
216 let type_args = self
217 .type_args
218 .iter()
219 .map(|ta| ta.substitute(subst))
220 .collect::<Vec<_>>();
221 let instantiation = self.instantiation.substitute(subst);
222 debug_assert_eq!(
223 self.func_sig.instantiate(&type_args).as_ref(),
224 Ok(&instantiation)
225 );
226 Self {
227 type_args,
228 instantiation,
229 func_sig: self.func_sig.clone(),
230 }
231 }
232}
233impl Call {
234 pub fn try_new(
239 func_sig: PolyFuncType,
240 type_args: impl Into<Vec<TypeArg>>,
241 ) -> Result<Self, SignatureError> {
242 let type_args: Vec<_> = type_args.into();
243 let instantiation = func_sig.instantiate(&type_args)?;
244 Ok(Self {
245 func_sig,
246 type_args,
247 instantiation,
248 })
249 }
250
251 #[inline]
252 #[must_use]
254 pub fn called_function_type(&self) -> &PolyFuncType {
255 &self.func_sig
256 }
257
258 #[inline]
276 #[must_use]
277 pub fn called_function_port(&self) -> IncomingPort {
278 self.instantiation.input_count().into()
279 }
280
281 pub(crate) fn validate(&self) -> Result<(), SignatureError> {
282 let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
283 if other.instantiation == self.instantiation {
284 Ok(())
285 } else {
286 Err(SignatureError::CallIncorrectlyAppliesType {
287 cached: self.instantiation.clone(),
288 expected: other.instantiation.clone(),
289 })
290 }
291 }
292}
293
294#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
297#[cfg_attr(test, derive(Arbitrary))]
298pub struct CallIndirect {
299 pub signature: Signature,
301}
302impl_op_name!(CallIndirect);
303
304impl DataflowOpTrait for CallIndirect {
305 const TAG: OpTag = OpTag::DataflowChild;
306
307 fn description(&self) -> &'static str {
308 "Call a function indirectly"
309 }
310
311 fn signature(&self) -> Cow<'_, Signature> {
312 let mut s = self.signature.clone();
314 s.input
315 .to_mut()
316 .insert(0, Type::new_function(self.signature.clone()));
317 Cow::Owned(s)
318 }
319
320 fn substitute(&self, subst: &Substitution) -> Self {
321 Self {
322 signature: self.signature.substitute(subst),
323 }
324 }
325}
326
327#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
329#[cfg_attr(test, derive(Arbitrary))]
330pub struct LoadConstant {
331 pub datatype: Type,
333}
334impl_op_name!(LoadConstant);
335impl DataflowOpTrait for LoadConstant {
336 const TAG: OpTag = OpTag::LoadConst;
337
338 fn description(&self) -> &'static str {
339 "Load a static constant in to the local dataflow graph"
340 }
341
342 fn signature(&self) -> Cow<'_, Signature> {
343 Cow::Owned(Signature::new(TypeRow::new(), vec![self.datatype.clone()]))
345 }
346
347 fn static_input(&self) -> Option<EdgeKind> {
348 Some(EdgeKind::Const(self.constant_type().clone()))
349 }
350
351 fn substitute(&self, _subst: &Substitution) -> Self {
352 self.clone()
354 }
355}
356
357impl LoadConstant {
358 #[inline]
359 #[must_use]
361 pub fn constant_type(&self) -> &Type {
362 &self.datatype
363 }
364
365 #[inline]
381 #[must_use]
382 pub fn constant_port(&self) -> IncomingPort {
383 0.into()
384 }
385}
386
387#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
389#[cfg_attr(test, derive(Arbitrary))]
390pub struct LoadFunction {
391 pub func_sig: PolyFuncType,
393 pub type_args: Vec<TypeArg>,
395 pub instantiation: Signature, }
398impl_op_name!(LoadFunction);
399impl DataflowOpTrait for LoadFunction {
400 const TAG: OpTag = OpTag::LoadFunc;
401
402 fn description(&self) -> &'static str {
403 "Load a static function in to the local dataflow graph"
404 }
405
406 fn signature(&self) -> Cow<'_, Signature> {
407 Cow::Owned(Signature::new(
408 type_row![],
409 Type::new_function(self.instantiation.clone()),
410 ))
411 }
412
413 fn static_input(&self) -> Option<EdgeKind> {
414 Some(EdgeKind::Function(self.func_sig.clone()))
415 }
416
417 fn substitute(&self, subst: &Substitution) -> Self {
418 let type_args = self
419 .type_args
420 .iter()
421 .map(|ta| ta.substitute(subst))
422 .collect::<Vec<_>>();
423 let instantiation = self.instantiation.substitute(subst);
424 debug_assert_eq!(
425 self.func_sig.instantiate(&type_args).as_ref(),
426 Ok(&instantiation)
427 );
428 Self {
429 func_sig: self.func_sig.clone(),
430 type_args,
431 instantiation,
432 }
433 }
434}
435impl LoadFunction {
436 pub fn try_new(
441 func_sig: PolyFuncType,
442 type_args: impl Into<Vec<TypeArg>>,
443 ) -> Result<Self, SignatureError> {
444 let type_args: Vec<_> = type_args.into();
445 let instantiation = func_sig.instantiate(&type_args)?;
446 Ok(Self {
447 func_sig,
448 type_args,
449 instantiation,
450 })
451 }
452
453 #[inline]
454 #[must_use]
456 pub fn function_type(&self) -> &PolyFuncType {
457 &self.func_sig
458 }
459
460 #[inline]
466 #[must_use]
467 pub fn function_port(&self) -> IncomingPort {
468 0.into()
469 }
470
471 pub(crate) fn validate(&self) -> Result<(), SignatureError> {
472 let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
473 if other.instantiation == self.instantiation {
474 Ok(())
475 } else {
476 Err(SignatureError::LoadFunctionIncorrectlyAppliesType {
477 cached: self.instantiation.clone(),
478 expected: other.instantiation.clone(),
479 })
480 }
481 }
482}
483
484pub trait DataflowParent {
489 fn inner_signature(&self) -> Cow<'_, Signature>;
491}
492
493#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
495#[cfg_attr(test, derive(Arbitrary))]
496pub struct DFG {
497 pub signature: Signature,
499}
500
501impl_op_name!(DFG);
502
503impl DataflowParent for DFG {
504 fn inner_signature(&self) -> Cow<'_, Signature> {
505 Cow::Borrowed(&self.signature)
506 }
507}
508
509impl DataflowOpTrait for DFG {
510 const TAG: OpTag = OpTag::Dfg;
511
512 fn description(&self) -> &'static str {
513 "A simply nested dataflow graph"
514 }
515
516 fn signature(&self) -> Cow<'_, Signature> {
517 self.inner_signature()
518 }
519
520 fn substitute(&self, subst: &Substitution) -> Self {
521 Self {
522 signature: self.signature.substitute(subst),
523 }
524 }
525}