lust/jit/
trace.rs

1use crate::bytecode::value::NativeFn;
2use crate::bytecode::Instruction;
3use crate::bytecode::{Register, Value};
4use crate::LustError;
5use alloc::{
6    format,
7    rc::Rc,
8    string::{String, ToString},
9    vec::Vec,
10};
11use core::fmt;
12use hashbrown::HashSet;
13
14#[derive(Clone)]
15pub struct TracedNativeFn {
16    function: NativeFn,
17}
18
19impl TracedNativeFn {
20    pub fn new(function: NativeFn) -> Self {
21        Self { function }
22    }
23
24    pub fn pointer(&self) -> *const () {
25        Rc::as_ptr(&self.function) as *const ()
26    }
27}
28
29impl fmt::Debug for TracedNativeFn {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        write!(f, "NativeFn({:p})", Rc::as_ptr(&self.function))
32    }
33}
34
35#[derive(Debug, Clone)]
36pub struct Trace {
37    pub function_idx: usize,
38    pub start_ip: usize,
39    pub ops: Vec<TraceOp>,
40    pub inputs: Vec<Register>,
41    pub outputs: Vec<Register>,
42}
43
44#[derive(Debug, Clone)]
45pub enum TraceOp {
46    LoadConst {
47        dest: Register,
48        value: Value,
49    },
50    Move {
51        dest: Register,
52        src: Register,
53    },
54    Add {
55        dest: Register,
56        lhs: Register,
57        rhs: Register,
58        lhs_type: ValueType,
59        rhs_type: ValueType,
60    },
61    Sub {
62        dest: Register,
63        lhs: Register,
64        rhs: Register,
65        lhs_type: ValueType,
66        rhs_type: ValueType,
67    },
68    Mul {
69        dest: Register,
70        lhs: Register,
71        rhs: Register,
72        lhs_type: ValueType,
73        rhs_type: ValueType,
74    },
75    Div {
76        dest: Register,
77        lhs: Register,
78        rhs: Register,
79        lhs_type: ValueType,
80        rhs_type: ValueType,
81    },
82    Mod {
83        dest: Register,
84        lhs: Register,
85        rhs: Register,
86        lhs_type: ValueType,
87        rhs_type: ValueType,
88    },
89    Neg {
90        dest: Register,
91        src: Register,
92    },
93    Eq {
94        dest: Register,
95        lhs: Register,
96        rhs: Register,
97    },
98    Ne {
99        dest: Register,
100        lhs: Register,
101        rhs: Register,
102    },
103    Lt {
104        dest: Register,
105        lhs: Register,
106        rhs: Register,
107    },
108    Le {
109        dest: Register,
110        lhs: Register,
111        rhs: Register,
112    },
113    Gt {
114        dest: Register,
115        lhs: Register,
116        rhs: Register,
117    },
118    Ge {
119        dest: Register,
120        lhs: Register,
121        rhs: Register,
122    },
123    And {
124        dest: Register,
125        lhs: Register,
126        rhs: Register,
127    },
128    Or {
129        dest: Register,
130        lhs: Register,
131        rhs: Register,
132    },
133    Not {
134        dest: Register,
135        src: Register,
136    },
137    Concat {
138        dest: Register,
139        lhs: Register,
140        rhs: Register,
141    },
142    GetIndex {
143        dest: Register,
144        array: Register,
145        index: Register,
146    },
147    ArrayLen {
148        dest: Register,
149        array: Register,
150    },
151    GuardNativeFunction {
152        register: Register,
153        function: TracedNativeFn,
154    },
155    CallNative {
156        dest: Register,
157        callee: Register,
158        function: TracedNativeFn,
159        first_arg: Register,
160        arg_count: u8,
161    },
162    CallMethod {
163        dest: Register,
164        object: Register,
165        method_name: String,
166        first_arg: Register,
167        arg_count: u8,
168    },
169    GetField {
170        dest: Register,
171        object: Register,
172        field_name: String,
173        field_index: Option<usize>,
174        value_type: Option<ValueType>,
175        is_weak: bool,
176    },
177    SetField {
178        object: Register,
179        field_name: String,
180        value: Register,
181        field_index: Option<usize>,
182        value_type: Option<ValueType>,
183        is_weak: bool,
184    },
185    NewStruct {
186        dest: Register,
187        struct_name: String,
188        field_names: Vec<String>,
189        field_registers: Vec<Register>,
190    },
191    Guard {
192        register: Register,
193        expected_type: ValueType,
194    },
195    GuardLoopContinue {
196        condition_register: Register,
197        bailout_ip: usize,
198    },
199    NestedLoopCall {
200        function_idx: usize,
201        loop_start_ip: usize,
202        bailout_ip: usize,
203    },
204    Return {
205        value: Option<Register>,
206    },
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq)]
210pub enum ValueType {
211    Int,
212    Float,
213    Bool,
214    String,
215    Array,
216    Tuple,
217    Struct,
218}
219
220pub struct TraceRecorder {
221    pub trace: Trace,
222    max_length: usize,
223    recording: bool,
224    guarded_registers: HashSet<Register>,
225}
226
227impl TraceRecorder {
228    pub fn new(function_idx: usize, start_ip: usize, max_length: usize) -> Self {
229        Self {
230            trace: Trace {
231                function_idx,
232                start_ip,
233                ops: Vec::new(),
234                inputs: Vec::new(),
235                outputs: Vec::new(),
236            },
237            max_length,
238            recording: true,
239            guarded_registers: HashSet::new(),
240        }
241    }
242
243    pub fn record_instruction(
244        &mut self,
245        instruction: Instruction,
246        current_ip: usize,
247        registers: &[Value; 256],
248        function: &crate::bytecode::Function,
249        function_idx: usize,
250    ) -> Result<(), LustError> {
251        if !self.recording {
252            return Ok(());
253        }
254
255        if function_idx != self.trace.function_idx {
256            return Ok(());
257        }
258
259        let trace_op = match instruction {
260            Instruction::LoadConst(dest, _) => {
261                if let Some(_ty) = Self::get_value_type(&registers[dest as usize]) {
262                    self.guarded_registers.insert(dest);
263                }
264
265                TraceOp::LoadConst {
266                    dest,
267                    value: registers[dest as usize].clone(),
268                }
269            }
270
271            Instruction::LoadGlobal(dest, _) => {
272                if let Some(_ty) = Self::get_value_type(&registers[dest as usize]) {
273                    self.guarded_registers.insert(dest);
274                }
275
276                TraceOp::LoadConst {
277                    dest,
278                    value: registers[dest as usize].clone(),
279                }
280            }
281
282            Instruction::StoreGlobal(_, _) => {
283                return Ok(());
284            }
285
286            Instruction::Move(dest, src) => TraceOp::Move { dest, src },
287            Instruction::Add(dest, lhs, rhs) => {
288                self.add_type_guards(lhs, rhs, registers, function)?;
289                let lhs_type =
290                    Self::get_value_type(&registers[lhs as usize]).unwrap_or(ValueType::Int);
291                let rhs_type =
292                    Self::get_value_type(&registers[rhs as usize]).unwrap_or(ValueType::Int);
293                TraceOp::Add {
294                    dest,
295                    lhs,
296                    rhs,
297                    lhs_type,
298                    rhs_type,
299                }
300            }
301
302            Instruction::Sub(dest, lhs, rhs) => {
303                self.add_type_guards(lhs, rhs, registers, function)?;
304                let lhs_type =
305                    Self::get_value_type(&registers[lhs as usize]).unwrap_or(ValueType::Int);
306                let rhs_type =
307                    Self::get_value_type(&registers[rhs as usize]).unwrap_or(ValueType::Int);
308                TraceOp::Sub {
309                    dest,
310                    lhs,
311                    rhs,
312                    lhs_type,
313                    rhs_type,
314                }
315            }
316
317            Instruction::Mul(dest, lhs, rhs) => {
318                self.add_type_guards(lhs, rhs, registers, function)?;
319                let lhs_type =
320                    Self::get_value_type(&registers[lhs as usize]).unwrap_or(ValueType::Int);
321                let rhs_type =
322                    Self::get_value_type(&registers[rhs as usize]).unwrap_or(ValueType::Int);
323                TraceOp::Mul {
324                    dest,
325                    lhs,
326                    rhs,
327                    lhs_type,
328                    rhs_type,
329                }
330            }
331
332            Instruction::Div(dest, lhs, rhs) => {
333                self.add_type_guards(lhs, rhs, registers, function)?;
334                let lhs_type =
335                    Self::get_value_type(&registers[lhs as usize]).unwrap_or(ValueType::Int);
336                let rhs_type =
337                    Self::get_value_type(&registers[rhs as usize]).unwrap_or(ValueType::Int);
338                TraceOp::Div {
339                    dest,
340                    lhs,
341                    rhs,
342                    lhs_type,
343                    rhs_type,
344                }
345            }
346
347            Instruction::Mod(dest, lhs, rhs) => {
348                self.add_type_guards(lhs, rhs, registers, function)?;
349                let lhs_type =
350                    Self::get_value_type(&registers[lhs as usize]).unwrap_or(ValueType::Int);
351                let rhs_type =
352                    Self::get_value_type(&registers[rhs as usize]).unwrap_or(ValueType::Int);
353                TraceOp::Mod {
354                    dest,
355                    lhs,
356                    rhs,
357                    lhs_type,
358                    rhs_type,
359                }
360            }
361
362            Instruction::Neg(dest, src) => TraceOp::Neg { dest, src },
363            Instruction::Eq(dest, lhs, rhs) => TraceOp::Eq { dest, lhs, rhs },
364            Instruction::Ne(dest, lhs, rhs) => TraceOp::Ne { dest, lhs, rhs },
365            Instruction::Lt(dest, lhs, rhs) => TraceOp::Lt { dest, lhs, rhs },
366            Instruction::Le(dest, lhs, rhs) => TraceOp::Le { dest, lhs, rhs },
367            Instruction::Gt(dest, lhs, rhs) => TraceOp::Gt { dest, lhs, rhs },
368            Instruction::Ge(dest, lhs, rhs) => TraceOp::Ge { dest, lhs, rhs },
369            Instruction::And(dest, lhs, rhs) => TraceOp::And { dest, lhs, rhs },
370            Instruction::Or(dest, lhs, rhs) => TraceOp::Or { dest, lhs, rhs },
371            Instruction::Not(dest, src) => TraceOp::Not { dest, src },
372            Instruction::Concat(dest, lhs, rhs) => {
373                if let Some(ty) = Self::get_value_type(&registers[lhs as usize]) {
374                    if !self.guarded_registers.contains(&lhs) {
375                        self.trace.ops.push(TraceOp::Guard {
376                            register: lhs,
377                            expected_type: ty,
378                        });
379                        self.guarded_registers.insert(lhs);
380                    }
381                }
382
383                if let Some(ty) = Self::get_value_type(&registers[rhs as usize]) {
384                    if !self.guarded_registers.contains(&rhs) {
385                        self.trace.ops.push(TraceOp::Guard {
386                            register: rhs,
387                            expected_type: ty,
388                        });
389                        self.guarded_registers.insert(rhs);
390                    }
391                }
392
393                TraceOp::Concat { dest, lhs, rhs }
394            }
395
396            Instruction::GetIndex(dest, array, index) => {
397                if let Some(ty) = Self::get_value_type(&registers[array as usize]) {
398                    if !self.guarded_registers.contains(&array) {
399                        self.trace.ops.push(TraceOp::Guard {
400                            register: array,
401                            expected_type: ty,
402                        });
403                        self.guarded_registers.insert(array);
404                    }
405                }
406
407                if let Some(ty) = Self::get_value_type(&registers[index as usize]) {
408                    if !self.guarded_registers.contains(&index) {
409                        self.trace.ops.push(TraceOp::Guard {
410                            register: index,
411                            expected_type: ty,
412                        });
413                        self.guarded_registers.insert(index);
414                    }
415                }
416
417                TraceOp::GetIndex { dest, array, index }
418            }
419
420            Instruction::ArrayLen(dest, array) => {
421                if let Some(ty) = Self::get_value_type(&registers[array as usize]) {
422                    if !self.guarded_registers.contains(&array) {
423                        self.trace.ops.push(TraceOp::Guard {
424                            register: array,
425                            expected_type: ty,
426                        });
427                        self.guarded_registers.insert(array);
428                    }
429                }
430
431                TraceOp::ArrayLen { dest, array }
432            }
433
434            Instruction::CallMethod(obj_reg, method_name_idx, first_arg, arg_count, dest_reg) => {
435                let method_name = function.chunk.constants[method_name_idx as usize]
436                    .as_string()
437                    .unwrap_or("unknown")
438                    .to_string();
439                if let Some(ty) = Self::get_value_type(&registers[obj_reg as usize]) {
440                    if !self.guarded_registers.contains(&obj_reg) {
441                        self.trace.ops.push(TraceOp::Guard {
442                            register: obj_reg,
443                            expected_type: ty,
444                        });
445                        self.guarded_registers.insert(obj_reg);
446                    }
447                }
448
449                for i in 0..arg_count {
450                    let arg_reg = first_arg + i;
451                    if let Some(ty) = Self::get_value_type(&registers[arg_reg as usize]) {
452                        if !self.guarded_registers.contains(&arg_reg) {
453                            self.trace.ops.push(TraceOp::Guard {
454                                register: arg_reg,
455                                expected_type: ty,
456                            });
457                            self.guarded_registers.insert(arg_reg);
458                        }
459                    }
460                }
461
462                TraceOp::CallMethod {
463                    dest: dest_reg,
464                    object: obj_reg,
465                    method_name,
466                    first_arg,
467                    arg_count,
468                }
469            }
470
471            Instruction::GetField(dest, obj_reg, field_name_idx) => {
472                let field_name = function.chunk.constants[field_name_idx as usize]
473                    .as_string()
474                    .unwrap_or("unknown")
475                    .to_string();
476                let (field_index, is_weak_field) = match &registers[obj_reg as usize] {
477                    Value::Struct { layout, .. } => {
478                        let idx = layout.index_of_str(&field_name);
479                        let is_weak = idx.map(|i| layout.is_weak(i)).unwrap_or(false);
480                        (idx, is_weak)
481                    }
482
483                    _ => (None, false),
484                };
485                if let Some(ty) = Self::get_value_type(&registers[obj_reg as usize]) {
486                    if !self.guarded_registers.contains(&obj_reg) {
487                        self.trace.ops.push(TraceOp::Guard {
488                            register: obj_reg,
489                            expected_type: ty,
490                        });
491                        self.guarded_registers.insert(obj_reg);
492                    }
493                }
494
495                let value_type = Self::get_value_type(&registers[dest as usize]);
496                TraceOp::GetField {
497                    dest,
498                    object: obj_reg,
499                    field_name,
500                    field_index,
501                    value_type,
502                    is_weak: is_weak_field,
503                }
504            }
505
506            Instruction::SetField(obj_reg, field_name_idx, value_reg) => {
507                let field_name = function.chunk.constants[field_name_idx as usize]
508                    .as_string()
509                    .unwrap_or("unknown")
510                    .to_string();
511                let (field_index, is_weak_field) = match &registers[obj_reg as usize] {
512                    Value::Struct { layout, .. } => {
513                        let idx = layout.index_of_str(&field_name);
514                        let is_weak = idx.map(|i| layout.is_weak(i)).unwrap_or(false);
515                        (idx, is_weak)
516                    }
517
518                    _ => (None, false),
519                };
520                if let Some(ty) = Self::get_value_type(&registers[obj_reg as usize]) {
521                    if !self.guarded_registers.contains(&obj_reg) {
522                        self.trace.ops.push(TraceOp::Guard {
523                            register: obj_reg,
524                            expected_type: ty,
525                        });
526                        self.guarded_registers.insert(obj_reg);
527                    }
528                }
529
530                let value_type = Self::get_value_type(&registers[value_reg as usize]);
531                if let Some(ty) = value_type {
532                    if !self.guarded_registers.contains(&value_reg) {
533                        self.trace.ops.push(TraceOp::Guard {
534                            register: value_reg,
535                            expected_type: ty,
536                        });
537                        self.guarded_registers.insert(value_reg);
538                    }
539                }
540
541                TraceOp::SetField {
542                    object: obj_reg,
543                    field_name,
544                    value: value_reg,
545                    field_index,
546                    value_type,
547                    is_weak: is_weak_field,
548                }
549            }
550
551            Instruction::NewStruct(
552                dest,
553                struct_name_idx,
554                first_field_name_idx,
555                first_field_reg,
556                field_count,
557            ) => {
558                let struct_name = function.chunk.constants[struct_name_idx as usize]
559                    .as_string()
560                    .unwrap_or("unknown")
561                    .to_string();
562                let mut field_names = Vec::new();
563                for i in 0..field_count {
564                    let field_name_idx = first_field_name_idx + (i as u16);
565                    let field_name = function.chunk.constants[field_name_idx as usize]
566                        .as_string()
567                        .unwrap_or("unknown")
568                        .to_string();
569                    field_names.push(field_name);
570                }
571
572                let mut field_registers = Vec::new();
573                for i in 0..field_count {
574                    let field_reg = first_field_reg + i;
575                    field_registers.push(field_reg);
576                    if let Some(ty) = Self::get_value_type(&registers[field_reg as usize]) {
577                        if !self.guarded_registers.contains(&field_reg) {
578                            self.trace.ops.push(TraceOp::Guard {
579                                register: field_reg,
580                                expected_type: ty,
581                            });
582                            self.guarded_registers.insert(field_reg);
583                        }
584                    }
585                }
586
587                TraceOp::NewStruct {
588                    dest,
589                    struct_name,
590                    field_names,
591                    field_registers,
592                }
593            }
594
595            Instruction::Call(func_reg, first_arg, arg_count, dest_reg) => {
596                match &registers[func_reg as usize] {
597                    Value::NativeFunction(native_fn) => {
598                        let traced = TracedNativeFn::new(native_fn.clone());
599                        if !self.guarded_registers.contains(&func_reg) {
600                            self.trace.ops.push(TraceOp::GuardNativeFunction {
601                                register: func_reg,
602                                function: traced.clone(),
603                            });
604                            self.guarded_registers.insert(func_reg);
605                        }
606
607                        self.trace.ops.push(TraceOp::CallNative {
608                            dest: dest_reg,
609                            callee: func_reg,
610                            function: traced,
611                            first_arg,
612                            arg_count,
613                        });
614                        return Ok(());
615                    }
616
617                    _ => {
618                        self.recording = false;
619                        return Err(LustError::RuntimeError {
620                            message: "Trace aborted: unsupported operation".to_string(),
621                        });
622                    }
623                }
624            }
625
626            Instruction::NewArray(_, _, _)
627            | Instruction::NewMap(_)
628            | Instruction::SetIndex(_, _, _) => {
629                self.recording = false;
630                return Err(LustError::RuntimeError {
631                    message: "Trace aborted: unsupported operation".to_string(),
632                });
633            }
634
635            Instruction::Return(value_reg) => {
636                if function_idx == self.trace.function_idx {
637                    self.recording = false;
638                    return Ok(());
639                } else {
640                    TraceOp::Return {
641                        value: if value_reg == 255 {
642                            None
643                        } else {
644                            Some(value_reg)
645                        },
646                    }
647                }
648            }
649
650            Instruction::Jump(offset) => {
651                if offset < 0 {
652                    let target_calc = (current_ip as isize) + (offset as isize);
653                    if target_calc < 0 {
654                        self.recording = false;
655                        return Err(LustError::RuntimeError {
656                            message: format!(
657                                "Invalid jump target: offset={}, current_ip={}, target={}",
658                                offset, current_ip, target_calc
659                            ),
660                        });
661                    }
662
663                    let jump_target = target_calc as usize;
664                    if function_idx == self.trace.function_idx && jump_target == self.trace.start_ip
665                    {
666                        self.recording = false;
667                        return Ok(());
668                    } else {
669                        let bailout_ip = current_ip.saturating_sub(1);
670                        TraceOp::NestedLoopCall {
671                            function_idx,
672                            loop_start_ip: jump_target,
673                            bailout_ip,
674                        }
675                    }
676                } else {
677                    return Ok(());
678                }
679            }
680
681            Instruction::JumpIf(_cond, _) | Instruction::JumpIfNot(_cond, _) => {
682                return Ok(());
683            }
684
685            _ => {
686                self.recording = false;
687                return Err(LustError::RuntimeError {
688                    message: "Trace aborted: unsupported instruction".to_string(),
689                });
690            }
691        };
692        self.trace.ops.push(trace_op);
693        if self.trace.ops.len() >= self.max_length {
694            self.recording = false;
695            return Err(LustError::RuntimeError {
696                message: "Trace too long".to_string(),
697            });
698        }
699
700        Ok(())
701    }
702
703    fn add_type_guards(
704        &mut self,
705        lhs: Register,
706        rhs: Register,
707        registers: &[Value; 256],
708        function: &crate::bytecode::Function,
709    ) -> Result<(), LustError> {
710        if let Some(ty) = Self::get_value_type(&registers[lhs as usize]) {
711            let needs_guard = if self.guarded_registers.contains(&lhs) {
712                false
713            } else if let Some(static_type) = function.register_types.get(&lhs) {
714                !Self::type_kind_matches_value_type(static_type, ty)
715            } else {
716                true
717            };
718            if needs_guard {
719                self.trace.ops.push(TraceOp::Guard {
720                    register: lhs,
721                    expected_type: ty,
722                });
723                self.guarded_registers.insert(lhs);
724            } else {
725                self.guarded_registers.insert(lhs);
726            }
727        }
728
729        if let Some(ty) = Self::get_value_type(&registers[rhs as usize]) {
730            let needs_guard = if self.guarded_registers.contains(&rhs) {
731                false
732            } else if let Some(static_type) = function.register_types.get(&rhs) {
733                !Self::type_kind_matches_value_type(static_type, ty)
734            } else {
735                true
736            };
737            if needs_guard {
738                self.trace.ops.push(TraceOp::Guard {
739                    register: rhs,
740                    expected_type: ty,
741                });
742                self.guarded_registers.insert(rhs);
743            } else {
744                self.guarded_registers.insert(rhs);
745            }
746        }
747
748        Ok(())
749    }
750
751    fn type_kind_matches_value_type(
752        type_kind: &crate::ast::TypeKind,
753        value_type: ValueType,
754    ) -> bool {
755        use crate::ast::TypeKind;
756        match (type_kind, value_type) {
757            (TypeKind::Int, ValueType::Int) => true,
758            (TypeKind::Float, ValueType::Float) => true,
759            (TypeKind::Bool, ValueType::Bool) => true,
760            (TypeKind::String, ValueType::String) => true,
761            (TypeKind::Array(_), ValueType::Array) => true,
762            (TypeKind::Tuple(_), ValueType::Tuple) => true,
763            _ => false,
764        }
765    }
766
767    fn get_value_type(value: &Value) -> Option<ValueType> {
768        match value {
769            Value::Int(_) => Some(ValueType::Int),
770            Value::Float(_) => Some(ValueType::Float),
771            Value::Bool(_) => Some(ValueType::Bool),
772            Value::String(_) => Some(ValueType::String),
773            Value::Array(_) => Some(ValueType::Array),
774            Value::Tuple(_) => Some(ValueType::Tuple),
775            Value::Struct { .. } => Some(ValueType::Struct),
776            _ => None,
777        }
778    }
779
780    pub fn finish(self) -> Trace {
781        self.trace
782    }
783
784    pub fn is_recording(&self) -> bool {
785        self.recording
786    }
787
788    pub fn abort(&mut self) {
789        self.recording = false;
790    }
791}