koto_bytecode/
instruction_reader.rs

1use crate::{Chunk, FunctionFlags, Instruction, Op, StringFormatFlags};
2use koto_memory::Ptr;
3use koto_parser::{StringFormatOptions, StringFormatRepresentation};
4use std::mem::MaybeUninit;
5
6/// An iterator that converts bytecode into a series of [Instruction]s
7#[derive(Clone, Default)]
8pub struct InstructionReader {
9    /// The chunk that the reader is reading from
10    pub chunk: Ptr<Chunk>,
11    /// The reader's instruction pointer
12    pub ip: usize,
13}
14
15impl InstructionReader {
16    /// Initializes a reader with the given chunk
17    pub fn new(chunk: Ptr<Chunk>) -> Self {
18        Self { chunk, ip: 0 }
19    }
20}
21
22impl Iterator for InstructionReader {
23    type Item = Instruction;
24
25    fn next(&mut self) -> Option<Self::Item> {
26        use Instruction::*;
27
28        let bytes = self.chunk.bytes.as_slice();
29
30        macro_rules! get_u8 {
31            () => {{
32                match bytes.get(self.ip) {
33                    Some(byte) => {
34                        self.ip += 1;
35                        *byte
36                    }
37                    None => return out_of_bounds_access_error(self.ip),
38                }
39            }};
40        }
41
42        macro_rules! get_u8_array {
43            ($n:expr) => {{
44                if bytes.len() >= self.ip + $n {
45                    // Safety:
46                    // - The size of `bytes` has been checked so we know its safe to access ip + $n
47                    // - `result` is fully initialized as a result of the copy_nonoverlapping call
48                    //   so it's safe to transmute.
49                    // Todo: Simplify once https://github.com/rust-lang/rust/issues/96097 is stable.
50                    unsafe {
51                        let mut result: [MaybeUninit<u8>; $n] = MaybeUninit::uninit().assume_init();
52                        std::ptr::copy_nonoverlapping(
53                            bytes.as_ptr().add(self.ip),
54                            result.as_mut_ptr() as *mut u8,
55                            $n,
56                        );
57                        self.ip += $n;
58                        // Convert `MaybeUninit<[u8; $n]>` to `[u8; $n]`.
59                        std::mem::transmute::<[MaybeUninit<u8>; $n], [u8; $n]>(result)
60                    }
61                } else {
62                    return out_of_bounds_access_error(self.ip);
63                }
64            }};
65        }
66        macro_rules! get_u8x2 {
67            () => {{ get_u8_array!(2) }};
68        }
69        macro_rules! get_u8x3 {
70            () => {{ get_u8_array!(3) }};
71        }
72        macro_rules! get_u8x4 {
73            () => {{ get_u8_array!(4) }};
74        }
75        macro_rules! get_u8x5 {
76            () => {{ get_u8_array!(5) }};
77        }
78        macro_rules! get_u8x6 {
79            () => {{ get_u8_array!(6) }};
80        }
81
82        macro_rules! get_u16 {
83            () => {{
84                let [a, b] = get_u8x2!();
85                u16::from_le_bytes([a, b])
86            }};
87        }
88
89        macro_rules! get_var_u32 {
90            () => {{
91                let mut result = 0;
92                let mut shift_amount = 0;
93                loop {
94                    let Some(&byte) = bytes.get(self.ip) else {
95                        return out_of_bounds_access_error(self.ip);
96                    };
97                    self.ip += 1;
98                    result |= (byte as u32 & 0x7f) << shift_amount;
99                    if byte & 0x80 == 0 {
100                        break;
101                    } else {
102                        shift_amount += 7;
103                    }
104                }
105                result
106            }};
107        }
108
109        macro_rules! get_var_u32_with_first_byte {
110            ($first_byte:expr) => {{
111                let mut byte = $first_byte;
112                let mut result = (byte as u32 & 0x7f);
113                let mut shift_amount = 0;
114                while byte & 0x80 != 0 {
115                    let Some(&next_byte) = bytes.get(self.ip) else {
116                        return out_of_bounds_access_error(self.ip);
117                    };
118
119                    byte = next_byte;
120                    self.ip += 1;
121                    shift_amount += 7;
122
123                    result |= (byte as u32 & 0x7f) << shift_amount;
124                }
125                result
126            }};
127        }
128
129        // Each op consists of at least two bytes
130        let op_ip = self.ip;
131        let (op, byte_a) = match bytes.get(op_ip..op_ip + 2) {
132            Some(&[op, byte]) => (Op::from(op), byte),
133            _ => return None,
134        };
135        self.ip += 2;
136
137        let instruction = match op {
138            Op::NewFrame => NewFrame {
139                register_count: byte_a,
140            },
141            Op::Copy => Copy {
142                target: byte_a,
143                source: get_u8!(),
144            },
145            Op::SetNull => SetNull { register: byte_a },
146            Op::SetFalse => SetBool {
147                register: byte_a,
148                value: false,
149            },
150            Op::SetTrue => SetBool {
151                register: byte_a,
152                value: true,
153            },
154            Op::Set0 => SetNumber {
155                register: byte_a,
156                value: 0,
157            },
158            Op::Set1 => SetNumber {
159                register: byte_a,
160                value: 1,
161            },
162            Op::SetNumberU8 => SetNumber {
163                register: byte_a,
164                value: get_u8!() as i64,
165            },
166            Op::SetNumberNegU8 => SetNumber {
167                register: byte_a,
168                value: -(get_u8!() as i64),
169            },
170            Op::LoadFloat => LoadFloat {
171                register: byte_a,
172                constant: get_var_u32!().into(),
173            },
174            Op::LoadInt => LoadInt {
175                register: byte_a,
176                constant: get_var_u32!().into(),
177            },
178            Op::LoadString => LoadString {
179                register: byte_a,
180                constant: get_var_u32!().into(),
181            },
182            Op::LoadNonLocal => LoadNonLocal {
183                register: byte_a,
184                constant: get_var_u32!().into(),
185            },
186            Op::ExportValue => ExportValue {
187                key: byte_a,
188                value: get_u8!(),
189            },
190            Op::ExportEntry => ExportEntry { entry: byte_a },
191            Op::Import => Import { register: byte_a },
192            Op::ImportAll => ImportAll { register: byte_a },
193            Op::MakeTempTuple => {
194                let [byte_b, byte_c] = get_u8x2!();
195                MakeTempTuple {
196                    register: byte_a,
197                    start: byte_b,
198                    count: byte_c,
199                }
200            }
201            Op::TempTupleToTuple => TempTupleToTuple {
202                register: byte_a,
203                source: get_u8!(),
204            },
205            Op::MakeMap => MakeMap {
206                register: byte_a,
207                size_hint: get_var_u32!(),
208            },
209            Op::SequenceStart => SequenceStart {
210                size_hint: get_var_u32_with_first_byte!(byte_a),
211            },
212            Op::SequencePush => SequencePush { value: byte_a },
213            Op::SequencePushN => SequencePushN {
214                start: byte_a,
215                count: get_u8!(),
216            },
217            Op::SequenceToList => SequenceToList { register: byte_a },
218            Op::SequenceToTuple => SequenceToTuple { register: byte_a },
219            Op::Range => {
220                let [byte_b, byte_c] = get_u8x2!();
221                Range {
222                    register: byte_a,
223                    start: byte_b,
224                    end: byte_c,
225                }
226            }
227            Op::RangeInclusive => {
228                let [byte_b, byte_c] = get_u8x2!();
229                RangeInclusive {
230                    register: byte_a,
231                    start: byte_b,
232                    end: byte_c,
233                }
234            }
235            Op::RangeTo => RangeTo {
236                register: byte_a,
237                end: get_u8!(),
238            },
239            Op::RangeToInclusive => RangeToInclusive {
240                register: byte_a,
241                end: get_u8!(),
242            },
243            Op::RangeFrom => RangeFrom {
244                register: byte_a,
245                start: get_u8!(),
246            },
247            Op::RangeFull => RangeFull { register: byte_a },
248            Op::MakeIterator => MakeIterator {
249                register: byte_a,
250                iterable: get_u8!(),
251            },
252            Op::Function => {
253                let register = byte_a;
254                let [
255                    arg_count,
256                    optional_arg_count,
257                    capture_count,
258                    flags,
259                    size_a,
260                    size_b,
261                ] = get_u8x6!();
262                match FunctionFlags::try_from(flags) {
263                    Ok(flags) => {
264                        let size = u16::from_le_bytes([size_a, size_b]);
265
266                        Function {
267                            register,
268                            arg_count,
269                            optional_arg_count,
270                            capture_count,
271                            flags,
272                            size,
273                        }
274                    }
275                    Err(e) => Error { message: e },
276                }
277            }
278            Op::Capture => {
279                let [byte_b, byte_c] = get_u8x2!();
280                Capture {
281                    function: byte_a,
282                    target: byte_b,
283                    source: byte_c,
284                }
285            }
286            Op::Negate => Negate {
287                register: byte_a,
288                value: get_u8!(),
289            },
290            Op::Not => Not {
291                register: byte_a,
292                value: get_u8!(),
293            },
294            Op::Add => {
295                let [byte_b, byte_c] = get_u8x2!();
296                Add {
297                    register: byte_a,
298                    lhs: byte_b,
299                    rhs: byte_c,
300                }
301            }
302            Op::Subtract => {
303                let [byte_b, byte_c] = get_u8x2!();
304                Subtract {
305                    register: byte_a,
306                    lhs: byte_b,
307                    rhs: byte_c,
308                }
309            }
310            Op::Multiply => {
311                let [byte_b, byte_c] = get_u8x2!();
312                Multiply {
313                    register: byte_a,
314                    lhs: byte_b,
315                    rhs: byte_c,
316                }
317            }
318            Op::Divide => {
319                let [byte_b, byte_c] = get_u8x2!();
320                Divide {
321                    register: byte_a,
322                    lhs: byte_b,
323                    rhs: byte_c,
324                }
325            }
326            Op::Remainder => {
327                let [byte_b, byte_c] = get_u8x2!();
328                Remainder {
329                    register: byte_a,
330                    lhs: byte_b,
331                    rhs: byte_c,
332                }
333            }
334            Op::Power => {
335                let [byte_b, byte_c] = get_u8x2!();
336                Power {
337                    register: byte_a,
338                    lhs: byte_b,
339                    rhs: byte_c,
340                }
341            }
342            Op::AddAssign => AddAssign {
343                lhs: byte_a,
344                rhs: get_u8!(),
345            },
346            Op::SubtractAssign => SubtractAssign {
347                lhs: byte_a,
348                rhs: get_u8!(),
349            },
350            Op::MultiplyAssign => MultiplyAssign {
351                lhs: byte_a,
352                rhs: get_u8!(),
353            },
354            Op::DivideAssign => DivideAssign {
355                lhs: byte_a,
356                rhs: get_u8!(),
357            },
358            Op::RemainderAssign => RemainderAssign {
359                lhs: byte_a,
360                rhs: get_u8!(),
361            },
362            Op::PowerAssign => PowerAssign {
363                lhs: byte_a,
364                rhs: get_u8!(),
365            },
366            Op::Less => {
367                let [lhs, rhs] = get_u8x2!();
368                Less {
369                    register: byte_a,
370                    lhs,
371                    rhs,
372                }
373            }
374            Op::LessOrEqual => {
375                let [lhs, rhs] = get_u8x2!();
376                LessOrEqual {
377                    register: byte_a,
378                    lhs,
379                    rhs,
380                }
381            }
382            Op::Greater => {
383                let [lhs, rhs] = get_u8x2!();
384                Greater {
385                    register: byte_a,
386                    lhs,
387                    rhs,
388                }
389            }
390            Op::GreaterOrEqual => {
391                let [lhs, rhs] = get_u8x2!();
392                GreaterOrEqual {
393                    register: byte_a,
394                    lhs,
395                    rhs,
396                }
397            }
398            Op::Equal => {
399                let [lhs, rhs] = get_u8x2!();
400                Equal {
401                    register: byte_a,
402                    lhs,
403                    rhs,
404                }
405            }
406            Op::NotEqual => {
407                let [byte_b, byte_c] = get_u8x2!();
408                NotEqual {
409                    register: byte_a,
410                    lhs: byte_b,
411                    rhs: byte_c,
412                }
413            }
414            Op::Jump => Jump {
415                offset: u16::from_le_bytes([byte_a, get_u8!()]),
416            },
417            Op::JumpBack => JumpBack {
418                offset: u16::from_le_bytes([byte_a, get_u8!()]),
419            },
420            Op::JumpIfTrue => JumpIfTrue {
421                register: byte_a,
422                offset: get_u16!(),
423            },
424            Op::JumpIfFalse => JumpIfFalse {
425                register: byte_a,
426                offset: get_u16!(),
427            },
428            Op::JumpIfNull => JumpIfNull {
429                register: byte_a,
430                offset: get_u16!(),
431            },
432            Op::Call => {
433                let [function, frame_base, arg_count, unpacked_arg_count] = get_u8x4!();
434                Call {
435                    result: byte_a,
436                    function,
437                    frame_base,
438                    arg_count,
439                    packed_arg_count: unpacked_arg_count,
440                }
441            }
442            Op::CallInstance => {
443                let [
444                    function,
445                    instance,
446                    frame_base,
447                    arg_count,
448                    unpacked_arg_count,
449                ] = get_u8x5!();
450                CallInstance {
451                    result: byte_a,
452                    function,
453                    instance,
454                    frame_base,
455                    arg_count,
456                    packed_arg_count: unpacked_arg_count,
457                }
458            }
459            Op::Return => Return { register: byte_a },
460            Op::Yield => Yield { register: byte_a },
461            Op::Throw => Throw { register: byte_a },
462            Op::Size => Size {
463                register: byte_a,
464                value: get_u8!(),
465            },
466            Op::IterNext => {
467                let [byte_b, byte_c, byte_d] = get_u8x3!();
468                IterNext {
469                    result: Some(byte_a),
470                    iterator: byte_b,
471                    jump_offset: u16::from_le_bytes([byte_c, byte_d]),
472                    temporary_output: false,
473                }
474            }
475            Op::IterNextTemp => {
476                let [byte_b, byte_c, byte_d] = get_u8x3!();
477                IterNext {
478                    result: Some(byte_a),
479                    iterator: byte_b,
480                    jump_offset: u16::from_le_bytes([byte_c, byte_d]),
481                    temporary_output: true,
482                }
483            }
484            Op::IterNextQuiet => IterNext {
485                result: None,
486                iterator: byte_a,
487                jump_offset: get_u16!(),
488                temporary_output: false,
489            },
490            Op::IterUnpack => IterNext {
491                result: Some(byte_a),
492                iterator: get_u8!(),
493                jump_offset: 0,
494                temporary_output: false,
495            },
496            Op::TempIndex => {
497                let [byte_b, byte_c] = get_u8x2!();
498                TempIndex {
499                    register: byte_a,
500                    value: byte_b,
501                    index: byte_c as i8,
502                }
503            }
504            Op::SliceFrom => {
505                let [byte_b, byte_c] = get_u8x2!();
506                SliceFrom {
507                    register: byte_a,
508                    value: byte_b,
509                    index: byte_c as i8,
510                }
511            }
512            Op::SliceTo => {
513                let [byte_b, byte_c] = get_u8x2!();
514                SliceTo {
515                    register: byte_a,
516                    value: byte_b,
517                    index: byte_c as i8,
518                }
519            }
520            Op::Index => {
521                let [value, index] = get_u8x2!();
522                Index {
523                    register: byte_a,
524                    value,
525                    index,
526                }
527            }
528            Op::IndexMut => {
529                let [index, value] = get_u8x2!();
530                IndexMut {
531                    register: byte_a,
532                    index,
533                    value,
534                }
535            }
536            Op::MapInsert => {
537                let [key, value] = get_u8x2!();
538                MapInsert {
539                    register: byte_a,
540                    key,
541                    value,
542                }
543            }
544            Op::MetaInsert => {
545                let register = byte_a;
546                let [meta_id, value] = get_u8x2!();
547                if let Ok(id) = meta_id.try_into() {
548                    {
549                        MetaInsert {
550                            register,
551                            value,
552                            id,
553                        }
554                    }
555                } else {
556                    Error {
557                        message: format!(
558                            "Unexpected meta id {meta_id} found at instruction {op_ip}",
559                        ),
560                    }
561                }
562            }
563            Op::MetaInsertNamed => {
564                let register = byte_a;
565                let [meta_id, name, value] = get_u8x3!();
566                if let Ok(id) = meta_id.try_into() {
567                    MetaInsertNamed {
568                        register,
569                        value,
570                        id,
571                        name,
572                    }
573                } else {
574                    Error {
575                        message: format!(
576                            "Unexpected meta id {meta_id} found at instruction {op_ip}",
577                        ),
578                    }
579                }
580            }
581            Op::MetaExport => {
582                let meta_id = byte_a;
583                let value = get_u8!();
584                if let Ok(id) = meta_id.try_into() {
585                    MetaExport { id, value }
586                } else {
587                    Error {
588                        message: format!(
589                            "Unexpected meta id {meta_id} found at instruction {op_ip}",
590                        ),
591                    }
592                }
593            }
594            Op::MetaExportNamed => {
595                let meta_id = byte_a;
596                let [name, value] = get_u8x2!();
597                if let Ok(id) = meta_id.try_into() {
598                    MetaExportNamed { id, value, name }
599                } else {
600                    Error {
601                        message: format!(
602                            "Unexpected meta id {meta_id} found at instruction {op_ip}",
603                        ),
604                    }
605                }
606            }
607            Op::Access => {
608                let [value, key_a] = get_u8x2!();
609                Access {
610                    register: byte_a,
611                    value,
612                    key: get_var_u32_with_first_byte!(key_a).into(),
613                }
614            }
615            Op::AccessString => {
616                let [byte_b, byte_c] = get_u8x2!();
617                AccessString {
618                    register: byte_a,
619                    value: byte_b,
620                    key: byte_c,
621                }
622            }
623            Op::TryStart => TryStart {
624                arg_register: byte_a,
625                catch_offset: get_u16!(),
626            },
627            Op::TryEnd => TryEnd,
628            Op::Debug => Debug {
629                register: byte_a,
630                constant: get_var_u32!().into(),
631            },
632            Op::CheckSizeEqual => CheckSizeEqual {
633                register: byte_a,
634                size: get_u8!() as usize,
635            },
636            Op::CheckSizeMin => CheckSizeMin {
637                register: byte_a,
638                size: get_u8!() as usize,
639            },
640            Op::AssertType => AssertType {
641                value: byte_a,
642                allow_null: false,
643                type_string: get_var_u32!().into(),
644            },
645            Op::CheckType => CheckType {
646                value: byte_a,
647                allow_null: false,
648                type_string: get_var_u32!().into(),
649                jump_offset: get_u16!(),
650            },
651            Op::AssertOptionalType => AssertType {
652                value: byte_a,
653                allow_null: true,
654                type_string: get_var_u32!().into(),
655            },
656            Op::CheckOptionalType => CheckType {
657                value: byte_a,
658                allow_null: true,
659                type_string: get_var_u32!().into(),
660                jump_offset: get_u16!(),
661            },
662            Op::StringStart => StringStart {
663                size_hint: get_var_u32_with_first_byte!(byte_a),
664            },
665            Op::StringPush => {
666                let value = byte_a;
667                let flags = get_u8!();
668
669                if flags != 0 {
670                    match StringFormatFlags::try_from(flags) {
671                        Ok(flags) => {
672                            let mut options = StringFormatOptions {
673                                alignment: flags.alignment(),
674                                ..Default::default()
675                            };
676                            if flags.has_min_width() {
677                                options.min_width = Some(get_var_u32!());
678                            }
679                            if flags.has_precision() {
680                                options.precision = Some(get_var_u32!());
681                            }
682                            if flags.has_fill_character() {
683                                options.fill_character = Some(get_var_u32!().into());
684                            }
685                            if flags.has_representation() {
686                                match StringFormatRepresentation::try_from(get_u8!()) {
687                                    Ok(representation) => {
688                                        options.representation = Some(representation);
689                                    }
690                                    Err(e) => return Some(Error { message: e }),
691                                }
692                            }
693                            StringPush {
694                                value,
695                                format_options: Some(options),
696                            }
697                        }
698                        Err(e) => Error { message: e },
699                    }
700                } else {
701                    StringPush {
702                        value,
703                        format_options: None,
704                    }
705                }
706            }
707            Op::StringFinish => StringFinish { register: byte_a },
708            _ => Error {
709                message: format!("Unexpected opcode {op:?} found at instruction {op_ip}"),
710            },
711        };
712
713        Some(instruction)
714    }
715}
716
717#[inline(never)]
718fn out_of_bounds_access_error(ip: usize) -> Option<Instruction> {
719    Some(Instruction::Error {
720        message: format!("Instruction access out of bounds at {ip}"),
721    })
722}