1use super::numtype::{self, NumberTypes};
2use super::runtime;
3use super::translate::{build_function_body, max_local_index, TransCtx};
4use crate::MirProgram;
5use anyhow::Result;
6use cranelift::codegen::ir::{FuncRef, GlobalValue, Signature};
7use cranelift::prelude::*;
8use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
9use cranelift_jit::{JITBuilder, JITModule};
10use cranelift_module::{DataDescription, FuncId, Linkage, Module};
11use ling_mir::ir::*;
12use std::collections::HashMap;
13
14pub struct JitBackend {
17 module: JITModule,
18 builder_ctx: FunctionBuilderContext,
19 func_ids: HashMap<String, FuncId>,
20 runtime_sigs: HashMap<String, (FuncId, Signature)>,
21 string_data_ids: HashMap<String, cranelift_module::DataId>,
22 builtin_data_ids: HashMap<String, cranelift_module::DataId>,
23 functions: Vec<MirFunction>,
24 compiled_names: Vec<String>,
25}
26
27fn declare_runtime_functions(module: &mut JITModule) -> HashMap<String, (FuncId, Signature)> {
30 use cranelift::codegen::ir::AbiParam;
31
32 let mut sigs = HashMap::new();
33 let runtime_names: &[(&str, &[types::Type], types::Type)] = &[
34 ("__ling_f64_add", &[types::F64, types::F64], types::F64),
35 ("__ling_f64_sub", &[types::F64, types::F64], types::F64),
36 ("__ling_f64_mul", &[types::F64, types::F64], types::F64),
37 ("__ling_f64_div", &[types::F64, types::F64], types::F64),
38 ("__ling_f64_rem", &[types::F64, types::F64], types::F64),
39 ("__ling_f64_neg", &[types::F64], types::F64),
40 ("__ling_f64_eq", &[types::F64, types::F64], types::I64),
41 ("__ling_f64_lt", &[types::F64, types::F64], types::I64),
42 ("__ling_f64_gt", &[types::F64, types::F64], types::I64),
43 ("__ling_f64_le", &[types::F64, types::F64], types::I64),
44 ("__ling_f64_ge", &[types::F64, types::F64], types::I64),
45 ("__ling_sin", &[types::F64], types::F64),
46 ("__ling_cos", &[types::F64], types::F64),
47 ("__ling_sqrt", &[types::F64], types::F64),
48 ("__ling_abs", &[types::F64], types::F64),
49 ("__ling_floor", &[types::F64], types::F64),
50 ("__ling_ceil", &[types::F64], types::F64),
51 ("__ling_round", &[types::F64], types::F64),
52 ("__ling_add", &[types::I64, types::I64], types::I64),
53 ("__ling_sub", &[types::I64, types::I64], types::I64),
54 ("__ling_mul", &[types::I64, types::I64], types::I64),
55 ("__ling_div", &[types::I64, types::I64], types::I64),
56 ("__ling_rem", &[types::I64, types::I64], types::I64),
57 ("__ling_neg", &[types::I64, types::I64], types::I64),
58 ("__ling_eq", &[types::I64, types::I64], types::I64),
59 ("__ling_ne", &[types::I64, types::I64], types::I64),
60 ("__ling_lt", &[types::I64, types::I64], types::I64),
61 ("__ling_le", &[types::I64, types::I64], types::I64),
62 ("__ling_gt", &[types::I64, types::I64], types::I64),
63 ("__ling_ge", &[types::I64, types::I64], types::I64),
64 ("__ling_and", &[types::I64, types::I64], types::I64),
65 ("__ling_or", &[types::I64, types::I64], types::I64),
66 ("__ling_not", &[types::I64], types::I64),
67 ("__ling_bool_to_u64", &[types::I64], types::I64),
68 ("__ling_alloc", &[types::I64], types::I64),
69 ("__ling_free", &[types::I64], types::I64),
70 ("__ling_panic", &[types::I64], types::I64),
71 ("__ling_str_new", &[types::I64, types::I64], types::I64),
72 ("__ling_str_len", &[types::I64], types::I64),
73 ("__ling_str_concat", &[types::I64, types::I64], types::I64),
74 ("__ling_str_eq", &[types::I64, types::I64], types::I64),
75 ("__ling_list_new", &[], types::I64),
76 ("__ling_list_push", &[types::I64, types::I64], types::I64),
77 ("__ling_list_get", &[types::I64, types::I64], types::I64),
78 ("__ling_list_len", &[types::I64], types::I64),
79 (
80 "__ling_struct_new",
81 &[types::I64, types::I64, types::I64, types::I64],
82 types::I64,
83 ),
84 (
85 "__ling_struct_get",
86 &[types::I64, types::I64, types::I64],
87 types::I64,
88 ),
89 ("__ling_print", &[types::I64], types::I64),
90 ("__ling_print_val", &[types::I64], types::I64),
91 ("__ling_print_newline", &[], types::I64),
92 ("__ling_time_now", &[], types::I64),
93 (
94 "__ling_builtin",
95 &[types::I64, types::I64, types::I64, types::I64],
96 types::I64,
97 ),
98 ];
99 for &(name, params, ret) in runtime_names {
100 let mut sig = module.make_signature();
101 for &pt in params {
102 sig.params.push(AbiParam::new(pt));
103 }
104 sig.returns.push(AbiParam::new(ret));
105 let id = module
106 .declare_function(name, Linkage::Import, &sig)
107 .unwrap();
108 sigs.insert(name.to_string(), (id, sig));
109 }
110 sigs
111}
112
113fn collect_strings(
116 functions: &[MirFunction],
117 module: &mut JITModule,
118) -> (
119 HashMap<String, cranelift_module::DataId>,
120 HashMap<String, cranelift_module::DataId>,
121) {
122 let mut string_ids: HashMap<String, cranelift_module::DataId> = HashMap::new();
123 let mut builtin_ids: HashMap<String, cranelift_module::DataId> = HashMap::new();
124 for func in functions {
125 for bb in &func.basic_blocks {
126 for stmt in &bb.statements {
127 if let StatementKind::Assign(_, rval) = &stmt.kind {
128 visit_rvalue_strings(rval, module, &mut string_ids);
129 visit_rvalue_builtin_names(rval, module, &mut builtin_ids);
130 }
131 }
132 if let Some(term) = &bb.terminator {
133 visit_term_strings(term, module, &mut string_ids);
134 }
135 }
136 }
137 (string_ids, builtin_ids)
138}
139
140fn visit_operand_strings(
141 op: &Operand,
142 module: &mut JITModule,
143 string_ids: &mut HashMap<String, cranelift_module::DataId>,
144) {
145 if let Operand::Constant(Constant::Str(s)) = op {
146 if !string_ids.contains_key(s) {
147 let name = format!("__str_{}", string_ids.len());
148 let data_id = module
149 .declare_data(&name, Linkage::Local, true, false)
150 .unwrap();
151 let mut desc = DataDescription::new();
152 desc.define(s.as_bytes().to_vec().into_boxed_slice());
153 desc.set_align(1);
154 module.define_data(data_id, &desc).unwrap();
155 string_ids.insert(s.clone(), data_id);
156 }
157 }
158}
159
160fn visit_rvalue_builtin_names(
161 rval: &Rvalue,
162 module: &mut JITModule,
163 builtin_ids: &mut HashMap<String, cranelift_module::DataId>,
164) {
165 if let Rvalue::Call { func: Operand::Constant(Constant::Function(n)), .. } = rval {
166 if !builtin_ids.contains_key(n) {
167 let name = format!("__builtin_{}", builtin_ids.len());
168 let data_id = module
169 .declare_data(&name, Linkage::Local, true, false)
170 .unwrap();
171 let mut desc = DataDescription::new();
172 let mut bytes = n.as_bytes().to_vec();
173 bytes.push(0);
174 desc.define(bytes.into_boxed_slice());
175 desc.set_align(1);
176 module.define_data(data_id, &desc).unwrap();
177 builtin_ids.insert(n.clone(), data_id);
178 }
179 }
180}
181
182fn visit_rvalue_strings(
183 rval: &Rvalue,
184 module: &mut JITModule,
185 string_ids: &mut HashMap<String, cranelift_module::DataId>,
186) {
187 match rval {
188 Rvalue::Use(op) | Rvalue::UnaryOp(_, op) => visit_operand_strings(op, module, string_ids),
189 Rvalue::BinaryOp(_, lhs, rhs) => {
190 visit_operand_strings(lhs, module, string_ids);
191 visit_operand_strings(rhs, module, string_ids);
192 },
193 Rvalue::Call { args, .. } => {
194 for arg in args {
195 visit_operand_strings(arg, module, string_ids);
196 }
197 },
198 Rvalue::Aggregate(_, ops) => {
199 for op in ops {
200 visit_operand_strings(op, module, string_ids);
201 }
202 },
203 _ => {},
204 }
205}
206
207fn visit_term_strings(
208 term: &Terminator,
209 module: &mut JITModule,
210 string_ids: &mut HashMap<String, cranelift_module::DataId>,
211) {
212 if let TerminatorKind::SwitchInt { discr, .. } = &term.kind {
213 visit_operand_strings(discr, module, string_ids);
214 }
215}
216
217impl JitBackend {
218 pub fn new<F>(register_symbols_fn: F) -> Self
220 where
221 F: FnOnce(&mut JITBuilder),
222 {
223 let mut flag_builder = settings::builder();
224 flag_builder.set("use_colocated_libcalls", "false").unwrap();
225 flag_builder.set("is_pic", "false").unwrap();
226 flag_builder.set("opt_level", "speed").unwrap();
227 flag_builder.set("enable_alias_analysis", "true").unwrap();
228 flag_builder.set("enable_verifier", "false").unwrap();
229
230 let isa_builder = cranelift_native::builder()
231 .unwrap_or_else(|msg| panic!("host architecture not supported: {msg}"));
232 let isa = isa_builder
233 .finish(settings::Flags::new(flag_builder))
234 .unwrap_or_else(|msg| panic!("host architecture not supported: {msg}"));
235
236 let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
237 register_symbols_fn(&mut builder);
238 let module = JITModule::new(builder);
239
240 Self {
241 module,
242 builder_ctx: FunctionBuilderContext::new(),
243 func_ids: HashMap::new(),
244 runtime_sigs: HashMap::new(),
245 string_data_ids: HashMap::new(),
246 builtin_data_ids: HashMap::new(),
247 functions: Vec::new(),
248 compiled_names: Vec::new(),
249 }
250 }
251
252 pub fn compile(&mut self, program: &MirProgram) -> Result<()> {
254 let num_types = numtype::analyze(&program.mir.functions);
255 self.runtime_sigs = declare_runtime_functions(&mut self.module);
256
257 let (string_ids, builtin_ids) = collect_strings(&program.mir.functions, &mut self.module);
258 self.string_data_ids = string_ids;
259 self.builtin_data_ids = builtin_ids;
260
261 for func in &program.mir.functions {
262 let mut sig = self.module.make_signature();
263 for _ in 0..func.arg_count {
264 sig.params.push(AbiParam::new(types::I64));
265 }
266 sig.returns.push(AbiParam::new(types::I64));
267 let id = self
268 .module
269 .declare_function(&func.name, Linkage::Export, &sig)
270 .unwrap();
271 self.func_ids.insert(func.name.clone(), id);
272 }
273
274 for func in &program.mir.functions {
275 self.translate_function(func, &num_types);
276 }
277
278 self.module.finalize_definitions().unwrap();
279
280 self.functions = program.mir.functions.clone();
281 for func in &program.mir.functions {
282 self.compiled_names.push(func.name.clone());
283 }
284
285 Ok(())
286 }
287
288 fn translate_function(&mut self, func: &MirFunction, nt: &NumberTypes) {
289 let &fid = self.func_ids.get(&func.name).unwrap();
290 let mut ctx = self.module.make_context();
291 let mut sig = self.module.make_signature();
292 for _ in 0..func.arg_count {
293 sig.params.push(AbiParam::new(types::I64));
294 }
295 sig.returns.push(AbiParam::new(types::I64));
296 ctx.func.signature = sig;
297
298 let mut builder = FunctionBuilder::new(&mut ctx.func, &mut self.builder_ctx);
299 let blocks: Vec<Block> = func
300 .basic_blocks
301 .iter()
302 .map(|_| builder.create_block())
303 .collect();
304 let max_local = max_local_index(func);
305 let mut vars: HashMap<Local, Variable> = HashMap::new();
306 for i in 0..=max_local {
307 vars.insert(Local(i), builder.declare_var(types::I64));
308 }
309
310 let mut string_gvs: HashMap<String, GlobalValue> = HashMap::new();
311 for (s, &data_id) in &self.string_data_ids {
312 let gv = self.module.declare_data_in_func(data_id, builder.func);
313 string_gvs.insert(s.clone(), gv);
314 }
315 let mut builtin_gvs: HashMap<String, GlobalValue> = HashMap::new();
316 for (s, &data_id) in &self.builtin_data_ids {
317 let gv = self.module.declare_data_in_func(data_id, builder.func);
318 builtin_gvs.insert(s.clone(), gv);
319 }
320
321 let mut runtime_refs: HashMap<String, FuncRef> = HashMap::new();
322 for (name, (id, _sig)) in &self.runtime_sigs {
323 let fr = self.module.declare_func_in_func(*id, builder.func);
324 runtime_refs.insert(name.clone(), fr);
325 }
326 let mut func_refs: HashMap<String, FuncRef> = HashMap::new();
327 for (name, &id) in &self.func_ids {
328 let fr = self.module.declare_func_in_func(id, builder.func);
329 func_refs.insert(name.clone(), fr);
330 }
331
332 let tctx = TransCtx {
333 vars: &vars,
334 string_gvs: &string_gvs,
335 builtin_gvs: &builtin_gvs,
336 runtime_refs: &runtime_refs,
337 func_refs: &func_refs,
338 nt,
339 fname: &func.name,
340 };
341 build_function_body(&mut builder, func, &blocks, &tctx);
342 builder.finalize();
343 self.module.define_function(fid, &mut ctx).unwrap();
344 }
345
346 pub fn get_function(&mut self, name: &str) -> Option<*const u8> {
347 let func_id = self.func_ids.get(name)?;
348 Some(self.module.get_finalized_function(*func_id))
349 }
350
351 pub fn run_main(&mut self) -> Result<u64> {
352 let main_name = self
353 .compiled_names
354 .iter()
355 .find(|n| {
356 n.as_str() == "__main__"
357 || n.as_str() == "main"
358 || n.as_str() == "start"
359 || n.as_str() == "เริ่ม"
360 })
361 .cloned()
362 .unwrap_or_else(|| self.compiled_names.first().cloned().unwrap_or_default());
363 if main_name.is_empty() {
364 return Ok(runtime::TAG_UNIT);
365 }
366 match self.get_function(&main_name) {
367 Some(ptr) => {
368 let func: unsafe extern "C" fn() -> u64 = unsafe { std::mem::transmute(ptr) };
369 Ok(unsafe { func() })
370 },
371 None => Ok(runtime::TAG_UNIT),
372 }
373 }
374
375 pub fn run_function(&mut self, name: &str, args: &[u64]) -> Result<u64> {
376 let fn_ptr = match self.get_function(name) {
377 Some(p) => p,
378 None => return Ok(runtime::TAG_UNIT),
379 };
380 unsafe {
381 match args.len() {
382 0 => {
383 let f: unsafe extern "C" fn() -> u64 = std::mem::transmute(fn_ptr);
384 Ok(f())
385 },
386 1 => {
387 let f: unsafe extern "C" fn(u64) -> u64 = std::mem::transmute(fn_ptr);
388 Ok(f(args[0]))
389 },
390 2 => {
391 let f: unsafe extern "C" fn(u64, u64) -> u64 = std::mem::transmute(fn_ptr);
392 Ok(f(args[0], args[1]))
393 },
394 3 => {
395 let f: unsafe extern "C" fn(u64, u64, u64) -> u64 = std::mem::transmute(fn_ptr);
396 Ok(f(args[0], args[1], args[2]))
397 },
398 n => {
399 let f: unsafe extern "C" fn(*const u64, usize) -> u64 =
400 std::mem::transmute(fn_ptr);
401 Ok(f(args.as_ptr(), n))
402 },
403 }
404 }
405 }
406}