use super::*;
use inkwell::OptimizationLevel;
use inkwell::context::Context;
use inkwell::targets::{InitializationConfig, Target};
pub(crate) fn is_jit_eligible(chunk: &Chunk) -> bool {
for &inst in &chunk.code {
let op = (inst >> 24) as u8;
match op {
OP_ADD_NN | OP_SUB_NN | OP_MUL_NN | OP_DIV_NN | OP_ADDK_N | OP_SUBK_N | OP_MULK_N
| OP_DIVK_N | OP_MOVE | OP_NEG | OP_RET => {}
OP_LOADK => {
let bx = (inst & 0xFFFF) as usize;
if bx >= chunk.constants.len() {
return false;
}
if !matches!(chunk.constants[bx], Value::Number(_)) {
return false;
}
}
_ => return false,
}
}
true
}
pub(crate) struct JitFunction {
_context: Context,
func_ptr: *const u8,
param_count: usize,
}
unsafe impl Send for JitFunction {}
pub(crate) fn compile(chunk: &Chunk, nan_consts: &[NanVal]) -> Option<JitFunction> {
if !is_jit_eligible(chunk) {
return None;
}
Target::initialize_native(&InitializationConfig::default()).ok()?;
let context = Context::create();
let module = context.create_module("jit");
let builder = context.create_builder();
let f64_type = context.f64_type();
let param_types: Vec<_> = (0..chunk.param_count).map(|_| f64_type.into()).collect();
let fn_type = f64_type.fn_type(¶m_types, false);
let function = module.add_function("jit_func", fn_type, None);
let entry = context.append_basic_block(function, "entry");
builder.position_at_end(entry);
let reg_count = chunk.reg_count.max(chunk.param_count) as usize;
let mut regs: Vec<inkwell::values::FloatValue> = Vec::with_capacity(reg_count);
for i in 0..chunk.param_count as usize {
regs.push(function.get_nth_param(i as u32).unwrap().into_float_value());
}
for _ in chunk.param_count as usize..reg_count {
regs.push(f64_type.const_float(0.0));
}
for &inst in &chunk.code {
let op = (inst >> 24) as u8;
let a = ((inst >> 16) & 0xFF) as usize;
let b = ((inst >> 8) & 0xFF) as usize;
let c = (inst & 0xFF) as usize;
match op {
OP_ADD_NN => {
let result = builder.build_float_add(regs[b], regs[c], "add").ok()?;
regs[a] = result;
}
OP_SUB_NN => {
let result = builder.build_float_sub(regs[b], regs[c], "sub").ok()?;
regs[a] = result;
}
OP_MUL_NN => {
let result = builder.build_float_mul(regs[b], regs[c], "mul").ok()?;
regs[a] = result;
}
OP_DIV_NN => {
let result = builder.build_float_div(regs[b], regs[c], "div").ok()?;
regs[a] = result;
}
OP_ADDK_N => {
let kv = nan_consts.get(c)?.as_number();
let kval = f64_type.const_float(kv);
let result = builder.build_float_add(regs[b], kval, "addk").ok()?;
regs[a] = result;
}
OP_SUBK_N => {
let kv = nan_consts.get(c)?.as_number();
let kval = f64_type.const_float(kv);
let result = builder.build_float_sub(regs[b], kval, "subk").ok()?;
regs[a] = result;
}
OP_MULK_N => {
let kv = nan_consts.get(c)?.as_number();
let kval = f64_type.const_float(kv);
let result = builder.build_float_mul(regs[b], kval, "mulk").ok()?;
regs[a] = result;
}
OP_DIVK_N => {
let kv = nan_consts.get(c)?.as_number();
let kval = f64_type.const_float(kv);
let result = builder.build_float_div(regs[b], kval, "divk").ok()?;
regs[a] = result;
}
OP_LOADK => {
let bx = (inst & 0xFFFF) as usize;
let val = match &chunk.constants[bx] {
Value::Number(n) => *n,
_ => return None,
};
regs[a] = f64_type.const_float(val);
}
OP_MOVE => {
if a != b {
regs[a] = regs[b];
}
}
OP_NEG => {
let result = builder.build_float_neg(regs[b], "neg").ok()?;
regs[a] = result;
}
OP_RET => {
builder.build_return(Some(®s[a])).ok()?;
}
_ => return None,
}
}
let engine = module
.create_jit_execution_engine(OptimizationLevel::Aggressive)
.ok()?;
let func_ptr = engine.get_function_address("jit_func").ok()? as *const u8;
std::mem::forget(engine);
Some(JitFunction {
_context: context,
func_ptr,
param_count: chunk.param_count as usize,
})
}
pub(crate) fn call(func: &JitFunction, args: &[f64]) -> Option<f64> {
if args.len() != func.param_count {
return None;
}
Some(match args.len() {
0 => {
let f: extern "C" fn() -> f64 = unsafe { std::mem::transmute(func.func_ptr) };
f()
}
1 => {
let f: extern "C" fn(f64) -> f64 = unsafe { std::mem::transmute(func.func_ptr) };
f(args[0])
}
2 => {
let f: extern "C" fn(f64, f64) -> f64 = unsafe { std::mem::transmute(func.func_ptr) };
f(args[0], args[1])
}
3 => {
let f: extern "C" fn(f64, f64, f64) -> f64 =
unsafe { std::mem::transmute(func.func_ptr) };
f(args[0], args[1], args[2])
}
4 => {
let f: extern "C" fn(f64, f64, f64, f64) -> f64 =
unsafe { std::mem::transmute(func.func_ptr) };
f(args[0], args[1], args[2], args[3])
}
5 => {
let f: extern "C" fn(f64, f64, f64, f64, f64) -> f64 =
unsafe { std::mem::transmute(func.func_ptr) };
f(args[0], args[1], args[2], args[3], args[4])
}
6 => {
let f: extern "C" fn(f64, f64, f64, f64, f64, f64) -> f64 =
unsafe { std::mem::transmute(func.func_ptr) };
f(args[0], args[1], args[2], args[3], args[4], args[5])
}
7 => {
let f: extern "C" fn(f64, f64, f64, f64, f64, f64, f64) -> f64 =
unsafe { std::mem::transmute(func.func_ptr) };
f(
args[0], args[1], args[2], args[3], args[4], args[5], args[6],
)
}
8 => {
let f: extern "C" fn(f64, f64, f64, f64, f64, f64, f64, f64) -> f64 =
unsafe { std::mem::transmute(func.func_ptr) };
f(
args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7],
)
}
_ => return None,
})
}
pub(crate) fn compile_and_call(chunk: &Chunk, nan_consts: &[NanVal], args: &[f64]) -> Option<f64> {
let func = compile(chunk, nan_consts)?;
call(&func, args)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lexer;
use crate::parser;
fn jit_run_numeric(source: &str, func_name: &str, args: &[f64]) -> Option<f64> {
let tokens: Vec<crate::lexer::Token> = lexer::lex(source)
.unwrap()
.into_iter()
.map(|(t, _)| t)
.collect();
let prog = parser::parse_tokens(tokens).unwrap();
let compiled = crate::vm::compile(&prog).unwrap();
let idx = compiled.func_names.iter().position(|n| n == func_name)?;
let chunk = &compiled.chunks[idx];
let nan_consts = &compiled.nan_constants[idx];
compile_and_call(chunk, nan_consts, args)
}
#[test]
fn llvm_add_nn() {
let result = jit_run_numeric("f a:n b:n>n;+a b", "f", &[3.0, 7.0]);
assert_eq!(result, Some(10.0));
}
#[test]
fn llvm_sub_nn() {
let result = jit_run_numeric("f a:n b:n>n;-a b", "f", &[10.0, 3.0]);
assert_eq!(result, Some(7.0));
}
#[test]
fn llvm_mul_nn() {
let result = jit_run_numeric("f a:n b:n>n;*a b", "f", &[4.0, 5.0]);
assert_eq!(result, Some(20.0));
}
#[test]
fn llvm_div_nn() {
let result = jit_run_numeric("f a:n b:n>n;/a b", "f", &[10.0, 2.0]);
assert_eq!(result, Some(5.0));
}
#[test]
fn llvm_neg() {
let result = jit_run_numeric("f x:n>n;-x", "f", &[5.0]);
assert_eq!(result, Some(-5.0));
}
#[test]
fn llvm_addk_n() {
let result = jit_run_numeric("f x:n>n;+x 10", "f", &[5.0]);
assert_eq!(result, Some(15.0));
}
#[test]
fn llvm_subk_n() {
let result = jit_run_numeric("f x:n>n;-x 3", "f", &[10.0]);
assert_eq!(result, Some(7.0));
}
#[test]
fn llvm_mulk_n() {
let result = jit_run_numeric("f x:n>n;*x 4", "f", &[5.0]);
assert_eq!(result, Some(20.0));
}
#[test]
fn llvm_divk_n() {
let result = jit_run_numeric("f x:n>n;/x 4", "f", &[20.0]);
assert_eq!(result, Some(5.0));
}
#[test]
fn llvm_loadk_constant() {
let result = jit_run_numeric("f>n;42", "f", &[]);
assert_eq!(result, Some(42.0));
}
#[test]
fn llvm_move_passthrough() {
let result = jit_run_numeric("f x:n>n;x", "f", &[7.0]);
assert_eq!(result, Some(7.0));
}
#[test]
fn llvm_move_via_let_binding() {
let result = jit_run_numeric("f x:n>n;y=x;y", "f", &[7.0]);
assert_eq!(result, Some(7.0));
}
#[test]
fn llvm_zero_args() {
let result = jit_run_numeric("f>n;99", "f", &[]);
assert_eq!(result, Some(99.0));
}
#[test]
fn llvm_two_args() {
let result = jit_run_numeric("f a:n b:n>n;+a b", "f", &[3.0, 4.0]);
assert_eq!(result, Some(7.0));
}
#[test]
fn llvm_four_args() {
let result = jit_run_numeric("f a:n b:n c:n d:n>n;+a +b +c d", "f", &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(result, Some(10.0));
}
#[test]
fn llvm_arg_mismatch_returns_none() {
let tokens: Vec<crate::lexer::Token> = lexer::lex("f a:n b:n>n;+a b")
.unwrap()
.into_iter()
.map(|(t, _)| t)
.collect();
let prog = parser::parse_tokens(tokens).unwrap();
let compiled = crate::vm::compile(&prog).unwrap();
let idx = compiled.func_names.iter().position(|n| n == "f").unwrap();
let chunk = &compiled.chunks[idx];
let nan_consts = &compiled.nan_constants[idx];
let func = compile(chunk, nan_consts).unwrap();
assert_eq!(call(&func, &[1.0]), None);
}
#[test]
fn llvm_ineligible_function_returns_none() {
let tokens: Vec<crate::lexer::Token> = lexer::lex(r#"f a:t b:t>t;+a b"#)
.unwrap()
.into_iter()
.map(|(t, _)| t)
.collect();
let prog = parser::parse_tokens(tokens).unwrap();
let compiled = crate::vm::compile(&prog).unwrap();
let idx = compiled.func_names.iter().position(|n| n == "f").unwrap();
let chunk = &compiled.chunks[idx];
let nan_consts = &compiled.nan_constants[idx];
assert!(compile(chunk, nan_consts).is_none());
}
#[test]
fn llvm_compound_arithmetic() {
let result = jit_run_numeric("f a:n b:n>n;* +a b -a b", "f", &[5.0, 3.0]);
assert_eq!(result, Some(16.0));
}
#[test]
fn llvm_nested_constants() {
let result = jit_run_numeric("f x:n>n;+ *x 2 10", "f", &[5.0]);
assert_eq!(result, Some(20.0));
}
}