1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
//! Data structure summarizing static nodes of a Hugr and their uses
use std::collections::BTreeMap;
use crate::{HugrView, Node, core::HugrNode, ops::OpType};
use petgraph::{Graph, visit::EdgeRef};
/// Weight for an edge in a [`ModuleGraph`]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[non_exhaustive]
pub enum StaticEdge<N = Node> {
/// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr
Call(N),
/// Edge corresponds to a [`LoadFunction`](OpType::LoadFunction) node (specified) in the Hugr
LoadFunction(N),
/// Edge corresponds to a [LoadConstant](OpType::LoadConstant) node (specified) in the Hugr
LoadConstant(N),
}
/// Weight for a petgraph-node in a [`ModuleGraph`]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[non_exhaustive]
pub enum StaticNode<N = Node> {
/// petgraph-node corresponds to a [`FuncDecl`](OpType::FuncDecl) node (specified) in the Hugr
FuncDecl(N),
/// petgraph-node corresponds to a [`FuncDefn`](OpType::FuncDefn) node (specified) in the Hugr
FuncDefn(N),
/// petgraph-node corresponds to the [HugrView::entrypoint], that is not
/// a [`FuncDefn`](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module)
/// either, as such a node could not have edges, so is not represented in the petgraph.
NonFuncEntrypoint,
/// petgraph-node corresponds to a constant; will have no outgoing edges, and incoming
/// edges will be [StaticEdge::LoadConstant]
Const(N),
}
/// Details the [`FuncDefn`]s, [`FuncDecl`]s and module-level [`Const`]s in a Hugr,
/// in a Hugr, along with the [`Call`]s, [`LoadFunction`]s, and [`LoadConstant`]s connecting them.
///
/// Each node in the `ModuleGraph` corresponds to a module-level function or const;
/// each edge corresponds to a use of the target contained in the edge's source.
///
/// For Hugrs whose entrypoint is neither a [Module](OpType::Module) nor a [`FuncDefn`],
/// the static graph will have an additional [`StaticNode::NonFuncEntrypoint`]
/// corresponding to the Hugr's entrypoint, with no incoming edges.
///
/// [`Call`]: OpType::Call
/// [`Const`]: OpType::Const
/// [`FuncDecl`]: OpType::FuncDecl
/// [`FuncDefn`]: OpType::FuncDefn
/// [`LoadConstant`]: OpType::LoadConstant
/// [`LoadFunction`]: OpType::LoadFunction
pub struct ModuleGraph<N = Node> {
g: Graph<StaticNode<N>, StaticEdge<N>>,
node_to_g: BTreeMap<N, petgraph::graph::NodeIndex<u32>>,
}
impl<N: HugrNode> ModuleGraph<N> {
/// Makes a new `ModuleGraph` for a Hugr.
pub fn new(hugr: &impl HugrView<Node = N>) -> Self {
let mut g = Graph::default();
let mut node_to_g = hugr
.children(hugr.module_root())
.filter_map(|n| {
let weight = match hugr.get_optype(n) {
OpType::FuncDecl(_) => StaticNode::FuncDecl(n),
OpType::FuncDefn(_) => StaticNode::FuncDefn(n),
OpType::Const(_) => StaticNode::Const(n),
_ => return None,
};
Some((n, g.add_node(weight)))
})
.collect::<BTreeMap<_, _>>();
if !hugr.entrypoint_optype().is_module() && !node_to_g.contains_key(&hugr.entrypoint()) {
node_to_g.insert(hugr.entrypoint(), g.add_node(StaticNode::NonFuncEntrypoint));
}
for (func, cg_node) in &node_to_g {
for n in hugr.descendants(*func) {
let weight = match hugr.get_optype(n) {
OpType::Call(_) => StaticEdge::Call(n),
OpType::LoadFunction(_) => StaticEdge::LoadFunction(n),
OpType::LoadConstant(_) => StaticEdge::LoadConstant(n),
_ => continue,
};
if let Some(target) = hugr.static_source(n) {
if hugr.get_parent(target) == Some(hugr.module_root()) {
g.add_edge(*cg_node, node_to_g[&target], weight);
} else {
// Local constant (only global constants are in the graph)
assert!(!node_to_g.contains_key(&target));
assert!(hugr.get_optype(n).is_load_constant());
assert!(hugr.get_optype(target).is_const());
}
}
}
}
ModuleGraph { g, node_to_g }
}
/// Allows access to the petgraph
#[must_use]
pub fn graph(&self) -> &Graph<StaticNode<N>, StaticEdge<N>> {
&self.g
}
/// Convert a Hugr [Node] into a petgraph node index.
/// Result will be `None` if `n` is not a [`FuncDefn`](OpType::FuncDefn),
/// [`FuncDecl`](OpType::FuncDecl) or the [HugrView::entrypoint].
pub fn node_index(&self, n: N) -> Option<petgraph::graph::NodeIndex<u32>> {
self.node_to_g.get(&n).copied()
}
/// Returns an iterator over the out-edges from the given Node, i.e.
/// edges to the functions/constants called/loaded by it.
///
/// If the node is not recognised as a function or the entrypoint,
/// for example if it is a [`Const`](OpType::Const), the iterator will be empty.
pub fn out_edges(&self, n: N) -> impl Iterator<Item = (&StaticEdge<N>, &StaticNode<N>)> {
let g = self.graph();
self.node_index(n).into_iter().flat_map(move |n| {
self.graph().edges(n).map(|e| {
(
g.edge_weight(e.id()).unwrap(),
g.node_weight(e.target()).unwrap(),
)
})
})
}
/// Returns an iterator over the in-edges to the given Node, i.e.
/// edges from the (necessarily) functions that call/load it.
///
/// If the node is not recognised as a function or constant,
/// for example if it is a non-function entrypoint, the iterator will be empty.
pub fn in_edges(&self, n: N) -> impl Iterator<Item = (&StaticNode<N>, &StaticEdge<N>)> {
let g = self.graph();
self.node_index(n).into_iter().flat_map(move |n| {
self.graph()
.edges_directed(n, petgraph::Direction::Incoming)
.map(|e| {
(
g.node_weight(e.source()).unwrap(),
g.edge_weight(e.id()).unwrap(),
)
})
})
}
}
#[cfg(test)]
mod test {
use itertools::Itertools as _;
use rstest::rstest;
use std::collections::HashMap;
use crate::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
ModuleBuilder, endo_sig, inout_sig,
};
use crate::extension::prelude::{ConstUsize, usize_t};
use crate::hugr::hugrmut::HugrMut;
use crate::ops::{Value, handle::NodeHandle};
use super::*;
#[test]
fn edges() {
let mut mb = ModuleBuilder::new();
let cst = mb.add_constant(Value::from(ConstUsize::new(42)));
let callee = mb.define_function("callee", endo_sig([usize_t()])).unwrap();
let ins = callee.input_wires();
let callee = callee.finish_with_outputs(ins).unwrap();
let mut caller = mb
.define_function("caller", inout_sig(vec![], [usize_t()]))
.unwrap();
let val = caller.load_const(&cst);
let call = caller.call(callee.handle(), &[], vec![val]).unwrap();
let caller = caller.finish_with_outputs(call.outputs()).unwrap();
let h = mb.finish_hugr().unwrap();
let mg = ModuleGraph::new(&h);
let call_edge = StaticEdge::Call(call.node());
let load_const_edge = StaticEdge::LoadConstant(val.node());
assert_eq!(mg.out_edges(callee.node()).next(), None);
assert_eq!(
mg.in_edges(callee.node()).collect_vec(),
[(&StaticNode::FuncDefn(caller.node()), &call_edge,)]
);
assert_eq!(
mg.out_edges(caller.node()).collect_vec(),
[
(&call_edge, &StaticNode::FuncDefn(callee.node()),),
(&load_const_edge, &StaticNode::Const(cst.node()),)
]
);
assert_eq!(mg.in_edges(caller.node()).next(), None);
assert_eq!(mg.out_edges(cst.node()).next(), None);
assert_eq!(
mg.in_edges(cst.node()).collect_vec(),
[(&StaticNode::FuncDefn(caller.node()), &load_const_edge,)]
);
}
#[rstest]
fn entrypoint(#[values(true, false)] single_node: bool) {
let mut dfb = DFGBuilder::new(endo_sig([usize_t()])).unwrap();
let called = dfb
.module_root_builder()
.declare("called", endo_sig([usize_t()]).into())
.unwrap();
let ins = dfb.input_wires();
let call = dfb.call(&called, &[], ins).unwrap();
let mut h = dfb.finish_hugr_with_outputs(call.outputs()).unwrap();
let main = h.get_parent(h.entrypoint()).unwrap();
if single_node {
h.set_entrypoint(call.node());
}
let mg = ModuleGraph::new(&h);
let mut in_edges: HashMap<_, _> = mg.in_edges(called.node()).collect();
assert_eq!(
in_edges.remove(&StaticNode::NonFuncEntrypoint),
Some(&StaticEdge::Call(call.node()))
);
let expected = (&StaticNode::FuncDefn(main), &StaticEdge::Call(call.node()));
assert_eq!(in_edges, HashMap::from_iter([expected]));
for n in [h.entrypoint(), main] {
assert_eq!(
mg.out_edges(n).collect_vec(),
vec![(
&StaticEdge::Call(call.node()),
&StaticNode::FuncDecl(called.node())
)]
);
}
}
}