Skip to main content

hugr_core/
module_graph.rs

1//! Data structure summarizing static nodes of a Hugr and their uses
2use std::collections::BTreeMap;
3
4use crate::{HugrView, Node, core::HugrNode, ops::OpType};
5use petgraph::{Graph, visit::EdgeRef};
6
7/// Weight for an edge in a [`ModuleGraph`]
8#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
9#[non_exhaustive]
10pub enum StaticEdge<N = Node> {
11    /// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr
12    Call(N),
13    /// Edge corresponds to a [`LoadFunction`](OpType::LoadFunction) node (specified) in the Hugr
14    LoadFunction(N),
15    /// Edge corresponds to a [LoadConstant](OpType::LoadConstant) node (specified) in the Hugr
16    LoadConstant(N),
17}
18
19/// Weight for a petgraph-node in a [`ModuleGraph`]
20#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
21#[non_exhaustive]
22pub enum StaticNode<N = Node> {
23    /// petgraph-node corresponds to a [`FuncDecl`](OpType::FuncDecl) node (specified) in the Hugr
24    FuncDecl(N),
25    /// petgraph-node corresponds to a [`FuncDefn`](OpType::FuncDefn) node (specified) in the Hugr
26    FuncDefn(N),
27    /// petgraph-node corresponds to the [HugrView::entrypoint], that is not
28    /// a [`FuncDefn`](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module)
29    /// either, as such a node could not have edges, so is not represented in the petgraph.
30    NonFuncEntrypoint,
31    /// petgraph-node corresponds to a constant; will have no outgoing edges, and incoming
32    /// edges will be [StaticEdge::LoadConstant]
33    Const(N),
34}
35
36/// Details the [`FuncDefn`]s, [`FuncDecl`]s and module-level [`Const`]s in a Hugr,
37/// in a Hugr, along with the [`Call`]s, [`LoadFunction`]s, and [`LoadConstant`]s connecting them.
38///
39/// Each node in the `ModuleGraph` corresponds to a module-level function or const;
40/// each edge corresponds to a use of the target contained in the edge's source.
41///
42/// For Hugrs whose entrypoint is neither a [Module](OpType::Module) nor a [`FuncDefn`],
43/// the static graph will have an additional [`StaticNode::NonFuncEntrypoint`]
44/// corresponding to the Hugr's entrypoint, with no incoming edges.
45///
46/// [`Call`]: OpType::Call
47/// [`Const`]: OpType::Const
48/// [`FuncDecl`]: OpType::FuncDecl
49/// [`FuncDefn`]: OpType::FuncDefn
50/// [`LoadConstant`]: OpType::LoadConstant
51/// [`LoadFunction`]: OpType::LoadFunction
52pub struct ModuleGraph<N = Node> {
53    g: Graph<StaticNode<N>, StaticEdge<N>>,
54    node_to_g: BTreeMap<N, petgraph::graph::NodeIndex<u32>>,
55}
56
57impl<N: HugrNode> ModuleGraph<N> {
58    /// Makes a new `ModuleGraph` for a Hugr.
59    pub fn new(hugr: &impl HugrView<Node = N>) -> Self {
60        let mut g = Graph::default();
61        let mut node_to_g = hugr
62            .children(hugr.module_root())
63            .filter_map(|n| {
64                let weight = match hugr.get_optype(n) {
65                    OpType::FuncDecl(_) => StaticNode::FuncDecl(n),
66                    OpType::FuncDefn(_) => StaticNode::FuncDefn(n),
67                    OpType::Const(_) => StaticNode::Const(n),
68                    _ => return None,
69                };
70                Some((n, g.add_node(weight)))
71            })
72            .collect::<BTreeMap<_, _>>();
73        if !hugr.entrypoint_optype().is_module() && !node_to_g.contains_key(&hugr.entrypoint()) {
74            node_to_g.insert(hugr.entrypoint(), g.add_node(StaticNode::NonFuncEntrypoint));
75        }
76        for (func, cg_node) in &node_to_g {
77            for n in hugr.descendants(*func) {
78                let weight = match hugr.get_optype(n) {
79                    OpType::Call(_) => StaticEdge::Call(n),
80                    OpType::LoadFunction(_) => StaticEdge::LoadFunction(n),
81                    OpType::LoadConstant(_) => StaticEdge::LoadConstant(n),
82                    _ => continue,
83                };
84                if let Some(target) = hugr.static_source(n) {
85                    if hugr.get_parent(target) == Some(hugr.module_root()) {
86                        g.add_edge(*cg_node, node_to_g[&target], weight);
87                    } else {
88                        // Local constant (only global constants are in the graph)
89                        assert!(!node_to_g.contains_key(&target));
90                        assert!(hugr.get_optype(n).is_load_constant());
91                        assert!(hugr.get_optype(target).is_const());
92                    }
93                }
94            }
95        }
96        ModuleGraph { g, node_to_g }
97    }
98
99    /// Allows access to the petgraph
100    #[must_use]
101    pub fn graph(&self) -> &Graph<StaticNode<N>, StaticEdge<N>> {
102        &self.g
103    }
104
105    /// Convert a Hugr [Node] into a petgraph node index.
106    /// Result will be `None` if `n` is not a [`FuncDefn`](OpType::FuncDefn),
107    /// [`FuncDecl`](OpType::FuncDecl) or the [HugrView::entrypoint].
108    pub fn node_index(&self, n: N) -> Option<petgraph::graph::NodeIndex<u32>> {
109        self.node_to_g.get(&n).copied()
110    }
111
112    /// Returns an iterator over the out-edges from the given Node, i.e.
113    /// edges to the functions/constants called/loaded by it.
114    ///
115    /// If the node is not recognised as a function or the entrypoint,
116    /// for example if it is a [`Const`](OpType::Const), the iterator will be empty.
117    pub fn out_edges(&self, n: N) -> impl Iterator<Item = (&StaticEdge<N>, &StaticNode<N>)> {
118        let g = self.graph();
119        self.node_index(n).into_iter().flat_map(move |n| {
120            self.graph().edges(n).map(|e| {
121                (
122                    g.edge_weight(e.id()).unwrap(),
123                    g.node_weight(e.target()).unwrap(),
124                )
125            })
126        })
127    }
128
129    /// Returns an iterator over the in-edges to the given Node, i.e.
130    /// edges from the (necessarily) functions that call/load it.
131    ///
132    /// If the node is not recognised as a function or constant,
133    /// for example if it is a non-function entrypoint, the iterator will be empty.
134    pub fn in_edges(&self, n: N) -> impl Iterator<Item = (&StaticNode<N>, &StaticEdge<N>)> {
135        let g = self.graph();
136        self.node_index(n).into_iter().flat_map(move |n| {
137            self.graph()
138                .edges_directed(n, petgraph::Direction::Incoming)
139                .map(|e| {
140                    (
141                        g.node_weight(e.source()).unwrap(),
142                        g.edge_weight(e.id()).unwrap(),
143                    )
144                })
145        })
146    }
147}
148
149#[cfg(test)]
150mod test {
151    use itertools::Itertools as _;
152    use rstest::rstest;
153    use std::collections::HashMap;
154
155    use crate::builder::{
156        Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
157        ModuleBuilder, endo_sig, inout_sig,
158    };
159    use crate::extension::prelude::{ConstUsize, usize_t};
160    use crate::hugr::hugrmut::HugrMut;
161    use crate::ops::{Value, handle::NodeHandle};
162
163    use super::*;
164
165    #[test]
166    fn edges() {
167        let mut mb = ModuleBuilder::new();
168        let cst = mb.add_constant(Value::from(ConstUsize::new(42)));
169        let callee = mb.define_function("callee", endo_sig([usize_t()])).unwrap();
170        let ins = callee.input_wires();
171        let callee = callee.finish_with_outputs(ins).unwrap();
172        let mut caller = mb
173            .define_function("caller", inout_sig(vec![], [usize_t()]))
174            .unwrap();
175        let val = caller.load_const(&cst);
176        let call = caller.call(callee.handle(), &[], vec![val]).unwrap();
177        let caller = caller.finish_with_outputs(call.outputs()).unwrap();
178        let h = mb.finish_hugr().unwrap();
179
180        let mg = ModuleGraph::new(&h);
181        let call_edge = StaticEdge::Call(call.node());
182        let load_const_edge = StaticEdge::LoadConstant(val.node());
183
184        assert_eq!(mg.out_edges(callee.node()).next(), None);
185        assert_eq!(
186            mg.in_edges(callee.node()).collect_vec(),
187            [(&StaticNode::FuncDefn(caller.node()), &call_edge,)]
188        );
189
190        assert_eq!(
191            mg.out_edges(caller.node()).collect_vec(),
192            [
193                (&call_edge, &StaticNode::FuncDefn(callee.node()),),
194                (&load_const_edge, &StaticNode::Const(cst.node()),)
195            ]
196        );
197        assert_eq!(mg.in_edges(caller.node()).next(), None);
198
199        assert_eq!(mg.out_edges(cst.node()).next(), None);
200        assert_eq!(
201            mg.in_edges(cst.node()).collect_vec(),
202            [(&StaticNode::FuncDefn(caller.node()), &load_const_edge,)]
203        );
204    }
205
206    #[rstest]
207    fn entrypoint(#[values(true, false)] single_node: bool) {
208        let mut dfb = DFGBuilder::new(endo_sig([usize_t()])).unwrap();
209        let called = dfb
210            .module_root_builder()
211            .declare("called", endo_sig([usize_t()]).into())
212            .unwrap();
213        let ins = dfb.input_wires();
214        let call = dfb.call(&called, &[], ins).unwrap();
215        let mut h = dfb.finish_hugr_with_outputs(call.outputs()).unwrap();
216        let main = h.get_parent(h.entrypoint()).unwrap();
217        if single_node {
218            h.set_entrypoint(call.node());
219        }
220
221        let mg = ModuleGraph::new(&h);
222        let mut in_edges: HashMap<_, _> = mg.in_edges(called.node()).collect();
223        assert_eq!(
224            in_edges.remove(&StaticNode::NonFuncEntrypoint),
225            Some(&StaticEdge::Call(call.node()))
226        );
227        let expected = (&StaticNode::FuncDefn(main), &StaticEdge::Call(call.node()));
228        assert_eq!(in_edges, HashMap::from_iter([expected]));
229        for n in [h.entrypoint(), main] {
230            assert_eq!(
231                mg.out_edges(n).collect_vec(),
232                vec![(
233                    &StaticEdge::Call(call.node()),
234                    &StaticNode::FuncDecl(called.node())
235                )]
236            );
237        }
238    }
239}