shape_jit/compiler/
setup.rs1use cranelift::prelude::*;
4use cranelift_jit::{JITBuilder, JITModule};
5use std::collections::HashMap;
6use std::sync::Mutex;
7
8use crate::context::{JITConfig, SimulationKernelConfig};
9use crate::error::JitError;
10use crate::ffi_symbols::{declare_ffi_functions, register_ffi_symbols};
11use shape_runtime::simulation::{KernelCompileConfig, KernelCompiler, SimulationKernelFn};
12use shape_vm::bytecode::BytecodeProgram;
13
14pub struct JITCompiler {
15 pub(super) module: JITModule,
16 pub(super) builder_context: FunctionBuilderContext,
17 #[allow(dead_code)]
18 pub(super) config: JITConfig,
19 pub(super) compiled_functions: HashMap<String, *const u8>,
20 pub(super) ffi_funcs: HashMap<String, cranelift_module::FuncId>,
21 pub(super) function_table: Vec<*const u8>,
22 pub(super) compiled_dc_funcs: HashMap<u16, (cranelift_module::FuncId, u16)>,
27}
28
29impl JITCompiler {
30 pub fn module_mut(&mut self) -> &mut JITModule {
32 &mut self.module
33 }
34
35 pub fn builder_context_mut(&mut self) -> &mut FunctionBuilderContext {
37 &mut self.builder_context
38 }
39}
40
41impl JITCompiler {
42 #[inline(always)]
43 pub fn new(config: JITConfig) -> Result<Self, JitError> {
44 let mut flag_builder = settings::builder();
45 let opt_level_str = if config.opt_level >= 2 {
46 "speed"
47 } else {
48 "speed_and_size"
49 };
50 flag_builder
51 .set("opt_level", opt_level_str)
52 .map_err(|e| JitError::Setup(format!("Failed to set opt_level: {}", e)))?;
53 flag_builder
54 .set("is_pic", "false")
55 .map_err(|e| JitError::Setup(format!("Failed to set is_pic: {}", e)))?;
56
57 let isa_builder = cranelift_native::builder()
58 .map_err(|e| JitError::Setup(format!("Failed to create ISA builder: {}", e)))?;
59 let isa = isa_builder
60 .finish(settings::Flags::new(flag_builder))
61 .map_err(|e| JitError::Setup(format!("Failed to create ISA: {}", e)))?;
62
63 let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
64
65 register_ffi_symbols(&mut builder);
66
67 let mut module = JITModule::new(builder);
68
69 let ffi_funcs = declare_ffi_functions(&mut module);
70
71 Ok(Self {
72 module,
73 builder_context: FunctionBuilderContext::new(),
74 config,
75 compiled_functions: HashMap::new(),
76 ffi_funcs,
77 function_table: Vec::new(),
78 compiled_dc_funcs: HashMap::new(),
79 })
80 }
81}
82
83pub struct JITKernelCompiler {
93 compiler: Mutex<JITCompiler>,
95}
96
97impl JITKernelCompiler {
98 pub fn new() -> Result<Self, JitError> {
100 Ok(Self {
101 compiler: Mutex::new(JITCompiler::new(JITConfig::default())?),
102 })
103 }
104
105 pub fn with_config(config: JITConfig) -> Result<Self, JitError> {
107 Ok(Self {
108 compiler: Mutex::new(JITCompiler::new(config)?),
109 })
110 }
111}
112
113impl Default for JITKernelCompiler {
114 fn default() -> Self {
115 Self::new().expect("Failed to create JIT compiler with default config")
116 }
117}
118
119unsafe impl Send for JITKernelCompiler {}
125unsafe impl Sync for JITKernelCompiler {}
126
127impl KernelCompiler for JITKernelCompiler {
128 fn compile_kernel(
129 &self,
130 name: &str,
131 function_bytecode: &[u8],
132 config: &KernelCompileConfig,
133 ) -> Result<SimulationKernelFn, String> {
134 let program: BytecodeProgram = rmp_serde::from_slice(function_bytecode)
136 .or_else(|mp_err| {
137 serde_json::from_slice(function_bytecode).map_err(|json_err| {
138 format!(
139 "Failed to deserialize bytecode as MessagePack ({mp_err}) or JSON ({json_err})"
140 )
141 })
142 })?;
143
144 let mut jit_config =
146 SimulationKernelConfig::new(config.state_schema_id, config.column_count);
147
148 for (field_name, offset) in &config.state_field_offsets {
150 jit_config
151 .state_field_offsets
152 .push((field_name.clone(), *offset));
153 }
154
155 for (col_name, idx) in &config.column_map {
157 jit_config.column_map.push((col_name.clone(), *idx));
158 }
159
160 let mut compiler = self
162 .compiler
163 .lock()
164 .map_err(|e| format!("Failed to acquire JIT compiler lock: {}", e))?;
165
166 let kernel_fn = compiler.compile_simulation_kernel(name, &program, &jit_config)?;
168
169 Ok(unsafe { std::mem::transmute(kernel_fn) })
172 }
173
174 fn supports_feature(&self, feature: &str) -> bool {
175 match feature {
176 "typed_object" => true,
177 "closures" => false, "multi_table" => true,
179 _ => false,
180 }
181 }
182}