Skip to main content

shape_jit/compiler/
setup.rs

1//! JIT Compiler initialization and setup
2
3use cranelift::prelude::*;
4use cranelift_jit::{JITBuilder, JITModule};
5use std::collections::HashMap;
6use std::sync::Mutex;
7
8use crate::context::{JITConfig, SimulationKernelConfig};
9use crate::error::JitError;
10use crate::ffi_symbols::{declare_ffi_functions, register_ffi_symbols};
11use shape_runtime::simulation::{KernelCompileConfig, KernelCompiler, SimulationKernelFn};
12use shape_vm::bytecode::BytecodeProgram;
13
14pub struct JITCompiler {
15    pub(super) module: JITModule,
16    pub(super) builder_context: FunctionBuilderContext,
17    #[allow(dead_code)]
18    pub(super) config: JITConfig,
19    pub(super) compiled_functions: HashMap<String, *const u8>,
20    pub(super) ffi_funcs: HashMap<String, cranelift_module::FuncId>,
21    pub(super) function_table: Vec<*const u8>,
22    /// Maps function_id → FuncId of their `opt_dc_*` direct-call entry point.
23    /// Used for cross-function speculative direct calls: when function A has
24    /// monomorphic feedback for callee B, and B has already been Tier-2
25    /// compiled, A can emit a direct `call` to B's direct-call entry.
26    pub(super) compiled_dc_funcs: HashMap<u16, (cranelift_module::FuncId, u16)>,
27}
28
29impl JITCompiler {
30    /// Borrow the underlying JITModule (for declaring/defining functions).
31    pub fn module_mut(&mut self) -> &mut JITModule {
32        &mut self.module
33    }
34
35    /// Borrow the FunctionBuilderContext (reused across compilations).
36    pub fn builder_context_mut(&mut self) -> &mut FunctionBuilderContext {
37        &mut self.builder_context
38    }
39}
40
41impl JITCompiler {
42    #[inline(always)]
43    pub fn new(config: JITConfig) -> Result<Self, JitError> {
44        let mut flag_builder = settings::builder();
45        let opt_level_str = if config.opt_level >= 2 {
46            "speed"
47        } else {
48            "speed_and_size"
49        };
50        flag_builder
51            .set("opt_level", opt_level_str)
52            .map_err(|e| JitError::Setup(format!("Failed to set opt_level: {}", e)))?;
53        flag_builder
54            .set("is_pic", "false")
55            .map_err(|e| JitError::Setup(format!("Failed to set is_pic: {}", e)))?;
56
57        let isa_builder = cranelift_native::builder()
58            .map_err(|e| JitError::Setup(format!("Failed to create ISA builder: {}", e)))?;
59        let isa = isa_builder
60            .finish(settings::Flags::new(flag_builder))
61            .map_err(|e| JitError::Setup(format!("Failed to create ISA: {}", e)))?;
62
63        let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
64
65        register_ffi_symbols(&mut builder);
66
67        let mut module = JITModule::new(builder);
68
69        let ffi_funcs = declare_ffi_functions(&mut module);
70
71        Ok(Self {
72            module,
73            builder_context: FunctionBuilderContext::new(),
74            config,
75            compiled_functions: HashMap::new(),
76            ffi_funcs,
77            function_table: Vec::new(),
78            compiled_dc_funcs: HashMap::new(),
79        })
80    }
81}
82
83// ============================================================================
84// KernelCompiler Trait Implementation
85// ============================================================================
86
87/// Thread-safe wrapper around JITCompiler for use with ExecutionContext.
88///
89/// This wrapper implements the `KernelCompiler` trait from `shape-runtime`,
90/// enabling JIT kernel compilation to be injected into ExecutionContext without
91/// circular dependencies.
92pub struct JITKernelCompiler {
93    /// Inner JIT compiler protected by mutex for thread safety
94    compiler: Mutex<JITCompiler>,
95}
96
97impl JITKernelCompiler {
98    /// Create a new JIT kernel compiler with default configuration.
99    pub fn new() -> Result<Self, JitError> {
100        Ok(Self {
101            compiler: Mutex::new(JITCompiler::new(JITConfig::default())?),
102        })
103    }
104
105    /// Create a new JIT kernel compiler with custom configuration.
106    pub fn with_config(config: JITConfig) -> Result<Self, JitError> {
107        Ok(Self {
108            compiler: Mutex::new(JITCompiler::new(config)?),
109        })
110    }
111}
112
113impl Default for JITKernelCompiler {
114    fn default() -> Self {
115        Self::new().expect("Failed to create JIT compiler with default config")
116    }
117}
118
119// Safety: JITKernelCompiler is thread-safe because:
120// 1. All access to the inner JITCompiler is protected by a Mutex
121// 2. The raw pointers in JITCompiler are function pointers to compiled code,
122//    which are immutable after compilation
123// 3. We never expose the raw pointers outside the Mutex
124unsafe impl Send for JITKernelCompiler {}
125unsafe impl Sync for JITKernelCompiler {}
126
127impl KernelCompiler for JITKernelCompiler {
128    fn compile_kernel(
129        &self,
130        name: &str,
131        function_bytecode: &[u8],
132        config: &KernelCompileConfig,
133    ) -> Result<SimulationKernelFn, String> {
134        // Deserialize bytecode (wire-native MessagePack first, JSON fallback for compatibility).
135        let program: BytecodeProgram = rmp_serde::from_slice(function_bytecode)
136            .or_else(|mp_err| {
137                serde_json::from_slice(function_bytecode).map_err(|json_err| {
138                    format!(
139                        "Failed to deserialize bytecode as MessagePack ({mp_err}) or JSON ({json_err})"
140                    )
141                })
142            })?;
143
144        // Convert KernelCompileConfig to SimulationKernelConfig
145        let mut jit_config =
146            SimulationKernelConfig::new(config.state_schema_id, config.column_count);
147
148        // Add state field offsets
149        for (field_name, offset) in &config.state_field_offsets {
150            jit_config
151                .state_field_offsets
152                .push((field_name.clone(), *offset));
153        }
154
155        // Add column mappings
156        for (col_name, idx) in &config.column_map {
157            jit_config.column_map.push((col_name.clone(), *idx));
158        }
159
160        // Acquire lock and compile
161        let mut compiler = self
162            .compiler
163            .lock()
164            .map_err(|e| format!("Failed to acquire JIT compiler lock: {}", e))?;
165
166        // Call the JIT compiler's compile_simulation_kernel
167        let kernel_fn = compiler.compile_simulation_kernel(name, &program, &jit_config)?;
168
169        // The function pointer type is the same, just transmute
170        // Safety: SimulationKernelFn in shape-jit has the same signature as in shape-runtime
171        Ok(unsafe { std::mem::transmute(kernel_fn) })
172    }
173
174    fn supports_feature(&self, feature: &str) -> bool {
175        match feature {
176            "typed_object" => true,
177            "closures" => false, // Phase 1: no closure support
178            "multi_table" => true,
179            _ => false,
180        }
181    }
182}