hugr_core/hugr/rewrite/
inline_call.rs

1//! Rewrite to inline a Call to a FuncDefn by copying the body of the function
2//! into a DFG which replaces the Call node.
3use derive_more::{Display, Error};
4
5use crate::ops::{DataflowParent, OpType, DFG};
6use crate::types::Substitution;
7use crate::{Direction, HugrView, Node};
8
9use super::{HugrMut, Rewrite};
10
11/// Rewrite to inline a [Call](OpType::Call) to a known [FuncDefn](OpType::FuncDefn)
12pub struct InlineCall(Node);
13
14/// Error in performing [InlineCall] rewrite.
15#[derive(Clone, Debug, Display, Error, PartialEq)]
16#[non_exhaustive]
17pub enum InlineCallError {
18    /// The specified Node was not a [Call](OpType::Call)
19    #[display("Node to inline {_0} expected to be a Call but actually {_1}")]
20    NotCallNode(Node, OpType),
21    /// The node was a Call, but the target was not a [FuncDefn](OpType::FuncDefn)
22    /// - presumably a [FuncDecl](OpType::FuncDecl), if the Hugr is valid.
23    #[display("Call targetted node {_0} which must be a FuncDefn but was {_1}")]
24    CallTargetNotFuncDefn(Node, OpType),
25}
26
27impl InlineCall {
28    /// Create a new instance that will inline the specified node
29    /// (i.e. that should be a [Call](OpType::Call))
30    pub fn new(node: Node) -> Self {
31        Self(node)
32    }
33}
34
35impl Rewrite for InlineCall {
36    type ApplyResult = ();
37    type Error = InlineCallError;
38    fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error> {
39        let call_ty = h.get_optype(self.0);
40        if !call_ty.is_call() {
41            return Err(InlineCallError::NotCallNode(self.0, call_ty.clone()));
42        }
43        let func = h.static_source(self.0).unwrap();
44        let func_ty = h.get_optype(func);
45        if !func_ty.is_func_defn() {
46            return Err(InlineCallError::CallTargetNotFuncDefn(
47                func,
48                func_ty.clone(),
49            ));
50        }
51        Ok(())
52    }
53
54    fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> {
55        self.verify(h)?; // Now we know we have a Call to a FuncDefn.
56        let orig_func = h.static_source(self.0).unwrap();
57
58        h.disconnect(self.0, h.get_optype(self.0).static_input_port().unwrap());
59
60        // The order input port gets renumbered because the static input
61        // (which comes between the value inports and the order inport) gets removed
62        let old_order_in = h.get_optype(self.0).other_input_port().unwrap();
63        let order_preds = h.linked_outputs(self.0, old_order_in).collect::<Vec<_>>();
64        h.disconnect(self.0, old_order_in); // PortGraph currently does this anyway
65
66        let new_op = OpType::from(DFG {
67            signature: h
68                .get_optype(orig_func)
69                .as_func_defn()
70                .unwrap()
71                .inner_signature()
72                .into_owned(),
73        });
74        let new_order_in = new_op.other_input_port().unwrap();
75
76        let ty_args = h
77            .replace_op(self.0, new_op)
78            .unwrap()
79            .as_call()
80            .unwrap()
81            .type_args
82            .clone();
83
84        h.add_ports(self.0, Direction::Incoming, -1);
85
86        // Reconnect order predecessors
87        for (src, srcp) in order_preds {
88            h.connect(src, srcp, self.0, new_order_in);
89        }
90
91        h.copy_descendants(
92            orig_func,
93            self.0,
94            (!ty_args.is_empty()).then_some(Substitution::new(&ty_args)),
95        );
96        Ok(())
97    }
98
99    /// Failure only occurs if the node is not a Call, or the target not a FuncDefn.
100    /// (Any later failure means an invalid Hugr and `panic`.)
101    const UNCHANGED_ON_FAILURE: bool = true;
102
103    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
104        Some(self.0).into_iter()
105    }
106}
107
108#[cfg(test)]
109mod test {
110    use std::iter::successors;
111
112    use itertools::Itertools;
113
114    use crate::builder::{
115        Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
116        ModuleBuilder,
117    };
118    use crate::extension::prelude::usize_t;
119    use crate::hugr::views::RootChecked;
120    use crate::ops::handle::{FuncID, ModuleRootID, NodeHandle};
121    use crate::ops::{Input, OpType, Value};
122    use crate::std_extensions::arithmetic::{
123        int_ops::{self, IntOpDef},
124        int_types::{self, ConstInt, INT_TYPES},
125    };
126
127    use crate::types::{PolyFuncType, Signature, Type, TypeBound};
128    use crate::{HugrView, Node};
129
130    use super::{HugrMut, InlineCall, InlineCallError};
131
132    fn calls(h: &impl HugrView<Node = Node>) -> Vec<Node> {
133        h.nodes().filter(|n| h.get_optype(*n).is_call()).collect()
134    }
135
136    fn extension_ops(h: &impl HugrView<Node = Node>) -> Vec<Node> {
137        h.nodes()
138            .filter(|n| h.get_optype(*n).is_extension_op())
139            .collect()
140    }
141
142    #[test]
143    fn test_inline() -> Result<(), Box<dyn std::error::Error>> {
144        let mut mb = ModuleBuilder::new();
145        let cst3 = mb.add_constant(Value::from(ConstInt::new_u(4, 3)?));
146        let sig = Signature::new_endo(INT_TYPES[4].clone())
147            .with_extension_delta(int_ops::EXTENSION_ID)
148            .with_extension_delta(int_types::EXTENSION_ID);
149        let func = {
150            let mut fb = mb.define_function("foo", sig.clone())?;
151            let c1 = fb.load_const(&cst3);
152            let mut inner = fb.dfg_builder(sig.clone(), fb.input_wires())?;
153            let [i] = inner.input_wires_arr();
154            let add = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(4), [i, c1])?;
155            let inner_res = inner.finish_with_outputs(add.outputs())?;
156            fb.finish_with_outputs(inner_res.outputs())?
157        };
158        let mut main = mb.define_function("main", sig)?;
159        let call1 = main.call(func.handle(), &[], main.input_wires())?;
160        main.add_other_wire(main.input().node(), call1.node());
161        let call2 = main.call(func.handle(), &[], call1.outputs())?;
162        main.finish_with_outputs(call2.outputs())?;
163        let mut hugr = mb.finish_hugr()?;
164        let call1 = call1.node();
165        let call2 = call2.node();
166        assert_eq!(
167            hugr.output_neighbours(func.node()).collect_vec(),
168            [call1, call2]
169        );
170        assert_eq!(calls(&hugr), [call1, call2]);
171        assert_eq!(extension_ops(&hugr).len(), 1);
172
173        assert_eq!(
174            hugr.linked_outputs(
175                call1.node(),
176                hugr.get_optype(call1).other_input_port().unwrap()
177            )
178            .count(),
179            1
180        );
181        RootChecked::<_, ModuleRootID>::try_new(&mut hugr)
182            .unwrap()
183            .apply_rewrite(InlineCall(call1.node()))
184            .unwrap();
185        hugr.validate().unwrap();
186        assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]);
187        assert_eq!(calls(&hugr), [call2]);
188        assert_eq!(extension_ops(&hugr).len(), 2);
189        assert_eq!(
190            hugr.linked_outputs(
191                call1.node(),
192                hugr.get_optype(call1).other_input_port().unwrap()
193            )
194            .count(),
195            1
196        );
197        hugr.apply_rewrite(InlineCall(call2.node())).unwrap();
198        hugr.validate().unwrap();
199        assert_eq!(hugr.output_neighbours(func.node()).next(), None);
200        assert_eq!(calls(&hugr), []);
201        assert_eq!(extension_ops(&hugr).len(), 3);
202
203        Ok(())
204    }
205
206    #[test]
207    fn test_recursion() -> Result<(), Box<dyn std::error::Error>> {
208        let mut mb = ModuleBuilder::new();
209        let sig = Signature::new_endo(INT_TYPES[5].clone())
210            .with_extension_delta(int_ops::EXTENSION_ID)
211            .with_extension_delta(int_types::EXTENSION_ID);
212        let (func, rec_call) = {
213            let mut fb = mb.define_function("foo", sig.clone())?;
214            let cst1 = fb.add_load_value(ConstInt::new_u(5, 1)?);
215            let [i] = fb.input_wires_arr();
216            let add = fb.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [i, cst1])?;
217            let call = fb.call(
218                &FuncID::<true>::from(fb.container_node()),
219                &[],
220                add.outputs(),
221            )?;
222            (fb.finish_with_outputs(call.outputs())?, call)
223        };
224        let mut main = mb.define_function("main", sig)?;
225        let call = main.call(func.handle(), &[], main.input_wires())?;
226        let main = main.finish_with_outputs(call.outputs())?;
227        let mut hugr = mb.finish_hugr()?;
228
229        let func = func.node();
230        let mut call = call.node();
231        for i in 2..10 {
232            hugr.apply_rewrite(InlineCall(call))?;
233            hugr.validate().unwrap();
234            assert_eq!(extension_ops(&hugr).len(), i);
235            let v = calls(&hugr);
236            assert!(v.iter().all(|n| hugr.static_source(*n) == Some(func)));
237
238            let [rec, nonrec] = v.try_into().expect("Should be two");
239            assert_eq!(rec, rec_call.node());
240            assert_eq!(hugr.output_neighbours(func).collect_vec(), [rec, nonrec]);
241            call = nonrec;
242
243            let mut ancestors = successors(hugr.get_parent(call), |n| hugr.get_parent(*n));
244            for _ in 1..i {
245                assert!(hugr.get_optype(ancestors.next().unwrap()).is_dfg());
246            }
247            assert_eq!(ancestors.next(), Some(main.node()));
248            assert_eq!(ancestors.next(), Some(hugr.root()));
249            assert_eq!(ancestors.next(), None);
250        }
251        Ok(())
252    }
253
254    #[test]
255    fn test_bad() {
256        let mut modb = ModuleBuilder::new();
257        let decl = modb
258            .declare(
259                "UndefinedFunc",
260                Signature::new_endo(INT_TYPES[3].clone()).into(),
261            )
262            .unwrap();
263        let mut main = modb
264            .define_function("main", Signature::new_endo(INT_TYPES[3].clone()))
265            .unwrap();
266        let call = main.call(&decl, &[], main.input_wires()).unwrap();
267        let main = main.finish_with_outputs(call.outputs()).unwrap();
268        let h = modb.finish_hugr().unwrap();
269        let mut h2 = h.clone();
270        assert_eq!(
271            h2.apply_rewrite(InlineCall(call.node())),
272            Err(InlineCallError::CallTargetNotFuncDefn(
273                decl.node(),
274                h.get_optype(decl.node()).clone()
275            ))
276        );
277        assert_eq!(h, h2);
278        let [inp, _out, _call] = h
279            .children(main.node())
280            .collect::<Vec<_>>()
281            .try_into()
282            .unwrap();
283        assert_eq!(
284            h2.apply_rewrite(InlineCall(inp)),
285            Err(InlineCallError::NotCallNode(
286                inp,
287                Input {
288                    types: INT_TYPES[3].clone().into()
289                }
290                .into()
291            ))
292        )
293    }
294
295    #[test]
296    fn test_polymorphic() -> Result<(), Box<dyn std::error::Error>> {
297        let tuple_ty = Type::new_tuple(vec![usize_t(); 2]);
298        let mut fb = FunctionBuilder::new(
299            "mkpair",
300            Signature::new(usize_t(), tuple_ty.clone()).with_prelude(),
301        )?;
302        let inner = fb.define_function(
303            "id",
304            PolyFuncType::new(
305                [TypeBound::Copyable.into()],
306                Signature::new_endo(Type::new_var_use(0, TypeBound::Copyable)),
307            ),
308        )?;
309        let inps = inner.input_wires();
310        let inner = inner.finish_with_outputs(inps)?;
311        let call1 = fb.call(inner.handle(), &[usize_t().into()], fb.input_wires())?;
312        let [call1_out] = call1.outputs_arr();
313        let tup = fb.make_tuple([call1_out, call1_out])?;
314        let call2 = fb.call(inner.handle(), &[tuple_ty.into()], [tup])?;
315        let mut hugr = fb.finish_hugr_with_outputs(call2.outputs()).unwrap();
316
317        assert_eq!(
318            hugr.output_neighbours(inner.node()).collect::<Vec<_>>(),
319            [call1.node(), call2.node()]
320        );
321        hugr.apply_rewrite(InlineCall::new(call1.node()))?;
322
323        assert_eq!(
324            hugr.output_neighbours(inner.node()).collect::<Vec<_>>(),
325            [call2.node()]
326        );
327        assert!(hugr.get_optype(call1.node()).is_dfg());
328        assert!(matches!(
329            hugr.children(call1.node())
330                .map(|n| hugr.get_optype(n).clone())
331                .collect::<Vec<_>>()[..],
332            [OpType::Input(_), OpType::Output(_)]
333        ));
334        Ok(())
335    }
336}