use std::borrow::Cow;
use super::{OpTag, OpTrait, impl_op_name};
use crate::extension::SignatureError;
use crate::ops::StaticTag;
use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow};
use crate::{IncomingPort, type_row};
#[cfg(test)]
use {crate::types::proptest_utils::any_serde_type_arg_vec, proptest_derive::Arbitrary};
pub trait DataflowOpTrait: Sized {
const TAG: OpTag;
fn description(&self) -> &str;
fn signature(&self) -> Cow<'_, Signature>;
#[inline]
fn other_input(&self) -> Option<EdgeKind> {
Some(EdgeKind::StateOrder)
}
#[inline]
fn other_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::StateOrder)
}
#[inline]
fn static_input(&self) -> Option<EdgeKind> {
None
}
fn substitute(&self, _subst: &Substitution) -> Self;
}
pub trait IOTrait {
fn new(types: impl Into<TypeRow>) -> Self;
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct Input {
pub types: TypeRow,
}
impl_op_name!(Input);
impl IOTrait for Input {
fn new(types: impl Into<TypeRow>) -> Self {
Input {
types: types.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct Output {
pub types: TypeRow,
}
impl_op_name!(Output);
impl IOTrait for Output {
fn new(types: impl Into<TypeRow>) -> Self {
Output {
types: types.into(),
}
}
}
impl DataflowOpTrait for Input {
const TAG: OpTag = OpTag::Input;
fn description(&self) -> &'static str {
"The input node for this dataflow subgraph"
}
fn other_input(&self) -> Option<EdgeKind> {
None
}
fn signature(&self) -> Cow<'_, Signature> {
Cow::Owned(Signature::new(TypeRow::new(), self.types.clone()))
}
fn substitute(&self, subst: &Substitution) -> Self {
Self {
types: self.types.substitute(subst),
}
}
}
impl DataflowOpTrait for Output {
const TAG: OpTag = OpTag::Output;
fn description(&self) -> &'static str {
"The output node for this dataflow subgraph"
}
fn signature(&self) -> Cow<'_, Signature> {
Cow::Owned(Signature::new(self.types.clone(), TypeRow::new()))
}
fn other_output(&self) -> Option<EdgeKind> {
None
}
fn substitute(&self, subst: &Substitution) -> Self {
Self {
types: self.types.substitute(subst),
}
}
}
impl<T: DataflowOpTrait + Clone> OpTrait for T {
fn description(&self) -> &str {
DataflowOpTrait::description(self)
}
fn tag(&self) -> OpTag {
T::TAG
}
fn dataflow_signature(&self) -> Option<Cow<'_, Signature>> {
Some(DataflowOpTrait::signature(self))
}
fn other_input(&self) -> Option<EdgeKind> {
DataflowOpTrait::other_input(self)
}
fn other_output(&self) -> Option<EdgeKind> {
DataflowOpTrait::other_output(self)
}
fn static_input(&self) -> Option<EdgeKind> {
DataflowOpTrait::static_input(self)
}
fn substitute(&self, subst: &crate::types::Substitution) -> Self {
DataflowOpTrait::substitute(self, subst)
}
}
impl<T: DataflowOpTrait> StaticTag for T {
const TAG: OpTag = T::TAG;
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct Call {
pub func_sig: PolyFuncType,
#[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
pub type_args: Vec<TypeArg>,
pub instantiation: Signature, }
impl_op_name!(Call);
impl DataflowOpTrait for Call {
const TAG: OpTag = OpTag::FnCall;
fn description(&self) -> &'static str {
"Call a function directly"
}
fn signature(&self) -> Cow<'_, Signature> {
Cow::Borrowed(&self.instantiation)
}
fn static_input(&self) -> Option<EdgeKind> {
Some(EdgeKind::Function(self.called_function_type().clone()))
}
fn substitute(&self, subst: &Substitution) -> Self {
let type_args = self
.type_args
.iter()
.map(|ta| ta.substitute(subst))
.collect::<Vec<_>>();
let instantiation = self.instantiation.substitute(subst);
debug_assert_eq!(
self.func_sig.instantiate(&type_args).as_ref(),
Ok(&instantiation)
);
Self {
type_args,
instantiation,
func_sig: self.func_sig.clone(),
}
}
}
impl Call {
pub fn try_new(
func_sig: PolyFuncType,
type_args: impl Into<Vec<TypeArg>>,
) -> Result<Self, SignatureError> {
let type_args: Vec<_> = type_args.into();
let instantiation = func_sig.instantiate(&type_args)?;
Ok(Self {
func_sig,
type_args,
instantiation,
})
}
#[inline]
#[must_use]
pub fn called_function_type(&self) -> &PolyFuncType {
&self.func_sig
}
#[inline]
#[must_use]
pub fn called_function_port(&self) -> IncomingPort {
self.instantiation.input_count().into()
}
pub(crate) fn validate(&self) -> Result<(), SignatureError> {
let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
if other.instantiation == self.instantiation {
Ok(())
} else {
Err(SignatureError::CallIncorrectlyAppliesType {
cached: Box::new(self.instantiation.clone()),
expected: Box::new(other.instantiation.clone()),
})
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct CallIndirect {
pub signature: Signature,
}
impl_op_name!(CallIndirect);
impl DataflowOpTrait for CallIndirect {
const TAG: OpTag = OpTag::DataflowChild;
fn description(&self) -> &'static str {
"Call a function indirectly"
}
fn signature(&self) -> Cow<'_, Signature> {
let mut s = self.signature.clone();
s.input
.to_mut()
.insert(0, Type::new_function(self.signature.clone()));
Cow::Owned(s)
}
fn substitute(&self, subst: &Substitution) -> Self {
Self {
signature: self.signature.substitute(subst),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct LoadConstant {
pub datatype: Type,
}
impl_op_name!(LoadConstant);
impl DataflowOpTrait for LoadConstant {
const TAG: OpTag = OpTag::LoadConst;
fn description(&self) -> &'static str {
"Load a static constant in to the local dataflow graph"
}
fn signature(&self) -> Cow<'_, Signature> {
Cow::Owned(Signature::new(TypeRow::new(), vec![self.datatype.clone()]))
}
fn static_input(&self) -> Option<EdgeKind> {
Some(EdgeKind::Const(self.constant_type().clone()))
}
fn substitute(&self, _subst: &Substitution) -> Self {
self.clone()
}
}
impl LoadConstant {
#[inline]
#[must_use]
pub fn constant_type(&self) -> &Type {
&self.datatype
}
#[inline]
#[must_use]
pub fn constant_port(&self) -> IncomingPort {
0.into()
}
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct LoadFunction {
pub func_sig: PolyFuncType,
#[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
pub type_args: Vec<TypeArg>,
pub instantiation: Signature, }
impl_op_name!(LoadFunction);
impl DataflowOpTrait for LoadFunction {
const TAG: OpTag = OpTag::LoadFunc;
fn description(&self) -> &'static str {
"Load a static function in to the local dataflow graph"
}
fn signature(&self) -> Cow<'_, Signature> {
Cow::Owned(Signature::new(
type_row![],
[Type::new_function(self.instantiation.clone())],
))
}
fn static_input(&self) -> Option<EdgeKind> {
Some(EdgeKind::Function(self.func_sig.clone()))
}
fn substitute(&self, subst: &Substitution) -> Self {
let type_args = self
.type_args
.iter()
.map(|ta| ta.substitute(subst))
.collect::<Vec<_>>();
let instantiation = self.instantiation.substitute(subst);
debug_assert_eq!(
self.func_sig.instantiate(&type_args).as_ref(),
Ok(&instantiation)
);
Self {
func_sig: self.func_sig.clone(),
type_args,
instantiation,
}
}
}
impl LoadFunction {
pub fn try_new(
func_sig: PolyFuncType,
type_args: impl Into<Vec<TypeArg>>,
) -> Result<Self, SignatureError> {
let type_args: Vec<_> = type_args.into();
let instantiation = func_sig.instantiate(&type_args)?;
Ok(Self {
func_sig,
type_args,
instantiation,
})
}
#[inline]
#[must_use]
pub fn function_type(&self) -> &PolyFuncType {
&self.func_sig
}
#[inline]
#[must_use]
pub fn function_port(&self) -> IncomingPort {
0.into()
}
pub(crate) fn validate(&self) -> Result<(), SignatureError> {
let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?;
if other.instantiation == self.instantiation {
Ok(())
} else {
Err(SignatureError::LoadFunctionIncorrectlyAppliesType {
cached: Box::new(self.instantiation.clone()),
expected: Box::new(other.instantiation.clone()),
})
}
}
}
pub trait DataflowParent {
fn inner_signature(&self) -> Cow<'_, Signature>;
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct DFG {
pub signature: Signature,
}
impl_op_name!(DFG);
impl DataflowParent for DFG {
fn inner_signature(&self) -> Cow<'_, Signature> {
Cow::Borrowed(&self.signature)
}
}
impl DataflowOpTrait for DFG {
const TAG: OpTag = OpTag::Dfg;
fn description(&self) -> &'static str {
"A simply nested dataflow graph"
}
fn signature(&self) -> Cow<'_, Signature> {
self.inner_signature()
}
fn substitute(&self, subst: &Substitution) -> Self {
Self {
signature: self.signature.substitute(subst),
}
}
}