hugr_llvm/utils/
fat.rs

1//! We define a type [`FatNode`], named for analogy with a "fat pointer".
2//!
3//! We define a trait [`FatExt`], an extension trait for [`HugrView`]. It provides
4//! methods that return [`FatNode`]s rather than [Node]s.
5use std::{cmp::Ordering, fmt, hash::Hash, marker::PhantomData, ops::Deref};
6
7use hugr_core::hugr::views::Rerooted;
8use hugr_core::{
9    Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort,
10    core::HugrNode,
11    ops::{CFG, DataflowBlock, ExitBlock, Input, Module, OpType, Output},
12    types::Type,
13};
14use itertools::Itertools as _;
15
16/// A Fat Node is a [Node] along with a reference to the [`HugrView`] whence it
17/// originates.
18///
19/// It carries a type `OT`, the [`OpType`] of that node. `OT` may be
20/// general, i.e. exactly [`OpType`], or specifec, e.g. [`FuncDefn`].
21///
22/// We provide a [Deref<Target=OT>] impl, so it can be used in place of `OT`.
23///
24/// We provide [`PartialEq`], [Eq], [`PartialOrd`], [Ord], [Hash], so that this type
25/// can be used in containers. Note that these implementations use only the
26/// stored node, so will silently malfunction if used with [`FatNode`]s from
27/// different base [Hugr]s. Note that [Node] has this same behaviour.
28///
29/// [`FuncDefn`]: [hugr_core::ops::FuncDefn]
30#[derive(Debug)]
31pub struct FatNode<'hugr, OT = OpType, H = Hugr, N = Node>
32where
33    H: ?Sized,
34{
35    hugr: &'hugr H,
36    node: N,
37    marker: PhantomData<OT>,
38}
39
40impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node>
41where
42    for<'a> &'a OpType: TryInto<&'a OT>,
43{
44    /// Create a `FatNode` from a [`HugrView`] and a [Node].
45    ///
46    /// Panics if the node is not valid in the `Hugr` or if it's `get_optype` is
47    /// not an `OT`.
48    ///
49    /// Note that while we do check the type of the node's `get_optype`, we
50    /// do not verify that it is actually equal to `ot`.
51    pub fn new(hugr: &'hugr H, node: H::Node, #[allow(unused)] ot: &OT) -> Self {
52        assert!(hugr.contains_node(node));
53        assert!(TryInto::<&OT>::try_into(hugr.get_optype(node)).is_ok());
54        // We don't actually check `ot == hugr.get_optype(node)` so as to not require OT: PartialEq`
55        Self {
56            hugr,
57            node,
58            marker: PhantomData,
59        }
60    }
61
62    /// Tries to create a `FatNode` from a [`HugrView`] and a node (typically a
63    /// [Node]).
64    ///
65    /// If the node is invalid, or if its `get_optype` is not `OT`, returns
66    /// `None`.
67    pub fn try_new(hugr: &'hugr H, node: H::Node) -> Option<Self> {
68        (hugr.contains_node(node)).then_some(())?;
69        Some(Self::new(
70            hugr,
71            node,
72            hugr.get_optype(node).try_into().ok()?,
73        ))
74    }
75
76    /// Create a general `FatNode` from a specific one.
77    pub fn generalise(self) -> FatNode<'hugr, OpType, H, H::Node> {
78        // guaranteed to be valid because self is valid
79        FatNode {
80            hugr: self.hugr,
81            node: self.node,
82            marker: PhantomData,
83        }
84    }
85}
86
87impl<'hugr, OT, H, N: HugrNode> FatNode<'hugr, OT, H, N> {
88    /// Gets the [Node] of the `FatNode`.
89    pub fn node(&self) -> N {
90        self.node
91    }
92
93    /// Gets the [`HugrView`] of the `FatNode`.
94    pub fn hugr(&self) -> &'hugr H {
95        self.hugr
96    }
97}
98
99impl<'hugr, H: HugrView + ?Sized> FatNode<'hugr, OpType, H, H::Node> {
100    /// Creates a new general `FatNode` from a [`HugrView`] and a [Node].
101    ///
102    /// Panics if the node is not valid in the [Hugr].
103    pub fn new_optype(hugr: &'hugr H, node: H::Node) -> Self {
104        assert!(hugr.contains_node(node));
105        FatNode::new(hugr, node, hugr.get_optype(node))
106    }
107
108    /// Tries to downcast a general `FatNode` into a specific `OT`.
109    pub fn try_into_ot<OT>(&self) -> Option<FatNode<'hugr, OT, H, H::Node>>
110    where
111        for<'a> &'a OpType: TryInto<&'a OT>,
112    {
113        FatNode::try_new(self.hugr, self.node)
114    }
115
116    /// Creates a specific `FatNode` from a general `FatNode`.
117    ///
118    /// Panics if the node is not valid in the `Hugr` or if its `get_optype` is
119    /// not an `OT`.
120    ///
121    /// Note that while we do check the type of the node's `get_optype`, we
122    /// do not verify that it is actually equal to `ot`.
123    pub fn into_ot<OT>(self, ot: &OT) -> FatNode<'hugr, OT, H, H::Node>
124    where
125        for<'a> &'a OpType: TryInto<&'a OT>,
126    {
127        FatNode::new(self.hugr, self.node, ot)
128    }
129}
130
131impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node> {
132    /// If there is exactly one `OutgoingPort` connected to this `IncomingPort`,
133    /// return it and its node.
134    #[allow(clippy::type_complexity)]
135    pub fn single_linked_output(
136        &self,
137        port: IncomingPort,
138    ) -> Option<(FatNode<'hugr, OpType, H, H::Node>, OutgoingPort)> {
139        self.hugr
140            .single_linked_output(self.node, port)
141            .map(|(n, p)| (FatNode::new_optype(self.hugr, n), p))
142    }
143
144    /// Iterator over all incoming ports that have Value type, along
145    /// with corresponding types.
146    pub fn out_value_types(
147        &self,
148    ) -> impl Iterator<Item = (OutgoingPort, Type)> + 'hugr + use<'hugr, OT, H> {
149        self.hugr.out_value_types(self.node)
150    }
151
152    /// Iterator over all incoming ports that have Value type, along
153    /// with corresponding types.
154    pub fn in_value_types(
155        &self,
156    ) -> impl Iterator<Item = (IncomingPort, Type)> + 'hugr + use<'hugr, OT, H> {
157        self.hugr.in_value_types(self.node)
158    }
159
160    /// Return iterator over the direct children of node.
161    pub fn children(
162        &self,
163    ) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr + use<'hugr, OT, H> {
164        self.hugr
165            .children(self.node)
166            .map(|n| FatNode::new_optype(self.hugr, n))
167    }
168
169    /// Get the input and output child nodes of a dataflow parent.
170    /// If the node isn't a dataflow parent, then return None
171    #[allow(clippy::type_complexity)]
172    pub fn get_io(
173        &self,
174    ) -> Option<(
175        FatNode<'hugr, Input, H, H::Node>,
176        FatNode<'hugr, Output, H, H::Node>,
177    )> {
178        let [i, o] = self.hugr.get_io(self.node)?;
179        Some((
180            FatNode::try_new(self.hugr, i)?,
181            FatNode::try_new(self.hugr, o)?,
182        ))
183    }
184
185    /// Iterator over output ports of node.
186    pub fn node_outputs(&self) -> impl Iterator<Item = OutgoingPort> + 'hugr + use<'hugr, OT, H> {
187        self.hugr.node_outputs(self.node)
188    }
189
190    /// Iterates over the output neighbours of the `node`.
191    pub fn output_neighbours(
192        &self,
193    ) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr + use<'hugr, OT, H> {
194        self.hugr
195            .output_neighbours(self.node)
196            .map(|n| FatNode::new_optype(self.hugr, n))
197    }
198
199    /// Returns a view of the internal [`HugrView`] with this [Node] as entrypoint.
200    pub fn as_entrypoint(&self) -> Rerooted<&H>
201    where
202        H: Sized,
203    {
204        self.hugr.with_entrypoint(self.node)
205    }
206}
207
208impl<'hugr, H: HugrView> FatNode<'hugr, CFG, H, H::Node> {
209    /// Returns the entry and exit nodes of a CFG.
210    ///
211    /// These are guaranteed to exist the `Hugr` is valid. Panics if they do not
212    /// exist.
213    #[allow(clippy::type_complexity)]
214    pub fn get_entry_exit(
215        &self,
216    ) -> (
217        FatNode<'hugr, DataflowBlock, H, H::Node>,
218        FatNode<'hugr, ExitBlock, H, H::Node>,
219    ) {
220        let [i, o] = self
221            .hugr
222            .children(self.node)
223            .take(2)
224            .collect_vec()
225            .try_into()
226            .unwrap();
227        (
228            FatNode::try_new(self.hugr, i).unwrap(),
229            FatNode::try_new(self.hugr, o).unwrap(),
230        )
231    }
232}
233
234impl<OT, H> PartialEq<Node> for FatNode<'_, OT, H, Node> {
235    fn eq(&self, other: &Node) -> bool {
236        &self.node == other
237    }
238}
239
240impl<OT, H> PartialEq<FatNode<'_, OT, H, Node>> for Node {
241    fn eq(&self, other: &FatNode<'_, OT, H, Node>) -> bool {
242        self == &other.node
243    }
244}
245
246impl<N: PartialEq, OT1, OT2, H1, H2> PartialEq<FatNode<'_, OT1, H1, N>>
247    for FatNode<'_, OT2, H2, N>
248{
249    fn eq(&self, other: &FatNode<'_, OT1, H1, N>) -> bool {
250        self.node == other.node
251    }
252}
253
254impl<N: Eq, OT, H> Eq for FatNode<'_, OT, H, N> {}
255
256impl<OT, H> PartialOrd<Node> for FatNode<'_, OT, H> {
257    fn partial_cmp(&self, other: &Node) -> Option<Ordering> {
258        self.node.partial_cmp(other)
259    }
260}
261
262impl<OT, H> PartialOrd<FatNode<'_, OT, H>> for Node {
263    fn partial_cmp(&self, other: &FatNode<'_, OT, H>) -> Option<Ordering> {
264        self.partial_cmp(&other.node)
265    }
266}
267
268impl<N: PartialOrd, OT1, OT2, H1, H2> PartialOrd<FatNode<'_, OT1, H1, N>>
269    for FatNode<'_, OT2, H2, N>
270{
271    fn partial_cmp(&self, other: &FatNode<'_, OT1, H1, N>) -> Option<Ordering> {
272        self.node.partial_cmp(&other.node)
273    }
274}
275
276impl<OT, H, N: Ord> Ord for FatNode<'_, OT, H, N> {
277    fn cmp(&self, other: &Self) -> Ordering {
278        self.node.cmp(&other.node)
279    }
280}
281
282impl<OT, H, N: Hash> Hash for FatNode<'_, OT, H, N> {
283    fn hash<HA: std::hash::Hasher>(&self, state: &mut HA) {
284        self.node.hash(state);
285    }
286}
287
288impl<OT, H: HugrView + ?Sized> AsRef<OT> for FatNode<'_, OT, H, H::Node>
289where
290    for<'a> &'a OpType: TryInto<&'a OT>,
291{
292    fn as_ref(&self) -> &OT {
293        self.hugr.get_optype(self.node).try_into().ok().unwrap()
294    }
295}
296
297impl<OT, H: HugrView + ?Sized> Deref for FatNode<'_, OT, H, H::Node>
298where
299    for<'a> &'a OpType: TryInto<&'a OT>,
300{
301    type Target = OT;
302
303    fn deref(&self) -> &Self::Target {
304        self.as_ref()
305    }
306}
307
308impl<OT, H> Copy for FatNode<'_, OT, H> {}
309
310impl<OT, H> Clone for FatNode<'_, OT, H> {
311    fn clone(&self) -> Self {
312        *self
313    }
314}
315
316impl<OT: fmt::Display, H: HugrView + ?Sized> fmt::Display for FatNode<'_, OT, H, H::Node>
317where
318    for<'a> &'a OpType: TryInto<&'a OT>,
319{
320    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321        f.write_fmt(format_args!("N<{}:{}>", self.as_ref(), self.node))
322    }
323}
324
325impl<OT, H, N: NodeIndex> NodeIndex for FatNode<'_, OT, H, N> {
326    fn index(self) -> usize {
327        self.node.index()
328    }
329}
330
331impl<OT, H> NodeIndex for &FatNode<'_, OT, H> {
332    fn index(self) -> usize {
333        self.node.index()
334    }
335}
336
337/// An extension trait for [`HugrView`] which provides methods that delegate to
338/// [`HugrView`] and then return the result in [`FatNode`] form. See for example
339/// [`FatExt::fat_io`].
340///
341/// TODO: Add the remaining [`HugrView`] equivalents that make sense.
342pub trait FatExt: HugrView {
343    /// Try to create a specific [`FatNode`] for a given [Node].
344    fn try_fat<OT>(&self, node: Self::Node) -> Option<FatNode<'_, OT, Self, Self::Node>>
345    where
346        for<'a> &'a OpType: TryInto<&'a OT>,
347    {
348        FatNode::try_new(self, node)
349    }
350
351    /// Create a general [`FatNode`] for a given [Node].
352    fn fat_optype(&self, node: Self::Node) -> FatNode<'_, OpType, Self, Self::Node> {
353        FatNode::new_optype(self, node)
354    }
355
356    /// Try to create [Input] and [Output] [`FatNode`]s for a given [Node]. This
357    /// will succeed only for `DataFlow` Parent Nodes.
358    #[allow(clippy::type_complexity)]
359    fn fat_io(
360        &self,
361        node: Self::Node,
362    ) -> Option<(
363        FatNode<'_, Input, Self, Self::Node>,
364        FatNode<'_, Output, Self, Self::Node>,
365    )> {
366        self.fat_optype(node).get_io()
367    }
368
369    /// Create general [`FatNode`]s for each of a [Node]'s children.
370    fn fat_children(
371        &self,
372        node: Self::Node,
373    ) -> impl Iterator<Item = FatNode<'_, OpType, Self, Self::Node>> {
374        self.children(node).map(|x| self.fat_optype(x))
375    }
376
377    /// Try to create a specific [`FatNode`] for the root of a [`HugrView`].
378    fn fat_root(&self) -> Option<FatNode<'_, Module, Self, Self::Node>> {
379        self.try_fat(self.module_root())
380    }
381
382    /// Try to create a specific [`FatNode`] for the entrypoint of a [`HugrView`].
383    fn fat_entrypoint<OT>(&self) -> Option<FatNode<'_, OT, Self, Self::Node>>
384    where
385        for<'a> &'a OpType: TryInto<&'a OT>,
386    {
387        self.try_fat(self.entrypoint())
388    }
389}
390
391impl<H: HugrView> FatExt for H {}