1use std::collections::BTreeMap;
8
9use petr_ir::{DataLabel, DataSectionEntry, Intrinsic, IrOpcode, Reg, ReservedRegister};
10use petr_utils::{idx_map_key, IndexMap};
11use thiserror::Error;
12
13#[cfg(test)]
14mod tests {
15
16 use expect_test::{expect, Expect};
17 use petr_ir::Lowerer;
18 use petr_resolve::resolve_symbols;
19 use petr_typecheck::type_check;
20 use petr_utils::render_error;
21
22 use super::*;
23 fn check(
24 input: impl Into<String>,
25 expect: Expect,
26 ) {
27 let input = input.into();
28 let mut sources = stdlib::stdlib();
29 sources.push(("test", &input));
30 let parser = petr_parse::Parser::new(sources);
31 let (ast, errs, interner, source_map) = parser.into_result();
32 if !errs.is_empty() {
33 errs.into_iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err)));
34 panic!("build failed: code didn't parse");
35 }
36 let (errs, resolved) = resolve_symbols(ast, interner, Default::default());
37 if !errs.is_empty() {
38 dbg!(&errs);
39 panic!("build failed: resolution");
40 }
41 let (type_errs, type_checker) = type_check(resolved);
42
43 if !type_errs.is_empty() {
44 type_errs.iter().for_each(|err| eprintln!("{:?}", render_error(&source_map, err.clone())));
45 panic!("build failed: code didn't type check");
46 }
47
48 let lowerer = match Lowerer::new(type_checker) {
49 Ok(l) => l,
50 Err(err) => panic!("lowering failed: {err:?}"),
51 };
52 let (data, ir) = lowerer.finalize();
53 let vm = Vm::new(ir, data);
54 let (res, _stack, logs) = match vm.run() {
55 Ok(o) => o,
56 Err(err) => panic!("vm returned error: {err:?}"),
57 };
58
59 let mut res = format!("{res:?}");
60
61 if !logs.is_empty() {
62 res.push_str("\n___LOGS___\n");
63
64 res.push_str(&logs.join("\n"));
65 }
66
67 expect.assert_eq(&res);
68 }
69
70 #[test]
71 fn let_bindings() {
72 check(
73 r#"
74fn hi(x in 'int, y in 'int) returns 'int
75 let a = x;
76 b = y;
77 c = 20;
78 d = 30;
79 e = 12;
80 a
81fn main() returns 'int ~hi(42, 3)
82"#,
83 expect!["Value(42)"],
84 )
85 }
86 #[test]
87 fn import_call() {
88 check(
89 r#"
90import std.io.print
91
92fn main() returns 'unit
93 ~print("hello, world!")
94 "#,
95 expect![[r#"
96 Value(0)
97 ___LOGS___
98 hello, world! "#]],
99 )
100 }
101
102 #[test]
103 fn addition() {
104 check(
105 r#"
106 fn hi(x in 'int, y in 'int) returns 'int
107 let a = x;
108 b = y;
109 c = 20;
110 d = 30;
111 e = 42;
112 + a + b + c + d e
113
114fn main() returns 'int ~hi(1, 3)
115"#,
116 expect!("Value(96)"),
117 )
118 }
119
120 #[test]
121 fn addition_path_res() {
122 check(
123 r#"
124 fn hi(x in 'int, y in 'int) returns 'int
125 let a = x;
126 b = y;
127 c = 20;
128 d = 30;
129 e = 42;
130 ~std.ops.add(a, + b + c + d e)
131
132fn main() returns 'int ~hi(1, 3)
133"#,
134 expect!("Value(96)"),
135 )
136 }
137
138 #[test]
139 fn subtraction() {
140 check(
141 r#"
142 fn hi(x in 'int) returns 'int
143 let a = + x 1;
144 b = - x 1;
145 c = - 20 x;
146 d = + 20 x
147 d
148
149fn main() returns 'int ~hi(100)
150"#,
151 expect!("Value(120)"),
152 )
153 }
154
155 #[test]
156 fn overflowing_sub() {
157 check(
158 r#"
159fn main() returns 'int - 0 1
160"#,
161 expect!("Value(18446744073709551615)"),
162 )
163 }
164
165 #[test]
166 fn basic_malloc() {
167 check(
168 r#"
169fn main() returns 'int
170 let a = @malloc 1
171 let b = @malloc 1
172 let c = @malloc 5
173 let d = @malloc 1
174 d
175"#,
176 expect!("Value(7)"),
177 )
178 }
179
180 #[test]
181 fn ptr_mem() {
182 check(
183 r#"
184fn main() returns 'Ptr
185 let pointer = ~std.mem.malloc(20);
186 let pointer2 = ~std.mem.malloc(20);
187 side_effect = ~std.io.print "Hello, World!"
188
189 pointer2
190 "#,
191 expect!([r#"
192 Value(56)
193 ___LOGS___
194 Hello, World! "#]),
195 )
196 }
197}
198
199pub struct Vm {
200 state: VmState,
201 instructions: IndexMap<ProgramOffset, IrOpcode>,
202 stdout: Vec<String>,
204}
205
206idx_map_key!(Register);
207idx_map_key!(ProgramOffset);
208
209#[derive(Default)]
210pub struct VmState {
211 stack: Vec<Value>,
212 static_data: IndexMap<DataLabel, DataSectionEntry>,
213 registers: BTreeMap<Reg, Value>,
214 program_counter: ProgramOffset,
215 memory: Vec<u64>,
216 call_stack: Vec<ProgramOffset>,
217}
218
219impl Default for ProgramOffset {
220 fn default() -> Self {
221 0.into()
222 }
223}
224
225#[derive(Clone, Copy, Debug)]
226pub struct Value(u64);
227
228impl Value {
229 pub fn inner(&self) -> u64 {
230 self.0
231 }
232}
233
234#[derive(Debug, Error)]
235pub enum VmError {
236 #[error("Function label not found when executing opcode {0}")]
237 FunctionLabelNotFound(IrOpcode),
238 #[error("Popped empty stack when executing opcode {0}")]
239 PoppedEmptyStack(IrOpcode),
240 #[error("Register {0} not found")]
241 RegisterNotFound(Reg),
242 #[error("PC value of {0} is out of bounds for program of length {1}")]
243 ProgramCounterOutOfBounds(ProgramOffset, u64),
244 #[error("Returned to an empty call stack when executing opcode {0}")]
245 PoppedEmptyCallStack(IrOpcode),
246 #[error("Attempted to write to memory at index {0} but memory only has length {1}")]
247 OutOfBoundsMemoryWrite(usize, usize),
248}
249
250type Result<T> = std::result::Result<T, VmError>;
251
252enum VmControlFlow {
253 Continue,
254 Terminate(Value),
255}
256
257pub type VmLogs = Vec<String>;
258
259impl Vm {
260 pub fn new(
261 instructions: Vec<IrOpcode>,
262 static_data: IndexMap<DataLabel, DataSectionEntry>,
263 ) -> Self {
264 let mut idx_map = IndexMap::default();
265 for instr in instructions {
266 idx_map.insert(instr);
267 }
268 Self {
269 state: VmState {
270 stack: Default::default(),
271 static_data,
272 registers: Default::default(),
273 program_counter: 0.into(),
274 memory: Vec::with_capacity(100),
275 call_stack: Default::default(),
276 },
277 instructions: idx_map,
278 stdout: vec![],
279 }
280 }
281
282 pub fn run(mut self) -> Result<(Value, Vec<Value>, VmLogs)> {
283 use VmControlFlow::*;
284 let val = loop {
285 match self.execute() {
286 Ok(Continue) => continue,
287 Ok(Terminate(val)) => break val,
288 Err(e) => return Err(e),
289 }
290 };
291 Ok((val, self.state.stack, self.stdout))
292 }
293
294 fn execute(&mut self) -> Result<VmControlFlow> {
295 use VmControlFlow::*;
296 if self.state.program_counter.0 >= self.instructions.len() {
297 return Err(VmError::ProgramCounterOutOfBounds(
298 self.state.program_counter,
299 self.instructions.len() as u64,
300 ));
301 }
302 let opcode = self.instructions.get(self.state.program_counter).clone();
303 self.state.program_counter = (self.state.program_counter.0 + 1).into();
304 match opcode {
305 IrOpcode::JumpImmediate(label) => {
306 let Some(offset) = self
307 .instructions
308 .iter()
309 .find_map(|(position, op)| if *op == IrOpcode::FunctionLabel(label) { Some(position) } else { None })
310 else {
311 return Err(VmError::FunctionLabelNotFound(opcode));
312 };
313 self.state.program_counter = offset;
314 Ok(Continue)
315 },
316 IrOpcode::Add(dest, lhs, rhs) => {
317 let lhs = self.get_register(lhs)?;
318 let rhs = self.get_register(rhs)?;
319 self.set_register(dest, Value(lhs.0.wrapping_add(rhs.0)));
320 Ok(Continue)
321 },
322 IrOpcode::Multiply(dest, lhs, rhs) => {
323 let lhs = self.get_register(lhs)?;
324 let rhs = self.get_register(rhs)?;
325 self.set_register(dest, Value(lhs.0.wrapping_mul(rhs.0)));
326 Ok(Continue)
327 },
328 IrOpcode::Subtract(dest, lhs, rhs) => {
329 let lhs = self.get_register(lhs)?;
330 let rhs = self.get_register(rhs)?;
331 self.set_register(dest, Value(lhs.0.wrapping_sub(rhs.0)));
332 Ok(Continue)
333 },
334 IrOpcode::Divide(dest, lhs, rhs) => {
335 let lhs = self.get_register(lhs)?;
336 let rhs = self.get_register(rhs)?;
337 self.set_register(dest, Value(lhs.0 / rhs.0));
338 Ok(Continue)
339 },
340 IrOpcode::LoadData(dest, data_label) => {
341 let data = self.state.static_data.get(data_label).clone();
342 let data = self.data_section_to_val(&data);
343 self.set_register(dest, data);
344 Ok(Continue)
345 },
346 IrOpcode::StackPop(ref dest) => {
347 let Some(data) = self.state.stack.pop() else {
348 return Err(VmError::PoppedEmptyStack(opcode));
349 };
350 self.set_register(dest.reg, data);
351 Ok(Continue)
352 },
353 IrOpcode::StackPush(val) => {
354 let data = self.get_register(val.reg)?;
355 self.state.stack.push(data);
356 Ok(Continue)
357 },
358 IrOpcode::Intrinsic(intrinsic) => {
359 match intrinsic {
360 Intrinsic::Puts(reg) => {
361 let ptr = self.get_register(reg)?.0;
362 let ptr = ptr as usize;
363
364 let len = self.state.memory[ptr];
365 let len = len as usize;
366
367 let str = &self.state.memory[ptr + 1..ptr + 1 + len];
370 let str = str.iter().flat_map(|num| num.to_ne_bytes()).collect::<Vec<u8>>();
371 let string: String = str.iter().map(|&c| c as char).collect();
373 self.stdout.push(string.clone());
374 },
375 };
376 Ok(Continue)
377 },
378 IrOpcode::FunctionLabel(_) => Ok(Continue),
379 IrOpcode::LoadImmediate(dest, imm) => {
380 self.set_register(dest, Value(imm));
381 Ok(Continue)
382 },
383 IrOpcode::Copy(dest, src) => {
384 let val = self.get_register(src)?;
385 self.set_register(dest, val);
386 Ok(Continue)
387 },
388 IrOpcode::Jump(_) => todo!(),
389 IrOpcode::Label(_) => todo!(),
390 IrOpcode::Return() => {
391 let val = self.get_register(Reg::Reserved(ReservedRegister::ReturnValueRegister))?;
392 let Some(offset) = self.state.call_stack.pop() else {
394 return Ok(Terminate(val));
395 };
396 self.state.program_counter = offset;
397 Ok(Continue)
398 },
399 IrOpcode::PushPc() => {
400 self.state.call_stack.push((self.state.program_counter.0 + 1).into());
401 Ok(Continue)
402 },
403 IrOpcode::StackPushImmediate(imm) => {
404 self.state.stack.push(Value(imm));
405 Ok(Continue)
406 },
407 IrOpcode::ReturnImmediate(imm) => {
408 let Some(offset) = self.state.call_stack.pop() else {
409 return Ok(Terminate(Value(imm)));
410 };
411 self.state.program_counter = offset;
412 Ok(Continue)
413 },
414 IrOpcode::Malloc(ptr_dest, size) => {
415 let size = self.get_register(size)?;
416 let ptr = self.state.memory.len();
417 self.state.memory.resize(ptr + size.0 as usize, 0);
418 self.set_register(ptr_dest, Value(ptr as u64));
419 Ok(Continue)
420 },
421 IrOpcode::MallocImmediate(ptr_dest, size) => {
422 let ptr = self.state.memory.len();
423 self.state.memory.resize(ptr + size.num_bytes(), 0);
424 self.set_register(ptr_dest, Value(ptr as u64));
425 Ok(Continue)
426 },
427 IrOpcode::WriteRegisterToMemory(reg, dest_ptr) => {
428 let dest_ptr = self.get_register(dest_ptr)?.0 as usize;
429 let val = self.get_register(reg)?.0;
430
431 if self.state.memory.len() <= dest_ptr {
432 return Err(VmError::OutOfBoundsMemoryWrite(dest_ptr, self.state.memory.len()));
433 };
434 self.state.memory[dest_ptr] = val;
435 Ok(Continue)
436 },
437 IrOpcode::Comment(_) => Ok(Continue),
438 }
439 }
440
441 fn get_register(
442 &self,
443 reg: petr_ir::Reg,
444 ) -> Result<Value> {
445 self.state.registers.get(®).copied().ok_or(VmError::RegisterNotFound(reg))
446 }
447
448 fn set_register(
449 &mut self,
450 dest: petr_ir::Reg,
451 val: Value,
452 ) {
453 self.state.registers.insert(dest, val);
454 }
455
456 fn data_section_to_val(
458 &mut self,
459 data: &DataSectionEntry,
460 ) -> Value {
461 match data {
462 DataSectionEntry::Int64(x) => Value(*x as u64),
463 DataSectionEntry::String(val) => {
464 let str_as_bytes = val.as_bytes();
465 let bytes_compressed_as_u64s = str_as_bytes
466 .chunks(8)
467 .map(|chunk| {
468 let mut bytes = [0u8; 8];
469 let len = chunk.len();
471 let chunk = if len < 8 {
472 let mut padded = [0u8; 8];
473 padded[..len].copy_from_slice(chunk);
474 padded.to_vec()
475 } else {
476 chunk.to_vec()
477 };
478 bytes.copy_from_slice(&chunk[..]);
479 u64::from_ne_bytes(bytes)
480 })
481 .collect::<Vec<_>>();
482 let ptr = self.state.memory.len();
483 self.state.memory.push(bytes_compressed_as_u64s.len() as u64);
485 self.state.memory.extend_from_slice(&bytes_compressed_as_u64s);
486 Value(ptr as u64)
487 },
488 DataSectionEntry::Bool(x) => Value(if *x { 1 } else { 0 }),
489 }
490 }
491}