use crate::extension::ExtensionSet;
use crate::types::{EdgeKind, Signature, Type, TypeRow};
use crate::Direction;
use super::dataflow::{DataflowOpTrait, DataflowParent};
use super::{impl_op_name, NamedOp, OpTrait, StaticTag};
use super::{OpName, OpTag};
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct TailLoop {
pub just_inputs: TypeRow,
pub just_outputs: TypeRow,
pub rest: TypeRow,
pub extension_delta: ExtensionSet,
}
impl_op_name!(TailLoop);
impl DataflowOpTrait for TailLoop {
const TAG: OpTag = OpTag::TailLoop;
fn description(&self) -> &str {
"A tail-controlled loop"
}
fn signature(&self) -> Signature {
let [inputs, outputs] =
[&self.just_inputs, &self.just_outputs].map(|row| row.extend(self.rest.iter()));
Signature::new(inputs, outputs).with_extension_delta(self.extension_delta.clone())
}
}
impl TailLoop {
pub(crate) fn body_output_row(&self) -> TypeRow {
let sum_type = Type::new_sum([self.just_inputs.clone(), self.just_outputs.clone()]);
let mut outputs = vec![sum_type];
outputs.extend_from_slice(&self.rest);
outputs.into()
}
pub(crate) fn body_input_row(&self) -> TypeRow {
self.just_inputs.extend(self.rest.iter())
}
}
impl DataflowParent for TailLoop {
fn inner_signature(&self) -> Signature {
Signature::new(self.body_input_row(), self.body_output_row())
.with_extension_delta(self.extension_delta.clone())
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Conditional {
pub sum_rows: Vec<TypeRow>,
pub other_inputs: TypeRow,
pub outputs: TypeRow,
pub extension_delta: ExtensionSet,
}
impl_op_name!(Conditional);
impl DataflowOpTrait for Conditional {
const TAG: OpTag = OpTag::Conditional;
fn description(&self) -> &str {
"HUGR conditional operation"
}
fn signature(&self) -> Signature {
let mut inputs = self.other_inputs.clone();
inputs
.to_mut()
.insert(0, Type::new_sum(self.sum_rows.clone()));
Signature::new(inputs, self.outputs.clone())
.with_extension_delta(self.extension_delta.clone())
}
}
impl Conditional {
pub(crate) fn case_input_row(&self, case: usize) -> Option<TypeRow> {
Some(self.sum_rows.get(case)?.extend(self.other_inputs.iter()))
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[allow(missing_docs)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct CFG {
pub signature: Signature,
}
impl_op_name!(CFG);
impl DataflowOpTrait for CFG {
const TAG: OpTag = OpTag::Cfg;
fn description(&self) -> &str {
"A dataflow node defined by a child CFG"
}
fn signature(&self) -> Signature {
self.signature.clone()
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
#[allow(missing_docs)]
pub struct DataflowBlock {
pub inputs: TypeRow,
pub other_outputs: TypeRow,
pub sum_rows: Vec<TypeRow>,
pub extension_delta: ExtensionSet,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct ExitBlock {
pub cfg_outputs: TypeRow,
}
impl NamedOp for DataflowBlock {
fn name(&self) -> OpName {
"DataflowBlock".into()
}
}
impl NamedOp for ExitBlock {
fn name(&self) -> OpName {
"ExitBlock".into()
}
}
impl StaticTag for DataflowBlock {
const TAG: OpTag = OpTag::DataflowBlock;
}
impl StaticTag for ExitBlock {
const TAG: OpTag = OpTag::BasicBlockExit;
}
impl DataflowParent for DataflowBlock {
fn inner_signature(&self) -> Signature {
let sum_type = Type::new_sum(self.sum_rows.clone());
let mut node_outputs = vec![sum_type];
node_outputs.extend_from_slice(&self.other_outputs);
Signature::new(self.inputs.clone(), TypeRow::from(node_outputs))
.with_extension_delta(self.extension_delta.clone())
}
}
impl OpTrait for DataflowBlock {
fn description(&self) -> &str {
"A CFG basic block node"
}
fn tag(&self) -> OpTag {
Self::TAG
}
fn other_input(&self) -> Option<EdgeKind> {
Some(EdgeKind::ControlFlow)
}
fn other_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::ControlFlow)
}
fn extension_delta(&self) -> ExtensionSet {
self.extension_delta.clone()
}
fn non_df_port_count(&self, dir: Direction) -> usize {
match dir {
Direction::Incoming => 1,
Direction::Outgoing => self.sum_rows.len(),
}
}
}
impl OpTrait for ExitBlock {
fn description(&self) -> &str {
"A CFG exit block node"
}
fn tag(&self) -> OpTag {
Self::TAG
}
fn other_input(&self) -> Option<EdgeKind> {
Some(EdgeKind::ControlFlow)
}
fn other_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::ControlFlow)
}
fn non_df_port_count(&self, dir: Direction) -> usize {
match dir {
Direction::Incoming => 1,
Direction::Outgoing => 0,
}
}
}
pub trait BasicBlock {
fn dataflow_input(&self) -> &TypeRow;
}
impl BasicBlock for DataflowBlock {
fn dataflow_input(&self) -> &TypeRow {
&self.inputs
}
}
impl DataflowBlock {
pub fn successor_input(&self, successor: usize) -> Option<TypeRow> {
Some(
self.sum_rows
.get(successor)?
.extend(self.other_outputs.iter()),
)
}
}
impl BasicBlock for ExitBlock {
fn dataflow_input(&self) -> &TypeRow {
&self.cfg_outputs
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub struct Case {
pub signature: Signature,
}
impl_op_name!(Case);
impl StaticTag for Case {
const TAG: OpTag = OpTag::Case;
}
impl DataflowParent for Case {
fn inner_signature(&self) -> Signature {
self.signature.clone()
}
}
impl OpTrait for Case {
fn description(&self) -> &str {
"A case node inside a conditional"
}
fn extension_delta(&self) -> ExtensionSet {
self.signature.extension_reqs.clone()
}
fn tag(&self) -> OpTag {
<Self as StaticTag>::TAG
}
}
impl Case {
pub fn dataflow_input(&self) -> &TypeRow {
&self.signature.input
}
pub fn dataflow_output(&self) -> &TypeRow {
&self.signature.output
}
}