harn-vm 0.8.3

Async bytecode virtual machine for the Harn programming language
Documentation
use std::rc::Rc;

use crate::value::{VmError, VmValue};

impl super::super::Vm {
    fn push_binary_result(
        &mut self,
        f: impl FnOnce(&Self, VmValue, VmValue) -> Result<VmValue, VmError>,
    ) -> Result<(), VmError> {
        let b = self.pop()?;
        let a = self.pop()?;
        let result = f(self, a, b)?;
        self.stack.push(result);
        Ok(())
    }

    pub(super) fn execute_add(&mut self) -> Result<(), VmError> {
        self.push_binary_result(Self::add)
    }

    pub(super) fn execute_sub(&mut self) -> Result<(), VmError> {
        self.push_binary_result(Self::sub)
    }

    pub(super) fn execute_mul(&mut self) -> Result<(), VmError> {
        self.push_binary_result(Self::mul)
    }

    pub(super) fn execute_div(&mut self) -> Result<(), VmError> {
        self.push_binary_result(Self::div)
    }

    pub(super) fn execute_mod(&mut self) -> Result<(), VmError> {
        self.push_binary_result(Self::modulo)
    }

    pub(super) fn execute_pow(&mut self) -> Result<(), VmError> {
        self.push_binary_result(Self::pow)
    }

    pub(super) fn execute_negate(&mut self) -> Result<(), VmError> {
        let v = self.pop()?;
        self.stack.push(match v {
            VmValue::Int(n) => VmValue::Int(n.wrapping_neg()),
            VmValue::Float(n) => VmValue::Float(-n),
            _ => {
                return Err(VmError::Runtime(format!(
                    "Cannot negate value of type {}",
                    v.type_name()
                )))
            }
        });
        Ok(())
    }

    pub(super) fn execute_add_int(&mut self) -> Result<(), VmError> {
        let b = self.pop()?;
        let a = self.pop()?;
        let (x, y) = typed_int_pair("add", a, b)?;
        self.stack.push(VmValue::Int(x.wrapping_add(y)));
        Ok(())
    }

    pub(super) fn execute_sub_int(&mut self) -> Result<(), VmError> {
        let b = self.pop()?;
        let a = self.pop()?;
        let (x, y) = typed_int_pair("subtract", a, b)?;
        self.stack.push(VmValue::Int(x.wrapping_sub(y)));
        Ok(())
    }

    pub(super) fn execute_mul_int(&mut self) -> Result<(), VmError> {
        let b = self.pop()?;
        let a = self.pop()?;
        let (x, y) = typed_int_pair("multiply", a, b)?;
        self.stack.push(VmValue::Int(x.wrapping_mul(y)));
        Ok(())
    }

    pub(super) fn execute_div_int(&mut self) -> Result<(), VmError> {
        let b = self.pop()?;
        let a = self.pop()?;
        let (x, y) = typed_int_pair("divide", a, b)?;
        if y == 0 {
            return Err(VmError::DivisionByZero);
        }
        self.stack.push(VmValue::Int(x / y));
        Ok(())
    }

    pub(super) fn execute_mod_int(&mut self) -> Result<(), VmError> {
        let b = self.pop()?;
        let a = self.pop()?;
        let (x, y) = typed_int_pair("modulo", a, b)?;
        if y == 0 {
            return Err(VmError::DivisionByZero);
        }
        self.stack.push(VmValue::Int(x % y));
        Ok(())
    }

    pub(super) fn execute_add_float(&mut self) -> Result<(), VmError> {
        let b = self.pop()?;
        let a = self.pop()?;
        let (x, y) = typed_float_pair("add", a, b)?;
        self.stack.push(VmValue::Float(x + y));
        Ok(())
    }

    pub(super) fn execute_sub_float(&mut self) -> Result<(), VmError> {
        let b = self.pop()?;
        let a = self.pop()?;
        let (x, y) = typed_float_pair("subtract", a, b)?;
        self.stack.push(VmValue::Float(x - y));
        Ok(())
    }

    pub(super) fn execute_mul_float(&mut self) -> Result<(), VmError> {
        let b = self.pop()?;
        let a = self.pop()?;
        let (x, y) = typed_float_pair("multiply", a, b)?;
        self.stack.push(VmValue::Float(x * y));
        Ok(())
    }

