luallaby 0.1.0-alpha.3

**Work in progress** A pure-Rust Lua interpreter/compiler
Documentation
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Sub};

use crate::error::{LuaError, Result};
use crate::vm::{Numeric, Value};

#[derive(Clone, Copy, Debug)]
pub enum BinOp {
    // Mathematical
    Add,
    Sub,
    Mul,
    Div,
    FlDiv,
    Pow,
    Mod,
    // Bitwise
    BitAnd,
    BitXor,
    BitOr,
    Shl,
    Shr,
    // String
    Conc,
    // Comparison
    Eq,
    Neq,
    Lt,
    Leq,
    Gt,
    Geq,
    // Logical
    And,
    Or,
}

#[derive(Clone, Copy, Debug)]
pub enum UnOp {
    Minus,
    BitNot,
    Not,
    Len,
}

impl BinOp {
    pub fn precedence(&self) -> usize {
        // Value '10' is absent since that is given to unary operators
        match self {
            BinOp::Or => 0,
            BinOp::And => 1,
            BinOp::Lt | BinOp::Leq | BinOp::Gt | BinOp::Geq | BinOp::Eq | BinOp::Neq => 2,
            BinOp::BitOr => 3,
            BinOp::BitXor => 4,
            BinOp::BitAnd => 5,
            BinOp::Shl | BinOp::Shr => 6,
            BinOp::Conc => 7,
            BinOp::Add | BinOp::Sub => 8,
            BinOp::Mul | BinOp::Div | BinOp::FlDiv | BinOp::Mod => 9,
            BinOp::Pow => 11,
        }
    }

    pub fn is_left_ass(&self) -> bool {
        !matches!(self, BinOp::Conc | BinOp::Pow)
    }
}

impl UnOp {
    pub fn precedence(&self) -> usize {
        10 // Precedence is always 10, above all binary operators except 'pow'
    }
}

impl Value {
    pub(super) fn bin_op(&self, op: BinOp, rhs: &Value) -> Result<Value> {
        bin_op(self, op, rhs)
    }
}

impl Numeric {
    pub fn rem(&self, rhs: &Self) -> Result<Self> {
        num_rem(self, rhs)
    }
}

pub(super) fn bin_op(lhs: &Value, op: BinOp, rhs: &Value) -> Result<Value> {
    match op {
        // Arithmetic
        BinOp::Add => val_add(lhs, rhs),
        BinOp::Sub => val_sub(lhs, rhs),
        BinOp::Mul => val_mul(lhs, rhs),
        BinOp::Div => val_div(lhs, rhs),
        BinOp::FlDiv => val_fldiv(lhs, rhs),
        BinOp::Mod => val_rem(lhs, rhs),
        BinOp::Pow => val_pow(lhs, rhs),
        // Bitwise
        BinOp::BitAnd => val_bitand(lhs, rhs),
        BinOp::BitOr => val_bitor(lhs, rhs),
        BinOp::BitXor => val_bitxor(lhs, rhs),
        BinOp::Shl => val_shl(lhs, rhs),
        BinOp::Shr => val_shr(lhs, rhs),
        // String
        BinOp::Conc => val_conc(lhs, rhs),
        // Comparison
        BinOp::Eq => val_eq(lhs, rhs),
        BinOp::Neq => val_neq(lhs, rhs),
        BinOp::Lt => val_lt(lhs, rhs),
        BinOp::Leq => val_le(lhs, rhs),
        BinOp::Gt => val_gt(lhs, rhs),
        BinOp::Geq => val_ge(lhs, rhs),
        // Logical, short circuiting implemented in compiler, should not be computed through here
        BinOp::And | BinOp::Or => unreachable!(),
    }
}

pub(super) fn un_op(lhs: &Value, op: UnOp) -> Result<Value> {
    match op {
        UnOp::Minus => val_minus(lhs),
        UnOp::BitNot => val_bitnot(lhs),
        UnOp::Not => val_not(lhs),
        UnOp::Len => val_len(lhs),
    }
}

macro_rules! op_val_arith {
    ( $fn:ident, $num_op:ident ) => {
        fn $fn(lhs: &Value, rhs: &Value) -> Result<Value> {
            let lhs = lhs.to_number_coerce().ok();
            let rhs = rhs.to_number_coerce().ok();
            if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
                $num_op(&lhs, &rhs).map(Value::Number)
            } else {
                panic!("both operands must be numbers or string representation of a number")
            }
        }
    };
}

op_val_arith!(val_add, num_add);
op_val_arith!(val_sub, num_sub);
op_val_arith!(val_mul, num_mul);
op_val_arith!(val_div, num_div);
op_val_arith!(val_fldiv, num_fldiv);
op_val_arith!(val_rem, num_rem);
op_val_arith!(val_pow, num_pow);

