hugr_llvm/utils/
fat.rs

1//! We define a type [`FatNode`], named for analogy with a "fat pointer".
2//!
3//! We define a trait [`FatExt`], an extension trait for [`HugrView`]. It provides
4//! methods that return [`FatNode`]s rather than [Node]s.
5use std::{cmp::Ordering, fmt, hash::Hash, marker::PhantomData, ops::Deref};
6
7use hugr_core::hugr::views::Rerooted;
8use hugr_core::{
9    Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort,
10    core::HugrNode,
11    ops::{CFG, DataflowBlock, ExitBlock, Input, Module, OpType, Output},
12    types::Type,
13};
14use itertools::Itertools as _;
15
16/// A Fat Node is a [Node] along with a reference to the [`HugrView`] whence it
17/// originates. It carries a type `OT`, the [`OpType`] of that node. `OT` may be
18/// general, i.e. exactly [`OpType`], or specifec, e.g. [`FuncDefn`].
19///
20/// We provide a [Deref<Target=OT>] impl, so it can be used in place of `OT`.
21///
22/// We provide [`PartialEq`], [Eq], [`PartialOrd`], [Ord], [Hash], so that this type
23/// can be used in containers. Note that these implementations use only the
24/// stored node, so will silently malfunction if used with [`FatNode`]s from
25/// different base [Hugr]s. Note that [Node] has this same behaviour.
26///
27/// [`FuncDefn`]: [hugr_core::ops::FuncDefn]
28#[derive(Debug)]
29pub struct FatNode<'hugr, OT = OpType, H = Hugr, N = Node>
30where
31    H: ?Sized,
32{
33    hugr: &'hugr H,
34    node: N,
35    marker: PhantomData<OT>,
36}
37
38impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node>
39where
40    for<'a> &'a OpType: TryInto<&'a OT>,
41{
42    /// Create a `FatNode` from a [`HugrView`] and a [Node].
43    ///
44    /// Panics if the node is not valid in the `Hugr` or if it's `get_optype` is
45    /// not an `OT`.
46    ///
47    /// Note that while we do check the type of the node's `get_optype`, we
48    /// do not verify that it is actually equal to `ot`.
49    pub fn new(hugr: &'hugr H, node: H::Node, #[allow(unused)] ot: &OT) -> Self {
50        assert!(hugr.contains_node(node));
51        assert!(TryInto::<&OT>::try_into(hugr.get_optype(node)).is_ok());
52        // We don't actually check `ot == hugr.get_optype(node)` so as to not require OT: PartialEq`
53        Self {
54            hugr,
55            node,
56            marker: PhantomData,
57        }
58    }
59
60    /// Tries to create a `FatNode` from a [`HugrView`] and a node (typically a
61    /// [Node]).
62    ///
63    /// If the node is invalid, or if its `get_optype` is not `OT`, returns
64    /// `None`.
65    pub fn try_new(hugr: &'hugr H, node: H::Node) -> Option<Self> {
66        (hugr.contains_node(node)).then_some(())?;
67        Some(Self::new(
68            hugr,
69            node,
70            hugr.get_optype(node).try_into().ok()?,
71        ))
72    }
73
74    /// Create a general `FatNode` from a specific one.
75    pub fn generalise(self) -> FatNode<'hugr, OpType, H, H::Node> {
76        // guaranteed to be valid because self is valid
77        FatNode {
78            hugr: self.hugr,
79            node: self.node,
80            marker: PhantomData,
81        }
82    }
83}
84
85impl<'hugr, OT, H, N: HugrNode> FatNode<'hugr, OT, H, N> {
86    /// Gets the [Node] of the `FatNode`.
87    pub fn node(&self) -> N {
88        self.node
89    }
90
91    /// Gets the [`HugrView`] of the `FatNode`.
92    pub fn hugr(&self) -> &'hugr H {
93        self.hugr
94    }
95}
96
97impl<'hugr, H: HugrView + ?Sized> FatNode<'hugr, OpType, H, H::Node> {
98    /// Creates a new general `FatNode` from a [`HugrView`] and a [Node].
99    ///
100    /// Panics if the node is not valid in the [Hugr].
101    pub fn new_optype(hugr: &'hugr H, node: H::Node) -> Self {
102        assert!(hugr.contains_node(node));
103        FatNode::new(hugr, node, hugr.get_optype(node))
104    }
105
106    /// Tries to downcast a general `FatNode` into a specific `OT`.
107    pub fn try_into_ot<OT>(&self) -> Option<FatNode<'hugr, OT, H, H::Node>>
108    where
109        for<'a> &'a OpType: TryInto<&'a OT>,
110    {
111        FatNode::try_new(self.hugr, self.node)
112    }
113
114    /// Creates a specific `FatNode` from a general `FatNode`.
115    ///
116    /// Panics if the node is not valid in the `Hugr` or if its `get_optype` is
117    /// not an `OT`.
118    ///
119    /// Note that while we do check the type of the node's `get_optype`, we
120    /// do not verify that it is actually equal to `ot`.
121    pub fn into_ot<OT>(self, ot: &OT) -> FatNode<'hugr, OT, H, H::Node>
122    where
123        for<'a> &'a OpType: TryInto<&'a OT>,
124    {
125        FatNode::new(self.hugr, self.node, ot)
126    }
127}
128
129impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node> {
130    /// If there is exactly one `OutgoingPort` connected to this `IncomingPort`,
131    /// return it and its node.
132    #[allow(clippy::type_complexity)]
133    pub fn single_linked_output(
134        &self,
135        port: IncomingPort,
136    ) -> Option<(FatNode<'hugr, OpType, H, H::Node>, OutgoingPort)> {
137        self.hugr
138            .single_linked_output(self.node, port)
139            .map(|(n, p)| (FatNode::new_optype(self.hugr, n), p))
140    }
141
142    /// Iterator over all incoming ports that have Value type, along
143    /// with corresponding types.
144    pub fn out_value_types(
145        &self,
146    ) -> impl Iterator<Item = (OutgoingPort, Type)> + 'hugr + use<'hugr, OT, H> {
147        self.hugr.out_value_types(self.node)
148    }
149
150    /// Iterator over all incoming ports that have Value type, along
151    /// with corresponding types.
152    pub fn in_value_types(
153        &self,
154    ) -> impl Iterator<Item = (IncomingPort, Type)> + 'hugr + use<'hugr, OT, H> {
155        self.hugr.in_value_types(self.node)
156    }
157
158    /// Return iterator over the direct children of node.
159    pub fn children(
160        &self,
161    ) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr + use<'hugr, OT, H> {
162        self.hugr
163            .children(self.node)
164            .map(|n| FatNode::new_optype(self.hugr, n))
165    }
166
167    /// Get the input and output child nodes of a dataflow parent.
168    /// If the node isn't a dataflow parent, then return None
169    #[allow(clippy::type_complexity)]
170    pub fn get_io(
171        &self,
172    ) -> Option<(
173        FatNode<'hugr, Input, H, H::Node>,
174        FatNode<'hugr, Output, H, H::Node>,
175    )> {
176        let [i, o] = self.hugr.get_io(self.node)?;
177        Some((
178            FatNode::try_new(self.hugr, i)?,
179            FatNode::try_new(self.hugr, o)?,
180        ))
181    }
182
183    /// Iterator over output ports of node.
184    pub fn node_outputs(&self) -> impl Iterator<Item = OutgoingPort> + 'hugr + use<'hugr, OT, H> {
185        self.hugr.node_outputs(self.node)
186    }
187
188    /// Iterates over the output neighbours of the `node`.
189    pub fn output_neighbours(
190        &self,
191    ) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr + use<'hugr, OT, H> {
192        self.hugr
193            .output_neighbours(self.node)
194            .map(|n| FatNode::new_optype(self.hugr, n))
195    }
196
197    /// Returns a view of the internal [`HugrView`] with this [Node] as entrypoint.
198    pub fn as_entrypoint(&self) -> Rerooted<&H>
199    where
200        H: Sized,
201    {
202        self.hugr.with_entrypoint(self.node)
203    }
204}
205
206impl<'hugr, H: HugrView> FatNode<'hugr, CFG, H, H::Node> {
207    /// Returns the entry and exit nodes of a CFG.
208    ///
209    /// These are guaranteed to exist the `Hugr` is valid. Panics if they do not
210    /// exist.
211    #[allow(clippy::type_complexity)]
212    pub fn get_entry_exit(
213        &self,
214    ) -> (
215        FatNode<'hugr, DataflowBlock, H, H::Node>,
216        FatNode<'hugr, ExitBlock, H, H::Node>,
217    ) {
218        let [i, o] = self
219            .hugr
220            .children(self.node)
221            .take(2)
222            .collect_vec()
223            .try_into()
224            .unwrap();
225        (
226            FatNode::try_new(self.hugr, i).unwrap(),
227            FatNode::try_new(self.hugr, o).unwrap(),
228        )
229    }
230}
231
232impl<OT, H> PartialEq<Node> for FatNode<'_, OT, H, Node> {
233    fn eq(&self, other: &Node) -> bool {
234        &self.node == other
235    }
236}
237
238impl<OT, H> PartialEq<FatNode<'_, OT, H, Node>> for Node {
239    fn eq(&self, other: &FatNode<'_, OT, H, Node>) -> bool {
240        self == &other.node
241    }
242}
243
244impl<N: PartialEq, OT1, OT2, H1, H2> PartialEq<FatNode<'_, OT1, H1, N>>
245    for FatNode<'_, OT2, H2, N>
246{
247    fn eq(&self, other: &FatNode<'_, OT1, H1, N>) -> bool {
248        self.node == other.node
249    }
250}
251
252impl<N: Eq, OT, H> Eq for FatNode<'_, OT, H, N> {}
253
254impl<OT, H> PartialOrd<Node> for FatNode<'_, OT, H> {
255    fn partial_cmp(&self, other: &Node) -> Option<Ordering> {
256        self.node.partial_cmp(other)
257    }
258}
259
260impl<OT, H> PartialOrd<FatNode<'_, OT, H>> for Node {
261    fn partial_cmp(&self, other: &FatNode<'_, OT, H>) -> Option<Ordering> {
262        self.partial_cmp(&other.node)
263    }
264}
265
266impl<N: PartialOrd, OT1, OT2, H1, H2> PartialOrd<FatNode<'_, OT1, H1, N>>
267    for FatNode<'_, OT2, H2, N>
268{
269    fn partial_cmp(&self, other: &FatNode<'_, OT1, H1, N>) -> Option<Ordering> {
270        self.node.partial_cmp(&other.node)
271    }
272}
273
274impl<OT, H, N: Ord> Ord for FatNode<'_, OT, H, N> {
275    fn cmp(&self, other: &Self) -> Ordering {
276        self.node.cmp(&other.node)
277    }
278}
279
280impl<OT, H, N: Hash> Hash for FatNode<'_, OT, H, N> {
281    fn hash<HA: std::hash::Hasher>(&self, state: &mut HA) {
282        self.node.hash(state);
283    }
284}
285
286impl<OT, H: HugrView + ?Sized> AsRef<OT> for FatNode<'_, OT, H, H::Node>
287where
288    for<'a> &'a OpType: TryInto<&'a OT>,
289{
290    fn as_ref(&self) -> &OT {
291        self.hugr.get_optype(self.node).try_into().ok().unwrap()
292    }
293}
294
295impl<OT, H: HugrView + ?Sized> Deref for FatNode<'_, OT, H, H::Node>
296where
297    for<'a> &'a OpType: TryInto<&'a OT>,
298{
299    type Target = OT;
300
301    fn deref(&self) -> &Self::Target {
302        self.as_ref()
303    }
304}
305
306impl<OT, H> Copy for FatNode<'_, OT, H> {}
307
308impl<OT, H> Clone for FatNode<'_, OT, H> {
309    fn clone(&self) -> Self {
310        *self
311    }
312}
313
314impl<OT: fmt::Display, H: HugrView + ?Sized> fmt::Display for FatNode<'_, OT, H, H::Node>
315where
316    for<'a> &'a OpType: TryInto<&'a OT>,
317{
318    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319        f.write_fmt(format_args!("N<{}:{}>", self.as_ref(), self.node))
320    }
321}
322
323impl<OT, H, N: NodeIndex> NodeIndex for FatNode<'_, OT, H, N> {
324    fn index(self) -> usize {
325        self.node.index()
326    }
327}
328
329impl<OT, H> NodeIndex for &FatNode<'_, OT, H> {
330    fn index(self) -> usize {
331        self.node.index()
332    }
333}
334
335/// An extension trait for [`HugrView`] which provides methods that delegate to
336/// [`HugrView`] and then return the result in [`FatNode`] form. See for example
337/// [`FatExt::fat_io`].
338///
339/// TODO: Add the remaining [`HugrView`] equivalents that make sense.
340pub trait FatExt: HugrView {
341    /// Try to create a specific [`FatNode`] for a given [Node].
342    fn try_fat<OT>(&self, node: Self::Node) -> Option<FatNode<OT, Self, Self::Node>>
343    where
344        for<'a> &'a OpType: TryInto<&'a OT>,
345    {
346        FatNode::try_new(self, node)
347    }
348
349    /// Create a general [`FatNode`] for a given [Node].
350    fn fat_optype(&self, node: Self::Node) -> FatNode<OpType, Self, Self::Node> {
351        FatNode::new_optype(self, node)
352    }
353
354    /// Try to create [Input] and [Output] [`FatNode`]s for a given [Node]. This
355    /// will succeed only for `DataFlow` Parent Nodes.
356    #[allow(clippy::type_complexity)]
357    fn fat_io(
358        &self,
359        node: Self::Node,
360    ) -> Option<(
361        FatNode<Input, Self, Self::Node>,
362        FatNode<Output, Self, Self::Node>,
363    )> {
364        self.fat_optype(node).get_io()
365    }
366
367    /// Create general [`FatNode`]s for each of a [Node]'s children.
368    fn fat_children(
369        &self,
370        node: Self::Node,
371    ) -> impl Iterator<Item = FatNode<OpType, Self, Self::Node>> {
372        self.children(node).map(|x| self.fat_optype(x))
373    }
374
375    /// Try to create a specific [`FatNode`] for the root of a [`HugrView`].
376    fn fat_root(&self) -> Option<FatNode<Module, Self, Self::Node>> {
377        self.try_fat(self.module_root())
378    }
379
380    /// Try to create a specific [`FatNode`] for the entrypoint of a [`HugrView`].
381    fn fat_entrypoint<OT>(&self) -> Option<FatNode<OT, Self, Self::Node>>
382    where
383        for<'a> &'a OpType: TryInto<&'a OT>,
384    {
385        self.try_fat(self.entrypoint())
386    }
387}
388
389impl<H: HugrView> FatExt for H {}