1use std::collections::BTreeMap;
3
4use crate::{HugrView, Node, core::HugrNode, ops::OpType};
5use petgraph::{Graph, visit::EdgeRef};
6
7#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
9#[non_exhaustive]
10pub enum StaticEdge<N = Node> {
11 Call(N),
13 LoadFunction(N),
15 LoadConstant(N),
17}
18
19#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
21#[non_exhaustive]
22pub enum StaticNode<N = Node> {
23 FuncDecl(N),
25 FuncDefn(N),
27 NonFuncEntrypoint,
31 Const(N),
34}
35
36pub struct ModuleGraph<N = Node> {
53 g: Graph<StaticNode<N>, StaticEdge<N>>,
54 node_to_g: BTreeMap<N, petgraph::graph::NodeIndex<u32>>,
55}
56
57impl<N: HugrNode> ModuleGraph<N> {
58 pub fn new(hugr: &impl HugrView<Node = N>) -> Self {
60 let mut g = Graph::default();
61 let mut node_to_g = hugr
62 .children(hugr.module_root())
63 .filter_map(|n| {
64 let weight = match hugr.get_optype(n) {
65 OpType::FuncDecl(_) => StaticNode::FuncDecl(n),
66 OpType::FuncDefn(_) => StaticNode::FuncDefn(n),
67 OpType::Const(_) => StaticNode::Const(n),
68 _ => return None,
69 };
70 Some((n, g.add_node(weight)))
71 })
72 .collect::<BTreeMap<_, _>>();
73 if !hugr.entrypoint_optype().is_module() && !node_to_g.contains_key(&hugr.entrypoint()) {
74 node_to_g.insert(hugr.entrypoint(), g.add_node(StaticNode::NonFuncEntrypoint));
75 }
76 for (func, cg_node) in &node_to_g {
77 for n in hugr.descendants(*func) {
78 let weight = match hugr.get_optype(n) {
79 OpType::Call(_) => StaticEdge::Call(n),
80 OpType::LoadFunction(_) => StaticEdge::LoadFunction(n),
81 OpType::LoadConstant(_) => StaticEdge::LoadConstant(n),
82 _ => continue,
83 };
84 if let Some(target) = hugr.static_source(n) {
85 if hugr.get_parent(target) == Some(hugr.module_root()) {
86 g.add_edge(*cg_node, node_to_g[&target], weight);
87 } else {
88 assert!(!node_to_g.contains_key(&target));
90 assert!(hugr.get_optype(n).is_load_constant());
91 assert!(hugr.get_optype(target).is_const());
92 }
93 }
94 }
95 }
96 ModuleGraph { g, node_to_g }
97 }
98
99 #[must_use]
101 pub fn graph(&self) -> &Graph<StaticNode<N>, StaticEdge<N>> {
102 &self.g
103 }
104
105 pub fn node_index(&self, n: N) -> Option<petgraph::graph::NodeIndex<u32>> {
109 self.node_to_g.get(&n).copied()
110 }
111
112 pub fn out_edges(&self, n: N) -> impl Iterator<Item = (&StaticEdge<N>, &StaticNode<N>)> {
118 let g = self.graph();
119 self.node_index(n).into_iter().flat_map(move |n| {
120 self.graph().edges(n).map(|e| {
121 (
122 g.edge_weight(e.id()).unwrap(),
123 g.node_weight(e.target()).unwrap(),
124 )
125 })
126 })
127 }
128
129 pub fn in_edges(&self, n: N) -> impl Iterator<Item = (&StaticNode<N>, &StaticEdge<N>)> {
135 let g = self.graph();
136 self.node_index(n).into_iter().flat_map(move |n| {
137 self.graph()
138 .edges_directed(n, petgraph::Direction::Incoming)
139 .map(|e| {
140 (
141 g.node_weight(e.source()).unwrap(),
142 g.edge_weight(e.id()).unwrap(),
143 )
144 })
145 })
146 }
147}
148
149#[cfg(test)]
150mod test {
151 use itertools::Itertools as _;
152 use rstest::rstest;
153 use std::collections::HashMap;
154
155 use crate::builder::{
156 Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
157 ModuleBuilder, endo_sig, inout_sig,
158 };
159 use crate::extension::prelude::{ConstUsize, usize_t};
160 use crate::hugr::hugrmut::HugrMut;
161 use crate::ops::{Value, handle::NodeHandle};
162
163 use super::*;
164
165 #[test]
166 fn edges() {
167 let mut mb = ModuleBuilder::new();
168 let cst = mb.add_constant(Value::from(ConstUsize::new(42)));
169 let callee = mb.define_function("callee", endo_sig([usize_t()])).unwrap();
170 let ins = callee.input_wires();
171 let callee = callee.finish_with_outputs(ins).unwrap();
172 let mut caller = mb
173 .define_function("caller", inout_sig(vec![], [usize_t()]))
174 .unwrap();
175 let val = caller.load_const(&cst);
176 let call = caller.call(callee.handle(), &[], vec![val]).unwrap();
177 let caller = caller.finish_with_outputs(call.outputs()).unwrap();
178 let h = mb.finish_hugr().unwrap();
179
180 let mg = ModuleGraph::new(&h);
181 let call_edge = StaticEdge::Call(call.node());
182 let load_const_edge = StaticEdge::LoadConstant(val.node());
183
184 assert_eq!(mg.out_edges(callee.node()).next(), None);
185 assert_eq!(
186 mg.in_edges(callee.node()).collect_vec(),
187 [(&StaticNode::FuncDefn(caller.node()), &call_edge,)]
188 );
189
190 assert_eq!(
191 mg.out_edges(caller.node()).collect_vec(),
192 [
193 (&call_edge, &StaticNode::FuncDefn(callee.node()),),
194 (&load_const_edge, &StaticNode::Const(cst.node()),)
195 ]
196 );
197 assert_eq!(mg.in_edges(caller.node()).next(), None);
198
199 assert_eq!(mg.out_edges(cst.node()).next(), None);
200 assert_eq!(
201 mg.in_edges(cst.node()).collect_vec(),
202 [(&StaticNode::FuncDefn(caller.node()), &load_const_edge,)]
203 );
204 }
205
206 #[rstest]
207 fn entrypoint(#[values(true, false)] single_node: bool) {
208 let mut dfb = DFGBuilder::new(endo_sig([usize_t()])).unwrap();
209 let called = dfb
210 .module_root_builder()
211 .declare("called", endo_sig([usize_t()]).into())
212 .unwrap();
213 let ins = dfb.input_wires();
214 let call = dfb.call(&called, &[], ins).unwrap();
215 let mut h = dfb.finish_hugr_with_outputs(call.outputs()).unwrap();
216 let main = h.get_parent(h.entrypoint()).unwrap();
217 if single_node {
218 h.set_entrypoint(call.node());
219 }
220
221 let mg = ModuleGraph::new(&h);
222 let mut in_edges: HashMap<_, _> = mg.in_edges(called.node()).collect();
223 assert_eq!(
224 in_edges.remove(&StaticNode::NonFuncEntrypoint),
225 Some(&StaticEdge::Call(call.node()))
226 );
227 let expected = (&StaticNode::FuncDefn(main), &StaticEdge::Call(call.node()));
228 assert_eq!(in_edges, HashMap::from_iter([expected]));
229 for n in [h.entrypoint(), main] {
230 assert_eq!(
231 mg.out_edges(n).collect_vec(),
232 vec![(
233 &StaticEdge::Call(call.node()),
234 &StaticNode::FuncDecl(called.node())
235 )]
236 );
237 }
238 }
239}