use std::{cmp::Ordering, hash::Hash, marker::PhantomData, ops::Deref};
use hugr::{
hugr::{views::HierarchyView, HugrError},
ops::{DataflowBlock, ExitBlock, Input, NamedOp, OpType, Output, CFG},
types::Type,
Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort,
};
use itertools::Itertools as _;
#[derive(Debug)]
pub struct FatNode<'c, OT = OpType, H = Hugr>
where
H: ?Sized,
{
hugr: &'c H,
node: Node,
marker: PhantomData<OT>,
}
impl<'c, OT, H: HugrView + ?Sized> FatNode<'c, OT, H>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
pub fn new(hugr: &'c H, node: Node, #[allow(unused)] ot: &OT) -> Self {
assert!(hugr.valid_node(node));
assert!(TryInto::<&OT>::try_into(hugr.get_optype(node)).is_ok());
Self {
hugr,
node,
marker: PhantomData,
}
}
pub fn try_new(hugr: &'c H, node: Node) -> Option<Self> {
(hugr.valid_node(node)).then_some(())?;
Some(Self::new(
hugr,
node,
hugr.get_optype(node).try_into().ok()?,
))
}
pub fn generalise(self) -> FatNode<'c, OpType, H> {
FatNode {
hugr: self.hugr,
node: self.node,
marker: PhantomData,
}
}
}
impl<'c, OT, H> FatNode<'c, OT, H> {
pub fn node(&self) -> Node {
self.node
}
pub fn hugr(&self) -> &'c H {
self.hugr
}
}
impl<'c, H: HugrView + ?Sized> FatNode<'c, OpType, H> {
pub fn new_optype(hugr: &'c H, node: Node) -> Self {
assert!(hugr.valid_node(node));
FatNode::new(hugr, node, hugr.get_optype(node))
}
pub fn try_into_ot<OT>(&self) -> Option<FatNode<'c, OT, H>>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
FatNode::try_new(self.hugr, self.node)
}
pub fn into_ot<OT>(self, ot: &OT) -> FatNode<'c, OT, H>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
FatNode::new(self.hugr, self.node, ot)
}
}
impl<'c, OT, H: HugrView + ?Sized> FatNode<'c, OT, H> {
pub fn single_linked_output(
&self,
port: IncomingPort,
) -> Option<(FatNode<'c, OpType, H>, OutgoingPort)> {
self.hugr
.single_linked_output(self.node, port)
.map(|(n, p)| (FatNode::new_optype(self.hugr, n), p))
}
pub fn out_value_types(&self) -> impl Iterator<Item = (OutgoingPort, Type)> + 'c {
self.hugr.out_value_types(self.node)
}
pub fn in_value_types(&self) -> impl Iterator<Item = (IncomingPort, Type)> + 'c {
self.hugr.in_value_types(self.node)
}
pub fn children(&self) -> impl Iterator<Item = FatNode<'c, OpType, H>> + 'c {
self.hugr
.children(self.node)
.map(|n| FatNode::new_optype(self.hugr, n))
}
pub fn get_io(&self) -> Option<(FatNode<'c, Input, H>, FatNode<'c, Output, H>)> {
let [i, o] = self.hugr.get_io(self.node)?;
Some((
FatNode::try_new(self.hugr, i)?,
FatNode::try_new(self.hugr, o)?,
))
}
pub fn node_outputs(&self) -> impl Iterator<Item = OutgoingPort> + 'c {
self.hugr.node_outputs(self.node)
}
pub fn output_neighbours(&self) -> impl Iterator<Item = FatNode<'c, OpType, H>> + 'c {
self.hugr
.output_neighbours(self.node)
.map(|n| FatNode::new_optype(self.hugr, n))
}
pub fn try_new_hierarchy_view<HV: HierarchyView<'c>>(&self) -> Result<HV, HugrError>
where
H: Sized,
{
HV::try_new(self.hugr, self.node)
}
}
impl<'c, H: HugrView> FatNode<'c, CFG, H> {
pub fn get_entry_exit(&self) -> (FatNode<'c, DataflowBlock, H>, FatNode<'c, ExitBlock, H>) {
let [i, o] = self
.hugr
.children(self.node)
.take(2)
.collect_vec()
.try_into()
.unwrap();
(
FatNode::try_new(self.hugr, i).unwrap(),
FatNode::try_new(self.hugr, o).unwrap(),
)
}
}
impl<OT, H> PartialEq<Node> for FatNode<'_, OT, H> {
fn eq(&self, other: &Node) -> bool {
&self.node == other
}
}
impl<OT, H> PartialEq<FatNode<'_, OT, H>> for Node {
fn eq(&self, other: &FatNode<'_, OT, H>) -> bool {
self == &other.node
}
}
impl<OT1, OT2, H1, H2> PartialEq<FatNode<'_, OT1, H1>> for FatNode<'_, OT2, H2> {
fn eq(&self, other: &FatNode<'_, OT1, H1>) -> bool {
self.node == other.node
}
}
impl<OT, H> Eq for FatNode<'_, OT, H> {}
impl<OT, H> PartialOrd<Node> for FatNode<'_, OT, H> {
fn partial_cmp(&self, other: &Node) -> Option<Ordering> {
self.node.partial_cmp(other)
}
}
impl<OT, H> PartialOrd<FatNode<'_, OT, H>> for Node {
fn partial_cmp(&self, other: &FatNode<'_, OT, H>) -> Option<Ordering> {
self.partial_cmp(&other.node)
}
}
impl<OT1, OT2, H1, H2> PartialOrd<FatNode<'_, OT1, H1>> for FatNode<'_, OT2, H2> {
fn partial_cmp(&self, other: &FatNode<'_, OT1, H1>) -> Option<Ordering> {
self.partial_cmp(&other.node)
}
}
impl<OT, H> Ord for FatNode<'_, OT, H> {
fn cmp(&self, other: &Self) -> Ordering {
self.node.cmp(&other.node)
}
}
impl<OT, H> Hash for FatNode<'_, OT, H> {
fn hash<HA: std::hash::Hasher>(&self, state: &mut HA) {
self.node.hash(state);
}
}
impl<OT, H: HugrView + ?Sized> AsRef<OT> for FatNode<'_, OT, H>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
fn as_ref(&self) -> &OT {
self.hugr.get_optype(self.node).try_into().ok().unwrap()
}
}
impl<OT, H: HugrView + ?Sized> Deref for FatNode<'_, OT, H>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
type Target = OT;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
impl<OT, H> Copy for FatNode<'_, OT, H> {}
impl<OT, H> Clone for FatNode<'_, OT, H> {
fn clone(&self) -> Self {
*self
}
}
impl<'c, OT: NamedOp, H: HugrView + ?Sized> std::fmt::Display for FatNode<'c, OT, H>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("N<{}:{}>", self.as_ref().name(), self.node))
}
}
impl<OT, H> NodeIndex for FatNode<'_, OT, H> {
fn index(self) -> usize {
self.node.index()
}
}
impl<OT, H> NodeIndex for &FatNode<'_, OT, H> {
fn index(self) -> usize {
self.node.index()
}
}
pub trait FatExt: HugrView {
fn try_fat<OT>(&self, node: Node) -> Option<FatNode<OT, Self>>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
FatNode::try_new(self, node)
}
fn fat_optype(&self, node: Node) -> FatNode<OpType, Self> {
FatNode::new_optype(self, node)
}
fn fat_io(&self, node: Node) -> Option<(FatNode<Input, Self>, FatNode<Output, Self>)> {
self.fat_optype(node).get_io()
}
fn fat_children(&self, node: Node) -> impl Iterator<Item = FatNode<OpType, Self>> {
self.children(node).map(|x| self.fat_optype(x))
}
fn fat_root<OT>(&self) -> Option<FatNode<OT, Self>>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
self.try_fat(self.root())
}
}
impl<H: HugrView> FatExt for H {}