Skip to main content

shape_jit/
worker.rs

1//! JIT Compilation Backend
2//!
3//! Implements the `CompilationBackend` trait from `shape-vm` so the TierManager
4//! can drive JIT compilation on a background worker thread.
5
6use shape_vm::bytecode::BytecodeProgram;
7use shape_vm::tier::{CompilationBackend, CompilationRequest, CompilationResult, Tier};
8use shape_vm::type_tracking::FrameDescriptor;
9
10use crate::compiler::JITCompiler;
11use crate::context::JITConfig;
12use crate::translator::loop_analysis;
13use crate::translator::osr_compiler;
14
15/// JIT compilation backend that compiles hot loops to native code via Cranelift.
16///
17/// Owns a `JITCompiler` instance and implements the `CompilationBackend` trait.
18/// The `TierManager::set_backend()` spawns a worker thread that drives this.
19pub struct JitCompilationBackend {
20    jit: JITCompiler,
21}
22
23impl JitCompilationBackend {
24    /// Create a new JIT compilation backend with default configuration.
25    pub fn new() -> Result<Self, crate::error::JitError> {
26        Ok(Self {
27            jit: JITCompiler::new(JITConfig::default())?,
28        })
29    }
30
31    /// Create a new JIT compilation backend with custom configuration.
32    pub fn with_config(config: JITConfig) -> Result<Self, crate::error::JitError> {
33        Ok(Self {
34            jit: JITCompiler::new(config)?,
35        })
36    }
37
38    /// Compile an OSR loop from a compilation request.
39    fn compile_osr(
40        &mut self,
41        request: &CompilationRequest,
42        program: &BytecodeProgram,
43    ) -> CompilationResult {
44        let func_id = request.function_id;
45        let loop_header_ip = request.loop_header_ip;
46
47        // Get the target function
48        let function = match program.functions.get(func_id as usize) {
49            Some(f) => f,
50            None => {
51                return CompilationResult {
52                    function_id: func_id,
53                    compiled_tier: Tier::Interpreted,
54                    native_code: None,
55                    error: Some(format!("Function {} not found in program", func_id)),
56                    osr_entry: None,
57                    deopt_points: Vec::new(),
58                    loop_header_ip,
59                    shape_guards: Vec::new(),
60                };
61            }
62        };
63
64        // Extract the function's instruction range
65        let entry = function.entry_point;
66        let end = find_function_end(program, func_id as usize);
67        if entry >= program.instructions.len() || end > program.instructions.len() {
68            return CompilationResult {
69                function_id: func_id,
70                compiled_tier: Tier::Interpreted,
71                native_code: None,
72                error: Some(format!(
73                    "Function {} instruction range [{}, {}) out of bounds",
74                    func_id, entry, end
75                )),
76                osr_entry: None,
77                deopt_points: Vec::new(),
78                loop_header_ip,
79                shape_guards: Vec::new(),
80            };
81        }
82        let func_instructions = &program.instructions[entry..end];
83
84        // Run loop analysis on a sub-program containing just this function's instructions
85        let sub_program = build_sub_program(program, entry, end);
86        let loop_infos = loop_analysis::analyze_loops(&sub_program);
87
88        // Find the target loop. The loop_header_ip from the request is in
89        // global instruction coordinates; convert to function-local offset.
90        let target_local_ip = match loop_header_ip {
91            Some(ip) => {
92                if ip < entry {
93                    return CompilationResult {
94                        function_id: func_id,
95                        compiled_tier: Tier::Interpreted,
96                        native_code: None,
97                        error: Some(format!(
98                            "OSR loop header IP {} is before function entry {}",
99                            ip, entry
100                        )),
101                        osr_entry: None,
102                        deopt_points: Vec::new(),
103                        loop_header_ip: Some(ip),
104                        shape_guards: Vec::new(),
105                    };
106                }
107                ip - entry
108            }
109            None => {
110                return CompilationResult {
111                    function_id: func_id,
112                    compiled_tier: Tier::Interpreted,
113                    native_code: None,
114                    error: Some("OSR request without loop_header_ip".to_string()),
115                    osr_entry: None,
116                    deopt_points: Vec::new(),
117                    loop_header_ip: None,
118                    shape_guards: Vec::new(),
119                };
120            }
121        };
122
123        let loop_info = match loop_infos.get(&target_local_ip) {
124            Some(li) => li,
125            None => {
126                return CompilationResult {
127                    function_id: func_id,
128                    compiled_tier: Tier::Interpreted,
129                    native_code: None,
130                    error: Some(format!(
131                        "No loop found at local IP {} (global IP {:?})",
132                        target_local_ip, loop_header_ip
133                    )),
134                    osr_entry: None,
135                    deopt_points: Vec::new(),
136                    loop_header_ip,
137                    shape_guards: Vec::new(),
138                };
139            }
140        };
141
142        // Build frame descriptor (use function's if available, else default)
143        let default_frame = FrameDescriptor::default();
144        let frame_descriptor = function.frame_descriptor.as_ref().unwrap_or(&default_frame);
145
146        // Compile the loop
147        match osr_compiler::compile_osr_loop(
148            &mut self.jit,
149            function,
150            func_instructions,
151            loop_info,
152            frame_descriptor,
153        ) {
154            Ok(osr_result) => {
155                // Adjust entry point bytecode_ip back to global coordinates
156                let mut entry_point = osr_result.entry_point;
157                entry_point.bytecode_ip += entry;
158                entry_point.exit_ip += entry;
159
160                CompilationResult {
161                    function_id: func_id,
162                    compiled_tier: Tier::BaselineJit,
163                    native_code: Some(osr_result.native_code),
164                    error: None,
165                    osr_entry: Some(entry_point),
166                    deopt_points: osr_result.deopt_points,
167                    loop_header_ip,
168                    shape_guards: Vec::new(),
169                }
170            }
171            Err(e) => CompilationResult {
172                function_id: func_id,
173                compiled_tier: Tier::Interpreted,
174                native_code: None,
175                error: Some(e),
176                osr_entry: None,
177                deopt_points: Vec::new(),
178                loop_header_ip,
179                shape_guards: Vec::new(),
180            },
181        }
182    }
183}
184
185// SAFETY: JitCompilationBackend is used exclusively on its own worker thread.
186// The raw pointers in JITCompiler (compiled_functions, function_table) point
187// to JIT code that is immutable after compilation and valid for the module's
188// lifetime. Access is single-threaded (the worker thread).
189unsafe impl Send for JitCompilationBackend {}
190
191impl JitCompilationBackend {
192    /// Compile a whole function for Tier 1/2 promotion.
193    ///
194    /// Tier 1 (BaselineJit, no feedback): uses `compile_single_function` with
195    /// empty user_funcs — cross-function calls deopt to interpreter.
196    ///
197    /// Tier 2 (OptimizingJit, with feedback): uses `compile_optimizing_function`
198    /// which enables speculative calls based on monomorphic call feedback.
199    /// Self-recursive calls get direct-call FuncRefs. Cross-function monomorphic
200    /// calls get callee identity guard + FFI fallthrough (guard deopt on mismatch).
201    fn compile_function(
202        &mut self,
203        request: &CompilationRequest,
204        program: &BytecodeProgram,
205    ) -> CompilationResult {
206        let func_id = request.function_id;
207
208        // Tier 2: feedback-guided optimizing compilation with populated user_funcs
209        if let Some(fv) = request.feedback.clone() {
210            return match self
211                .jit
212                .compile_optimizing_function(program, func_id as usize, fv, &request.callee_feedback)
213            {
214                Ok((code_ptr, deopt_points, shape_guards)) => CompilationResult {
215                    function_id: func_id,
216                    compiled_tier: request.target_tier,
217                    native_code: Some(code_ptr),
218                    error: None,
219                    osr_entry: None,
220                    deopt_points,
221                    loop_header_ip: None,
222                    shape_guards,
223                },
224                Err(e) => CompilationResult {
225                    function_id: func_id,
226                    compiled_tier: Tier::Interpreted,
227                    native_code: None,
228                    error: Some(e),
229                    osr_entry: None,
230                    deopt_points: Vec::new(),
231                    loop_header_ip: None,
232                    shape_guards: Vec::new(),
233                },
234            };
235        }
236
237        // Tier 1: baseline compilation without cross-function speculation
238        match self
239            .jit
240            .compile_single_function(program, func_id as usize, None)
241        {
242            Ok((code_ptr, deopt_points, shape_guards)) => CompilationResult {
243                function_id: func_id,
244                compiled_tier: request.target_tier,
245                native_code: Some(code_ptr),
246                error: None,
247                osr_entry: None,
248                deopt_points,
249                loop_header_ip: None,
250                shape_guards,
251            },
252            Err(e) => CompilationResult {
253                function_id: func_id,
254                compiled_tier: Tier::Interpreted,
255                native_code: None,
256                error: Some(e),
257                osr_entry: None,
258                deopt_points: Vec::new(),
259                loop_header_ip: None,
260                shape_guards: Vec::new(),
261            },
262        }
263    }
264}
265
266impl CompilationBackend for JitCompilationBackend {
267    fn compile(
268        &mut self,
269        request: &CompilationRequest,
270        program: &BytecodeProgram,
271    ) -> CompilationResult {
272        if request.osr {
273            self.compile_osr(request, program)
274        } else {
275            self.compile_function(request, program)
276        }
277    }
278}
279
280/// Find the end of a function's instruction range.
281///
282/// For the last function, this is the end of the instruction stream.
283/// For other functions, this is the entry point of the next function.
284fn find_function_end(program: &BytecodeProgram, func_index: usize) -> usize {
285    let func = &program.functions[func_index];
286    func.entry_point + func.body_length
287}
288
289/// Build a minimal sub-program containing only the instructions in [start, end).
290///
291/// The sub-program's instructions are indexed from 0, making it compatible
292/// with `analyze_loops()` which expects a contiguous instruction stream.
293fn build_sub_program(program: &BytecodeProgram, start: usize, end: usize) -> BytecodeProgram {
294    BytecodeProgram {
295        instructions: program.instructions[start..end].to_vec(),
296        constants: program.constants.clone(),
297        strings: program.strings.clone(),
298        functions: vec![],
299        debug_info: Default::default(),
300        data_schema: None,
301        module_binding_names: vec![],
302        top_level_locals_count: 0,
303        top_level_local_storage_hints: vec![],
304        type_schema_registry: Default::default(),
305        module_binding_storage_hints: vec![],
306        function_local_storage_hints: vec![],
307        compiled_annotations: Default::default(),
308        trait_method_symbols: Default::default(),
309        expanded_function_defs: Default::default(),
310        string_index: Default::default(),
311        foreign_functions: Vec::new(),
312        native_struct_layouts: vec![],
313        content_addressed: None,
314        function_blob_hashes: vec![],
315        top_level_frame: None,
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use shape_vm::bytecode::*;
323    use shape_vm::type_tracking::{FrameDescriptor, SlotKind};
324
325    fn make_instr(opcode: OpCode, operand: Option<Operand>) -> Instruction {
326        Instruction { opcode, operand }
327    }
328
329    #[test]
330    fn test_backend_compiles_whole_function() {
331        let mut backend = JitCompilationBackend::new().unwrap();
332
333        // Simple function: return local 0 + local 1
334        let instrs = vec![
335            // Function body at entry_point=0
336            make_instr(OpCode::LoadLocal, Some(Operand::Local(0))), // 0
337            make_instr(OpCode::LoadLocal, Some(Operand::Local(1))), // 1
338            make_instr(OpCode::AddInt, None),                       // 2
339            make_instr(OpCode::ReturnValue, None),                  // 3
340            // Main code (trampoline target)
341            make_instr(OpCode::Halt, None), // 4
342        ];
343
344        let func = Function {
345            name: "add_two".to_string(),
346            arity: 2,
347            param_names: vec![],
348            locals_count: 2,
349            entry_point: 0,
350            body_length: 4,
351            is_closure: false,
352            captures_count: 0,
353            is_async: false,
354            ref_params: vec![],
355            ref_mutates: vec![],
356            mutable_captures: vec![],
357            frame_descriptor: Some(FrameDescriptor::from_slots(vec![
358                SlotKind::Int64, // arg0
359                SlotKind::Int64, // arg1
360            ])),
361            osr_entry_points: vec![],
362        };
363
364        let program = BytecodeProgram {
365            instructions: instrs,
366            constants: vec![],
367            strings: vec![],
368            functions: vec![func],
369            debug_info: Default::default(),
370            data_schema: None,
371            module_binding_names: vec![],
372            top_level_locals_count: 0,
373            top_level_local_storage_hints: vec![],
374            type_schema_registry: Default::default(),
375            module_binding_storage_hints: vec![],
376            function_local_storage_hints: vec![],
377            compiled_annotations: Default::default(),
378            trait_method_symbols: Default::default(),
379            expanded_function_defs: Default::default(),
380            string_index: Default::default(),
381            foreign_functions: Vec::new(),
382            native_struct_layouts: vec![],
383            content_addressed: None,
384            function_blob_hashes: vec![],
385            top_level_frame: None,
386            ..Default::default()
387        };
388
389        let request = CompilationRequest {
390            function_id: 0,
391            target_tier: Tier::BaselineJit,
392            blob_hash: None,
393            osr: false,
394            loop_header_ip: None,
395            feedback: None,
396            callee_feedback: std::collections::HashMap::new(),
397        };
398
399        let result = backend.compile(&request, &program);
400        assert!(
401            result.error.is_none(),
402            "Expected successful whole-function compilation, got: {:?}",
403            result.error
404        );
405        assert!(result.native_code.is_some());
406        assert_eq!(result.compiled_tier, Tier::BaselineJit);
407        assert!(result.osr_entry.is_none()); // Not an OSR result
408    }
409
410    #[test]
411    fn test_backend_whole_function_invalid_id() {
412        let mut backend = JitCompilationBackend::new().unwrap();
413        let program = BytecodeProgram {
414            instructions: vec![make_instr(OpCode::Halt, None)],
415            constants: vec![],
416            strings: vec![],
417            functions: vec![], // No functions
418            debug_info: Default::default(),
419            data_schema: None,
420            module_binding_names: vec![],
421            top_level_locals_count: 0,
422            top_level_local_storage_hints: vec![],
423            type_schema_registry: Default::default(),
424            module_binding_storage_hints: vec![],
425            function_local_storage_hints: vec![],
426            compiled_annotations: Default::default(),
427            trait_method_symbols: Default::default(),
428            expanded_function_defs: Default::default(),
429            string_index: Default::default(),
430            foreign_functions: Vec::new(),
431            native_struct_layouts: vec![],
432            content_addressed: None,
433            function_blob_hashes: vec![],
434            top_level_frame: None,
435            ..Default::default()
436        };
437        let request = CompilationRequest {
438            function_id: 99,
439            target_tier: Tier::BaselineJit,
440            blob_hash: None,
441            osr: false,
442            loop_header_ip: None,
443            feedback: None,
444            callee_feedback: std::collections::HashMap::new(),
445        };
446        let result = backend.compile(&request, &program);
447        assert!(result.error.is_some());
448        assert!(result.error.unwrap().contains("not found"));
449    }
450
451    #[test]
452    fn test_backend_osr_compiles_simple_loop() {
453        let mut backend = JitCompilationBackend::new().unwrap();
454
455        // Function at entry_point=0: for (i=0; i<n; i++) { sum += i }
456        let instrs = vec![
457            make_instr(OpCode::LoopStart, None),                       // 0
458            make_instr(OpCode::LoadLocal, Some(Operand::Local(0))),    // 1: i
459            make_instr(OpCode::LoadLocal, Some(Operand::Local(1))),    // 2: n
460            make_instr(OpCode::LtInt, None),                           // 3
461            make_instr(OpCode::JumpIfFalse, Some(Operand::Offset(7))), // 4
462            make_instr(OpCode::LoadLocal, Some(Operand::Local(2))),    // 5: sum
463            make_instr(OpCode::LoadLocal, Some(Operand::Local(0))),    // 6: i
464            make_instr(OpCode::AddInt, None),                          // 7
465            make_instr(OpCode::StoreLocal, Some(Operand::Local(2))),   // 8
466            make_instr(OpCode::LoadLocal, Some(Operand::Local(0))),    // 9: i
467            make_instr(OpCode::PushConst, Some(Operand::Const(0))),    // 10: 1
468            make_instr(OpCode::AddInt, None),                          // 11
469            make_instr(OpCode::StoreLocal, Some(Operand::Local(0))),   // 12
470            make_instr(OpCode::LoopEnd, None),                         // 13
471            make_instr(OpCode::ReturnValue, None),                     // 14
472        ];
473
474        let func = Function {
475            name: "test_loop".to_string(),
476            arity: 0,
477            param_names: vec![],
478            locals_count: 3,
479            entry_point: 0,
480            body_length: 15,
481            is_closure: false,
482            captures_count: 0,
483            is_async: false,
484            ref_params: vec![],
485            ref_mutates: vec![],
486            mutable_captures: vec![],
487            frame_descriptor: Some(FrameDescriptor::from_slots(vec![
488                SlotKind::Int64, // i
489                SlotKind::Int64, // n
490                SlotKind::Int64, // sum
491            ])),
492            osr_entry_points: vec![],
493        };
494
495        let program = BytecodeProgram {
496            instructions: instrs,
497            constants: vec![Constant::Int(1)],
498            strings: vec![],
499            functions: vec![func],
500            debug_info: Default::default(),
501            data_schema: None,
502            module_binding_names: vec![],
503            top_level_locals_count: 0,
504            top_level_local_storage_hints: vec![],
505            type_schema_registry: Default::default(),
506            module_binding_storage_hints: vec![],
507            function_local_storage_hints: vec![],
508            compiled_annotations: Default::default(),
509            trait_method_symbols: Default::default(),
510            expanded_function_defs: Default::default(),
511            string_index: Default::default(),
512            foreign_functions: Vec::new(),
513            native_struct_layouts: vec![],
514            content_addressed: None,
515            function_blob_hashes: vec![],
516            top_level_frame: None,
517            ..Default::default()
518        };
519
520        let request = CompilationRequest {
521            function_id: 0,
522            target_tier: Tier::BaselineJit,
523            blob_hash: None,
524            osr: true,
525            loop_header_ip: Some(0), // Global IP of LoopStart
526            feedback: None,
527            callee_feedback: std::collections::HashMap::new(),
528        };
529
530        let result = backend.compile(&request, &program);
531        assert!(
532            result.error.is_none(),
533            "Expected successful compilation, got: {:?}",
534            result.error
535        );
536        assert!(result.native_code.is_some());
537        assert!(result.osr_entry.is_some());
538        assert_eq!(result.compiled_tier, Tier::BaselineJit);
539
540        let entry = result.osr_entry.unwrap();
541        assert_eq!(entry.bytecode_ip, 0);
542        assert!(entry.live_locals.contains(&0)); // i
543        assert!(entry.live_locals.contains(&1)); // n
544        assert!(entry.live_locals.contains(&2)); // sum
545    }
546
547    #[test]
548    fn test_backend_osr_blacklists_unsupported_loop() {
549        let mut backend = JitCompilationBackend::new().unwrap();
550
551        // Function with a loop containing CallMethod (unsupported)
552        let instrs = vec![
553            make_instr(OpCode::LoopStart, None),
554            make_instr(OpCode::LoadLocal, Some(Operand::Local(0))),
555            make_instr(OpCode::CallMethod, None), // Unsupported!
556            make_instr(OpCode::Pop, None),
557            make_instr(OpCode::LoopEnd, None),
558            make_instr(OpCode::Halt, None),
559        ];
560
561        let func = Function {
562            name: "unsupported_loop".to_string(),
563            arity: 0,
564            param_names: vec![],
565            locals_count: 1,
566            entry_point: 0,
567            body_length: 6,
568            is_closure: false,
569            captures_count: 0,
570            is_async: false,
571            ref_params: vec![],
572            ref_mutates: vec![],
573            mutable_captures: vec![],
574            frame_descriptor: Some(FrameDescriptor::from_slots(vec![SlotKind::Unknown])),
575            osr_entry_points: vec![],
576        };
577
578        let program = BytecodeProgram {
579            instructions: instrs,
580            constants: vec![],
581            strings: vec![],
582            functions: vec![func],
583            debug_info: Default::default(),
584            data_schema: None,
585            module_binding_names: vec![],
586            top_level_locals_count: 0,
587            top_level_local_storage_hints: vec![],
588            type_schema_registry: Default::default(),
589            module_binding_storage_hints: vec![],
590            function_local_storage_hints: vec![],
591            compiled_annotations: Default::default(),
592            trait_method_symbols: Default::default(),
593            expanded_function_defs: Default::default(),
594            string_index: Default::default(),
595            foreign_functions: Vec::new(),
596            native_struct_layouts: vec![],
597            content_addressed: None,
598            function_blob_hashes: vec![],
599            top_level_frame: None,
600            ..Default::default()
601        };
602
603        let request = CompilationRequest {
604            function_id: 0,
605            target_tier: Tier::BaselineJit,
606            blob_hash: None,
607            osr: true,
608            loop_header_ip: Some(0),
609            feedback: None,
610            callee_feedback: std::collections::HashMap::new(),
611        };
612
613        let result = backend.compile(&request, &program);
614        assert!(result.error.is_some());
615        assert!(result.error.unwrap().contains("unsupported opcode"));
616        assert_eq!(result.loop_header_ip, Some(0)); // For blacklisting
617    }
618}