hugr_llvm/utils/
inline_constant_functions.rs

1use hugr_core::{
2    HugrView, Node, NodeIndex as _,
3    hugr::{hugrmut::HugrMut, internal::HugrMutInternals},
4    ops::{FuncDefn, LoadFunction, Value},
5    types::PolyFuncType,
6};
7
8use anyhow::{Result, anyhow, bail};
9
10fn const_fn_name(konst_n: Node) -> String {
11    format!("const_fun_{}", konst_n.index())
12}
13
14pub fn inline_constant_functions(hugr: &mut impl HugrMut<Node = Node>) -> Result<()> {
15    while inline_constant_functions_impl(hugr)? {}
16    Ok(())
17}
18
19fn inline_constant_functions_impl(hugr: &mut impl HugrMut<Node = Node>) -> Result<bool> {
20    let mut const_funs = vec![];
21
22    for n in hugr.entry_descendants() {
23        let konst_hugr = {
24            let Some(konst) = hugr.get_optype(n).as_const() else {
25                continue;
26            };
27            let Value::Function { hugr } = konst.value() else {
28                continue;
29            };
30            let optype = hugr.get_optype(hugr.entrypoint());
31            if !optype.is_dfg() && !optype.is_func_defn() {
32                bail!(
33                    "Constant function has unsupported root: {:?}",
34                    hugr.get_optype(hugr.entrypoint())
35                )
36            }
37            hugr.clone()
38        };
39        let mut lcs = vec![];
40        for load_constant in hugr.output_neighbours(n) {
41            if !hugr.get_optype(load_constant).is_load_constant() {
42                bail!(
43                    "Constant function has non-LoadConstant output-neighbour: {load_constant} {:?}",
44                    hugr.get_optype(load_constant)
45                )
46            }
47            lcs.push(load_constant);
48        }
49        const_funs.push((n, konst_hugr.as_ref().clone(), lcs));
50    }
51
52    let mut any_changes = false;
53
54    for (konst_n, mut func_hugr, load_constant_ns) in const_funs {
55        if !load_constant_ns.is_empty() {
56            let polysignature: PolyFuncType = func_hugr
57                .inner_function_type()
58                .ok_or(anyhow!(
59                    "Constant function hugr has no inner_func_type: {}",
60                    konst_n.index()
61                ))?
62                .into_owned()
63                .into();
64            let func_defn = FuncDefn::new(const_fn_name(konst_n), polysignature.clone());
65            func_hugr.replace_op(func_hugr.entrypoint(), func_defn);
66            let func_node = hugr
67                .insert_hugr(hugr.entrypoint(), func_hugr)
68                .inserted_entrypoint;
69            hugr.set_num_ports(func_node, 0, 1);
70
71            for lcn in load_constant_ns {
72                hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?);
73
74                let src_port = hugr.node_outputs(func_node).next().unwrap();
75                let tgt_port = hugr.node_inputs(lcn).next().unwrap();
76                hugr.connect(func_node, src_port, lcn, tgt_port);
77            }
78            any_changes = true;
79        }
80        hugr.remove_node(konst_n);
81    }
82    Ok(any_changes)
83}
84
85#[cfg(test)]
86mod test {
87    use hugr_core::{
88        Hugr, HugrView, Wire,
89        builder::{
90            Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
91            ModuleBuilder,
92        },
93        extension::prelude::qb_t,
94        ops::{CallIndirect, Const, Value},
95        types::Signature,
96    };
97
98    use super::inline_constant_functions;
99
100    fn build_const(go: impl FnOnce(&mut DFGBuilder<Hugr>) -> Wire) -> Const {
101        Value::function({
102            let mut builder = DFGBuilder::new(Signature::new_endo(qb_t())).unwrap();
103            let r = go(&mut builder);
104            builder.finish_hugr_with_outputs([r]).unwrap()
105        })
106        .unwrap()
107        .into()
108    }
109
110    #[test]
111    fn simple() {
112        let qb_sig: Signature = Signature::new_endo(qb_t());
113        let mut hugr = {
114            let mut builder = ModuleBuilder::new();
115            let const_node = builder.add_constant(build_const(|builder| {
116                let [r] = builder.input_wires_arr();
117                r
118            }));
119            {
120                let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
121                let [i] = builder.input_wires_arr();
122                let fun = builder.load_const(&const_node);
123                let [r] = builder
124                    .add_dataflow_op(
125                        CallIndirect {
126                            signature: qb_sig.clone(),
127                        },
128                        [fun, i],
129                    )
130                    .unwrap()
131                    .outputs_arr();
132                builder.finish_with_outputs([r]).unwrap();
133            };
134            builder.finish_hugr().unwrap()
135        };
136
137        inline_constant_functions(&mut hugr).unwrap();
138        hugr.validate().unwrap();
139
140        for n in hugr.entry_descendants() {
141            if let Some(konst) = hugr.get_optype(n).as_const() {
142                assert!(!matches!(konst.value(), Value::Function { .. }));
143            }
144        }
145    }
146
147    #[test]
148    fn nested() {
149        let qb_sig: Signature = Signature::new_endo(qb_t());
150        let mut hugr = {
151            let mut builder = ModuleBuilder::new();
152            let const_node = builder.add_constant(build_const(|builder| {
153                let [i] = builder.input_wires_arr();
154                let func = builder.add_load_const(build_const(|builder| {
155                    let [r] = builder.input_wires_arr();
156                    r
157                }));
158                let [r] = builder
159                    .add_dataflow_op(
160                        CallIndirect {
161                            signature: qb_sig.clone(),
162                        },
163                        [func, i],
164                    )
165                    .unwrap()
166                    .outputs_arr();
167                r
168            }));
169            {
170                let mut builder = builder.define_function("main", qb_sig.clone()).unwrap();
171                let [i] = builder.input_wires_arr();
172                let fun = builder.load_const(&const_node);
173                let [r] = builder
174                    .add_dataflow_op(
175                        CallIndirect {
176                            signature: qb_sig.clone(),
177                        },
178                        [fun, i],
179                    )
180                    .unwrap()
181                    .outputs_arr();
182                builder.finish_with_outputs([r]).unwrap();
183            };
184            builder.finish_hugr().unwrap()
185        };
186
187        inline_constant_functions(&mut hugr).unwrap();
188        hugr.validate().unwrap();
189
190        for n in hugr.entry_descendants() {
191            if let Some(konst) = hugr.get_optype(n).as_const() {
192                assert!(!matches!(konst.value(), Value::Function { .. }));
193            }
194        }
195    }
196}