use rust_decimal::Decimal;
use crate::ast::{BinaryOp, BinaryOperator, UnaryOp, UnaryOperator};
use crate::error::QueryError;
use super::Executor;
use super::types::{Interval, PostingContext, Value};
impl Executor<'_> {
pub(super) fn evaluate_binary_op(
&self,
op: &BinaryOp,
ctx: &PostingContext,
) -> Result<Value, QueryError> {
let left = self.evaluate_expr(&op.left, ctx)?;
let right = self.evaluate_expr(&op.right, ctx)?;
match op.op {
BinaryOperator::Eq => Ok(Value::Boolean(self.values_equal(&left, &right))),
BinaryOperator::Ne => Ok(Value::Boolean(!self.values_equal(&left, &right))),
BinaryOperator::Lt => self.compare_values(&left, &right, std::cmp::Ordering::is_lt),
BinaryOperator::Le => self.compare_values(&left, &right, std::cmp::Ordering::is_le),
BinaryOperator::Gt => self.compare_values(&left, &right, std::cmp::Ordering::is_gt),
BinaryOperator::Ge => self.compare_values(&left, &right, std::cmp::Ordering::is_ge),
BinaryOperator::And => {
let l = self.to_bool(&left)?;
let r = self.to_bool(&right)?;
Ok(Value::Boolean(l && r))
}
BinaryOperator::Or => {
let l = self.to_bool(&left)?;
let r = self.to_bool(&right)?;
Ok(Value::Boolean(l || r))
}
BinaryOperator::Regex => {
let s = match left {
Value::String(s) => s,
Value::Null => return Ok(Value::Boolean(false)),
_ => {
return Err(QueryError::Type(
"regex requires string left operand".to_string(),
));
}
};
let pattern = match right {
Value::String(p) => p,
_ => {
return Err(QueryError::Type(
"regex requires string pattern".to_string(),
));
}
};
let re = self.require_regex(&pattern)?;
Ok(Value::Boolean(re.is_match(&s)))
}
BinaryOperator::In => {
match right {
Value::StringSet(set) => {
let needle = match left {
Value::String(s) => s,
_ => {
return Err(QueryError::Type(
"IN requires string left operand".to_string(),
));
}
};
Ok(Value::Boolean(set.contains(&needle)))
}
_ => Err(QueryError::Type(
"IN requires set right operand".to_string(),
)),
}
}
BinaryOperator::NotRegex => {
let s = match left {
Value::String(s) => s,
Value::Null => return Ok(Value::Boolean(true)),
_ => {
return Err(QueryError::Type(
"!~ requires string left operand".to_string(),
));
}
};
let pattern = match right {
Value::String(p) => p,
_ => {
return Err(QueryError::Type("!~ requires string pattern".to_string()));
}
};
let re = self.require_regex(&pattern)?;
Ok(Value::Boolean(!re.is_match(&s)))
}
BinaryOperator::NotIn => {
match right {
Value::StringSet(set) => {
let needle = match left {
Value::String(s) => s,
_ => {
return Err(QueryError::Type(
"NOT IN requires string left operand".to_string(),
));
}
};
Ok(Value::Boolean(!set.contains(&needle)))
}
_ => Err(QueryError::Type(
"NOT IN requires set right operand".to_string(),
)),
}
}
BinaryOperator::Add => {
match (&left, &right) {
(Value::Date(d), Value::Interval(i)) | (Value::Interval(i), Value::Date(d)) => {
i.add_to_date(*d)
.map(Value::Date)
.ok_or_else(|| QueryError::Evaluation("date overflow".to_string()))
}
_ => self.arithmetic_op(&left, &right, |a, b| a + b),
}
}
BinaryOperator::Sub => {
match (&left, &right) {
(Value::Date(d), Value::Interval(i)) => {
let neg_count = i.count.checked_neg().ok_or_else(|| {
QueryError::Evaluation("interval count overflow".to_string())
})?;
let neg_interval = Interval::new(neg_count, i.unit);
neg_interval
.add_to_date(*d)
.map(Value::Date)
.ok_or_else(|| QueryError::Evaluation("date overflow".to_string()))
}
_ => self.arithmetic_op(&left, &right, |a, b| a - b),
}
}
BinaryOperator::Mul => self.arithmetic_op(&left, &right, |a, b| a * b),
BinaryOperator::Div => self.arithmetic_op(&left, &right, |a, b| a / b),
BinaryOperator::Mod => self.arithmetic_op(&left, &right, |a, b| a % b),
}
}
pub(super) fn evaluate_unary_op(
&self,
op: &UnaryOp,
ctx: &PostingContext,
) -> Result<Value, QueryError> {
let val = self.evaluate_expr(&op.operand, ctx)?;
self.unary_op_on_value(op.op, &val)
}
pub(super) fn unary_op_on_value(
&self,
op: UnaryOperator,
val: &Value,
) -> Result<Value, QueryError> {
match op {
UnaryOperator::Not => {
let b = self.to_bool(val)?;
Ok(Value::Boolean(!b))
}
UnaryOperator::Neg => match val {
Value::Number(n) => Ok(Value::Number(-*n)),
Value::Integer(i) => Ok(Value::Integer(-*i)),
_ => Err(QueryError::Type(
"negation requires numeric value".to_string(),
)),
},
UnaryOperator::IsNull => Ok(Value::Boolean(matches!(val, Value::Null))),
UnaryOperator::IsNotNull => Ok(Value::Boolean(!matches!(val, Value::Null))),
}
}
pub(super) fn values_equal(&self, left: &Value, right: &Value) -> bool {
match (left, right) {
(Value::Null, Value::Null) => true,
(Value::String(a), Value::String(b)) => a == b,
(Value::Number(a), Value::Number(b)) => a == b,
(Value::Integer(a), Value::Integer(b)) => a == b,
(Value::Number(a), Value::Integer(b)) => *a == Decimal::from(*b),
(Value::Integer(a), Value::Number(b)) => Decimal::from(*a) == *b,
(Value::Date(a), Value::Date(b)) => a == b,
(Value::Boolean(a), Value::Boolean(b)) => a == b,
_ => false,
}
}
pub(super) fn compare_values<F>(
&self,
left: &Value,
right: &Value,
pred: F,
) -> Result<Value, QueryError>
where
F: FnOnce(std::cmp::Ordering) -> bool,
{
let ord = match (left, right) {
(Value::Number(a), Value::Number(b)) => a.cmp(b),
(Value::Integer(a), Value::Integer(b)) => a.cmp(b),
(Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
(Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
(Value::String(a), Value::String(b)) => a.cmp(b),
(Value::Date(a), Value::Date(b)) => a.cmp(b),
_ => return Err(QueryError::Type("cannot compare values".to_string())),
};
Ok(Value::Boolean(pred(ord)))
}
pub(super) fn value_less_than(&self, left: &Value, right: &Value) -> Result<bool, QueryError> {
let ord = match (left, right) {
(Value::Number(a), Value::Number(b)) => a.cmp(b),
(Value::Integer(a), Value::Integer(b)) => a.cmp(b),
(Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
(Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
(Value::String(a), Value::String(b)) => a.cmp(b),
(Value::Date(a), Value::Date(b)) => a.cmp(b),
_ => return Err(QueryError::Type("cannot compare values".to_string())),
};
Ok(ord.is_lt())
}
pub(super) fn arithmetic_op<F>(
&self,
left: &Value,
right: &Value,
op: F,
) -> Result<Value, QueryError>
where
F: FnOnce(Decimal, Decimal) -> Decimal,
{
let (a, b) = match (left, right) {
(Value::Number(a), Value::Number(b)) => (*a, *b),
(Value::Integer(a), Value::Integer(b)) => (Decimal::from(*a), Decimal::from(*b)),
(Value::Number(a), Value::Integer(b)) => (*a, Decimal::from(*b)),
(Value::Integer(a), Value::Number(b)) => (Decimal::from(*a), *b),
_ => {
return Err(QueryError::Type(
"arithmetic requires numeric values".to_string(),
));
}
};
Ok(Value::Number(op(a, b)))
}
pub(super) fn to_bool(&self, val: &Value) -> Result<bool, QueryError> {
match val {
Value::Boolean(b) => Ok(*b),
Value::Null => Ok(false),
_ => Err(QueryError::Type("expected boolean".to_string())),
}
}
pub(super) fn binary_op_on_values(
&self,
op: BinaryOperator,
left: &Value,
right: &Value,
) -> Result<Value, QueryError> {
match op {
BinaryOperator::Eq => Ok(Value::Boolean(self.values_equal(left, right))),
BinaryOperator::Ne => Ok(Value::Boolean(!self.values_equal(left, right))),
BinaryOperator::Lt => self.compare_values(left, right, std::cmp::Ordering::is_lt),
BinaryOperator::Le => self.compare_values(left, right, std::cmp::Ordering::is_le),
BinaryOperator::Gt => self.compare_values(left, right, std::cmp::Ordering::is_gt),
BinaryOperator::Ge => self.compare_values(left, right, std::cmp::Ordering::is_ge),
BinaryOperator::And => {
let l = self.to_bool(left)?;
let r = self.to_bool(right)?;
Ok(Value::Boolean(l && r))
}
BinaryOperator::Or => {
let l = self.to_bool(left)?;
let r = self.to_bool(right)?;
Ok(Value::Boolean(l || r))
}
BinaryOperator::Regex => {
let s = match left {
Value::String(s) => s,
Value::Null => return Ok(Value::Boolean(false)),
_ => {
return Err(QueryError::Type(
"regex requires string left operand".to_string(),
));
}
};
let pattern = match right {
Value::String(p) => p,
_ => {
return Err(QueryError::Type(
"regex requires string pattern".to_string(),
));
}
};
let re = self.require_regex(pattern)?;
Ok(Value::Boolean(re.is_match(s)))
}
BinaryOperator::In => {
match right {
Value::StringSet(set) => {
let needle = match left {
Value::String(s) => s,
_ => {
return Err(QueryError::Type(
"IN requires string left operand".to_string(),
));
}
};
Ok(Value::Boolean(set.contains(needle)))
}
_ => Err(QueryError::Type(
"IN requires set right operand".to_string(),
)),
}
}
BinaryOperator::NotRegex => {
let s = match left {
Value::String(s) => s,
Value::Null => return Ok(Value::Boolean(true)),
_ => {
return Err(QueryError::Type(
"!~ requires string left operand".to_string(),
));
}
};
let pattern = match right {
Value::String(p) => p,
_ => {
return Err(QueryError::Type("!~ requires string pattern".to_string()));
}
};
let re = self.require_regex(pattern)?;
Ok(Value::Boolean(!re.is_match(s)))
}
BinaryOperator::NotIn => {
match right {
Value::StringSet(set) => {
let needle = match left {
Value::String(s) => s,
_ => {
return Err(QueryError::Type(
"NOT IN requires string left operand".to_string(),
));
}
};
Ok(Value::Boolean(!set.contains(needle)))
}
_ => Err(QueryError::Type(
"NOT IN requires set right operand".to_string(),
)),
}
}
BinaryOperator::Add => {
match (left, right) {
(Value::Date(d), Value::Interval(i)) | (Value::Interval(i), Value::Date(d)) => {
i.add_to_date(*d)
.map(Value::Date)
.ok_or_else(|| QueryError::Evaluation("date overflow".to_string()))
}
_ => self.arithmetic_op(left, right, |a, b| a + b),
}
}
BinaryOperator::Sub => {
match (left, right) {
(Value::Date(d), Value::Interval(i)) => {
let neg_count = i.count.checked_neg().ok_or_else(|| {
QueryError::Evaluation("interval count overflow".to_string())
})?;
let neg_interval = Interval::new(neg_count, i.unit);
neg_interval
.add_to_date(*d)
.map(Value::Date)
.ok_or_else(|| QueryError::Evaluation("date overflow".to_string()))
}
_ => self.arithmetic_op(left, right, |a, b| a - b),
}
}
BinaryOperator::Mul => self.arithmetic_op(left, right, |a, b| a * b),
BinaryOperator::Div => self.arithmetic_op(left, right, |a, b| a / b),
BinaryOperator::Mod => self.arithmetic_op(left, right, |a, b| a % b),
}
}
pub(super) fn compare_values_for_sort(
&self,
left: &Value,
right: &Value,
) -> std::cmp::Ordering {
match (left, right) {
(Value::Null, Value::Null) => std::cmp::Ordering::Equal,
(Value::Null, _) => std::cmp::Ordering::Greater, (_, Value::Null) => std::cmp::Ordering::Less,
(Value::Number(a), Value::Number(b)) => a.cmp(b),
(Value::Integer(a), Value::Integer(b)) => a.cmp(b),
(Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
(Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
(Value::String(a), Value::String(b)) => a.cmp(b),
(Value::Date(a), Value::Date(b)) => a.cmp(b),
(Value::Boolean(a), Value::Boolean(b)) => a.cmp(b),
(Value::Amount(a), Value::Amount(b)) => a.number.cmp(&b.number),
(Value::Position(a), Value::Position(b)) => a.units.number.cmp(&b.units.number),
(Value::Inventory(a), Value::Inventory(b)) => {
let a_val = a.positions().first().map(|p| &p.units.number);
let b_val = b.positions().first().map(|p| &p.units.number);
match (a_val, b_val) {
(Some(av), Some(bv)) => av.cmp(bv),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => std::cmp::Ordering::Equal,
}
}
(Value::Interval(a), Value::Interval(b)) => a.to_approx_days().cmp(&b.to_approx_days()),
_ => std::cmp::Ordering::Equal, }
}
}