1use crate::Node;
3use crate::core::HugrNode;
4use crate::types::{Type, TypeBound};
5
6use derive_more::From as DerFrom;
7use smol_str::SmolStr;
8
9use super::{AliasDecl, OpTag};
10
11pub trait NodeHandle<N = Node>: Clone {
14    const TAG: OpTag;
16
17    fn node(&self) -> N;
19
20    #[inline]
22    fn tag(&self) -> OpTag {
23        Self::TAG
24    }
25
26    fn try_cast<T: NodeHandle<N> + From<N>>(&self) -> Option<T> {
28        T::TAG.is_superset(Self::TAG).then(|| self.node().into())
29    }
30
31    #[must_use]
33    fn can_hold(tag: OpTag) -> bool {
34        Self::TAG.is_superset(tag)
35    }
36}
37
38pub trait ContainerHandle<N = Node>: NodeHandle<N> {
42    type ChildrenHandle: NodeHandle<N>;
44}
45
46#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
47pub struct DataflowOpID<N = Node>(N);
49
50#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
51pub struct DfgID<N = Node>(N);
53
54#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
55pub struct CfgID<N = Node>(N);
57
58#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
59pub struct ModuleRootID<N = Node>(N);
61
62#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
63pub struct ModuleID<N = Node>(N);
65
66#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
67pub struct FuncID<const DEF: bool, N = Node>(N);
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub struct AliasID<const DEF: bool, N = Node> {
81    node: N,
82    name: SmolStr,
83    bound: TypeBound,
84}
85
86impl<const DEF: bool, N> AliasID<DEF, N> {
87    pub fn new(node: N, name: SmolStr, bound: TypeBound) -> Self {
89        Self { node, name, bound }
90    }
91
92    pub fn get_alias_type(&self) -> Type {
94        Type::new_alias(AliasDecl::new(self.name.clone(), self.bound))
95    }
96    pub fn get_name(&self) -> &SmolStr {
98        &self.name
99    }
100}
101
102#[derive(DerFrom, Debug, Clone, PartialEq, Eq)]
103pub struct ConstID<N = Node>(N);
105
106#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
107pub struct BasicBlockID<N = Node>(N);
109
110#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
111pub struct CaseID<N = Node>(N);
113
114#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
115pub struct TailLoopID<N = Node>(N);
117
118#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
119pub struct ConditionalID<N = Node>(N);
121
122#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)]
123pub struct DataflowParentID<N = Node>(N);
125
126macro_rules! impl_nodehandle {
132    ($name:ident, $tag:expr) => {
133        impl_nodehandle!($name, $tag, 0);
134    };
135    ($name:ident, $tag:expr, $node_attr:tt) => {
136        impl<N: HugrNode> NodeHandle<N> for $name<N> {
137            const TAG: OpTag = $tag;
138
139            #[inline]
140            fn node(&self) -> N {
141                self.$node_attr
142            }
143        }
144    };
145}
146
147impl_nodehandle!(DataflowParentID, OpTag::DataflowParent);
148impl_nodehandle!(DataflowOpID, OpTag::DataflowChild);
149impl_nodehandle!(ConditionalID, OpTag::Conditional);
150impl_nodehandle!(CaseID, OpTag::Case);
151impl_nodehandle!(DfgID, OpTag::Dfg);
152impl_nodehandle!(TailLoopID, OpTag::TailLoop);
153impl_nodehandle!(CfgID, OpTag::Cfg);
154
155impl_nodehandle!(ModuleRootID, OpTag::ModuleRoot);
156impl_nodehandle!(ModuleID, OpTag::ModuleOp);
157impl_nodehandle!(ConstID, OpTag::Const);
158
159impl_nodehandle!(BasicBlockID, OpTag::DataflowBlock);
160
161impl<const DEF: bool, N: HugrNode> NodeHandle<N> for FuncID<DEF, N> {
162    const TAG: OpTag = OpTag::Function;
163    #[inline]
164    fn node(&self) -> N {
165        self.0
166    }
167}
168
169impl<const DEF: bool, N: HugrNode> NodeHandle<N> for AliasID<DEF, N> {
170    const TAG: OpTag = OpTag::Alias;
171    #[inline]
172    fn node(&self) -> N {
173        self.node
174    }
175}
176
177impl<N: HugrNode> NodeHandle<N> for N {
178    const TAG: OpTag = OpTag::Any;
179    #[inline]
180    fn node(&self) -> N {
181        *self
182    }
183}
184
185macro_rules! impl_containerHandle {
187    ($name:ident, $children:ident) => {
188        impl<N: HugrNode> ContainerHandle<N> for $name<N> {
189            type ChildrenHandle = $children<N>;
190        }
191    };
192}
193
194impl_containerHandle!(DataflowParentID, DataflowOpID);
195impl_containerHandle!(DfgID, DataflowOpID);
196impl_containerHandle!(TailLoopID, DataflowOpID);
197impl_containerHandle!(ConditionalID, CaseID);
198impl_containerHandle!(CaseID, DataflowOpID);
199impl_containerHandle!(ModuleRootID, ModuleID);
200impl_containerHandle!(CfgID, BasicBlockID);
201impl_containerHandle!(BasicBlockID, DataflowOpID);
202impl<N: HugrNode> ContainerHandle<N> for FuncID<true, N> {
203    type ChildrenHandle = DataflowOpID<N>;
204}
205impl<N: HugrNode> ContainerHandle<N> for AliasID<true, N> {
206    type ChildrenHandle = DataflowOpID<N>;
207}