macro_rules! op_val_bitw {
    ( $fn:ident, $num_op:ident ) => {
        fn $fn(lhs: &Value, rhs: &Value) -> Result<Value> {
            if let (Value::Number(num_lhs), Value::Number(num_rhs)) = (lhs, rhs) {
                $num_op(num_lhs, num_rhs).map(Value::Number)
            } else {
                panic!("both operands must be numbers")
            }
        }
    };
}

op_val_bitw!(val_bitand, num_bitand);
op_val_bitw!(val_bitor, num_bitor);
op_val_bitw!(val_bitxor, num_bitxor);
op_val_bitw!(val_shl, num_shl);
op_val_bitw!(val_shr, num_shr);

fn val_conc(lhs: &Value, rhs: &Value) -> Result<Value> {
    let op1 = match lhs {
        Value::Number(n) => format!("{}", n),
        Value::String(str) => str.to_string(),
        _ => panic!(),
    };
    let op2 = match rhs {
        Value::Number(n) => format!("{}", n),
        Value::String(str) => str.to_string(),
        _ => panic!(),
    };
    Ok(Value::String(format!("{}{}", op1, op2)))
}

fn val_eq(lhs: &Value, rhs: &Value) -> Result<Value> {
    Ok(Value::Bool(match (lhs, rhs) {
        // Only override number comparison, for all other types this is equivalent to raw equality
        (Value::Number(n1), Value::Number(n2)) => match (n1.coerce_int(), n2.coerce_int()) {
            (Ok(n1), Ok(n2)) => n1 == n2,
            (Ok(..), ..) | (.., Ok(..)) => false,
            _ => {
                let n1 = n1.to_float();
                let n2 = n2.to_float();
                #[allow(clippy::float_cmp)]
                if n1.is_finite() && n2.is_finite() {
                    (n1 - n2).abs() < f64::EPSILON
                } else {
                    n1 == n2
                }
            }
        },
        // Delegate to raw equality
        (v1, v2) => v1 == v2,
    }))
}

fn val_neq(lhs: &Value, rhs: &Value) -> Result<Value> {
    bin_op(lhs, BinOp::Eq, rhs).map(|res| {
        if let Value::Bool(bool) = res {
            Value::Bool(!bool)
        } else {
            unreachable!()
        }
    })
}

macro_rules! op_val_comp {
    ( $fn:ident, $num_op:ident, $str_op:ident ) => {
        fn $fn(lhs: &Value, rhs: &Value) -> Result<Value> {
            Ok(match (lhs, rhs) {
                (Value::Number(lhs), Value::Number(rhs)) => Value::Bool($num_op(lhs, rhs)?),
                (Value::String(lhs), Value::String(rhs)) => Value::Bool(lhs.$str_op(&rhs)),
                _ => panic!(),
            })
        }
    };
}

op_val_comp!(val_lt, num_lt, lt);
op_val_comp!(val_le, num_le, le);
op_val_comp!(val_gt, num_gt, gt);
op_val_comp!(val_ge, num_ge, ge);

fn val_minus(op: &Value) -> Result<Value> {
    let op = op.to_number_coerce().ok();
    if let Some(n) = op {
        Ok(Value::Number(num_minus(&n)))
    } else {
        panic!("operand must be number or string representing a number")
    }
}

fn val_bitnot(op: &Value) -> Result<Value> {
    match op {
        Value::Number(n) => num_bitnot(n).map(Value::Number),
        _ => panic!("operand must be number"),
    }
}

fn val_not(op: &Value) -> Result<Value> {
    Ok(Value::Bool(op.is_falsy()))
}

fn val_len(op: &Value) -> Result<Value> {
    match op {
        Value::String(s) => Ok(Value::int(s.len() as i64)),
        Value::Table(t) => Ok(Value::int(t.borrow().border() as i64)),
        _ => panic!("operand must be string or table"),
    }
}

macro_rules! op_num_arith {
    ( $fn:ident, $int_op:ident, $flt_op:ident ) => {
        fn $fn(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
            if let (Ok(lhs), Ok(rhs)) = (lhs.to_int(), rhs.to_int()) {
                Ok(Numeric::Integer(lhs.$int_op(rhs)))
            } else {
                Ok(Numeric::Float(lhs.to_float().$flt_op(rhs.to_float())))
            }
        }
    };
}

op_num_arith!(num_add, wrapping_add, add);
op_num_arith!(num_sub, wrapping_sub, sub);
op_num_arith!(num_mul, wrapping_mul, mul);

fn num_rem(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
    if let (Numeric::Integer(lhs), Numeric::Integer(rhs)) = (lhs, rhs) {
        if *rhs == 0 {
            return err!(LuaError::DivideByZero);
        }
        let mut rem = lhs.wrapping_rem(*rhs);
        if rem != 0 && lhs.signum() != rhs.signum() {
            rem += rhs;
        }
        Ok(Numeric::Integer(rem))
    } else {
        let lhs = lhs.to_float();
        let rhs = rhs.to_float();
        let mut rem = lhs % rhs;
        if rem.abs() > f64::EPSILON && lhs.signum() != rhs.signum() {
            rem += rhs;
        }
        Ok(Numeric::Float(rem))
    }
}

