1use std::{cmp::Ordering, fmt, hash::Hash, marker::PhantomData, ops::Deref};
6
7use hugr_core::hugr::views::Rerooted;
8use hugr_core::{
9 Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort,
10 core::HugrNode,
11 ops::{CFG, DataflowBlock, ExitBlock, Input, Module, OpType, Output},
12 types::Type,
13};
14use itertools::Itertools as _;
15
16#[derive(Debug)]
29pub struct FatNode<'hugr, OT = OpType, H = Hugr, N = Node>
30where
31 H: ?Sized,
32{
33 hugr: &'hugr H,
34 node: N,
35 marker: PhantomData<OT>,
36}
37
38impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node>
39where
40 for<'a> &'a OpType: TryInto<&'a OT>,
41{
42 pub fn new(hugr: &'hugr H, node: H::Node, #[allow(unused)] ot: &OT) -> Self {
50 assert!(hugr.contains_node(node));
51 assert!(TryInto::<&OT>::try_into(hugr.get_optype(node)).is_ok());
52 Self {
54 hugr,
55 node,
56 marker: PhantomData,
57 }
58 }
59
60 pub fn try_new(hugr: &'hugr H, node: H::Node) -> Option<Self> {
66 (hugr.contains_node(node)).then_some(())?;
67 Some(Self::new(
68 hugr,
69 node,
70 hugr.get_optype(node).try_into().ok()?,
71 ))
72 }
73
74 pub fn generalise(self) -> FatNode<'hugr, OpType, H, H::Node> {
76 FatNode {
78 hugr: self.hugr,
79 node: self.node,
80 marker: PhantomData,
81 }
82 }
83}
84
85impl<'hugr, OT, H, N: HugrNode> FatNode<'hugr, OT, H, N> {
86 pub fn node(&self) -> N {
88 self.node
89 }
90
91 pub fn hugr(&self) -> &'hugr H {
93 self.hugr
94 }
95}
96
97impl<'hugr, H: HugrView + ?Sized> FatNode<'hugr, OpType, H, H::Node> {
98 pub fn new_optype(hugr: &'hugr H, node: H::Node) -> Self {
102 assert!(hugr.contains_node(node));
103 FatNode::new(hugr, node, hugr.get_optype(node))
104 }
105
106 pub fn try_into_ot<OT>(&self) -> Option<FatNode<'hugr, OT, H, H::Node>>
108 where
109 for<'a> &'a OpType: TryInto<&'a OT>,
110 {
111 FatNode::try_new(self.hugr, self.node)
112 }
113
114 pub fn into_ot<OT>(self, ot: &OT) -> FatNode<'hugr, OT, H, H::Node>
122 where
123 for<'a> &'a OpType: TryInto<&'a OT>,
124 {
125 FatNode::new(self.hugr, self.node, ot)
126 }
127}
128
129impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node> {
130 #[allow(clippy::type_complexity)]
133 pub fn single_linked_output(
134 &self,
135 port: IncomingPort,
136 ) -> Option<(FatNode<'hugr, OpType, H, H::Node>, OutgoingPort)> {
137 self.hugr
138 .single_linked_output(self.node, port)
139 .map(|(n, p)| (FatNode::new_optype(self.hugr, n), p))
140 }
141
142 pub fn out_value_types(
145 &self,
146 ) -> impl Iterator<Item = (OutgoingPort, Type)> + 'hugr + use<'hugr, OT, H> {
147 self.hugr.out_value_types(self.node)
148 }
149
150 pub fn in_value_types(
153 &self,
154 ) -> impl Iterator<Item = (IncomingPort, Type)> + 'hugr + use<'hugr, OT, H> {
155 self.hugr.in_value_types(self.node)
156 }
157
158 pub fn children(
160 &self,
161 ) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr + use<'hugr, OT, H> {
162 self.hugr
163 .children(self.node)
164 .map(|n| FatNode::new_optype(self.hugr, n))
165 }
166
167 #[allow(clippy::type_complexity)]
170 pub fn get_io(
171 &self,
172 ) -> Option<(
173 FatNode<'hugr, Input, H, H::Node>,
174 FatNode<'hugr, Output, H, H::Node>,
175 )> {
176 let [i, o] = self.hugr.get_io(self.node)?;
177 Some((
178 FatNode::try_new(self.hugr, i)?,
179 FatNode::try_new(self.hugr, o)?,
180 ))
181 }
182
183 pub fn node_outputs(&self) -> impl Iterator<Item = OutgoingPort> + 'hugr + use<'hugr, OT, H> {
185 self.hugr.node_outputs(self.node)
186 }
187
188 pub fn output_neighbours(
190 &self,
191 ) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr + use<'hugr, OT, H> {
192 self.hugr
193 .output_neighbours(self.node)
194 .map(|n| FatNode::new_optype(self.hugr, n))
195 }
196
197 pub fn as_entrypoint(&self) -> Rerooted<&H>
199 where
200 H: Sized,
201 {
202 self.hugr.with_entrypoint(self.node)
203 }
204}
205
206impl<'hugr, H: HugrView> FatNode<'hugr, CFG, H, H::Node> {
207 #[allow(clippy::type_complexity)]
212 pub fn get_entry_exit(
213 &self,
214 ) -> (
215 FatNode<'hugr, DataflowBlock, H, H::Node>,
216 FatNode<'hugr, ExitBlock, H, H::Node>,
217 ) {
218 let [i, o] = self
219 .hugr
220 .children(self.node)
221 .take(2)
222 .collect_vec()
223 .try_into()
224 .unwrap();
225 (
226 FatNode::try_new(self.hugr, i).unwrap(),
227 FatNode::try_new(self.hugr, o).unwrap(),
228 )
229 }
230}
231
232impl<OT, H> PartialEq<Node> for FatNode<'_, OT, H, Node> {
233 fn eq(&self, other: &Node) -> bool {
234 &self.node == other
235 }
236}
237
238impl<OT, H> PartialEq<FatNode<'_, OT, H, Node>> for Node {
239 fn eq(&self, other: &FatNode<'_, OT, H, Node>) -> bool {
240 self == &other.node
241 }
242}
243
244impl<N: PartialEq, OT1, OT2, H1, H2> PartialEq<FatNode<'_, OT1, H1, N>>
245 for FatNode<'_, OT2, H2, N>
246{
247 fn eq(&self, other: &FatNode<'_, OT1, H1, N>) -> bool {
248 self.node == other.node
249 }
250}
251
252impl<N: Eq, OT, H> Eq for FatNode<'_, OT, H, N> {}
253
254impl<OT, H> PartialOrd<Node> for FatNode<'_, OT, H> {
255 fn partial_cmp(&self, other: &Node) -> Option<Ordering> {
256 self.node.partial_cmp(other)
257 }
258}
259
260impl<OT, H> PartialOrd<FatNode<'_, OT, H>> for Node {
261 fn partial_cmp(&self, other: &FatNode<'_, OT, H>) -> Option<Ordering> {
262 self.partial_cmp(&other.node)
263 }
264}
265
266impl<N: PartialOrd, OT1, OT2, H1, H2> PartialOrd<FatNode<'_, OT1, H1, N>>
267 for FatNode<'_, OT2, H2, N>
268{
269 fn partial_cmp(&self, other: &FatNode<'_, OT1, H1, N>) -> Option<Ordering> {
270 self.node.partial_cmp(&other.node)
271 }
272}
273
274impl<OT, H, N: Ord> Ord for FatNode<'_, OT, H, N> {
275 fn cmp(&self, other: &Self) -> Ordering {
276 self.node.cmp(&other.node)
277 }
278}
279
280impl<OT, H, N: Hash> Hash for FatNode<'_, OT, H, N> {
281 fn hash<HA: std::hash::Hasher>(&self, state: &mut HA) {
282 self.node.hash(state);
283 }
284}
285
286impl<OT, H: HugrView + ?Sized> AsRef<OT> for FatNode<'_, OT, H, H::Node>
287where
288 for<'a> &'a OpType: TryInto<&'a OT>,
289{
290 fn as_ref(&self) -> &OT {
291 self.hugr.get_optype(self.node).try_into().ok().unwrap()
292 }
293}
294
295impl<OT, H: HugrView + ?Sized> Deref for FatNode<'_, OT, H, H::Node>
296where
297 for<'a> &'a OpType: TryInto<&'a OT>,
298{
299 type Target = OT;
300
301 fn deref(&self) -> &Self::Target {
302 self.as_ref()
303 }
304}
305
306impl<OT, H> Copy for FatNode<'_, OT, H> {}
307
308impl<OT, H> Clone for FatNode<'_, OT, H> {
309 fn clone(&self) -> Self {
310 *self
311 }
312}
313
314impl<OT: fmt::Display, H: HugrView + ?Sized> fmt::Display for FatNode<'_, OT, H, H::Node>
315where
316 for<'a> &'a OpType: TryInto<&'a OT>,
317{
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 f.write_fmt(format_args!("N<{}:{}>", self.as_ref(), self.node))
320 }
321}
322
323impl<OT, H, N: NodeIndex> NodeIndex for FatNode<'_, OT, H, N> {
324 fn index(self) -> usize {
325 self.node.index()
326 }
327}
328
329impl<OT, H> NodeIndex for &FatNode<'_, OT, H> {
330 fn index(self) -> usize {
331 self.node.index()
332 }
333}
334
335pub trait FatExt: HugrView {
341 fn try_fat<OT>(&self, node: Self::Node) -> Option<FatNode<OT, Self, Self::Node>>
343 where
344 for<'a> &'a OpType: TryInto<&'a OT>,
345 {
346 FatNode::try_new(self, node)
347 }
348
349 fn fat_optype(&self, node: Self::Node) -> FatNode<OpType, Self, Self::Node> {
351 FatNode::new_optype(self, node)
352 }
353
354 #[allow(clippy::type_complexity)]
357 fn fat_io(
358 &self,
359 node: Self::Node,
360 ) -> Option<(
361 FatNode<Input, Self, Self::Node>,
362 FatNode<Output, Self, Self::Node>,
363 )> {
364 self.fat_optype(node).get_io()
365 }
366
367 fn fat_children(
369 &self,
370 node: Self::Node,
371 ) -> impl Iterator<Item = FatNode<OpType, Self, Self::Node>> {
372 self.children(node).map(|x| self.fat_optype(x))
373 }
374
375 fn fat_root(&self) -> Option<FatNode<Module, Self, Self::Node>> {
377 self.try_fat(self.module_root())
378 }
379
380 fn fat_entrypoint<OT>(&self) -> Option<FatNode<OT, Self, Self::Node>>
382 where
383 for<'a> &'a OpType: TryInto<&'a OT>,
384 {
385 self.try_fat(self.entrypoint())
386 }
387}
388
389impl<H: HugrView> FatExt for H {}