Skip to main content

ling_codegen/cranelift/
aot.rs

1use super::numtype;
2use super::translate::{build_function_body, max_local_index, TransCtx};
3use crate::CodegenBackend;
4use crate::MirProgram;
5use anyhow::Result;
6use cranelift::codegen::ir::{FuncRef, GlobalValue};
7use cranelift::prelude::*;
8use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
9use cranelift_module::DataId;
10use cranelift_module::{DataDescription, FuncId, Linkage, Module};
11use cranelift_object::{ObjectBuilder, ObjectModule};
12use ling_mir::ir::*;
13use std::collections::{HashMap, HashSet};
14use std::io::IsTerminal;
15
16// ─── AOT Backend ───────────────────────────────────────────────────────────
17
18pub struct CraneliftBackend {
19    module: Option<ObjectModule>,
20    builder_ctx: FunctionBuilderContext,
21    progress: bool,
22}
23
24/// Render a single-line tqdm-style progress bar to stderr (overwriting in place
25/// with `\r`). Colors follow the Ling palette: teal fill on a grey track.
26fn render_progress(done: usize, total: usize, label: &str) {
27    use std::io::Write as _;
28    const WIDTH: usize = 28;
29    let frac = if total == 0 { 1.0 } else { done as f64 / total as f64 };
30    let filled = ((frac * WIDTH as f64).round() as usize).min(WIDTH);
31    let bar_full = "█".repeat(filled);
32    let bar_empty = "░".repeat(WIDTH - filled);
33    let pct = (frac * 100.0) as u32;
34    // truncate the label so the line never wraps the terminal
35    let label: String = label.chars().take(24).collect();
36    let mut err = std::io::stderr();
37    let _ = write!(
38        err,
39        "\r    compiling \x1b[38;5;37m{bar_full}\x1b[38;5;240m{bar_empty}\x1b[0m {pct:>3}% [{done}/{total}] {label:<24}",
40    );
41    let _ = err.flush();
42}
43
44// ─── Runtime Function Declarations ─────────────────────────────────────────
45
46struct RuntimeDecl {
47    id: FuncId,
48}
49
50fn declare_runtime_functions(module: &mut ObjectModule) -> HashMap<String, RuntimeDecl> {
51    let mut decls = HashMap::new();
52
53    let runtime_fns: &[(&str, &[types::Type], types::Type)] = &[
54        ("ling_f64_add", &[types::F64, types::F64], types::F64),
55        ("ling_f64_sub", &[types::F64, types::F64], types::F64),
56        ("ling_f64_mul", &[types::F64, types::F64], types::F64),
57        ("ling_f64_div", &[types::F64, types::F64], types::F64),
58        ("ling_f64_rem", &[types::F64, types::F64], types::F64),
59        ("ling_f64_neg", &[types::F64], types::F64),
60        ("ling_f64_eq", &[types::F64, types::F64], types::I64),
61        ("ling_f64_lt", &[types::F64, types::F64], types::I64),
62        ("ling_f64_gt", &[types::F64, types::F64], types::I64),
63        ("ling_f64_le", &[types::F64, types::F64], types::I64),
64        ("ling_f64_ge", &[types::F64, types::F64], types::I64),
65        ("ling_sin", &[types::F64], types::F64),
66        ("ling_cos", &[types::F64], types::F64),
67        ("ling_sqrt", &[types::F64], types::F64),
68        ("ling_abs", &[types::F64], types::F64),
69        ("ling_floor", &[types::F64], types::F64),
70        ("ling_ceil", &[types::F64], types::F64),
71        ("ling_round", &[types::F64], types::F64),
72        ("ling_add", &[types::I64, types::I64], types::I64),
73        ("ling_sub", &[types::I64, types::I64], types::I64),
74        ("ling_mul", &[types::I64, types::I64], types::I64),
75        ("ling_div", &[types::I64, types::I64], types::I64),
76        ("ling_rem", &[types::I64, types::I64], types::I64),
77        ("ling_neg", &[types::I64, types::I64], types::I64),
78        ("ling_eq", &[types::I64, types::I64], types::I64),
79        ("ling_ne", &[types::I64, types::I64], types::I64),
80        ("ling_lt", &[types::I64, types::I64], types::I64),
81        ("ling_le", &[types::I64, types::I64], types::I64),
82        ("ling_gt", &[types::I64, types::I64], types::I64),
83        ("ling_ge", &[types::I64, types::I64], types::I64),
84        ("ling_and", &[types::I64, types::I64], types::I64),
85        ("ling_or", &[types::I64, types::I64], types::I64),
86        ("ling_not", &[types::I64], types::I64),
87        ("ling_bool_to_u64", &[types::I64], types::I64),
88        ("ling_alloc", &[types::I64], types::I64),
89        ("ling_free", &[types::I64], types::I64),
90        ("ling_panic", &[types::I64], types::I64),
91        ("ling_str_new", &[types::I64, types::I64], types::I64),
92        ("ling_str_len", &[types::I64], types::I64),
93        ("ling_str_concat", &[types::I64, types::I64], types::I64),
94        ("ling_str_eq", &[types::I64, types::I64], types::I64),
95        ("ling_list_new", &[], types::I64),
96        ("ling_list_push", &[types::I64, types::I64], types::I64),
97        ("ling_list_get", &[types::I64, types::I64], types::I64),
98        ("ling_list_len", &[types::I64], types::I64),
99        (
100            "ling_struct_new",
101            &[types::I64, types::I64, types::I64, types::I64],
102            types::I64,
103        ),
104        (
105            "ling_struct_get",
106            &[types::I64, types::I64, types::I64],
107            types::I64,
108        ),
109        ("ling_print", &[types::I64], types::I64),
110        ("ling_print_val", &[types::I64], types::I64),
111        ("ling_print_newline", &[], types::I64),
112        ("ling_time_now", &[], types::I64),
113        (
114            "ling_builtin",
115            &[types::I64, types::I64, types::I64, types::I64],
116            types::I64,
117        ),
118    ];
119
120    for &(name, params, ret) in runtime_fns {
121        let mut sig = module.make_signature();
122        for &pt in params {
123            sig.params.push(AbiParam::new(pt));
124        }
125        sig.returns.push(AbiParam::new(ret));
126        let id = module
127            .declare_function(name, Linkage::Import, &sig)
128            .unwrap();
129        decls.insert(name.to_string(), RuntimeDecl { id });
130    }
131
132    decls
133}
134
135// ─── String/Builtin name collection ─────────────────────────────────────────
136
137fn collect_string_constants(
138    functions: &[MirFunction],
139    module: &mut ObjectModule,
140) -> (HashMap<String, DataId>, HashMap<String, DataId>) {
141    let mut string_ids: HashMap<String, DataId> = HashMap::new();
142    let mut builtin_ids: HashMap<String, DataId> = HashMap::new();
143    for func in functions {
144        for bb in &func.basic_blocks {
145            for stmt in &bb.statements {
146                if let StatementKind::Assign(_, rval) = &stmt.kind {
147                    visit_rvalue_strings(rval, module, &mut string_ids);
148                    visit_rvalue_builtin_names(rval, module, &mut builtin_ids);
149                }
150            }
151            if let Some(term) = &bb.terminator {
152                visit_term_strings(term, module, &mut string_ids);
153            }
154        }
155    }
156    (string_ids, builtin_ids)
157}
158
159fn visit_operand_strings(
160    op: &Operand,
161    module: &mut ObjectModule,
162    string_ids: &mut HashMap<String, DataId>,
163) {
164    if let Operand::Constant(Constant::Str(s)) = op {
165        if !string_ids.contains_key(s) {
166            let name = format!("__str_{}", string_ids.len());
167            let data_id = module
168                .declare_data(&name, Linkage::Local, true, false)
169                .unwrap();
170            let mut desc = DataDescription::new();
171            desc.define(s.as_bytes().to_vec().into_boxed_slice());
172            desc.set_align(1);
173            module.define_data(data_id, &desc).unwrap();
174            string_ids.insert(s.clone(), data_id);
175        }
176    }
177}
178
179fn visit_rvalue_builtin_names(
180    rval: &Rvalue,
181    module: &mut ObjectModule,
182    builtin_ids: &mut HashMap<String, DataId>,
183) {
184    if let Rvalue::Call { func: Operand::Constant(Constant::Function(n)), .. } = rval {
185        if !builtin_ids.contains_key(n) {
186            let name = format!("__builtin_{}", builtin_ids.len());
187            let data_id = module
188                .declare_data(&name, Linkage::Local, true, false)
189                .unwrap();
190            let mut desc = DataDescription::new();
191            let mut bytes = n.as_bytes().to_vec();
192            bytes.push(0);
193            desc.define(bytes.into_boxed_slice());
194            desc.set_align(1);
195            module.define_data(data_id, &desc).unwrap();
196            builtin_ids.insert(n.clone(), data_id);
197        }
198    }
199}
200
201fn visit_rvalue_strings(
202    rval: &Rvalue,
203    module: &mut ObjectModule,
204    string_ids: &mut HashMap<String, DataId>,
205) {
206    match rval {
207        Rvalue::Use(op) | Rvalue::UnaryOp(_, op) => visit_operand_strings(op, module, string_ids),
208        Rvalue::BinaryOp(_, lhs, rhs) => {
209            visit_operand_strings(lhs, module, string_ids);
210            visit_operand_strings(rhs, module, string_ids);
211        },
212        Rvalue::Call { args, .. } => {
213            for arg in args {
214                visit_operand_strings(arg, module, string_ids);
215            }
216        },
217        Rvalue::Aggregate(_, ops) => {
218            for op in ops {
219                visit_operand_strings(op, module, string_ids);
220            }
221        },
222        _ => {},
223    }
224}
225
226fn visit_term_strings(
227    term: &Terminator,
228    module: &mut ObjectModule,
229    string_ids: &mut HashMap<String, DataId>,
230) {
231    if let TerminatorKind::SwitchInt { discr, .. } = &term.kind {
232        visit_operand_strings(discr, module, string_ids);
233    }
234}
235
236// ─── Per-function symbol references ──────────────────────────────────────────
237
238/// The data/function symbols a single function actually references: the string
239/// constants it materializes and the callee/builtin names it invokes.
240///
241/// Declaring every global string, builtin, and user function into *every*
242/// function is O(functions × symbols) in both time and per-function IR size,
243/// which exhausts memory on large programs. The translator only ever looks up
244/// the symbols a function references, so we declare only those.
245#[derive(Default)]
246struct FuncRefs {
247    strings: HashSet<String>,
248    /// Direct callee names — either user functions (resolved to `FuncId`) or
249    /// builtin names (resolved to a name-string `DataId`).
250    names: HashSet<String>,
251}
252
253fn collect_func_refs(func: &MirFunction) -> FuncRefs {
254    let mut refs = FuncRefs::default();
255    for bb in &func.basic_blocks {
256        for stmt in &bb.statements {
257            if let StatementKind::Assign(_, rval) = &stmt.kind {
258                collect_rvalue_refs(rval, &mut refs);
259            }
260        }
261        if let Some(term) = &bb.terminator {
262            if let TerminatorKind::SwitchInt { discr, .. } = &term.kind {
263                collect_operand_str(discr, &mut refs);
264            }
265        }
266    }
267    refs
268}
269
270fn collect_operand_str(op: &Operand, refs: &mut FuncRefs) {
271    if let Operand::Constant(Constant::Str(s)) = op {
272        refs.strings.insert(s.clone());
273    }
274}
275
276fn collect_rvalue_refs(rval: &Rvalue, refs: &mut FuncRefs) {
277    match rval {
278        Rvalue::Use(op) | Rvalue::UnaryOp(_, op) => collect_operand_str(op, refs),
279        Rvalue::BinaryOp(_, lhs, rhs) => {
280            collect_operand_str(lhs, refs);
281            collect_operand_str(rhs, refs);
282        },
283        Rvalue::Call { func, args } => {
284            if let Operand::Constant(Constant::Function(n)) = func {
285                refs.names.insert(n.clone());
286            }
287            for arg in args {
288                collect_operand_str(arg, refs);
289            }
290        },
291        Rvalue::Aggregate(_, ops) => {
292            for op in ops {
293                collect_operand_str(op, refs);
294            }
295        },
296        _ => {},
297    }
298}
299
300// ─── Main AOT compilation ───────────────────────────────────────────────────
301
302impl CraneliftBackend {
303    pub fn new() -> Self {
304        let mut flag_builder = settings::builder();
305        flag_builder.set("is_pic", "true").unwrap();
306        flag_builder.set("opt_level", "speed").unwrap();
307        flag_builder.set("enable_alias_analysis", "true").unwrap();
308        flag_builder.set("enable_verifier", "false").unwrap();
309        let isa_builder = cranelift::native::builder()
310            .unwrap_or_else(|_| isa::lookup_by_name("aarch64").unwrap());
311        let isa = isa_builder
312            .finish(settings::Flags::new(flag_builder))
313            .unwrap();
314        let obj_builder = ObjectBuilder::new(
315            isa,
316            "ling_program",
317            cranelift_module::default_libcall_names(),
318        )
319        .expect("ObjectBuilder");
320        let module = ObjectModule::new(obj_builder);
321        Self {
322            module: Some(module),
323            builder_ctx: FunctionBuilderContext::new(),
324            progress: false,
325        }
326    }
327
328    /// Enable a tqdm-style progress bar over the per-function codegen loop.
329    /// Off by default so library/test use stays silent.
330    pub fn with_progress(mut self, on: bool) -> Self {
331        self.progress = on;
332        self
333    }
334}
335
336impl CodegenBackend for CraneliftBackend {
337    fn emit(&mut self, program: &MirProgram, out: &std::path::Path) -> Result<()> {
338        let module: &mut ObjectModule = self.module.as_mut().unwrap();
339
340        let num_types = numtype::analyze(&program.mir.functions);
341
342        // Phase 0: declare all runtime functions as imports
343        let runtime_decls = declare_runtime_functions(module);
344
345        // Phase 1: collect and declare string/builtin data objects
346        let (string_ids, builtin_ids) = collect_string_constants(&program.mir.functions, module);
347
348        // Phase 2: declare all user functions as exports
349        let mut func_ids: HashMap<String, FuncId> = HashMap::new();
350        for func in &program.mir.functions {
351            let mut sig = module.make_signature();
352            for _ in 0..func.arg_count {
353                sig.params.push(AbiParam::new(types::I64));
354            }
355            sig.returns.push(AbiParam::new(types::I64));
356            let id = module
357                .declare_function(&func.name, Linkage::Export, &sig)
358                .unwrap();
359            func_ids.insert(func.name.clone(), id);
360        }
361
362        // Phase 3: translate each function body
363        let total = program.mir.functions.len();
364        // Throttle redraws so huge programs don't spend real time on the bar.
365        let step = (total / 100).max(1);
366        // Only draw the live bar on a real terminal; piped/CI logs stay clean.
367        let show_progress = self.progress && total > 1 && std::io::stderr().is_terminal();
368        for (idx, func) in program.mir.functions.iter().enumerate() {
369            if show_progress && (idx % step == 0 || idx + 1 == total) {
370                let label = if func.name == "__main__" {
371                    "main"
372                } else {
373                    func.name.as_str()
374                };
375                render_progress(idx + 1, total, label);
376            }
377            let &fid = func_ids.get(&func.name).unwrap();
378
379            let mut ctx = module.make_context();
380            let mut sig = module.make_signature();
381            for _ in 0..func.arg_count {
382                sig.params.push(AbiParam::new(types::I64));
383            }
384            sig.returns.push(AbiParam::new(types::I64));
385            ctx.func.signature = sig;
386
387            let mut builder = FunctionBuilder::new(&mut ctx.func, &mut self.builder_ctx);
388            let blocks: Vec<Block> = func
389                .basic_blocks
390                .iter()
391                .map(|_| builder.create_block())
392                .collect();
393
394            let max_local = max_local_index(func);
395            let mut vars: HashMap<Local, Variable> = HashMap::new();
396            for i in 0..=max_local {
397                vars.insert(Local(i), builder.declare_var(types::I64));
398            }
399
400            // Declare only the symbols this function actually references. The
401            // global `string_ids`/`builtin_ids`/`func_ids` maps cover the whole
402            // program; declaring all of them into every function is O(functions ×
403            // symbols) and runs the compiler out of memory on large projects.
404            let refs = collect_func_refs(func);
405
406            // Build string GlobalValue map for this function
407            let mut string_gvs: HashMap<String, GlobalValue> = HashMap::new();
408            for s in &refs.strings {
409                if let Some(&data_id) = string_ids.get(s) {
410                    let gv = module.declare_data_in_func(data_id, builder.func);
411                    string_gvs.insert(s.clone(), gv);
412                }
413            }
414
415            // Build runtime FuncRef map for this function. The shared translator
416            // looks up runtime helpers by their JIT symbol name (`__ling_*`); the
417            // object backend declares them without that prefix, so re-key here.
418            let mut runtime_refs: HashMap<String, FuncRef> = HashMap::new();
419            for (name, decl) in &runtime_decls {
420                let fr = module.declare_func_in_func(decl.id, builder.func);
421                runtime_refs.insert(format!("__{name}"), fr);
422            }
423
424            // FuncRefs are function-local, so callee references must be declared
425            // into this builder (not shared across functions). A referenced name is
426            // either a user function (declared as a FuncRef) or a builtin (declared
427            // as a name-string GlobalValue for the `__ling_builtin` dispatch path).
428            let mut func_refs: HashMap<String, FuncRef> = HashMap::new();
429            let mut builtin_gvs: HashMap<String, GlobalValue> = HashMap::new();
430            for name in &refs.names {
431                if let Some(&id) = func_ids.get(name) {
432                    let fr = module.declare_func_in_func(id, builder.func);
433                    func_refs.insert(name.clone(), fr);
434                } else if let Some(&data_id) = builtin_ids.get(name) {
435                    let gv = module.declare_data_in_func(data_id, builder.func);
436                    builtin_gvs.insert(name.clone(), gv);
437                }
438            }
439
440            let tctx = TransCtx {
441                vars: &vars,
442                string_gvs: &string_gvs,
443                builtin_gvs: &builtin_gvs,
444                runtime_refs: &runtime_refs,
445                func_refs: &func_refs,
446                nt: &num_types,
447                fname: &func.name,
448            };
449            build_function_body(&mut builder, func, &blocks, &tctx);
450            builder.finalize();
451            module.define_function(fid, &mut ctx).unwrap();
452        }
453        if show_progress {
454            eprintln!();
455        }
456
457        // Phase 4: finalize and emit .o file
458        let obj = self.module.take().unwrap().finish();
459        let bytes = obj.emit().map_err(|e| anyhow::anyhow!("{:?}", e))?;
460        std::fs::write(out, bytes)?;
461        Ok(())
462    }
463}
464
465impl Default for CraneliftBackend {
466    fn default() -> Self {
467        Self::new()
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use crate::CodegenBackend;
475    use ling_ast::Span;
476
477    fn decl() -> LocalDecl {
478        LocalDecl {
479            ty: MirType::Any,
480            name: None,
481            span: Span::DUMMY,
482            is_mut: false,
483            is_owning: false,
484        }
485    }
486    fn stmt(kind: StatementKind) -> Statement {
487        Statement { kind, span: Span::DUMMY }
488    }
489    fn ret() -> Terminator {
490        Terminator { kind: TerminatorKind::Return, span: Span::DUMMY }
491    }
492
493    /// A program that uses a string constant, a builtin call, and a call to
494    /// another user function must still compile when each function only declares
495    /// the symbols it actually references (the O(functions × symbols) fix). The
496    /// `helper` function references none of `__main__`'s symbols, so a regression
497    /// that under-declares would surface here as a missing-symbol panic or an
498    /// empty object.
499    #[test]
500    fn emit_declares_only_referenced_symbols() {
501        // helper(x) = x
502        let mut helper = MirFunction::new("helper", 1);
503        helper.basic_blocks = vec![BasicBlock {
504            statements: vec![stmt(StatementKind::Assign(
505                Local(0),
506                Rvalue::Use(Operand::Copy(Local(1))),
507            ))],
508            terminator: Some(ret()),
509        }];
510
511        // __main__: s = "hi"; print(s); r = helper(5); return r
512        let mut main = MirFunction::new("__main__", 0);
513        main.locals = vec![decl(), decl(), decl()]; // Local(1), Local(2), Local(3)
514        main.basic_blocks = vec![BasicBlock {
515            statements: vec![
516                stmt(StatementKind::Assign(
517                    Local(1),
518                    Rvalue::Use(Operand::Constant(Constant::Str("hi".into()))),
519                )),
520                stmt(StatementKind::Assign(
521                    Local(2),
522                    Rvalue::Call {
523                        func: Operand::Constant(Constant::Function("print".into())),
524                        args: vec![Operand::Copy(Local(1))],
525                    },
526                )),
527                stmt(StatementKind::Assign(
528                    Local(3),
529                    Rvalue::Call {
530                        func: Operand::Constant(Constant::Function("helper".into())),
531                        args: vec![Operand::Constant(Constant::I64(5))],
532                    },
533                )),
534                stmt(StatementKind::Assign(
535                    Local(0),
536                    Rvalue::Use(Operand::Copy(Local(3))),
537                )),
538            ],
539            terminator: Some(ret()),
540        }];
541
542        let program = ling_mir::MirProgram { functions: vec![helper, main] };
543        let wrapped = crate::MirProgram::new(program, "test.ling");
544
545        let mut backend = CraneliftBackend::new();
546        let out = std::env::temp_dir().join(format!("ling_aot_test_{}.o", std::process::id()));
547        backend.emit(&wrapped, &out).expect("emit should succeed");
548        let bytes = std::fs::read(&out).expect("object file should exist");
549        assert!(!bytes.is_empty(), "emitted object must be non-empty");
550        let _ = std::fs::remove_file(&out);
551    }
552}