1use shape_vm::bytecode::BytecodeProgram;
7use shape_vm::tier::{CompilationBackend, CompilationRequest, CompilationResult, Tier};
8use shape_vm::type_tracking::FrameDescriptor;
9
10use crate::compiler::JITCompiler;
11use crate::context::JITConfig;
12use crate::translator::loop_analysis;
13use crate::translator::osr_compiler;
14
15pub struct JitCompilationBackend {
20 jit: JITCompiler,
21}
22
23impl JitCompilationBackend {
24 pub fn new() -> Result<Self, crate::error::JitError> {
26 Ok(Self {
27 jit: JITCompiler::new(JITConfig::default())?,
28 })
29 }
30
31 pub fn with_config(config: JITConfig) -> Result<Self, crate::error::JitError> {
33 Ok(Self {
34 jit: JITCompiler::new(config)?,
35 })
36 }
37
38 fn compile_osr(
40 &mut self,
41 request: &CompilationRequest,
42 program: &BytecodeProgram,
43 ) -> CompilationResult {
44 let func_id = request.function_id;
45 let loop_header_ip = request.loop_header_ip;
46
47 let function = match program.functions.get(func_id as usize) {
49 Some(f) => f,
50 None => {
51 return CompilationResult {
52 function_id: func_id,
53 compiled_tier: Tier::Interpreted,
54 native_code: None,
55 error: Some(format!("Function {} not found in program", func_id)),
56 osr_entry: None,
57 deopt_points: Vec::new(),
58 loop_header_ip,
59 shape_guards: Vec::new(),
60 };
61 }
62 };
63
64 let entry = function.entry_point;
66 let end = find_function_end(program, func_id as usize);
67 if entry >= program.instructions.len() || end > program.instructions.len() {
68 return CompilationResult {
69 function_id: func_id,
70 compiled_tier: Tier::Interpreted,
71 native_code: None,
72 error: Some(format!(
73 "Function {} instruction range [{}, {}) out of bounds",
74 func_id, entry, end
75 )),
76 osr_entry: None,
77 deopt_points: Vec::new(),
78 loop_header_ip,
79 shape_guards: Vec::new(),
80 };
81 }
82 let func_instructions = &program.instructions[entry..end];
83
84 let sub_program = build_sub_program(program, entry, end);
86 let loop_infos = loop_analysis::analyze_loops(&sub_program);
87
88 let target_local_ip = match loop_header_ip {
91 Some(ip) => {
92 if ip < entry {
93 return CompilationResult {
94 function_id: func_id,
95 compiled_tier: Tier::Interpreted,
96 native_code: None,
97 error: Some(format!(
98 "OSR loop header IP {} is before function entry {}",
99 ip, entry
100 )),
101 osr_entry: None,
102 deopt_points: Vec::new(),
103 loop_header_ip: Some(ip),
104 shape_guards: Vec::new(),
105 };
106 }
107 ip - entry
108 }
109 None => {
110 return CompilationResult {
111 function_id: func_id,
112 compiled_tier: Tier::Interpreted,
113 native_code: None,
114 error: Some("OSR request without loop_header_ip".to_string()),
115 osr_entry: None,
116 deopt_points: Vec::new(),
117 loop_header_ip: None,
118 shape_guards: Vec::new(),
119 };
120 }
121 };
122
123 let loop_info = match loop_infos.get(&target_local_ip) {
124 Some(li) => li,
125 None => {
126 return CompilationResult {
127 function_id: func_id,
128 compiled_tier: Tier::Interpreted,
129 native_code: None,
130 error: Some(format!(
131 "No loop found at local IP {} (global IP {:?})",
132 target_local_ip, loop_header_ip
133 )),
134 osr_entry: None,
135 deopt_points: Vec::new(),
136 loop_header_ip,
137 shape_guards: Vec::new(),
138 };
139 }
140 };
141
142 let default_frame = FrameDescriptor::default();
144 let frame_descriptor = function.frame_descriptor.as_ref().unwrap_or(&default_frame);
145
146 match osr_compiler::compile_osr_loop(
148 &mut self.jit,
149 function,
150 func_instructions,
151 loop_info,
152 frame_descriptor,
153 ) {
154 Ok(osr_result) => {
155 let mut entry_point = osr_result.entry_point;
157 entry_point.bytecode_ip += entry;
158 entry_point.exit_ip += entry;
159
160 CompilationResult {
161 function_id: func_id,
162 compiled_tier: Tier::BaselineJit,
163 native_code: Some(osr_result.native_code),
164 error: None,
165 osr_entry: Some(entry_point),
166 deopt_points: osr_result.deopt_points,
167 loop_header_ip,
168 shape_guards: Vec::new(),
169 }
170 }
171 Err(e) => CompilationResult {
172 function_id: func_id,
173 compiled_tier: Tier::Interpreted,
174 native_code: None,
175 error: Some(e),
176 osr_entry: None,
177 deopt_points: Vec::new(),
178 loop_header_ip,
179 shape_guards: Vec::new(),
180 },
181 }
182 }
183}
184
185unsafe impl Send for JitCompilationBackend {}
190
191impl JitCompilationBackend {
192 fn compile_function(
202 &mut self,
203 request: &CompilationRequest,
204 program: &BytecodeProgram,
205 ) -> CompilationResult {
206 let func_id = request.function_id;
207
208 if let Some(fv) = request.feedback.clone() {
210 return match self
211 .jit
212 .compile_optimizing_function(program, func_id as usize, fv, &request.callee_feedback)
213 {
214 Ok((code_ptr, deopt_points, shape_guards)) => CompilationResult {
215 function_id: func_id,
216 compiled_tier: request.target_tier,
217 native_code: Some(code_ptr),
218 error: None,
219 osr_entry: None,
220 deopt_points,
221 loop_header_ip: None,
222 shape_guards,
223 },
224 Err(e) => CompilationResult {
225 function_id: func_id,
226 compiled_tier: Tier::Interpreted,
227 native_code: None,
228 error: Some(e),
229 osr_entry: None,
230 deopt_points: Vec::new(),
231 loop_header_ip: None,
232 shape_guards: Vec::new(),
233 },
234 };
235 }
236
237 match self
239 .jit
240 .compile_single_function(program, func_id as usize, None)
241 {
242 Ok((code_ptr, deopt_points, shape_guards)) => CompilationResult {
243 function_id: func_id,
244 compiled_tier: request.target_tier,
245 native_code: Some(code_ptr),
246 error: None,
247 osr_entry: None,
248 deopt_points,
249 loop_header_ip: None,
250 shape_guards,
251 },
252 Err(e) => CompilationResult {
253 function_id: func_id,
254 compiled_tier: Tier::Interpreted,
255 native_code: None,
256 error: Some(e),
257 osr_entry: None,
258 deopt_points: Vec::new(),
259 loop_header_ip: None,
260 shape_guards: Vec::new(),
261 },
262 }
263 }
264}
265
266impl CompilationBackend for JitCompilationBackend {
267 fn compile(
268 &mut self,
269 request: &CompilationRequest,
270 program: &BytecodeProgram,
271 ) -> CompilationResult {
272 if request.osr {
273 self.compile_osr(request, program)
274 } else {
275 self.compile_function(request, program)
276 }
277 }
278}
279
280fn find_function_end(program: &BytecodeProgram, func_index: usize) -> usize {
285 let func = &program.functions[func_index];
286 func.entry_point + func.body_length
287}
288
289fn build_sub_program(program: &BytecodeProgram, start: usize, end: usize) -> BytecodeProgram {
294 BytecodeProgram {
295 instructions: program.instructions[start..end].to_vec(),
296 constants: program.constants.clone(),
297 strings: program.strings.clone(),
298 functions: vec![],
299 debug_info: Default::default(),
300 data_schema: None,
301 module_binding_names: vec![],
302 top_level_locals_count: 0,
303 top_level_local_storage_hints: vec![],
304 type_schema_registry: Default::default(),
305 module_binding_storage_hints: vec![],
306 function_local_storage_hints: vec![],
307 compiled_annotations: Default::default(),
308 trait_method_symbols: Default::default(),
309 expanded_function_defs: Default::default(),
310 string_index: Default::default(),
311 foreign_functions: Vec::new(),
312 native_struct_layouts: vec![],
313 content_addressed: None,
314 function_blob_hashes: vec![],
315 top_level_frame: None,
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use shape_vm::bytecode::*;
323 use shape_vm::type_tracking::{FrameDescriptor, SlotKind};
324
325 fn make_instr(opcode: OpCode, operand: Option<Operand>) -> Instruction {
326 Instruction { opcode, operand }
327 }
328
329 #[test]
330 fn test_backend_compiles_whole_function() {
331 let mut backend = JitCompilationBackend::new().unwrap();
332
333 let instrs = vec![
335 make_instr(OpCode::LoadLocal, Some(Operand::Local(0))), make_instr(OpCode::LoadLocal, Some(Operand::Local(1))), make_instr(OpCode::AddInt, None), make_instr(OpCode::ReturnValue, None), make_instr(OpCode::Halt, None), ];
343
344 let func = Function {
345 name: "add_two".to_string(),
346 arity: 2,
347 param_names: vec![],
348 locals_count: 2,
349 entry_point: 0,
350 body_length: 4,
351 is_closure: false,
352 captures_count: 0,
353 is_async: false,
354 ref_params: vec![],
355 ref_mutates: vec![],
356 mutable_captures: vec![],
357 frame_descriptor: Some(FrameDescriptor::from_slots(vec![
358 SlotKind::Int64, SlotKind::Int64, ])),
361 osr_entry_points: vec![],
362 };
363
364 let program = BytecodeProgram {
365 instructions: instrs,
366 constants: vec![],
367 strings: vec![],
368 functions: vec![func],
369 debug_info: Default::default(),
370 data_schema: None,
371 module_binding_names: vec![],
372 top_level_locals_count: 0,
373 top_level_local_storage_hints: vec![],
374 type_schema_registry: Default::default(),
375 module_binding_storage_hints: vec![],
376 function_local_storage_hints: vec![],
377 compiled_annotations: Default::default(),
378 trait_method_symbols: Default::default(),
379 expanded_function_defs: Default::default(),
380 string_index: Default::default(),
381 foreign_functions: Vec::new(),
382 native_struct_layouts: vec![],
383 content_addressed: None,
384 function_blob_hashes: vec![],
385 top_level_frame: None,
386 ..Default::default()
387 };
388
389 let request = CompilationRequest {
390 function_id: 0,
391 target_tier: Tier::BaselineJit,
392 blob_hash: None,
393 osr: false,
394 loop_header_ip: None,
395 feedback: None,
396 callee_feedback: std::collections::HashMap::new(),
397 };
398
399 let result = backend.compile(&request, &program);
400 assert!(
401 result.error.is_none(),
402 "Expected successful whole-function compilation, got: {:?}",
403 result.error
404 );
405 assert!(result.native_code.is_some());
406 assert_eq!(result.compiled_tier, Tier::BaselineJit);
407 assert!(result.osr_entry.is_none()); }
409
410 #[test]
411 fn test_backend_whole_function_invalid_id() {
412 let mut backend = JitCompilationBackend::new().unwrap();
413 let program = BytecodeProgram {
414 instructions: vec![make_instr(OpCode::Halt, None)],
415 constants: vec![],
416 strings: vec![],
417 functions: vec![], debug_info: Default::default(),
419 data_schema: None,
420 module_binding_names: vec![],
421 top_level_locals_count: 0,
422 top_level_local_storage_hints: vec![],
423 type_schema_registry: Default::default(),
424 module_binding_storage_hints: vec![],
425 function_local_storage_hints: vec![],
426 compiled_annotations: Default::default(),
427 trait_method_symbols: Default::default(),
428 expanded_function_defs: Default::default(),
429 string_index: Default::default(),
430 foreign_functions: Vec::new(),
431 native_struct_layouts: vec![],
432 content_addressed: None,
433 function_blob_hashes: vec![],
434 top_level_frame: None,
435 ..Default::default()
436 };
437 let request = CompilationRequest {
438 function_id: 99,
439 target_tier: Tier::BaselineJit,
440 blob_hash: None,
441 osr: false,
442 loop_header_ip: None,
443 feedback: None,
444 callee_feedback: std::collections::HashMap::new(),
445 };
446 let result = backend.compile(&request, &program);
447 assert!(result.error.is_some());
448 assert!(result.error.unwrap().contains("not found"));
449 }
450
451 #[test]
452 fn test_backend_osr_compiles_simple_loop() {
453 let mut backend = JitCompilationBackend::new().unwrap();
454
455 let instrs = vec![
457 make_instr(OpCode::LoopStart, None), make_instr(OpCode::LoadLocal, Some(Operand::Local(0))), make_instr(OpCode::LoadLocal, Some(Operand::Local(1))), make_instr(OpCode::LtInt, None), make_instr(OpCode::JumpIfFalse, Some(Operand::Offset(7))), make_instr(OpCode::LoadLocal, Some(Operand::Local(2))), make_instr(OpCode::LoadLocal, Some(Operand::Local(0))), make_instr(OpCode::AddInt, None), make_instr(OpCode::StoreLocal, Some(Operand::Local(2))), make_instr(OpCode::LoadLocal, Some(Operand::Local(0))), make_instr(OpCode::PushConst, Some(Operand::Const(0))), make_instr(OpCode::AddInt, None), make_instr(OpCode::StoreLocal, Some(Operand::Local(0))), make_instr(OpCode::LoopEnd, None), make_instr(OpCode::ReturnValue, None), ];
473
474 let func = Function {
475 name: "test_loop".to_string(),
476 arity: 0,
477 param_names: vec![],
478 locals_count: 3,
479 entry_point: 0,
480 body_length: 15,
481 is_closure: false,
482 captures_count: 0,
483 is_async: false,
484 ref_params: vec![],
485 ref_mutates: vec![],
486 mutable_captures: vec![],
487 frame_descriptor: Some(FrameDescriptor::from_slots(vec![
488 SlotKind::Int64, SlotKind::Int64, SlotKind::Int64, ])),
492 osr_entry_points: vec![],
493 };
494
495 let program = BytecodeProgram {
496 instructions: instrs,
497 constants: vec![Constant::Int(1)],
498 strings: vec![],
499 functions: vec![func],
500 debug_info: Default::default(),
501 data_schema: None,
502 module_binding_names: vec![],
503 top_level_locals_count: 0,
504 top_level_local_storage_hints: vec![],
505 type_schema_registry: Default::default(),
506 module_binding_storage_hints: vec![],
507 function_local_storage_hints: vec![],
508 compiled_annotations: Default::default(),
509 trait_method_symbols: Default::default(),
510 expanded_function_defs: Default::default(),
511 string_index: Default::default(),
512 foreign_functions: Vec::new(),
513 native_struct_layouts: vec![],
514 content_addressed: None,
515 function_blob_hashes: vec![],
516 top_level_frame: None,
517 ..Default::default()
518 };
519
520 let request = CompilationRequest {
521 function_id: 0,
522 target_tier: Tier::BaselineJit,
523 blob_hash: None,
524 osr: true,
525 loop_header_ip: Some(0), feedback: None,
527 callee_feedback: std::collections::HashMap::new(),
528 };
529
530 let result = backend.compile(&request, &program);
531 assert!(
532 result.error.is_none(),
533 "Expected successful compilation, got: {:?}",
534 result.error
535 );
536 assert!(result.native_code.is_some());
537 assert!(result.osr_entry.is_some());
538 assert_eq!(result.compiled_tier, Tier::BaselineJit);
539
540 let entry = result.osr_entry.unwrap();
541 assert_eq!(entry.bytecode_ip, 0);
542 assert!(entry.live_locals.contains(&0)); assert!(entry.live_locals.contains(&1)); assert!(entry.live_locals.contains(&2)); }
546
547 #[test]
548 fn test_backend_osr_blacklists_unsupported_loop() {
549 let mut backend = JitCompilationBackend::new().unwrap();
550
551 let instrs = vec![
553 make_instr(OpCode::LoopStart, None),
554 make_instr(OpCode::LoadLocal, Some(Operand::Local(0))),
555 make_instr(OpCode::CallMethod, None), make_instr(OpCode::Pop, None),
557 make_instr(OpCode::LoopEnd, None),
558 make_instr(OpCode::Halt, None),
559 ];
560
561 let func = Function {
562 name: "unsupported_loop".to_string(),
563 arity: 0,
564 param_names: vec![],
565 locals_count: 1,
566 entry_point: 0,
567 body_length: 6,
568 is_closure: false,
569 captures_count: 0,
570 is_async: false,
571 ref_params: vec![],
572 ref_mutates: vec![],
573 mutable_captures: vec![],
574 frame_descriptor: Some(FrameDescriptor::from_slots(vec![SlotKind::Unknown])),
575 osr_entry_points: vec![],
576 };
577
578 let program = BytecodeProgram {
579 instructions: instrs,
580 constants: vec![],
581 strings: vec![],
582 functions: vec![func],
583 debug_info: Default::default(),
584 data_schema: None,
585 module_binding_names: vec![],
586 top_level_locals_count: 0,
587 top_level_local_storage_hints: vec![],
588 type_schema_registry: Default::default(),
589 module_binding_storage_hints: vec![],
590 function_local_storage_hints: vec![],
591 compiled_annotations: Default::default(),
592 trait_method_symbols: Default::default(),
593 expanded_function_defs: Default::default(),
594 string_index: Default::default(),
595 foreign_functions: Vec::new(),
596 native_struct_layouts: vec![],
597 content_addressed: None,
598 function_blob_hashes: vec![],
599 top_level_frame: None,
600 ..Default::default()
601 };
602
603 let request = CompilationRequest {
604 function_id: 0,
605 target_tier: Tier::BaselineJit,
606 blob_hash: None,
607 osr: true,
608 loop_header_ip: Some(0),
609 feedback: None,
610 callee_feedback: std::collections::HashMap::new(),
611 };
612
613 let result = backend.compile(&request, &program);
614 assert!(result.error.is_some());
615 assert!(result.error.unwrap().contains("unsupported opcode"));
616 assert_eq!(result.loop_header_ip, Some(0)); }
618}