1use shape_ast::Program;
4use shape_runtime::engine::{ExecutionType, ProgramExecutor, ShapeEngine};
5use shape_runtime::error::Result;
6use shape_wire::WireValue;
7use std::time::Instant;
8
9pub struct JITExecutor;
15
16impl ProgramExecutor for JITExecutor {
17 fn execute_program(
18 &mut self,
19 engine: &mut ShapeEngine,
20 program: &Program,
21 ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
22 use shape_vm::BytecodeCompiler;
23 let emit_phase_metrics = std::env::var_os("SHAPE_JIT_PHASE_METRICS").is_some();
24
25 let source_for_compilation = engine.current_source().map(|s| s.to_string());
27
28 let runtime = engine.get_runtime_mut();
30
31 let known_bindings: Vec<String> = if let Some(ctx) = runtime.persistent_context() {
33 let names = ctx.root_scope_binding_names();
34 if names.is_empty() {
35 shape_vm::stdlib::core_binding_names()
36 } else {
37 names
38 }
39 } else {
40 shape_vm::stdlib::core_binding_names()
41 };
42
43 let mut loader = shape_runtime::module_loader::ModuleLoader::new();
45 let (graph, stdlib_names, prelude_imports) =
46 shape_vm::module_resolution::build_graph_and_stdlib_names(
47 program,
48 &mut loader,
49 &[],
50 )
51 .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
52 message: format!("Module graph construction failed: {}", e),
53 location: None,
54 })?;
55
56 let bytecode_compile_start = Instant::now();
57 let mut compiler = BytecodeCompiler::new();
58 compiler.stdlib_function_names = stdlib_names;
59 compiler.register_known_bindings(&known_bindings);
60 if let Some(source) = &source_for_compilation {
61 compiler.set_source(source);
62 }
63 let bytecode = compiler
64 .compile_with_graph_and_prelude(program, graph, &prelude_imports)
65 .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
66 message: format!("Bytecode compilation failed: {}", e),
67 location: None,
68 })?;
69 let bytecode_compile_ms = bytecode_compile_start.elapsed().as_millis();
70
71 self.execute_with_jit(engine, &bytecode, bytecode_compile_ms, emit_phase_metrics)
72 }
73}
74
75impl JITExecutor {
76 fn execute_with_jit(
77 &self,
78 engine: &mut ShapeEngine,
79 bytecode: &shape_vm::bytecode::BytecodeProgram,
80 bytecode_compile_ms: u128,
81 emit_phase_metrics: bool,
82 ) -> Result<shape_runtime::engine::ProgramExecutorResult> {
83 use crate::JITConfig;
84 use crate::JITContext;
85 use crate::compiler::JITCompiler;
86
87 let jit_config = JITConfig::default();
89 let mut jit = JITCompiler::new(jit_config).map_err(|e| {
90 shape_runtime::error::ShapeError::RuntimeError {
91 message: format!("JIT compiler initialization failed: {}", e),
92 location: None,
93 }
94 })?;
95
96 if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
99 eprintln!(
100 "[jit-debug] starting compile_program_selective with {} instructions, {} functions",
101 bytecode.instructions.len(),
102 bytecode.functions.len()
103 );
104 for (i, instr) in bytecode.instructions.iter().enumerate() {
105 eprintln!(
106 "[jit-debug] instr[{}]: {:?} {:?}",
107 i, instr.opcode, instr.operand
108 );
109 }
110 }
111 let jit_compile_start = Instant::now();
112 let compile_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
113 jit.compile_program_selective("main", bytecode)
114 }));
115 let jit_compile_ms = jit_compile_start.elapsed().as_millis();
116 let (jit_fn, _mixed_table) = match compile_result {
117 Ok(Ok(result)) => result,
118 Ok(Err(e)) => {
119 return Err(shape_runtime::error::ShapeError::RuntimeError {
120 message: format!("JIT compilation failed: {}", e),
121 location: None,
122 });
123 }
124 Err(panic_info) => {
125 let msg = if let Some(s) = panic_info.downcast_ref::<String>() {
126 s.clone()
127 } else if let Some(s) = panic_info.downcast_ref::<&str>() {
128 s.to_string()
129 } else {
130 "unknown panic".to_string()
131 };
132 return Err(shape_runtime::error::ShapeError::RuntimeError {
133 message: format!("JIT compilation panicked: {}", msg),
134 location: None,
135 });
136 }
137 };
138
139 let foreign_bridge = {
140 let runtime = engine.get_runtime_mut();
141 crate::foreign_bridge::link_foreign_functions_for_jit(
142 bytecode,
143 runtime.persistent_context(),
144 )
145 .map_err(|e| shape_runtime::error::ShapeError::RuntimeError {
146 message: format!("JIT foreign-function linking failed: {}", e),
147 location: None,
148 })?
149 };
150
151 let mut jit_ctx = JITContext::default();
153 if let Some(state) = foreign_bridge.as_ref() {
154 jit_ctx.foreign_bridge_ptr = state.as_ref() as *const _ as *const std::ffi::c_void;
155 }
156
157 {
159 let runtime = engine.get_runtime_mut();
160 if let Some(ctx) = runtime.persistent_context_mut() {
161 jit_ctx.exec_context_ptr = ctx as *mut _ as *mut std::ffi::c_void;
162 }
163 }
164
165 if std::env::var_os("SHAPE_JIT_DEBUG").is_some() {
167 eprintln!("[jit-debug] compilation OK, about to execute...");
168 }
169 let jit_exec_start = Instant::now();
170 let signal = unsafe { jit_fn(&mut jit_ctx) };
171 let jit_exec_ms = jit_exec_start.elapsed().as_millis();
172
173 let raw_result = if jit_ctx.stack_ptr > 0 {
175 jit_ctx.stack[0]
176 } else {
177 crate::nan_boxing::TAG_NULL
178 };
179
180 if signal < 0 {
182 return Err(shape_runtime::error::ShapeError::RuntimeError {
183 message: format!("JIT execution error (code: {})", signal),
184 location: None,
185 });
186 }
187
188 let return_hint = bytecode.top_level_frame.as_ref().and_then(|fd| {
191 if fd.return_kind != shape_vm::type_tracking::SlotKind::Unknown {
192 Some(fd.return_kind)
193 } else {
194 fd.slots.last().copied()
195 }
196 });
197 let result_scalar =
198 crate::ffi::object::conversion::jit_bits_to_typed_scalar(raw_result, return_hint);
199
200 let wire_value = self.typed_scalar_to_wire(&result_scalar, raw_result);
202
203 if emit_phase_metrics {
204 let total_ms = bytecode_compile_ms + jit_compile_ms + jit_exec_ms;
205 eprintln!(
206 "[shape-jit-phases] bytecode_compile_ms={} jit_compile_ms={} jit_exec_ms={} total_ms={}",
207 bytecode_compile_ms, jit_compile_ms, jit_exec_ms, total_ms
208 );
209 }
210
211 Ok(shape_runtime::engine::ProgramExecutorResult {
212 wire_value,
213 type_info: None,
214 execution_type: ExecutionType::Script,
215 content_json: None,
216 content_html: None,
217 content_terminal: None,
218 })
219 }
220
221 fn typed_scalar_to_wire(&self, ts: &shape_value::TypedScalar, raw_bits: u64) -> WireValue {
226 use shape_value::ScalarKind;
227
228 match ts.kind {
229 ScalarKind::I8
230 | ScalarKind::I16
231 | ScalarKind::I32
232 | ScalarKind::I64
233 | ScalarKind::U8
234 | ScalarKind::U16
235 | ScalarKind::U32
236 | ScalarKind::U64
237 | ScalarKind::I128
238 | ScalarKind::U128 => {
239 WireValue::Number(ts.payload_lo as i64 as f64)
241 }
242 ScalarKind::F64 | ScalarKind::F32 => WireValue::Number(f64::from_bits(ts.payload_lo)),
243 ScalarKind::Bool => WireValue::Bool(ts.payload_lo != 0),
244 ScalarKind::Unit => WireValue::Null,
245 ScalarKind::None => {
246 self.nan_boxed_to_wire(raw_bits)
249 }
250 }
251 }
252
253 fn nan_boxed_to_wire(&self, bits: u64) -> WireValue {
254 use crate::nan_boxing::{
255 HK_STRING, TAG_BOOL_FALSE, TAG_BOOL_TRUE, TAG_NULL, is_heap_kind, is_number, jit_unbox,
256 unbox_number,
257 };
258 use shape_value::tags::{TAG_INT, get_payload, get_tag, is_tagged, sign_extend_i48};
259
260 if is_number(bits) {
261 WireValue::Number(unbox_number(bits))
262 } else if bits == TAG_NULL {
263 WireValue::Null
264 } else if bits == TAG_BOOL_TRUE {
265 WireValue::Bool(true)
266 } else if bits == TAG_BOOL_FALSE {
267 WireValue::Bool(false)
268 } else if is_tagged(bits) && get_tag(bits) == TAG_INT {
269 let int_val = sign_extend_i48(get_payload(bits));
271 WireValue::Integer(int_val)
272 } else if is_heap_kind(bits, HK_STRING) {
273 let s = unsafe { jit_unbox::<String>(bits) };
274 WireValue::String(s.clone())
275 } else {
276 WireValue::Number(f64::from_bits(bits))
278 }
279 }
280}