hugr_core/hugr/patch/
inline_call.rs

1//! Rewrite to inline a Call to a `FuncDefn` by copying the body of the function
2//! into a DFG which replaces the Call node.
3use derive_more::{Display, Error};
4
5use crate::core::HugrNode;
6use crate::ops::{DFG, DataflowParent, OpType};
7use crate::types::Substitution;
8use crate::{Direction, HugrView, Node};
9
10use super::{HugrMut, PatchHugrMut, PatchVerification};
11
12/// Rewrite to inline a [Call](OpType::Call) to a known [`FuncDefn`](OpType::FuncDefn)
13pub struct InlineCall<N = Node>(N);
14
15/// Error in performing [`InlineCall`] rewrite.
16#[derive(Clone, Debug, Display, Error, PartialEq)]
17#[non_exhaustive]
18pub enum InlineCallError<N = Node> {
19    /// The specified Node was not a [Call](OpType::Call)
20    #[display("Node to inline {_0} expected to be a Call but actually {_1}")]
21    NotCallNode(N, OpType),
22    /// The node was a Call, but the target was not a [`FuncDefn`](OpType::FuncDefn)
23    /// - presumably a [`FuncDecl`](OpType::FuncDecl), if the Hugr is valid.
24    #[display("Call targetted node {_0} which must be a FuncDefn but was {_1}")]
25    CallTargetNotFuncDefn(N, OpType),
26}
27
28impl<N> InlineCall<N> {
29    /// Create a new instance that will inline the specified node
30    /// (i.e. that should be a [Call](OpType::Call))
31    pub fn new(node: N) -> Self {
32        Self(node)
33    }
34}
35
36impl<N: HugrNode> PatchVerification for InlineCall<N> {
37    type Error = InlineCallError<N>;
38    type Node = N;
39    fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
40        let call_ty = h.get_optype(self.0);
41        if !call_ty.is_call() {
42            return Err(InlineCallError::NotCallNode(self.0, call_ty.clone()));
43        }
44        let func = h.static_source(self.0).unwrap();
45        let func_ty = h.get_optype(func);
46        if !func_ty.is_func_defn() {
47            return Err(InlineCallError::CallTargetNotFuncDefn(
48                func,
49                func_ty.clone(),
50            ));
51        }
52        Ok(())
53    }
54
55    fn invalidated_nodes(&self, _: &impl HugrView<Node = N>) -> impl Iterator<Item = N> {
56        Some(self.0).into_iter()
57    }
58}
59
60impl<N: HugrNode> PatchHugrMut for InlineCall<N> {
61    type Outcome = ();
62    fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<(), Self::Error> {
63        self.verify(h)?; // Now we know we have a Call to a FuncDefn.
64        let orig_func = h.static_source(self.0).unwrap();
65
66        h.disconnect(self.0, h.get_optype(self.0).static_input_port().unwrap());
67
68        // The order input port gets renumbered because the static input
69        // (which comes between the value inports and the order inport) gets removed
70        let old_order_in = h.get_optype(self.0).other_input_port().unwrap();
71        let order_preds = h.linked_outputs(self.0, old_order_in).collect::<Vec<_>>();
72        h.disconnect(self.0, old_order_in); // PortGraph currently does this anyway
73
74        let new_op = OpType::from(DFG {
75            signature: h
76                .get_optype(orig_func)
77                .as_func_defn()
78                .unwrap()
79                .inner_signature()
80                .into_owned(),
81        });
82        let new_order_in = new_op.other_input_port().unwrap();
83
84        let ty_args = h
85            .replace_op(self.0, new_op)
86            .as_call()
87            .unwrap()
88            .type_args
89            .clone();
90
91        h.add_ports(self.0, Direction::Incoming, -1);
92
93        // Reconnect order predecessors
94        for (src, srcp) in order_preds {
95            h.connect(src, srcp, self.0, new_order_in);
96        }
97
98        h.copy_descendants(
99            orig_func,
100            self.0,
101            (!ty_args.is_empty()).then_some(Substitution::new(&ty_args)),
102        );
103        Ok(())
104    }
105
106    /// Failure only occurs if the node is not a Call, or the target not a `FuncDefn`.
107    /// (Any later failure means an invalid Hugr and `panic`.)
108    const UNCHANGED_ON_FAILURE: bool = true;
109}
110
111#[cfg(test)]
112mod test {
113    use std::iter::successors;
114
115    use itertools::Itertools;
116
117    use crate::builder::{
118        Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
119        ModuleBuilder,
120    };
121    use crate::extension::prelude::usize_t;
122    use crate::ops::handle::{FuncID, NodeHandle};
123    use crate::ops::{Input, OpType, Value};
124    use crate::std_extensions::arithmetic::int_types::INT_TYPES;
125    use crate::std_extensions::arithmetic::{int_ops::IntOpDef, int_types::ConstInt};
126
127    use crate::types::{PolyFuncType, Signature, Type, TypeBound};
128    use crate::{HugrView, Node};
129
130    use super::{HugrMut, InlineCall, InlineCallError};
131
132    fn calls(h: &impl HugrView<Node = Node>) -> Vec<Node> {
133        h.entry_descendants()
134            .filter(|n| h.get_optype(*n).is_call())
135            .collect()
136    }
137
138    fn extension_ops(h: &impl HugrView<Node = Node>) -> Vec<Node> {
139        h.entry_descendants()
140            .filter(|n| h.get_optype(*n).is_extension_op())
141            .collect()
142    }
143
144    #[test]
145    fn test_inline() -> Result<(), Box<dyn std::error::Error>> {
146        let mut mb = ModuleBuilder::new();
147        let cst3 = mb.add_constant(Value::from(ConstInt::new_u(4, 3)?));
148        let sig = Signature::new_endo(INT_TYPES[4].clone());
149        let func = {
150            let mut fb = mb.define_function("foo", sig.clone())?;
151            let c1 = fb.load_const(&cst3);
152            let mut inner = fb.dfg_builder(sig.clone(), fb.input_wires())?;
153            let [i] = inner.input_wires_arr();
154            let add = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(4), [i, c1])?;
155            let inner_res = inner.finish_with_outputs(add.outputs())?;
156            fb.finish_with_outputs(inner_res.outputs())?
157        };
158        let mut main = mb.define_function("main", sig)?;
159        let call1 = main.call(func.handle(), &[], main.input_wires())?;
160        main.add_other_wire(main.input().node(), call1.node());
161        let call2 = main.call(func.handle(), &[], call1.outputs())?;
162        main.finish_with_outputs(call2.outputs())?;
163        let mut hugr = mb.finish_hugr()?;
164        let call1 = call1.node();
165        let call2 = call2.node();
166        assert_eq!(
167            hugr.output_neighbours(func.node()).collect_vec(),
168            [call1, call2]
169        );
170        assert_eq!(calls(&hugr), [call1, call2]);
171        assert_eq!(extension_ops(&hugr).len(), 1);
172
173        assert_eq!(
174            hugr.linked_outputs(
175                call1.node(),
176                hugr.get_optype(call1).other_input_port().unwrap()
177            )
178            .count(),
179            1
180        );
181        hugr.apply_patch(InlineCall(call1.node())).unwrap();
182        hugr.validate().unwrap();
183        assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]);
184        assert_eq!(calls(&hugr), [call2]);
185        assert_eq!(extension_ops(&hugr).len(), 2);
186        assert_eq!(
187            hugr.linked_outputs(
188                call1.node(),
189                hugr.get_optype(call1).other_input_port().unwrap()
190            )
191            .count(),
192            1
193        );
194        hugr.apply_patch(InlineCall(call2.node())).unwrap();
195        hugr.validate().unwrap();
196        assert_eq!(hugr.output_neighbours(func.node()).next(), None);
197        assert_eq!(calls(&hugr), []);
198        assert_eq!(extension_ops(&hugr).len(), 3);
199
200        Ok(())
201    }
202
203    #[test]
204    fn test_recursion() -> Result<(), Box<dyn std::error::Error>> {
205        let mut mb = ModuleBuilder::new();
206        let sig = Signature::new_endo(INT_TYPES[5].clone());
207        let (func, rec_call) = {
208            let mut fb = mb.define_function("foo", sig.clone())?;
209            let cst1 = fb.add_load_value(ConstInt::new_u(5, 1)?);
210            let [i] = fb.input_wires_arr();
211            let add = fb.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [i, cst1])?;
212            let call = fb.call(
213                &FuncID::<true>::from(fb.container_node()),
214                &[],
215                add.outputs(),
216            )?;
217            (fb.finish_with_outputs(call.outputs())?, call)
218        };
219        let mut main = mb.define_function("main", sig)?;
220        let call = main.call(func.handle(), &[], main.input_wires())?;
221        let main = main.finish_with_outputs(call.outputs())?;
222        let mut hugr = mb.finish_hugr()?;
223
224        let func = func.node();
225        let mut call = call.node();
226        for i in 2..10 {
227            hugr.apply_patch(InlineCall(call))?;
228            hugr.validate().unwrap();
229            assert_eq!(extension_ops(&hugr).len(), i);
230            let v = calls(&hugr);
231            assert!(v.iter().all(|n| hugr.static_source(*n) == Some(func)));
232
233            let [rec, nonrec] = v.try_into().expect("Should be two");
234            assert_eq!(rec, rec_call.node());
235            assert_eq!(hugr.output_neighbours(func).collect_vec(), [rec, nonrec]);
236            call = nonrec;
237
238            let mut ancestors = successors(hugr.get_parent(call), |n| hugr.get_parent(*n));
239            for _ in 1..i {
240                assert!(hugr.get_optype(ancestors.next().unwrap()).is_dfg());
241            }
242            assert_eq!(ancestors.next(), Some(main.node()));
243            assert_eq!(ancestors.next(), Some(hugr.entrypoint()));
244            assert_eq!(ancestors.next(), None);
245        }
246        Ok(())
247    }
248
249    #[test]
250    fn test_bad() {
251        let mut modb = ModuleBuilder::new();
252        let decl = modb
253            .declare(
254                "UndefinedFunc",
255                Signature::new_endo(INT_TYPES[3].clone()).into(),
256            )
257            .unwrap();
258        let mut main = modb
259            .define_function("main", Signature::new_endo(INT_TYPES[3].clone()))
260            .unwrap();
261        let call = main.call(&decl, &[], main.input_wires()).unwrap();
262        let main = main.finish_with_outputs(call.outputs()).unwrap();
263        let h = modb.finish_hugr().unwrap();
264        let mut h2 = h.clone();
265        assert_eq!(
266            h2.apply_patch(InlineCall(call.node())),
267            Err(InlineCallError::CallTargetNotFuncDefn(
268                decl.node(),
269                h.get_optype(decl.node()).clone()
270            ))
271        );
272        assert_eq!(h, h2);
273        let [inp, _out, _call] = h
274            .children(main.node())
275            .collect::<Vec<_>>()
276            .try_into()
277            .unwrap();
278        assert_eq!(
279            h2.apply_patch(InlineCall(inp)),
280            Err(InlineCallError::NotCallNode(
281                inp,
282                Input {
283                    types: INT_TYPES[3].clone().into()
284                }
285                .into()
286            ))
287        );
288    }
289
290    #[test]
291    fn test_polymorphic() -> Result<(), Box<dyn std::error::Error>> {
292        let tuple_ty = Type::new_tuple(vec![usize_t(); 2]);
293        let mut fb = FunctionBuilder::new("mkpair", Signature::new(usize_t(), tuple_ty.clone()))?;
294        let helper = {
295            let mut mb = fb.module_root_builder();
296            let fb2 = mb.define_function(
297                "id",
298                PolyFuncType::new(
299                    [TypeBound::Copyable.into()],
300                    Signature::new_endo(Type::new_var_use(0, TypeBound::Copyable)),
301                ),
302            )?;
303            let inps = fb2.input_wires();
304            fb2.finish_with_outputs(inps)?
305        };
306        let call1 = fb.call(helper.handle(), &[usize_t().into()], fb.input_wires())?;
307        let [call1_out] = call1.outputs_arr();
308        let tup = fb.make_tuple([call1_out, call1_out])?;
309        let call2 = fb.call(helper.handle(), &[tuple_ty.into()], [tup])?;
310        let mut hugr = fb.finish_hugr_with_outputs(call2.outputs()).unwrap();
311
312        assert_eq!(
313            hugr.output_neighbours(helper.node()).collect::<Vec<_>>(),
314            [call1.node(), call2.node()]
315        );
316        hugr.apply_patch(InlineCall::new(call1.node()))?;
317
318        assert_eq!(
319            hugr.output_neighbours(helper.node()).collect::<Vec<_>>(),
320            [call2.node()]
321        );
322        assert!(hugr.get_optype(call1.node()).is_dfg());
323        assert!(matches!(
324            hugr.children(call1.node())
325                .map(|n| hugr.get_optype(n).clone())
326                .collect::<Vec<_>>()[..],
327            [OpType::Input(_), OpType::Output(_)]
328        ));
329        Ok(())
330    }
331}