1use std::collections::{HashMap, HashSet};
2
3use lutra_bin::ir;
4
5use super::fold::{self, IrFold};
6use crate::utils::IdGenerator;
7
8pub fn inline(program: ir::Program) -> ir::Program {
9 let (mut program, id_counts) = IdCounter::run(program);
10
11 let mut inliner = FuncInliner {
13 bindings: Default::default(),
14
15 currently_inlining: Default::default(),
16
17 generator_var_binding: IdGenerator::new_at(id_counts.max_var_id as usize),
18 };
19 program.main = inliner.fold_expr(program.main).unwrap();
20
21 tracing::debug!("ir (funcs inlined):\n{}", ir::print(&program));
22
23 let mut counter = BindingUsageCounter {
25 usage: Default::default(),
26 simple: Default::default(),
27 };
28 program.main = counter.fold_expr(program.main).unwrap();
29 tracing::debug!("binding_usage = {:?}", counter.usage);
30 tracing::debug!("simple_bindings = {:?}", counter.simple);
31
32 let mut inliner = BindingInliner::new(counter.usage, counter.simple);
34 program.main = inliner.fold_expr(program.main).unwrap();
35
36 program
37}
38
39struct FuncInliner {
40 bindings: HashMap<u32, ir::Function>,
42
43 currently_inlining: HashSet<u32>,
44
45 generator_var_binding: IdGenerator,
46}
47
48impl fold::IrFold for FuncInliner {
49 fn fold_binding(&mut self, binding: ir::Binding, ty: ir::Ty) -> Result<ir::Expr, ()> {
50 if binding.expr.ty.kind.is_function() {
52 match binding.expr.kind {
53 ir::ExprKind::Function(func) => {
55 let func = self.fold_func(*func, binding.expr.ty)?;
57 let func = func.kind.into_function().unwrap();
58
59 self.bindings.insert(binding.id, *func);
61
62 return self.fold_expr(binding.main);
64 }
65
66 ir::ExprKind::Pointer(_) => todo!(),
67
68 _ => panic!(),
69 }
70 }
71
72 fold::fold_binding(self, binding, ty)
73 }
74
75 fn fold_call(&mut self, call: ir::Call, ty: ir::Ty) -> Result<ir::Expr, ()> {
76 let args = fold::fold_exprs(self, call.args)?;
77
78 let function = match call.function.kind {
79 ir::ExprKind::Pointer(ir::Pointer::Binding(ref binding_id)) => {
81 if self.currently_inlining.contains(binding_id) {
82 panic!("recursive function cannot be inlined");
83 }
84 if let Some(func) = self.bindings.get(binding_id) {
85 let expr = self.substitute_function(func.clone(), args);
86
87 self.currently_inlining.insert(*binding_id);
88 let expr = self.fold_expr(expr);
89 self.currently_inlining.remove(binding_id);
90 return expr;
91 } else {
92 call.function
94 }
95 }
96
97 ir::ExprKind::Function(func) => {
99 let expr = self.substitute_function(*func, args);
102 return self.fold_expr(expr);
103 }
104
105 ir::ExprKind::Pointer(ir::Pointer::Parameter(_)) => call.function,
106
107 ir::ExprKind::Pointer(ir::Pointer::External(_)) => call.function,
108
109 _ => unreachable!(),
110 };
111
112 let kind = ir::ExprKind::Call(Box::new(ir::Call { function, args }));
113 Ok(ir::Expr { kind, ty })
114 }
115
116 fn fold_ptr(&mut self, ptr: ir::Pointer, ty: ir::Ty) -> Result<ir::Expr, ()> {
117 if let ir::Pointer::Binding(binding_id) = &ptr {
118 if let Some(func) = self.bindings.get(binding_id) {
121 return Ok(ir::Expr {
122 kind: ir::ExprKind::Function(Box::new(func.clone())),
123 ty,
124 });
125 }
126 }
127 fold::fold_ptr(ptr, ty)
128 }
129
130 fn fold_switch(&mut self, branches: Vec<ir::SwitchBranch>, ty: ir::Ty) -> Result<ir::Expr, ()> {
132 fn as_bool_literal(expr: &ir::Expr) -> Option<bool> {
138 expr.kind.as_literal().and_then(|l| l.as_bool().cloned())
139 }
140 if matches!(ty.kind, ir::TyKind::Primitive(ir::TyPrimitive::bool))
141 && branches.len() == 2
142 && let Some(value_then) = as_bool_literal(&branches[0].value)
143 && let Some(value_else) = as_bool_literal(&branches[1].value)
144 {
145 let cond = branches.into_iter().next().unwrap().condition;
146
147 match (value_then, value_else) {
148 (true, true) | (false, false) => {
149 return Ok(ir::Expr::new_lit_bool(value_then));
151 }
152 (true, false) => {
153 return self.fold_expr(cond);
155 }
156 (false, true) => {
157 let cond = self.fold_expr(cond)?;
159
160 let ty_bool = ir::Ty::new(ir::TyPrimitive::bool);
161 let std_not = ir::Expr::new(
162 ir::ExternalPtr {
163 id: "std::not".into(),
164 },
165 ir::Ty::new(ir::TyFunction {
166 params: vec![ty_bool.clone()],
167 body: ty_bool.clone(),
168 }),
169 );
170
171 return Ok(ir::Expr::new(
172 ir::Call {
173 function: std_not,
174 args: vec![cond],
175 },
176 ty_bool,
177 ));
178 }
179 }
180 }
181
182 fold::fold_switch(self, branches, ty)
183 }
184}
185
186impl FuncInliner {
187 fn substitute_function(&mut self, func: ir::Function, args: Vec<ir::Expr>) -> ir::Expr {
188 let mut arg_var_ids = Vec::with_capacity(args.len());
190 let mut arg_pointers = Vec::with_capacity(args.len());
191 for arg in &args {
192 let id = self.generator_var_binding.next() as u32;
193 arg_var_ids.push(id);
194 arg_pointers.push(ir::Expr {
195 kind: ir::ExprKind::Pointer(ir::Pointer::Binding(id)),
196 ty: arg.ty.clone(),
197 });
198 }
199
200 tracing::debug!("inlining call to function {} with {arg_var_ids:?}", func.id);
202 let mut expr = Substituter::run(func.body, func.id, arg_pointers);
203
204 for (id, arg) in std::iter::zip(arg_var_ids, args) {
206 expr = ir::Expr {
207 ty: expr.ty.clone(),
208 kind: ir::ExprKind::Binding(Box::new(ir::Binding {
209 id,
210 expr: arg,
211 main: expr,
212 })),
213 }
214 }
215 expr
216 }
217}
218
219struct BindingUsageCounter {
220 usage: HashMap<u32, usize>,
221 simple: HashSet<u32>,
222}
223
224impl BindingUsageCounter {
225 fn is_simple_expr(expr: &ir::Expr) -> bool {
226 match &expr.kind {
227 ir::ExprKind::Literal(ir::Literal::text(_)) => false,
228 ir::ExprKind::Literal(_) => true,
229 ir::ExprKind::Pointer(_) => true,
230 ir::ExprKind::TupleLookup(lookup) => Self::is_simple_expr(&lookup.base),
231 ir::ExprKind::Tuple(fields) => fields.iter().all(|f| Self::is_simple_expr(&f.expr)),
232 _ => false,
233 }
234 }
235}
236
237impl fold::IrFold for BindingUsageCounter {
238 fn fold_binding(&mut self, binding: ir::Binding, ty: ir::Ty) -> Result<ir::Expr, ()> {
239 self.usage.insert(binding.id, 0);
240
241 if Self::is_simple_expr(&binding.expr) {
243 self.simple.insert(binding.id);
244 }
245
246 fold::fold_binding(self, binding, ty)
247 }
248
249 fn fold_ptr(&mut self, ptr: ir::Pointer, ty: ir::Ty) -> Result<ir::Expr, ()> {
250 if let ir::Pointer::Binding(binding_id) = &ptr {
251 *self.usage.entry(*binding_id).or_default() += 1;
253 }
254 fold::fold_ptr(ptr, ty)
255 }
256}
257
258struct BindingInliner {
259 bindings: HashMap<u32, ir::Expr>,
260
261 to_inline: HashSet<u32>,
262}
263
264impl BindingInliner {
265 fn new(bindings_usage: HashMap<u32, usize>, simple: HashSet<u32>) -> Self {
266 let to_inline: HashSet<u32> = bindings_usage
268 .into_iter()
269 .filter(|(id, usage_count)| *usage_count <= 1 || simple.contains(id))
270 .map(|(id, _)| id)
271 .collect();
272
273 tracing::debug!("inlining vars: {:?}", to_inline);
274
275 BindingInliner {
276 bindings: Default::default(),
277 to_inline,
278 }
279 }
280}
281
282impl IrFold for BindingInliner {
283 fn fold_binding(&mut self, binding: ir::Binding, ty: ir::Ty) -> Result<ir::Expr, ()> {
284 if self.to_inline.contains(&binding.id) {
285 let expr = self.fold_expr(binding.expr)?;
287 self.bindings.insert(binding.id, expr);
288
289 return self.fold_expr(binding.main);
291 }
292 fold::fold_binding(self, binding, ty)
293 }
294 fn fold_ptr(&mut self, ptr: ir::Pointer, ty: ir::Ty) -> Result<ir::Expr, ()> {
295 if let ir::Pointer::Binding(binding_id) = &ptr {
296 if let Some(value) = self.bindings.get(binding_id) {
299 return Ok(value.clone());
300 }
301 }
302 fold::fold_ptr(ptr, ty)
303 }
304
305 fn fold_enum_eq(&mut self, enum_eq: ir::EnumEq, ty: ir::Ty) -> Result<ir::Expr, ()> {
307 let enum_eq = ir::EnumEq {
309 tag: enum_eq.tag,
310 subject: self.fold_expr(enum_eq.subject)?,
311 };
312
313 if let ir::ExprKind::Call(call) = &enum_eq.subject.kind
323 && let ir::ExprKind::Pointer(ir::Pointer::External(func)) = &call.function.kind
324 && func.id == "std::cmp"
325 {
326 let (cmp_func, swap) = match enum_eq.tag {
327 0 => ("std::lt", false),
328 1 => ("std::eq", false),
329 2 => ("std::lt", true),
330 _ => unreachable!(),
331 };
332
333 let mut func_ty = call.function.ty.clone();
334 func_ty.kind.as_function_mut().unwrap().body = ir::Ty::new(ir::TyPrimitive::bool);
335
336 let function = ir::Expr::new(
337 ir::ExternalPtr {
338 id: cmp_func.to_string(),
339 },
340 func_ty,
341 );
342
343 let mut args = call.args.clone();
344 if swap {
345 args.reverse();
346 }
347
348 return Ok(ir::Expr::new(ir::Call { function, args }, ty));
349 }
350
351 Ok(ir::Expr {
352 kind: ir::ExprKind::EnumEq(Box::new(enum_eq)),
353 ty,
354 })
355 }
356
357 fn fold_call(&mut self, call: ir::Call, ty: ir::Ty) -> Result<ir::Expr, ()> {
359 let expr = fold::fold_call(self, call, ty)?;
360
361 fn as_external(expr: &ir::Expr) -> Option<&str> {
362 expr.kind
363 .as_pointer()
364 .and_then(|p| p.as_external())
365 .map(|e| e.id.as_str())
366 }
367 fn as_external_mut(expr: &mut ir::Expr) -> Option<&mut String> {
368 expr.kind
369 .as_pointer_mut()
370 .and_then(|p| p.as_external_mut())
371 .map(|e| &mut e.id)
372 }
373
374 if let ir::ExprKind::Call(outer) = &expr.kind
384 && let Some(outer_id) = as_external(&outer.function)
385 && let ir::ExprKind::Call(inner) = &outer.args[0].kind
386 && let Some(inner_id) = as_external(&inner.function)
387 {
388 match (outer_id, inner_id) {
389 ("std::not", "std::not") => {
390 return Ok(inner.args[0].clone());
392 }
393
394 ("std::not", "std::lt") => {
395 let mut call = *inner.clone();
397
398 let func_id = as_external_mut(&mut call.function).unwrap();
399 *func_id = "std::lte".to_string();
400
401 call.args.reverse();
402
403 return Ok(ir::Expr::new(call, expr.ty));
404 }
405 ("std::not", "std::lte") => {
406 let mut call = *inner.clone();
408
409 let func_id = as_external_mut(&mut call.function).unwrap();
410 *func_id = "std::lt".to_string();
411
412 call.args.reverse();
413
414 return Ok(ir::Expr::new(call, expr.ty));
415 }
416 _ => {}
417 }
418 }
419
420 Ok(expr)
421 }
422}
423
424struct Substituter {
425 function_id: u32,
426 args: Vec<ir::Expr>,
427}
428
429impl Substituter {
430 fn run(expr: ir::Expr, function_id: u32, args: Vec<ir::Expr>) -> ir::Expr {
431 let mut s = Substituter { function_id, args };
432 s.fold_expr(expr).unwrap()
433 }
434}
435
436impl fold::IrFold for Substituter {
437 fn fold_ptr(&mut self, ptr: ir::Pointer, ty: ir::Ty) -> Result<ir::Expr, ()> {
438 match &ptr {
439 ir::Pointer::Parameter(ptr) if ptr.function_id == self.function_id => {
440 Ok(self.args[ptr.param_position as usize].clone())
441 }
442 _ => {
443 let kind = ir::ExprKind::Pointer(ptr);
444 Ok(ir::Expr { kind, ty })
445 }
446 }
447 }
448}
449
450#[derive(Default)]
451pub(crate) struct IdCounter {
452 pub max_var_id: u32,
453 pub max_func_id: u32,
454}
455
456impl IdCounter {
457 pub(crate) fn run(mut program: ir::Program) -> (ir::Program, IdCounter) {
458 let mut c = Self::default();
459 program.main = c.fold_expr(program.main).unwrap();
460 (program, c)
461 }
462}
463
464impl fold::IrFold for IdCounter {
465 fn fold_func(&mut self, func: ir::Function, ty: ir::Ty) -> Result<ir::Expr, ()> {
466 self.max_func_id = u32::max(self.max_func_id, func.id);
467 fold::fold_func(self, func, ty)
468 }
469
470 fn fold_binding(&mut self, binding: ir::Binding, ty: ir::Ty) -> Result<ir::Expr, ()> {
471 self.max_var_id = u32::max(self.max_var_id, binding.id);
472 fold::fold_binding(self, binding, ty)
473 }
474
475 fn fold_ty(&mut self, ty: ir::Ty) -> Result<ir::Ty, ()> {
476 Ok(ty)
477 }
478}