hugr_llvm/utils/
inline_constant_functions.rs1use hugr_core::{
2 hugr::hugrmut::HugrMut,
3 ops::{FuncDefn, LoadFunction, Value},
4 types::PolyFuncType,
5 HugrView, Node, NodeIndex as _,
6};
7
8use anyhow::{anyhow, bail, Result};
9
10fn const_fn_name(konst_n: Node) -> String {
11 format!("const_fun_{}", konst_n.index())
12}
13
14pub fn inline_constant_functions(hugr: &mut impl HugrMut) -> Result<()> {
15 while inline_constant_functions_impl(hugr)? {}
16 Ok(())
17}
18
19fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Result<bool> {
20 let mut const_funs = vec![];
21
22 for n in hugr.nodes() {
23 let konst_hugr = {
24 let Some(konst) = hugr.get_optype(n).as_const() else {
25 continue;
26 };
27 let Value::Function { hugr } = konst.value() else {
28 continue;
29 };
30 let optype = hugr.get_optype(hugr.root());
31 if !optype.is_dfg() && !optype.is_func_defn() {
32 bail!(
33 "Constant function has unsupported root: {:?}",
34 hugr.get_optype(hugr.root())
35 )
36 }
37 hugr.clone()
38 };
39 let mut lcs = vec![];
40 for load_constant in hugr.output_neighbours(n) {
41 if !hugr.get_optype(load_constant).is_load_constant() {
42 bail!(
43 "Constant function has non-LoadConstant output-neighbour: {load_constant} {:?}",
44 hugr.get_optype(load_constant)
45 )
46 }
47 lcs.push(load_constant);
48 }
49 const_funs.push((n, konst_hugr.as_ref().clone(), lcs));
50 }
51
52 let mut any_changes = false;
53
54 for (konst_n, func_hugr, load_constant_ns) in const_funs {
55 if !load_constant_ns.is_empty() {
56 let polysignature: PolyFuncType = func_hugr
57 .inner_function_type()
58 .ok_or(anyhow!(
59 "Constant function hugr has no inner_func_type: {}",
60 konst_n.index()
61 ))?
62 .into_owned()
63 .into();
64 let func_defn = FuncDefn {
65 name: const_fn_name(konst_n),
66 signature: polysignature.clone(),
67 };
68 let func_node = hugr.add_node_with_parent(hugr.root(), func_defn);
69 hugr.insert_hugr(func_node, func_hugr);
70
71 for lcn in load_constant_ns {
72 hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?)?;
73 }
74 any_changes = true;
75 }
76 hugr.remove_node(konst_n);
77 }
78 Ok(any_changes)
79}
80
81#[cfg(test)]
82mod test {
83 use hugr_core::{
84 builder::{
85 Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
86 ModuleBuilder,
87 },
88 extension::prelude::qb_t,
89 ops::{CallIndirect, Const, Value},
90 types::Signature,
91 Hugr, HugrView, Wire,
92 };
93
94 use super::inline_constant_functions;
95
96 fn build_const(go: impl FnOnce(&mut DFGBuilder<Hugr>) -> Wire) -> Const {
97 Value::function({
98 let mut builder = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
99 let r = go(&mut builder);
100 builder.finish_hugr_with_outputs([r]).unwrap()
101 })
102 .unwrap()
103 .into()
104 }
105
106 #[test]
107 fn simple() {
108 let qb_sig: Signature = Signature::new_endo(qb_t());
109 let mut hugr = {
110 let mut builder = ModuleBuilder::new();
111 let const_node = builder.add_constant(build_const(|builder| {
112 let [r] = builder.input_wires_arr();
113 r
114 }));
115 {
116 let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
117 let [i] = builder.input_wires_arr();
118 let fun = builder.load_const(&const_node);
119 let [r] = builder
120 .add_dataflow_op(
121 CallIndirect {
122 signature: qb_sig.clone(),
123 },
124 [fun, i],
125 )
126 .unwrap()
127 .outputs_arr();
128 builder.finish_with_outputs([r]).unwrap();
129 };
130 builder.finish_hugr().unwrap()
131 };
132
133 inline_constant_functions(&mut hugr).unwrap();
134
135 for n in hugr.nodes() {
136 if let Some(konst) = hugr.get_optype(n).as_const() {
137 assert!(!matches!(konst.value(), Value::Function { .. }))
138 }
139 }
140 }
141
142 #[test]
143 fn nested() {
144 let qb_sig: Signature = Signature::new_endo(qb_t());
145 let mut hugr = {
146 let mut builder = ModuleBuilder::new();
147 let const_node = builder.add_constant(build_const(|builder| {
148 let [i] = builder.input_wires_arr();
149 let func = builder.add_load_const(build_const(|builder| {
150 let [r] = builder.input_wires_arr();
151 r
152 }));
153 let [r] = builder
154 .add_dataflow_op(
155 CallIndirect {
156 signature: qb_sig.clone(),
157 },
158 [func, i],
159 )
160 .unwrap()
161 .outputs_arr();
162 r
163 }));
164 {
165 let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
166 let [i] = builder.input_wires_arr();
167 let fun = builder.load_const(&const_node);
168 let [r] = builder
169 .add_dataflow_op(
170 CallIndirect {
171 signature: qb_sig.clone(),
172 },
173 [fun, i],
174 )
175 .unwrap()
176 .outputs_arr();
177 builder.finish_with_outputs([r]).unwrap();
178 };
179 builder.finish_hugr().unwrap()
180 };
181
182 inline_constant_functions(&mut hugr).unwrap();
183
184 for n in hugr.nodes() {
185 if let Some(konst) = hugr.get_optype(n).as_const() {
186 assert!(!matches!(konst.value(), Value::Function { .. }))
187 }
188 }
189 }
190}