    pub(super) fn execute_div_float(&mut self) -> Result<(), VmError> {
        let b = self.pop()?;
        let a = self.pop()?;
        let (x, y) = typed_float_pair("divide", a, b)?;
        self.stack.push(VmValue::Float(x / y));
        Ok(())
    }

    pub(super) fn execute_mod_float(&mut self) -> Result<(), VmError> {
        let b = self.pop()?;
        let a = self.pop()?;
        let (x, y) = typed_float_pair("modulo", a, b)?;
        if y == 0.0 {
            return Err(VmError::DivisionByZero);
        }
        self.stack.push(VmValue::Float(x % y));
        Ok(())
    }

    fn add(&self, a: VmValue, b: VmValue) -> Result<VmValue, VmError> {
        match (a, b) {
            (VmValue::Int(x), VmValue::Int(y)) => Ok(VmValue::Int(x.wrapping_add(y))),
            (VmValue::Float(x), VmValue::Float(y)) => Ok(VmValue::Float(x + y)),
            (VmValue::Int(x), VmValue::Float(y)) => Ok(VmValue::Float(x as f64 + y)),
            (VmValue::Float(x), VmValue::Int(y)) => Ok(VmValue::Float(x + y as f64)),
            (VmValue::String(x), VmValue::String(y)) => {
                if x.is_empty() {
                    return Ok(VmValue::String(y));
                }
                if y.is_empty() {
                    return Ok(VmValue::String(x));
                }
                let mut s = String::with_capacity(x.len() + y.len());
                s.push_str(&x);
                s.push_str(&y);
                Ok(VmValue::String(Rc::from(s)))
            }
            (VmValue::List(x), VmValue::List(y)) => {
                if x.is_empty() {
                    return Ok(VmValue::List(y));
                }
                if y.is_empty() {
                    return Ok(VmValue::List(x));
                }
                let y_len = y.len();
                let mut result = Rc::try_unwrap(x).unwrap_or_else(|items| items.as_ref().clone());
                result.reserve(y_len);
                match Rc::try_unwrap(y) {
                    Ok(items) => result.extend(items),
                    Err(items) => result.extend(items.iter().cloned()),
                }
                Ok(VmValue::List(Rc::new(result)))
            }
            (VmValue::Dict(x), VmValue::Dict(y)) => {
                if x.is_empty() {
                    return Ok(VmValue::Dict(y));
                }
                if y.is_empty() {
                    return Ok(VmValue::Dict(x));
                }
                let mut result =
                    Rc::try_unwrap(x).unwrap_or_else(|entries| entries.as_ref().clone());
                match Rc::try_unwrap(y) {
                    Ok(entries) => result.extend(entries),
                    Err(entries) => {
                        result.extend(entries.iter().map(|(k, v)| (k.clone(), v.clone())))
                    }
                }
                Ok(VmValue::Dict(Rc::new(result)))
            }
            (a, b) => Err(VmError::TypeError(format!(
                "Cannot add {} and {}",
                a.type_name(),
                b.type_name()
            ))),
        }
    }

    fn sub(&self, a: VmValue, b: VmValue) -> Result<VmValue, VmError> {
        match (&a, &b) {
            (VmValue::Int(x), VmValue::Int(y)) => Ok(VmValue::Int(x.wrapping_sub(*y))),
            (VmValue::Float(x), VmValue::Float(y)) => Ok(VmValue::Float(x - y)),
            (VmValue::Int(x), VmValue::Float(y)) => Ok(VmValue::Float(*x as f64 - y)),
            (VmValue::Float(x), VmValue::Int(y)) => Ok(VmValue::Float(x - *y as f64)),
            _ => Err(VmError::TypeError(format!(
                "Cannot subtract {} from {}",
                b.type_name(),
                a.type_name()
            ))),
        }
    }

    fn mul(&self, a: VmValue, b: VmValue) -> Result<VmValue, VmError> {
        match (&a, &b) {
            (VmValue::Int(x), VmValue::Int(y)) => Ok(VmValue::Int(x.wrapping_mul(*y))),
            (VmValue::Float(x), VmValue::Float(y)) => Ok(VmValue::Float(x * y)),
            (VmValue::Int(x), VmValue::Float(y)) => Ok(VmValue::Float(*x as f64 * y)),
            (VmValue::Float(x), VmValue::Int(y)) => Ok(VmValue::Float(x * *y as f64)),
            (VmValue::String(s), VmValue::Int(n)) | (VmValue::Int(n), VmValue::String(s)) => {
                let count = (*n).max(0) as usize;
                Ok(VmValue::String(s.repeat(count).into()))
            }
            _ => Err(VmError::TypeError(format!(
                "Cannot multiply {} and {}",
                a.type_name(),
                b.type_name()
            ))),
        }
    }

