use std::{cmp::Ordering, fmt, hash::Hash, marker::PhantomData, ops::Deref};
use hugr_core::hugr::views::Rerooted;
use hugr_core::{
Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort,
core::HugrNode,
ops::{CFG, DataflowBlock, ExitBlock, Input, Module, OpType, Output},
types::Type,
};
use itertools::Itertools as _;
#[derive(Debug)]
pub struct FatNode<'hugr, OT = OpType, H = Hugr, N = Node>
where
H: ?Sized,
{
hugr: &'hugr H,
node: N,
marker: PhantomData<OT>,
}
impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
pub fn new(hugr: &'hugr H, node: H::Node, #[allow(unused)] ot: &OT) -> Self {
assert!(hugr.contains_node(node));
assert!(TryInto::<&OT>::try_into(hugr.get_optype(node)).is_ok());
Self {
hugr,
node,
marker: PhantomData,
}
}
pub fn try_new(hugr: &'hugr H, node: H::Node) -> Option<Self> {
(hugr.contains_node(node)).then_some(())?;
Some(Self::new(
hugr,
node,
hugr.get_optype(node).try_into().ok()?,
))
}
pub fn generalise(self) -> FatNode<'hugr, OpType, H, H::Node> {
FatNode {
hugr: self.hugr,
node: self.node,
marker: PhantomData,
}
}
}
impl<'hugr, OT, H, N: HugrNode> FatNode<'hugr, OT, H, N> {
pub fn node(&self) -> N {
self.node
}
pub fn hugr(&self) -> &'hugr H {
self.hugr
}
}
impl<'hugr, H: HugrView + ?Sized> FatNode<'hugr, OpType, H, H::Node> {
pub fn new_optype(hugr: &'hugr H, node: H::Node) -> Self {
assert!(hugr.contains_node(node));
FatNode::new(hugr, node, hugr.get_optype(node))
}
pub fn try_into_ot<OT>(&self) -> Option<FatNode<'hugr, OT, H, H::Node>>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
FatNode::try_new(self.hugr, self.node)
}
pub fn into_ot<OT>(self, ot: &OT) -> FatNode<'hugr, OT, H, H::Node>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
FatNode::new(self.hugr, self.node, ot)
}
}
impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node> {
#[allow(clippy::type_complexity)]
pub fn single_linked_output(
&self,
port: IncomingPort,
) -> Option<(FatNode<'hugr, OpType, H, H::Node>, OutgoingPort)> {
self.hugr
.single_linked_output(self.node, port)
.map(|(n, p)| (FatNode::new_optype(self.hugr, n), p))
}
pub fn out_value_types(
&self,
) -> impl Iterator<Item = (OutgoingPort, Type)> + 'hugr + use<'hugr, OT, H> {
self.hugr.out_value_types(self.node)
}
pub fn in_value_types(
&self,
) -> impl Iterator<Item = (IncomingPort, Type)> + 'hugr + use<'hugr, OT, H> {
self.hugr.in_value_types(self.node)
}
pub fn children(
&self,
) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr + use<'hugr, OT, H> {
self.hugr
.children(self.node)
.map(|n| FatNode::new_optype(self.hugr, n))
}
#[allow(clippy::type_complexity)]
pub fn get_io(
&self,
) -> Option<(
FatNode<'hugr, Input, H, H::Node>,
FatNode<'hugr, Output, H, H::Node>,
)> {
let [i, o] = self.hugr.get_io(self.node)?;
Some((
FatNode::try_new(self.hugr, i)?,
FatNode::try_new(self.hugr, o)?,
))
}
pub fn node_outputs(&self) -> impl Iterator<Item = OutgoingPort> + 'hugr + use<'hugr, OT, H> {
self.hugr.node_outputs(self.node)
}
pub fn output_neighbours(
&self,
) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr + use<'hugr, OT, H> {
self.hugr
.output_neighbours(self.node)
.map(|n| FatNode::new_optype(self.hugr, n))
}
pub fn as_entrypoint(&self) -> Rerooted<&H>
where
H: Sized,
{
self.hugr.with_entrypoint(self.node)
}
}
impl<'hugr, H: HugrView> FatNode<'hugr, CFG, H, H::Node> {
#[allow(clippy::type_complexity)]
pub fn get_entry_exit(
&self,
) -> (
FatNode<'hugr, DataflowBlock, H, H::Node>,
FatNode<'hugr, ExitBlock, H, H::Node>,
) {
let [i, o] = self
.hugr
.children(self.node)
.take(2)
.collect_vec()
.try_into()
.unwrap();
(
FatNode::try_new(self.hugr, i).unwrap(),
FatNode::try_new(self.hugr, o).unwrap(),
)
}
}
impl<OT, H> PartialEq<Node> for FatNode<'_, OT, H, Node> {
fn eq(&self, other: &Node) -> bool {
&self.node == other
}
}
impl<OT, H> PartialEq<FatNode<'_, OT, H, Node>> for Node {
fn eq(&self, other: &FatNode<'_, OT, H, Node>) -> bool {
self == &other.node
}
}
impl<N: PartialEq, OT1, OT2, H1, H2> PartialEq<FatNode<'_, OT1, H1, N>>
for FatNode<'_, OT2, H2, N>
{
fn eq(&self, other: &FatNode<'_, OT1, H1, N>) -> bool {
self.node == other.node
}
}
impl<N: Eq, OT, H> Eq for FatNode<'_, OT, H, N> {}
impl<OT, H> PartialOrd<Node> for FatNode<'_, OT, H> {
fn partial_cmp(&self, other: &Node) -> Option<Ordering> {
self.node.partial_cmp(other)
}
}
impl<OT, H> PartialOrd<FatNode<'_, OT, H>> for Node {
fn partial_cmp(&self, other: &FatNode<'_, OT, H>) -> Option<Ordering> {
self.partial_cmp(&other.node)
}
}
impl<N: PartialOrd, OT1, OT2, H1, H2> PartialOrd<FatNode<'_, OT1, H1, N>>
for FatNode<'_, OT2, H2, N>
{
fn partial_cmp(&self, other: &FatNode<'_, OT1, H1, N>) -> Option<Ordering> {
self.node.partial_cmp(&other.node)
}
}
impl<OT, H, N: Ord> Ord for FatNode<'_, OT, H, N> {
fn cmp(&self, other: &Self) -> Ordering {
self.node.cmp(&other.node)
}
}
impl<OT, H, N: Hash> Hash for FatNode<'_, OT, H, N> {
fn hash<HA: std::hash::Hasher>(&self, state: &mut HA) {
self.node.hash(state);
}
}
impl<OT, H: HugrView + ?Sized> AsRef<OT> for FatNode<'_, OT, H, H::Node>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
fn as_ref(&self) -> &OT {
self.hugr.get_optype(self.node).try_into().ok().unwrap()
}
}
impl<OT, H: HugrView + ?Sized> Deref for FatNode<'_, OT, H, H::Node>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
type Target = OT;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
impl<OT, H> Copy for FatNode<'_, OT, H> {}
impl<OT, H> Clone for FatNode<'_, OT, H> {
fn clone(&self) -> Self {
*self
}
}
impl<OT: fmt::Display, H: HugrView + ?Sized> fmt::Display for FatNode<'_, OT, H, H::Node>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("N<{}:{}>", self.as_ref(), self.node))
}
}
impl<OT, H, N: NodeIndex> NodeIndex for FatNode<'_, OT, H, N> {
fn index(self) -> usize {
self.node.index()
}
}
impl<OT, H> NodeIndex for &FatNode<'_, OT, H> {
fn index(self) -> usize {
self.node.index()
}
}
pub trait FatExt: HugrView {
fn try_fat<OT>(&self, node: Self::Node) -> Option<FatNode<'_, OT, Self, Self::Node>>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
FatNode::try_new(self, node)
}
fn fat_optype(&self, node: Self::Node) -> FatNode<'_, OpType, Self, Self::Node> {
FatNode::new_optype(self, node)
}
#[allow(clippy::type_complexity)]
fn fat_io(
&self,
node: Self::Node,
) -> Option<(
FatNode<'_, Input, Self, Self::Node>,
FatNode<'_, Output, Self, Self::Node>,
)> {
self.fat_optype(node).get_io()
}
fn fat_children(
&self,
node: Self::Node,
) -> impl Iterator<Item = FatNode<'_, OpType, Self, Self::Node>> {
self.children(node).map(|x| self.fat_optype(x))
}
fn fat_root(&self) -> Option<FatNode<'_, Module, Self, Self::Node>> {
self.try_fat(self.module_root())
}
fn fat_entrypoint<OT>(&self) -> Option<FatNode<'_, OT, Self, Self::Node>>
where
for<'a> &'a OpType: TryInto<&'a OT>,
{
self.try_fat(self.entrypoint())
}
}
impl<H: HugrView> FatExt for H {}