hugr_llvm/utils/
fat.rs

1//! We define a type [FatNode], named for analogy with a "fat pointer".
2//!
3//! We define a trait [FatExt], an extension trait for [HugrView]. It provides
4//! methods that return [FatNode]s rather than [Node]s.
5use std::{cmp::Ordering, hash::Hash, marker::PhantomData, ops::Deref};
6
7use hugr_core::{
8    core::HugrNode,
9    hugr::{views::HierarchyView, HugrError},
10    ops::{DataflowBlock, ExitBlock, Input, NamedOp, OpType, Output, CFG},
11    types::Type,
12    Hugr, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort,
13};
14use itertools::Itertools as _;
15
16/// A Fat Node is a [Node] along with a reference to the [HugrView] whence it
17/// originates. It carries a type `OT`, the [OpType] of that node. `OT` may be
18/// general, i.e. exactly [OpType], or specifec, e.g. [FuncDefn].
19///
20/// We provide a [Deref<Target=OT>] impl, so it can be used in place of `OT`.
21///
22/// We provide [PartialEq], [Eq], [PartialOrd], [Ord], [Hash], so that this type
23/// can be used in containers. Note that these implementations use only the
24/// stored node, so will silently malfunction if used with [FatNode]s from
25/// different base [Hugr]s. Note that [Node] has this same behaviour.
26///
27/// [FuncDefn]: [hugr_core::ops::FuncDefn]
28#[derive(Debug)]
29pub struct FatNode<'hugr, OT = OpType, H = Hugr, N = Node>
30where
31    H: ?Sized,
32{
33    hugr: &'hugr H,
34    node: N,
35    marker: PhantomData<OT>,
36}
37
38impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node>
39where
40    for<'a> &'a OpType: TryInto<&'a OT>,
41{
42    /// Create a `FatNode` from a [HugrView] and a [Node].
43    ///
44    /// Panics if the node is not valid in the `Hugr` or if it's `get_optype` is
45    /// not an `OT`.
46    ///
47    /// Note that while we do check the type of the node's `get_optype`, we
48    /// do not verify that it is actually equal to `ot`.
49    pub fn new(hugr: &'hugr H, node: H::Node, #[allow(unused)] ot: &OT) -> Self {
50        assert!(hugr.valid_node(node));
51        assert!(TryInto::<&OT>::try_into(hugr.get_optype(node)).is_ok());
52        // We don't actually check `ot == hugr.get_optype(node)` so as to not require OT: PartialEq`
53        Self {
54            hugr,
55            node,
56            marker: PhantomData,
57        }
58    }
59
60    /// Tries to create a `FatNode` from a [HugrView] and a node (typically a
61    /// [Node]).
62    ///
63    /// If the node is invalid, or if its `get_optype` is not `OT`, returns
64    /// `None`.
65    pub fn try_new(hugr: &'hugr H, node: H::Node) -> Option<Self> {
66        (hugr.valid_node(node)).then_some(())?;
67        Some(Self::new(
68            hugr,
69            node,
70            hugr.get_optype(node).try_into().ok()?,
71        ))
72    }
73
74    /// Create a general `FatNode` from a specific one.
75    pub fn generalise(self) -> FatNode<'hugr, OpType, H, H::Node> {
76        // guaranteed to be valid because self is valid
77        FatNode {
78            hugr: self.hugr,
79            node: self.node,
80            marker: PhantomData,
81        }
82    }
83}
84
85impl<'hugr, OT, H, N: HugrNode> FatNode<'hugr, OT, H, N> {
86    /// Gets the [Node] of the `FatNode`.
87    pub fn node(&self) -> N {
88        self.node
89    }
90
91    /// Gets the [HugrView] of the `FatNode`.
92    pub fn hugr(&self) -> &'hugr H {
93        self.hugr
94    }
95}
96
97impl<'hugr, H: HugrView + ?Sized> FatNode<'hugr, OpType, H, H::Node> {
98    /// Creates a new general `FatNode` from a [HugrView] and a [Node].
99    ///
100    /// Panics if the node is not valid in the [Hugr].
101    pub fn new_optype(hugr: &'hugr H, node: H::Node) -> Self {
102        assert!(hugr.valid_node(node));
103        FatNode::new(hugr, node, hugr.get_optype(node))
104    }
105
106    /// Tries to downcast a general `FatNode` into a specific `OT`.
107    pub fn try_into_ot<OT>(&self) -> Option<FatNode<'hugr, OT, H, H::Node>>
108    where
109        for<'a> &'a OpType: TryInto<&'a OT>,
110    {
111        FatNode::try_new(self.hugr, self.node)
112    }
113
114    /// Creates a specific `FatNode` from a general `FatNode`.
115    ///
116    /// Panics if the node is not valid in the `Hugr` or if its `get_optype` is
117    /// not an `OT`.
118    ///
119    /// Note that while we do check the type of the node's `get_optype`, we
120    /// do not verify that it is actually equal to `ot`.
121    pub fn into_ot<OT>(self, ot: &OT) -> FatNode<'hugr, OT, H, H::Node>
122    where
123        for<'a> &'a OpType: TryInto<&'a OT>,
124    {
125        FatNode::new(self.hugr, self.node, ot)
126    }
127}
128
129impl<'hugr, OT, H: HugrView + ?Sized> FatNode<'hugr, OT, H, H::Node> {
130    /// If there is exactly one OutgoingPort connected to this IncomingPort,
131    /// return it and its node.
132    #[allow(clippy::type_complexity)]
133    pub fn single_linked_output(
134        &self,
135        port: IncomingPort,
136    ) -> Option<(FatNode<'hugr, OpType, H, H::Node>, OutgoingPort)> {
137        self.hugr
138            .single_linked_output(self.node, port)
139            .map(|(n, p)| (FatNode::new_optype(self.hugr, n), p))
140    }
141
142    /// Iterator over all incoming ports that have Value type, along
143    /// with corresponding types.
144    pub fn out_value_types(&self) -> impl Iterator<Item = (OutgoingPort, Type)> + 'hugr {
145        self.hugr.out_value_types(self.node)
146    }
147
148    /// Iterator over all incoming ports that have Value type, along
149    /// with corresponding types.
150    pub fn in_value_types(&self) -> impl Iterator<Item = (IncomingPort, Type)> + 'hugr {
151        self.hugr.in_value_types(self.node)
152    }
153
154    /// Return iterator over the direct children of node.
155    pub fn children(&self) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr {
156        self.hugr
157            .children(self.node)
158            .map(|n| FatNode::new_optype(self.hugr, n))
159    }
160
161    /// Get the input and output child nodes of a dataflow parent.
162    /// If the node isn't a dataflow parent, then return None
163    #[allow(clippy::type_complexity)]
164    pub fn get_io(
165        &self,
166    ) -> Option<(
167        FatNode<'hugr, Input, H, H::Node>,
168        FatNode<'hugr, Output, H, H::Node>,
169    )> {
170        let [i, o] = self.hugr.get_io(self.node)?;
171        Some((
172            FatNode::try_new(self.hugr, i)?,
173            FatNode::try_new(self.hugr, o)?,
174        ))
175    }
176
177    /// Iterator over output ports of node.
178    pub fn node_outputs(&self) -> impl Iterator<Item = OutgoingPort> + 'hugr {
179        self.hugr.node_outputs(self.node)
180    }
181
182    /// Iterates over the output neighbours of the `node`.
183    pub fn output_neighbours(
184        &self,
185    ) -> impl Iterator<Item = FatNode<'hugr, OpType, H, H::Node>> + 'hugr {
186        self.hugr
187            .output_neighbours(self.node)
188            .map(|n| FatNode::new_optype(self.hugr, n))
189    }
190
191    /// Delegates to `HV::try_new` with the internal [HugrView] and [Node].
192    pub fn try_new_hierarchy_view<HV: HierarchyView<'hugr, Node = H::Node>>(
193        &self,
194    ) -> Result<HV, HugrError>
195    where
196        H: Sized,
197    {
198        HV::try_new(self.hugr, self.node)
199    }
200}
201
202impl<'hugr, H: HugrView> FatNode<'hugr, CFG, H, H::Node> {
203    /// Returns the entry and exit nodes of a CFG.
204    ///
205    /// These are guaranteed to exist the `Hugr` is valid. Panics if they do not
206    /// exist.
207    #[allow(clippy::type_complexity)]
208    pub fn get_entry_exit(
209        &self,
210    ) -> (
211        FatNode<'hugr, DataflowBlock, H, H::Node>,
212        FatNode<'hugr, ExitBlock, H, H::Node>,
213    ) {
214        let [i, o] = self
215            .hugr
216            .children(self.node)
217            .take(2)
218            .collect_vec()
219            .try_into()
220            .unwrap();
221        (
222            FatNode::try_new(self.hugr, i).unwrap(),
223            FatNode::try_new(self.hugr, o).unwrap(),
224        )
225    }
226}
227
228impl<OT, H> PartialEq<Node> for FatNode<'_, OT, H, Node> {
229    fn eq(&self, other: &Node) -> bool {
230        &self.node == other
231    }
232}
233
234impl<OT, H> PartialEq<FatNode<'_, OT, H, Node>> for Node {
235    fn eq(&self, other: &FatNode<'_, OT, H, Node>) -> bool {
236        self == &other.node
237    }
238}
239
240impl<N: PartialEq, OT1, OT2, H1, H2> PartialEq<FatNode<'_, OT1, H1, N>>
241    for FatNode<'_, OT2, H2, N>
242{
243    fn eq(&self, other: &FatNode<'_, OT1, H1, N>) -> bool {
244        self.node == other.node
245    }
246}
247
248impl<N: Eq, OT, H> Eq for FatNode<'_, OT, H, N> {}
249
250impl<OT, H> PartialOrd<Node> for FatNode<'_, OT, H> {
251    fn partial_cmp(&self, other: &Node) -> Option<Ordering> {
252        self.node.partial_cmp(other)
253    }
254}
255
256impl<OT, H> PartialOrd<FatNode<'_, OT, H>> for Node {
257    fn partial_cmp(&self, other: &FatNode<'_, OT, H>) -> Option<Ordering> {
258        self.partial_cmp(&other.node)
259    }
260}
261
262impl<N: PartialOrd, OT1, OT2, H1, H2> PartialOrd<FatNode<'_, OT1, H1, N>>
263    for FatNode<'_, OT2, H2, N>
264{
265    fn partial_cmp(&self, other: &FatNode<'_, OT1, H1, N>) -> Option<Ordering> {
266        self.node.partial_cmp(&other.node)
267    }
268}
269
270impl<OT, H, N: Ord> Ord for FatNode<'_, OT, H, N> {
271    fn cmp(&self, other: &Self) -> Ordering {
272        self.node.cmp(&other.node)
273    }
274}
275
276impl<OT, H, N: Hash> Hash for FatNode<'_, OT, H, N> {
277    fn hash<HA: std::hash::Hasher>(&self, state: &mut HA) {
278        self.node.hash(state);
279    }
280}
281
282impl<OT, H: HugrView + ?Sized> AsRef<OT> for FatNode<'_, OT, H, H::Node>
283where
284    for<'a> &'a OpType: TryInto<&'a OT>,
285{
286    fn as_ref(&self) -> &OT {
287        self.hugr.get_optype(self.node).try_into().ok().unwrap()
288    }
289}
290
291impl<OT, H: HugrView + ?Sized> Deref for FatNode<'_, OT, H, H::Node>
292where
293    for<'a> &'a OpType: TryInto<&'a OT>,
294{
295    type Target = OT;
296
297    fn deref(&self) -> &Self::Target {
298        self.as_ref()
299    }
300}
301
302impl<OT, H> Copy for FatNode<'_, OT, H> {}
303
304impl<OT, H> Clone for FatNode<'_, OT, H> {
305    fn clone(&self) -> Self {
306        *self
307    }
308}
309
310impl<OT: NamedOp, H: HugrView + ?Sized> std::fmt::Display for FatNode<'_, OT, H, H::Node>
311where
312    for<'a> &'a OpType: TryInto<&'a OT>,
313{
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        f.write_fmt(format_args!("N<{}:{}>", self.as_ref().name(), self.node))
316    }
317}
318
319impl<OT, H, N: NodeIndex> NodeIndex for FatNode<'_, OT, H, N> {
320    fn index(self) -> usize {
321        self.node.index()
322    }
323}
324
325impl<OT, H> NodeIndex for &FatNode<'_, OT, H> {
326    fn index(self) -> usize {
327        self.node.index()
328    }
329}
330
331/// An extension trait for [HugrView] which provides methods that delegate to
332/// [HugrView] and then return the result in [FatNode] form. See for example
333/// [FatExt::fat_io].
334///
335/// TODO: Add the remaining [HugrView] equivalents that make sense.
336pub trait FatExt: HugrView {
337    /// Try to create a specific [FatNode] for a given [Node].
338    fn try_fat<OT>(&self, node: Self::Node) -> Option<FatNode<OT, Self, Self::Node>>
339    where
340        for<'a> &'a OpType: TryInto<&'a OT>,
341    {
342        FatNode::try_new(self, node)
343    }
344
345    /// Create a general [FatNode] for a given [Node].
346    fn fat_optype(&self, node: Self::Node) -> FatNode<OpType, Self, Self::Node> {
347        FatNode::new_optype(self, node)
348    }
349
350    /// Try to create [Input] and [Output] [FatNode]s for a given [Node]. This
351    /// will succeed only for DataFlow Parent Nodes.
352    #[allow(clippy::type_complexity)]
353    fn fat_io(
354        &self,
355        node: Self::Node,
356    ) -> Option<(
357        FatNode<Input, Self, Self::Node>,
358        FatNode<Output, Self, Self::Node>,
359    )> {
360        self.fat_optype(node).get_io()
361    }
362
363    /// Create general [FatNode]s for each of a [Node]'s children.
364    fn fat_children(
365        &self,
366        node: Self::Node,
367    ) -> impl Iterator<Item = FatNode<OpType, Self, Self::Node>> {
368        self.children(node).map(|x| self.fat_optype(x))
369    }
370
371    /// Try to create a specific [FatNode] for the root of a [HugrView].
372    fn fat_root<OT>(&self) -> Option<FatNode<OT, Self, Self::Node>>
373    where
374        for<'a> &'a OpType: TryInto<&'a OT>,
375    {
376        self.try_fat(self.root())
377    }
378}
379
380impl<H: HugrView> FatExt for H {}