use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Sub};
use crate::error::{LuaError, Result};
use crate::vm::{Numeric, Value};
#[derive(Clone, Copy, Debug)]
pub enum BinOp {
Add,
Sub,
Mul,
Div,
FlDiv,
Pow,
Mod,
BitAnd,
BitXor,
BitOr,
Shl,
Shr,
Conc,
Eq,
Neq,
Lt,
Leq,
Gt,
Geq,
And,
Or,
}
#[derive(Clone, Copy, Debug)]
pub enum UnOp {
Minus,
BitNot,
Not,
Len,
}
impl BinOp {
pub fn precedence(&self) -> usize {
match self {
BinOp::Or => 0,
BinOp::And => 1,
BinOp::Lt | BinOp::Leq | BinOp::Gt | BinOp::Geq | BinOp::Eq | BinOp::Neq => 2,
BinOp::BitOr => 3,
BinOp::BitXor => 4,
BinOp::BitAnd => 5,
BinOp::Shl | BinOp::Shr => 6,
BinOp::Conc => 7,
BinOp::Add | BinOp::Sub => 8,
BinOp::Mul | BinOp::Div | BinOp::FlDiv | BinOp::Mod => 9,
BinOp::Pow => 11,
}
}
pub fn is_left_ass(&self) -> bool {
!matches!(self, BinOp::Conc | BinOp::Pow)
}
}
impl UnOp {
pub fn precedence(&self) -> usize {
10 }
}
impl Value {
pub(super) fn bin_op(&self, op: BinOp, rhs: &Value) -> Result<Value> {
bin_op(self, op, rhs)
}
}
impl Numeric {
pub fn rem(&self, rhs: &Self) -> Result<Self> {
num_rem(self, rhs)
}
}
pub(super) fn bin_op(lhs: &Value, op: BinOp, rhs: &Value) -> Result<Value> {
match op {
BinOp::Add => val_add(lhs, rhs),
BinOp::Sub => val_sub(lhs, rhs),
BinOp::Mul => val_mul(lhs, rhs),
BinOp::Div => val_div(lhs, rhs),
BinOp::FlDiv => val_fldiv(lhs, rhs),
BinOp::Mod => val_rem(lhs, rhs),
BinOp::Pow => val_pow(lhs, rhs),
BinOp::BitAnd => val_bitand(lhs, rhs),
BinOp::BitOr => val_bitor(lhs, rhs),
BinOp::BitXor => val_bitxor(lhs, rhs),
BinOp::Shl => val_shl(lhs, rhs),
BinOp::Shr => val_shr(lhs, rhs),
BinOp::Conc => val_conc(lhs, rhs),
BinOp::Eq => val_eq(lhs, rhs),
BinOp::Neq => val_neq(lhs, rhs),
BinOp::Lt => val_lt(lhs, rhs),
BinOp::Leq => val_le(lhs, rhs),
BinOp::Gt => val_gt(lhs, rhs),
BinOp::Geq => val_ge(lhs, rhs),
BinOp::And | BinOp::Or => unreachable!(),
}
}
pub(super) fn un_op(lhs: &Value, op: UnOp) -> Result<Value> {
match op {
UnOp::Minus => val_minus(lhs),
UnOp::BitNot => val_bitnot(lhs),
UnOp::Not => val_not(lhs),
UnOp::Len => val_len(lhs),
}
}
macro_rules! op_val_arith {
( $fn:ident, $num_op:ident ) => {
fn $fn(lhs: &Value, rhs: &Value) -> Result<Value> {
let lhs = lhs.to_number_coerce().ok();
let rhs = rhs.to_number_coerce().ok();
if let (Some(lhs), Some(rhs)) = (lhs, rhs) {
$num_op(&lhs, &rhs).map(Value::Number)
} else {
panic!("both operands must be numbers or string representation of a number")
}
}
};
}
op_val_arith!(val_add, num_add);
op_val_arith!(val_sub, num_sub);
op_val_arith!(val_mul, num_mul);
op_val_arith!(val_div, num_div);
op_val_arith!(val_fldiv, num_fldiv);
op_val_arith!(val_rem, num_rem);
op_val_arith!(val_pow, num_pow);
macro_rules! op_val_bitw {
( $fn:ident, $num_op:ident ) => {
fn $fn(lhs: &Value, rhs: &Value) -> Result<Value> {
if let (Value::Number(num_lhs), Value::Number(num_rhs)) = (lhs, rhs) {
$num_op(num_lhs, num_rhs).map(Value::Number)
} else {
panic!("both operands must be numbers")
}
}
};
}
op_val_bitw!(val_bitand, num_bitand);
op_val_bitw!(val_bitor, num_bitor);
op_val_bitw!(val_bitxor, num_bitxor);
op_val_bitw!(val_shl, num_shl);
op_val_bitw!(val_shr, num_shr);
fn val_conc(lhs: &Value, rhs: &Value) -> Result<Value> {
let op1 = match lhs {
Value::Number(n) => format!("{}", n),
Value::String(str) => str.to_string(),
_ => panic!(),
};
let op2 = match rhs {
Value::Number(n) => format!("{}", n),
Value::String(str) => str.to_string(),
_ => panic!(),
};
Ok(Value::String(format!("{}{}", op1, op2)))
}
fn val_eq(lhs: &Value, rhs: &Value) -> Result<Value> {
Ok(Value::Bool(match (lhs, rhs) {
(Value::Number(n1), Value::Number(n2)) => match (n1.coerce_int(), n2.coerce_int()) {
(Ok(n1), Ok(n2)) => n1 == n2,
(Ok(..), ..) | (.., Ok(..)) => false,
_ => {
let n1 = n1.to_float();
let n2 = n2.to_float();
#[allow(clippy::float_cmp)]
if n1.is_finite() && n2.is_finite() {
(n1 - n2).abs() < f64::EPSILON
} else {
n1 == n2
}
}
},
(v1, v2) => v1 == v2,
}))
}
fn val_neq(lhs: &Value, rhs: &Value) -> Result<Value> {
bin_op(lhs, BinOp::Eq, rhs).map(|res| {
if let Value::Bool(bool) = res {
Value::Bool(!bool)
} else {
unreachable!()
}
})
}
macro_rules! op_val_comp {
( $fn:ident, $num_op:ident, $str_op:ident ) => {
fn $fn(lhs: &Value, rhs: &Value) -> Result<Value> {
Ok(match (lhs, rhs) {
(Value::Number(lhs), Value::Number(rhs)) => Value::Bool($num_op(lhs, rhs)?),
(Value::String(lhs), Value::String(rhs)) => Value::Bool(lhs.$str_op(&rhs)),
_ => panic!(),
})
}
};
}
op_val_comp!(val_lt, num_lt, lt);
op_val_comp!(val_le, num_le, le);
op_val_comp!(val_gt, num_gt, gt);
op_val_comp!(val_ge, num_ge, ge);
fn val_minus(op: &Value) -> Result<Value> {
let op = op.to_number_coerce().ok();
if let Some(n) = op {
Ok(Value::Number(num_minus(&n)))
} else {
panic!("operand must be number or string representing a number")
}
}
fn val_bitnot(op: &Value) -> Result<Value> {
match op {
Value::Number(n) => num_bitnot(n).map(Value::Number),
_ => panic!("operand must be number"),
}
}
fn val_not(op: &Value) -> Result<Value> {
Ok(Value::Bool(op.is_falsy()))
}
fn val_len(op: &Value) -> Result<Value> {
match op {
Value::String(s) => Ok(Value::int(s.len() as i64)),
Value::Table(t) => Ok(Value::int(t.borrow().border() as i64)),
_ => panic!("operand must be string or table"),
}
}
macro_rules! op_num_arith {
( $fn:ident, $int_op:ident, $flt_op:ident ) => {
fn $fn(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
if let (Ok(lhs), Ok(rhs)) = (lhs.to_int(), rhs.to_int()) {
Ok(Numeric::Integer(lhs.$int_op(rhs)))
} else {
Ok(Numeric::Float(lhs.to_float().$flt_op(rhs.to_float())))
}
}
};
}
op_num_arith!(num_add, wrapping_add, add);
op_num_arith!(num_sub, wrapping_sub, sub);
op_num_arith!(num_mul, wrapping_mul, mul);
fn num_rem(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
if let (Numeric::Integer(lhs), Numeric::Integer(rhs)) = (lhs, rhs) {
if *rhs == 0 {
return err!(LuaError::DivideByZero);
}
let mut rem = lhs.wrapping_rem(*rhs);
if rem != 0 && lhs.signum() != rhs.signum() {
rem += rhs;
}
Ok(Numeric::Integer(rem))
} else {
let lhs = lhs.to_float();
let rhs = rhs.to_float();
let mut rem = lhs % rhs;
if rem.abs() > f64::EPSILON && lhs.signum() != rhs.signum() {
rem += rhs;
}
Ok(Numeric::Float(rem))
}
}
fn num_div(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
Ok(Numeric::Float(lhs.to_float().div(rhs.to_float())))
}
fn num_fldiv(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
if let (Numeric::Integer(n1), Numeric::Integer(n2)) = (lhs, rhs) {
if *n2 == 0 {
return err!(LuaError::DivideByZero);
}
let (mut res, _) = n1.overflowing_div(*n2);
if (n1 ^ n2) < 0 && n1 % n2 != 0 {
res -= 1;
}
Ok(Numeric::Integer(res))
} else {
Ok(Numeric::Float(lhs.to_float().div(rhs.to_float()).floor()))
}
}
fn num_pow(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
Ok(Numeric::Float(lhs.to_float().powf(rhs.to_float())))
}
macro_rules! op_num_bitw {
( $fn:ident, $op:ident ) => {
fn $fn(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
let lhs = lhs.coerce_int()?;
let rhs = rhs.coerce_int()?;
Ok(Numeric::Integer(lhs.$op(rhs)))
}
};
}
op_num_bitw!(num_bitand, bitand);
op_num_bitw!(num_bitor, bitor);
op_num_bitw!(num_bitxor, bitxor);
fn num_shl(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
let rhs_int = rhs.coerce_int()?;
if rhs_int < 0 {
num_shr(lhs, &Numeric::Integer(-rhs_int))
} else {
let lhs_int = lhs.coerce_int()?;
let (val, over) = lhs_int.overflowing_shl(rhs_int as u32);
Ok(Numeric::Integer(if over { 0 } else { val }))
}
}
fn num_shr(lhs: &Numeric, rhs: &Numeric) -> Result<Numeric> {
let rhs_int = rhs.coerce_int()?;
if rhs_int < 0 {
num_shl(lhs, &Numeric::Integer(-rhs_int))
} else {
let lhs_int = lhs.coerce_int()?;
let (val, over) = (lhs_int as u64).overflowing_shr(rhs_int as u32);
Ok(Numeric::Integer(if over { 0 } else { val as i64 }))
}
}
fn num_lt(lhs: &Numeric, rhs: &Numeric) -> Result<bool> {
Ok(match (lhs.to_int(), rhs.to_int()) {
(Ok(lhs), Ok(rhs)) => lhs < rhs,
(Ok(lhs), ..) => {
let rhs = Numeric::Float(rhs.to_float().ceil());
rhs.coerce_int()
.map(|rhs| lhs < rhs)
.unwrap_or(rhs.to_float() > 0.)
}
(.., Ok(rhs)) => {
let lhs = Numeric::Float(lhs.to_float().floor());
lhs.coerce_int()
.map(|lhs| lhs < rhs)
.unwrap_or(lhs.to_float() < 0.)
}
_ => lhs.to_float() < rhs.to_float(),
})
}
fn num_le(lhs: &Numeric, rhs: &Numeric) -> Result<bool> {
Ok(match (lhs.to_int(), rhs.to_int()) {
(Ok(lhs), Ok(rhs)) => lhs <= rhs,
(Ok(lhs), ..) => {
let rhs = Numeric::Float(rhs.to_float().floor());
rhs.coerce_int()
.map(|rhs| lhs <= rhs)
.unwrap_or(rhs.to_float() > 0.)
}
(.., Ok(rhs)) => {
let lhs = Numeric::Float(lhs.to_float().ceil());
lhs.coerce_int()
.map(|lhs| lhs <= rhs)
.unwrap_or(lhs.to_float() < 0.)
}
_ => lhs.to_float() <= rhs.to_float(),
})
}
fn num_gt(lhs: &Numeric, rhs: &Numeric) -> Result<bool> {
num_lt(rhs, lhs)
}
fn num_ge(lhs: &Numeric, rhs: &Numeric) -> Result<bool> {
num_le(rhs, lhs)
}
fn num_minus(op: &Numeric) -> Numeric {
match op {
Numeric::Integer(n) => Numeric::Integer(n.wrapping_neg()),
Numeric::Float(f) => Numeric::Float(-f),
}
}
fn num_bitnot(op: &Numeric) -> Result<Numeric> {
Ok(Numeric::Integer(!op.coerce_int()?))
}