lust/jit/codegen/
builder.rs

1use super::*;
2use crate::VM;
3impl JitCompiler {
4    pub fn new() -> Self {
5        Self {
6            ops: Assembler::new().unwrap(),
7            leaked_constants: Vec::new(),
8            fail_stack: Vec::new(),
9            exit_stack: Vec::new(),
10            inline_depth: 0,
11        }
12    }
13
14    fn current_fail_label(&self) -> dynasmrt::DynamicLabel {
15        *self
16            .fail_stack
17            .last()
18            .expect("JIT fail label stack is empty")
19    }
20
21    pub(super) fn current_exit_label(&self) -> dynasmrt::DynamicLabel {
22        *self
23            .exit_stack
24            .last()
25            .expect("JIT exit label stack is empty")
26    }
27
28    pub fn compile_trace(
29        &mut self,
30        trace: &Trace,
31        trace_id: TraceId,
32        parent: Option<TraceId>,
33        hoisted_constants: Vec<(u8, Value)>,
34    ) -> Result<CompiledTrace> {
35        let mut guards = Vec::new();
36        let mut guard_index = 0i32;
37        let exit_label = self.ops.new_dynamic_label();
38        let fail_label = self.ops.new_dynamic_label();
39        self.exit_stack.push(exit_label);
40        self.fail_stack.push(fail_label);
41        crate::jit::log(|| "🔧 JIT: Emitting prologue with sub rsp, 40".to_string());
42        dynasm!(self.ops
43            ; push rbp
44            ; mov rbp, rsp
45            ; push rbx
46            ; push r12
47            ; push r13
48            ; push r14
49            ; push r15
50            ; sub rsp, 40
51            ; xor r15, r15
52            ; mov r12, rdi
53            ; mov r13, rsi
54        );
55        for (dest, value) in &hoisted_constants {
56            self.compile_load_const(*dest, value)?;
57        }
58
59        let compile_result = self.compile_ops(&trace.ops, &mut guard_index, &mut guards);
60        self.exit_stack.pop();
61        self.fail_stack.pop();
62        compile_result?;
63        let unwind_label = self.ops.new_dynamic_label();
64        let fail_return_label = self.ops.new_dynamic_label();
65        dynasm!(self.ops
66            ; xor eax, eax
67            ; => exit_label
68            ; exit:
69            ; add rsp, 40
70            ; pop r15
71            ; pop r14
72            ; pop r13
73            ; pop r12
74            ; pop rbx
75            ; pop rbp
76            ; ret
77            ; => fail_label
78            ; fail:
79            ; mov eax, DWORD -1
80            ; => unwind_label
81            ; test r15, r15
82            ; je => fail_return_label
83            ; mov eax, DWORD [r15]
84            ; mov rbx, rax
85            ; add rsp, rbx
86            ; mov r12, [r15 + 8]
87            ; mov r15, [r15 + 16]
88            ; add rsp, 24
89            ; jmp => unwind_label
90            ; => fail_return_label
91            ; jmp => exit_label
92        );
93        let ops = mem::replace(&mut self.ops, Assembler::new().unwrap());
94        let exec_buffer = ops.finalize().unwrap();
95        let entry_point = exec_buffer.ptr(dynasmrt::AssemblyOffset(0));
96        let entry: extern "C" fn(*mut Value, *mut VM, *const Function) -> i32 =
97            unsafe { mem::transmute(entry_point) };
98        #[cfg(feature = "std")]
99        {
100            if std::env::var("LUST_JIT_DUMP").is_ok() {
101                use std::{fs, path::PathBuf};
102                let len = exec_buffer.len();
103                let bytes = unsafe { std::slice::from_raw_parts(entry_point as *const u8, len) };
104                let mut path = PathBuf::from("target");
105                let _ = fs::create_dir_all(&path);
106                path.push(format!(
107                    "jit_trace_{}_{}.bin",
108                    trace_id.0,
109                    parent.map(|p| p.0).unwrap_or(trace.function_idx)
110                ));
111                if let Err(err) = fs::write(&path, bytes) {
112                    crate::jit::log(|| {
113                        format!("⚠️  JIT: failed to dump trace to {:?}: {}", path, err)
114                    });
115                } else {
116                    crate::jit::log(|| format!("📝 JIT: Dumped trace bytes to {:?}", path));
117                }
118            }
119        }
120        Box::leak(Box::new(exec_buffer));
121        let leaked_constants = mem::take(&mut self.leaked_constants);
122        Ok(CompiledTrace {
123            id: trace_id,
124            entry,
125            trace: trace.clone(),
126            guards,
127            parent,
128            side_traces: Vec::new(),
129            leaked_constants,
130            hoisted_constants,
131        })
132    }
133
134    fn compile_ops(
135        &mut self,
136        ops: &[TraceOp],
137        guard_index: &mut i32,
138        guards: &mut Vec<Guard>,
139    ) -> Result<()> {
140        for op in ops {
141            match op {
142                TraceOp::LoadConst { dest, value } => {
143                    self.compile_load_const(*dest, value)?;
144                }
145
146                TraceOp::Move { dest, src } => {
147                    self.compile_move(*dest, *src)?;
148                }
149
150                TraceOp::Add {
151                    dest,
152                    lhs,
153                    rhs,
154                    lhs_type,
155                    rhs_type,
156                } => {
157                    self.compile_add_specialized(*dest, *lhs, *rhs, *lhs_type, *rhs_type)?;
158                }
159
160                TraceOp::Sub {
161                    dest,
162                    lhs,
163                    rhs,
164                    lhs_type,
165                    rhs_type,
166                } => {
167                    self.compile_sub_specialized(*dest, *lhs, *rhs, *lhs_type, *rhs_type)?;
168                }
169
170                TraceOp::Mul {
171                    dest,
172                    lhs,
173                    rhs,
174                    lhs_type,
175                    rhs_type,
176                } => {
177                    self.compile_mul_specialized(*dest, *lhs, *rhs, *lhs_type, *rhs_type)?;
178                }
179
180                TraceOp::Div {
181                    dest,
182                    lhs,
183                    rhs,
184                    lhs_type,
185                    rhs_type,
186                } => {
187                    self.compile_div_specialized(*dest, *lhs, *rhs, *lhs_type, *rhs_type)?;
188                }
189
190                TraceOp::Mod {
191                    dest,
192                    lhs,
193                    rhs,
194                    lhs_type,
195                    rhs_type,
196                } => {
197                    self.compile_mod_specialized(*dest, *lhs, *rhs, *lhs_type, *rhs_type)?;
198                }
199
200                TraceOp::Neg { dest, src } => {
201                    self.compile_neg(*dest, *src)?;
202                }
203
204                TraceOp::Lt { dest, lhs, rhs } => {
205                    self.compile_lt(*dest, *lhs, *rhs)?;
206                }
207
208                TraceOp::Le { dest, lhs, rhs } => {
209                    self.compile_le(*dest, *lhs, *rhs)?;
210                }
211
212                TraceOp::Gt { dest, lhs, rhs } => {
213                    self.compile_gt(*dest, *lhs, *rhs)?;
214                }
215
216                TraceOp::Ge { dest, lhs, rhs } => {
217                    self.compile_ge(*dest, *lhs, *rhs)?;
218                }
219
220                TraceOp::Eq { dest, lhs, rhs } => {
221                    self.compile_eq(*dest, *lhs, *rhs)?;
222                }
223
224                TraceOp::Ne { dest, lhs, rhs } => {
225                    self.compile_ne(*dest, *lhs, *rhs)?;
226                }
227
228                TraceOp::And { dest, lhs, rhs } => {
229                    self.compile_and(*dest, *lhs, *rhs)?;
230                }
231
232                TraceOp::Or { dest, lhs, rhs } => {
233                    self.compile_or(*dest, *lhs, *rhs)?;
234                }
235
236                TraceOp::Not { dest, src } => {
237                    self.compile_not(*dest, *src)?;
238                }
239
240                TraceOp::Concat { dest, lhs, rhs } => {
241                    self.compile_concat(*dest, *lhs, *rhs)?;
242                }
243
244                TraceOp::GetIndex { dest, array, index } => {
245                    self.compile_get_index(*dest, *array, *index)?;
246                }
247
248                TraceOp::ArrayLen { dest, array } => {
249                    self.compile_array_len(*dest, *array)?;
250                }
251
252                TraceOp::GuardNativeFunction { register, function } => {
253                    let expected_ptr = function.pointer();
254                    crate::jit::log(|| format!("🔒 JIT: guard native reg {}", register));
255                    let guard = self.compile_guard_native_function(
256                        *register,
257                        expected_ptr,
258                        *guard_index as usize,
259                    )?;
260                    guards.push(guard);
261                    *guard_index += 1;
262                }
263
264                TraceOp::GuardFunction {
265                    register,
266                    function_idx,
267                } => {
268                    crate::jit::log(|| {
269                        format!(
270                            "🔒 JIT: guard function reg {} -> idx {}",
271                            register, function_idx
272                        )
273                    });
274                    let guard = self.compile_guard_function(
275                        *register,
276                        *function_idx,
277                        *guard_index as usize,
278                    )?;
279                    guards.push(guard);
280                    *guard_index += 1;
281                }
282
283                TraceOp::GuardClosure {
284                    register,
285                    function_idx,
286                    upvalues_ptr,
287                } => {
288                    crate::jit::log(|| {
289                        format!(
290                            "🔒 JIT: guard closure reg {} -> idx {}",
291                            register, function_idx
292                        )
293                    });
294                    let guard = self.compile_guard_closure(
295                        *register,
296                        *function_idx,
297                        *upvalues_ptr,
298                        *guard_index as usize,
299                    )?;
300                    guards.push(guard);
301                    *guard_index += 1;
302                }
303
304                TraceOp::CallNative {
305                    dest,
306                    callee,
307                    function,
308                    first_arg,
309                    arg_count,
310                } => {
311                    let expected_ptr = function.pointer();
312                    self.compile_call_native(*dest, *callee, expected_ptr, *first_arg, *arg_count)?;
313                }
314
315                TraceOp::CallFunction {
316                    dest,
317                    callee,
318                    function_idx,
319                    first_arg,
320                    arg_count,
321                    is_closure,
322                    upvalues_ptr,
323                } => {
324                    self.compile_call_function(
325                        *dest,
326                        *callee,
327                        *function_idx,
328                        *first_arg,
329                        *arg_count,
330                        *is_closure,
331                        *upvalues_ptr,
332                    )?;
333                }
334
335                TraceOp::InlineCall {
336                    dest,
337                    callee,
338                    trace,
339                } => {
340                    self.compile_inline_call(*dest, *callee, trace, guard_index, guards)?;
341                }
342
343                TraceOp::CallMethod {
344                    dest,
345                    object,
346                    method_name,
347                    first_arg,
348                    arg_count,
349                } => {
350                    self.compile_call_method(*dest, *object, method_name, *first_arg, *arg_count)?;
351                }
352
353                TraceOp::GetField {
354                    dest,
355                    object,
356                    field_name,
357                    field_index,
358                    value_type,
359                    is_weak,
360                } => {
361                    self.compile_get_field(
362                        *dest,
363                        *object,
364                        field_name,
365                        *field_index,
366                        *value_type,
367                        *is_weak,
368                    )?;
369                }
370
371                TraceOp::SetField {
372                    object,
373                    field_name,
374                    value,
375                    field_index,
376                    value_type,
377                    is_weak,
378                } => {
379                    self.compile_set_field(
380                        *object,
381                        field_name,
382                        *value,
383                        *field_index,
384                        *value_type,
385                        *is_weak,
386                    )?;
387                }
388
389                TraceOp::NewStruct {
390                    dest,
391                    struct_name,
392                    field_names,
393                    field_registers,
394                } => {
395                    self.compile_new_struct(*dest, struct_name, field_names, field_registers)?;
396                }
397
398                TraceOp::NewEnumUnit {
399                    dest,
400                    enum_name,
401                    variant_name,
402                } => {
403                    self.compile_new_enum_unit(*dest, enum_name, variant_name)?;
404                }
405
406                TraceOp::NewEnumVariant {
407                    dest,
408                    enum_name,
409                    variant_name,
410                    value_registers,
411                } => {
412                    self.compile_new_enum_variant(*dest, enum_name, variant_name, value_registers)?;
413                }
414
415                TraceOp::IsEnumVariant {
416                    dest,
417                    value,
418                    enum_name,
419                    variant_name,
420                } => {
421                    self.compile_is_enum_variant(*dest, *value, enum_name, variant_name)?;
422                }
423
424                TraceOp::GetEnumValue {
425                    dest,
426                    enum_reg,
427                    index,
428                } => {
429                    self.compile_get_enum_value(*dest, *enum_reg, *index)?;
430                }
431
432                TraceOp::Guard {
433                    register,
434                    expected_type,
435                } => {
436                    let guard =
437                        self.compile_guard(*register, *expected_type, *guard_index as usize)?;
438                    guards.push(guard);
439                    *guard_index += 1;
440                }
441
442                TraceOp::GuardLoopContinue {
443                    condition_register,
444                    expect_truthy,
445                    bailout_ip,
446                } => {
447                    let guard = self.compile_truth_guard(
448                        *condition_register,
449                        *expect_truthy,
450                        *bailout_ip,
451                        *guard_index as usize,
452                    )?;
453                    guards.push(guard);
454                    *guard_index += 1;
455                }
456
457                TraceOp::NestedLoopCall {
458                    function_idx,
459                    loop_start_ip,
460                    bailout_ip,
461                } => {
462                    let exit_label = self.current_exit_label();
463                    jit::log(|| {
464                        format!(
465                            "🔗 JIT: Nested loop detected at func {} ip {} (guard #{})",
466                            function_idx, loop_start_ip, *guard_index
467                        )
468                    });
469                    guards.push(Guard {
470                        index: *guard_index as usize,
471                        bailout_ip: *bailout_ip,
472                        kind: GuardKind::NestedLoop {
473                            function_idx: *function_idx,
474                            loop_start_ip: *loop_start_ip,
475                        },
476                        fail_count: 0,
477                        side_trace: None,
478                    });
479                    let current_guard_index = *guard_index;
480                    dynasm!(self.ops
481                        ; mov eax, DWORD (current_guard_index + 1)
482                        ; jmp => exit_label
483                    );
484                    *guard_index += 1;
485                }
486
487                TraceOp::Return { .. } => {}
488            }
489        }
490
491        Ok(())
492    }
493
494    fn compile_inline_call(
495        &mut self,
496        dest: u8,
497        callee: u8,
498        trace: &InlineTrace,
499        guard_index: &mut i32,
500        guards: &mut Vec<Guard>,
501    ) -> Result<()> {
502        self.inline_depth += 1;
503        let result = (|| -> Result<()> {
504            if trace.register_count == 0 {
505                crate::jit::log(|| {
506                    format!(
507                        "⚠️  JIT: Inline fallback for func {} (no registers)",
508                        trace.function_idx
509                    )
510                });
511                return self.compile_call_function(
512                    dest,
513                    callee,
514                    trace.function_idx,
515                    trace.first_arg,
516                    trace.arg_count,
517                    trace.is_closure,
518                    trace.upvalues_ptr,
519                );
520            }
521
522            crate::jit::log(|| {
523                format!(
524                    "✨ JIT: Inlining call to func {} into register R{}",
525                    trace.function_idx, dest
526                )
527            });
528
529            let value_size = mem::size_of::<Value>() as i32;
530            let frame_size = trace.register_count as i32 * value_size;
531            let align_adjust = ((16 - (frame_size & 15)) & 15) as i32;
532            let metadata_size = 32i32;
533            let outer_fail = self.current_fail_label();
534            let inline_fail = self.ops.new_dynamic_label();
535            let inline_end = self.ops.new_dynamic_label();
536            extern "C" {
537                fn jit_move_safe(src_ptr: *const Value, dest_ptr: *mut Value) -> u8;
538            }
539
540            // Save inline metadata (frame size, caller registers, previous inline frame).
541            dynasm!(self.ops
542                ; sub rsp, metadata_size
543            );
544            dynasm!(self.ops
545                ; mov eax, DWORD frame_size as _
546                ; mov [rsp], rax
547                ; mov [rsp + 8], r12
548                ; mov [rsp + 16], r15
549            );
550            dynasm!(self.ops
551                ; mov eax, DWORD align_adjust as _
552                ; mov [rsp + 24], rax
553                ; mov r15, rsp
554            );
555            if align_adjust != 0 {
556                dynasm!(self.ops
557                    ; sub rsp, align_adjust
558                );
559            }
560            // Allocate space for callee registers.
561            dynasm!(self.ops
562                ; sub rsp, frame_size
563                ; mov r12, rsp
564            );
565
566            for reg in 0..trace.register_count {
567                self.compile_load_const(reg, &Value::Nil)?;
568            }
569
570            // Copy positional arguments into callee registers.
571            for (arg_index, src_reg) in trace.arg_registers.iter().enumerate() {
572                let src_offset = (*src_reg as i32) * value_size;
573                let dest_offset = (arg_index as i32) * value_size;
574                dynasm!(self.ops
575                    ; mov r14, [r15 + 8]
576                    ; lea rdi, [r14 + src_offset]
577                    ; lea rsi, [r12 + dest_offset]
578                    ; mov rax, QWORD jit_move_safe as _
579                    ; call rax
580                    ; test al, al
581                    ; jz =>inline_fail
582                );
583            }
584
585            self.fail_stack.push(inline_fail);
586            let inline_result = self.compile_ops(&trace.body, guard_index, guards);
587            self.fail_stack.pop();
588            inline_result?;
589
590            if let Some(ret_reg) = trace.return_register {
591                let ret_offset = (ret_reg as i32) * value_size;
592                let dest_offset = (dest as i32) * value_size;
593                dynasm!(self.ops
594                    ; mov r14, [r15 + 8]
595                    ; lea rdi, [r12 + ret_offset]
596                    ; lea rsi, [r14 + dest_offset]
597                    ; mov rax, QWORD jit_move_safe as _
598                    ; call rax
599                    ; test al, al
600                    ; jz =>inline_fail
601                );
602                dynasm!(self.ops
603                    ; add rsp, frame_size
604                );
605                dynasm!(self.ops
606                    ; mov eax, DWORD [r15 + 24]
607                    ; add rsp, rax
608                    ; mov r12, [r15 + 8]
609                    ; mov r15, [r15 + 16]
610                    ; add rsp, metadata_size
611                    ; jmp => inline_end
612                );
613            } else {
614                dynasm!(self.ops
615                    ; add rsp, frame_size
616                );
617                dynasm!(self.ops
618                    ; mov eax, DWORD [r15 + 24]
619                    ; add rsp, rax
620                    ; mov r12, [r15 + 8]
621                    ; mov r15, [r15 + 16]
622                    ; add rsp, metadata_size
623                );
624                self.compile_load_const(dest, &Value::Nil)?;
625                dynasm!(self.ops
626                    ; jmp => inline_end
627                );
628            }
629
630            dynasm!(self.ops
631                ; => inline_fail
632                ; mov eax, DWORD [r15]
633                ; mov rbx, rax
634                ; add rsp, rbx
635            );
636            dynasm!(self.ops
637            ; mov eax, DWORD [r15 + 24]
638            ; add rsp, rax
639            ; mov r12, [r15 + 8]
640            ; mov r15, [r15 + 16]
641            ; add rsp, metadata_size
642            ; jmp => outer_fail
643            ; => inline_end
644            );
645
646            Ok(())
647        })();
648        self.inline_depth -= 1;
649        result
650    }
651}