pub mod constant;
pub mod controlflow;
pub mod custom;
pub mod dataflow;
pub mod handle;
pub mod module;
pub mod sum;
pub mod tag;
pub mod validate;
use crate::core::HugrNode;
use crate::extension::resolution::{
ExtensionCollectionError, collect_op_extension, collect_op_types_extensions,
};
use std::borrow::Cow;
use std::cmp::Ordering;
use crate::extension::simple_op::MakeExtensionOp;
use crate::extension::{ExtensionId, ExtensionRegistry};
use crate::types::{EdgeKind, Signature, Substitution};
use crate::{Direction, Node, OutgoingPort, Port};
use crate::{IncomingPort, PortIndex};
use handle::NodeHandle;
use pastey::paste;
use enum_dispatch::enum_dispatch;
pub use constant::{Const, Value};
pub use controlflow::{BasicBlock, CFG, Case, Conditional, DataflowBlock, ExitBlock, TailLoop};
pub use custom::{ExtensionOp, OpaqueOp};
pub use dataflow::{
Call, CallIndirect, DFG, DataflowOpTrait, DataflowParent, Input, LoadConstant, LoadFunction,
Output,
};
pub use module::{AliasDecl, AliasDefn, FuncDecl, FuncDefn, Module};
use smol_str::SmolStr;
pub use sum::Tag;
pub use tag::OpTag;
#[enum_dispatch(OpTrait, NamedOp, ValidateOp, OpParent)]
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
#[non_exhaustive]
#[allow(missing_docs)]
#[serde(tag = "op")]
pub enum OpType {
Module,
FuncDefn,
FuncDecl,
AliasDecl,
AliasDefn,
Const,
Input,
Output,
Call,
CallIndirect,
LoadConstant,
LoadFunction,
DFG,
#[serde(skip_deserializing, rename = "Extension")]
ExtensionOp,
#[serde(rename = "Extension")]
OpaqueOp,
Tag,
DataflowBlock,
ExitBlock,
TailLoop,
CFG,
Conditional,
Case,
}
fn optype_id(optype: &OpType) -> usize {
match optype {
OpType::Module(_) => 0,
OpType::FuncDefn(_) => 1,
OpType::FuncDecl(_) => 2,
OpType::AliasDecl(_) => 3,
OpType::AliasDefn(_) => 4,
OpType::Const(_) => 5,
OpType::Input(_) => 6,
OpType::Output(_) => 7,
OpType::Call(_) => 8,
OpType::CallIndirect(_) => 9,
OpType::LoadConstant(_) => 10,
OpType::LoadFunction(_) => 11,
OpType::DFG(_) => 12,
OpType::ExtensionOp(_) => 13,
OpType::OpaqueOp(_) => 14,
OpType::Tag(_) => 15,
OpType::DataflowBlock(_) => 16,
OpType::ExitBlock(_) => 17,
OpType::TailLoop(_) => 18,
OpType::CFG(_) => 19,
OpType::Conditional(_) => 20,
OpType::Case(_) => 21,
}
}
impl PartialOrd for OpType {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
let a = optype_id(self);
let b = optype_id(other);
if a < b {
Some(Ordering::Less)
} else if a > b {
Some(Ordering::Greater)
} else {
match format!("{:?}", self).cmp(&format!("{:?}", other)) {
Ordering::Less => Some(Ordering::Less),
Ordering::Greater => Some(Ordering::Greater),
Ordering::Equal => None,
}
}
}
}
macro_rules! impl_op_ref_try_into {
($Op: tt, $sname:ident) => {
paste! {
impl OpType {
#[doc = "If is an instance of `" $Op "` return a reference to it."]
#[must_use] pub fn [<as_ $sname:snake>](&self) -> Option<&$Op> {
TryInto::<&$Op>::try_into(self).ok()
}
#[doc = "Returns `true` if the operation is an instance of `" $Op "`."]
#[must_use] pub fn [<is_ $sname:snake>](&self) -> bool {
self.[<as_ $sname:snake>]().is_some()
}
}
impl<'a> TryFrom<&'a OpType> for &'a $Op {
type Error = ();
fn try_from(optype: &'a OpType) -> Result<Self, Self::Error> {
if let OpType::$Op(l) = optype {
Ok(l)
} else {
Err(())
}
}
}
}
};
($Op:tt) => {
impl_op_ref_try_into!($Op, $Op);
};
}
impl_op_ref_try_into!(Module);
impl_op_ref_try_into!(FuncDefn);
impl_op_ref_try_into!(FuncDecl);
impl_op_ref_try_into!(AliasDecl);
impl_op_ref_try_into!(AliasDefn);
impl_op_ref_try_into!(Const);
impl_op_ref_try_into!(Input);
impl_op_ref_try_into!(Output);
impl_op_ref_try_into!(Call);
impl_op_ref_try_into!(CallIndirect);
impl_op_ref_try_into!(LoadConstant);
impl_op_ref_try_into!(LoadFunction);
impl_op_ref_try_into!(DFG, dfg);
impl_op_ref_try_into!(ExtensionOp);
impl_op_ref_try_into!(Tag);
impl_op_ref_try_into!(DataflowBlock);
impl_op_ref_try_into!(ExitBlock);
impl_op_ref_try_into!(TailLoop);
impl_op_ref_try_into!(CFG, cfg);
impl_op_ref_try_into!(Conditional);
impl_op_ref_try_into!(Case);
pub const DEFAULT_OPTYPE: OpType = OpType::Module(Module::new());
impl Default for OpType {
fn default() -> Self {
DEFAULT_OPTYPE
}
}
impl std::fmt::Display for OpType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
impl OpType {
#[inline]
#[must_use]
pub fn other_port_kind(&self, dir: Direction) -> Option<EdgeKind> {
match dir {
Direction::Incoming => self.other_input(),
Direction::Outgoing => self.other_output(),
}
}
#[inline]
#[must_use]
pub fn static_port_kind(&self, dir: Direction) -> Option<EdgeKind> {
match dir {
Direction::Incoming => self.static_input(),
Direction::Outgoing => self.static_output(),
}
}
pub fn port_kind(&self, port: impl Into<Port>) -> Option<EdgeKind> {
let signature = self.dataflow_signature().unwrap_or_default();
let port: Port = port.into();
let dir = port.direction();
let port_count = signature.port_count(dir);
if port.index() < port_count {
return signature.port_type(port).cloned().map(EdgeKind::Value);
}
let static_kind = self.static_port_kind(dir);
if port.index() == port_count
&& let Some(kind) = static_kind
{
return Some(kind);
}
self.other_port_kind(dir)
}
#[must_use]
pub fn other_port(&self, dir: Direction) -> Option<Port> {
let df_count = self.value_port_count(dir);
let non_df_count = self.non_df_port_count(dir);
let static_input =
usize::from(dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag()));
if self.other_port_kind(dir).is_some() && non_df_count >= 1 {
Some(Port::new(dir, df_count + static_input))
} else {
None
}
}
#[inline]
#[must_use]
pub fn other_input_port(&self) -> Option<IncomingPort> {
self.other_port(Direction::Incoming)
.map(|p| p.as_incoming().unwrap())
}
#[inline]
#[must_use]
pub fn other_output_port(&self) -> Option<OutgoingPort> {
self.other_port(Direction::Outgoing)
.map(|p| p.as_outgoing().unwrap())
}
#[inline]
#[must_use]
pub fn static_port(&self, dir: Direction) -> Option<Port> {
self.static_port_kind(dir)?;
Some(Port::new(dir, self.value_port_count(dir)))
}
#[inline]
#[must_use]
pub fn static_input_port(&self) -> Option<IncomingPort> {
self.static_port(Direction::Incoming)
.map(|p| p.as_incoming().unwrap())
}
#[inline]
#[must_use]
pub fn static_output_port(&self) -> Option<OutgoingPort> {
self.static_port(Direction::Outgoing)
.map(|p| p.as_outgoing().unwrap())
}
#[inline]
#[must_use]
pub fn value_ports(&self, dir: Direction) -> impl DoubleEndedIterator<Item = Port> {
(0..self.value_port_count(dir)).map(move |i| Port::new(dir, i))
}
#[inline]
#[must_use]
pub fn value_input_ports(&self) -> impl DoubleEndedIterator<Item = IncomingPort> {
self.value_ports(Direction::Incoming)
.map(|p| p.as_incoming().unwrap())
}
#[inline]
#[must_use]
pub fn value_output_ports(&self) -> impl DoubleEndedIterator<Item = OutgoingPort> {
self.value_ports(Direction::Outgoing)
.map(|p| p.as_outgoing().unwrap())
}
#[inline]
#[must_use]
pub fn value_port_count(&self, dir: portgraph::Direction) -> usize {
self.dataflow_signature()
.map_or(0, |sig| sig.port_count(dir))
}
#[inline]
#[must_use]
pub fn value_input_count(&self) -> usize {
self.value_port_count(Direction::Incoming)
}
#[inline]
#[must_use]
pub fn value_output_count(&self) -> usize {
self.value_port_count(Direction::Outgoing)
}
#[inline]
#[must_use]
pub fn port_count(&self, dir: Direction) -> usize {
let has_static_port = self.static_port_kind(dir).is_some();
let non_df_count = self.non_df_port_count(dir);
self.value_port_count(dir) + usize::from(has_static_port) + non_df_count
}
#[inline]
#[must_use]
pub fn input_count(&self) -> usize {
self.port_count(Direction::Incoming)
}
#[inline]
#[must_use]
pub fn output_count(&self) -> usize {
self.port_count(Direction::Outgoing)
}
#[inline]
#[must_use]
pub fn is_container(&self) -> bool {
self.validity_flags::<Node>().allowed_children != OpTag::None
}
pub fn cast<T: MakeExtensionOp>(&self) -> Option<T> {
self.as_extension_op().and_then(ExtensionOp::cast)
}
#[must_use]
pub fn extension_id(&self) -> Option<&ExtensionId> {
match self {
OpType::OpaqueOp(opaque) => Some(opaque.extension()),
OpType::ExtensionOp(e) => Some(e.def().extension_id()),
_ => None,
}
}
pub fn used_extensions(&self) -> Result<ExtensionRegistry, ExtensionCollectionError> {
let mut reg = collect_op_types_extensions(None, self)?;
if let Some(ext) = collect_op_extension(None, self)? {
reg.register_updated(ext);
}
reg.extend_with_dependencies()?;
Ok(reg)
}
}
macro_rules! impl_op_name {
($i: ident) => {
impl $crate::ops::NamedOp for $i {
fn name(&self) -> $crate::ops::OpName {
stringify!($i).into()
}
}
};
}
use impl_op_name;
pub type OpName = SmolStr;
pub type OpNameRef = str;
#[enum_dispatch]
pub(crate) trait NamedOp {
fn name(&self) -> OpName;
}
pub trait StaticTag {
const TAG: OpTag;
}
#[enum_dispatch]
pub trait OpTrait: Sized + Clone {
fn description(&self) -> &str;
fn tag(&self) -> OpTag;
fn try_node_handle<N, H>(&self, node: N) -> Option<H>
where
N: HugrNode,
H: NodeHandle<N> + From<N>,
{
H::TAG.is_superset(self.tag()).then(|| node.into())
}
fn dataflow_signature(&self) -> Option<Cow<'_, Signature>> {
None
}
fn other_input(&self) -> Option<EdgeKind> {
None
}
fn other_output(&self) -> Option<EdgeKind> {
None
}
fn static_input(&self) -> Option<EdgeKind> {
None
}
fn static_output(&self) -> Option<EdgeKind> {
None
}
fn non_df_port_count(&self, dir: Direction) -> usize {
usize::from(
match dir {
Direction::Incoming => self.other_input(),
Direction::Outgoing => self.other_output(),
}
.is_some(),
)
}
fn substitute(&self, _subst: &Substitution) -> Self {
self.clone()
}
}
#[enum_dispatch]
pub trait OpParent {
fn inner_function_type(&self) -> Option<Cow<'_, Signature>> {
None
}
}
impl<T: DataflowParent> OpParent for T {
fn inner_function_type(&self) -> Option<Cow<'_, Signature>> {
Some(DataflowParent::inner_signature(self))
}
}
impl OpParent for Module {}
impl OpParent for AliasDecl {}
impl OpParent for AliasDefn {}
impl OpParent for Const {}
impl OpParent for Input {}
impl OpParent for Output {}
impl OpParent for Call {}
impl OpParent for CallIndirect {}
impl OpParent for LoadConstant {}
impl OpParent for LoadFunction {}
impl OpParent for ExtensionOp {}
impl OpParent for OpaqueOp {}
impl OpParent for Tag {}
impl OpParent for CFG {}
impl OpParent for Conditional {}
impl OpParent for FuncDecl {}
impl OpParent for ExitBlock {}
#[enum_dispatch]
pub trait ValidateOp {
#[inline]
fn validity_flags<N: HugrNode>(&self) -> validate::OpValidityFlags<N> {
Default::default()
}
#[inline]
fn validate_op_children<'a, N: HugrNode>(
&self,
_children: impl DoubleEndedIterator<Item = (N, &'a OpType)>,
) -> Result<(), validate::ChildrenValidationError<N>> {
Ok(())
}
}
macro_rules! impl_validate_op {
($i: ident) => {
impl $crate::ops::ValidateOp for $i {}
};
}
use impl_validate_op;