1use shape_vm::bytecode::BytecodeProgram;
7use shape_vm::tier::{CompilationBackend, CompilationRequest, CompilationResult, Tier};
8use shape_vm::type_tracking::FrameDescriptor;
9
10use crate::compiler::JITCompiler;
11use crate::context::JITConfig;
12use crate::translator::loop_analysis;
13use crate::translator::osr_compiler;
14
15pub struct JitCompilationBackend {
20 jit: JITCompiler,
21}
22
23impl JitCompilationBackend {
24 pub fn new() -> Result<Self, crate::error::JitError> {
26 Ok(Self {
27 jit: JITCompiler::new(JITConfig::default())?,
28 })
29 }
30
31 pub fn with_config(config: JITConfig) -> Result<Self, crate::error::JitError> {
33 Ok(Self {
34 jit: JITCompiler::new(config)?,
35 })
36 }
37
38 fn compile_osr(
40 &mut self,
41 request: &CompilationRequest,
42 program: &BytecodeProgram,
43 ) -> CompilationResult {
44 let func_id = request.function_id;
45 let loop_header_ip = request.loop_header_ip;
46
47 let function = match program.functions.get(func_id as usize) {
49 Some(f) => f,
50 None => {
51 return CompilationResult {
52 function_id: func_id,
53 compiled_tier: Tier::Interpreted,
54 native_code: None,
55 error: Some(format!("Function {} not found in program", func_id)),
56 osr_entry: None,
57 deopt_points: Vec::new(),
58 loop_header_ip,
59 shape_guards: Vec::new(),
60 };
61 }
62 };
63
64 let entry = function.entry_point;
66 let end = find_function_end(program, func_id as usize);
67 if entry >= program.instructions.len() || end > program.instructions.len() {
68 return CompilationResult {
69 function_id: func_id,
70 compiled_tier: Tier::Interpreted,
71 native_code: None,
72 error: Some(format!(
73 "Function {} instruction range [{}, {}) out of bounds",
74 func_id, entry, end
75 )),
76 osr_entry: None,
77 deopt_points: Vec::new(),
78 loop_header_ip,
79 shape_guards: Vec::new(),
80 };
81 }
82 let func_instructions = &program.instructions[entry..end];
83
84 let sub_program = build_sub_program(program, entry, end);
86 let loop_infos = loop_analysis::analyze_loops(&sub_program);
87
88 let target_local_ip = match loop_header_ip {
91 Some(ip) => {
92 if ip < entry {
93 return CompilationResult {
94 function_id: func_id,
95 compiled_tier: Tier::Interpreted,
96 native_code: None,
97 error: Some(format!(
98 "OSR loop header IP {} is before function entry {}",
99 ip, entry
100 )),
101 osr_entry: None,
102 deopt_points: Vec::new(),
103 loop_header_ip: Some(ip),
104 shape_guards: Vec::new(),
105 };
106 }
107 ip - entry
108 }
109 None => {
110 return CompilationResult {
111 function_id: func_id,
112 compiled_tier: Tier::Interpreted,
113 native_code: None,
114 error: Some("OSR request without loop_header_ip".to_string()),
115 osr_entry: None,
116 deopt_points: Vec::new(),
117 loop_header_ip: None,
118 shape_guards: Vec::new(),
119 };
120 }
121 };
122
123 let loop_info = match loop_infos.get(&target_local_ip) {
124 Some(li) => li,
125 None => {
126 return CompilationResult {
127 function_id: func_id,
128 compiled_tier: Tier::Interpreted,
129 native_code: None,
130 error: Some(format!(
131 "No loop found at local IP {} (global IP {:?})",
132 target_local_ip, loop_header_ip
133 )),
134 osr_entry: None,
135 deopt_points: Vec::new(),
136 loop_header_ip,
137 shape_guards: Vec::new(),
138 };
139 }
140 };
141
142 let default_frame = FrameDescriptor::default();
144 let frame_descriptor = function.frame_descriptor.as_ref().unwrap_or(&default_frame);
145
146 match osr_compiler::compile_osr_loop(
148 &mut self.jit,
149 function,
150 func_instructions,
151 loop_info,
152 frame_descriptor,
153 ) {
154 Ok(osr_result) => {
155 let mut entry_point = osr_result.entry_point;
157 entry_point.bytecode_ip += entry;
158 entry_point.exit_ip += entry;
159
160 CompilationResult {
161 function_id: func_id,
162 compiled_tier: Tier::BaselineJit,
163 native_code: Some(osr_result.native_code),
164 error: None,
165 osr_entry: Some(entry_point),
166 deopt_points: osr_result.deopt_points,
167 loop_header_ip,
168 shape_guards: Vec::new(),
169 }
170 }
171 Err(e) => CompilationResult {
172 function_id: func_id,
173 compiled_tier: Tier::Interpreted,
174 native_code: None,
175 error: Some(e),
176 osr_entry: None,
177 deopt_points: Vec::new(),
178 loop_header_ip,
179 shape_guards: Vec::new(),
180 },
181 }
182 }
183}
184
185unsafe impl Send for JitCompilationBackend {}
190
191impl JitCompilationBackend {
192 fn compile_function(
202 &mut self,
203 request: &CompilationRequest,
204 program: &BytecodeProgram,
205 ) -> CompilationResult {
206 let func_id = request.function_id;
207
208 if let Some(fv) = request.feedback.clone() {
210 return match self.jit.compile_optimizing_function(
211 program,
212 func_id as usize,
213 fv,
214 &request.callee_feedback,
215 ) {
216 Ok((code_ptr, deopt_points, shape_guards)) => CompilationResult {
217 function_id: func_id,
218 compiled_tier: request.target_tier,
219 native_code: Some(code_ptr),
220 error: None,
221 osr_entry: None,
222 deopt_points,
223 loop_header_ip: None,
224 shape_guards,
225 },
226 Err(e) => CompilationResult {
227 function_id: func_id,
228 compiled_tier: Tier::Interpreted,
229 native_code: None,
230 error: Some(e),
231 osr_entry: None,
232 deopt_points: Vec::new(),
233 loop_header_ip: None,
234 shape_guards: Vec::new(),
235 },
236 };
237 }
238
239 match self
241 .jit
242 .compile_single_function(program, func_id as usize, None)
243 {
244 Ok((code_ptr, deopt_points, shape_guards)) => CompilationResult {
245 function_id: func_id,
246 compiled_tier: request.target_tier,
247 native_code: Some(code_ptr),
248 error: None,
249 osr_entry: None,
250 deopt_points,
251 loop_header_ip: None,
252 shape_guards,
253 },
254 Err(e) => CompilationResult {
255 function_id: func_id,
256 compiled_tier: Tier::Interpreted,
257 native_code: None,
258 error: Some(e),
259 osr_entry: None,
260 deopt_points: Vec::new(),
261 loop_header_ip: None,
262 shape_guards: Vec::new(),
263 },
264 }
265 }
266}
267
268impl CompilationBackend for JitCompilationBackend {
269 fn compile(
270 &mut self,
271 request: &CompilationRequest,
272 program: &BytecodeProgram,
273 ) -> CompilationResult {
274 if request.osr {
275 self.compile_osr(request, program)
276 } else {
277 self.compile_function(request, program)
278 }
279 }
280}
281
282fn find_function_end(program: &BytecodeProgram, func_index: usize) -> usize {
287 let func = &program.functions[func_index];
288 func.entry_point + func.body_length
289}
290
291fn build_sub_program(program: &BytecodeProgram, start: usize, end: usize) -> BytecodeProgram {
296 BytecodeProgram {
297 instructions: program.instructions[start..end].to_vec(),
298 constants: program.constants.clone(),
299 strings: program.strings.clone(),
300 functions: vec![],
301 debug_info: Default::default(),
302 data_schema: None,
303 module_binding_names: vec![],
304 top_level_locals_count: 0,
305 top_level_local_storage_hints: vec![],
306 type_schema_registry: Default::default(),
307 module_binding_storage_hints: vec![],
308 function_local_storage_hints: vec![],
309 compiled_annotations: Default::default(),
310 trait_method_symbols: Default::default(),
311 expanded_function_defs: Default::default(),
312 string_index: Default::default(),
313 foreign_functions: Vec::new(),
314 native_struct_layouts: vec![],
315 content_addressed: None,
316 function_blob_hashes: vec![],
317 top_level_frame: None,
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use shape_vm::bytecode::*;
325 use shape_vm::type_tracking::{FrameDescriptor, SlotKind};
326
327 fn make_instr(opcode: OpCode, operand: Option<Operand>) -> Instruction {
328 Instruction { opcode, operand }
329 }
330
331 #[test]
332 fn test_backend_compiles_whole_function() {
333 let mut backend = JitCompilationBackend::new().unwrap();
334
335 let instrs = vec![
337 make_instr(OpCode::LoadLocal, Some(Operand::Local(0))), make_instr(OpCode::LoadLocal, Some(Operand::Local(1))), make_instr(OpCode::AddInt, None), make_instr(OpCode::ReturnValue, None), make_instr(OpCode::Halt, None), ];
345
346 let func = Function {
347 name: "add_two".to_string(),
348 arity: 2,
349 param_names: vec![],
350 locals_count: 2,
351 entry_point: 0,
352 body_length: 4,
353 is_closure: false,
354 captures_count: 0,
355 is_async: false,
356 ref_params: vec![],
357 ref_mutates: vec![],
358 mutable_captures: vec![],
359 frame_descriptor: Some(FrameDescriptor::from_slots(vec![
360 SlotKind::Int64, SlotKind::Int64, ])),
363 osr_entry_points: vec![],
364 };
365
366 let program = BytecodeProgram {
367 instructions: instrs,
368 constants: vec![],
369 strings: vec![],
370 functions: vec![func],
371 debug_info: Default::default(),
372 data_schema: None,
373 module_binding_names: vec![],
374 top_level_locals_count: 0,
375 top_level_local_storage_hints: vec![],
376 type_schema_registry: Default::default(),
377 module_binding_storage_hints: vec![],
378 function_local_storage_hints: vec![],
379 compiled_annotations: Default::default(),
380 trait_method_symbols: Default::default(),
381 expanded_function_defs: Default::default(),
382 string_index: Default::default(),
383 foreign_functions: Vec::new(),
384 native_struct_layouts: vec![],
385 content_addressed: None,
386 function_blob_hashes: vec![],
387 top_level_frame: None,
388 ..Default::default()
389 };
390
391 let request = CompilationRequest {
392 function_id: 0,
393 target_tier: Tier::BaselineJit,
394 blob_hash: None,
395 osr: false,
396 loop_header_ip: None,
397 feedback: None,
398 callee_feedback: std::collections::HashMap::new(),
399 };
400
401 let result = backend.compile(&request, &program);
402 assert!(
403 result.error.is_none(),
404 "Expected successful whole-function compilation, got: {:?}",
405 result.error
406 );
407 assert!(result.native_code.is_some());
408 assert_eq!(result.compiled_tier, Tier::BaselineJit);
409 assert!(result.osr_entry.is_none()); }
411
412 #[test]
413 fn test_backend_whole_function_invalid_id() {
414 let mut backend = JitCompilationBackend::new().unwrap();
415 let program = BytecodeProgram {
416 instructions: vec![make_instr(OpCode::Halt, None)],
417 constants: vec![],
418 strings: vec![],
419 functions: vec![], debug_info: Default::default(),
421 data_schema: None,
422 module_binding_names: vec![],
423 top_level_locals_count: 0,
424 top_level_local_storage_hints: vec![],
425 type_schema_registry: Default::default(),
426 module_binding_storage_hints: vec![],
427 function_local_storage_hints: vec![],
428 compiled_annotations: Default::default(),
429 trait_method_symbols: Default::default(),
430 expanded_function_defs: Default::default(),
431 string_index: Default::default(),
432 foreign_functions: Vec::new(),
433 native_struct_layouts: vec![],
434 content_addressed: None,
435 function_blob_hashes: vec![],
436 top_level_frame: None,
437 ..Default::default()
438 };
439 let request = CompilationRequest {
440 function_id: 99,
441 target_tier: Tier::BaselineJit,
442 blob_hash: None,
443 osr: false,
444 loop_header_ip: None,
445 feedback: None,
446 callee_feedback: std::collections::HashMap::new(),
447 };
448 let result = backend.compile(&request, &program);
449 assert!(result.error.is_some());
450 assert!(result.error.unwrap().contains("not found"));
451 }
452
453 #[test]
454 fn test_backend_osr_compiles_simple_loop() {
455 let mut backend = JitCompilationBackend::new().unwrap();
456
457 let instrs = vec![
459 make_instr(OpCode::LoopStart, None), make_instr(OpCode::LoadLocal, Some(Operand::Local(0))), make_instr(OpCode::LoadLocal, Some(Operand::Local(1))), make_instr(OpCode::LtInt, None), make_instr(OpCode::JumpIfFalse, Some(Operand::Offset(7))), make_instr(OpCode::LoadLocal, Some(Operand::Local(2))), make_instr(OpCode::LoadLocal, Some(Operand::Local(0))), make_instr(OpCode::AddInt, None), make_instr(OpCode::StoreLocal, Some(Operand::Local(2))), make_instr(OpCode::LoadLocal, Some(Operand::Local(0))), make_instr(OpCode::PushConst, Some(Operand::Const(0))), make_instr(OpCode::AddInt, None), make_instr(OpCode::StoreLocal, Some(Operand::Local(0))), make_instr(OpCode::LoopEnd, None), make_instr(OpCode::ReturnValue, None), ];
475
476 let func = Function {
477 name: "test_loop".to_string(),
478 arity: 0,
479 param_names: vec![],
480 locals_count: 3,
481 entry_point: 0,
482 body_length: 15,
483 is_closure: false,
484 captures_count: 0,
485 is_async: false,
486 ref_params: vec![],
487 ref_mutates: vec![],
488 mutable_captures: vec![],
489 frame_descriptor: Some(FrameDescriptor::from_slots(vec![
490 SlotKind::Int64, SlotKind::Int64, SlotKind::Int64, ])),
494 osr_entry_points: vec![],
495 };
496
497 let program = BytecodeProgram {
498 instructions: instrs,
499 constants: vec![Constant::Int(1)],
500 strings: vec![],
501 functions: vec![func],
502 debug_info: Default::default(),
503 data_schema: None,
504 module_binding_names: vec![],
505 top_level_locals_count: 0,
506 top_level_local_storage_hints: vec![],
507 type_schema_registry: Default::default(),
508 module_binding_storage_hints: vec![],
509 function_local_storage_hints: vec![],
510 compiled_annotations: Default::default(),
511 trait_method_symbols: Default::default(),
512 expanded_function_defs: Default::default(),
513 string_index: Default::default(),
514 foreign_functions: Vec::new(),
515 native_struct_layouts: vec![],
516 content_addressed: None,
517 function_blob_hashes: vec![],
518 top_level_frame: None,
519 ..Default::default()
520 };
521
522 let request = CompilationRequest {
523 function_id: 0,
524 target_tier: Tier::BaselineJit,
525 blob_hash: None,
526 osr: true,
527 loop_header_ip: Some(0), feedback: None,
529 callee_feedback: std::collections::HashMap::new(),
530 };
531
532 let result = backend.compile(&request, &program);
533 assert!(
534 result.error.is_none(),
535 "Expected successful compilation, got: {:?}",
536 result.error
537 );
538 assert!(result.native_code.is_some());
539 assert!(result.osr_entry.is_some());
540 assert_eq!(result.compiled_tier, Tier::BaselineJit);
541
542 let entry = result.osr_entry.unwrap();
543 assert_eq!(entry.bytecode_ip, 0);
544 assert!(entry.live_locals.contains(&0)); assert!(entry.live_locals.contains(&1)); assert!(entry.live_locals.contains(&2)); }
548
549 #[test]
550 fn test_backend_osr_blacklists_unsupported_loop() {
551 let mut backend = JitCompilationBackend::new().unwrap();
552
553 let instrs = vec![
555 make_instr(OpCode::LoopStart, None),
556 make_instr(OpCode::LoadLocal, Some(Operand::Local(0))),
557 make_instr(OpCode::CallMethod, None), make_instr(OpCode::Pop, None),
559 make_instr(OpCode::LoopEnd, None),
560 make_instr(OpCode::Halt, None),
561 ];
562
563 let func = Function {
564 name: "unsupported_loop".to_string(),
565 arity: 0,
566 param_names: vec![],
567 locals_count: 1,
568 entry_point: 0,
569 body_length: 6,
570 is_closure: false,
571 captures_count: 0,
572 is_async: false,
573 ref_params: vec![],
574 ref_mutates: vec![],
575 mutable_captures: vec![],
576 frame_descriptor: Some(FrameDescriptor::from_slots(vec![SlotKind::Unknown])),
577 osr_entry_points: vec![],
578 };
579
580 let program = BytecodeProgram {
581 instructions: instrs,
582 constants: vec![],
583 strings: vec![],
584 functions: vec![func],
585 debug_info: Default::default(),
586 data_schema: None,
587 module_binding_names: vec![],
588 top_level_locals_count: 0,
589 top_level_local_storage_hints: vec![],
590 type_schema_registry: Default::default(),
591 module_binding_storage_hints: vec![],
592 function_local_storage_hints: vec![],
593 compiled_annotations: Default::default(),
594 trait_method_symbols: Default::default(),
595 expanded_function_defs: Default::default(),
596 string_index: Default::default(),
597 foreign_functions: Vec::new(),
598 native_struct_layouts: vec![],
599 content_addressed: None,
600 function_blob_hashes: vec![],
601 top_level_frame: None,
602 ..Default::default()
603 };
604
605 let request = CompilationRequest {
606 function_id: 0,
607 target_tier: Tier::BaselineJit,
608 blob_hash: None,
609 osr: true,
610 loop_header_ip: Some(0),
611 feedback: None,
612 callee_feedback: std::collections::HashMap::new(),
613 };
614
615 let result = backend.compile(&request, &program);
616 assert!(result.error.is_some());
617 assert!(result.error.unwrap().contains("unsupported opcode"));
618 assert_eq!(result.loop_header_ip, Some(0)); }
620}