use super::{Expr, ExprEvaluator};
use crate::err::*;
use crate::eval::*;
use crate::value::*;
#[derive(Debug, Clone)]
pub struct BinaryOpExpr {
pub left: Index,
pub op: BinaryOperator,
pub right: Index,
}
impl Expr for BinaryOpExpr {
fn resolve(
&self,
index: Index,
evaluation: &mut dyn ExprEvaluator,
) -> Result<ValueOrReference, ErrorKind> {
evaluation.ensure_resolved(vec![self.left, self.right])?;
let left = evaluation.get_value(self.left)?;
let right = evaluation.get_value(self.right)?;
self.op
.apply(left, right)
.map(Into::into)
.map_err(|op_error| ErrorKind::BinaryOperationNotAllowed {
index,
left: self.left,
right: self.right,
error: op_error,
})
}
}
#[derive(Debug, Copy, Clone)]
pub enum BinaryOperator {
Add,
Sub,
Mul,
Div,
Mod,
Eq,
NotEq,
LessEq,
GreaterEq,
Less,
Greater,
And,
Or,
}
impl BinaryOperator {
pub fn verb(&self) -> &'static str {
use BinaryOperator::*;
match self {
Add => "add",
Sub => "subtract",
Mul => "multiply",
Div => "divide",
Mod => "modulo",
Eq | NotEq | LessEq | GreaterEq | Less | Greater => "compare",
And | Or => "logically chain",
}
}
pub fn apply(self, left: Value, right: Value) -> Result<Value, OpError> {
match self {
BinaryOperator::Add => match (&left, &right) {
(Value::Integer(l), Value::Integer(r)) => Ok(Value::Integer(l + r)),
(Value::Decimal(l), Value::Decimal(r)) => Ok(Value::Decimal(l + r)),
(Value::Integer(l), Value::Decimal(r)) => Ok(Value::Decimal(*l as f64 + r)),
(Value::Decimal(l), Value::Integer(r)) => Ok(Value::Decimal(l + *r as f64)),
(Value::String(l), Value::String(r)) => Ok(Value::String(l.clone() + r)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::Sub => match (&left, &right) {
(Value::Integer(l), Value::Integer(r)) => Ok(Value::Integer(l - r)),
(Value::Decimal(l), Value::Decimal(r)) => Ok(Value::Decimal(l - r)),
(Value::Integer(l), Value::Decimal(r)) => Ok(Value::Decimal(*l as f64 - r)),
(Value::Decimal(l), Value::Integer(r)) => Ok(Value::Decimal(l - *r as f64)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::Mul => match (&left, &right) {
(Value::Integer(l), Value::Integer(r)) => Ok(Value::Integer(l * r)),
(Value::Decimal(l), Value::Decimal(r)) => Ok(Value::Decimal(l * r)),
(Value::Integer(l), Value::Decimal(r)) => Ok(Value::Decimal(*l as f64 * r)),
(Value::Decimal(l), Value::Integer(r)) => Ok(Value::Decimal(l * *r as f64)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::Div => match (&left, &right) {
(Value::Integer(l), Value::Integer(r)) => Ok(Value::Integer(l / r)),
(Value::Decimal(l), Value::Decimal(r)) => Ok(Value::Decimal(l / r)),
(Value::Integer(l), Value::Decimal(r)) => Ok(Value::Decimal(*l as f64 / r)),
(Value::Decimal(l), Value::Integer(r)) => Ok(Value::Decimal(l / *r as f64)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::Mod => match (&left, &right) {
(Value::Integer(l), Value::Integer(r)) => Ok(Value::Integer(l % r)),
(Value::Decimal(l), Value::Decimal(r)) => Ok(Value::Decimal(l % r)),
(Value::Integer(l), Value::Decimal(r)) => Ok(Value::Decimal(*l as f64 % r)),
(Value::Decimal(l), Value::Integer(r)) => Ok(Value::Decimal(l % *r as f64)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::Eq => match (&left, &right) {
(Value::Boolean(l), Value::Boolean(r)) => Ok(Value::Boolean(l == r)),
(Value::Integer(l), Value::Integer(r)) => Ok(Value::Boolean(l == r)),
(Value::Decimal(l), Value::Decimal(r)) => Ok(Value::Boolean(l == r)),
(Value::Integer(l), Value::Decimal(r)) => Ok(Value::Boolean(*l as f64 == *r)),
(Value::Decimal(l), Value::Integer(r)) => Ok(Value::Boolean(*l == *r as f64)),
(Value::String(l), Value::String(r)) => Ok(Value::Boolean(l == r)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::NotEq => match (&left, &right) {
(Value::Boolean(l), Value::Boolean(r)) => Ok(Value::Boolean(l != r)),
(Value::Integer(l), Value::Integer(r)) => Ok(Value::Boolean(l != r)),
(Value::Decimal(l), Value::Decimal(r)) => Ok(Value::Boolean(l != r)),
(Value::Integer(l), Value::Decimal(r)) => Ok(Value::Boolean(*l as f64 != *r)),
(Value::Decimal(l), Value::Integer(r)) => Ok(Value::Boolean(*l != *r as f64)),
(Value::String(l), Value::String(r)) => Ok(Value::Boolean(l != r)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::Less => match (&left, &right) {
(Value::Integer(l), Value::Integer(r)) => Ok(Value::Boolean(l < r)),
(Value::Decimal(l), Value::Decimal(r)) => Ok(Value::Boolean(l < r)),
(Value::Integer(l), Value::Decimal(r)) => Ok(Value::Boolean((*l as f64) < *r)),
(Value::Decimal(l), Value::Integer(r)) => Ok(Value::Boolean(*l < *r as f64)),
(Value::String(l), Value::String(r)) => Ok(Value::Boolean(l < r)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::LessEq => match (&left, &right) {
(Value::Integer(l), Value::Integer(r)) => Ok(Value::Boolean(l <= r)),
(Value::Decimal(l), Value::Decimal(r)) => Ok(Value::Boolean(l <= r)),
(Value::Integer(l), Value::Decimal(r)) => Ok(Value::Boolean((*l as f64) <= *r)),
(Value::Decimal(l), Value::Integer(r)) => Ok(Value::Boolean(*l <= *r as f64)),
(Value::String(l), Value::String(r)) => Ok(Value::Boolean(l <= r)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::Greater => match (&left, &right) {
(Value::Integer(l), Value::Integer(r)) => Ok(Value::Boolean(l > r)),
(Value::Decimal(l), Value::Decimal(r)) => Ok(Value::Boolean(l > r)),
(Value::Integer(l), Value::Decimal(r)) => Ok(Value::Boolean((*l as f64) > *r)),
(Value::Decimal(l), Value::Integer(r)) => Ok(Value::Boolean(*l > *r as f64)),
(Value::String(l), Value::String(r)) => Ok(Value::Boolean(l > r)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::GreaterEq => match (&left, &right) {
(Value::Integer(l), Value::Integer(r)) => Ok(Value::Boolean(l >= r)),
(Value::Decimal(l), Value::Decimal(r)) => Ok(Value::Boolean(l >= r)),
(Value::Integer(l), Value::Decimal(r)) => Ok(Value::Boolean((*l as f64) >= *r)),
(Value::Decimal(l), Value::Integer(r)) => Ok(Value::Boolean(*l >= *r as f64)),
(Value::String(l), Value::String(r)) => Ok(Value::Boolean(l >= r)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::And => match (&left, &right) {
(Value::Boolean(l), Value::Boolean(r)) => Ok(Value::Boolean(*l && *r)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
BinaryOperator::Or => match (&left, &right) {
(Value::Boolean(l), Value::Boolean(r)) => Ok(Value::Boolean(*l || *r)),
_ => Err(OpError::BinaryOpTypeNotAllowed {
left: left.into(),
op: self,
right: right.into(),
}),
},
}
}
}
impl std::fmt::Display for BinaryOperator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BinaryOperator::Add => write!(f, "+"),
BinaryOperator::Sub => write!(f, "-"),
BinaryOperator::Mul => write!(f, "*"),
BinaryOperator::Div => write!(f, "/"),
BinaryOperator::Mod => write!(f, "%"),
BinaryOperator::Eq => write!(f, "=="),
BinaryOperator::NotEq => write!(f, "!="),
BinaryOperator::LessEq => write!(f, "<="),
BinaryOperator::GreaterEq => write!(f, ">="),
BinaryOperator::Less => write!(f, "<"),
BinaryOperator::Greater => write!(f, ">"),
BinaryOperator::And => write!(f, "&&"),
BinaryOperator::Or => write!(f, "||"),
}
}
}
#[derive(thiserror::Error, Clone, Debug)]
pub enum OpError {
#[error("Binary operation {op} not allowed between {left} and {right}")]
BinaryOpTypeNotAllowed {
left: ValueKind,
op: BinaryOperator,
right: ValueKind,
},
}