hugr_llvm/utils/
inline_constant_functions.rs

1use hugr_core::{
2    hugr::hugrmut::HugrMut,
3    ops::{FuncDefn, LoadFunction, Value},
4    types::PolyFuncType,
5    HugrView, Node, NodeIndex as _,
6};
7
8use anyhow::{anyhow, bail, Result};
9
10fn const_fn_name(konst_n: Node) -> String {
11    format!("const_fun_{}", konst_n.index())
12}
13
14pub fn inline_constant_functions(hugr: &mut impl HugrMut) -> Result<()> {
15    while inline_constant_functions_impl(hugr)? {}
16    Ok(())
17}
18
19fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Result<bool> {
20    let mut const_funs = vec![];
21
22    for n in hugr.nodes() {
23        let konst_hugr = {
24            let Some(konst) = hugr.get_optype(n).as_const() else {
25                continue;
26            };
27            let Value::Function { hugr } = konst.value() else {
28                continue;
29            };
30            let optype = hugr.get_optype(hugr.root());
31            if !optype.is_dfg() && !optype.is_func_defn() {
32                bail!(
33                    "Constant function has unsupported root: {:?}",
34                    hugr.get_optype(hugr.root())
35                )
36            }
37            hugr.clone()
38        };
39        let mut lcs = vec![];
40        for load_constant in hugr.output_neighbours(n) {
41            if !hugr.get_optype(load_constant).is_load_constant() {
42                bail!(
43                    "Constant function has non-LoadConstant output-neighbour: {load_constant} {:?}",
44                    hugr.get_optype(load_constant)
45                )
46            }
47            lcs.push(load_constant);
48        }
49        const_funs.push((n, konst_hugr.as_ref().clone(), lcs));
50    }
51
52    let mut any_changes = false;
53
54    for (konst_n, func_hugr, load_constant_ns) in const_funs {
55        if !load_constant_ns.is_empty() {
56            let polysignature: PolyFuncType = func_hugr
57                .inner_function_type()
58                .ok_or(anyhow!(
59                    "Constant function hugr has no inner_func_type: {}",
60                    konst_n.index()
61                ))?
62                .into_owned()
63                .into();
64            let func_defn = FuncDefn {
65                name: const_fn_name(konst_n),
66                signature: polysignature.clone(),
67            };
68            let func_node = hugr.add_node_with_parent(hugr.root(), func_defn);
69            hugr.insert_hugr(func_node, func_hugr);
70
71            for lcn in load_constant_ns {
72                hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?)?;
73            }
74            any_changes = true;
75        }
76        hugr.remove_node(konst_n);
77    }
78    Ok(any_changes)
79}
80
81#[cfg(test)]
82mod test {
83    use hugr_core::{
84        builder::{
85            Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
86            ModuleBuilder,
87        },
88        extension::prelude::qb_t,
89        ops::{CallIndirect, Const, Value},
90        types::Signature,
91        Hugr, HugrView, Wire,
92    };
93
94    use super::inline_constant_functions;
95
96    fn build_const(go: impl FnOnce(&mut DFGBuilder<Hugr>) -> Wire) -> Const {
97        Value::function({
98            let mut builder = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
99            let r = go(&mut builder);
100            builder.finish_hugr_with_outputs([r]).unwrap()
101        })
102        .unwrap()
103        .into()
104    }
105
106    #[test]
107    fn simple() {
108        let qb_sig: Signature = Signature::new_endo(qb_t());
109        let mut hugr = {
110            let mut builder = ModuleBuilder::new();
111            let const_node = builder.add_constant(build_const(|builder| {
112                let [r] = builder.input_wires_arr();
113                r
114            }));
115            {
116                let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
117                let [i] = builder.input_wires_arr();
118                let fun = builder.load_const(&const_node);
119                let [r] = builder
120                    .add_dataflow_op(
121                        CallIndirect {
122                            signature: qb_sig.clone(),
123                        },
124                        [fun, i],
125                    )
126                    .unwrap()
127                    .outputs_arr();
128                builder.finish_with_outputs([r]).unwrap();
129            };
130            builder.finish_hugr().unwrap()
131        };
132
133        inline_constant_functions(&mut hugr).unwrap();
134
135        for n in hugr.nodes() {
136            if let Some(konst) = hugr.get_optype(n).as_const() {
137                assert!(!matches!(konst.value(), Value::Function { .. }))
138            }
139        }
140    }
141
142    #[test]
143    fn nested() {
144        let qb_sig: Signature = Signature::new_endo(qb_t());
145        let mut hugr = {
146            let mut builder = ModuleBuilder::new();
147            let const_node = builder.add_constant(build_const(|builder| {
148                let [i] = builder.input_wires_arr();
149                let func = builder.add_load_const(build_const(|builder| {
150                    let [r] = builder.input_wires_arr();
151                    r
152                }));
153                let [r] = builder
154                    .add_dataflow_op(
155                        CallIndirect {
156                            signature: qb_sig.clone(),
157                        },
158                        [func, i],
159                    )
160                    .unwrap()
161                    .outputs_arr();
162                r
163            }));
164            {
165                let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
166                let [i] = builder.input_wires_arr();
167                let fun = builder.load_const(&const_node);
168                let [r] = builder
169                    .add_dataflow_op(
170                        CallIndirect {
171                            signature: qb_sig.clone(),
172                        },
173                        [fun, i],
174                    )
175                    .unwrap()
176                    .outputs_arr();
177                builder.finish_with_outputs([r]).unwrap();
178            };
179            builder.finish_hugr().unwrap()
180        };
181
182        inline_constant_functions(&mut hugr).unwrap();
183
184        for n in hugr.nodes() {
185            if let Some(konst) = hugr.get_optype(n).as_const() {
186                assert!(!matches!(konst.value(), Value::Function { .. }))
187            }
188        }
189    }
190}