fn num_div(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
    Ok(Numeric::Float(lhs.to_float().div(rhs.to_float())))
}

fn num_fldiv(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
    if let (Numeric::Integer(n1), Numeric::Integer(n2)) = (lhs, rhs) {
        if *n2 == 0 {
            return err!(LuaError::DivideByZero);
        }
        // When overflowing saturate to boundary
        let (mut res, _) = n1.overflowing_div(*n2);
        // Correct rounding if n1/n2 is negative non-integer
        if (n1 ^ n2) < 0 && n1 % n2 != 0 {
            res -= 1;
        }
        Ok(Numeric::Integer(res))
    } else {
        Ok(Numeric::Float(lhs.to_float().div(rhs.to_float()).floor()))
    }
}

fn num_pow(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
    Ok(Numeric::Float(lhs.to_float().powf(rhs.to_float())))
}

macro_rules! op_num_bitw {
    ( $fn:ident, $op:ident ) => {
        fn $fn(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
            let lhs = lhs.coerce_int()?;
            let rhs = rhs.coerce_int()?;
            Ok(Numeric::Integer(lhs.$op(rhs)))
        }
    };
}

op_num_bitw!(num_bitand, bitand);
op_num_bitw!(num_bitor, bitor);
op_num_bitw!(num_bitxor, bitxor);

// Take care of overflowing shifts
fn num_shl(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
    let rhs_int = rhs.coerce_int()?;
    if rhs_int < 0 {
        num_shr(lhs, &Numeric::Integer(-rhs_int))
    } else {
        let lhs_int = lhs.coerce_int()?;
        let (val, over) = lhs_int.overflowing_shl(rhs_int as u32);
        Ok(Numeric::Integer(if over { 0 } else { val }))
    }
}

// Rust does arithmetic right shifting on signed integers, force logical shifting
fn num_shr(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
    let rhs_int = rhs.coerce_int()?;
    if rhs_int < 0 {
        num_shl(lhs, &Numeric::Integer(-rhs_int))
    } else {
        let lhs_int = lhs.coerce_int()?;
        let (val, over) = (lhs_int as u64).overflowing_shr(rhs_int as u32);
        Ok(Numeric::Integer(if over { 0 } else { val as i64 }))
    }
}

fn num_lt(lhs: &Numeric, rhs: &Numeric) -> Result<bool> {
    Ok(match (lhs.to_int(), rhs.to_int()) {
        (Ok(lhs), Ok(rhs)) => lhs < rhs,
        (Ok(lhs), ..) => {
            // i < f <=> i < ceil(f)
            let rhs = Numeric::Float(rhs.to_float().ceil());
            rhs.coerce_int()
                .map(|rhs| lhs < rhs)
                .unwrap_or(rhs.to_float() > 0.)
        }
        (.., Ok(rhs)) => {
            // f < i <=> floor(f) < i
            let lhs = Numeric::Float(lhs.to_float().floor());
            lhs.coerce_int()
                .map(|lhs| lhs < rhs)
                .unwrap_or(lhs.to_float() < 0.)
        }
        _ => lhs.to_float() < rhs.to_float(),
    })
}

fn num_le(lhs: &Numeric, rhs: &Numeric) -> Result<bool> {
    Ok(match (lhs.to_int(), rhs.to_int()) {
        (Ok(lhs), Ok(rhs)) => lhs <= rhs,
        (Ok(lhs), ..) => {
            // i <= f <=> i <= floor(f)
            let rhs = Numeric::Float(rhs.to_float().floor());
            rhs.coerce_int()
                .map(|rhs| lhs <= rhs)
                .unwrap_or(rhs.to_float() > 0.)
        }
        (.., Ok(rhs)) => {
            // f <= i <=> ceil(f) <= i
            let lhs = Numeric::Float(lhs.to_float().ceil());
            lhs.coerce_int()
                .map(|lhs| lhs <= rhs)
                .unwrap_or(lhs.to_float() < 0.)
        }
        _ => lhs.to_float() <= rhs.to_float(),
    })
}

fn num_gt(lhs: &Numeric, rhs: &Numeric) -> Result<bool> {
    num_lt(rhs, lhs)
}

// a >= b <=> b <= a
fn num_ge(lhs: &Numeric, rhs: &Numeric) -> Result<bool> {
    num_le(rhs, lhs)
}

fn num_minus(op: &Numeric) -> Numeric {
    match op {
        Numeric::Integer(n) => Numeric::Integer(n.wrapping_neg()),
        Numeric::Float(f) => Numeric::Float(-f),
    }
}

fn num_bitnot(op: &Numeric) -> Result<Numeric> {
    Ok(Numeric::Integer(!op.coerce_int()?))
}