Skip to main content

lutra_compiler/intermediate/
inliner.rs

1use std::collections::{HashMap, HashSet};
2
3use lutra_bin::ir;
4
5use super::fold::{self, IrFold};
6use crate::utils::IdGenerator;
7
8pub fn inline(program: ir::Program) -> ir::Program {
9    let (mut program, id_counts) = IdCounter::run(program);
10
11    // inline functions
12    let mut inliner = FuncInliner {
13        bindings: Default::default(),
14
15        currently_inlining: Default::default(),
16
17        generator_var_binding: IdGenerator::new_at(id_counts.max_var_id as usize),
18    };
19    program.main = inliner.fold_expr(program.main).unwrap();
20
21    tracing::debug!("ir (funcs inlined):\n{}", ir::print(&program));
22
23    // count bindings usage
24    let mut counter = BindingUsageCounter {
25        usage: Default::default(),
26        simple: Default::default(),
27    };
28    program.main = counter.fold_expr(program.main).unwrap();
29    tracing::debug!("binding_usage = {:?}", counter.usage);
30    tracing::debug!("simple_bindings = {:?}", counter.simple);
31
32    // inline bindings
33    let mut inliner = BindingInliner::new(counter.usage, counter.simple);
34    program.main = inliner.fold_expr(program.main).unwrap();
35
36    program
37}
38
39struct FuncInliner {
40    // bindings of functions
41    bindings: HashMap<u32, ir::Function>,
42
43    currently_inlining: HashSet<u32>,
44
45    generator_var_binding: IdGenerator,
46}
47
48impl fold::IrFold for FuncInliner {
49    fn fold_binding(&mut self, binding: ir::Binding, ty: ir::Ty) -> Result<ir::Expr, ()> {
50        // bindings of the function type
51        if binding.expr.ty.kind.is_function() {
52            match binding.expr.kind {
53                // that are function definitions
54                ir::ExprKind::Function(func) => {
55                    // fold the function
56                    let func = self.fold_func(*func, binding.expr.ty)?;
57                    let func = func.kind.into_function().unwrap();
58
59                    // store in self.bindings
60                    self.bindings.insert(binding.id, *func);
61
62                    // return just the main expr
63                    return self.fold_expr(binding.main);
64                }
65
66                ir::ExprKind::Pointer(_) => todo!(),
67
68                _ => panic!(),
69            }
70        }
71
72        fold::fold_binding(self, binding, ty)
73    }
74
75    fn fold_call(&mut self, call: ir::Call, ty: ir::Ty) -> Result<ir::Expr, ()> {
76        let args = fold::fold_exprs(self, call.args)?;
77
78        let function = match call.function.kind {
79            // calls to bound functions
80            ir::ExprKind::Pointer(ir::Pointer::Binding(ref binding_id)) => {
81                if self.currently_inlining.contains(binding_id) {
82                    panic!("recursive function cannot be inlined");
83                }
84                if let Some(func) = self.bindings.get(binding_id) {
85                    let expr = self.substitute_function(func.clone(), args);
86
87                    self.currently_inlining.insert(*binding_id);
88                    let expr = self.fold_expr(expr);
89                    self.currently_inlining.remove(binding_id);
90                    return expr;
91                } else {
92                    // panic!("binding not found: {binding_id}")
93                    call.function
94                }
95            }
96
97            // calls of lambda functions
98            ir::ExprKind::Function(func) => {
99                // substitute
100
101                let expr = self.substitute_function(*func, args);
102                return self.fold_expr(expr);
103            }
104
105            ir::ExprKind::Pointer(ir::Pointer::Parameter(_)) => call.function,
106
107            ir::ExprKind::Pointer(ir::Pointer::External(_)) => call.function,
108
109            _ => unreachable!(),
110        };
111
112        let kind = ir::ExprKind::Call(Box::new(ir::Call { function, args }));
113        Ok(ir::Expr { kind, ty })
114    }
115
116    fn fold_ptr(&mut self, ptr: ir::Pointer, ty: ir::Ty) -> Result<ir::Expr, ()> {
117        if let ir::Pointer::Binding(binding_id) = &ptr {
118            // special case: when there is a function ptr that is not called directly
119            // we have to inline it
120            if let Some(func) = self.bindings.get(binding_id) {
121                return Ok(ir::Expr {
122                    kind: ir::ExprKind::Function(Box::new(func.clone())),
123                    ty,
124                });
125            }
126        }
127        fold::fold_ptr(ptr, ty)
128    }
129
130    // optimization: inline unneeded switch cases
131    fn fold_switch(&mut self, branches: Vec<ir::SwitchBranch>, ty: ir::Ty) -> Result<ir::Expr, ()> {
132        // detect cases:
133        // (switch
134        //    (...cond..., bool_then)
135        //    (.anything., bool_else)
136        // )
137        fn as_bool_literal(expr: &ir::Expr) -> Option<bool> {
138            expr.kind.as_literal().and_then(|l| l.as_bool().cloned())
139        }
140        if matches!(ty.kind, ir::TyKind::Primitive(ir::TyPrimitive::bool))
141            && branches.len() == 2
142            && let Some(value_then) = as_bool_literal(&branches[0].value)
143            && let Some(value_else) = as_bool_literal(&branches[1].value)
144        {
145            let cond = branches.into_iter().next().unwrap().condition;
146
147            match (value_then, value_else) {
148                (true, true) | (false, false) => {
149                    // pathological case, we don't need to compare
150                    return Ok(ir::Expr::new_lit_bool(value_then));
151                }
152                (true, false) => {
153                    // this just executes the cond
154                    return self.fold_expr(cond);
155                }
156                (false, true) => {
157                    // this just inverts the cond
158                    let cond = self.fold_expr(cond)?;
159
160                    let ty_bool = ir::Ty::new(ir::TyPrimitive::bool);
161                    let std_not = ir::Expr::new(
162                        ir::ExternalPtr {
163                            id: "std::not".into(),
164                        },
165                        ir::Ty::new(ir::TyFunction {
166                            params: vec![ty_bool.clone()],
167                            body: ty_bool.clone(),
168                        }),
169                    );
170
171                    return Ok(ir::Expr::new(
172                        ir::Call {
173                            function: std_not,
174                            args: vec![cond],
175                        },
176                        ty_bool,
177                    ));
178                }
179            }
180        }
181
182        fold::fold_switch(self, branches, ty)
183    }
184}
185
186impl FuncInliner {
187    fn substitute_function(&mut self, func: ir::Function, args: Vec<ir::Expr>) -> ir::Expr {
188        // generate args bound to vars
189        let mut arg_var_ids = Vec::with_capacity(args.len());
190        let mut arg_pointers = Vec::with_capacity(args.len());
191        for arg in &args {
192            let id = self.generator_var_binding.next() as u32;
193            arg_var_ids.push(id);
194            arg_pointers.push(ir::Expr {
195                kind: ir::ExprKind::Pointer(ir::Pointer::Binding(id)),
196                ty: arg.ty.clone(),
197            });
198        }
199
200        // substitute
201        tracing::debug!("inlining call to function {} with {arg_var_ids:?}", func.id);
202        let mut expr = Substituter::run(func.body, func.id, arg_pointers);
203
204        // wrap in Bindings
205        for (id, arg) in std::iter::zip(arg_var_ids, args) {
206            expr = ir::Expr {
207                ty: expr.ty.clone(),
208                kind: ir::ExprKind::Binding(Box::new(ir::Binding {
209                    id,
210                    expr: arg,
211                    main: expr,
212                })),
213            }
214        }
215        expr
216    }
217}
218
219struct BindingUsageCounter {
220    usage: HashMap<u32, usize>,
221    simple: HashSet<u32>,
222}
223
224impl BindingUsageCounter {
225    fn is_simple_expr(expr: &ir::Expr) -> bool {
226        match &expr.kind {
227            ir::ExprKind::Literal(ir::Literal::text(_)) => false,
228            ir::ExprKind::Literal(_) => true,
229            ir::ExprKind::Pointer(_) => true,
230            ir::ExprKind::TupleLookup(lookup) => Self::is_simple_expr(&lookup.base),
231            ir::ExprKind::Tuple(fields) => fields.iter().all(|f| Self::is_simple_expr(&f.expr)),
232            _ => false,
233        }
234    }
235}
236
237impl fold::IrFold for BindingUsageCounter {
238    fn fold_binding(&mut self, binding: ir::Binding, ty: ir::Ty) -> Result<ir::Expr, ()> {
239        self.usage.insert(binding.id, 0);
240
241        // Check if this binding is simple
242        if Self::is_simple_expr(&binding.expr) {
243            self.simple.insert(binding.id);
244        }
245
246        fold::fold_binding(self, binding, ty)
247    }
248
249    fn fold_ptr(&mut self, ptr: ir::Pointer, ty: ir::Ty) -> Result<ir::Expr, ()> {
250        if let ir::Pointer::Binding(binding_id) = &ptr {
251            // count usage of bindings
252            *self.usage.entry(*binding_id).or_default() += 1;
253        }
254        fold::fold_ptr(ptr, ty)
255    }
256}
257
258struct BindingInliner {
259    bindings: HashMap<u32, ir::Expr>,
260
261    to_inline: HashSet<u32>,
262}
263
264impl BindingInliner {
265    fn new(bindings_usage: HashMap<u32, usize>, simple: HashSet<u32>) -> Self {
266        // inline vars that are used 1 or 0 times, or are simple
267        let to_inline: HashSet<u32> = bindings_usage
268            .into_iter()
269            .filter(|(id, usage_count)| *usage_count <= 1 || simple.contains(id))
270            .map(|(id, _)| id)
271            .collect();
272
273        tracing::debug!("inlining vars: {:?}", to_inline);
274
275        BindingInliner {
276            bindings: Default::default(),
277            to_inline,
278        }
279    }
280}
281
282impl IrFold for BindingInliner {
283    fn fold_binding(&mut self, binding: ir::Binding, ty: ir::Ty) -> Result<ir::Expr, ()> {
284        if self.to_inline.contains(&binding.id) {
285            // store in self.bindings
286            let expr = self.fold_expr(binding.expr)?;
287            self.bindings.insert(binding.id, expr);
288
289            // return just the main expr
290            return self.fold_expr(binding.main);
291        }
292        fold::fold_binding(self, binding, ty)
293    }
294    fn fold_ptr(&mut self, ptr: ir::Pointer, ty: ir::Ty) -> Result<ir::Expr, ()> {
295        if let ir::Pointer::Binding(binding_id) = &ptr {
296            // replace ptr with bound value
297
298            if let Some(value) = self.bindings.get(binding_id) {
299                return Ok(value.clone());
300            }
301        }
302        fold::fold_ptr(ptr, ty)
303    }
304
305    // optimization: simplify std::cmp
306    fn fold_enum_eq(&mut self, enum_eq: ir::EnumEq, ty: ir::Ty) -> Result<ir::Expr, ()> {
307        // normal fold
308        let enum_eq = ir::EnumEq {
309            tag: enum_eq.tag,
310            subject: self.fold_expr(enum_eq.subject)?,
311        };
312
313        // detect cases:
314        // (enum_eq
315        //   (call
316        //      external.std::cmp,
317        //      a,
318        //      b,
319        //   ),
320        //   tag
321        // )
322        if let ir::ExprKind::Call(call) = &enum_eq.subject.kind
323            && let ir::ExprKind::Pointer(ir::Pointer::External(func)) = &call.function.kind
324            && func.id == "std::cmp"
325        {
326            let (cmp_func, swap) = match enum_eq.tag {
327                0 => ("std::lt", false),
328                1 => ("std::eq", false),
329                2 => ("std::lt", true),
330                _ => unreachable!(),
331            };
332
333            let mut func_ty = call.function.ty.clone();
334            func_ty.kind.as_function_mut().unwrap().body = ir::Ty::new(ir::TyPrimitive::bool);
335
336            let function = ir::Expr::new(
337                ir::ExternalPtr {
338                    id: cmp_func.to_string(),
339                },
340                func_ty,
341            );
342
343            let mut args = call.args.clone();
344            if swap {
345                args.reverse();
346            }
347
348            return Ok(ir::Expr::new(ir::Call { function, args }, ty));
349        }
350
351        Ok(ir::Expr {
352            kind: ir::ExprKind::EnumEq(Box::new(enum_eq)),
353            ty,
354        })
355    }
356
357    // optimization: simplify call chains
358    fn fold_call(&mut self, call: ir::Call, ty: ir::Ty) -> Result<ir::Expr, ()> {
359        let expr = fold::fold_call(self, call, ty)?;
360
361        fn as_external(expr: &ir::Expr) -> Option<&str> {
362            expr.kind
363                .as_pointer()
364                .and_then(|p| p.as_external())
365                .map(|e| e.id.as_str())
366        }
367        fn as_external_mut(expr: &mut ir::Expr) -> Option<&mut String> {
368            expr.kind
369                .as_pointer_mut()
370                .and_then(|p| p.as_external_mut())
371                .map(|e| &mut e.id)
372        }
373
374        // detect:
375        // (call
376        //   external.outer_id,
377        //   (call
378        //     external.inner_id,
379        //     ..inner_args..
380        //   ),
381        //   ..outer_args..
382        // )
383        if let ir::ExprKind::Call(outer) = &expr.kind
384            && let Some(outer_id) = as_external(&outer.function)
385            && let ir::ExprKind::Call(inner) = &outer.args[0].kind
386            && let Some(inner_id) = as_external(&inner.function)
387        {
388            match (outer_id, inner_id) {
389                ("std::not", "std::not") => {
390                    // not(not(x)) --> x
391                    return Ok(inner.args[0].clone());
392                }
393
394                ("std::not", "std::lt") => {
395                    // not(lt(a, b)) --> lte(b, a)
396                    let mut call = *inner.clone();
397
398                    let func_id = as_external_mut(&mut call.function).unwrap();
399                    *func_id = "std::lte".to_string();
400
401                    call.args.reverse();
402
403                    return Ok(ir::Expr::new(call, expr.ty));
404                }
405                ("std::not", "std::lte") => {
406                    // not(lte(a, b)) --> lt(b, a)
407                    let mut call = *inner.clone();
408
409                    let func_id = as_external_mut(&mut call.function).unwrap();
410                    *func_id = "std::lt".to_string();
411
412                    call.args.reverse();
413
414                    return Ok(ir::Expr::new(call, expr.ty));
415                }
416                _ => {}
417            }
418        }
419
420        Ok(expr)
421    }
422}
423
424struct Substituter {
425    function_id: u32,
426    args: Vec<ir::Expr>,
427}
428
429impl Substituter {
430    fn run(expr: ir::Expr, function_id: u32, args: Vec<ir::Expr>) -> ir::Expr {
431        let mut s = Substituter { function_id, args };
432        s.fold_expr(expr).unwrap()
433    }
434}
435
436impl fold::IrFold for Substituter {
437    fn fold_ptr(&mut self, ptr: ir::Pointer, ty: ir::Ty) -> Result<ir::Expr, ()> {
438        match &ptr {
439            ir::Pointer::Parameter(ptr) if ptr.function_id == self.function_id => {
440                Ok(self.args[ptr.param_position as usize].clone())
441            }
442            _ => {
443                let kind = ir::ExprKind::Pointer(ptr);
444                Ok(ir::Expr { kind, ty })
445            }
446        }
447    }
448}
449
450#[derive(Default)]
451pub(crate) struct IdCounter {
452    pub max_var_id: u32,
453    pub max_func_id: u32,
454}
455
456impl IdCounter {
457    pub(crate) fn run(mut program: ir::Program) -> (ir::Program, IdCounter) {
458        let mut c = Self::default();
459        program.main = c.fold_expr(program.main).unwrap();
460        (program, c)
461    }
462}
463
464impl fold::IrFold for IdCounter {
465    fn fold_func(&mut self, func: ir::Function, ty: ir::Ty) -> Result<ir::Expr, ()> {
466        self.max_func_id = u32::max(self.max_func_id, func.id);
467        fold::fold_func(self, func, ty)
468    }
469
470    fn fold_binding(&mut self, binding: ir::Binding, ty: ir::Ty) -> Result<ir::Expr, ()> {
471        self.max_var_id = u32::max(self.max_var_id, binding.id);
472        fold::fold_binding(self, binding, ty)
473    }
474
475    fn fold_ty(&mut self, ty: ir::Ty) -> Result<ir::Ty, ()> {
476        Ok(ty)
477    }
478}