use std::sync::Arc;
use crate::cache::FunctionCache;
use crate::error::{JitError, JitResult};
use crate::ir::{Graph, Node, NodeId, Op};
use crate::optimize::Optimizer;
#[derive(Clone)]
pub struct CompiledFunction {
graph: Arc<Graph>,
kind: CompiledKind,
}
#[derive(Clone)]
enum CompiledKind {
Interpreted,
Native {
code_ptr: *const u8,
code_size: usize,
},
}
unsafe impl Send for CompiledKind {}
unsafe impl Sync for CompiledKind {}
impl CompiledFunction {
pub fn placeholder() -> Self {
Self {
graph: Arc::new(Graph::new()),
kind: CompiledKind::Interpreted,
}
}
pub fn graph(&self) -> &Graph {
&self.graph
}
pub fn run(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
match &self.kind {
CompiledKind::Interpreted => self.run_interpreted(inputs),
CompiledKind::Native {
code_ptr,
code_size,
} => {
unsafe {
let func: extern "C" fn(*const f32, *mut f32) = std::mem::transmute(code_ptr);
let flat_inputs: Vec<f32> =
inputs.iter().flat_map(|(_, d)| d.iter().copied()).collect();
let output_size: usize = self.graph.outputs().values()
.map(|id| self.graph.node(*id).shape.numel())
.sum();
let mut output = vec![0.0f32; output_size];
func(flat_inputs.as_ptr(), output.as_mut_ptr());
let _ = code_size; Ok(output)
}
}
}
}
fn run_interpreted(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
let mut values: Vec<Option<Vec<f32>>> = vec![None; self.graph.len()];
for (name, data) in inputs {
if let Some(id) = self.graph.input(name) {
values[id.index()] = Some(data.to_vec());
} else {
return Err(JitError::InputNotFound(name.to_string()));
}
}
for node in self.graph.nodes() {
let result = self.eval_node(node, &values)?;
values[node.id.index()] = Some(result);
}
if let Some((_, output_id)) = self.graph.outputs().iter().next() {
let output_node = self.graph.node(*output_id);
if let Op::Output { input, .. } = &output_node.op {
return Ok(values[input.index()].clone().unwrap_or_default());
}
}
Err(JitError::OutputNotFound("no output".to_string()))
}
fn eval_node(&self, node: &Node, values: &[Option<Vec<f32>>]) -> JitResult<Vec<f32>> {
let get = |id: NodeId| -> JitResult<&Vec<f32>> {
values[id.index()]
.as_ref()
.ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not computed", id)))
};
match &node.op {
Op::Input { .. } => {
Ok(values[node.id.index()].clone().unwrap_or_default())
}
Op::Output { input, .. } => Ok(get(*input)?.clone()),
Op::Constant { value } => {
let numel = node.shape.numel();
Ok(vec![*value as f32; numel])
}
Op::Add { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
}
Op::Sub { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
}
Op::Mul { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
}
Op::Div { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
}
Op::Pow { base, exp } => {
let a = get(*base)?;
let b = get(*exp)?;
Ok(a.iter().zip(b.iter()).map(|(x, y)| x.powf(*y)).collect())
}
Op::Max { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).collect())
}
Op::Min { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).collect())
}
Op::AddScalar { input, scalar } => {
let a = get(*input)?;
Ok(a.iter().map(|x| x + *scalar as f32).collect())
}
Op::MulScalar { input, scalar } => {
let a = get(*input)?;
Ok(a.iter().map(|x| x * *scalar as f32).collect())
}
Op::Neg { input } => {
let a = get(*input)?;
Ok(a.iter().map(|x| -x).collect())
}
Op::Abs { input } => {
let a = get(*input)?;
Ok(a.iter().map(|x| x.abs()).collect())
}
Op::Sqrt { input } => {
let a = get(*input)?;
Ok(a.iter().map(|x| x.sqrt()).collect())
}
Op::Exp { input } => {
let a = get(*input)?;
Ok(a.iter().map(|x| x.exp()).collect())
}
Op::Log { input } => {
let a = get(*input)?;
Ok(a.iter().map(|x| x.ln()).collect())
}
Op::Sin { input } => {
let a = get(*input)?;
Ok(a.iter().map(|x| x.sin()).collect())
}
Op::Cos { input } => {
let a = get(*input)?;
Ok(a.iter().map(|x| x.cos()).collect())
}
Op::Tanh { input } => {
let a = get(*input)?;
Ok(a.iter().map(|x| x.tanh()).collect())
}
Op::Relu { input } => {
let a = get(*input)?;
Ok(a.iter().map(|x| x.max(0.0)).collect())
}
Op::Sigmoid { input } => {
let a = get(*input)?;
Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
}
Op::Gelu { input } => {
let a = get(*input)?;
const SQRT_2_OVER_PI: f32 = 0.797_884_6;
Ok(a.iter()
.map(|x| 0.5 * x * (1.0 + (SQRT_2_OVER_PI * (x + 0.044715 * x.powi(3))).tanh()))
.collect())
}
Op::Silu { input } => {
let a = get(*input)?;
Ok(a.iter().map(|x| x / (1.0 + (-x).exp())).collect())
}
Op::Sum { input } => {
let a = get(*input)?;
Ok(vec![a.iter().sum()])
}
Op::Mean { input } => {
let a = get(*input)?;
let sum: f32 = a.iter().sum();
Ok(vec![sum / a.len() as f32])
}
Op::SumAxis {
input,
axis,
keepdim,
} => {
let a = get(*input)?;
let input_node = self.graph.node(*input);
let input_shape = input_node.shape.dims();
reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)
}
Op::MeanAxis {
input,
axis,
keepdim,
} => {
let a = get(*input)?;
let input_node = self.graph.node(*input);
let input_shape = input_node.shape.dims();
let axis_size = input_shape[normalize_axis(*axis, input_shape.len())];
let sum = reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)?;
Ok(sum.iter().map(|x| x / axis_size as f32).collect())
}
Op::MaxAxis {
input,
axis,
keepdim,
} => {
let a = get(*input)?;
let input_node = self.graph.node(*input);
let input_shape = input_node.shape.dims();
reduce_axis(a, input_shape, *axis, *keepdim, f32::max, f32::NEG_INFINITY)
}
Op::Reshape { input, .. }
| Op::Squeeze { input, .. }
| Op::Unsqueeze { input, .. }
| Op::Broadcast { input, .. }
| Op::Contiguous { input } => Ok(get(*input)?.clone()),
Op::Transpose { input, dim0, dim1 } => {
let a = get(*input)?;
let input_shape = &self.graph.node(*input).shape;
let ndim = input_shape.dims().len();
if ndim < 2 || *dim0 >= ndim || *dim1 >= ndim || dim0 == dim1 {
return Ok(a.clone());
}
let dims = input_shape.dims();
let mut perm: Vec<usize> = (0..ndim).collect();
perm.swap(*dim0, *dim1);
let mut new_shape: Vec<usize> = perm.iter().map(|&d| dims[d]).collect();
let numel: usize = dims.iter().product();
let mut result = vec![0.0f32; numel];
let mut in_strides = vec![1usize; ndim];
for d in (0..ndim - 1).rev() {
in_strides[d] = in_strides[d + 1] * dims[d + 1];
}
let mut out_strides = vec![1usize; ndim];
for d in (0..ndim - 1).rev() {
out_strides[d] = out_strides[d + 1] * new_shape[d + 1];
}
for flat in 0..numel {
let mut remaining = flat;
let mut out_idx = vec![0usize; ndim];
for d in 0..ndim {
out_idx[d] = remaining / out_strides[d];
remaining %= out_strides[d];
}
let mut in_flat = 0;
for d in 0..ndim {
in_flat += out_idx[d] * in_strides[perm[d]];
}
result[flat] = a[in_flat];
}
Ok(result)
}
Op::MatMul { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
let lhs_node = self.graph.node(*lhs);
let rhs_node = self.graph.node(*rhs);
let lhs_shape = lhs_node.shape.dims();
let rhs_shape = rhs_node.shape.dims();
matmul_impl(a, b, lhs_shape, rhs_shape)
}
Op::Gt { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(a.iter()
.zip(b.iter())
.map(|(x, y)| if x > y { 1.0 } else { 0.0 })
.collect())
}
Op::Lt { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(a.iter()
.zip(b.iter())
.map(|(x, y)| if x < y { 1.0 } else { 0.0 })
.collect())
}
Op::Eq { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(a.iter()
.zip(b.iter())
.map(|(x, y)| {
if (x - y).abs() < f32::EPSILON {
1.0
} else {
0.0
}
})
.collect())
}
Op::Where { condition, x, y } => {
let cond = get(*condition)?;
let a = get(*x)?;
let b = get(*y)?;
Ok(cond
.iter()
.zip(a.iter().zip(b.iter()))
.map(|(c, (a, b))| if *c != 0.0 { *a } else { *b })
.collect())
}
Op::Cast { input, .. } => {
Ok(get(*input)?.clone())
}
}
}
}
fn normalize_axis(axis: i32, ndim: usize) -> usize {
if axis < 0 {
(ndim as i32 + axis) as usize
} else {
axis as usize
}
}
fn reduce_axis(
data: &[f32],
shape: &[usize],
axis: i32,
keepdim: bool,
op: fn(f32, f32) -> f32,
init: f32,
) -> JitResult<Vec<f32>> {
let axis = normalize_axis(axis, shape.len());
let mut strides = vec![1usize; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let mut output_shape: Vec<usize> = shape.to_vec();
if keepdim {
output_shape[axis] = 1;
} else {
output_shape.remove(axis);
}
let output_numel: usize = output_shape.iter().product();
let mut result = vec![init; output_numel];
for (i, &val) in data.iter().enumerate() {
let mut multi_idx = vec![0usize; shape.len()];
let mut idx = i;
for (d, &st) in strides.iter().enumerate() {
multi_idx[d] = idx / st;
idx %= st;
}
let out_idx = if keepdim {
let mut out_idx = 0;
let mut temp_strides = vec![1usize; output_shape.len()];
for d in (0..output_shape.len() - 1).rev() {
temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
}
for d in 0..output_shape.len() {
let dim_idx = if d == axis { 0 } else { multi_idx[d] };
out_idx += dim_idx * temp_strides[d];
}
out_idx
} else {
let mut out_idx = 0;
let mut temp_strides = vec![1usize; output_shape.len()];
if !output_shape.is_empty() {
for d in (0..output_shape.len() - 1).rev() {
temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
}
}
let mut out_d = 0;
for (d, &mi) in multi_idx.iter().enumerate().take(shape.len()) {
if d == axis {
continue;
}
if out_d < temp_strides.len() {
out_idx += mi * temp_strides[out_d];
}
out_d += 1;
}
out_idx
};
if out_idx < result.len() {
result[out_idx] = op(result[out_idx], val);
}
}
Ok(result)
}
fn matmul_impl(a: &[f32], b: &[f32], a_shape: &[usize], b_shape: &[usize]) -> JitResult<Vec<f32>> {
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(JitError::UnsupportedOp(
"Only 2D matmul supported in interpreter".to_string(),
));
}
let m = a_shape[0];
let k = a_shape[1];
let n = b_shape[1];
if k != b_shape[0] {
return Err(JitError::ShapeMismatch {
expected: vec![k],
found: vec![b_shape[0]],
});
}
let mut result = vec![0.0f32; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0;
for p in 0..k {
sum += a[i * k + p] * b[p * n + j];
}
result[i * n + j] = sum;
}
}
Ok(result)
}
pub struct JitCompiler {
optimizer: Optimizer,
cache: FunctionCache,
use_native: bool,
}
impl JitCompiler {
pub fn new() -> Self {
Self {
optimizer: Optimizer::default_passes(),
cache: FunctionCache::default_size(),
use_native: false, }
}
pub fn with_optimizer(optimizer: Optimizer) -> Self {
Self {
optimizer,
cache: FunctionCache::default_size(),
use_native: false,
}
}
pub fn enable_native(&mut self, enable: bool) {
self.use_native = enable;
}
pub fn compile(&self, graph: &Graph) -> JitResult<CompiledFunction> {
let cache_key = FunctionCache::hash_graph(graph);
if let Some(cached) = self.cache.get(cache_key) {
return Ok(cached);
}
graph.validate().map_err(JitError::InvalidGraph)?;
let optimized = self.optimizer.optimize(graph.clone());
let func = if self.use_native {
self.compile_native(&optimized)?
} else {
self.compile_interpreted(&optimized)
};
self.cache.insert(cache_key, func.clone());
Ok(func)
}
fn compile_interpreted(&self, graph: &Graph) -> CompiledFunction {
CompiledFunction {
graph: Arc::new(graph.clone()),
kind: CompiledKind::Interpreted,
}
}
fn compile_native(&self, graph: &Graph) -> JitResult<CompiledFunction> {
use cranelift::prelude::*;
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{Linkage, Module};
let mut flag_builder = settings::builder();
flag_builder.set("use_colocated_libcalls", "false").unwrap();
flag_builder.set("is_pic", "false").unwrap();
let isa_builder = cranelift_native::builder()
.map_err(|e| JitError::CompilationFailed(format!("Failed to get native ISA: {}", e)))?;
let isa = isa_builder
.finish(settings::Flags::new(flag_builder))
.map_err(|e| JitError::CompilationFailed(format!("Failed to build ISA: {}", e)))?;
let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
let mut module = JITModule::new(builder);
let mut sig = module.make_signature();
sig.params.push(AbiParam::new(types::I64)); sig.params.push(AbiParam::new(types::I64));
let func_id = module
.declare_function("jit_kernel", Linkage::Export, &sig)
.map_err(|e| {
JitError::CompilationFailed(format!("Failed to declare function: {}", e))
})?;
let mut ctx = module.make_context();
ctx.func.signature = sig;
let mut builder_ctx = FunctionBuilderContext::new();
{
let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
builder.seal_block(entry_block);
let input_ptr = builder.block_params(entry_block)[0];
let output_ptr = builder.block_params(entry_block)[1];
let mut values: Vec<Option<Value>> = vec![None; graph.len()];
for node in graph.nodes() {
let result = self.codegen_node(&mut builder, node, &values, input_ptr)?;
values[node.id.index()] = Some(result);
}
if let Some((_, output_id)) = graph.outputs().iter().next() {
let output_node = graph.node(*output_id);
if let Op::Output { input, .. } = &output_node.op {
if let Some(val) = values[input.index()] {
builder.ins().store(MemFlags::new(), val, output_ptr, 0);
}
}
}
builder.ins().return_(&[]);
builder.finalize();
}
module.define_function(func_id, &mut ctx).map_err(|e| {
JitError::CompilationFailed(format!("Failed to define function: {}", e))
})?;
module.clear_context(&mut ctx);
module
.finalize_definitions()
.map_err(|e| JitError::CompilationFailed(format!("Failed to finalize: {:?}", e)))?;
let code_ptr = module.get_finalized_function(func_id);
let code_size = 0;
std::mem::forget(module);
Ok(CompiledFunction {
graph: Arc::new(graph.clone()),
kind: CompiledKind::Native {
code_ptr,
code_size,
},
})
}
fn codegen_node(
&self,
builder: &mut cranelift::prelude::FunctionBuilder,
node: &Node,
values: &[Option<cranelift::prelude::Value>],
input_ptr: cranelift::prelude::Value,
) -> JitResult<cranelift::prelude::Value> {
use cranelift::prelude::*;
let get = |id: NodeId| -> JitResult<Value> {
values[id.index()]
.ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not compiled", id)))
};
match &node.op {
Op::Input { name, .. } => {
let offset = self.get_input_offset(name);
Ok(builder
.ins()
.load(types::F32, MemFlags::new(), input_ptr, offset))
}
Op::Output { input, .. } => get(*input),
Op::Constant { value } => Ok(builder.ins().f32const(*value as f32)),
Op::Add { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(builder.ins().fadd(a, b))
}
Op::Sub { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(builder.ins().fsub(a, b))
}
Op::Mul { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(builder.ins().fmul(a, b))
}
Op::Div { lhs, rhs } => {
let a = get(*lhs)?;
let b = get(*rhs)?;
Ok(builder.ins().fdiv(a, b))
}
Op::Neg { input } => {
let a = get(*input)?;
Ok(builder.ins().fneg(a))
}
Op::Abs { input } => {
let a = get(*input)?;
Ok(builder.ins().fabs(a))
}
Op::Sqrt { input } => {
let a = get(*input)?;
Ok(builder.ins().sqrt(a))
}
Op::AddScalar { input, scalar } => {
let a = get(*input)?;
let s = builder.ins().f32const(*scalar as f32);
Ok(builder.ins().fadd(a, s))
}
Op::MulScalar { input, scalar } => {
let a = get(*input)?;
let s = builder.ins().f32const(*scalar as f32);
Ok(builder.ins().fmul(a, s))
}
_ => Err(JitError::UnsupportedOp(format!(
"Operation {:?} not supported in native codegen, using interpreter",
node.op
))),
}
}
fn get_input_offset(&self, _name: &str) -> i32 {
0
}
pub fn cache_stats(&self) -> crate::cache::CacheStats {
self.cache.stats()
}
pub fn clear_cache(&self) {
self.cache.clear();
}
}
impl Default for JitCompiler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::trace::trace;
#[test]
fn test_compile_simple() {
let graph = trace(|tracer| {
let a = tracer.input("a", &[4]);
let b = tracer.input("b", &[4]);
let c = a.add(&b);
tracer.output("result", c)
});
let compiler = JitCompiler::new();
let func = compiler.compile(&graph).unwrap();
let a = [1.0, 2.0, 3.0, 4.0];
let b = [5.0, 6.0, 7.0, 8.0];
let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
}
#[test]
fn test_compile_chain() {
let graph = trace(|tracer| {
let x = tracer.input("x", &[4]);
let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
tracer.output("y", y)
});
let compiler = JitCompiler::new();
let func = compiler.compile(&graph).unwrap();
let x = [-1.0, 0.0, 1.0, 2.0];
let result = func.run(&[("x", &x)]).unwrap();
assert_eq!(result, vec![1.0, 1.0, 3.0, 5.0]);
}
#[test]
fn test_compile_activations() {
let graph = trace(|tracer| {
let x = tracer.input("x", &[3]);
let y = x.sigmoid();
tracer.output("y", y)
});
let compiler = JitCompiler::new();
let func = compiler.compile(&graph).unwrap();
let x = [0.0, 1.0, -1.0];
let result = func.run(&[("x", &x)]).unwrap();
assert!((result[0] - 0.5).abs() < 0.01);
assert!((result[1] - 0.731).abs() < 0.01);
}
#[test]
fn test_compile_matmul() {
let graph = trace(|tracer| {
let a = tracer.input("a", &[2, 3]);
let b = tracer.input("b", &[3, 2]);
let c = a.matmul(&b);
tracer.output("c", c)
});
let compiler = JitCompiler::new();
let func = compiler.compile(&graph).unwrap();
let a = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let b = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]; let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
assert_eq!(result.len(), 4); }
#[test]
fn test_caching() {
let graph = trace(|tracer| {
let x = tracer.input("x", &[4]);
tracer.output("y", x.relu())
});
let compiler = JitCompiler::new();
assert_eq!(compiler.cache_stats().entries, 0);
let _ = compiler.compile(&graph).unwrap();
assert_eq!(compiler.cache_stats().entries, 1);
let _ = compiler.compile(&graph).unwrap();
assert_eq!(compiler.cache_stats().entries, 1);
}
}