lust/jit/codegen/
builder.rs

1use super::*;
2use crate::VM;
3use hashbrown::HashMap;
4impl JitCompiler {
5    pub fn new() -> Self {
6        Self {
7            ops: Assembler::new().unwrap(),
8            leaked_constants: Vec::new(),
9            fail_stack: Vec::new(),
10            exit_stack: Vec::new(),
11            inline_depth: 0,
12            specialization_registry: SpecializationRegistry::new(),
13            specialized_values: HashMap::new(),
14            next_specialized_id: 0,
15        }
16    }
17
18    pub(super) fn current_fail_label(&self) -> dynasmrt::DynamicLabel {
19        *self
20            .fail_stack
21            .last()
22            .expect("JIT fail label stack is empty")
23    }
24
25    pub(super) fn current_exit_label(&self) -> dynasmrt::DynamicLabel {
26        *self
27            .exit_stack
28            .last()
29            .expect("JIT exit label stack is empty")
30    }
31
32    pub fn compile_trace(
33        &mut self,
34        trace: &Trace,
35        trace_id: TraceId,
36        parent: Option<TraceId>,
37        hoisted_constants: Vec<(u8, Value)>,
38    ) -> Result<CompiledTrace> {
39        let stack_size = Self::compute_stack_size(trace);
40        let mut guards = Vec::new();
41        let mut guard_index = 0i32;
42        let exit_label = self.ops.new_dynamic_label();
43        let fail_label = self.ops.new_dynamic_label();
44        self.exit_stack.push(exit_label);
45        self.fail_stack.push(fail_label);
46        crate::jit::log(|| format!("🔧 JIT: Emitting prologue with sub rsp, {}", stack_size));
47        dynasm!(self.ops
48            ; push rbp
49            ; mov rbp, rsp
50            ; push rbx
51            ; push r12
52            ; push r13
53            ; push r14
54            ; push r15
55            ; sub rsp, stack_size
56            ; xor r15, r15
57            ; mov r12, rdi
58            ; mov r13, rsi
59        );
60        for (dest, value) in &hoisted_constants {
61            self.compile_load_const(*dest, value)?;
62        }
63
64        // Compile preamble (executed once at trace entry)
65        jit::log(|| format!("🔧 JIT: Compiling preamble ({} ops)", trace.preamble.len()));
66        self.compile_ops(&trace.preamble, &mut guard_index, &mut guards)?;
67
68        // Create a loop_start label AFTER preamble, BEFORE loop body
69        let loop_start_label = self.ops.new_dynamic_label();
70        dynasm!(self.ops
71            ; => loop_start_label
72            ; loop_start:
73        );
74
75        // Compile main trace body (the loop)
76        let compile_result = self.compile_ops(&trace.ops, &mut guard_index, &mut guards);
77        compile_result?;
78
79        // At end of loop body, jump back to loop_start to loop
80        dynasm!(self.ops
81            ; jmp => loop_start_label
82        );
83
84        let unwind_label = self.ops.new_dynamic_label();
85        let fail_return_label = self.ops.new_dynamic_label();
86        dynasm!(self.ops
87            ; => exit_label
88            ; exit:
89        );
90
91        // Compile postamble (executed once at trace exit)
92        jit::log(|| {
93            format!(
94                "🔧 JIT: Compiling postamble ({} ops)",
95                trace.postamble.len()
96            )
97        });
98        self.compile_ops(&trace.postamble, &mut guard_index, &mut guards)?;
99
100        // Now pop the label stacks after everything is compiled
101        self.exit_stack.pop();
102        self.fail_stack.pop();
103
104        // Set return value to 0 (success) AFTER postamble to avoid clobbering
105        dynasm!(self.ops
106            ; xor eax, eax
107        );
108
109        dynasm!(self.ops
110            ; add rsp, stack_size
111            ; pop r15
112            ; pop r14
113            ; pop r13
114            ; pop r12
115            ; pop rbx
116            ; pop rbp
117            ; ret
118            ; => fail_label
119            ; fail:
120            ; mov eax, DWORD -1
121            ; => unwind_label
122            ; test r15, r15
123            ; je => fail_return_label
124            ; mov eax, DWORD [r15]
125            ; mov rbx, rax
126            ; add rsp, rbx
127            ; mov r12, [r15 + 8]
128            ; mov r15, [r15 + 16]
129            ; add rsp, 24
130            ; jmp => unwind_label
131            ; => fail_return_label
132            ; jmp => exit_label
133        );
134        let ops = mem::replace(&mut self.ops, Assembler::new().unwrap());
135        let exec_buffer = ops.finalize().unwrap();
136        let entry_point = exec_buffer.ptr(dynasmrt::AssemblyOffset(0));
137        let entry: extern "C" fn(*mut Value, *mut VM, *const Function) -> i32 =
138            unsafe { mem::transmute(entry_point) };
139        #[cfg(feature = "std")]
140        {
141            if std::env::var("LUST_JIT_DUMP").is_ok() {
142                use std::{fs, path::PathBuf};
143                let len = exec_buffer.len();
144                let bytes = unsafe { std::slice::from_raw_parts(entry_point as *const u8, len) };
145                let mut path = PathBuf::from("target");
146                let _ = fs::create_dir_all(&path);
147                path.push(format!(
148                    "jit_trace_{}_{}.bin",
149                    trace_id.0,
150                    parent.map(|p| p.0).unwrap_or(trace.function_idx)
151                ));
152                if let Err(err) = fs::write(&path, bytes) {
153                    crate::jit::log(|| {
154                        format!("⚠️  JIT: failed to dump trace to {:?}: {}", path, err)
155                    });
156                } else {
157                    crate::jit::log(|| format!("📝 JIT: Dumped trace bytes to {:?}", path));
158                }
159            }
160        }
161        Box::leak(Box::new(exec_buffer));
162        let leaked_constants = mem::take(&mut self.leaked_constants);
163        Ok(CompiledTrace {
164            id: trace_id,
165            entry,
166            trace: trace.clone(),
167            guards,
168            parent,
169            side_traces: Vec::new(),
170            leaked_constants,
171            hoisted_constants,
172        })
173    }
174
175    fn compile_ops(
176        &mut self,
177        ops: &[TraceOp],
178        guard_index: &mut i32,
179        guards: &mut Vec<Guard>,
180    ) -> Result<()> {
181        for op in ops {
182            match op {
183                TraceOp::LoadConst { dest, value } => {
184                    self.compile_load_const(*dest, value)?;
185                }
186
187                TraceOp::Move { dest, src } => {
188                    self.compile_move(*dest, *src)?;
189                }
190
191                TraceOp::Add {
192                    dest,
193                    lhs,
194                    rhs,
195                    lhs_type,
196                    rhs_type,
197                } => {
198                    self.compile_add_specialized(*dest, *lhs, *rhs, *lhs_type, *rhs_type)?;
199                }
200
201                TraceOp::Sub {
202                    dest,
203                    lhs,
204                    rhs,
205                    lhs_type,
206                    rhs_type,
207                } => {
208                    self.compile_sub_specialized(*dest, *lhs, *rhs, *lhs_type, *rhs_type)?;
209                }
210
211                TraceOp::Mul {
212                    dest,
213                    lhs,
214                    rhs,
215                    lhs_type,
216                    rhs_type,
217                } => {
218                    self.compile_mul_specialized(*dest, *lhs, *rhs, *lhs_type, *rhs_type)?;
219                }
220
221                TraceOp::Div {
222                    dest,
223                    lhs,
224                    rhs,
225                    lhs_type,
226                    rhs_type,
227                } => {
228                    self.compile_div_specialized(*dest, *lhs, *rhs, *lhs_type, *rhs_type)?;
229                }
230
231                TraceOp::Mod {
232                    dest,
233                    lhs,
234                    rhs,
235                    lhs_type,
236                    rhs_type,
237                } => {
238                    self.compile_mod_specialized(*dest, *lhs, *rhs, *lhs_type, *rhs_type)?;
239                }
240
241                TraceOp::Neg { dest, src } => {
242                    self.compile_neg(*dest, *src)?;
243                }
244
245                TraceOp::Lt { dest, lhs, rhs } => {
246                    self.compile_lt(*dest, *lhs, *rhs)?;
247                }
248
249                TraceOp::Le { dest, lhs, rhs } => {
250                    self.compile_le(*dest, *lhs, *rhs)?;
251                }
252
253                TraceOp::Gt { dest, lhs, rhs } => {
254                    self.compile_gt(*dest, *lhs, *rhs)?;
255                }
256
257                TraceOp::Ge { dest, lhs, rhs } => {
258                    self.compile_ge(*dest, *lhs, *rhs)?;
259                }
260
261                TraceOp::Eq { dest, lhs, rhs } => {
262                    self.compile_eq(*dest, *lhs, *rhs)?;
263                }
264
265                TraceOp::Ne { dest, lhs, rhs } => {
266                    self.compile_ne(*dest, *lhs, *rhs)?;
267                }
268
269                TraceOp::And { dest, lhs, rhs } => {
270                    self.compile_and(*dest, *lhs, *rhs)?;
271                }
272
273                TraceOp::Or { dest, lhs, rhs } => {
274                    self.compile_or(*dest, *lhs, *rhs)?;
275                }
276
277                TraceOp::Not { dest, src } => {
278                    self.compile_not(*dest, *src)?;
279                }
280
281                TraceOp::Concat { dest, lhs, rhs } => {
282                    self.compile_concat(*dest, *lhs, *rhs)?;
283                }
284
285                TraceOp::GetIndex { dest, array, index } => {
286                    self.compile_get_index(*dest, *array, *index)?;
287                }
288
289                TraceOp::ArrayLen { dest, array } => {
290                    self.compile_array_len(*dest, *array)?;
291                }
292
293                TraceOp::GuardNativeFunction { register, function } => {
294                    let expected_ptr = function.pointer();
295                    crate::jit::log(|| format!("🔒 JIT: guard native reg {}", register));
296                    let guard = self.compile_guard_native_function(
297                        *register,
298                        expected_ptr,
299                        *guard_index as usize,
300                    )?;
301                    guards.push(guard);
302                    *guard_index += 1;
303                }
304
305                TraceOp::GuardFunction {
306                    register,
307                    function_idx,
308                } => {
309                    crate::jit::log(|| {
310                        format!(
311                            "🔒 JIT: guard function reg {} -> idx {}",
312                            register, function_idx
313                        )
314                    });
315                    let guard = self.compile_guard_function(
316                        *register,
317                        *function_idx,
318                        *guard_index as usize,
319                    )?;
320                    guards.push(guard);
321                    *guard_index += 1;
322                }
323
324                TraceOp::GuardClosure {
325                    register,
326                    function_idx,
327                    upvalues_ptr,
328                } => {
329                    crate::jit::log(|| {
330                        format!(
331                            "🔒 JIT: guard closure reg {} -> idx {}",
332                            register, function_idx
333                        )
334                    });
335                    let guard = self.compile_guard_closure(
336                        *register,
337                        *function_idx,
338                        *upvalues_ptr,
339                        *guard_index as usize,
340                    )?;
341                    guards.push(guard);
342                    *guard_index += 1;
343                }
344
345                TraceOp::CallNative {
346                    dest,
347                    callee,
348                    function,
349                    first_arg,
350                    arg_count,
351                } => {
352                    let expected_ptr = function.pointer();
353                    self.compile_call_native(*dest, *callee, expected_ptr, *first_arg, *arg_count)?;
354                }
355
356                TraceOp::CallFunction {
357                    dest,
358                    callee,
359                    function_idx,
360                    first_arg,
361                    arg_count,
362                    is_closure,
363                    upvalues_ptr,
364                } => {
365                    self.compile_call_function(
366                        *dest,
367                        *callee,
368                        *function_idx,
369                        *first_arg,
370                        *arg_count,
371                        *is_closure,
372                        *upvalues_ptr,
373                    )?;
374                }
375
376                TraceOp::InlineCall {
377                    dest,
378                    callee,
379                    trace,
380                } => {
381                    self.compile_inline_call(*dest, *callee, trace, guard_index, guards)?;
382                }
383
384                TraceOp::CallMethod {
385                    dest,
386                    object,
387                    method_name,
388                    first_arg,
389                    arg_count,
390                } => {
391                    // Optimize common method calls with specialized JIT helpers
392                    match (method_name.as_str(), *arg_count) {
393                        ("push", 1) => {
394                            self.compile_array_push(*object, *first_arg)?;
395                        }
396                        ("is_some", 0) => {
397                            self.compile_enum_is_some(*dest, *object)?;
398                        }
399                        ("unwrap", 0) => {
400                            self.compile_enum_unwrap(*dest, *object)?;
401                        }
402                        _ => {
403                            self.compile_call_method(
404                                *dest,
405                                *object,
406                                method_name,
407                                *first_arg,
408                                *arg_count,
409                            )?;
410                        }
411                    }
412                }
413
414                TraceOp::GetField {
415                    dest,
416                    object,
417                    field_name,
418                    field_index,
419                    value_type,
420                    is_weak,
421                } => {
422                    self.compile_get_field(
423                        *dest,
424                        *object,
425                        field_name,
426                        *field_index,
427                        *value_type,
428                        *is_weak,
429                    )?;
430                }
431
432                TraceOp::SetField {
433                    object,
434                    field_name,
435                    value,
436                    field_index,
437                    value_type,
438                    is_weak,
439                } => {
440                    self.compile_set_field(
441                        *object,
442                        field_name,
443                        *value,
444                        *field_index,
445                        *value_type,
446                        *is_weak,
447                    )?;
448                }
449
450                TraceOp::NewArray {
451                    dest,
452                    first_element,
453                    count,
454                } => {
455                    self.compile_new_array(*dest, *first_element, *count)?;
456                }
457
458                TraceOp::NewStruct {
459                    dest,
460                    struct_name,
461                    field_names,
462                    field_registers,
463                } => {
464                    self.compile_new_struct(*dest, struct_name, field_names, field_registers)?;
465                }
466
467                TraceOp::NewEnumUnit {
468                    dest,
469                    enum_name,
470                    variant_name,
471                } => {
472                    self.compile_new_enum_unit(*dest, enum_name, variant_name)?;
473                }
474
475                TraceOp::NewEnumVariant {
476                    dest,
477                    enum_name,
478                    variant_name,
479                    value_registers,
480                } => {
481                    self.compile_new_enum_variant(*dest, enum_name, variant_name, value_registers)?;
482                }
483
484                TraceOp::IsEnumVariant {
485                    dest,
486                    value,
487                    enum_name,
488                    variant_name,
489                } => {
490                    self.compile_is_enum_variant(*dest, *value, enum_name, variant_name)?;
491                }
492
493                TraceOp::GetEnumValue {
494                    dest,
495                    enum_reg,
496                    index,
497                } => {
498                    self.compile_get_enum_value(*dest, *enum_reg, *index)?;
499                }
500
501                TraceOp::Guard {
502                    register,
503                    expected_type,
504                } => {
505                    let guard =
506                        self.compile_guard(*register, *expected_type, *guard_index as usize)?;
507                    guards.push(guard);
508                    *guard_index += 1;
509                }
510
511                TraceOp::GuardLoopContinue {
512                    condition_register,
513                    expect_truthy,
514                    bailout_ip,
515                } => {
516                    let guard = self.compile_truth_guard(
517                        *condition_register,
518                        *expect_truthy,
519                        *bailout_ip,
520                        *guard_index as usize,
521                    )?;
522                    guards.push(guard);
523                    *guard_index += 1;
524                }
525
526                TraceOp::NestedLoopCall {
527                    function_idx,
528                    loop_start_ip,
529                    bailout_ip,
530                } => {
531                    // Nested loop call - this will be replaced with a direct call to
532                    // the compiled inner loop trace once it's compiled.
533                    // For now, exit to interpreter which will:
534                    // 1. Run the loop in interpreter
535                    // 2. Eventually compile it as a hot trace
536                    // 3. Later, this guard can become a side trace that calls the compiled loop
537
538                    let exit_label = self.current_exit_label();
539                    jit::log(|| {
540                        format!(
541                            "🔗 JIT: Nested loop at func {} ip {} - exiting to interpreter (guard #{})",
542                            function_idx, loop_start_ip, *guard_index
543                        )
544                    });
545                    guards.push(Guard {
546                        index: *guard_index as usize,
547                        bailout_ip: *bailout_ip,
548                        kind: GuardKind::NestedLoop {
549                            function_idx: *function_idx,
550                            loop_start_ip: *loop_start_ip,
551                        },
552                        fail_count: 0,
553                        side_trace: None,
554                    });
555                    let current_guard_index = *guard_index;
556                    dynasm!(self.ops
557                        ; mov eax, DWORD (current_guard_index + 1)
558                        ; jmp => exit_label
559                    );
560                    *guard_index += 1;
561                }
562
563                TraceOp::Unbox {
564                    specialized_id,
565                    source_reg,
566                    layout,
567                } => {
568                    self.compile_unbox(*specialized_id, *source_reg, layout)?;
569                }
570
571                TraceOp::Rebox {
572                    dest_reg,
573                    specialized_id,
574                    layout,
575                } => {
576                    self.compile_rebox(*dest_reg, *specialized_id, layout)?;
577                }
578
579                TraceOp::DropSpecialized {
580                    specialized_id,
581                    layout,
582                } => {
583                    self.compile_drop_specialized(*specialized_id, layout)?;
584                }
585
586                TraceOp::SpecializedOp { op, operands } => {
587                    self.compile_specialized_op(op, operands)?;
588                }
589
590                TraceOp::Return { .. } => {}
591            }
592        }
593
594        Ok(())
595    }
596
597    fn compute_stack_size(trace: &Trace) -> i32 {
598        let specialized_slots = Self::count_specialized_slots(trace) as i32;
599        let specialized_bytes =
600            SPECIALIZED_STACK_BASE + (specialized_slots * SPECIALIZED_SLOT_SIZE);
601        let mut size = MIN_JIT_STACK_SIZE.max(specialized_bytes);
602        let remainder = size % 16;
603        if remainder != 8 {
604            size += (8 - remainder + 16) % 16;
605        }
606        crate::jit::log(|| {
607            format!(
608                "🧮 JIT: Trace requires {} specialized slots → stack {} bytes",
609                specialized_slots, size
610            )
611        });
612        size
613    }
614
615    fn count_specialized_slots(trace: &Trace) -> usize {
616        trace
617            .preamble
618            .iter()
619            .chain(trace.ops.iter())
620            .chain(trace.postamble.iter())
621            .filter(|op| matches!(op, TraceOp::Unbox { .. }))
622            .count()
623    }
624
625    fn compile_inline_call(
626        &mut self,
627        dest: u8,
628        callee: u8,
629        trace: &InlineTrace,
630        guard_index: &mut i32,
631        guards: &mut Vec<Guard>,
632    ) -> Result<()> {
633        self.inline_depth += 1;
634        let result = (|| -> Result<()> {
635            if trace.register_count == 0 {
636                crate::jit::log(|| {
637                    format!(
638                        "⚠️  JIT: Inline fallback for func {} (no registers)",
639                        trace.function_idx
640                    )
641                });
642                return self.compile_call_function(
643                    dest,
644                    callee,
645                    trace.function_idx,
646                    trace.first_arg,
647                    trace.arg_count,
648                    trace.is_closure,
649                    trace.upvalues_ptr,
650                );
651            }
652
653            crate::jit::log(|| {
654                format!(
655                    "✨ JIT: Inlining call to func {} into register R{}",
656                    trace.function_idx, dest
657                )
658            });
659
660            let value_size = mem::size_of::<Value>() as i32;
661            let frame_size = trace.register_count as i32 * value_size;
662            let align_adjust = ((16 - (frame_size & 15)) & 15) as i32;
663            let metadata_size = 32i32;
664            let outer_fail = self.current_fail_label();
665            let inline_fail = self.ops.new_dynamic_label();
666            let inline_end = self.ops.new_dynamic_label();
667            extern "C" {
668                fn jit_move_safe(src_ptr: *const Value, dest_ptr: *mut Value) -> u8;
669            }
670
671            // Save inline metadata (frame size, caller registers, previous inline frame).
672            dynasm!(self.ops
673                ; sub rsp, metadata_size
674            );
675            dynasm!(self.ops
676                ; mov eax, DWORD frame_size as _
677                ; mov [rsp], rax
678                ; mov [rsp + 8], r12
679                ; mov [rsp + 16], r15
680            );
681            dynasm!(self.ops
682                ; mov eax, DWORD align_adjust as _
683                ; mov [rsp + 24], rax
684                ; mov r15, rsp
685            );
686            if align_adjust != 0 {
687                dynasm!(self.ops
688                    ; sub rsp, align_adjust
689                );
690            }
691            // Allocate space for callee registers.
692            dynasm!(self.ops
693                ; sub rsp, frame_size
694                ; mov r12, rsp
695            );
696
697            for reg in 0..trace.register_count {
698                self.compile_load_const(reg, &Value::Nil)?;
699            }
700
701            // Copy positional arguments into callee registers.
702            for (arg_index, src_reg) in trace.arg_registers.iter().enumerate() {
703                let src_offset = (*src_reg as i32) * value_size;
704                let dest_offset = (arg_index as i32) * value_size;
705                dynasm!(self.ops
706                    ; mov r14, [r15 + 8]
707                    ; lea rdi, [r14 + src_offset]
708                    ; lea rsi, [r12 + dest_offset]
709                    ; mov rax, QWORD jit_move_safe as _
710                    ; call rax
711                    ; test al, al
712                    ; jz =>inline_fail
713                );
714            }
715
716            self.fail_stack.push(inline_fail);
717            let inline_result = self.compile_ops(&trace.body, guard_index, guards);
718            self.fail_stack.pop();
719            inline_result?;
720
721            if let Some(ret_reg) = trace.return_register {
722                let ret_offset = (ret_reg as i32) * value_size;
723                let dest_offset = (dest as i32) * value_size;
724                dynasm!(self.ops
725                    ; mov r14, [r15 + 8]
726                    ; lea rdi, [r12 + ret_offset]
727                    ; lea rsi, [r14 + dest_offset]
728                    ; mov rax, QWORD jit_move_safe as _
729                    ; call rax
730                    ; test al, al
731                    ; jz =>inline_fail
732                );
733                dynasm!(self.ops
734                    ; add rsp, frame_size
735                );
736                dynasm!(self.ops
737                    ; mov eax, DWORD [r15 + 24]
738                    ; add rsp, rax
739                    ; mov r12, [r15 + 8]
740                    ; mov r15, [r15 + 16]
741                    ; add rsp, metadata_size
742                    ; jmp => inline_end
743                );
744            } else {
745                dynasm!(self.ops
746                    ; add rsp, frame_size
747                );
748                dynasm!(self.ops
749                    ; mov eax, DWORD [r15 + 24]
750                    ; add rsp, rax
751                    ; mov r12, [r15 + 8]
752                    ; mov r15, [r15 + 16]
753                    ; add rsp, metadata_size
754                );
755                self.compile_load_const(dest, &Value::Nil)?;
756                dynasm!(self.ops
757                    ; jmp => inline_end
758                );
759            }
760
761            dynasm!(self.ops
762                ; => inline_fail
763                ; mov eax, DWORD [r15]
764                ; mov rbx, rax
765                ; add rsp, rbx
766            );
767            dynasm!(self.ops
768            ; mov eax, DWORD [r15 + 24]
769            ; add rsp, rax
770            ; mov r12, [r15 + 8]
771            ; mov r15, [r15 + 16]
772            ; add rsp, metadata_size
773            ; jmp => outer_fail
774            ; => inline_end
775            );
776
777            Ok(())
778        })();
779        self.inline_depth -= 1;
780        result
781    }
782}