    fn div(&self, a: VmValue, b: VmValue) -> Result<VmValue, VmError> {
        match (&a, &b) {
            (VmValue::Int(_), VmValue::Int(y)) if *y == 0 => Err(VmError::DivisionByZero),
            (VmValue::Int(x), VmValue::Int(y)) => Ok(VmValue::Int(x / y)),
            (VmValue::Float(x), VmValue::Float(y)) => Ok(VmValue::Float(x / y)),
            (VmValue::Int(x), VmValue::Float(y)) => Ok(VmValue::Float(*x as f64 / y)),
            (VmValue::Float(x), VmValue::Int(y)) => Ok(VmValue::Float(x / *y as f64)),
            _ => Err(VmError::Runtime(format!(
                "Cannot divide {} by {}",
                a.type_name(),
                b.type_name()
            ))),
        }
    }

    fn modulo(&self, a: VmValue, b: VmValue) -> Result<VmValue, VmError> {
        match (&a, &b) {
            (VmValue::Int(_), VmValue::Int(y)) if *y == 0 => Err(VmError::DivisionByZero),
            (VmValue::Int(x), VmValue::Int(y)) => Ok(VmValue::Int(x % y)),
            (VmValue::Float(_), VmValue::Float(y)) if *y == 0.0 => Err(VmError::DivisionByZero),
            (VmValue::Float(x), VmValue::Float(y)) => Ok(VmValue::Float(x % y)),
            (VmValue::Int(_), VmValue::Float(y)) if *y == 0.0 => Err(VmError::DivisionByZero),
            (VmValue::Int(x), VmValue::Float(y)) => Ok(VmValue::Float(*x as f64 % y)),
            (VmValue::Float(_), VmValue::Int(y)) if *y == 0 => Err(VmError::DivisionByZero),
            (VmValue::Float(x), VmValue::Int(y)) => Ok(VmValue::Float(x % *y as f64)),
            _ => Err(VmError::Runtime(format!(
                "Cannot modulo {} by {}",
                a.type_name(),
                b.type_name()
            ))),
        }
    }

    fn pow(&self, a: VmValue, b: VmValue) -> Result<VmValue, VmError> {
        match (&a, &b) {
            (VmValue::Int(base), VmValue::Int(exp)) => {
                if *exp >= 0 && *exp <= u32::MAX as i64 {
                    Ok(VmValue::Int(base.wrapping_pow(*exp as u32)))
                } else {
                    Ok(VmValue::Float((*base as f64).powf(*exp as f64)))
                }
            }
            (VmValue::Float(base), VmValue::Int(exp)) => {
                if *exp >= i32::MIN as i64 && *exp <= i32::MAX as i64 {
                    Ok(VmValue::Float(base.powi(*exp as i32)))
                } else {
                    Ok(VmValue::Float(base.powf(*exp as f64)))
                }
            }
            (VmValue::Int(base), VmValue::Float(exp)) => {
                Ok(VmValue::Float((*base as f64).powf(*exp)))
            }
            (VmValue::Float(base), VmValue::Float(exp)) => Ok(VmValue::Float(base.powf(*exp))),
            _ => Err(VmError::TypeError(format!(
                "Cannot exponentiate {} by {}",
                a.type_name(),
                b.type_name()
            ))),
        }
    }
}

#[inline]
fn typed_int_pair(name: &str, a: VmValue, b: VmValue) -> Result<(i64, i64), VmError> {
    match (a, b) {
        (VmValue::Int(x), VmValue::Int(y)) => Ok((x, y)),
        (a, b) => Err(VmError::TypeError(format!(
            "Typed int {name} expected int operands, got {} and {}",
            a.type_name(),
            b.type_name()
        ))),
    }
}

#[inline]
fn typed_float_pair(name: &str, a: VmValue, b: VmValue) -> Result<(f64, f64), VmError> {
    match (a, b) {
        (VmValue::Float(x), VmValue::Float(y)) => Ok((x, y)),
        (a, b) => Err(VmError::TypeError(format!(
            "Typed float {name} expected float operands, got {} and {}",
            a.type_name(),
            b.type_name()
        ))),
    }
}