Skip to main content

ling_codegen/cranelift/
aot.rs

1use super::numtype;
2use super::translate::{build_function_body, max_local_index, TransCtx};
3use crate::CodegenBackend;
4use crate::MirProgram;
5use anyhow::Result;
6use cranelift::codegen::ir::{FuncRef, GlobalValue};
7use cranelift::prelude::*;
8use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
9use cranelift_module::DataId;
10use cranelift_module::{DataDescription, FuncId, Linkage, Module};
11use cranelift_object::{ObjectBuilder, ObjectModule};
12use ling_mir::ir::*;
13use std::collections::HashMap;
14
15// ─── AOT Backend ───────────────────────────────────────────────────────────
16
17pub struct CraneliftBackend {
18    module: Option<ObjectModule>,
19    builder_ctx: FunctionBuilderContext,
20}
21
22// ─── Runtime Function Declarations ─────────────────────────────────────────
23
24struct RuntimeDecl {
25    id: FuncId,
26}
27
28fn declare_runtime_functions(module: &mut ObjectModule) -> HashMap<String, RuntimeDecl> {
29    let mut decls = HashMap::new();
30
31    let runtime_fns: &[(&str, &[types::Type], types::Type)] = &[
32        ("ling_f64_add", &[types::F64, types::F64], types::F64),
33        ("ling_f64_sub", &[types::F64, types::F64], types::F64),
34        ("ling_f64_mul", &[types::F64, types::F64], types::F64),
35        ("ling_f64_div", &[types::F64, types::F64], types::F64),
36        ("ling_f64_rem", &[types::F64, types::F64], types::F64),
37        ("ling_f64_neg", &[types::F64], types::F64),
38        ("ling_f64_eq", &[types::F64, types::F64], types::I64),
39        ("ling_f64_lt", &[types::F64, types::F64], types::I64),
40        ("ling_f64_gt", &[types::F64, types::F64], types::I64),
41        ("ling_f64_le", &[types::F64, types::F64], types::I64),
42        ("ling_f64_ge", &[types::F64, types::F64], types::I64),
43        ("ling_sin", &[types::F64], types::F64),
44        ("ling_cos", &[types::F64], types::F64),
45        ("ling_sqrt", &[types::F64], types::F64),
46        ("ling_abs", &[types::F64], types::F64),
47        ("ling_floor", &[types::F64], types::F64),
48        ("ling_ceil", &[types::F64], types::F64),
49        ("ling_round", &[types::F64], types::F64),
50        ("ling_add", &[types::I64, types::I64], types::I64),
51        ("ling_sub", &[types::I64, types::I64], types::I64),
52        ("ling_mul", &[types::I64, types::I64], types::I64),
53        ("ling_div", &[types::I64, types::I64], types::I64),
54        ("ling_rem", &[types::I64, types::I64], types::I64),
55        ("ling_neg", &[types::I64, types::I64], types::I64),
56        ("ling_eq", &[types::I64, types::I64], types::I64),
57        ("ling_ne", &[types::I64, types::I64], types::I64),
58        ("ling_lt", &[types::I64, types::I64], types::I64),
59        ("ling_le", &[types::I64, types::I64], types::I64),
60        ("ling_gt", &[types::I64, types::I64], types::I64),
61        ("ling_ge", &[types::I64, types::I64], types::I64),
62        ("ling_and", &[types::I64, types::I64], types::I64),
63        ("ling_or", &[types::I64, types::I64], types::I64),
64        ("ling_not", &[types::I64], types::I64),
65        ("ling_bool_to_u64", &[types::I64], types::I64),
66        ("ling_alloc", &[types::I64], types::I64),
67        ("ling_free", &[types::I64], types::I64),
68        ("ling_panic", &[types::I64], types::I64),
69        ("ling_str_new", &[types::I64, types::I64], types::I64),
70        ("ling_str_len", &[types::I64], types::I64),
71        ("ling_str_concat", &[types::I64, types::I64], types::I64),
72        ("ling_str_eq", &[types::I64, types::I64], types::I64),
73        ("ling_list_new", &[], types::I64),
74        ("ling_list_push", &[types::I64, types::I64], types::I64),
75        ("ling_list_get", &[types::I64, types::I64], types::I64),
76        ("ling_list_len", &[types::I64], types::I64),
77        (
78            "ling_struct_new",
79            &[types::I64, types::I64, types::I64, types::I64],
80            types::I64,
81        ),
82        (
83            "ling_struct_get",
84            &[types::I64, types::I64, types::I64],
85            types::I64,
86        ),
87        ("ling_print", &[types::I64], types::I64),
88        ("ling_print_val", &[types::I64], types::I64),
89        ("ling_print_newline", &[], types::I64),
90        ("ling_time_now", &[], types::I64),
91        (
92            "ling_builtin",
93            &[types::I64, types::I64, types::I64, types::I64],
94            types::I64,
95        ),
96    ];
97
98    for &(name, params, ret) in runtime_fns {
99        let mut sig = module.make_signature();
100        for &pt in params {
101            sig.params.push(AbiParam::new(pt));
102        }
103        sig.returns.push(AbiParam::new(ret));
104        let id = module
105            .declare_function(name, Linkage::Import, &sig)
106            .unwrap();
107        decls.insert(name.to_string(), RuntimeDecl { id });
108    }
109
110    decls
111}
112
113// ─── String/Builtin name collection ─────────────────────────────────────────
114
115fn collect_string_constants(
116    functions: &[MirFunction],
117    module: &mut ObjectModule,
118) -> (HashMap<String, DataId>, HashMap<String, DataId>) {
119    let mut string_ids: HashMap<String, DataId> = HashMap::new();
120    let mut builtin_ids: HashMap<String, DataId> = HashMap::new();
121    for func in functions {
122        for bb in &func.basic_blocks {
123            for stmt in &bb.statements {
124                if let StatementKind::Assign(_, rval) = &stmt.kind {
125                    visit_rvalue_strings(rval, module, &mut string_ids);
126                    visit_rvalue_builtin_names(rval, module, &mut builtin_ids);
127                }
128            }
129            if let Some(term) = &bb.terminator {
130                visit_term_strings(term, module, &mut string_ids);
131            }
132        }
133    }
134    (string_ids, builtin_ids)
135}
136
137fn visit_operand_strings(
138    op: &Operand,
139    module: &mut ObjectModule,
140    string_ids: &mut HashMap<String, DataId>,
141) {
142    if let Operand::Constant(Constant::Str(s)) = op {
143        if !string_ids.contains_key(s) {
144            let name = format!("__str_{}", string_ids.len());
145            let data_id = module
146                .declare_data(&name, Linkage::Local, true, false)
147                .unwrap();
148            let mut desc = DataDescription::new();
149            desc.define(s.as_bytes().to_vec().into_boxed_slice());
150            desc.set_align(1);
151            module.define_data(data_id, &desc).unwrap();
152            string_ids.insert(s.clone(), data_id);
153        }
154    }
155}
156
157fn visit_rvalue_builtin_names(
158    rval: &Rvalue,
159    module: &mut ObjectModule,
160    builtin_ids: &mut HashMap<String, DataId>,
161) {
162    if let Rvalue::Call { func: Operand::Constant(Constant::Function(n)), .. } = rval {
163        if !builtin_ids.contains_key(n) {
164            let name = format!("__builtin_{}", builtin_ids.len());
165            let data_id = module
166                .declare_data(&name, Linkage::Local, true, false)
167                .unwrap();
168            let mut desc = DataDescription::new();
169            let mut bytes = n.as_bytes().to_vec();
170            bytes.push(0);
171            desc.define(bytes.into_boxed_slice());
172            desc.set_align(1);
173            module.define_data(data_id, &desc).unwrap();
174            builtin_ids.insert(n.clone(), data_id);
175        }
176    }
177}
178
179fn visit_rvalue_strings(
180    rval: &Rvalue,
181    module: &mut ObjectModule,
182    string_ids: &mut HashMap<String, DataId>,
183) {
184    match rval {
185        Rvalue::Use(op) | Rvalue::UnaryOp(_, op) => visit_operand_strings(op, module, string_ids),
186        Rvalue::BinaryOp(_, lhs, rhs) => {
187            visit_operand_strings(lhs, module, string_ids);
188            visit_operand_strings(rhs, module, string_ids);
189        },
190        Rvalue::Call { args, .. } => {
191            for arg in args {
192                visit_operand_strings(arg, module, string_ids);
193            }
194        },
195        Rvalue::Aggregate(_, ops) => {
196            for op in ops {
197                visit_operand_strings(op, module, string_ids);
198            }
199        },
200        _ => {},
201    }
202}
203
204fn visit_term_strings(
205    term: &Terminator,
206    module: &mut ObjectModule,
207    string_ids: &mut HashMap<String, DataId>,
208) {
209    if let TerminatorKind::SwitchInt { discr, .. } = &term.kind {
210        visit_operand_strings(discr, module, string_ids);
211    }
212}
213
214// ─── Main AOT compilation ───────────────────────────────────────────────────
215
216impl CraneliftBackend {
217    pub fn new() -> Self {
218        let mut flag_builder = settings::builder();
219        flag_builder.set("is_pic", "true").unwrap();
220        flag_builder.set("opt_level", "speed").unwrap();
221        flag_builder.set("enable_alias_analysis", "true").unwrap();
222        flag_builder.set("enable_verifier", "false").unwrap();
223        let isa_builder = cranelift::native::builder()
224            .unwrap_or_else(|_| isa::lookup_by_name("aarch64").unwrap());
225        let isa = isa_builder
226            .finish(settings::Flags::new(flag_builder))
227            .unwrap();
228        let obj_builder = ObjectBuilder::new(
229            isa,
230            "ling_program",
231            cranelift_module::default_libcall_names(),
232        )
233        .expect("ObjectBuilder");
234        let module = ObjectModule::new(obj_builder);
235        Self {
236            module: Some(module),
237            builder_ctx: FunctionBuilderContext::new(),
238        }
239    }
240}
241
242impl CodegenBackend for CraneliftBackend {
243    fn emit(&mut self, program: &MirProgram, out: &std::path::Path) -> Result<()> {
244        let module: &mut ObjectModule = self.module.as_mut().unwrap();
245
246        let num_types = numtype::analyze(&program.mir.functions);
247
248        // Phase 0: declare all runtime functions as imports
249        let runtime_decls = declare_runtime_functions(module);
250
251        // Phase 1: collect and declare string/builtin data objects
252        let (string_ids, builtin_ids) = collect_string_constants(&program.mir.functions, module);
253
254        // Phase 2: declare all user functions as exports
255        let mut func_ids: HashMap<String, FuncId> = HashMap::new();
256        for func in &program.mir.functions {
257            let mut sig = module.make_signature();
258            for _ in 0..func.arg_count {
259                sig.params.push(AbiParam::new(types::I64));
260            }
261            sig.returns.push(AbiParam::new(types::I64));
262            let id = module
263                .declare_function(&func.name, Linkage::Export, &sig)
264                .unwrap();
265            func_ids.insert(func.name.clone(), id);
266        }
267
268        // Phase 3: translate each function body
269        for func in &program.mir.functions {
270            let &fid = func_ids.get(&func.name).unwrap();
271
272            let mut ctx = module.make_context();
273            let mut sig = module.make_signature();
274            for _ in 0..func.arg_count {
275                sig.params.push(AbiParam::new(types::I64));
276            }
277            sig.returns.push(AbiParam::new(types::I64));
278            ctx.func.signature = sig;
279
280            let mut builder = FunctionBuilder::new(&mut ctx.func, &mut self.builder_ctx);
281            let blocks: Vec<Block> = func
282                .basic_blocks
283                .iter()
284                .map(|_| builder.create_block())
285                .collect();
286
287            let max_local = max_local_index(func);
288            let mut vars: HashMap<Local, Variable> = HashMap::new();
289            for i in 0..=max_local {
290                vars.insert(Local(i), builder.declare_var(types::I64));
291            }
292
293            // Build string GlobalValue map for this function
294            let mut string_gvs: HashMap<String, GlobalValue> = HashMap::new();
295            for (s, &data_id) in &string_ids {
296                let gv = module.declare_data_in_func(data_id, builder.func);
297                string_gvs.insert(s.clone(), gv);
298            }
299
300            // Build builtin name GlobalValue map for this function
301            let mut builtin_gvs: HashMap<String, GlobalValue> = HashMap::new();
302            for (s, &data_id) in &builtin_ids {
303                let gv = module.declare_data_in_func(data_id, builder.func);
304                builtin_gvs.insert(s.clone(), gv);
305            }
306
307            // Build runtime FuncRef map for this function. The shared translator
308            // looks up runtime helpers by their JIT symbol name (`__ling_*`); the
309            // object backend declares them without that prefix, so re-key here.
310            let mut runtime_refs: HashMap<String, FuncRef> = HashMap::new();
311            for (name, decl) in &runtime_decls {
312                let fr = module.declare_func_in_func(decl.id, builder.func);
313                runtime_refs.insert(format!("__{name}"), fr);
314            }
315
316            // FuncRefs are function-local, so user-function references must be
317            // declared into this builder (not shared across functions).
318            let mut func_refs: HashMap<String, FuncRef> = HashMap::new();
319            for (name, &id) in &func_ids {
320                let fr = module.declare_func_in_func(id, builder.func);
321                func_refs.insert(name.clone(), fr);
322            }
323
324            let tctx = TransCtx {
325                vars: &vars,
326                string_gvs: &string_gvs,
327                builtin_gvs: &builtin_gvs,
328                runtime_refs: &runtime_refs,
329                func_refs: &func_refs,
330                nt: &num_types,
331                fname: &func.name,
332            };
333            build_function_body(&mut builder, func, &blocks, &tctx);
334            builder.finalize();
335            module.define_function(fid, &mut ctx).unwrap();
336        }
337
338        // Phase 4: finalize and emit .o file
339        let obj = self.module.take().unwrap().finish();
340        let bytes = obj.emit().map_err(|e| anyhow::anyhow!("{:?}", e))?;
341        std::fs::write(out, bytes)?;
342        Ok(())
343    }
344}
345
346impl Default for CraneliftBackend {
347    fn default() -> Self {
348        Self::new()
349    }
350}