use arrow::datatypes::i256;
use datafusion_common::{Result, ScalarValue, internal_err};
use datafusion_expr::{
Case, Expr, Like, Operator,
expr::{Between, BinaryExpr, InList},
expr_fn::{and, bitwise_and, bitwise_or, or},
};
pub static POWS_OF_TEN: [i128; 38] = [
1,
10,
100,
1000,
10000,
100000,
1000000,
10000000,
100000000,
1000000000,
10000000000,
100000000000,
1000000000000,
10000000000000,
100000000000000,
1000000000000000,
10000000000000000,
100000000000000000,
1000000000000000000,
10000000000000000000,
100000000000000000000,
1000000000000000000000,
10000000000000000000000,
100000000000000000000000,
1000000000000000000000000,
10000000000000000000000000,
100000000000000000000000000,
1000000000000000000000000000,
10000000000000000000000000000,
100000000000000000000000000000,
1000000000000000000000000000000,
10000000000000000000000000000000,
100000000000000000000000000000000,
1000000000000000000000000000000000,
10000000000000000000000000000000000,
100000000000000000000000000000000000,
1000000000000000000000000000000000000,
10000000000000000000000000000000000000,
];
fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => {
expr_contains_inner(left, needle, search_op)
|| expr_contains_inner(right, needle, search_op)
}
_ => expr == needle,
}
}
pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
expr_contains_inner(expr, needle, search_op) && !needle.is_volatile()
}
pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr {
fn recursive_delete_xor_in_expr(
expr: &Expr,
needle: &Expr,
xor_counter: &mut i32,
) -> Expr {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right })
if *op == Operator::BitwiseXor =>
{
let left_expr = recursive_delete_xor_in_expr(left, needle, xor_counter);
let right_expr = recursive_delete_xor_in_expr(right, needle, xor_counter);
if left_expr == *needle {
*xor_counter += 1;
return right_expr;
} else if right_expr == *needle {
*xor_counter += 1;
return left_expr;
}
Expr::BinaryExpr(BinaryExpr::new(
Box::new(left_expr),
*op,
Box::new(right_expr),
))
}
_ => expr.clone(),
}
}
let mut xor_counter: i32 = 0;
let result_expr = recursive_delete_xor_in_expr(expr, needle, &mut xor_counter);
if result_expr == *needle {
return needle.clone();
} else if xor_counter % 2 == 0 {
if is_left {
return Expr::BinaryExpr(BinaryExpr::new(
Box::new(needle.clone()),
Operator::BitwiseXor,
Box::new(result_expr),
));
} else {
return Expr::BinaryExpr(BinaryExpr::new(
Box::new(result_expr),
Operator::BitwiseXor,
Box::new(needle.clone()),
));
}
}
result_expr
}
pub fn is_zero(s: &Expr) -> bool {
match s {
Expr::Literal(ScalarValue::Int8(Some(0)), _)
| Expr::Literal(ScalarValue::Int16(Some(0)), _)
| Expr::Literal(ScalarValue::Int32(Some(0)), _)
| Expr::Literal(ScalarValue::Int64(Some(0)), _)
| Expr::Literal(ScalarValue::UInt8(Some(0)), _)
| Expr::Literal(ScalarValue::UInt16(Some(0)), _)
| Expr::Literal(ScalarValue::UInt32(Some(0)), _)
| Expr::Literal(ScalarValue::UInt64(Some(0)), _) => true,
Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 0. => true,
Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 0. => true,
Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s), _) if *v == 0 => true,
Expr::Literal(ScalarValue::Decimal256(Some(v), _p, _s), _)
if *v == i256::ZERO =>
{
true
}
_ => false,
}
}
pub fn is_one(s: &Expr) -> bool {
match s {
Expr::Literal(ScalarValue::Int8(Some(1)), _)
| Expr::Literal(ScalarValue::Int16(Some(1)), _)
| Expr::Literal(ScalarValue::Int32(Some(1)), _)
| Expr::Literal(ScalarValue::Int64(Some(1)), _)
| Expr::Literal(ScalarValue::UInt8(Some(1)), _)
| Expr::Literal(ScalarValue::UInt16(Some(1)), _)
| Expr::Literal(ScalarValue::UInt32(Some(1)), _)
| Expr::Literal(ScalarValue::UInt64(Some(1)), _) => true,
Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 1. => true,
Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 1. => true,
Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s), _) => {
*s >= 0
&& POWS_OF_TEN
.get(*s as usize)
.map(|x| x == v)
.unwrap_or_default()
}
Expr::Literal(ScalarValue::Decimal256(Some(v), _p, s), _) => {
*s >= 0
&& match i256::from(10).checked_pow(*s as u32) {
Some(res) => res == *v,
None => false,
}
}
_ => false,
}
}
pub fn is_true(expr: &Expr) -> bool {
match expr {
Expr::Literal(ScalarValue::Boolean(Some(v)), _) => *v,
_ => false,
}
}
pub fn is_bool_lit(expr: &Expr) -> bool {
matches!(expr, Expr::Literal(ScalarValue::Boolean(_), _))
}
pub fn lit_bool_null() -> Expr {
Expr::Literal(ScalarValue::Boolean(None), None)
}
pub fn is_null(expr: &Expr) -> bool {
match expr {
Expr::Literal(v, _) => v.is_null(),
_ => false,
}
}
pub fn is_false(expr: &Expr) -> bool {
match expr {
Expr::Literal(ScalarValue::Boolean(Some(v)), _) => !(*v),
_ => false,
}
}
pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool {
matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && !needle.is_volatile())
}
pub fn can_reduce_to_equal_statement(haystack: &Expr, needle: &Expr) -> bool {
match (haystack, needle) {
(
Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::GtEq,
right,
}),
Expr::BinaryExpr(BinaryExpr {
left: n_left,
op: Operator::LtEq,
right: n_right,
}),
) if left == n_left && right == n_right => true,
_ => false,
}
}
pub fn is_not_of(not_expr: &Expr, expr: &Expr) -> bool {
matches!(not_expr, Expr::Not(inner) if expr == inner.as_ref())
}
pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool {
matches!(not_expr, Expr::Negative(inner) if expr == inner.as_ref())
}
pub fn as_bool_lit(expr: &Expr) -> Result<Option<bool>> {
match expr {
Expr::Literal(ScalarValue::Boolean(v), _) => Ok(*v),
_ => internal_err!("Expected boolean literal, got {expr:?}"),
}
}
pub fn is_case_with_literal_outputs(expr: &Expr) -> bool {
match expr {
Expr::Case(Case {
expr: None,
when_then_expr,
else_expr,
}) => {
when_then_expr.iter().all(|(_, then)| is_lit(then))
&& else_expr.as_deref().is_none_or(is_lit)
}
_ => false,
}
}
pub fn into_case(expr: Expr) -> Result<Case> {
match expr {
Expr::Case(case) => Ok(case),
_ => internal_err!("Expected case, got {expr:?}"),
}
}
pub fn is_lit(expr: &Expr) -> bool {
matches!(expr, Expr::Literal(_, _))
}
pub fn is_eq_and_ne_with_different_literal(eq_expr: &Expr, ne_expr: &Expr) -> bool {
fn extract_var_and_literal(expr: &Expr) -> Option<(&Expr, &Expr)> {
match expr {
Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Eq,
right,
})
| Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::NotEq,
right,
}) => match (left.as_ref(), right.as_ref()) {
(Expr::Literal(_, _), var) => Some((var, left)),
(var, Expr::Literal(_, _)) => Some((var, right)),
_ => None,
},
_ => None,
}
}
match (eq_expr, ne_expr) {
(
Expr::BinaryExpr(BinaryExpr {
op: Operator::Eq, ..
}),
Expr::BinaryExpr(BinaryExpr {
op: Operator::NotEq,
..
}),
) => {
if let (Some((var1, lit1)), Some((var2, lit2))) = (
extract_var_and_literal(eq_expr),
extract_var_and_literal(ne_expr),
) && var1 == var2
&& lit1 != lit2
{
return true;
}
false
}
_ => false,
}
}
pub fn negate_clause(expr: Expr) -> Expr {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
if let Some(negated_op) = op.negate() {
return Expr::BinaryExpr(BinaryExpr::new(left, negated_op, right));
}
match op {
Operator::And => {
let left = negate_clause(*left);
let right = negate_clause(*right);
or(left, right)
}
Operator::Or => {
let left = negate_clause(*left);
let right = negate_clause(*right);
and(left, right)
}
_ => Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr::new(
left, op, right,
)))),
}
}
Expr::Not(expr) => *expr,
Expr::IsNotNull(expr) => expr.is_null(),
Expr::IsNull(expr) => expr.is_not_null(),
Expr::InList(InList {
expr,
list,
negated,
}) => expr.in_list(list, !negated),
Expr::Between(between) => Expr::Between(Between::new(
between.expr,
!between.negated,
between.low,
between.high,
)),
Expr::Like(like) => Expr::Like(Like::new(
!like.negated,
like.expr,
like.pattern,
like.escape_char,
like.case_insensitive,
)),
_ => Expr::Not(Box::new(expr)),
}
}
pub fn distribute_negation(expr: Expr) -> Expr {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
match op {
Operator::BitwiseAnd => {
let left = distribute_negation(*left);
let right = distribute_negation(*right);
bitwise_or(left, right)
}
Operator::BitwiseOr => {
let left = distribute_negation(*left);
let right = distribute_negation(*right);
bitwise_and(left, right)
}
_ => Expr::Negative(Box::new(Expr::BinaryExpr(BinaryExpr::new(
left, op, right,
)))),
}
}
Expr::Negative(expr) => *expr,
_ => Expr::Negative(Box::new(expr)),
}
}
#[cfg(test)]
mod tests {
use super::{is_one, is_zero};
use arrow::datatypes::i256;
use datafusion_common::ScalarValue;
use datafusion_expr::lit;
#[test]
fn test_is_zero() {
assert!(is_zero(&lit(ScalarValue::Int8(Some(0)))));
assert!(is_zero(&lit(ScalarValue::Float32(Some(0.0)))));
assert!(is_zero(&lit(ScalarValue::Decimal128(
Some(i128::from(0)),
9,
0
))));
assert!(is_zero(&lit(ScalarValue::Decimal128(
Some(i128::from(0)),
9,
5
))));
assert!(is_zero(&lit(ScalarValue::Decimal256(
Some(i256::ZERO),
9,
0
))));
assert!(is_zero(&lit(ScalarValue::Decimal256(
Some(i256::ZERO),
9,
5
))));
}
#[test]
fn test_is_one() {
assert!(is_one(&lit(ScalarValue::Int8(Some(1)))));
assert!(is_one(&lit(ScalarValue::Float32(Some(1.0)))));
assert!(is_one(&lit(ScalarValue::Decimal128(
Some(i128::from(1)),
9,
0
))));
assert!(is_one(&lit(ScalarValue::Decimal128(
Some(i128::from(10)),
9,
1
))));
assert!(is_one(&lit(ScalarValue::Decimal128(
Some(i128::from(100)),
9,
2
))));
assert!(is_one(&lit(ScalarValue::Decimal256(
Some(i256::from(1)),
9,
0
))));
assert!(is_one(&lit(ScalarValue::Decimal256(
Some(i256::from(10)),
9,
1
))));
assert!(is_one(&lit(ScalarValue::Decimal256(
Some(i256::from(100)),
9,
2
))));
assert!(!is_one(&lit(ScalarValue::Decimal256(
Some(i256::from(100)),
9,
-1
))));
}
}