use cranelift_codegen::ir::{AbiParam, Function, InstBuilder, Type, UserFuncName, Value, types};
use cranelift_codegen::settings::{self, Configurable};
use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{FuncId, Linkage, Module};
use std::collections::HashMap;
use std::time::Instant;
use crate::error::{DSLCompileError, Result};
use crate::final_tagless::{ASTRepr, VariableRegistry};
pub struct CraneliftCompiler {
module: JITModule,
builder_context: FunctionBuilderContext,
settings: settings::Flags,
opt_level: OptimizationLevel,
}
pub struct CompiledFunction {
func_ptr: *const u8,
signature: FunctionSignature,
metadata: CompilationMetadata,
}
#[derive(Debug, Clone)]
pub struct FunctionSignature {
pub input_count: usize,
pub return_type: Type,
}
#[derive(Debug, Clone)]
pub struct CompilationMetadata {
pub compilation_time_ms: u64,
pub optimization_level: OptimizationLevel,
pub expression_complexity: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptimizationLevel {
None,
Basic,
Full,
}
impl Default for OptimizationLevel {
fn default() -> Self {
Self::Basic
}
}
struct ExternalMathFunctions {
sin_id: FuncId,
cos_id: FuncId,
exp_id: FuncId,
log_id: FuncId,
pow_id: FuncId,
}
struct LocalMathFunctions {
sin_ref: cranelift_codegen::ir::FuncRef,
cos_ref: cranelift_codegen::ir::FuncRef,
exp_ref: cranelift_codegen::ir::FuncRef,
log_ref: cranelift_codegen::ir::FuncRef,
pow_ref: cranelift_codegen::ir::FuncRef,
}
mod math_wrappers {
pub extern "C" fn sin_wrapper(x: f64) -> f64 {
x.sin()
}
pub extern "C" fn cos_wrapper(x: f64) -> f64 {
x.cos()
}
pub extern "C" fn exp_wrapper(x: f64) -> f64 {
x.exp()
}
pub extern "C" fn log_wrapper(x: f64) -> f64 {
x.ln()
}
pub extern "C" fn pow_wrapper(x: f64, y: f64) -> f64 {
x.powf(y)
}
}
impl CraneliftCompiler {
pub fn new(opt_level: OptimizationLevel) -> Result<Self> {
let mut flag_builder = settings::builder();
match opt_level {
OptimizationLevel::None => {
flag_builder.set("opt_level", "none").unwrap();
flag_builder.set("enable_verifier", "false").unwrap();
}
OptimizationLevel::Basic => {
flag_builder.set("opt_level", "speed").unwrap();
flag_builder.set("enable_verifier", "true").unwrap();
}
OptimizationLevel::Full => {
flag_builder.set("opt_level", "speed_and_size").unwrap();
flag_builder.set("enable_verifier", "true").unwrap();
flag_builder.set("enable_alias_analysis", "true").unwrap();
}
}
flag_builder.set("use_colocated_libcalls", "false").unwrap();
flag_builder.set("is_pic", "false").unwrap();
flag_builder.set("enable_float", "true").unwrap();
let settings = settings::Flags::new(flag_builder);
let isa = cranelift_codegen::isa::lookup(target_lexicon::Triple::host())
.map_err(|e| DSLCompileError::JITError(format!("Failed to create ISA: {e}")))?
.finish(settings.clone())
.map_err(|e| DSLCompileError::JITError(format!("Failed to finish ISA: {e}")))?;
let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
builder.symbol("sin", math_wrappers::sin_wrapper as *const u8);
builder.symbol("cos", math_wrappers::cos_wrapper as *const u8);
builder.symbol("exp", math_wrappers::exp_wrapper as *const u8);
builder.symbol("log", math_wrappers::log_wrapper as *const u8);
builder.symbol("pow", math_wrappers::pow_wrapper as *const u8);
let module = JITModule::new(builder);
Ok(Self {
module,
builder_context: FunctionBuilderContext::new(),
settings,
opt_level,
})
}
pub fn new_default() -> Result<Self> {
Self::new(OptimizationLevel::default())
}
pub fn compile_expression(
&mut self,
expr: &ASTRepr<f64>,
registry: &VariableRegistry,
) -> Result<CompiledFunction> {
let start_time = Instant::now();
let anf = crate::symbolic::anf::convert_to_anf(expr)?;
let math_functions = self.declare_external_math_functions()?;
let mut sig = self.module.make_signature();
for _ in 0..registry.len() {
sig.params.push(AbiParam::new(types::F64));
}
sig.returns.push(AbiParam::new(types::F64));
let func_id = self
.module
.declare_function("compiled_expr", Linkage::Export, &sig)
.map_err(|e| DSLCompileError::JITError(format!("Failed to declare function: {e}")))?;
let mut ctx = self.module.make_context();
ctx.func.signature = sig.clone();
self.build_function_body_from_anf(&mut ctx.func, &anf, registry, &math_functions)?;
self.module
.define_function(func_id, &mut ctx)
.map_err(|e| DSLCompileError::JITError(format!("Failed to define function: {e}")))?;
self.module.finalize_definitions();
let code_ptr = self.module.get_finalized_function(func_id);
let compilation_time = start_time.elapsed();
Ok(CompiledFunction {
func_ptr: code_ptr,
signature: FunctionSignature {
input_count: registry.len(),
return_type: types::F64,
},
metadata: CompilationMetadata {
compilation_time_ms: compilation_time.as_millis() as u64,
optimization_level: self.opt_level,
expression_complexity: expr.count_operations(),
},
})
}
fn declare_external_math_functions(&mut self) -> Result<ExternalMathFunctions> {
let mut single_arg_sig = self.module.make_signature();
single_arg_sig.params.push(AbiParam::new(types::F64));
single_arg_sig.returns.push(AbiParam::new(types::F64));
let mut double_arg_sig = self.module.make_signature();
double_arg_sig.params.push(AbiParam::new(types::F64));
double_arg_sig.params.push(AbiParam::new(types::F64));
double_arg_sig.returns.push(AbiParam::new(types::F64));
let sin_id = self
.module
.declare_function("sin", Linkage::Import, &single_arg_sig)
.map_err(|e| DSLCompileError::JITError(format!("Failed to declare sin: {e}")))?;
let cos_id = self
.module
.declare_function("cos", Linkage::Import, &single_arg_sig)
.map_err(|e| DSLCompileError::JITError(format!("Failed to declare cos: {e}")))?;
let exp_id = self
.module
.declare_function("exp", Linkage::Import, &single_arg_sig)
.map_err(|e| DSLCompileError::JITError(format!("Failed to declare exp: {e}")))?;
let log_id = self
.module
.declare_function("log", Linkage::Import, &single_arg_sig)
.map_err(|e| DSLCompileError::JITError(format!("Failed to declare log: {e}")))?;
let pow_id = self
.module
.declare_function("pow", Linkage::Import, &double_arg_sig)
.map_err(|e| DSLCompileError::JITError(format!("Failed to declare pow: {e}")))?;
Ok(ExternalMathFunctions {
sin_id,
cos_id,
exp_id,
log_id,
pow_id,
})
}
fn build_function_body_from_anf(
&mut self,
func: &mut Function,
anf: &crate::symbolic::anf::ANFExpr<f64>,
registry: &VariableRegistry,
math_functions: &ExternalMathFunctions,
) -> Result<()> {
let mut builder = FunctionBuilder::new(func, &mut self.builder_context);
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
builder.seal_block(entry_block);
let params = builder.block_params(entry_block);
let mut var_values = HashMap::new();
for i in 0..registry.len() {
var_values.insert(i, params[i]);
}
let local_sin = self
.module
.declare_func_in_func(math_functions.sin_id, builder.func);
let local_cos = self
.module
.declare_func_in_func(math_functions.cos_id, builder.func);
let local_exp = self
.module
.declare_func_in_func(math_functions.exp_id, builder.func);
let local_log = self
.module
.declare_func_in_func(math_functions.log_id, builder.func);
let local_pow = self
.module
.declare_func_in_func(math_functions.pow_id, builder.func);
let local_math_functions = LocalMathFunctions {
sin_ref: local_sin,
cos_ref: local_cos,
exp_ref: local_exp,
log_ref: local_log,
pow_ref: local_pow,
};
let result = Self::generate_ir_from_anf(&mut builder, anf, &var_values, &local_math_functions)?;
builder.ins().return_(&[result]);
builder.finalize();
Ok(())
}
fn generate_ir_from_anf(
builder: &mut FunctionBuilder,
anf: &crate::symbolic::anf::ANFExpr<f64>,
user_vars: &HashMap<usize, Value>,
math_functions: &LocalMathFunctions,
) -> Result<Value> {
use crate::symbolic::anf::{ANFAtom, ANFComputation, ANFExpr, VarRef};
let mut bound_vars: HashMap<u32, Value> = HashMap::new();
Self::generate_ir_from_anf_with_bindings(builder, anf, user_vars, &mut bound_vars, math_functions)
}
fn generate_ir_from_anf_with_bindings(
builder: &mut FunctionBuilder,
anf: &crate::symbolic::anf::ANFExpr<f64>,
user_vars: &HashMap<usize, Value>,
bound_vars: &mut HashMap<u32, Value>,
math_functions: &LocalMathFunctions,
) -> Result<Value> {
use crate::symbolic::anf::{ANFAtom, ANFComputation, ANFExpr, VarRef};
match anf {
ANFExpr::Atom(atom) => {
Ok(Self::generate_ir_from_atom(builder, atom, user_vars, bound_vars))
}
ANFExpr::Let(var_ref, computation, body) => {
let comp_result = Self::generate_ir_from_computation(builder, computation, user_vars, bound_vars, math_functions)?;
if let VarRef::Bound(id) = var_ref {
bound_vars.insert(*id, comp_result);
}
Self::generate_ir_from_anf_with_bindings(builder, body, user_vars, bound_vars, math_functions)
}
}
}
fn generate_ir_from_atom(
builder: &mut FunctionBuilder,
atom: &crate::symbolic::anf::ANFAtom<f64>,
user_vars: &HashMap<usize, Value>,
bound_vars: &HashMap<u32, Value>,
) -> Value {
use crate::symbolic::anf::{ANFAtom, VarRef};
match atom {
ANFAtom::Constant(value) => builder.ins().f64const(*value),
ANFAtom::Variable(var_ref) => match var_ref {
VarRef::User(idx) => *user_vars.get(idx).expect("User variable not found"),
VarRef::Bound(id) => *bound_vars.get(id).expect("Bound variable not found"),
},
}
}
fn generate_ir_from_computation(
builder: &mut FunctionBuilder,
computation: &crate::symbolic::anf::ANFComputation<f64>,
user_vars: &HashMap<usize, Value>,
bound_vars: &HashMap<u32, Value>,
math_functions: &LocalMathFunctions,
) -> Result<Value> {
use crate::symbolic::anf::ANFComputation;
match computation {
ANFComputation::Add(left, right) => {
let left_val = Self::generate_ir_from_atom(builder, left, user_vars, bound_vars);
let right_val = Self::generate_ir_from_atom(builder, right, user_vars, bound_vars);
Ok(builder.ins().fadd(left_val, right_val))
}
ANFComputation::Sub(left, right) => {
let left_val = Self::generate_ir_from_atom(builder, left, user_vars, bound_vars);
let right_val = Self::generate_ir_from_atom(builder, right, user_vars, bound_vars);
Ok(builder.ins().fsub(left_val, right_val))
}
ANFComputation::Mul(left, right) => {
let left_val = Self::generate_ir_from_atom(builder, left, user_vars, bound_vars);
let right_val = Self::generate_ir_from_atom(builder, right, user_vars, bound_vars);
Ok(builder.ins().fmul(left_val, right_val))
}
ANFComputation::Div(left, right) => {
let left_val = Self::generate_ir_from_atom(builder, left, user_vars, bound_vars);
let right_val = Self::generate_ir_from_atom(builder, right, user_vars, bound_vars);
Ok(builder.ins().fdiv(left_val, right_val))
}
ANFComputation::Pow(left, right) => {
let left_val = Self::generate_ir_from_atom(builder, left, user_vars, bound_vars);
let right_val = Self::generate_ir_from_atom(builder, right, user_vars, bound_vars);
let call = builder.ins().call(math_functions.pow_ref, &[left_val, right_val]);
Ok(builder.inst_results(call)[0])
}
ANFComputation::Neg(operand) => {
let val = Self::generate_ir_from_atom(builder, operand, user_vars, bound_vars);
Ok(builder.ins().fneg(val))
}
ANFComputation::Sqrt(operand) => {
let val = Self::generate_ir_from_atom(builder, operand, user_vars, bound_vars);
Ok(builder.ins().sqrt(val))
}
ANFComputation::Sin(operand) => {
let val = Self::generate_ir_from_atom(builder, operand, user_vars, bound_vars);
let call = builder.ins().call(math_functions.sin_ref, &[val]);
Ok(builder.inst_results(call)[0])
}
ANFComputation::Cos(operand) => {
let val = Self::generate_ir_from_atom(builder, operand, user_vars, bound_vars);
let call = builder.ins().call(math_functions.cos_ref, &[val]);
Ok(builder.inst_results(call)[0])
}
ANFComputation::Exp(operand) => {
let val = Self::generate_ir_from_atom(builder, operand, user_vars, bound_vars);
let call = builder.ins().call(math_functions.exp_ref, &[val]);
Ok(builder.inst_results(call)[0])
}
ANFComputation::Ln(operand) => {
let val = Self::generate_ir_from_atom(builder, operand, user_vars, bound_vars);
let call = builder.ins().call(math_functions.log_ref, &[val]);
Ok(builder.inst_results(call)[0])
}
}
}
}
impl CompiledFunction {
pub fn call(&self, args: &[f64]) -> Result<f64> {
if args.len() != self.signature.input_count {
return Err(DSLCompileError::JITError(format!(
"Expected {} arguments, got {}",
self.signature.input_count,
args.len()
)));
}
let result = match args.len() {
0 => {
let func: extern "C" fn() -> f64 =
unsafe { std::mem::transmute(self.func_ptr) };
func()
}
1 => {
let func: extern "C" fn(f64) -> f64 =
unsafe { std::mem::transmute(self.func_ptr) };
func(args[0])
}
2 => {
let func: extern "C" fn(f64, f64) -> f64 =
unsafe { std::mem::transmute(self.func_ptr) };
func(args[0], args[1])
}
3 => {
let func: extern "C" fn(f64, f64, f64) -> f64 =
unsafe { std::mem::transmute(self.func_ptr) };
func(args[0], args[1], args[2])
}
4 => {
let func: extern "C" fn(f64, f64, f64, f64) -> f64 =
unsafe { std::mem::transmute(self.func_ptr) };
func(args[0], args[1], args[2], args[3])
}
5 => {
let func: extern "C" fn(f64, f64, f64, f64, f64) -> f64 =
unsafe { std::mem::transmute(self.func_ptr) };
func(args[0], args[1], args[2], args[3], args[4])
}
6 => {
let func: extern "C" fn(f64, f64, f64, f64, f64, f64) -> f64 =
unsafe { std::mem::transmute(self.func_ptr) };
func(args[0], args[1], args[2], args[3], args[4], args[5])
}
_ => {
return Err(DSLCompileError::JITError(format!(
"Unsupported number of arguments: {}. Maximum supported is 6.",
args.len()
)));
}
};
Ok(result)
}
pub fn metadata(&self) -> &CompilationMetadata {
&self.metadata
}
pub fn signature(&self) -> &FunctionSignature {
&self.signature
}
}
fn estimate_code_size(expr: &ASTRepr<f64>) -> usize {
match expr {
ASTRepr::Constant(_) => 8, ASTRepr::Variable(_) => 4, ASTRepr::Add(l, r) | ASTRepr::Sub(l, r) | ASTRepr::Mul(l, r) | ASTRepr::Div(l, r) => {
estimate_code_size(l) + estimate_code_size(r) + 4
}
ASTRepr::Neg(inner) => estimate_code_size(inner) + 4,
ASTRepr::Pow(base, exp) => estimate_code_size(base) + estimate_code_size(exp) + 20,
ASTRepr::Sqrt(inner) => estimate_code_size(inner) + 8,
ASTRepr::Sin(inner) | ASTRepr::Cos(inner) | ASTRepr::Exp(inner) | ASTRepr::Ln(inner) => {
estimate_code_size(inner) + 16
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::final_tagless::ASTEval;
#[test]
fn test_basic_compilation() {
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable();
let expr = ASTEval::add(ASTEval::var(x_idx), ASTEval::constant(1.0));
let mut compiler = CraneliftCompiler::new_default().unwrap();
let compiled = compiler.compile_expression(&expr, ®istry).unwrap();
let result = compiled.call(&[2.0]).unwrap();
assert!((result - 3.0).abs() < 1e-10);
}
#[test]
fn test_multiple_variables() {
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable();
let y_idx = registry.register_variable();
let expr = ASTEval::add(
ASTEval::mul(ASTEval::var(x_idx), ASTEval::var(y_idx)),
ASTEval::constant(1.0),
);
let mut compiler = CraneliftCompiler::new_default().unwrap();
let compiled = compiler.compile_expression(&expr, ®istry).unwrap();
let result = compiled.call(&[3.0, 4.0]).unwrap();
assert!((result - 13.0).abs() < 1e-10); }
#[test]
fn test_optimization_levels() {
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable();
let expr = ASTEval::pow(ASTEval::var(x_idx), ASTEval::constant(2.0));
for opt_level in [
OptimizationLevel::None,
OptimizationLevel::Basic,
OptimizationLevel::Full,
] {
let mut compiler = CraneliftCompiler::new(opt_level).unwrap();
let compiled = compiler.compile_expression(&expr, ®istry).unwrap();
let result = compiled.call(&[3.0]).unwrap();
assert!((result - 9.0).abs() < 1e-10);
assert_eq!(compiled.metadata().optimization_level, opt_level);
}
}
#[test]
fn test_binary_exponentiation_optimization() {
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable();
let test_cases = vec![
(2, 4.0), (3, 8.0), (4, 16.0), (5, 32.0), (8, 256.0), (10, 1024.0), (16, 65536.0), ];
for (exp, expected) in test_cases {
let expr = ASTEval::pow(ASTEval::var(x_idx), ASTEval::constant(exp as f64));
let mut compiler = CraneliftCompiler::new_default().unwrap();
let compiled = compiler.compile_expression(&expr, ®istry).unwrap();
let result = compiled.call(&[2.0]).unwrap();
assert!(
(result - expected).abs() < 1e-10,
"2^{} = {} but got {}",
exp, expected, result
);
}
let expr = ASTEval::pow(ASTEval::var(x_idx), ASTEval::constant(-2.0));
let mut compiler = CraneliftCompiler::new_default().unwrap();
let compiled = compiler.compile_expression(&expr, ®istry).unwrap();
let result = compiled.call(&[2.0]).unwrap();
assert!((result - 0.25).abs() < 1e-10);
let expr = ASTEval::pow(ASTEval::var(x_idx), ASTEval::constant(0.5));
let mut compiler = CraneliftCompiler::new_default().unwrap();
let compiled = compiler.compile_expression(&expr, ®istry).unwrap();
let result = compiled.call(&[4.0]).unwrap();
assert!((result - 2.0).abs() < 1e-10);
let expr = ASTEval::pow(ASTEval::var(x_idx), ASTEval::constant(-0.5));
let mut compiler = CraneliftCompiler::new_default().unwrap();
let compiled = compiler.compile_expression(&expr, ®istry).unwrap();
let result = compiled.call(&[4.0]).unwrap();
assert!((result - 0.5).abs() < 1e-10); }
#[test]
fn test_transcendental_functions() {
let mut registry = VariableRegistry::new();
let x_idx = registry.register_variable();
let sin_expr = ASTEval::sin(ASTEval::var(x_idx));
let mut compiler = CraneliftCompiler::new_default().unwrap();
let compiled = compiler.compile_expression(&sin_expr, ®istry).unwrap();
let result = compiled.call(&[0.0]).unwrap();
assert!((result - 0.0).abs() < 1e-10);
let result = compiled.call(&[std::f64::consts::PI / 2.0]).unwrap();
assert!((result - 1.0).abs() < 1e-10);
let cos_expr = ASTEval::cos(ASTEval::var(x_idx));
let mut compiler = CraneliftCompiler::new_default().unwrap();
let compiled = compiler.compile_expression(&cos_expr, ®istry).unwrap();
let result = compiled.call(&[0.0]).unwrap();
assert!((result - 1.0).abs() < 1e-10);
let exp_expr = ASTEval::exp(ASTEval::var(x_idx));
let mut compiler = CraneliftCompiler::new_default().unwrap();
let compiled = compiler.compile_expression(&exp_expr, ®istry).unwrap();
let result = compiled.call(&[0.0]).unwrap();
assert!((result - 1.0).abs() < 1e-10);
let result = compiled.call(&[1.0]).unwrap();
assert!((result - std::f64::consts::E).abs() < 1e-10);
let ln_expr = ASTEval::ln(ASTEval::var(x_idx));
let mut compiler = CraneliftCompiler::new_default().unwrap();
let compiled = compiler.compile_expression(&ln_expr, ®istry).unwrap();
let result = compiled.call(&[1.0]).unwrap();
assert!((result - 0.0).abs() < 1e-10);
let result = compiled.call(&[std::f64::consts::E]).unwrap();
assert!((result - 1.0).abs() < 1e-10); }
}