1use derive_more::{Display, Error};
4
5use crate::ops::{DataflowParent, OpType, DFG};
6use crate::types::Substitution;
7use crate::{Direction, HugrView, Node};
8
9use super::{HugrMut, Rewrite};
10
11pub struct InlineCall(Node);
13
14#[derive(Clone, Debug, Display, Error, PartialEq)]
16#[non_exhaustive]
17pub enum InlineCallError {
18 #[display("Node to inline {_0} expected to be a Call but actually {_1}")]
20 NotCallNode(Node, OpType),
21 #[display("Call targetted node {_0} which must be a FuncDefn but was {_1}")]
24 CallTargetNotFuncDefn(Node, OpType),
25}
26
27impl InlineCall {
28 pub fn new(node: Node) -> Self {
31 Self(node)
32 }
33}
34
35impl Rewrite for InlineCall {
36 type ApplyResult = ();
37 type Error = InlineCallError;
38 fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error> {
39 let call_ty = h.get_optype(self.0);
40 if !call_ty.is_call() {
41 return Err(InlineCallError::NotCallNode(self.0, call_ty.clone()));
42 }
43 let func = h.static_source(self.0).unwrap();
44 let func_ty = h.get_optype(func);
45 if !func_ty.is_func_defn() {
46 return Err(InlineCallError::CallTargetNotFuncDefn(
47 func,
48 func_ty.clone(),
49 ));
50 }
51 Ok(())
52 }
53
54 fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> {
55 self.verify(h)?; let orig_func = h.static_source(self.0).unwrap();
57
58 h.disconnect(self.0, h.get_optype(self.0).static_input_port().unwrap());
59
60 let old_order_in = h.get_optype(self.0).other_input_port().unwrap();
63 let order_preds = h.linked_outputs(self.0, old_order_in).collect::<Vec<_>>();
64 h.disconnect(self.0, old_order_in); let new_op = OpType::from(DFG {
67 signature: h
68 .get_optype(orig_func)
69 .as_func_defn()
70 .unwrap()
71 .inner_signature()
72 .into_owned(),
73 });
74 let new_order_in = new_op.other_input_port().unwrap();
75
76 let ty_args = h
77 .replace_op(self.0, new_op)
78 .unwrap()
79 .as_call()
80 .unwrap()
81 .type_args
82 .clone();
83
84 h.add_ports(self.0, Direction::Incoming, -1);
85
86 for (src, srcp) in order_preds {
88 h.connect(src, srcp, self.0, new_order_in);
89 }
90
91 h.copy_descendants(
92 orig_func,
93 self.0,
94 (!ty_args.is_empty()).then_some(Substitution::new(&ty_args)),
95 );
96 Ok(())
97 }
98
99 const UNCHANGED_ON_FAILURE: bool = true;
102
103 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
104 Some(self.0).into_iter()
105 }
106}
107
108#[cfg(test)]
109mod test {
110 use std::iter::successors;
111
112 use itertools::Itertools;
113
114 use crate::builder::{
115 Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
116 ModuleBuilder,
117 };
118 use crate::extension::prelude::usize_t;
119 use crate::hugr::views::RootChecked;
120 use crate::ops::handle::{FuncID, ModuleRootID, NodeHandle};
121 use crate::ops::{Input, OpType, Value};
122 use crate::std_extensions::arithmetic::{
123 int_ops::{self, IntOpDef},
124 int_types::{self, ConstInt, INT_TYPES},
125 };
126
127 use crate::types::{PolyFuncType, Signature, Type, TypeBound};
128 use crate::{HugrView, Node};
129
130 use super::{HugrMut, InlineCall, InlineCallError};
131
132 fn calls(h: &impl HugrView<Node = Node>) -> Vec<Node> {
133 h.nodes().filter(|n| h.get_optype(*n).is_call()).collect()
134 }
135
136 fn extension_ops(h: &impl HugrView<Node = Node>) -> Vec<Node> {
137 h.nodes()
138 .filter(|n| h.get_optype(*n).is_extension_op())
139 .collect()
140 }
141
142 #[test]
143 fn test_inline() -> Result<(), Box<dyn std::error::Error>> {
144 let mut mb = ModuleBuilder::new();
145 let cst3 = mb.add_constant(Value::from(ConstInt::new_u(4, 3)?));
146 let sig = Signature::new_endo(INT_TYPES[4].clone())
147 .with_extension_delta(int_ops::EXTENSION_ID)
148 .with_extension_delta(int_types::EXTENSION_ID);
149 let func = {
150 let mut fb = mb.define_function("foo", sig.clone())?;
151 let c1 = fb.load_const(&cst3);
152 let mut inner = fb.dfg_builder(sig.clone(), fb.input_wires())?;
153 let [i] = inner.input_wires_arr();
154 let add = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(4), [i, c1])?;
155 let inner_res = inner.finish_with_outputs(add.outputs())?;
156 fb.finish_with_outputs(inner_res.outputs())?
157 };
158 let mut main = mb.define_function("main", sig)?;
159 let call1 = main.call(func.handle(), &[], main.input_wires())?;
160 main.add_other_wire(main.input().node(), call1.node());
161 let call2 = main.call(func.handle(), &[], call1.outputs())?;
162 main.finish_with_outputs(call2.outputs())?;
163 let mut hugr = mb.finish_hugr()?;
164 let call1 = call1.node();
165 let call2 = call2.node();
166 assert_eq!(
167 hugr.output_neighbours(func.node()).collect_vec(),
168 [call1, call2]
169 );
170 assert_eq!(calls(&hugr), [call1, call2]);
171 assert_eq!(extension_ops(&hugr).len(), 1);
172
173 assert_eq!(
174 hugr.linked_outputs(
175 call1.node(),
176 hugr.get_optype(call1).other_input_port().unwrap()
177 )
178 .count(),
179 1
180 );
181 RootChecked::<_, ModuleRootID>::try_new(&mut hugr)
182 .unwrap()
183 .apply_rewrite(InlineCall(call1.node()))
184 .unwrap();
185 hugr.validate().unwrap();
186 assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]);
187 assert_eq!(calls(&hugr), [call2]);
188 assert_eq!(extension_ops(&hugr).len(), 2);
189 assert_eq!(
190 hugr.linked_outputs(
191 call1.node(),
192 hugr.get_optype(call1).other_input_port().unwrap()
193 )
194 .count(),
195 1
196 );
197 hugr.apply_rewrite(InlineCall(call2.node())).unwrap();
198 hugr.validate().unwrap();
199 assert_eq!(hugr.output_neighbours(func.node()).next(), None);
200 assert_eq!(calls(&hugr), []);
201 assert_eq!(extension_ops(&hugr).len(), 3);
202
203 Ok(())
204 }
205
206 #[test]
207 fn test_recursion() -> Result<(), Box<dyn std::error::Error>> {
208 let mut mb = ModuleBuilder::new();
209 let sig = Signature::new_endo(INT_TYPES[5].clone())
210 .with_extension_delta(int_ops::EXTENSION_ID)
211 .with_extension_delta(int_types::EXTENSION_ID);
212 let (func, rec_call) = {
213 let mut fb = mb.define_function("foo", sig.clone())?;
214 let cst1 = fb.add_load_value(ConstInt::new_u(5, 1)?);
215 let [i] = fb.input_wires_arr();
216 let add = fb.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [i, cst1])?;
217 let call = fb.call(
218 &FuncID::<true>::from(fb.container_node()),
219 &[],
220 add.outputs(),
221 )?;
222 (fb.finish_with_outputs(call.outputs())?, call)
223 };
224 let mut main = mb.define_function("main", sig)?;
225 let call = main.call(func.handle(), &[], main.input_wires())?;
226 let main = main.finish_with_outputs(call.outputs())?;
227 let mut hugr = mb.finish_hugr()?;
228
229 let func = func.node();
230 let mut call = call.node();
231 for i in 2..10 {
232 hugr.apply_rewrite(InlineCall(call))?;
233 hugr.validate().unwrap();
234 assert_eq!(extension_ops(&hugr).len(), i);
235 let v = calls(&hugr);
236 assert!(v.iter().all(|n| hugr.static_source(*n) == Some(func)));
237
238 let [rec, nonrec] = v.try_into().expect("Should be two");
239 assert_eq!(rec, rec_call.node());
240 assert_eq!(hugr.output_neighbours(func).collect_vec(), [rec, nonrec]);
241 call = nonrec;
242
243 let mut ancestors = successors(hugr.get_parent(call), |n| hugr.get_parent(*n));
244 for _ in 1..i {
245 assert!(hugr.get_optype(ancestors.next().unwrap()).is_dfg());
246 }
247 assert_eq!(ancestors.next(), Some(main.node()));
248 assert_eq!(ancestors.next(), Some(hugr.root()));
249 assert_eq!(ancestors.next(), None);
250 }
251 Ok(())
252 }
253
254 #[test]
255 fn test_bad() {
256 let mut modb = ModuleBuilder::new();
257 let decl = modb
258 .declare(
259 "UndefinedFunc",
260 Signature::new_endo(INT_TYPES[3].clone()).into(),
261 )
262 .unwrap();
263 let mut main = modb
264 .define_function("main", Signature::new_endo(INT_TYPES[3].clone()))
265 .unwrap();
266 let call = main.call(&decl, &[], main.input_wires()).unwrap();
267 let main = main.finish_with_outputs(call.outputs()).unwrap();
268 let h = modb.finish_hugr().unwrap();
269 let mut h2 = h.clone();
270 assert_eq!(
271 h2.apply_rewrite(InlineCall(call.node())),
272 Err(InlineCallError::CallTargetNotFuncDefn(
273 decl.node(),
274 h.get_optype(decl.node()).clone()
275 ))
276 );
277 assert_eq!(h, h2);
278 let [inp, _out, _call] = h
279 .children(main.node())
280 .collect::<Vec<_>>()
281 .try_into()
282 .unwrap();
283 assert_eq!(
284 h2.apply_rewrite(InlineCall(inp)),
285 Err(InlineCallError::NotCallNode(
286 inp,
287 Input {
288 types: INT_TYPES[3].clone().into()
289 }
290 .into()
291 ))
292 )
293 }
294
295 #[test]
296 fn test_polymorphic() -> Result<(), Box<dyn std::error::Error>> {
297 let tuple_ty = Type::new_tuple(vec![usize_t(); 2]);
298 let mut fb = FunctionBuilder::new(
299 "mkpair",
300 Signature::new(usize_t(), tuple_ty.clone()).with_prelude(),
301 )?;
302 let inner = fb.define_function(
303 "id",
304 PolyFuncType::new(
305 [TypeBound::Copyable.into()],
306 Signature::new_endo(Type::new_var_use(0, TypeBound::Copyable)),
307 ),
308 )?;
309 let inps = inner.input_wires();
310 let inner = inner.finish_with_outputs(inps)?;
311 let call1 = fb.call(inner.handle(), &[usize_t().into()], fb.input_wires())?;
312 let [call1_out] = call1.outputs_arr();
313 let tup = fb.make_tuple([call1_out, call1_out])?;
314 let call2 = fb.call(inner.handle(), &[tuple_ty.into()], [tup])?;
315 let mut hugr = fb.finish_hugr_with_outputs(call2.outputs()).unwrap();
316
317 assert_eq!(
318 hugr.output_neighbours(inner.node()).collect::<Vec<_>>(),
319 [call1.node(), call2.node()]
320 );
321 hugr.apply_rewrite(InlineCall::new(call1.node()))?;
322
323 assert_eq!(
324 hugr.output_neighbours(inner.node()).collect::<Vec<_>>(),
325 [call2.node()]
326 );
327 assert!(hugr.get_optype(call1.node()).is_dfg());
328 assert!(matches!(
329 hugr.children(call1.node())
330 .map(|n| hugr.get_optype(n).clone())
331 .collect::<Vec<_>>()[..],
332 [OpType::Input(_), OpType::Output(_)]
333 ));
334 Ok(())
335 }
336}