use std::error::Error as StdError;
use std::fmt;
use std::sync::Arc;
use super::types::{JitNumeric, JitType, NumericValue, TypedVector};
#[cfg(feature = "jit")]
use cranelift::prelude::*;
#[cfg(feature = "jit")]
use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
#[cfg(feature = "jit")]
use cranelift_jit::JITModule;
#[cfg(feature = "jit")]
use cranelift_module::Module;
#[derive(Debug)]
pub enum JitError {
CompilationError(String),
ExecutionError(String),
FeatureNotAvailable(String),
}
impl fmt::Display for JitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
JitError::CompilationError(s) => write!(f, "JIT compilation error: {}", s),
JitError::ExecutionError(s) => write!(f, "JIT execution error: {}", s),
JitError::FeatureNotAvailable(s) => write!(f, "JIT feature not available: {}", s),
}
}
}
impl StdError for JitError {}
pub type JitResult<T> = Result<T, JitError>;
pub trait JitCompilable<Args, Result> {
fn execute(&self, args: Args) -> Result;
}
pub trait GenericJitCompilable {
fn execute_typed(&self, args: TypedVector) -> NumericValue;
fn input_type_name(&self) -> &'static str;
fn output_type_name(&self) -> &'static str;
}
pub type FloatArrayFn = dyn Fn(Vec<f64>) -> f64 + Send + Sync;
pub type Float32ArrayFn = dyn Fn(Vec<f32>) -> f32 + Send + Sync;
pub type Int64ArrayFn = dyn Fn(Vec<i64>) -> i64 + Send + Sync;
pub type Int32ArrayFn = dyn Fn(Vec<i32>) -> i32 + Send + Sync;
#[derive(Clone)]
pub struct JitFunction {
name: String,
native_fn: Arc<FloatArrayFn>,
input_type: &'static str,
output_type: &'static str,
#[cfg(feature = "jit")]
jit_context: Option<Arc<JitContext>>,
}
#[derive(Default, Debug, Clone)]
pub struct JitStats {
pub executions: u64,
pub execution_time_ns: u64,
pub jit_used: u64,
pub native_used: u64,
}
impl JitStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_jit_execution(&mut self, duration_ns: u64) {
self.executions += 1;
self.execution_time_ns += duration_ns;
self.jit_used += 1;
}
pub fn record_native_execution(&mut self, duration_ns: u64) {
self.executions += 1;
self.execution_time_ns += duration_ns;
self.native_used += 1;
}
pub fn average_execution_time_ns(&self) -> f64 {
if self.executions > 0 {
self.execution_time_ns as f64 / self.executions as f64
} else {
0.0
}
}
}
impl JitFunction {
pub fn new<F>(name: impl Into<String>, native_fn: F) -> Self
where
F: Fn(Vec<f64>) -> f64 + Send + Sync + 'static,
{
Self {
name: name.into(),
native_fn: Arc::new(native_fn),
input_type: "f64",
output_type: "f64",
#[cfg(feature = "jit")]
jit_context: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
#[cfg(feature = "jit")]
pub fn with_jit(mut self) -> JitResult<Self> {
let name = self.name.clone();
match JitContext::compile(&name) {
Ok(ctx) => {
self.jit_context = Some(Arc::new(ctx));
Ok(self)
}
Err(e) => Err(e),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn input_type(&self) -> &'static str {
self.input_type
}
pub fn output_type(&self) -> &'static str {
self.output_type
}
}
impl JitCompilable<Vec<f64>, f64> for JitFunction {
fn execute(&self, args: Vec<f64>) -> f64 {
let start = std::time::Instant::now();
#[cfg(feature = "jit")]
{
if let Some(ctx) = &self.jit_context {
match ctx.execute_array_sum(&args) {
Ok(result) => {
let duration = start.elapsed().as_nanos() as u64;
return result;
}
Err(_) => {
}
}
}
}
let result = (self.native_fn)(args);
let duration = start.elapsed().as_nanos() as u64;
result
}
}
#[cfg(feature = "jit")]
pub struct JitContext {
name: String,
compiled_fn: Option<*const u8>,
#[cfg(feature = "jit")]
jit_module: Option<cranelift_jit::JITModule>,
}
#[cfg(feature = "jit")]
impl JitContext {
pub fn compile(name: &str) -> JitResult<Self> {
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{Linkage, Module};
use target_lexicon::Triple;
let isa = cranelift_native::builder()
.map_err(|e| {
JitError::CompilationError(format!("Failed to create ISA builder: {}", e))
})?
.finish(settings::Flags::new(settings::builder()))
.map_err(|e| JitError::CompilationError(format!("Failed to finish ISA: {}", e)))?;
let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
let mut module = JITModule::new(builder);
let mut sig = module.make_signature();
sig.params.push(AbiParam::new(types::I64)); sig.params.push(AbiParam::new(types::I64)); sig.returns.push(AbiParam::new(types::F64));
let func_id = module
.declare_function(name, Linkage::Export, &sig)
.map_err(|e| {
JitError::CompilationError(format!("Function declaration failed: {}", e))
})?;
let mut ctx = module.make_context();
let mut builder_ctx = codegen::Context::new();
builder_ctx.func.signature = sig.clone();
{
use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
let mut func_ctx = FunctionBuilderContext::new();
let mut builder = FunctionBuilder::new(&mut builder_ctx.func, &mut func_ctx);
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
builder.seal_block(entry_block);
let array_ptr = builder.block_params(entry_block)[0];
let array_len = builder.block_params(entry_block)[1];
let zero = builder.ins().f64const(0.0);
let sum = builder.declare_var(types::F64);
builder.def_var(sum, zero);
let counter = builder.declare_var(types::I64);
let zero_i64 = builder.ins().iconst(types::I64, 0);
builder.def_var(counter, zero_i64);
let loop_header = builder.create_block();
let loop_body = builder.create_block();
let loop_end = builder.create_block();
builder.ins().jump(loop_header, &[]);
builder.switch_to_block(loop_header);
let current_counter = builder.use_var(counter);
let condition = builder
.ins()
.icmp(IntCC::UnsignedLessThan, current_counter, array_len);
builder.ins().brif(condition, loop_body, &[], loop_end, &[]);
builder.switch_to_block(loop_body);
let current_counter = builder.use_var(counter);
let element_offset = builder.ins().imul_imm(current_counter, 8); let element_ptr = builder.ins().iadd(array_ptr, element_offset);
let element_value = builder
.ins()
.load(types::F64, MemFlags::new(), element_ptr, 0);
let current_sum = builder.use_var(sum);
let new_sum = builder.ins().fadd(current_sum, element_value);
builder.def_var(sum, new_sum);
let one = builder.ins().iconst(types::I64, 1);
let next_counter = builder.ins().iadd(current_counter, one);
builder.def_var(counter, next_counter);
builder.ins().jump(loop_header, &[]);
builder.switch_to_block(loop_end);
let final_sum = builder.use_var(sum);
builder.ins().return_(&[final_sum]);
builder.seal_block(loop_header);
builder.seal_block(loop_body);
builder.seal_block(loop_end);
builder.finalize();
}
ctx.func = builder_ctx.func;
module.define_function(func_id, &mut ctx).map_err(|e| {
JitError::CompilationError(format!("Function definition failed: {}", e))
})?;
module.finalize_definitions().map_err(|e| {
JitError::CompilationError(format!("Failed to finalize definitions: {}", e))
})?;
let compiled_fn = module.get_finalized_function(func_id);
Ok(Self {
name: name.to_string(),
compiled_fn: Some(compiled_fn),
jit_module: Some(module),
})
}
pub fn execute_array_sum(&self, data: &[f64]) -> JitResult<f64> {
if let Some(func_ptr) = self.compiled_fn {
let func: unsafe extern "C" fn(*const f64, i64) -> f64 =
unsafe { std::mem::transmute(func_ptr) };
let result = unsafe { func(data.as_ptr(), data.len() as i64) };
Ok(result)
} else {
Err(JitError::ExecutionError(
"Function not compiled".to_string(),
))
}
}
}
pub fn jit<F>(name: impl Into<String>, f: F) -> JitFunction
where
F: Fn(Vec<f64>) -> f64 + Send + Sync + 'static,
{
let func = JitFunction::new(name, f);
#[cfg(feature = "jit")]
{
return func;
}
#[cfg(not(feature = "jit"))]
{
func
}
}