1use std::{cmp::Ordering, fmt, hash::Hash, marker::PhantomData, ops::Deref};
6
7use hugr_core::hugr::views::Rerooted;
8use hugr_core::{
9 Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort,
10 core::HugrNode,
11 ops::{CFG, DataflowBlock, ExitBlock, Input, Module, OpType, Output},
12 types::Type,
13};
14use itertools::Itertools as _;
15
16#[derive(Debug)]
31pub struct FatNode<'hugr, OT = OpType, H = Hugr, N = Node>
32where
33 H: ?Sized,
34{
35 hugr: &'hugr H,
36 node: N,
37 marker: PhantomData<OT>,
38}
39
40impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node>
41where
42 for<'a> &'a OpType: TryInto<&'a OT>,
43{
44 pub fn new(hugr: &'hugr H, node: H::Node, #[allow(unused)] ot: &OT) -> Self {
52 assert!(hugr.contains_node(node));
53 assert!(TryInto::<&OT>::try_into(hugr.get_optype(node)).is_ok());
54 Self {
56 hugr,
57 node,
58 marker: PhantomData,
59 }
60 }
61
62 pub fn try_new(hugr: &'hugr H, node: H::Node) -> Option<Self> {
68 (hugr.contains_node(node)).then_some(())?;
69 Some(Self::new(
70 hugr,
71 node,
72 hugr.get_optype(node).try_into().ok()?,
73 ))
74 }
75
76 pub fn generalise(self) -> FatNode<'hugr, OpType, H, H::Node> {
78 FatNode {
80 hugr: self.hugr,
81 node: self.node,
82 marker: PhantomData,
83 }
84 }
85}
86
87impl<'hugr, OT, H, N: HugrNode> FatNode<'hugr, OT, H, N> {
88 pub fn node(&self) -> N {
90 self.node
91 }
92
93 pub fn hugr(&self) -> &'hugr H {
95 self.hugr
96 }
97}
98
99impl<'hugr, H: HugrView + ?Sized> FatNode<'hugr, OpType, H, H::Node> {
100 pub fn new_optype(hugr: &'hugr H, node: H::Node) -> Self {
104 assert!(hugr.contains_node(node));
105 FatNode::new(hugr, node, hugr.get_optype(node))
106 }
107
108 pub fn try_into_ot<OT>(&self) -> Option<FatNode<'hugr, OT, H, H::Node>>
110 where
111 for<'a> &'a OpType: TryInto<&'a OT>,
112 {
113 FatNode::try_new(self.hugr, self.node)
114 }
115
116 pub fn into_ot<OT>(self, ot: &OT) -> FatNode<'hugr, OT, H, H::Node>
124 where
125 for<'a> &'a OpType: TryInto<&'a OT>,
126 {
127 FatNode::new(self.hugr, self.node, ot)
128 }
129}
130
131impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node> {
132 #[allow(clippy::type_complexity)]
135 pub fn single_linked_output(
136 &self,
137 port: IncomingPort,
138 ) -> Option<(FatNode<'hugr, OpType, H, H::Node>, OutgoingPort)> {
139 self.hugr
140 .single_linked_output(self.node, port)
141 .map(|(n, p)| (FatNode::new_optype(self.hugr, n), p))
142 }
143
144 pub fn out_value_types(
147 &self,
148 ) -> impl Iterator<Item = (OutgoingPort, Type)> + 'hugr + use<'hugr, OT, H> {
149 self.hugr.out_value_types(self.node)
150 }
151
152 pub fn in_value_types(
155 &self,
156 ) -> impl Iterator<Item = (IncomingPort, Type)> + 'hugr + use<'hugr, OT, H> {
157 self.hugr.in_value_types(self.node)
158 }
159
160 pub fn children(
162 &self,
163 ) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr + use<'hugr, OT, H> {
164 self.hugr
165 .children(self.node)
166 .map(|n| FatNode::new_optype(self.hugr, n))
167 }
168
169 #[allow(clippy::type_complexity)]
172 pub fn get_io(
173 &self,
174 ) -> Option<(
175 FatNode<'hugr, Input, H, H::Node>,
176 FatNode<'hugr, Output, H, H::Node>,
177 )> {
178 let [i, o] = self.hugr.get_io(self.node)?;
179 Some((
180 FatNode::try_new(self.hugr, i)?,
181 FatNode::try_new(self.hugr, o)?,
182 ))
183 }
184
185 pub fn node_outputs(&self) -> impl Iterator<Item = OutgoingPort> + 'hugr + use<'hugr, OT, H> {
187 self.hugr.node_outputs(self.node)
188 }
189
190 pub fn output_neighbours(
192 &self,
193 ) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr + use<'hugr, OT, H> {
194 self.hugr
195 .output_neighbours(self.node)
196 .map(|n| FatNode::new_optype(self.hugr, n))
197 }
198
199 pub fn as_entrypoint(&self) -> Rerooted<&H>
201 where
202 H: Sized,
203 {
204 self.hugr.with_entrypoint(self.node)
205 }
206}
207
208impl<'hugr, H: HugrView> FatNode<'hugr, CFG, H, H::Node> {
209 #[allow(clippy::type_complexity)]
214 pub fn get_entry_exit(
215 &self,
216 ) -> (
217 FatNode<'hugr, DataflowBlock, H, H::Node>,
218 FatNode<'hugr, ExitBlock, H, H::Node>,
219 ) {
220 let [i, o] = self
221 .hugr
222 .children(self.node)
223 .take(2)
224 .collect_vec()
225 .try_into()
226 .unwrap();
227 (
228 FatNode::try_new(self.hugr, i).unwrap(),
229 FatNode::try_new(self.hugr, o).unwrap(),
230 )
231 }
232}
233
234impl<OT, H> PartialEq<Node> for FatNode<'_, OT, H, Node> {
235 fn eq(&self, other: &Node) -> bool {
236 &self.node == other
237 }
238}
239
240impl<OT, H> PartialEq<FatNode<'_, OT, H, Node>> for Node {
241 fn eq(&self, other: &FatNode<'_, OT, H, Node>) -> bool {
242 self == &other.node
243 }
244}
245
246impl<N: PartialEq, OT1, OT2, H1, H2> PartialEq<FatNode<'_, OT1, H1, N>>
247 for FatNode<'_, OT2, H2, N>
248{
249 fn eq(&self, other: &FatNode<'_, OT1, H1, N>) -> bool {
250 self.node == other.node
251 }
252}
253
254impl<N: Eq, OT, H> Eq for FatNode<'_, OT, H, N> {}
255
256impl<OT, H> PartialOrd<Node> for FatNode<'_, OT, H> {
257 fn partial_cmp(&self, other: &Node) -> Option<Ordering> {
258 self.node.partial_cmp(other)
259 }
260}
261
262impl<OT, H> PartialOrd<FatNode<'_, OT, H>> for Node {
263 fn partial_cmp(&self, other: &FatNode<'_, OT, H>) -> Option<Ordering> {
264 self.partial_cmp(&other.node)
265 }
266}
267
268impl<N: PartialOrd, OT1, OT2, H1, H2> PartialOrd<FatNode<'_, OT1, H1, N>>
269 for FatNode<'_, OT2, H2, N>
270{
271 fn partial_cmp(&self, other: &FatNode<'_, OT1, H1, N>) -> Option<Ordering> {
272 self.node.partial_cmp(&other.node)
273 }
274}
275
276impl<OT, H, N: Ord> Ord for FatNode<'_, OT, H, N> {
277 fn cmp(&self, other: &Self) -> Ordering {
278 self.node.cmp(&other.node)
279 }
280}
281
282impl<OT, H, N: Hash> Hash for FatNode<'_, OT, H, N> {
283 fn hash<HA: std::hash::Hasher>(&self, state: &mut HA) {
284 self.node.hash(state);
285 }
286}
287
288impl<OT, H: HugrView + ?Sized> AsRef<OT> for FatNode<'_, OT, H, H::Node>
289where
290 for<'a> &'a OpType: TryInto<&'a OT>,
291{
292 fn as_ref(&self) -> &OT {
293 self.hugr.get_optype(self.node).try_into().ok().unwrap()
294 }
295}
296
297impl<OT, H: HugrView + ?Sized> Deref for FatNode<'_, OT, H, H::Node>
298where
299 for<'a> &'a OpType: TryInto<&'a OT>,
300{
301 type Target = OT;
302
303 fn deref(&self) -> &Self::Target {
304 self.as_ref()
305 }
306}
307
308impl<OT, H> Copy for FatNode<'_, OT, H> {}
309
310impl<OT, H> Clone for FatNode<'_, OT, H> {
311 fn clone(&self) -> Self {
312 *self
313 }
314}
315
316impl<OT: fmt::Display, H: HugrView + ?Sized> fmt::Display for FatNode<'_, OT, H, H::Node>
317where
318 for<'a> &'a OpType: TryInto<&'a OT>,
319{
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 f.write_fmt(format_args!("N<{}:{}>", self.as_ref(), self.node))
322 }
323}
324
325impl<OT, H, N: NodeIndex> NodeIndex for FatNode<'_, OT, H, N> {
326 fn index(self) -> usize {
327 self.node.index()
328 }
329}
330
331impl<OT, H> NodeIndex for &FatNode<'_, OT, H> {
332 fn index(self) -> usize {
333 self.node.index()
334 }
335}
336
337pub trait FatExt: HugrView {
343 fn try_fat<OT>(&self, node: Self::Node) -> Option<FatNode<'_, OT, Self, Self::Node>>
345 where
346 for<'a> &'a OpType: TryInto<&'a OT>,
347 {
348 FatNode::try_new(self, node)
349 }
350
351 fn fat_optype(&self, node: Self::Node) -> FatNode<'_, OpType, Self, Self::Node> {
353 FatNode::new_optype(self, node)
354 }
355
356 #[allow(clippy::type_complexity)]
359 fn fat_io(
360 &self,
361 node: Self::Node,
362 ) -> Option<(
363 FatNode<'_, Input, Self, Self::Node>,
364 FatNode<'_, Output, Self, Self::Node>,
365 )> {
366 self.fat_optype(node).get_io()
367 }
368
369 fn fat_children(
371 &self,
372 node: Self::Node,
373 ) -> impl Iterator<Item = FatNode<'_, OpType, Self, Self::Node>> {
374 self.children(node).map(|x| self.fat_optype(x))
375 }
376
377 fn fat_root(&self) -> Option<FatNode<'_, Module, Self, Self::Node>> {
379 self.try_fat(self.module_root())
380 }
381
382 fn fat_entrypoint<OT>(&self) -> Option<FatNode<'_, OT, Self, Self::Node>>
384 where
385 for<'a> &'a OpType: TryInto<&'a OT>,
386 {
387 self.try_fat(self.entrypoint())
388 }
389}
390
391impl<H: HugrView> FatExt for H {}