use std::collections::HashMap;
use bytes::Bytes;
use indexmap::IndexMap;
use small_len::{Length, SmallLen};
use crate::vm_value::{length_from_bytes, VMValue};
use crate::{Op, VMError};
pub type VMFunction<T> = fn(&mut VM<T>, args: Vec<T>) -> Result<Option<T>, VMError>;
pub type HCFFunction<T> = fn(&mut VM<T>) -> Result<(), VMError>;
#[derive(Clone, Debug, PartialEq)]
pub enum FrameValue<T: Clone + PartialEq> {
Value(T),
Frame(Frame<T>),
}
#[derive(Clone, Debug, PartialEq)]
pub struct Frame<T: Clone + PartialEq> {
pub instructions: Vec<u8>,
pub pc: usize,
pub locals: IndexMap<String, FrameValue<T>>,
pub parent: Option<usize>,
}
impl <T: Clone + PartialEq> Frame<T> {
pub(crate) fn new() -> Self {
Frame {
instructions: vec![],
pc: 0,
locals: IndexMap::new(),
parent: None,
}
}
pub(crate) fn child(parent: usize) -> Self {
Frame {
instructions: vec![],
pc: 0,
locals: IndexMap::new(),
parent: Some(parent),
}
}
}
pub struct VM<T: Clone + VMValue<T> + Default + PartialEq> {
pub fp: usize,
pub frames: Vec<Frame<T>>,
pub functions: Vec<VMFunction<T>>,
pub registers: HashMap<Length, T>,
pub index_registers: HashMap<Length, Length>,
pub stack: Vec<usize>,
pub program_data: Bytes,
pub hcf_trigger: Option<HCFFunction<T>>,
}
impl <T: Clone + VMValue<T> + Default + PartialEq> VM<T> {
pub fn empty() -> VM<T> {
VM {
fp: 0,
frames: vec![],
stack: vec![],
functions: vec![],
registers: HashMap::new(),
index_registers: HashMap::new(),
program_data: Bytes::new(),
hcf_trigger: None,
}
}
pub fn pc(&mut self) -> usize {
self.frames[self.fp].pc
}
pub fn pc_set(&mut self, new: usize) -> usize {
self.frames[self.fp].pc = new;
new
}
pub fn instructions(&mut self) -> &Vec<u8> {
&self.frames[self.fp].instructions
}
pub fn instruction(&mut self, index: usize) -> u8 {
self.frames[self.fp].instructions[index]
}
pub fn get_local(&mut self, fp: usize, var: String) -> Result<T, VMError> {
match self.frames[fp].locals.get(&var) {
None => {
match self.frames[fp].parent {
None => {
return Err(VMError::LocalNotFound(format!("{} not found", var)))
}
Some(parent) => self.get_local(parent, var)
}
}
Some(local) => {
match local {
FrameValue::Value(v) => {
Ok(v.clone())
}
FrameValue::Frame(_f) => {
todo!()
}
}
}
}
}
pub fn local_exists(&mut self, fp: usize, var: &String) -> (bool, usize) {
if self.frames[fp].locals.contains_key(var) {
return (true, fp)
}
match self.frames[fp].parent {
None => (false, 0),
Some(parent) => self.local_exists(parent, var)
}
}
pub fn set_local(&mut self, fp: usize, var: String, value: T) -> Result<(), VMError> {
let (exists, _) = self.local_exists(fp, &var);
if exists {
return Err(VMError::ImmutableLocal(format!("{} is immutable", var)))
}
self.frames[fp].locals.insert(var, FrameValue::Value(value));
Ok(())
}
pub fn set_mutable_local(&mut self, fp: usize, var: String, value: T) -> Result<(), VMError> {
let (exists, parent) = self.local_exists(fp, &var);
let fp = if exists {
parent
} else {
fp
};
self.frames[fp].locals.insert(var, FrameValue::Value(value));
Ok(())
}
pub fn remove_register(&mut self, index: &Length) -> Result<T, VMError> {
match self.registers.remove(index) {
None => return Err(VMError::RegisterNotFound(format!("register {} not found", index))),
Some(l) => Ok(l)
}
}
pub fn remove_index_register(&mut self, index: &Length) -> Result<Length, VMError> {
match self.index_registers.remove(index) {
None => return Err(VMError::RegisterNotFound(format!("index_register {} not found", index))),
Some(l) => Ok(l)
}
}
pub fn run(&mut self) -> Result<(), VMError> {
loop {
if self.pc() >= self.instructions().len() {
break
}
let pc = self.pc();
let raw = self.instruction(pc);
let command = Op::from(raw);
let pc = self.pc_set(pc + 1);
match command {
Op::NOP => {
self.pc_set( pc + 1);
}
Op::LOAD => {
let register = self.next_len();
let value = T::from_bytes(self);
self.set_register_value(register, value);
}
Op::COPY => {
let source = self.next_len();
let dest = self.next_len();
let value = self.registers[&source].clone();
self.set_register_value(dest, value);
}
Op::FN => {
let command = self.next_len().index();
if self.functions.len() <= command {
return Err(VMError::InvalidFunction(format!("Invalid function: {}", command)));
}
let function = self.functions[command];
let args_len = self.next_len().index();
let output_register = self.next_len();
let mut args = vec![];
for _ in 0..args_len {
let register = self.next_len();
args.push(self.remove_register(®ister)?);
}
let result = function(self, args)?;
if let Some(result) = result {
self.set_register_value(output_register, result);
}
}
Op::CALL => {
let frame = self.next_len();
self.stack.push(self.fp);
self.fp = frame.index();
}
Op::CALLR => {
let index = self.next_len();
self.stack.push(self.fp);
self.fp = self.remove_index_register(&index)?.index();
}
Op::RET => {
let from = self.next_len();
let to = self.next_len();
let value = self.remove_register(&from)?;
self.set_register_value(to, value);
self.fp = match self.stack.pop() {
None => {
return Err(VMError::StackUnderflow("Attempted to return with empty stack".into()))
}
Some(f) => f,
};
}
Op::GLV => {
let register = self.next_len();
let var = self.next_str();
let value = self.get_local(self.fp, var)?;
self.set_register_value(register, value);
}
Op::SLV => {
let var = self.next_str();
let value = T::from_bytes(self);
self.set_local(self.fp, var, value)?;
}
Op::SMV => {
let var = self.next_str();
let value = T::from_bytes(self);
self.set_mutable_local(self.fp, var, value)?;
}
Op::SLR => {
let register = self.next_len();
let var = self.next_str();
let value = self.remove_register(®ister)?;
self.set_local(self.fp, var, value)?;
}
Op::SMR => {
let register = self.next_len();
let var = self.next_str();
let value = self.remove_register(®ister)?;
self.set_mutable_local(self.fp, var, value)?;
}
Op::CFR => {
let register = self.next_len();
let mut f = Frame::child(self.fp);
let len = self.next_len();
let next = self.next_n_bytes_vec(len);
f.instructions.extend(next);
self.frames.push(f);
self.set_index_register_value(register, self.frames.small_len() - 1);
}
Op::DFR => {
let index = self.next_len().index();
self.frames.remove(index);
}
Op::HCF => {
if let Some(hcf) = self.hcf_trigger {
hcf(self)?;
}
break
},
Op::IVD => return Err(VMError::InvalidInstruction(format!("Invalid instruction: {}", raw))),
}
}
Ok(())
}
pub fn next_str(&mut self) -> String {
let len = self.next_len();
let bytes = self.next_n_bytes_vec(len);
std::str::from_utf8(&bytes).unwrap().to_string()
}
pub fn set_register_value(&mut self, register: Length, value: T) -> &mut Self {
self.registers.insert(register, value);
self
}
pub fn set_index_register_value(&mut self, register: Length, value: Length) -> &mut Self {
self.index_registers.insert(register, value);
self
}
pub fn next_len(&mut self) -> Length {
length_from_bytes(self)
}
pub fn next_byte(&mut self) -> u8 {
let pc = self.pc();
let result = self.instruction(pc);
self.pc_set(pc + 1);
result
}
pub fn next_n_bytes<const N: usize>(&mut self) -> [u8; N] {
let mut result = [0; N];
let pc = self.pc();
result.copy_from_slice(&self.instructions()[pc..pc + N]);
self.pc_set(pc + N);
result
}
pub fn next_n_bytes_vec(&mut self, n: Length) -> Vec<u8> {
let pc = self.pc();
let end = (pc + n).index();
let result = self.instructions()[pc..end].to_vec();
self.pc_set(end);
result
}
}