Skip to main content

ling_codegen/cranelift/
numtype.rs

1//! Whole-program number-type inference.
2//!
3//! Ling values are NaN-boxed: a `u64` is either a raw `f64` or a tagged pointer/
4//! singleton. The naive code path tag-checks both operands on every arithmetic op
5//! and branches to a runtime fallback. That overhead dominates numeric loops.
6//!
7//! This pass recovers static type information the untyped MIR throws away. A local
8//! is a *number* when every value that can flow into it is provably an `f64`.
9//! Because Ling discards declared parameter types, parameter and return types are
10//! inferred interprocedurally from call sites with a greatest fixpoint: assume
11//! everything is a number, then retract that wherever a non-number can reach.
12//!
13//! The JIT/AOT backends consult the result to emit raw `fadd`/`fcmp`/… with no tag
14//! checks wherever both operands are known numbers, closing most of the gap to
15//! native code.
16
17use ling_ast::ast::{BinOp, UnOp};
18use ling_mir::ir::*;
19use std::collections::{HashMap, HashSet};
20
21/// Per-function static types: which local indices are proven numbers (`f64`) or
22/// strict booleans (a `TAG_TRUE`/`TAG_FALSE` singleton).
23#[derive(Default)]
24pub struct NumberTypes {
25    locals: HashMap<String, HashSet<usize>>,
26    bools: HashMap<String, HashSet<usize>>,
27}
28
29impl NumberTypes {
30    /// Whether `local` in function `func` is statically known to be a number.
31    pub fn local_is_num(&self, func: &str, local: usize) -> bool {
32        self.locals.get(func).is_some_and(|s| s.contains(&local))
33    }
34
35    /// Whether `op` evaluates to a number inside function `func`.
36    pub fn operand_is_num(&self, func: &str, op: &Operand) -> bool {
37        match op {
38            Operand::Copy(l) | Operand::Move(l) => self.local_is_num(func, l.0),
39            Operand::Constant(c) => matches!(c, Constant::I64(_) | Constant::F64(_)),
40        }
41    }
42
43    /// Whether `op` is a strict boolean inside function `func` (so a branch can
44    /// test it directly against `TAG_TRUE` rather than running full truthiness).
45    pub fn operand_is_bool(&self, func: &str, op: &Operand) -> bool {
46        match op {
47            Operand::Copy(l) | Operand::Move(l) => {
48                self.bools.get(func).is_some_and(|s| s.contains(&l.0))
49            },
50            Operand::Constant(Constant::Bool(_)) => true,
51            _ => false,
52        }
53    }
54}
55
56fn bool_binop(op: &BinOp) -> bool {
57    matches!(
58        op,
59        BinOp::Eq
60            | BinOp::Ne
61            | BinOp::Lt
62            | BinOp::Le
63            | BinOp::Gt
64            | BinOp::Ge
65            | BinOp::And
66            | BinOp::Or
67    )
68}
69
70/// Returns true if a binary op always produces a number given numeric inputs.
71fn arith_binop(op: &BinOp) -> bool {
72    matches!(
73        op,
74        BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Rem
75    )
76}
77
78/// Compute number-ness for every function in the program.
79pub fn analyze(functions: &[MirFunction]) -> NumberTypes {
80    let by_name: HashMap<&str, &MirFunction> =
81        functions.iter().map(|f| (f.name.as_str(), f)).collect();
82
83    // Call sites: callee name -> list of (caller name, args).
84    let mut call_sites: HashMap<String, Vec<(String, Vec<Operand>)>> = HashMap::new();
85    // Functions whose name is used as a value (not a direct callee): their
86    // parameters can be invoked with unknown arguments, so stay non-number.
87    let mut address_taken: HashSet<String> = HashSet::new();
88
89    for func in functions {
90        for bb in &func.basic_blocks {
91            for stmt in &bb.statements {
92                if let StatementKind::Assign(_, rval) = &stmt.kind {
93                    match rval {
94                        Rvalue::Call { func: callee, args } => {
95                            if let Operand::Constant(Constant::Function(name)) = callee {
96                                call_sites
97                                    .entry(name.clone())
98                                    .or_default()
99                                    .push((func.name.clone(), args.clone()));
100                            }
101                            // A function passed as a call argument is address-taken.
102                            for a in args {
103                                if let Operand::Constant(Constant::Function(n)) = a {
104                                    address_taken.insert(n.clone());
105                                }
106                            }
107                        },
108                        Rvalue::Use(Operand::Constant(Constant::Function(n))) => {
109                            address_taken.insert(n.clone());
110                        },
111                        _ => {},
112                    }
113                }
114            }
115        }
116    }
117
118    // State: per function, which locals are currently believed to be numbers.
119    let mut state: HashMap<String, HashSet<usize>> = HashMap::new();
120    for func in functions {
121        // Optimistically assume every local is a number.
122        let all: HashSet<usize> = (0..func.locals.len() + func.arg_count + 1).collect();
123        state.insert(func.name.clone(), all);
124    }
125
126    let num_of = |state: &HashMap<String, HashSet<usize>>, func: &str, op: &Operand| -> bool {
127        match op {
128            Operand::Copy(l) | Operand::Move(l) => {
129                state.get(func).is_some_and(|s| s.contains(&l.0))
130            },
131            Operand::Constant(c) => matches!(c, Constant::I64(_) | Constant::F64(_)),
132        }
133    };
134
135    let mut changed = true;
136    while changed {
137        changed = false;
138
139        // 1. Parameters: a parameter is a number iff every call site passes a
140        //    number, the function is reachable from a call, and it is not invoked
141        //    indirectly with unknown arguments.
142        let mut param_num: HashMap<String, Vec<bool>> = HashMap::new();
143        for func in functions {
144            let mut pnums = vec![false; func.arg_count];
145            let sites = call_sites.get(&func.name);
146            let callable_directly = sites.is_some() && !address_taken.contains(&func.name);
147            if callable_directly {
148                for (j, pnum) in pnums.iter_mut().enumerate() {
149                    *pnum = sites.unwrap().iter().all(|(caller, args)| {
150                        args.get(j).is_some_and(|a| num_of(&state, caller, a))
151                    });
152                }
153            }
154            param_num.insert(func.name.clone(), pnums);
155        }
156
157        // 2. Locals: recompute from assignments. Parameters take their inferred
158        //    type; temporaries and the return slot are the meet of all writers.
159        for func in functions {
160            let pnums = &param_num[&func.name];
161            // Gather assignments per local.
162            let mut writers: HashMap<usize, Vec<&Rvalue>> = HashMap::new();
163            for bb in &func.basic_blocks {
164                for stmt in &bb.statements {
165                    if let StatementKind::Assign(l, rval) = &stmt.kind {
166                        writers.entry(l.0).or_default().push(rval);
167                    }
168                }
169            }
170
171            let total = func.locals.len() + func.arg_count + 1;
172            let mut new_set = HashSet::new();
173            for idx in 0..total {
174                // Parameters: Local(1..=arg_count).
175                if idx >= 1 && idx <= func.arg_count {
176                    if pnums[idx - 1] {
177                        new_set.insert(idx);
178                    }
179                    continue;
180                }
181                let assigns = writers.get(&idx);
182                let is_num = match assigns {
183                    // Never written and not a parameter: treat as non-number.
184                    None => false,
185                    Some(rvals) => rvals
186                        .iter()
187                        .all(|r| rvalue_is_num(r, &state, &param_num, func, &by_name)),
188                };
189                if is_num {
190                    new_set.insert(idx);
191                }
192            }
193
194            let prev = state.get(&func.name);
195            if prev != Some(&new_set) {
196                changed = true;
197                state.insert(func.name.clone(), new_set);
198            }
199        }
200    }
201
202    // Booleans are intra-procedural: a local is bool when every writer is a
203    // comparison/logical op, `!`, a bool constant, or a copy of a bool. Iterate to
204    // a fixpoint so copy chains converge.
205    let mut bools: HashMap<String, HashSet<usize>> = HashMap::new();
206    for func in functions {
207        let mut writers: HashMap<usize, Vec<&Rvalue>> = HashMap::new();
208        for bb in &func.basic_blocks {
209            for stmt in &bb.statements {
210                if let StatementKind::Assign(l, rval) = &stmt.kind {
211                    writers.entry(l.0).or_default().push(rval);
212                }
213            }
214        }
215        let mut set: HashSet<usize> = HashSet::new();
216        let mut changed = true;
217        while changed {
218            changed = false;
219            for (&idx, rvals) in &writers {
220                if set.contains(&idx) {
221                    continue;
222                }
223                let is_bool = rvals.iter().all(|r| match r {
224                    Rvalue::BinaryOp(op, _, _) => bool_binop(op),
225                    Rvalue::UnaryOp(UnOp::Not, _) => true,
226                    Rvalue::Use(Operand::Constant(Constant::Bool(_))) => true,
227                    Rvalue::Use(Operand::Copy(l)) | Rvalue::Use(Operand::Move(l)) => {
228                        set.contains(&l.0)
229                    },
230                    _ => false,
231                });
232                if is_bool {
233                    set.insert(idx);
234                    changed = true;
235                }
236            }
237        }
238        bools.insert(func.name.clone(), set);
239    }
240
241    NumberTypes { locals: state, bools }
242}
243
244/// Whether an rvalue produces a number, given the current global estimate.
245fn rvalue_is_num(
246    rval: &Rvalue,
247    state: &HashMap<String, HashSet<usize>>,
248    param_num: &HashMap<String, Vec<bool>>,
249    func: &MirFunction,
250    by_name: &HashMap<&str, &MirFunction>,
251) -> bool {
252    let op_num = |op: &Operand| -> bool {
253        match op {
254            Operand::Copy(l) | Operand::Move(l) => {
255                state.get(&func.name).is_some_and(|s| s.contains(&l.0))
256            },
257            Operand::Constant(c) => matches!(c, Constant::I64(_) | Constant::F64(_)),
258        }
259    };
260    match rval {
261        Rvalue::Use(op) => op_num(op),
262        Rvalue::BinaryOp(op, a, b) => arith_binop(op) && op_num(a) && op_num(b),
263        Rvalue::UnaryOp(UnOp::Neg, a) => op_num(a),
264        Rvalue::UnaryOp(_, _) => false,
265        Rvalue::Call { func: callee, .. } => {
266            // Return type of a directly-called function: number-ness of its Local 0.
267            if let Operand::Constant(Constant::Function(name)) = callee {
268                if by_name.contains_key(name.as_str()) {
269                    // Use param_num presence to confirm it's a known function, then
270                    // read its return slot from the running estimate.
271                    let _ = param_num;
272                    return state.get(name).is_some_and(|s| s.contains(&0));
273                }
274            }
275            false
276        },
277        _ => false,
278    }
279}