petr_vm/
lib.rs

1//! Basic VM/interpreter for petr-ir. Primarily intended for testing the correctness of codegen and maybe some other features down the line,
2//! like a debugger or repl.
3
4// TODO should use fallible index maps since invalid IR can result in labels pointing to things that don't exist. don't want to
5// panic in those cases
6
7use std::collections::BTreeMap;
8
9use petr_ir::{DataLabel, DataSectionEntry, Intrinsic, IrOpcode, Reg, ReservedRegister};
10use petr_utils::{idx_map_key, IndexMap};
11use thiserror::Error;
12
13#[cfg(test)]
14mod tests {
15
16    use expect_test::{expect, Expect};
17    use petr_ir::Lowerer;
18    use petr_resolve::resolve_symbols;
19    use petr_typecheck::type_check;
20    use petr_utils::render_error;
21
22    use super::*;
23    fn check(
24        input: impl Into<String>,
25        expect: Expect,
26    ) {
27        let input = input.into();
28        let mut sources = stdlib::stdlib();
29        sources.push(("test", &input));
30        let parser = petr_parse::Parser::new(sources);
31        let (ast, errs, interner, source_map) = parser.into_result();
32        if !errs.is_empty() {
33            errs.into_iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err)));
34            panic!("build failed: code didn't parse");
35        }
36        let (errs, resolved) = resolve_symbols(ast, interner, Default::default());
37        if !errs.is_empty() {
38            dbg!(&errs);
39            panic!("build failed: resolution");
40        }
41        let (type_errs, type_checker) = type_check(resolved);
42
43        if !type_errs.is_empty() {
44            type_errs.iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err.clone())));
45            panic!("build failed: code didn't type check");
46        }
47
48        let lowerer = match Lowerer::new(type_checker) {
49            Ok(l) => l,
50            Err(err) => panic!("lowering failed: {err:?}"),
51        };
52        let (data, ir) = lowerer.finalize();
53        let vm = Vm::new(ir, data);
54        let (res, _stack, logs) = match vm.run() {
55            Ok(o) => o,
56            Err(err) => panic!("vm returned error: {err:?}"),
57        };
58
59        let mut res = format!("{res:?}");
60
61        if !logs.is_empty() {
62            res.push_str("\n___LOGS___\n");
63
64            res.push_str(&logs.join("\n"));
65        }
66
67        expect.assert_eq(&res);
68    }
69
70    #[test]
71    fn let_bindings() {
72        check(
73            r#"
74fn hi(x in 'int, y in 'int) returns 'int
75    let a = x;
76        b = y;
77        c = 20;
78        d = 30;
79        e = 12;
80    a
81fn main() returns 'int ~hi(42, 3)
82"#,
83            expect!["Value(42)"],
84        )
85    }
86    #[test]
87    fn import_call() {
88        check(
89            r#"
90import std.io.print
91
92fn main() returns 'unit 
93  ~print("hello, world!")
94  "#,
95            expect![[r#"
96                Value(0)
97                ___LOGS___
98                hello, world!"#]],
99        )
100    }
101
102    #[test]
103    fn addition() {
104        check(
105            r#"
106            fn hi(x in 'int, y in 'int) returns 'int
107    let a = x;
108        b = y;
109        c = 20;
110        d = 30;
111        e = 42;
112    + a + b + c + d e
113
114fn main() returns 'int ~hi(1, 3)
115"#,
116            expect!("Value(96)"),
117        )
118    }
119
120    #[test]
121    fn addition_path_res() {
122        check(
123            r#"
124            fn hi(x in 'int, y in 'int) returns 'int
125    let a = x;
126        b = y;
127        c = 20;
128        d = 30;
129        e = 42;
130    ~std.ops.add(a,  + b + c + d e)
131
132fn main() returns 'int ~hi(1, 3)
133"#,
134            expect!("Value(96)"),
135        )
136    }
137
138    #[test]
139    fn subtraction() {
140        check(
141            r#"
142            fn hi(x in 'int) returns 'int
143    let a = + x 1;
144        b = - x 1;
145        c = - 20 x;
146        d = + 20 x
147        d
148
149fn main() returns 'int ~hi(100)
150"#,
151            expect!("Value(120)"),
152        )
153    }
154
155    #[test]
156    fn overflowing_sub() {
157        check(
158            r#"
159fn main() returns 'int - 0 1
160"#,
161            expect!("Value(18446744073709551615)"),
162        )
163    }
164
165    #[test]
166    fn basic_malloc() {
167        check(
168            r#"
169fn main() returns 'int
170    let a = @malloc 1
171    let b = @malloc 1
172    let c = @malloc 5
173    let d = @malloc 1
174    d
175"#,
176            expect!("Value(7)"),
177        )
178    }
179
180    #[test]
181    fn ptr_mem() {
182        check(
183            r#"
184fn main() returns 'Ptr
185    let pointer = ~std.mem.malloc(20);
186    let pointer2 = ~std.mem.malloc(20);
187    side_effect = ~std.io.print "Hello, World!"
188
189    pointer2
190      "#,
191            expect!([r#"
192                Value(56)
193                ___LOGS___
194                Hello, World!"#]),
195        )
196    }
197}
198
199pub struct Vm {
200    state:        VmState,
201    instructions: IndexMap<ProgramOffset, IrOpcode>,
202    /// any messages that were logged during execution
203    stdout:       Vec<String>,
204}
205
206idx_map_key!(Register);
207idx_map_key!(ProgramOffset);
208
209#[derive(Default)]
210pub struct VmState {
211    stack:           Vec<Value>,
212    static_data:     IndexMap<DataLabel, DataSectionEntry>,
213    registers:       BTreeMap<Reg, Value>,
214    program_counter: ProgramOffset,
215    memory:          Vec<u64>,
216    call_stack:      Vec<ProgramOffset>,
217}
218
219impl Default for ProgramOffset {
220    fn default() -> Self {
221        0.into()
222    }
223}
224
225#[derive(Clone, Copy, Debug)]
226pub struct Value(u64);
227
228impl Value {
229    pub fn inner(&self) -> u64 {
230        self.0
231    }
232}
233
234#[derive(Debug, Error)]
235pub enum VmError {
236    #[error("Function label not found when executing opcode {0}")]
237    FunctionLabelNotFound(IrOpcode),
238    #[error("Popped empty stack when executing opcode {0}")]
239    PoppedEmptyStack(IrOpcode),
240    #[error("Register {0} not found")]
241    RegisterNotFound(Reg),
242    #[error("PC value of {0} is out of bounds for program of length {1}")]
243    ProgramCounterOutOfBounds(ProgramOffset, u64),
244    #[error("Returned to an empty call stack when executing opcode {0}")]
245    PoppedEmptyCallStack(IrOpcode),
246    #[error("Attempted to write to memory at index {0} but memory only has length {1}")]
247    OutOfBoundsMemoryWrite(usize, usize),
248}
249
250type Result<T> = std::result::Result<T, VmError>;
251
252enum VmControlFlow {
253    Continue,
254    Terminate(Value),
255}
256
257pub type VmLogs = Vec<String>;
258
259impl Vm {
260    pub fn new(
261        instructions: Vec<IrOpcode>,
262        static_data: IndexMap<DataLabel, DataSectionEntry>,
263    ) -> Self {
264        let mut idx_map = IndexMap::default();
265        for instr in instructions {
266            idx_map.insert(instr);
267        }
268        Self {
269            state:        VmState {
270                stack: Default::default(),
271                static_data,
272                registers: Default::default(),
273                program_counter: 0.into(),
274                memory: Vec::with_capacity(100),
275                call_stack: Default::default(),
276            },
277            instructions: idx_map,
278            stdout:       vec![],
279        }
280    }
281
282    pub fn run(mut self) -> Result<(Value, Vec<Value>, VmLogs)> {
283        use VmControlFlow::*;
284        let val = loop {
285            match self.execute() {
286                Ok(Continue) => continue,
287                Ok(Terminate(val)) => break val,
288                Err(e) => return Err(e),
289            }
290        };
291        Ok((val, self.state.stack, self.stdout))
292    }
293
294    fn execute(&mut self) -> Result<VmControlFlow> {
295        use VmControlFlow::*;
296        if self.state.program_counter.0 >= self.instructions.len() {
297            return Err(VmError::ProgramCounterOutOfBounds(
298                self.state.program_counter,
299                self.instructions.len() as u64,
300            ));
301        }
302        let opcode = self.instructions.get(self.state.program_counter).clone();
303        self.state.program_counter = (self.state.program_counter.0 + 1).into();
304        match opcode {
305            IrOpcode::JumpImmediate(label) => {
306                let Some(offset) = self
307                    .instructions
308                    .iter()
309                    .find_map(|(position, op)| if *op == IrOpcode::FunctionLabel(label) { Some(position) } else { None })
310                else {
311                    return Err(VmError::FunctionLabelNotFound(opcode));
312                };
313                self.state.program_counter = offset;
314                Ok(Continue)
315            },
316            IrOpcode::Add(dest, lhs, rhs) => {
317                let lhs = self.get_register(lhs)?;
318                let rhs = self.get_register(rhs)?;
319                self.set_register(dest, Value(lhs.0.wrapping_add(rhs.0)));
320                Ok(Continue)
321            },
322            IrOpcode::Multiply(dest, lhs, rhs) => {
323                let lhs = self.get_register(lhs)?;
324                let rhs = self.get_register(rhs)?;
325                self.set_register(dest, Value(lhs.0.wrapping_mul(rhs.0)));
326                Ok(Continue)
327            },
328            IrOpcode::Subtract(dest, lhs, rhs) => {
329                let lhs = self.get_register(lhs)?;
330                let rhs = self.get_register(rhs)?;
331                self.set_register(dest, Value(lhs.0.wrapping_sub(rhs.0)));
332                Ok(Continue)
333            },
334            IrOpcode::Divide(dest, lhs, rhs) => {
335                let lhs = self.get_register(lhs)?;
336                let rhs = self.get_register(rhs)?;
337                self.set_register(dest, Value(lhs.0 / rhs.0));
338                Ok(Continue)
339            },
340            IrOpcode::LoadData(dest, data_label) => {
341                let data = self.state.static_data.get(data_label).clone();
342                let data = self.data_section_to_val(&data);
343                self.set_register(dest, data);
344                Ok(Continue)
345            },
346            IrOpcode::StackPop(ref dest) => {
347                let Some(data) = self.state.stack.pop() else {
348                    return Err(VmError::PoppedEmptyStack(opcode));
349                };
350                self.set_register(dest.reg, data);
351                Ok(Continue)
352            },
353            IrOpcode::StackPush(val) => {
354                let data = self.get_register(val.reg)?;
355                self.state.stack.push(data);
356                Ok(Continue)
357            },
358            IrOpcode::Intrinsic(intrinsic) => {
359                match intrinsic {
360                    Intrinsic::Puts(reg) => {
361                        let ptr = self.get_register(reg)?.0;
362                        let ptr = ptr as usize;
363
364                        let len = self.state.memory[ptr];
365                        let len = len as usize;
366
367                        // strings are padded to the nearest u64 boundary right now and that's working
368                        // ...should be fine?
369                        let str = &self.state.memory[ptr + 1..ptr + 1 + len];
370                        let str = str.iter().flat_map(|num| num.to_ne_bytes()).collect::<Vec<u8>>();
371                        // convert vec of usizes to string
372                        let string: String = str.iter().map(|&c| c as char).collect();
373                        self.stdout.push(string.clone());
374                    },
375                };
376                Ok(Continue)
377            },
378            IrOpcode::FunctionLabel(_) => Ok(Continue),
379            IrOpcode::LoadImmediate(dest, imm) => {
380                self.set_register(dest, Value(imm));
381                Ok(Continue)
382            },
383            IrOpcode::Copy(dest, src) => {
384                let val = self.get_register(src)?;
385                self.set_register(dest, val);
386                Ok(Continue)
387            },
388            IrOpcode::Jump(_) => todo!(),
389            IrOpcode::Label(_) => todo!(),
390            IrOpcode::Return() => {
391                let val = self.get_register(Reg::Reserved(ReservedRegister::ReturnValueRegister))?;
392                // pop the fn stack, return there
393                let Some(offset) = self.state.call_stack.pop() else {
394                    return Ok(Terminate(val));
395                };
396                self.state.program_counter = offset;
397                Ok(Continue)
398            },
399            IrOpcode::PushPc() => {
400                self.state.call_stack.push((self.state.program_counter.0 + 1).into());
401                Ok(Continue)
402            },
403            IrOpcode::StackPushImmediate(imm) => {
404                self.state.stack.push(Value(imm));
405                Ok(Continue)
406            },
407            IrOpcode::ReturnImmediate(imm) => {
408                let Some(offset) = self.state.call_stack.pop() else {
409                    return Ok(Terminate(Value(imm)));
410                };
411                self.state.program_counter = offset;
412                Ok(Continue)
413            },
414            IrOpcode::Malloc(ptr_dest, size) => {
415                let size = self.get_register(size)?;
416                let ptr = self.state.memory.len();
417                self.state.memory.resize(ptr + size.0 as usize, 0);
418                self.set_register(ptr_dest, Value(ptr as u64));
419                Ok(Continue)
420            },
421            IrOpcode::MallocImmediate(ptr_dest, size) => {
422                let ptr = self.state.memory.len();
423                self.state.memory.resize(ptr + size.num_bytes(), 0);
424                self.set_register(ptr_dest, Value(ptr as u64));
425                Ok(Continue)
426            },
427            IrOpcode::WriteRegisterToMemory(reg, dest_ptr) => {
428                let dest_ptr = self.get_register(dest_ptr)?.0 as usize;
429                let val = self.get_register(reg)?.0;
430
431                if self.state.memory.len() <= dest_ptr {
432                    return Err(VmError::OutOfBoundsMemoryWrite(dest_ptr, self.state.memory.len()));
433                };
434                self.state.memory[dest_ptr] = val;
435                Ok(Continue)
436            },
437            IrOpcode::Comment(_) => Ok(Continue),
438        }
439    }
440
441    fn get_register(
442        &self,
443        reg: petr_ir::Reg,
444    ) -> Result<Value> {
445        self.state.registers.get(&reg).copied().ok_or(VmError::RegisterNotFound(reg))
446    }
447
448    fn set_register(
449        &mut self,
450        dest: petr_ir::Reg,
451        val: Value,
452    ) {
453        self.state.registers.insert(dest, val);
454    }
455
456    // TODO things larger than a register
457    fn data_section_to_val(
458        &mut self,
459        data: &DataSectionEntry,
460    ) -> Value {
461        match data {
462            DataSectionEntry::Int64(x) => Value(*x as u64),
463            DataSectionEntry::String(val) => {
464                let str_as_bytes = val.as_bytes();
465                let bytes_compressed_as_u64s = str_as_bytes
466                    .chunks(8)
467                    .map(|chunk| {
468                        let mut bytes = [0u8; 8];
469                        // pad the chunk with 0s if it isn't a multiple of 8
470                        let len = chunk.len();
471                        let chunk = if len < 8 {
472                            let mut padded = [0u8; 8];
473                            padded[..len].copy_from_slice(chunk);
474                            padded.to_vec()
475                        } else {
476                            chunk.to_vec()
477                        };
478                        bytes.copy_from_slice(&chunk[..]);
479                        u64::from_ne_bytes(bytes)
480                    })
481                    .collect::<Vec<_>>();
482                let ptr = self.state.memory.len();
483                // first slot of a string is the len, then the content
484                self.state.memory.push(bytes_compressed_as_u64s.len() as u64);
485                self.state.memory.extend_from_slice(&bytes_compressed_as_u64s);
486                Value(ptr as u64)
487            },
488            DataSectionEntry::Bool(x) => Value(if *x { 1 } else { 0 }),
489        }
490    }
491}