Skip to main content

shape_jit/
worker.rs

1//! JIT Compilation Backend
2//!
3//! Implements the `CompilationBackend` trait from `shape-vm` so the TierManager
4//! can drive JIT compilation on a background worker thread.
5
6use shape_vm::bytecode::BytecodeProgram;
7use shape_vm::tier::{CompilationBackend, CompilationRequest, CompilationResult, Tier};
8use shape_vm::type_tracking::FrameDescriptor;
9
10use crate::compiler::JITCompiler;
11use crate::context::JITConfig;
12use crate::translator::loop_analysis;
13use crate::translator::osr_compiler;
14
15/// JIT compilation backend that compiles hot loops to native code via Cranelift.
16///
17/// Owns a `JITCompiler` instance and implements the `CompilationBackend` trait.
18/// The `TierManager::set_backend()` spawns a worker thread that drives this.
19pub struct JitCompilationBackend {
20    jit: JITCompiler,
21}
22
23impl JitCompilationBackend {
24    /// Create a new JIT compilation backend with default configuration.
25    pub fn new() -> Result<Self, crate::error::JitError> {
26        Ok(Self {
27            jit: JITCompiler::new(JITConfig::default())?,
28        })
29    }
30
31    /// Create a new JIT compilation backend with custom configuration.
32    pub fn with_config(config: JITConfig) -> Result<Self, crate::error::JitError> {
33        Ok(Self {
34            jit: JITCompiler::new(config)?,
35        })
36    }
37
38    /// Compile an OSR loop from a compilation request.
39    fn compile_osr(
40        &mut self,
41        request: &CompilationRequest,
42        program: &BytecodeProgram,
43    ) -> CompilationResult {
44        let func_id = request.function_id;
45        let loop_header_ip = request.loop_header_ip;
46
47        // Get the target function
48        let function = match program.functions.get(func_id as usize) {
49            Some(f) => f,
50            None => {
51                return CompilationResult {
52                    function_id: func_id,
53                    compiled_tier: Tier::Interpreted,
54                    native_code: None,
55                    error: Some(format!("Function {} not found in program", func_id)),
56                    osr_entry: None,
57                    deopt_points: Vec::new(),
58                    loop_header_ip,
59                    shape_guards: Vec::new(),
60                };
61            }
62        };
63
64        // Extract the function's instruction range
65        let entry = function.entry_point;
66        let end = find_function_end(program, func_id as usize);
67        if entry >= program.instructions.len() || end > program.instructions.len() {
68            return CompilationResult {
69                function_id: func_id,
70                compiled_tier: Tier::Interpreted,
71                native_code: None,
72                error: Some(format!(
73                    "Function {} instruction range [{}, {}) out of bounds",
74                    func_id, entry, end
75                )),
76                osr_entry: None,
77                deopt_points: Vec::new(),
78                loop_header_ip,
79                shape_guards: Vec::new(),
80            };
81        }
82        let func_instructions = &program.instructions[entry..end];
83
84        // Run loop analysis on a sub-program containing just this function's instructions
85        let sub_program = build_sub_program(program, entry, end);
86        let loop_infos = loop_analysis::analyze_loops(&sub_program);
87
88        // Find the target loop. The loop_header_ip from the request is in
89        // global instruction coordinates; convert to function-local offset.
90        let target_local_ip = match loop_header_ip {
91            Some(ip) => {
92                if ip < entry {
93                    return CompilationResult {
94                        function_id: func_id,
95                        compiled_tier: Tier::Interpreted,
96                        native_code: None,
97                        error: Some(format!(
98                            "OSR loop header IP {} is before function entry {}",
99                            ip, entry
100                        )),
101                        osr_entry: None,
102                        deopt_points: Vec::new(),
103                        loop_header_ip: Some(ip),
104                        shape_guards: Vec::new(),
105                    };
106                }
107                ip - entry
108            }
109            None => {
110                return CompilationResult {
111                    function_id: func_id,
112                    compiled_tier: Tier::Interpreted,
113                    native_code: None,
114                    error: Some("OSR request without loop_header_ip".to_string()),
115                    osr_entry: None,
116                    deopt_points: Vec::new(),
117                    loop_header_ip: None,
118                    shape_guards: Vec::new(),
119                };
120            }
121        };
122
123        let loop_info = match loop_infos.get(&target_local_ip) {
124            Some(li) => li,
125            None => {
126                return CompilationResult {
127                    function_id: func_id,
128                    compiled_tier: Tier::Interpreted,
129                    native_code: None,
130                    error: Some(format!(
131                        "No loop found at local IP {} (global IP {:?})",
132                        target_local_ip, loop_header_ip
133                    )),
134                    osr_entry: None,
135                    deopt_points: Vec::new(),
136                    loop_header_ip,
137                    shape_guards: Vec::new(),
138                };
139            }
140        };
141
142        // Build frame descriptor (use function's if available, else default)
143        let default_frame = FrameDescriptor::default();
144        let frame_descriptor = function.frame_descriptor.as_ref().unwrap_or(&default_frame);
145
146        // Compile the loop
147        match osr_compiler::compile_osr_loop(
148            &mut self.jit,
149            function,
150            func_instructions,
151            loop_info,
152            frame_descriptor,
153        ) {
154            Ok(osr_result) => {
155                // Adjust entry point bytecode_ip back to global coordinates
156                let mut entry_point = osr_result.entry_point;
157                entry_point.bytecode_ip += entry;
158                entry_point.exit_ip += entry;
159
160                CompilationResult {
161                    function_id: func_id,
162                    compiled_tier: Tier::BaselineJit,
163                    native_code: Some(osr_result.native_code),
164                    error: None,
165                    osr_entry: Some(entry_point),
166                    deopt_points: osr_result.deopt_points,
167                    loop_header_ip,
168                    shape_guards: Vec::new(),
169                }
170            }
171            Err(e) => CompilationResult {
172                function_id: func_id,
173                compiled_tier: Tier::Interpreted,
174                native_code: None,
175                error: Some(e),
176                osr_entry: None,
177                deopt_points: Vec::new(),
178                loop_header_ip,
179                shape_guards: Vec::new(),
180            },
181        }
182    }
183}
184
185// SAFETY: JitCompilationBackend is used exclusively on its own worker thread.
186// The raw pointers in JITCompiler (compiled_functions, function_table) point
187// to JIT code that is immutable after compilation and valid for the module's
188// lifetime. Access is single-threaded (the worker thread).
189unsafe impl Send for JitCompilationBackend {}
190
191impl JitCompilationBackend {
192    /// Compile a whole function for Tier 1/2 promotion.
193    ///
194    /// Tier 1 (BaselineJit, no feedback): uses `compile_single_function` with
195    /// empty user_funcs — cross-function calls deopt to interpreter.
196    ///
197    /// Tier 2 (OptimizingJit, with feedback): uses `compile_optimizing_function`
198    /// which enables speculative calls based on monomorphic call feedback.
199    /// Self-recursive calls get direct-call FuncRefs. Cross-function monomorphic
200    /// calls get callee identity guard + FFI fallthrough (guard deopt on mismatch).
201    fn compile_function(
202        &mut self,
203        request: &CompilationRequest,
204        program: &BytecodeProgram,
205    ) -> CompilationResult {
206        let func_id = request.function_id;
207
208        // Tier 2: feedback-guided optimizing compilation with populated user_funcs
209        if let Some(fv) = request.feedback.clone() {
210            return match self.jit.compile_optimizing_function(
211                program,
212                func_id as usize,
213                fv,
214                &request.callee_feedback,
215            ) {
216                Ok((code_ptr, deopt_points, shape_guards)) => CompilationResult {
217                    function_id: func_id,
218                    compiled_tier: request.target_tier,
219                    native_code: Some(code_ptr),
220                    error: None,
221                    osr_entry: None,
222                    deopt_points,
223                    loop_header_ip: None,
224                    shape_guards,
225                },
226                Err(e) => CompilationResult {
227                    function_id: func_id,
228                    compiled_tier: Tier::Interpreted,
229                    native_code: None,
230                    error: Some(e),
231                    osr_entry: None,
232                    deopt_points: Vec::new(),
233                    loop_header_ip: None,
234                    shape_guards: Vec::new(),
235                },
236            };
237        }
238
239        // Tier 1: baseline compilation without cross-function speculation
240        match self
241            .jit
242            .compile_single_function(program, func_id as usize, None)
243        {
244            Ok((code_ptr, deopt_points, shape_guards)) => CompilationResult {
245                function_id: func_id,
246                compiled_tier: request.target_tier,
247                native_code: Some(code_ptr),
248                error: None,
249                osr_entry: None,
250                deopt_points,
251                loop_header_ip: None,
252                shape_guards,
253            },
254            Err(e) => CompilationResult {
255                function_id: func_id,
256                compiled_tier: Tier::Interpreted,
257                native_code: None,
258                error: Some(e),
259                osr_entry: None,
260                deopt_points: Vec::new(),
261                loop_header_ip: None,
262                shape_guards: Vec::new(),
263            },
264        }
265    }
266}
267
268impl CompilationBackend for JitCompilationBackend {
269    fn compile(
270        &mut self,
271        request: &CompilationRequest,
272        program: &BytecodeProgram,
273    ) -> CompilationResult {
274        if request.osr {
275            self.compile_osr(request, program)
276        } else {
277            self.compile_function(request, program)
278        }
279    }
280}
281
282/// Find the end of a function's instruction range.
283///
284/// For the last function, this is the end of the instruction stream.
285/// For other functions, this is the entry point of the next function.
286fn find_function_end(program: &BytecodeProgram, func_index: usize) -> usize {
287    let func = &program.functions[func_index];
288    func.entry_point + func.body_length
289}
290
291/// Build a minimal sub-program containing only the instructions in [start, end).
292///
293/// The sub-program's instructions are indexed from 0, making it compatible
294/// with `analyze_loops()` which expects a contiguous instruction stream.
295fn build_sub_program(program: &BytecodeProgram, start: usize, end: usize) -> BytecodeProgram {
296    BytecodeProgram {
297        instructions: program.instructions[start..end].to_vec(),
298        constants: program.constants.clone(),
299        strings: program.strings.clone(),
300        functions: vec![],
301        debug_info: Default::default(),
302        data_schema: None,
303        module_binding_names: vec![],
304        top_level_locals_count: 0,
305        top_level_local_storage_hints: vec![],
306        type_schema_registry: Default::default(),
307        module_binding_storage_hints: vec![],
308        function_local_storage_hints: vec![],
309        compiled_annotations: Default::default(),
310        trait_method_symbols: Default::default(),
311        expanded_function_defs: Default::default(),
312        string_index: Default::default(),
313        foreign_functions: Vec::new(),
314        native_struct_layouts: vec![],
315        content_addressed: None,
316        function_blob_hashes: vec![],
317        top_level_frame: None,
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use shape_vm::bytecode::*;
325    use shape_vm::type_tracking::{FrameDescriptor, SlotKind};
326
327    fn make_instr(opcode: OpCode, operand: Option<Operand>) -> Instruction {
328        Instruction { opcode, operand }
329    }
330
331    #[test]
332    fn test_backend_compiles_whole_function() {
333        let mut backend = JitCompilationBackend::new().unwrap();
334
335        // Simple function: return local 0 + local 1
336        let instrs = vec![
337            // Function body at entry_point=0
338            make_instr(OpCode::LoadLocal, Some(Operand::Local(0))), // 0
339            make_instr(OpCode::LoadLocal, Some(Operand::Local(1))), // 1
340            make_instr(OpCode::AddInt, None),                       // 2
341            make_instr(OpCode::ReturnValue, None),                  // 3
342            // Main code (trampoline target)
343            make_instr(OpCode::Halt, None), // 4
344        ];
345
346        let func = Function {
347            name: "add_two".to_string(),
348            arity: 2,
349            param_names: vec![],
350            locals_count: 2,
351            entry_point: 0,
352            body_length: 4,
353            is_closure: false,
354            captures_count: 0,
355            is_async: false,
356            ref_params: vec![],
357            ref_mutates: vec![],
358            mutable_captures: vec![],
359            frame_descriptor: Some(FrameDescriptor::from_slots(vec![
360                SlotKind::Int64, // arg0
361                SlotKind::Int64, // arg1
362            ])),
363            osr_entry_points: vec![],
364        };
365
366        let program = BytecodeProgram {
367            instructions: instrs,
368            constants: vec![],
369            strings: vec![],
370            functions: vec![func],
371            debug_info: Default::default(),
372            data_schema: None,
373            module_binding_names: vec![],
374            top_level_locals_count: 0,
375            top_level_local_storage_hints: vec![],
376            type_schema_registry: Default::default(),
377            module_binding_storage_hints: vec![],
378            function_local_storage_hints: vec![],
379            compiled_annotations: Default::default(),
380            trait_method_symbols: Default::default(),
381            expanded_function_defs: Default::default(),
382            string_index: Default::default(),
383            foreign_functions: Vec::new(),
384            native_struct_layouts: vec![],
385            content_addressed: None,
386            function_blob_hashes: vec![],
387            top_level_frame: None,
388            ..Default::default()
389        };
390
391        let request = CompilationRequest {
392            function_id: 0,
393            target_tier: Tier::BaselineJit,
394            blob_hash: None,
395            osr: false,
396            loop_header_ip: None,
397            feedback: None,
398            callee_feedback: std::collections::HashMap::new(),
399        };
400
401        let result = backend.compile(&request, &program);
402        assert!(
403            result.error.is_none(),
404            "Expected successful whole-function compilation, got: {:?}",
405            result.error
406        );
407        assert!(result.native_code.is_some());
408        assert_eq!(result.compiled_tier, Tier::BaselineJit);
409        assert!(result.osr_entry.is_none()); // Not an OSR result
410    }
411
412    #[test]
413    fn test_backend_whole_function_invalid_id() {
414        let mut backend = JitCompilationBackend::new().unwrap();
415        let program = BytecodeProgram {
416            instructions: vec![make_instr(OpCode::Halt, None)],
417            constants: vec![],
418            strings: vec![],
419            functions: vec![], // No functions
420            debug_info: Default::default(),
421            data_schema: None,
422            module_binding_names: vec![],
423            top_level_locals_count: 0,
424            top_level_local_storage_hints: vec![],
425            type_schema_registry: Default::default(),
426            module_binding_storage_hints: vec![],
427            function_local_storage_hints: vec![],
428            compiled_annotations: Default::default(),
429            trait_method_symbols: Default::default(),
430            expanded_function_defs: Default::default(),
431            string_index: Default::default(),
432            foreign_functions: Vec::new(),
433            native_struct_layouts: vec![],
434            content_addressed: None,
435            function_blob_hashes: vec![],
436            top_level_frame: None,
437            ..Default::default()
438        };
439        let request = CompilationRequest {
440            function_id: 99,
441            target_tier: Tier::BaselineJit,
442            blob_hash: None,
443            osr: false,
444            loop_header_ip: None,
445            feedback: None,
446            callee_feedback: std::collections::HashMap::new(),
447        };
448        let result = backend.compile(&request, &program);
449        assert!(result.error.is_some());
450        assert!(result.error.unwrap().contains("not found"));
451    }
452
453    #[test]
454    fn test_backend_osr_compiles_simple_loop() {
455        let mut backend = JitCompilationBackend::new().unwrap();
456
457        // Function at entry_point=0: for (i=0; i<n; i++) { sum += i }
458        let instrs = vec![
459            make_instr(OpCode::LoopStart, None),                       // 0
460            make_instr(OpCode::LoadLocal, Some(Operand::Local(0))),    // 1: i
461            make_instr(OpCode::LoadLocal, Some(Operand::Local(1))),    // 2: n
462            make_instr(OpCode::LtInt, None),                           // 3
463            make_instr(OpCode::JumpIfFalse, Some(Operand::Offset(7))), // 4
464            make_instr(OpCode::LoadLocal, Some(Operand::Local(2))),    // 5: sum
465            make_instr(OpCode::LoadLocal, Some(Operand::Local(0))),    // 6: i
466            make_instr(OpCode::AddInt, None),                          // 7
467            make_instr(OpCode::StoreLocal, Some(Operand::Local(2))),   // 8
468            make_instr(OpCode::LoadLocal, Some(Operand::Local(0))),    // 9: i
469            make_instr(OpCode::PushConst, Some(Operand::Const(0))),    // 10: 1
470            make_instr(OpCode::AddInt, None),                          // 11
471            make_instr(OpCode::StoreLocal, Some(Operand::Local(0))),   // 12
472            make_instr(OpCode::LoopEnd, None),                         // 13
473            make_instr(OpCode::ReturnValue, None),                     // 14
474        ];
475
476        let func = Function {
477            name: "test_loop".to_string(),
478            arity: 0,
479            param_names: vec![],
480            locals_count: 3,
481            entry_point: 0,
482            body_length: 15,
483            is_closure: false,
484            captures_count: 0,
485            is_async: false,
486            ref_params: vec![],
487            ref_mutates: vec![],
488            mutable_captures: vec![],
489            frame_descriptor: Some(FrameDescriptor::from_slots(vec![
490                SlotKind::Int64, // i
491                SlotKind::Int64, // n
492                SlotKind::Int64, // sum
493            ])),
494            osr_entry_points: vec![],
495        };
496
497        let program = BytecodeProgram {
498            instructions: instrs,
499            constants: vec![Constant::Int(1)],
500            strings: vec![],
501            functions: vec![func],
502            debug_info: Default::default(),
503            data_schema: None,
504            module_binding_names: vec![],
505            top_level_locals_count: 0,
506            top_level_local_storage_hints: vec![],
507            type_schema_registry: Default::default(),
508            module_binding_storage_hints: vec![],
509            function_local_storage_hints: vec![],
510            compiled_annotations: Default::default(),
511            trait_method_symbols: Default::default(),
512            expanded_function_defs: Default::default(),
513            string_index: Default::default(),
514            foreign_functions: Vec::new(),
515            native_struct_layouts: vec![],
516            content_addressed: None,
517            function_blob_hashes: vec![],
518            top_level_frame: None,
519            ..Default::default()
520        };
521
522        let request = CompilationRequest {
523            function_id: 0,
524            target_tier: Tier::BaselineJit,
525            blob_hash: None,
526            osr: true,
527            loop_header_ip: Some(0), // Global IP of LoopStart
528            feedback: None,
529            callee_feedback: std::collections::HashMap::new(),
530        };
531
532        let result = backend.compile(&request, &program);
533        assert!(
534            result.error.is_none(),
535            "Expected successful compilation, got: {:?}",
536            result.error
537        );
538        assert!(result.native_code.is_some());
539        assert!(result.osr_entry.is_some());
540        assert_eq!(result.compiled_tier, Tier::BaselineJit);
541
542        let entry = result.osr_entry.unwrap();
543        assert_eq!(entry.bytecode_ip, 0);
544        assert!(entry.live_locals.contains(&0)); // i
545        assert!(entry.live_locals.contains(&1)); // n
546        assert!(entry.live_locals.contains(&2)); // sum
547    }
548
549    #[test]
550    fn test_backend_osr_blacklists_unsupported_loop() {
551        let mut backend = JitCompilationBackend::new().unwrap();
552
553        // Function with a loop containing CallMethod (unsupported)
554        let instrs = vec![
555            make_instr(OpCode::LoopStart, None),
556            make_instr(OpCode::LoadLocal, Some(Operand::Local(0))),
557            make_instr(OpCode::CallMethod, None), // Unsupported!
558            make_instr(OpCode::Pop, None),
559            make_instr(OpCode::LoopEnd, None),
560            make_instr(OpCode::Halt, None),
561        ];
562
563        let func = Function {
564            name: "unsupported_loop".to_string(),
565            arity: 0,
566            param_names: vec![],
567            locals_count: 1,
568            entry_point: 0,
569            body_length: 6,
570            is_closure: false,
571            captures_count: 0,
572            is_async: false,
573            ref_params: vec![],
574            ref_mutates: vec![],
575            mutable_captures: vec![],
576            frame_descriptor: Some(FrameDescriptor::from_slots(vec![SlotKind::Unknown])),
577            osr_entry_points: vec![],
578        };
579
580        let program = BytecodeProgram {
581            instructions: instrs,
582            constants: vec![],
583            strings: vec![],
584            functions: vec![func],
585            debug_info: Default::default(),
586            data_schema: None,
587            module_binding_names: vec![],
588            top_level_locals_count: 0,
589            top_level_local_storage_hints: vec![],
590            type_schema_registry: Default::default(),
591            module_binding_storage_hints: vec![],
592            function_local_storage_hints: vec![],
593            compiled_annotations: Default::default(),
594            trait_method_symbols: Default::default(),
595            expanded_function_defs: Default::default(),
596            string_index: Default::default(),
597            foreign_functions: Vec::new(),
598            native_struct_layouts: vec![],
599            content_addressed: None,
600            function_blob_hashes: vec![],
601            top_level_frame: None,
602            ..Default::default()
603        };
604
605        let request = CompilationRequest {
606            function_id: 0,
607            target_tier: Tier::BaselineJit,
608            blob_hash: None,
609            osr: true,
610            loop_header_ip: Some(0),
611            feedback: None,
612            callee_feedback: std::collections::HashMap::new(),
613        };
614
615        let result = backend.compile(&request, &program);
616        assert!(result.error.is_some());
617        assert!(result.error.unwrap().contains("unsupported opcode"));
618        assert_eq!(result.loop_header_ip, Some(0)); // For blacklisting
619    }
620}