Skip to main content

just_engine/runner/jit/
reg_jit.rs

1//! Cranelift-based JIT for register bytecode (numeric prototype).
2
3use std::mem;
4
5use cranelift::prelude::{types, AbiParam, FunctionBuilder, FunctionBuilderContext, InstBuilder, MemFlags, Variable, FloatCC};
6use cranelift_jit::{JITBuilder, JITModule};
7use cranelift_module::{default_libcall_names, FuncId, Linkage, Module};
8
9use crate::runner::ds::error::JErrorType;
10use crate::runner::ds::value::{JsNumberType, JsValue};
11
12use super::reg_bytecode::{RegChunk, RegOpCode};
13
14pub type RegJitFn = unsafe extern "C" fn(*mut f64) -> f64;
15
16pub struct RegJitFunction {
17    _func_id: FuncId,
18    code_ptr: *const u8,
19}
20
21impl RegJitFunction {
22    pub unsafe fn call(&self, registers: &mut [f64]) -> f64 {
23        let func: RegJitFn = mem::transmute(self.code_ptr);
24        func(registers.as_mut_ptr())
25    }
26}
27
28pub struct RegJit {
29    module: JITModule,
30}
31
32impl RegJit {
33    pub fn new() -> Result<Self, JErrorType> {
34        if !cfg!(target_arch = "x86_64") {
35            return Err(JErrorType::TypeError(
36                "Cranelift JIT prototype is only supported on x86_64".to_string(),
37            ));
38        }
39        let builder = JITBuilder::new(default_libcall_names())
40            .map_err(|e| JErrorType::TypeError(format!("Cranelift init failed: {}", e)))?;
41        let module = JITModule::new(builder);
42        Ok(RegJit { module })
43    }
44
45    /// Pre-scan: return Err if the chunk contains ops the numeric JIT cannot handle.
46    fn can_jit_compile(chunk: &RegChunk) -> Result<(), JErrorType> {
47        for instr in &chunk.code {
48            match instr.op {
49                RegOpCode::GetProp
50                | RegOpCode::SetProp
51                | RegOpCode::GetElem
52                | RegOpCode::SetElem
53                | RegOpCode::Call
54                | RegOpCode::CallMethod
55                | RegOpCode::TypeOf => {
56                    return Err(JErrorType::TypeError(format!(
57                        "JIT bail: unsupported op {:?}",
58                        instr.op
59                    )));
60                }
61                _ => {}
62            }
63        }
64        Ok(())
65    }
66
67    pub fn compile(&mut self, chunk: &RegChunk) -> Result<RegJitFunction, JErrorType> {
68        Self::can_jit_compile(chunk)?;
69        let ptr_ty = self.module.target_config().pointer_type();
70        let mut sig = self.module.make_signature();
71        sig.params.push(AbiParam::new(ptr_ty));
72        sig.returns.push(AbiParam::new(types::F64));
73
74        let func_id = self
75            .module
76            .declare_function("reg_jit_entry", Linkage::Local, &sig)
77            .map_err(|e| JErrorType::TypeError(format!("declare_function failed: {}", e)))?;
78
79        let mut ctx = self.module.make_context();
80        ctx.func.signature = sig;
81        let mut builder_ctx = FunctionBuilderContext::new();
82        let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
83
84        let entry_block = builder.create_block();
85        builder.append_block_params_for_function_params(entry_block);
86        builder.switch_to_block(entry_block);
87        builder.seal_block(entry_block);
88
89        let regs_ptr = builder.block_params(entry_block)[0];
90        let reg_var = Variable::from_u32(0);
91        builder.declare_var(reg_var, ptr_ty);
92        builder.def_var(reg_var, regs_ptr);
93
94        let mut local_regs = vec![None; chunk.names.len()];
95        for local in &chunk.locals {
96            if let Some(slot) = local_regs.get_mut(local.name_idx as usize) {
97                *slot = Some(local.reg);
98            }
99        }
100        let mut lexical_names = vec![false; chunk.names.len()];
101        for instr in &chunk.code {
102            if matches!(instr.op, RegOpCode::DeclareLet | RegOpCode::DeclareConst | RegOpCode::InitBinding) {
103                if let Some(slot) = lexical_names.get_mut(instr.imm as usize) {
104                    *slot = true;
105                }
106            }
107        }
108
109        let mut blocks = Vec::with_capacity(chunk.code.len());
110        for _ in 0..chunk.code.len() {
111            blocks.push(builder.create_block());
112        }
113        let exit_block = builder.create_block();
114
115        builder.ins().jump(blocks[0], &[]);
116
117        for (idx, instr) in chunk.code.iter().enumerate() {
118            let block = blocks[idx];
119            builder.switch_to_block(block);
120
121            let next_block = if idx + 1 < blocks.len() {
122                Some(blocks[idx + 1])
123            } else {
124                None
125            };
126
127            match instr.op {
128                RegOpCode::LoadConst => {
129                    let value = chunk
130                        .constants
131                        .get(instr.imm as usize)
132                        .ok_or_else(|| JErrorType::TypeError("Const out of range".to_string()))?;
133                    let num = match value {
134                        JsValue::Number(n) => self.num_to_f64(n),
135                        _ => {
136                            return Err(JErrorType::TypeError(
137                                "JIT supports only numeric constants".to_string(),
138                            ))
139                        }
140                    };
141                    let val = builder.ins().f64const(num);
142                    self.store_reg(&mut builder, reg_var, instr.dst, val);
143                }
144                RegOpCode::LoadUndefined => {
145                    let val = builder.ins().f64const(f64::NAN);
146                    self.store_reg(&mut builder, reg_var, instr.dst, val);
147                }
148                RegOpCode::LoadNull | RegOpCode::LoadFalse => {
149                    let val = builder.ins().f64const(0.0);
150                    self.store_reg(&mut builder, reg_var, instr.dst, val);
151                }
152                RegOpCode::LoadTrue => {
153                    let val = builder.ins().f64const(1.0);
154                    self.store_reg(&mut builder, reg_var, instr.dst, val);
155                }
156                RegOpCode::Move => {
157                    let val = self.load_reg(&mut builder, reg_var, instr.src1);
158                    self.store_reg(&mut builder, reg_var, instr.dst, val);
159                }
160                RegOpCode::Add | RegOpCode::Sub | RegOpCode::Mul | RegOpCode::Div => {
161                    let a = self.load_reg(&mut builder, reg_var, instr.src1);
162                    let b = self.load_reg(&mut builder, reg_var, instr.src2);
163                    let res = match instr.op {
164                        RegOpCode::Add => builder.ins().fadd(a, b),
165                        RegOpCode::Sub => builder.ins().fsub(a, b),
166                        RegOpCode::Mul => builder.ins().fmul(a, b),
167                        RegOpCode::Div => builder.ins().fdiv(a, b),
168                        _ => unreachable!(),
169                    };
170                    self.store_reg(&mut builder, reg_var, instr.dst, res);
171                }
172                RegOpCode::GetVar => {
173                    let name_idx = instr.imm as usize;
174                    if name_idx >= local_regs.len() || lexical_names[name_idx] {
175                        return Err(JErrorType::TypeError(
176                            "JIT supports only local var GetVar".to_string(),
177                        ));
178                    }
179                    let local_reg = local_regs[name_idx].ok_or_else(|| {
180                        JErrorType::TypeError("JIT supports only local var GetVar".to_string())
181                    })?;
182                    if instr.dst != local_reg {
183                        let val = self.load_reg(&mut builder, reg_var, local_reg);
184                        self.store_reg(&mut builder, reg_var, instr.dst, val);
185                    }
186                }
187                RegOpCode::SetVar => {
188                    let name_idx = instr.imm as usize;
189                    if name_idx >= local_regs.len() || lexical_names[name_idx] {
190                        return Err(JErrorType::TypeError(
191                            "JIT supports only local var SetVar".to_string(),
192                        ));
193                    }
194                    let local_reg = local_regs[name_idx].ok_or_else(|| {
195                        JErrorType::TypeError("JIT supports only local var SetVar".to_string())
196                    })?;
197                    if instr.src1 != local_reg {
198                        let val = self.load_reg(&mut builder, reg_var, instr.src1);
199                        self.store_reg(&mut builder, reg_var, local_reg, val);
200                    }
201                }
202                RegOpCode::DeclareVar | RegOpCode::DeclareLet | RegOpCode::DeclareConst => {
203                    let name_idx = instr.imm as usize;
204                    if name_idx >= local_regs.len() {
205                        return Err(JErrorType::TypeError(
206                            "JIT supports only local var Declare*".to_string(),
207                        ));
208                    }
209                    let _local_reg = local_regs[name_idx].ok_or_else(|| {
210                        JErrorType::TypeError("JIT supports only local var Declare*".to_string())
211                    })?;
212                }
213                RegOpCode::InitVar | RegOpCode::InitBinding => {
214                    let name_idx = instr.imm as usize;
215                    if name_idx >= local_regs.len() {
216                        return Err(JErrorType::TypeError(
217                            "JIT supports only local var Init*".to_string(),
218                        ));
219                    }
220                    let local_reg = local_regs[name_idx].ok_or_else(|| {
221                        JErrorType::TypeError("JIT supports only local var Init*".to_string())
222                    })?;
223                    if instr.src1 != local_reg {
224                        let val = self.load_reg(&mut builder, reg_var, instr.src1);
225                        self.store_reg(&mut builder, reg_var, local_reg, val);
226                    }
227                }
228                RegOpCode::Mod => {
229                    let a = self.load_reg(&mut builder, reg_var, instr.src1);
230                    let b = self.load_reg(&mut builder, reg_var, instr.src2);
231                    let zero = builder.ins().f64const(0.0);
232                    let b_is_zero = builder.ins().fcmp(FloatCC::Equal, b, zero);
233                    let b_is_nan = builder.ins().fcmp(FloatCC::Unordered, b, b);
234                    let a_is_nan = builder.ins().fcmp(FloatCC::Unordered, a, a);
235                    let b_bad = builder.ins().bor(b_is_zero, b_is_nan);
236                    let is_bad = builder.ins().bor(b_bad, a_is_nan);
237
238                    let ok_block = builder.create_block();
239                    let bad_block = builder.create_block();
240                    let cont_block = next_block.unwrap_or(exit_block);
241
242                    builder.ins().brif(is_bad, bad_block, &[], ok_block, &[]);
243                    builder.seal_block(block);
244
245                    builder.switch_to_block(ok_block);
246                    let div = builder.ins().fdiv(a, b);
247                    let quot = builder.ins().fcvt_to_sint(types::I64, div);
248                    let quot_f = builder.ins().fcvt_from_sint(types::F64, quot);
249                    let prod = builder.ins().fmul(quot_f, b);
250                    let res = builder.ins().fsub(a, prod);
251                    self.store_reg(&mut builder, reg_var, instr.dst, res);
252                    builder.ins().jump(cont_block, &[]);
253                    builder.seal_block(ok_block);
254
255                    builder.switch_to_block(bad_block);
256                    let nan = builder.ins().f64const(f64::NAN);
257                    self.store_reg(&mut builder, reg_var, instr.dst, nan);
258                    builder.ins().jump(cont_block, &[]);
259                    builder.seal_block(bad_block);
260                    continue;
261                }
262                RegOpCode::Negate => {
263                    let a = self.load_reg(&mut builder, reg_var, instr.src1);
264                    let zero = builder.ins().f64const(0.0);
265                    let res = builder.ins().fsub(zero, a);
266                    self.store_reg(&mut builder, reg_var, instr.dst, res);
267                }
268                RegOpCode::UnaryPlus => {
269                    let val = self.load_reg(&mut builder, reg_var, instr.src1);
270                    self.store_reg(&mut builder, reg_var, instr.dst, val);
271                }
272                RegOpCode::Not => {
273                    let val = self.load_reg(&mut builder, reg_var, instr.src1);
274                    let truthy = self.emit_truthy(&mut builder, val);
275                    let one = builder.ins().f64const(1.0);
276                    let zero = builder.ins().f64const(0.0);
277                    let res = builder.ins().select(truthy, zero, one);
278                    self.store_reg(&mut builder, reg_var, instr.dst, res);
279                }
280                RegOpCode::BitAnd | RegOpCode::BitOr | RegOpCode::BitXor => {
281                    let a = self.load_reg(&mut builder, reg_var, instr.src1);
282                    let b = self.load_reg(&mut builder, reg_var, instr.src2);
283                    let a_i32 = self.emit_to_i32(&mut builder, a);
284                    let b_i32 = self.emit_to_i32(&mut builder, b);
285                    let res_i32 = match instr.op {
286                        RegOpCode::BitAnd => builder.ins().band(a_i32, b_i32),
287                        RegOpCode::BitOr => builder.ins().bor(a_i32, b_i32),
288                        RegOpCode::BitXor => builder.ins().bxor(a_i32, b_i32),
289                        _ => unreachable!(),
290                    };
291                    let res = self.emit_from_i32(&mut builder, res_i32);
292                    self.store_reg(&mut builder, reg_var, instr.dst, res);
293                }
294                RegOpCode::BitNot => {
295                    let val = self.load_reg(&mut builder, reg_var, instr.src1);
296                    let a_i32 = self.emit_to_i32(&mut builder, val);
297                    let res_i32 = builder.ins().bnot(a_i32);
298                    let res = self.emit_from_i32(&mut builder, res_i32);
299                    self.store_reg(&mut builder, reg_var, instr.dst, res);
300                }
301                RegOpCode::ShiftLeft | RegOpCode::ShiftRight | RegOpCode::UShiftRight => {
302                    let a = self.load_reg(&mut builder, reg_var, instr.src1);
303                    let b = self.load_reg(&mut builder, reg_var, instr.src2);
304                    let mask = builder.ins().iconst(types::I32, 0x1f);
305                    let b_u32 = self.emit_to_u32(&mut builder, b);
306                    let shift = builder.ins().band(b_u32, mask);
307                    let res = match instr.op {
308                        RegOpCode::ShiftLeft => {
309                            let a_i32 = self.emit_to_i32(&mut builder, a);
310                            let res_i32 = builder.ins().ishl(a_i32, shift);
311                            self.emit_from_i32(&mut builder, res_i32)
312                        }
313                        RegOpCode::ShiftRight => {
314                            let a_i32 = self.emit_to_i32(&mut builder, a);
315                            let res_i32 = builder.ins().sshr(a_i32, shift);
316                            self.emit_from_i32(&mut builder, res_i32)
317                        }
318                        RegOpCode::UShiftRight => {
319                            let a_u32 = self.emit_to_u32(&mut builder, a);
320                            let res_u32 = builder.ins().ushr(a_u32, shift);
321                            self.emit_from_u32(&mut builder, res_u32)
322                        }
323                        _ => unreachable!(),
324                    };
325                    self.store_reg(&mut builder, reg_var, instr.dst, res);
326                }
327                RegOpCode::LessThan
328                | RegOpCode::LessEqual
329                | RegOpCode::GreaterThan
330                | RegOpCode::GreaterEqual
331                | RegOpCode::Equal
332                | RegOpCode::NotEqual
333                | RegOpCode::StrictEqual
334                | RegOpCode::StrictNotEqual => {
335                    let a = self.load_reg(&mut builder, reg_var, instr.src1);
336                    let b = self.load_reg(&mut builder, reg_var, instr.src2);
337                    let cc = match instr.op {
338                        RegOpCode::LessThan => FloatCC::LessThan,
339                        RegOpCode::LessEqual => FloatCC::LessThanOrEqual,
340                        RegOpCode::GreaterThan => FloatCC::GreaterThan,
341                        RegOpCode::GreaterEqual => FloatCC::GreaterThanOrEqual,
342                        RegOpCode::Equal | RegOpCode::StrictEqual => FloatCC::Equal,
343                        RegOpCode::NotEqual | RegOpCode::StrictNotEqual => FloatCC::NotEqual,
344                        _ => FloatCC::Equal,
345                    };
346                    let cmp = builder.ins().fcmp(cc, a, b);
347                    let one = builder.ins().f64const(1.0);
348                    let zero = builder.ins().f64const(0.0);
349                    let res = builder.ins().select(cmp, one, zero);
350                    self.store_reg(&mut builder, reg_var, instr.dst, res);
351                }
352                RegOpCode::Jump => {
353                    builder.ins().jump(blocks[instr.imm as usize], &[]);
354                    builder.seal_block(block);
355                    continue;
356                }
357                RegOpCode::JumpIfFalse | RegOpCode::JumpIfTrue => {
358                    let val = self.load_reg(&mut builder, reg_var, instr.src1);
359                    let truthy = self.emit_truthy(&mut builder, val);
360                    let target = blocks[instr.imm as usize];
361                    let fallthrough = next_block.unwrap_or(exit_block);
362                    if instr.op == RegOpCode::JumpIfFalse {
363                        builder.ins().brif(truthy, fallthrough, &[], target, &[]);
364                    } else {
365                        builder.ins().brif(truthy, target, &[], fallthrough, &[]);
366                    }
367                    builder.seal_block(block);
368                    continue;
369                }
370                // These ops are rejected by can_jit_compile; unreachable here.
371                RegOpCode::GetProp
372                | RegOpCode::SetProp
373                | RegOpCode::GetElem
374                | RegOpCode::SetElem
375                | RegOpCode::Call
376                | RegOpCode::CallMethod
377                | RegOpCode::TypeOf => {
378                    unreachable!("pre-scan should have rejected {:?}", instr.op);
379                }
380                RegOpCode::Return | RegOpCode::Halt => {
381                    let val = if instr.op == RegOpCode::Return {
382                        self.load_reg(&mut builder, reg_var, instr.src1)
383                    } else {
384                        builder.ins().f64const(0.0)
385                    };
386                    builder.ins().return_(&[val]);
387                    builder.seal_block(block);
388                    continue;
389                }
390                #[allow(unreachable_patterns)]
391                _ => {
392                    return Err(JErrorType::TypeError(format!(
393                        "Unsupported op in JIT prototype: {:?}",
394                        instr.op
395                    )));
396                }
397            }
398
399            if let Some(next) = next_block {
400                builder.ins().jump(next, &[]);
401            } else {
402                builder.ins().jump(exit_block, &[]);
403            }
404            builder.seal_block(block);
405        }
406
407        builder.switch_to_block(exit_block);
408        let zero = builder.ins().f64const(0.0);
409        builder.ins().return_(&[zero]);
410        builder.seal_block(exit_block);
411
412        builder.finalize();
413
414        self.module
415            .define_function(func_id, &mut ctx)
416            .map_err(|e| JErrorType::TypeError(format!("define_function failed: {}", e)))?;
417        self.module.clear_context(&mut ctx);
418        let _ = self.module.finalize_definitions();
419
420        let code_ptr = self.module.get_finalized_function(func_id);
421        Ok(RegJitFunction { _func_id: func_id, code_ptr })
422    }
423
424    pub fn execute(&mut self, chunk: &RegChunk) -> Result<(f64, Vec<f64>), JErrorType> {
425        let func = self.compile(chunk)?;
426        let mut registers = vec![0.0; chunk.register_count as usize];
427        let result = unsafe { func.call(&mut registers) };
428        Ok((result, registers))
429    }
430
431    fn load_reg(&self, builder: &mut FunctionBuilder, reg_var: Variable, reg: u32) -> cranelift::prelude::Value {
432        let ptr = builder.use_var(reg_var);
433        let offset = (reg as i32) * 8;
434        builder.ins().load(types::F64, MemFlags::new(), ptr, offset)
435    }
436
437    fn store_reg(
438        &self,
439        builder: &mut FunctionBuilder,
440        reg_var: Variable,
441        reg: u32,
442        value: cranelift::prelude::Value,
443    ) {
444        let ptr = builder.use_var(reg_var);
445        let offset = (reg as i32) * 8;
446        builder.ins().store(MemFlags::new(), value, ptr, offset);
447    }
448
449    fn emit_truthy(&self, builder: &mut FunctionBuilder, val: cranelift::prelude::Value) -> cranelift::prelude::Value {
450        let zero = builder.ins().f64const(0.0);
451        let is_zero = builder.ins().fcmp(FloatCC::Equal, val, zero);
452        let is_nan = builder.ins().fcmp(FloatCC::Unordered, val, val);
453        let is_false = builder.ins().bor(is_zero, is_nan);
454        builder.ins().bnot(is_false)
455    }
456
457    fn emit_to_i32(&self, builder: &mut FunctionBuilder, val: cranelift::prelude::Value) -> cranelift::prelude::Value {
458        builder.ins().fcvt_to_sint(types::I32, val)
459    }
460
461    fn emit_to_u32(&self, builder: &mut FunctionBuilder, val: cranelift::prelude::Value) -> cranelift::prelude::Value {
462        builder.ins().fcvt_to_uint(types::I32, val)
463    }
464
465    fn emit_from_i32(&self, builder: &mut FunctionBuilder, val: cranelift::prelude::Value) -> cranelift::prelude::Value {
466        builder.ins().fcvt_from_sint(types::F64, val)
467    }
468
469    fn emit_from_u32(&self, builder: &mut FunctionBuilder, val: cranelift::prelude::Value) -> cranelift::prelude::Value {
470        builder.ins().fcvt_from_uint(types::F64, val)
471    }
472
473    fn num_to_f64(&self, n: &JsNumberType) -> f64 {
474        match n {
475            JsNumberType::Integer(i) => *i as f64,
476            JsNumberType::Float(f) => *f,
477            JsNumberType::NaN => f64::NAN,
478            JsNumberType::PositiveInfinity => f64::INFINITY,
479            JsNumberType::NegativeInfinity => f64::NEG_INFINITY,
480        }
481    }
482}