Skip to main content

ling_codegen/cranelift/
jit.rs

1use super::numtype::{self, NumberTypes};
2use super::runtime;
3use super::translate::{build_function_body, max_local_index, TransCtx};
4use crate::MirProgram;
5use anyhow::Result;
6use cranelift::codegen::ir::{FuncRef, GlobalValue, Signature};
7use cranelift::prelude::*;
8use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
9use cranelift_jit::{JITBuilder, JITModule};
10use cranelift_module::{DataDescription, FuncId, Linkage, Module};
11use ling_mir::ir::*;
12use std::collections::HashMap;
13
14// ─── JIT Backend ────────────────────────────────────────────────────────────
15
16pub struct JitBackend {
17    module: JITModule,
18    builder_ctx: FunctionBuilderContext,
19    func_ids: HashMap<String, FuncId>,
20    runtime_sigs: HashMap<String, (FuncId, Signature)>,
21    string_data_ids: HashMap<String, cranelift_module::DataId>,
22    builtin_data_ids: HashMap<String, cranelift_module::DataId>,
23    functions: Vec<MirFunction>,
24    compiled_names: Vec<String>,
25}
26
27// ─── Runtime function declarations ──────────────────────────────────────────
28
29fn declare_runtime_functions(module: &mut JITModule) -> HashMap<String, (FuncId, Signature)> {
30    use cranelift::codegen::ir::AbiParam;
31
32    let mut sigs = HashMap::new();
33    let runtime_names: &[(&str, &[types::Type], types::Type)] = &[
34        ("__ling_f64_add", &[types::F64, types::F64], types::F64),
35        ("__ling_f64_sub", &[types::F64, types::F64], types::F64),
36        ("__ling_f64_mul", &[types::F64, types::F64], types::F64),
37        ("__ling_f64_div", &[types::F64, types::F64], types::F64),
38        ("__ling_f64_rem", &[types::F64, types::F64], types::F64),
39        ("__ling_f64_neg", &[types::F64], types::F64),
40        ("__ling_f64_eq", &[types::F64, types::F64], types::I64),
41        ("__ling_f64_lt", &[types::F64, types::F64], types::I64),
42        ("__ling_f64_gt", &[types::F64, types::F64], types::I64),
43        ("__ling_f64_le", &[types::F64, types::F64], types::I64),
44        ("__ling_f64_ge", &[types::F64, types::F64], types::I64),
45        ("__ling_sin", &[types::F64], types::F64),
46        ("__ling_cos", &[types::F64], types::F64),
47        ("__ling_sqrt", &[types::F64], types::F64),
48        ("__ling_abs", &[types::F64], types::F64),
49        ("__ling_floor", &[types::F64], types::F64),
50        ("__ling_ceil", &[types::F64], types::F64),
51        ("__ling_round", &[types::F64], types::F64),
52        ("__ling_add", &[types::I64, types::I64], types::I64),
53        ("__ling_sub", &[types::I64, types::I64], types::I64),
54        ("__ling_mul", &[types::I64, types::I64], types::I64),
55        ("__ling_div", &[types::I64, types::I64], types::I64),
56        ("__ling_rem", &[types::I64, types::I64], types::I64),
57        ("__ling_neg", &[types::I64, types::I64], types::I64),
58        ("__ling_eq", &[types::I64, types::I64], types::I64),
59        ("__ling_ne", &[types::I64, types::I64], types::I64),
60        ("__ling_lt", &[types::I64, types::I64], types::I64),
61        ("__ling_le", &[types::I64, types::I64], types::I64),
62        ("__ling_gt", &[types::I64, types::I64], types::I64),
63        ("__ling_ge", &[types::I64, types::I64], types::I64),
64        ("__ling_and", &[types::I64, types::I64], types::I64),
65        ("__ling_or", &[types::I64, types::I64], types::I64),
66        ("__ling_not", &[types::I64], types::I64),
67        ("__ling_bool_to_u64", &[types::I64], types::I64),
68        ("__ling_alloc", &[types::I64], types::I64),
69        ("__ling_free", &[types::I64], types::I64),
70        ("__ling_panic", &[types::I64], types::I64),
71        ("__ling_str_new", &[types::I64, types::I64], types::I64),
72        ("__ling_str_len", &[types::I64], types::I64),
73        ("__ling_str_concat", &[types::I64, types::I64], types::I64),
74        ("__ling_str_eq", &[types::I64, types::I64], types::I64),
75        ("__ling_list_new", &[], types::I64),
76        ("__ling_list_push", &[types::I64, types::I64], types::I64),
77        ("__ling_list_get", &[types::I64, types::I64], types::I64),
78        ("__ling_list_len", &[types::I64], types::I64),
79        (
80            "__ling_struct_new",
81            &[types::I64, types::I64, types::I64, types::I64],
82            types::I64,
83        ),
84        (
85            "__ling_struct_get",
86            &[types::I64, types::I64, types::I64],
87            types::I64,
88        ),
89        ("__ling_print", &[types::I64], types::I64),
90        ("__ling_print_val", &[types::I64], types::I64),
91        ("__ling_print_newline", &[], types::I64),
92        ("__ling_time_now", &[], types::I64),
93        (
94            "__ling_builtin",
95            &[types::I64, types::I64, types::I64, types::I64],
96            types::I64,
97        ),
98    ];
99    for &(name, params, ret) in runtime_names {
100        let mut sig = module.make_signature();
101        for &pt in params {
102            sig.params.push(AbiParam::new(pt));
103        }
104        sig.returns.push(AbiParam::new(ret));
105        let id = module
106            .declare_function(name, Linkage::Import, &sig)
107            .unwrap();
108        sigs.insert(name.to_string(), (id, sig));
109    }
110    sigs
111}
112
113// ─── String constant collection ──────────────────────────────────────────
114
115fn collect_strings(
116    functions: &[MirFunction],
117    module: &mut JITModule,
118) -> (
119    HashMap<String, cranelift_module::DataId>,
120    HashMap<String, cranelift_module::DataId>,
121) {
122    let mut string_ids: HashMap<String, cranelift_module::DataId> = HashMap::new();
123    let mut builtin_ids: HashMap<String, cranelift_module::DataId> = HashMap::new();
124    for func in functions {
125        for bb in &func.basic_blocks {
126            for stmt in &bb.statements {
127                if let StatementKind::Assign(_, rval) = &stmt.kind {
128                    visit_rvalue_strings(rval, module, &mut string_ids);
129                    visit_rvalue_builtin_names(rval, module, &mut builtin_ids);
130                }
131            }
132            if let Some(term) = &bb.terminator {
133                visit_term_strings(term, module, &mut string_ids);
134            }
135        }
136    }
137    (string_ids, builtin_ids)
138}
139
140fn visit_operand_strings(
141    op: &Operand,
142    module: &mut JITModule,
143    string_ids: &mut HashMap<String, cranelift_module::DataId>,
144) {
145    if let Operand::Constant(Constant::Str(s)) = op {
146        if !string_ids.contains_key(s) {
147            let name = format!("__str_{}", string_ids.len());
148            let data_id = module
149                .declare_data(&name, Linkage::Local, true, false)
150                .unwrap();
151            let mut desc = DataDescription::new();
152            desc.define(s.as_bytes().to_vec().into_boxed_slice());
153            desc.set_align(1);
154            module.define_data(data_id, &desc).unwrap();
155            string_ids.insert(s.clone(), data_id);
156        }
157    }
158}
159
160fn visit_rvalue_builtin_names(
161    rval: &Rvalue,
162    module: &mut JITModule,
163    builtin_ids: &mut HashMap<String, cranelift_module::DataId>,
164) {
165    if let Rvalue::Call { func: Operand::Constant(Constant::Function(n)), .. } = rval {
166        if !builtin_ids.contains_key(n) {
167            let name = format!("__builtin_{}", builtin_ids.len());
168            let data_id = module
169                .declare_data(&name, Linkage::Local, true, false)
170                .unwrap();
171            let mut desc = DataDescription::new();
172            let mut bytes = n.as_bytes().to_vec();
173            bytes.push(0);
174            desc.define(bytes.into_boxed_slice());
175            desc.set_align(1);
176            module.define_data(data_id, &desc).unwrap();
177            builtin_ids.insert(n.clone(), data_id);
178        }
179    }
180}
181
182fn visit_rvalue_strings(
183    rval: &Rvalue,
184    module: &mut JITModule,
185    string_ids: &mut HashMap<String, cranelift_module::DataId>,
186) {
187    match rval {
188        Rvalue::Use(op) | Rvalue::UnaryOp(_, op) => visit_operand_strings(op, module, string_ids),
189        Rvalue::BinaryOp(_, lhs, rhs) => {
190            visit_operand_strings(lhs, module, string_ids);
191            visit_operand_strings(rhs, module, string_ids);
192        },
193        Rvalue::Call { args, .. } => {
194            for arg in args {
195                visit_operand_strings(arg, module, string_ids);
196            }
197        },
198        Rvalue::Aggregate(_, ops) => {
199            for op in ops {
200                visit_operand_strings(op, module, string_ids);
201            }
202        },
203        _ => {},
204    }
205}
206
207fn visit_term_strings(
208    term: &Terminator,
209    module: &mut JITModule,
210    string_ids: &mut HashMap<String, cranelift_module::DataId>,
211) {
212    if let TerminatorKind::SwitchInt { discr, .. } = &term.kind {
213        visit_operand_strings(discr, module, string_ids);
214    }
215}
216
217impl JitBackend {
218    /// Create a new JIT backend. Symbols can be registered via `register_symbols_fn`.
219    pub fn new<F>(register_symbols_fn: F) -> Self
220    where
221        F: FnOnce(&mut JITBuilder),
222    {
223        let mut flag_builder = settings::builder();
224        flag_builder.set("use_colocated_libcalls", "false").unwrap();
225        flag_builder.set("is_pic", "false").unwrap();
226        flag_builder.set("opt_level", "speed").unwrap();
227        flag_builder.set("enable_alias_analysis", "true").unwrap();
228        flag_builder.set("enable_verifier", "false").unwrap();
229
230        let isa_builder = cranelift_native::builder()
231            .unwrap_or_else(|msg| panic!("host architecture not supported: {msg}"));
232        let isa = isa_builder
233            .finish(settings::Flags::new(flag_builder))
234            .unwrap_or_else(|msg| panic!("host architecture not supported: {msg}"));
235
236        let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
237        register_symbols_fn(&mut builder);
238        let module = JITModule::new(builder);
239
240        Self {
241            module,
242            builder_ctx: FunctionBuilderContext::new(),
243            func_ids: HashMap::new(),
244            runtime_sigs: HashMap::new(),
245            string_data_ids: HashMap::new(),
246            builtin_data_ids: HashMap::new(),
247            functions: Vec::new(),
248            compiled_names: Vec::new(),
249        }
250    }
251
252    /// Compile all functions in the MIR program into JIT memory.
253    pub fn compile(&mut self, program: &MirProgram) -> Result<()> {
254        let num_types = numtype::analyze(&program.mir.functions);
255        self.runtime_sigs = declare_runtime_functions(&mut self.module);
256
257        let (string_ids, builtin_ids) = collect_strings(&program.mir.functions, &mut self.module);
258        self.string_data_ids = string_ids;
259        self.builtin_data_ids = builtin_ids;
260
261        for func in &program.mir.functions {
262            let mut sig = self.module.make_signature();
263            for _ in 0..func.arg_count {
264                sig.params.push(AbiParam::new(types::I64));
265            }
266            sig.returns.push(AbiParam::new(types::I64));
267            let id = self
268                .module
269                .declare_function(&func.name, Linkage::Export, &sig)
270                .unwrap();
271            self.func_ids.insert(func.name.clone(), id);
272        }
273
274        for func in &program.mir.functions {
275            self.translate_function(func, &num_types);
276        }
277
278        self.module.finalize_definitions().unwrap();
279
280        self.functions = program.mir.functions.clone();
281        for func in &program.mir.functions {
282            self.compiled_names.push(func.name.clone());
283        }
284
285        Ok(())
286    }
287
288    fn translate_function(&mut self, func: &MirFunction, nt: &NumberTypes) {
289        let &fid = self.func_ids.get(&func.name).unwrap();
290        let mut ctx = self.module.make_context();
291        let mut sig = self.module.make_signature();
292        for _ in 0..func.arg_count {
293            sig.params.push(AbiParam::new(types::I64));
294        }
295        sig.returns.push(AbiParam::new(types::I64));
296        ctx.func.signature = sig;
297
298        let mut builder = FunctionBuilder::new(&mut ctx.func, &mut self.builder_ctx);
299        let blocks: Vec<Block> = func
300            .basic_blocks
301            .iter()
302            .map(|_| builder.create_block())
303            .collect();
304        let max_local = max_local_index(func);
305        let mut vars: HashMap<Local, Variable> = HashMap::new();
306        for i in 0..=max_local {
307            vars.insert(Local(i), builder.declare_var(types::I64));
308        }
309
310        let mut string_gvs: HashMap<String, GlobalValue> = HashMap::new();
311        for (s, &data_id) in &self.string_data_ids {
312            let gv = self.module.declare_data_in_func(data_id, builder.func);
313            string_gvs.insert(s.clone(), gv);
314        }
315        let mut builtin_gvs: HashMap<String, GlobalValue> = HashMap::new();
316        for (s, &data_id) in &self.builtin_data_ids {
317            let gv = self.module.declare_data_in_func(data_id, builder.func);
318            builtin_gvs.insert(s.clone(), gv);
319        }
320
321        let mut runtime_refs: HashMap<String, FuncRef> = HashMap::new();
322        for (name, (id, _sig)) in &self.runtime_sigs {
323            let fr = self.module.declare_func_in_func(*id, builder.func);
324            runtime_refs.insert(name.clone(), fr);
325        }
326        let mut func_refs: HashMap<String, FuncRef> = HashMap::new();
327        for (name, &id) in &self.func_ids {
328            let fr = self.module.declare_func_in_func(id, builder.func);
329            func_refs.insert(name.clone(), fr);
330        }
331
332        let tctx = TransCtx {
333            vars: &vars,
334            string_gvs: &string_gvs,
335            builtin_gvs: &builtin_gvs,
336            runtime_refs: &runtime_refs,
337            func_refs: &func_refs,
338            nt,
339            fname: &func.name,
340        };
341        build_function_body(&mut builder, func, &blocks, &tctx);
342        builder.finalize();
343        self.module.define_function(fid, &mut ctx).unwrap();
344    }
345
346    pub fn get_function(&mut self, name: &str) -> Option<*const u8> {
347        let func_id = self.func_ids.get(name)?;
348        Some(self.module.get_finalized_function(*func_id))
349    }
350
351    pub fn run_main(&mut self) -> Result<u64> {
352        let main_name = self
353            .compiled_names
354            .iter()
355            .find(|n| {
356                n.as_str() == "__main__"
357                    || n.as_str() == "main"
358                    || n.as_str() == "start"
359                    || n.as_str() == "เริ่ม"
360            })
361            .cloned()
362            .unwrap_or_else(|| self.compiled_names.first().cloned().unwrap_or_default());
363        if main_name.is_empty() {
364            return Ok(runtime::TAG_UNIT);
365        }
366        match self.get_function(&main_name) {
367            Some(ptr) => {
368                let func: unsafe extern "C" fn() -> u64 = unsafe { std::mem::transmute(ptr) };
369                Ok(unsafe { func() })
370            },
371            None => Ok(runtime::TAG_UNIT),
372        }
373    }
374
375    pub fn run_function(&mut self, name: &str, args: &[u64]) -> Result<u64> {
376        let fn_ptr = match self.get_function(name) {
377            Some(p) => p,
378            None => return Ok(runtime::TAG_UNIT),
379        };
380        unsafe {
381            match args.len() {
382                0 => {
383                    let f: unsafe extern "C" fn() -> u64 = std::mem::transmute(fn_ptr);
384                    Ok(f())
385                },
386                1 => {
387                    let f: unsafe extern "C" fn(u64) -> u64 = std::mem::transmute(fn_ptr);
388                    Ok(f(args[0]))
389                },
390                2 => {
391                    let f: unsafe extern "C" fn(u64, u64) -> u64 = std::mem::transmute(fn_ptr);
392                    Ok(f(args[0], args[1]))
393                },
394                3 => {
395                    let f: unsafe extern "C" fn(u64, u64, u64) -> u64 = std::mem::transmute(fn_ptr);
396                    Ok(f(args[0], args[1], args[2]))
397                },
398                n => {
399                    let f: unsafe extern "C" fn(*const u64, usize) -> u64 =
400                        std::mem::transmute(fn_ptr);
401                    Ok(f(args.as_ptr(), n))
402                },
403            }
404        }
405    }
406}