hugr_llvm/utils/
inline_constant_functions.rs1use hugr_core::{
2 HugrView, Node, NodeIndex as _,
3 hugr::{hugrmut::HugrMut, internal::HugrMutInternals},
4 ops::{FuncDefn, LoadFunction, Value},
5 types::PolyFuncType,
6};
7
8use anyhow::{Result, anyhow, bail};
9
10fn const_fn_name(konst_n: Node) -> String {
11 format!("const_fun_{}", konst_n.index())
12}
13
14pub fn inline_constant_functions(hugr: &mut impl HugrMut<Node = Node>) -> Result<()> {
15 while inline_constant_functions_impl(hugr)? {}
16 Ok(())
17}
18
19fn inline_constant_functions_impl(hugr: &mut impl HugrMut<Node = Node>) -> Result<bool> {
20 let mut const_funs = vec![];
21
22 for n in hugr.entry_descendants() {
23 let konst_hugr = {
24 let Some(konst) = hugr.get_optype(n).as_const() else {
25 continue;
26 };
27 let Value::Function { hugr } = konst.value() else {
28 continue;
29 };
30 let optype = hugr.get_optype(hugr.entrypoint());
31 if !optype.is_dfg() && !optype.is_func_defn() {
32 bail!(
33 "Constant function has unsupported root: {:?}",
34 hugr.get_optype(hugr.entrypoint())
35 )
36 }
37 hugr.clone()
38 };
39 let mut lcs = vec![];
40 for load_constant in hugr.output_neighbours(n) {
41 if !hugr.get_optype(load_constant).is_load_constant() {
42 bail!(
43 "Constant function has non-LoadConstant output-neighbour: {load_constant} {:?}",
44 hugr.get_optype(load_constant)
45 )
46 }
47 lcs.push(load_constant);
48 }
49 const_funs.push((n, konst_hugr.as_ref().clone(), lcs));
50 }
51
52 let mut any_changes = false;
53
54 for (konst_n, mut func_hugr, load_constant_ns) in const_funs {
55 if !load_constant_ns.is_empty() {
56 let polysignature: PolyFuncType = func_hugr
57 .inner_function_type()
58 .ok_or(anyhow!(
59 "Constant function hugr has no inner_func_type: {}",
60 konst_n.index()
61 ))?
62 .into_owned()
63 .into();
64 let func_defn = FuncDefn::new(const_fn_name(konst_n), polysignature.clone());
65 func_hugr.replace_op(func_hugr.entrypoint(), func_defn);
66 let func_node = hugr
67 .insert_hugr(hugr.entrypoint(), func_hugr)
68 .inserted_entrypoint;
69 hugr.set_num_ports(func_node, 0, 1);
70
71 for lcn in load_constant_ns {
72 hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?);
73
74 let src_port = hugr.node_outputs(func_node).next().unwrap();
75 let tgt_port = hugr.node_inputs(lcn).next().unwrap();
76 hugr.connect(func_node, src_port, lcn, tgt_port);
77 }
78 any_changes = true;
79 }
80 hugr.remove_node(konst_n);
81 }
82 Ok(any_changes)
83}
84
85#[cfg(test)]
86mod test {
87 use hugr_core::{
88 Hugr, HugrView, Wire,
89 builder::{
90 Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
91 ModuleBuilder,
92 },
93 extension::prelude::qb_t,
94 ops::{CallIndirect, Const, Value},
95 types::Signature,
96 };
97
98 use super::inline_constant_functions;
99
100 fn build_const(go: impl FnOnce(&mut DFGBuilder<Hugr>) -> Wire) -> Const {
101 Value::function({
102 let mut builder = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
103 let r = go(&mut builder);
104 builder.finish_hugr_with_outputs([r]).unwrap()
105 })
106 .unwrap()
107 .into()
108 }
109
110 #[test]
111 fn simple() {
112 let qb_sig: Signature = Signature::new_endo(qb_t());
113 let mut hugr = {
114 let mut builder = ModuleBuilder::new();
115 let const_node = builder.add_constant(build_const(|builder| {
116 let [r] = builder.input_wires_arr();
117 r
118 }));
119 {
120 let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
121 let [i] = builder.input_wires_arr();
122 let fun = builder.load_const(&const_node);
123 let [r] = builder
124 .add_dataflow_op(
125 CallIndirect {
126 signature: qb_sig.clone(),
127 },
128 [fun, i],
129 )
130 .unwrap()
131 .outputs_arr();
132 builder.finish_with_outputs([r]).unwrap();
133 };
134 builder.finish_hugr().unwrap()
135 };
136
137 inline_constant_functions(&mut hugr).unwrap();
138 hugr.validate().unwrap();
139
140 for n in hugr.entry_descendants() {
141 if let Some(konst) = hugr.get_optype(n).as_const() {
142 assert!(!matches!(konst.value(), Value::Function { .. }));
143 }
144 }
145 }
146
147 #[test]
148 fn nested() {
149 let qb_sig: Signature = Signature::new_endo(qb_t());
150 let mut hugr = {
151 let mut builder = ModuleBuilder::new();
152 let const_node = builder.add_constant(build_const(|builder| {
153 let [i] = builder.input_wires_arr();
154 let func = builder.add_load_const(build_const(|builder| {
155 let [r] = builder.input_wires_arr();
156 r
157 }));
158 let [r] = builder
159 .add_dataflow_op(
160 CallIndirect {
161 signature: qb_sig.clone(),
162 },
163 [func, i],
164 )
165 .unwrap()
166 .outputs_arr();
167 r
168 }));
169 {
170 let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
171 let [i] = builder.input_wires_arr();
172 let fun = builder.load_const(&const_node);
173 let [r] = builder
174 .add_dataflow_op(
175 CallIndirect {
176 signature: qb_sig.clone(),
177 },
178 [fun, i],
179 )
180 .unwrap()
181 .outputs_arr();
182 builder.finish_with_outputs([r]).unwrap();
183 };
184 builder.finish_hugr().unwrap()
185 };
186
187 inline_constant_functions(&mut hugr).unwrap();
188 hugr.validate().unwrap();
189
190 for n in hugr.entry_descendants() {
191 if let Some(konst) = hugr.get_optype(n).as_const() {
192 assert!(!matches!(konst.value(), Value::Function { .. }));
193 }
194 }
195 }
196}