Skip to main content

shape_jit/
executor.rs

1//! JIT executor implementing the ProgramExecutor trait
2
3use shape_ast::Program;
4use shape_runtime::engine::{ExecutionType, ProgramExecutor, ShapeEngine};
5use shape_runtime::error::Result;
6use shape_wire::WireValue;
7use std::time::Instant;
8
9/// JIT executor with selective per-function compilation.
10///
11/// JIT-compatible functions are compiled to native code; incompatible functions
12/// (e.g. those using async, pattern matching, or unsupported builtins) are left
13/// as `Interpreted` entries in the mixed function table for VM fallback.
14pub struct JITExecutor;
15
16impl ProgramExecutor for JITExecutor {
17    fn execute_program(
18        &mut self,
19        engine: &mut ShapeEngine,
20        program: &Program,
21    ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
22        use shape_vm::BytecodeCompiler;
23        let emit_phase_metrics = std::env::var_os("SHAPE_JIT_PHASE_METRICS").is_some();
24
25        // Capture source text before getting runtime reference (for error messages)
26        let source_for_compilation = engine.current_source().map(|s| s.to_string());
27
28        // Compile to bytecode first to check JIT compatibility
29        let runtime = engine.get_runtime_mut();
30
31        // Get known module bindings — prefer persistent context, fallback to precompiled names
32        let known_bindings: Vec<String> = if let Some(ctx) = runtime.persistent_context() {
33            let names = ctx.root_scope_binding_names();
34            if names.is_empty() {
35                shape_vm::stdlib::core_binding_names()
36            } else {
37                names
38            }
39        } else {
40            shape_vm::stdlib::core_binding_names()
41        };
42
43        // Build module graph and compile via graph pipeline
44        let mut loader = shape_runtime::module_loader::ModuleLoader::new();
45        let (graph, stdlib_names, prelude_imports) =
46            shape_vm::module_resolution::build_graph_and_stdlib_names(
47                program,
48                &mut loader,
49                &[],
50            )
51            .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
52                message: format!("Module graph construction failed: {}", e),
53                location: None,
54            })?;
55
56        let bytecode_compile_start = Instant::now();
57        let mut compiler = BytecodeCompiler::new();
58        compiler.stdlib_function_names = stdlib_names;
59        compiler.register_known_bindings(&known_bindings);
60        if let Some(source) = &source_for_compilation {
61            compiler.set_source(source);
62        }
63        let bytecode = compiler
64            .compile_with_graph_and_prelude(program, graph, &prelude_imports)
65            .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
66                message: format!("Bytecode compilation failed: {}", e),
67                location: None,
68            })?;
69        let bytecode_compile_ms = bytecode_compile_start.elapsed().as_millis();
70
71        self.execute_with_jit(engine, &bytecode, bytecode_compile_ms, emit_phase_metrics)
72    }
73}
74
75impl JITExecutor {
76    fn execute_with_jit(
77        &self,
78        engine: &mut ShapeEngine,
79        bytecode: &shape_vm::bytecode::BytecodeProgram,
80        bytecode_compile_ms: u128,
81        emit_phase_metrics: bool,
82    ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
83        use crate::JITConfig;
84        use crate::JITContext;
85        use crate::compiler::JITCompiler;
86
87        // JIT compile the bytecode
88        let jit_config = JITConfig::default();
89        let mut jit = JITCompiler::new(jit_config).map_err(|e| {
90            shape_runtime::error::ShapeError::RuntimeError {
91                message: format!("JIT compiler initialization failed: {}", e),
92                location: None,
93            }
94        })?;
95
96        // Use selective compilation: JIT-compatible functions get native code,
97        // incompatible ones get Interpreted entries for VM fallback.
98        if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
99            eprintln!(
100                "[jit-debug] starting compile_program_selective with {} instructions, {} functions",
101                bytecode.instructions.len(),
102                bytecode.functions.len()
103            );
104            for (i, instr) in bytecode.instructions.iter().enumerate() {
105                eprintln!(
106                    "[jit-debug] instr[{}]: {:?} {:?}",
107                    i, instr.opcode, instr.operand
108                );
109            }
110        }
111        let jit_compile_start = Instant::now();
112        let compile_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
113            jit.compile_program_selective("main", bytecode)
114        }));
115        let jit_compile_ms = jit_compile_start.elapsed().as_millis();
116        let (jit_fn, _mixed_table) = match compile_result {
117            Ok(Ok(result)) => result,
118            Ok(Err(e)) => {
119                return Err(shape_runtime::error::ShapeError::RuntimeError {
120                    message: format!("JIT compilation failed: {}", e),
121                    location: None,
122                });
123            }
124            Err(panic_info) => {
125                let msg = if let Some(s) = panic_info.downcast_ref::<String>() {
126                    s.clone()
127                } else if let Some(s) = panic_info.downcast_ref::<&str>() {
128                    s.to_string()
129                } else {
130                    "unknown panic".to_string()
131                };
132                return Err(shape_runtime::error::ShapeError::RuntimeError {
133                    message: format!("JIT compilation panicked: {}", msg),
134                    location: None,
135                });
136            }
137        };
138
139        let foreign_bridge = {
140            let runtime = engine.get_runtime_mut();
141            crate::foreign_bridge::link_foreign_functions_for_jit(
142                bytecode,
143                runtime.persistent_context(),
144            )
145            .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
146                message: format!("JIT foreign-function linking failed: {}", e),
147                location: None,
148            })?
149        };
150
151        // Create JIT context and execute
152        let mut jit_ctx = JITContext::default();
153        if let Some(state) = foreign_bridge.as_ref() {
154            jit_ctx.foreign_bridge_ptr = state.as_ref() as *const _ as *const std::ffi::c_void;
155        }
156
157        // Set exec_context_ptr so JIT FFI can access cached data
158        {
159            let runtime = engine.get_runtime_mut();
160            if let Some(ctx) = runtime.persistent_context_mut() {
161                jit_ctx.exec_context_ptr = ctx as *mut _ as *mut std::ffi::c_void;
162            }
163        }
164
165        // Execute the JIT-compiled function
166        if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
167            eprintln!("[jit-debug] compilation OK, about to execute...");
168        }
169        let jit_exec_start = Instant::now();
170        let signal = unsafe { jit_fn(&mut jit_ctx) };
171        let jit_exec_ms = jit_exec_start.elapsed().as_millis();
172
173        // Get result from JIT context stack via TypedScalar boundary
174        let raw_result = if jit_ctx.stack_ptr > 0 {
175            jit_ctx.stack[0]
176        } else {
177            crate::nan_boxing::TAG_NULL
178        };
179
180        // Check for errors
181        if signal < 0 {
182            return Err(shape_runtime::error::ShapeError::RuntimeError {
183                message: format!("JIT execution error (code: {})", signal),
184                location: None,
185            });
186        }
187
188        // Use FrameDescriptor hint to preserve integer type identity.
189        // Prefer return_kind when populated; fall back to last slot.
190        let return_hint = bytecode.top_level_frame.as_ref().and_then(|fd| {
191            if fd.return_kind != shape_vm::type_tracking::SlotKind::Unknown {
192                Some(fd.return_kind)
193            } else {
194                fd.slots.last().copied()
195            }
196        });
197        let result_scalar =
198            crate::ffi::object::conversion::jit_bits_to_typed_scalar(raw_result, return_hint);
199
200        // Convert TypedScalar to WireValue
201        let wire_value = self.typed_scalar_to_wire(&result_scalar, raw_result);
202
203        if emit_phase_metrics {
204            let total_ms = bytecode_compile_ms + jit_compile_ms + jit_exec_ms;
205            eprintln!(
206                "[shape-jit-phases] bytecode_compile_ms={} jit_compile_ms={} jit_exec_ms={} total_ms={}",
207                bytecode_compile_ms, jit_compile_ms, jit_exec_ms, total_ms
208            );
209        }
210
211        Ok(shape_runtime::engine::ProgramExecutorResult {
212            wire_value,
213            type_info: None,
214            execution_type: ExecutionType::Script,
215            content_json: None,
216            content_html: None,
217            content_terminal: None,
218        })
219    }
220
221    /// Convert a TypedScalar result to WireValue.
222    ///
223    /// For scalar types, the TypedScalar carries enough information. For heap types
224    /// (strings, arrays) that TypedScalar can't represent, we fall back to raw bits.
225    fn typed_scalar_to_wire(&self, ts: &shape_value::TypedScalar, raw_bits: u64) -> WireValue {
226        use shape_value::ScalarKind;
227
228        match ts.kind {
229            ScalarKind::I8
230            | ScalarKind::I16
231            | ScalarKind::I32
232            | ScalarKind::I64
233            | ScalarKind::U8
234            | ScalarKind::U16
235            | ScalarKind::U32
236            | ScalarKind::U64
237            | ScalarKind::I128
238            | ScalarKind::U128 => {
239                // Integer result — preserve as exact integer in WireValue::Number
240                WireValue::Number(ts.payload_lo as i64 as f64)
241            }
242            ScalarKind::F64 | ScalarKind::F32 => WireValue::Number(f64::from_bits(ts.payload_lo)),
243            ScalarKind::Bool => WireValue::Bool(ts.payload_lo != 0),
244            ScalarKind::Unit => WireValue::Null,
245            ScalarKind::None => {
246                // None could also be a fallback for non-scalar heap types.
247                // Check if raw_bits is actually a heap value.
248                self.nan_boxed_to_wire(raw_bits)
249            }
250        }
251    }
252
253    fn nan_boxed_to_wire(&self, bits: u64) -> WireValue {
254        use crate::nan_boxing::{
255            HK_STRING, TAG_BOOL_FALSE, TAG_BOOL_TRUE, TAG_NULL, is_heap_kind, is_number, jit_unbox,
256            unbox_number,
257        };
258        use shape_value::tags::{TAG_INT, get_payload, get_tag, is_tagged, sign_extend_i48};
259
260        if is_number(bits) {
261            WireValue::Number(unbox_number(bits))
262        } else if bits == TAG_NULL {
263            WireValue::Null
264        } else if bits == TAG_BOOL_TRUE {
265            WireValue::Bool(true)
266        } else if bits == TAG_BOOL_FALSE {
267            WireValue::Bool(false)
268        } else if is_tagged(bits) && get_tag(bits) == TAG_INT {
269            // NaN-boxed i48 integer — sign-extend to i64 and return as integer
270            let int_val = sign_extend_i48(get_payload(bits));
271            WireValue::Integer(int_val)
272        } else if is_heap_kind(bits, HK_STRING) {
273            let s = unsafe { jit_unbox::<String>(bits) };
274            WireValue::String(s.clone())
275        } else {
276            // Default to interpreting as a number for unknown tags
277            WireValue::Number(f64::from_bits(bits))
278        }
279    }
280}