1use std::{cmp::Ordering, hash::Hash, marker::PhantomData, ops::Deref};
6
7use hugr_core::{
8 core::HugrNode,
9 hugr::{views::HierarchyView, HugrError},
10 ops::{DataflowBlock, ExitBlock, Input, NamedOp, OpType, Output, CFG},
11 types::Type,
12 Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort,
13};
14use itertools::Itertools as _;
15
16#[derive(Debug)]
29pub struct FatNode<'hugr, OT = OpType, H = Hugr, N = Node>
30where
31 H: ?Sized,
32{
33 hugr: &'hugr H,
34 node: N,
35 marker: PhantomData<OT>,
36}
37
38impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node>
39where
40 for<'a> &'a OpType: TryInto<&'a OT>,
41{
42 pub fn new(hugr: &'hugr H, node: H::Node, #[allow(unused)] ot: &OT) -> Self {
50 assert!(hugr.valid_node(node));
51 assert!(TryInto::<&OT>::try_into(hugr.get_optype(node)).is_ok());
52 Self {
54 hugr,
55 node,
56 marker: PhantomData,
57 }
58 }
59
60 pub fn try_new(hugr: &'hugr H, node: H::Node) -> Option<Self> {
66 (hugr.valid_node(node)).then_some(())?;
67 Some(Self::new(
68 hugr,
69 node,
70 hugr.get_optype(node).try_into().ok()?,
71 ))
72 }
73
74 pub fn generalise(self) -> FatNode<'hugr, OpType, H, H::Node> {
76 FatNode {
78 hugr: self.hugr,
79 node: self.node,
80 marker: PhantomData,
81 }
82 }
83}
84
85impl<'hugr, OT, H, N: HugrNode> FatNode<'hugr, OT, H, N> {
86 pub fn node(&self) -> N {
88 self.node
89 }
90
91 pub fn hugr(&self) -> &'hugr H {
93 self.hugr
94 }
95}
96
97impl<'hugr, H: HugrView + ?Sized> FatNode<'hugr, OpType, H, H::Node> {
98 pub fn new_optype(hugr: &'hugr H, node: H::Node) -> Self {
102 assert!(hugr.valid_node(node));
103 FatNode::new(hugr, node, hugr.get_optype(node))
104 }
105
106 pub fn try_into_ot<OT>(&self) -> Option<FatNode<'hugr, OT, H, H::Node>>
108 where
109 for<'a> &'a OpType: TryInto<&'a OT>,
110 {
111 FatNode::try_new(self.hugr, self.node)
112 }
113
114 pub fn into_ot<OT>(self, ot: &OT) -> FatNode<'hugr, OT, H, H::Node>
122 where
123 for<'a> &'a OpType: TryInto<&'a OT>,
124 {
125 FatNode::new(self.hugr, self.node, ot)
126 }
127}
128
129impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node> {
130 #[allow(clippy::type_complexity)]
133 pub fn single_linked_output(
134 &self,
135 port: IncomingPort,
136 ) -> Option<(FatNode<'hugr, OpType, H, H::Node>, OutgoingPort)> {
137 self.hugr
138 .single_linked_output(self.node, port)
139 .map(|(n, p)| (FatNode::new_optype(self.hugr, n), p))
140 }
141
142 pub fn out_value_types(&self) -> impl Iterator<Item = (OutgoingPort, Type)> + 'hugr {
145 self.hugr.out_value_types(self.node)
146 }
147
148 pub fn in_value_types(&self) -> impl Iterator<Item = (IncomingPort, Type)> + 'hugr {
151 self.hugr.in_value_types(self.node)
152 }
153
154 pub fn children(&self) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr {
156 self.hugr
157 .children(self.node)
158 .map(|n| FatNode::new_optype(self.hugr, n))
159 }
160
161 #[allow(clippy::type_complexity)]
164 pub fn get_io(
165 &self,
166 ) -> Option<(
167 FatNode<'hugr, Input, H, H::Node>,
168 FatNode<'hugr, Output, H, H::Node>,
169 )> {
170 let [i, o] = self.hugr.get_io(self.node)?;
171 Some((
172 FatNode::try_new(self.hugr, i)?,
173 FatNode::try_new(self.hugr, o)?,
174 ))
175 }
176
177 pub fn node_outputs(&self) -> impl Iterator<Item = OutgoingPort> + 'hugr {
179 self.hugr.node_outputs(self.node)
180 }
181
182 pub fn output_neighbours(
184 &self,
185 ) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr {
186 self.hugr
187 .output_neighbours(self.node)
188 .map(|n| FatNode::new_optype(self.hugr, n))
189 }
190
191 pub fn try_new_hierarchy_view<HV: HierarchyView<'hugr, Node = H::Node>>(
193 &self,
194 ) -> Result<HV, HugrError>
195 where
196 H: Sized,
197 {
198 HV::try_new(self.hugr, self.node)
199 }
200}
201
202impl<'hugr, H: HugrView> FatNode<'hugr, CFG, H, H::Node> {
203 #[allow(clippy::type_complexity)]
208 pub fn get_entry_exit(
209 &self,
210 ) -> (
211 FatNode<'hugr, DataflowBlock, H, H::Node>,
212 FatNode<'hugr, ExitBlock, H, H::Node>,
213 ) {
214 let [i, o] = self
215 .hugr
216 .children(self.node)
217 .take(2)
218 .collect_vec()
219 .try_into()
220 .unwrap();
221 (
222 FatNode::try_new(self.hugr, i).unwrap(),
223 FatNode::try_new(self.hugr, o).unwrap(),
224 )
225 }
226}
227
228impl<OT, H> PartialEq<Node> for FatNode<'_, OT, H, Node> {
229 fn eq(&self, other: &Node) -> bool {
230 &self.node == other
231 }
232}
233
234impl<OT, H> PartialEq<FatNode<'_, OT, H, Node>> for Node {
235 fn eq(&self, other: &FatNode<'_, OT, H, Node>) -> bool {
236 self == &other.node
237 }
238}
239
240impl<N: PartialEq, OT1, OT2, H1, H2> PartialEq<FatNode<'_, OT1, H1, N>>
241 for FatNode<'_, OT2, H2, N>
242{
243 fn eq(&self, other: &FatNode<'_, OT1, H1, N>) -> bool {
244 self.node == other.node
245 }
246}
247
248impl<N: Eq, OT, H> Eq for FatNode<'_, OT, H, N> {}
249
250impl<OT, H> PartialOrd<Node> for FatNode<'_, OT, H> {
251 fn partial_cmp(&self, other: &Node) -> Option<Ordering> {
252 self.node.partial_cmp(other)
253 }
254}
255
256impl<OT, H> PartialOrd<FatNode<'_, OT, H>> for Node {
257 fn partial_cmp(&self, other: &FatNode<'_, OT, H>) -> Option<Ordering> {
258 self.partial_cmp(&other.node)
259 }
260}
261
262impl<N: PartialOrd, OT1, OT2, H1, H2> PartialOrd<FatNode<'_, OT1, H1, N>>
263 for FatNode<'_, OT2, H2, N>
264{
265 fn partial_cmp(&self, other: &FatNode<'_, OT1, H1, N>) -> Option<Ordering> {
266 self.node.partial_cmp(&other.node)
267 }
268}
269
270impl<OT, H, N: Ord> Ord for FatNode<'_, OT, H, N> {
271 fn cmp(&self, other: &Self) -> Ordering {
272 self.node.cmp(&other.node)
273 }
274}
275
276impl<OT, H, N: Hash> Hash for FatNode<'_, OT, H, N> {
277 fn hash<HA: std::hash::Hasher>(&self, state: &mut HA) {
278 self.node.hash(state);
279 }
280}
281
282impl<OT, H: HugrView + ?Sized> AsRef<OT> for FatNode<'_, OT, H, H::Node>
283where
284 for<'a> &'a OpType: TryInto<&'a OT>,
285{
286 fn as_ref(&self) -> &OT {
287 self.hugr.get_optype(self.node).try_into().ok().unwrap()
288 }
289}
290
291impl<OT, H: HugrView + ?Sized> Deref for FatNode<'_, OT, H, H::Node>
292where
293 for<'a> &'a OpType: TryInto<&'a OT>,
294{
295 type Target = OT;
296
297 fn deref(&self) -> &Self::Target {
298 self.as_ref()
299 }
300}
301
302impl<OT, H> Copy for FatNode<'_, OT, H> {}
303
304impl<OT, H> Clone for FatNode<'_, OT, H> {
305 fn clone(&self) -> Self {
306 *self
307 }
308}
309
310impl<OT: NamedOp, H: HugrView + ?Sized> std::fmt::Display for FatNode<'_, OT, H, H::Node>
311where
312 for<'a> &'a OpType: TryInto<&'a OT>,
313{
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.write_fmt(format_args!("N<{}:{}>", self.as_ref().name(), self.node))
316 }
317}
318
319impl<OT, H, N: NodeIndex> NodeIndex for FatNode<'_, OT, H, N> {
320 fn index(self) -> usize {
321 self.node.index()
322 }
323}
324
325impl<OT, H> NodeIndex for &FatNode<'_, OT, H> {
326 fn index(self) -> usize {
327 self.node.index()
328 }
329}
330
331pub trait FatExt: HugrView {
337 fn try_fat<OT>(&self, node: Self::Node) -> Option<FatNode<OT, Self, Self::Node>>
339 where
340 for<'a> &'a OpType: TryInto<&'a OT>,
341 {
342 FatNode::try_new(self, node)
343 }
344
345 fn fat_optype(&self, node: Self::Node) -> FatNode<OpType, Self, Self::Node> {
347 FatNode::new_optype(self, node)
348 }
349
350 #[allow(clippy::type_complexity)]
353 fn fat_io(
354 &self,
355 node: Self::Node,
356 ) -> Option<(
357 FatNode<Input, Self, Self::Node>,
358 FatNode<Output, Self, Self::Node>,
359 )> {
360 self.fat_optype(node).get_io()
361 }
362
363 fn fat_children(
365 &self,
366 node: Self::Node,
367 ) -> impl Iterator<Item = FatNode<OpType, Self, Self::Node>> {
368 self.children(node).map(|x| self.fat_optype(x))
369 }
370
371 fn fat_root<OT>(&self) -> Option<FatNode<OT, Self, Self::Node>>
373 where
374 for<'a> &'a OpType: TryInto<&'a OT>,
375 {
376 self.try_fat(self.root())
377 }
378}
379
380impl<H: HugrView> FatExt for H {}