1use derive_more::{Display, Error};
4
5use crate::core::HugrNode;
6use crate::ops::{DFG, DataflowParent, OpType};
7use crate::types::Substitution;
8use crate::{Direction, HugrView, Node};
9
10use super::{HugrMut, PatchHugrMut, PatchVerification};
11
12pub struct InlineCall<N = Node>(N);
14
15#[derive(Clone, Debug, Display, Error, PartialEq)]
17#[non_exhaustive]
18pub enum InlineCallError<N = Node> {
19 #[display("Node to inline {_0} expected to be a Call but actually {_1}")]
21 NotCallNode(N, OpType),
22 #[display("Call targetted node {_0} which must be a FuncDefn but was {_1}")]
25 CallTargetNotFuncDefn(N, OpType),
26}
27
28impl<N> InlineCall<N> {
29 pub fn new(node: N) -> Self {
32 Self(node)
33 }
34}
35
36impl<N: HugrNode> PatchVerification for InlineCall<N> {
37 type Error = InlineCallError<N>;
38 type Node = N;
39 fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
40 let call_ty = h.get_optype(self.0);
41 if !call_ty.is_call() {
42 return Err(InlineCallError::NotCallNode(self.0, call_ty.clone()));
43 }
44 let func = h.static_source(self.0).unwrap();
45 let func_ty = h.get_optype(func);
46 if !func_ty.is_func_defn() {
47 return Err(InlineCallError::CallTargetNotFuncDefn(
48 func,
49 func_ty.clone(),
50 ));
51 }
52 Ok(())
53 }
54
55 fn invalidated_nodes(&self, _: &impl HugrView<Node = N>) -> impl Iterator<Item = N> {
56 Some(self.0).into_iter()
57 }
58}
59
60impl<N: HugrNode> PatchHugrMut for InlineCall<N> {
61 type Outcome = ();
62 fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<(), Self::Error> {
63 self.verify(h)?; let orig_func = h.static_source(self.0).unwrap();
65
66 h.disconnect(self.0, h.get_optype(self.0).static_input_port().unwrap());
67
68 let old_order_in = h.get_optype(self.0).other_input_port().unwrap();
71 let order_preds = h.linked_outputs(self.0, old_order_in).collect::<Vec<_>>();
72 h.disconnect(self.0, old_order_in); let new_op = OpType::from(DFG {
75 signature: h
76 .get_optype(orig_func)
77 .as_func_defn()
78 .unwrap()
79 .inner_signature()
80 .into_owned(),
81 });
82 let new_order_in = new_op.other_input_port().unwrap();
83
84 let ty_args = h
85 .replace_op(self.0, new_op)
86 .as_call()
87 .unwrap()
88 .type_args
89 .clone();
90
91 h.add_ports(self.0, Direction::Incoming, -1);
92
93 for (src, srcp) in order_preds {
95 h.connect(src, srcp, self.0, new_order_in);
96 }
97
98 h.copy_descendants(
99 orig_func,
100 self.0,
101 (!ty_args.is_empty()).then_some(Substitution::new(&ty_args)),
102 );
103 Ok(())
104 }
105
106 const UNCHANGED_ON_FAILURE: bool = true;
109}
110
111#[cfg(test)]
112mod test {
113 use std::iter::successors;
114
115 use itertools::Itertools;
116
117 use crate::builder::{
118 Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
119 ModuleBuilder,
120 };
121 use crate::extension::prelude::usize_t;
122 use crate::ops::handle::{FuncID, NodeHandle};
123 use crate::ops::{Input, OpType, Value};
124 use crate::std_extensions::arithmetic::int_types::INT_TYPES;
125 use crate::std_extensions::arithmetic::{int_ops::IntOpDef, int_types::ConstInt};
126
127 use crate::types::{PolyFuncType, Signature, Type, TypeBound};
128 use crate::{HugrView, Node};
129
130 use super::{HugrMut, InlineCall, InlineCallError};
131
132 fn calls(h: &impl HugrView<Node = Node>) -> Vec<Node> {
133 h.entry_descendants()
134 .filter(|n| h.get_optype(*n).is_call())
135 .collect()
136 }
137
138 fn extension_ops(h: &impl HugrView<Node = Node>) -> Vec<Node> {
139 h.entry_descendants()
140 .filter(|n| h.get_optype(*n).is_extension_op())
141 .collect()
142 }
143
144 #[test]
145 fn test_inline() -> Result<(), Box<dyn std::error::Error>> {
146 let mut mb = ModuleBuilder::new();
147 let cst3 = mb.add_constant(Value::from(ConstInt::new_u(4, 3)?));
148 let sig = Signature::new_endo(INT_TYPES[4].clone());
149 let func = {
150 let mut fb = mb.define_function("foo", sig.clone())?;
151 let c1 = fb.load_const(&cst3);
152 let mut inner = fb.dfg_builder(sig.clone(), fb.input_wires())?;
153 let [i] = inner.input_wires_arr();
154 let add = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(4), [i, c1])?;
155 let inner_res = inner.finish_with_outputs(add.outputs())?;
156 fb.finish_with_outputs(inner_res.outputs())?
157 };
158 let mut main = mb.define_function("main", sig)?;
159 let call1 = main.call(func.handle(), &[], main.input_wires())?;
160 main.add_other_wire(main.input().node(), call1.node());
161 let call2 = main.call(func.handle(), &[], call1.outputs())?;
162 main.finish_with_outputs(call2.outputs())?;
163 let mut hugr = mb.finish_hugr()?;
164 let call1 = call1.node();
165 let call2 = call2.node();
166 assert_eq!(
167 hugr.output_neighbours(func.node()).collect_vec(),
168 [call1, call2]
169 );
170 assert_eq!(calls(&hugr), [call1, call2]);
171 assert_eq!(extension_ops(&hugr).len(), 1);
172
173 assert_eq!(
174 hugr.linked_outputs(
175 call1.node(),
176 hugr.get_optype(call1).other_input_port().unwrap()
177 )
178 .count(),
179 1
180 );
181 hugr.apply_patch(InlineCall(call1.node())).unwrap();
182 hugr.validate().unwrap();
183 assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]);
184 assert_eq!(calls(&hugr), [call2]);
185 assert_eq!(extension_ops(&hugr).len(), 2);
186 assert_eq!(
187 hugr.linked_outputs(
188 call1.node(),
189 hugr.get_optype(call1).other_input_port().unwrap()
190 )
191 .count(),
192 1
193 );
194 hugr.apply_patch(InlineCall(call2.node())).unwrap();
195 hugr.validate().unwrap();
196 assert_eq!(hugr.output_neighbours(func.node()).next(), None);
197 assert_eq!(calls(&hugr), []);
198 assert_eq!(extension_ops(&hugr).len(), 3);
199
200 Ok(())
201 }
202
203 #[test]
204 fn test_recursion() -> Result<(), Box<dyn std::error::Error>> {
205 let mut mb = ModuleBuilder::new();
206 let sig = Signature::new_endo(INT_TYPES[5].clone());
207 let (func, rec_call) = {
208 let mut fb = mb.define_function("foo", sig.clone())?;
209 let cst1 = fb.add_load_value(ConstInt::new_u(5, 1)?);
210 let [i] = fb.input_wires_arr();
211 let add = fb.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [i, cst1])?;
212 let call = fb.call(
213 &FuncID::<true>::from(fb.container_node()),
214 &[],
215 add.outputs(),
216 )?;
217 (fb.finish_with_outputs(call.outputs())?, call)
218 };
219 let mut main = mb.define_function("main", sig)?;
220 let call = main.call(func.handle(), &[], main.input_wires())?;
221 let main = main.finish_with_outputs(call.outputs())?;
222 let mut hugr = mb.finish_hugr()?;
223
224 let func = func.node();
225 let mut call = call.node();
226 for i in 2..10 {
227 hugr.apply_patch(InlineCall(call))?;
228 hugr.validate().unwrap();
229 assert_eq!(extension_ops(&hugr).len(), i);
230 let v = calls(&hugr);
231 assert!(v.iter().all(|n| hugr.static_source(*n) == Some(func)));
232
233 let [rec, nonrec] = v.try_into().expect("Should be two");
234 assert_eq!(rec, rec_call.node());
235 assert_eq!(hugr.output_neighbours(func).collect_vec(), [rec, nonrec]);
236 call = nonrec;
237
238 let mut ancestors = successors(hugr.get_parent(call), |n| hugr.get_parent(*n));
239 for _ in 1..i {
240 assert!(hugr.get_optype(ancestors.next().unwrap()).is_dfg());
241 }
242 assert_eq!(ancestors.next(), Some(main.node()));
243 assert_eq!(ancestors.next(), Some(hugr.entrypoint()));
244 assert_eq!(ancestors.next(), None);
245 }
246 Ok(())
247 }
248
249 #[test]
250 fn test_bad() {
251 let mut modb = ModuleBuilder::new();
252 let decl = modb
253 .declare(
254 "UndefinedFunc",
255 Signature::new_endo(INT_TYPES[3].clone()).into(),
256 )
257 .unwrap();
258 let mut main = modb
259 .define_function("main", Signature::new_endo(INT_TYPES[3].clone()))
260 .unwrap();
261 let call = main.call(&decl, &[], main.input_wires()).unwrap();
262 let main = main.finish_with_outputs(call.outputs()).unwrap();
263 let h = modb.finish_hugr().unwrap();
264 let mut h2 = h.clone();
265 assert_eq!(
266 h2.apply_patch(InlineCall(call.node())),
267 Err(InlineCallError::CallTargetNotFuncDefn(
268 decl.node(),
269 h.get_optype(decl.node()).clone()
270 ))
271 );
272 assert_eq!(h, h2);
273 let [inp, _out, _call] = h
274 .children(main.node())
275 .collect::<Vec<_>>()
276 .try_into()
277 .unwrap();
278 assert_eq!(
279 h2.apply_patch(InlineCall(inp)),
280 Err(InlineCallError::NotCallNode(
281 inp,
282 Input {
283 types: INT_TYPES[3].clone().into()
284 }
285 .into()
286 ))
287 );
288 }
289
290 #[test]
291 fn test_polymorphic() -> Result<(), Box<dyn std::error::Error>> {
292 let tuple_ty = Type::new_tuple(vec![usize_t(); 2]);
293 let mut fb = FunctionBuilder::new("mkpair", Signature::new(usize_t(), tuple_ty.clone()))?;
294 let helper = {
295 let mut mb = fb.module_root_builder();
296 let fb2 = mb.define_function(
297 "id",
298 PolyFuncType::new(
299 [TypeBound::Copyable.into()],
300 Signature::new_endo(Type::new_var_use(0, TypeBound::Copyable)),
301 ),
302 )?;
303 let inps = fb2.input_wires();
304 fb2.finish_with_outputs(inps)?
305 };
306 let call1 = fb.call(helper.handle(), &[usize_t().into()], fb.input_wires())?;
307 let [call1_out] = call1.outputs_arr();
308 let tup = fb.make_tuple([call1_out, call1_out])?;
309 let call2 = fb.call(helper.handle(), &[tuple_ty.into()], [tup])?;
310 let mut hugr = fb.finish_hugr_with_outputs(call2.outputs()).unwrap();
311
312 assert_eq!(
313 hugr.output_neighbours(helper.node()).collect::<Vec<_>>(),
314 [call1.node(), call2.node()]
315 );
316 hugr.apply_patch(InlineCall::new(call1.node()))?;
317
318 assert_eq!(
319 hugr.output_neighbours(helper.node()).collect::<Vec<_>>(),
320 [call2.node()]
321 );
322 assert!(hugr.get_optype(call1.node()).is_dfg());
323 assert!(matches!(
324 hugr.children(call1.node())
325 .map(|n| hugr.get_optype(n).clone())
326 .collect::<Vec<_>>()[..],
327 [OpType::Input(_), OpType::Output(_)]
328 ));
329 Ok(())
330 }
331}