hugr_llvm/utils/
inline_constant_functions.rs

1use hugr_core::{
2    hugr::{hugrmut::HugrMut, internal::HugrMutInternals},
3    ops::{FuncDefn, LoadFunction, Value},
4    types::PolyFuncType,
5    HugrView, Node, NodeIndex as _,
6};
7
8use anyhow::{anyhow, bail, Result};
9
10fn const_fn_name(konst_n: Node) -> String {
11    format!("const_fun_{}", konst_n.index())
12}
13
14pub fn inline_constant_functions(hugr: &mut impl HugrMut<Node = Node>) -> Result<()> {
15    while inline_constant_functions_impl(hugr)? {}
16    Ok(())
17}
18
19fn inline_constant_functions_impl(hugr: &mut impl HugrMut<Node = Node>) -> Result<bool> {
20    let mut const_funs = vec![];
21
22    for n in hugr.entry_descendants() {
23        let konst_hugr = {
24            let Some(konst) = hugr.get_optype(n).as_const() else {
25                continue;
26            };
27            let Value::Function { hugr } = konst.value() else {
28                continue;
29            };
30            let optype = hugr.get_optype(hugr.entrypoint());
31            if !optype.is_dfg() && !optype.is_func_defn() {
32                bail!(
33                    "Constant function has unsupported root: {:?}",
34                    hugr.get_optype(hugr.entrypoint())
35                )
36            }
37            hugr.clone()
38        };
39        let mut lcs = vec![];
40        for load_constant in hugr.output_neighbours(n) {
41            if !hugr.get_optype(load_constant).is_load_constant() {
42                bail!(
43                    "Constant function has non-LoadConstant output-neighbour: {load_constant} {:?}",
44                    hugr.get_optype(load_constant)
45                )
46            }
47            lcs.push(load_constant);
48        }
49        const_funs.push((n, konst_hugr.as_ref().clone(), lcs));
50    }
51
52    let mut any_changes = false;
53
54    for (konst_n, mut func_hugr, load_constant_ns) in const_funs {
55        if !load_constant_ns.is_empty() {
56            let polysignature: PolyFuncType = func_hugr
57                .inner_function_type()
58                .ok_or(anyhow!(
59                    "Constant function hugr has no inner_func_type: {}",
60                    konst_n.index()
61                ))?
62                .into_owned()
63                .into();
64            let func_defn = FuncDefn {
65                name: const_fn_name(konst_n),
66                signature: polysignature.clone(),
67            };
68            func_hugr.replace_op(func_hugr.entrypoint(), func_defn);
69            let func_node = hugr
70                .insert_hugr(hugr.entrypoint(), func_hugr)
71                .inserted_entrypoint;
72            hugr.set_num_ports(func_node, 0, 1);
73
74            for lcn in load_constant_ns {
75                hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?);
76
77                let src_port = hugr.node_outputs(func_node).next().unwrap();
78                let tgt_port = hugr.node_inputs(lcn).next().unwrap();
79                hugr.connect(func_node, src_port, lcn, tgt_port);
80            }
81            any_changes = true;
82        }
83        hugr.remove_node(konst_n);
84    }
85    Ok(any_changes)
86}
87
88#[cfg(test)]
89mod test {
90    use hugr_core::{
91        builder::{
92            Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
93            ModuleBuilder,
94        },
95        extension::prelude::qb_t,
96        ops::{CallIndirect, Const, Value},
97        types::Signature,
98        Hugr, HugrView, Wire,
99    };
100
101    use super::inline_constant_functions;
102
103    fn build_const(go: impl FnOnce(&mut DFGBuilder<Hugr>) -> Wire) -> Const {
104        Value::function({
105            let mut builder = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
106            let r = go(&mut builder);
107            builder.finish_hugr_with_outputs([r]).unwrap()
108        })
109        .unwrap()
110        .into()
111    }
112
113    #[test]
114    fn simple() {
115        let qb_sig: Signature = Signature::new_endo(qb_t());
116        let mut hugr = {
117            let mut builder = ModuleBuilder::new();
118            let const_node = builder.add_constant(build_const(|builder| {
119                let [r] = builder.input_wires_arr();
120                r
121            }));
122            {
123                let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
124                let [i] = builder.input_wires_arr();
125                let fun = builder.load_const(&const_node);
126                let [r] = builder
127                    .add_dataflow_op(
128                        CallIndirect {
129                            signature: qb_sig.clone(),
130                        },
131                        [fun, i],
132                    )
133                    .unwrap()
134                    .outputs_arr();
135                builder.finish_with_outputs([r]).unwrap();
136            };
137            builder.finish_hugr().unwrap()
138        };
139
140        inline_constant_functions(&mut hugr).unwrap();
141        hugr.validate().unwrap();
142
143        for n in hugr.entry_descendants() {
144            if let Some(konst) = hugr.get_optype(n).as_const() {
145                assert!(!matches!(konst.value(), Value::Function { .. }))
146            }
147        }
148    }
149
150    #[test]
151    fn nested() {
152        let qb_sig: Signature = Signature::new_endo(qb_t());
153        let mut hugr = {
154            let mut builder = ModuleBuilder::new();
155            let const_node = builder.add_constant(build_const(|builder| {
156                let [i] = builder.input_wires_arr();
157                let func = builder.add_load_const(build_const(|builder| {
158                    let [r] = builder.input_wires_arr();
159                    r
160                }));
161                let [r] = builder
162                    .add_dataflow_op(
163                        CallIndirect {
164                            signature: qb_sig.clone(),
165                        },
166                        [func, i],
167                    )
168                    .unwrap()
169                    .outputs_arr();
170                r
171            }));
172            {
173                let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
174                let [i] = builder.input_wires_arr();
175                let fun = builder.load_const(&const_node);
176                let [r] = builder
177                    .add_dataflow_op(
178                        CallIndirect {
179                            signature: qb_sig.clone(),
180                        },
181                        [fun, i],
182                    )
183                    .unwrap()
184                    .outputs_arr();
185                builder.finish_with_outputs([r]).unwrap();
186            };
187            builder.finish_hugr().unwrap()
188        };
189
190        inline_constant_functions(&mut hugr).unwrap();
191        hugr.validate().unwrap();
192
193        for n in hugr.entry_descendants() {
194            if let Some(konst) = hugr.get_optype(n).as_const() {
195                assert!(!matches!(konst.value(), Value::Function { .. }))
196            }
197        }
198    }
199}