use std::sync::Arc;
use arrow::datatypes::{DataType, Schema};
use datafusion_common::{Result, ScalarValue, tree_node::Transformed};
use datafusion_expr::Operator;
use datafusion_expr_common::casts::try_cast_literal_to_type;
use crate::PhysicalExpr;
use crate::expressions::{BinaryExpr, CastExpr, Literal, TryCastExpr, lit};
pub(crate) fn unwrap_cast_in_comparison(
expr: Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
if let Some(binary) = expr.as_any().downcast_ref::<BinaryExpr>()
&& let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)?
{
return Ok(Transformed::yes(unwrapped));
}
Ok(Transformed::no(expr))
}
fn try_unwrap_cast_binary(
binary: &BinaryExpr,
schema: &Schema,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
if let (Some((inner_expr, _cast_type)), Some(literal)) = (
extract_cast_info(binary.left()),
binary.right().as_any().downcast_ref::<Literal>(),
) && binary.op().supports_propagation()
&& let Some(unwrapped) = try_unwrap_cast_comparison(
Arc::clone(inner_expr),
literal.value(),
*binary.op(),
schema,
)?
{
return Ok(Some(unwrapped));
}
if let (Some(literal), Some((inner_expr, _cast_type))) = (
binary.left().as_any().downcast_ref::<Literal>(),
extract_cast_info(binary.right()),
) {
if let Some(swapped_op) = binary.op().swap()
&& binary.op().supports_propagation()
&& let Some(unwrapped) = try_unwrap_cast_comparison(
Arc::clone(inner_expr),
literal.value(),
swapped_op,
schema,
)?
{
return Ok(Some(unwrapped));
}
}
Ok(None)
}
fn extract_cast_info(
expr: &Arc<dyn PhysicalExpr>,
) -> Option<(&Arc<dyn PhysicalExpr>, &DataType)> {
if let Some(cast) = expr.as_any().downcast_ref::<CastExpr>() {
Some((cast.expr(), cast.cast_type()))
} else if let Some(try_cast) = expr.as_any().downcast_ref::<TryCastExpr>() {
Some((try_cast.expr(), try_cast.cast_type()))
} else {
None
}
}
fn try_unwrap_cast_comparison(
inner_expr: Arc<dyn PhysicalExpr>,
literal_value: &ScalarValue,
op: Operator,
schema: &Schema,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
let inner_type = inner_expr.data_type(schema)?;
if let Some(casted_literal) = try_cast_literal_to_type(literal_value, &inner_type) {
let literal_expr = lit(casted_literal);
let binary_expr = BinaryExpr::new(inner_expr, op, literal_expr);
return Ok(Some(Arc::new(binary_expr)));
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::{col, lit};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{ScalarValue, tree_node::TreeNode};
use datafusion_expr::Operator;
fn is_cast_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
expr.as_any().downcast_ref::<CastExpr>().is_some()
|| expr.as_any().downcast_ref::<TryCastExpr>().is_some()
}
fn is_binary_expr_with_cast_and_literal(binary: &BinaryExpr) -> bool {
let left_cast_right_literal = is_cast_expr(binary.left())
&& binary.right().as_any().downcast_ref::<Literal>().is_some();
let left_literal_right_cast =
binary.left().as_any().downcast_ref::<Literal>().is_some()
&& is_cast_expr(binary.right());
left_cast_right_literal || left_literal_right_cast
}
fn test_schema() -> Schema {
Schema::new(vec![
Field::new("c1", DataType::Int32, false),
Field::new("c2", DataType::Int64, false),
Field::new("c3", DataType::Utf8, false),
])
}
#[test]
fn test_unwrap_cast_in_binary_comparison() {
let schema = test_schema();
let column_expr = col("c1", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
let literal_expr = lit(10i64);
let binary_expr =
Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(result.transformed);
let optimized = result.data;
let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert!(!is_cast_expr(optimized_binary.left()));
let right_literal = optimized_binary
.right()
.as_any()
.downcast_ref::<Literal>()
.unwrap();
assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(10)));
}
#[test]
fn test_unwrap_cast_with_literal_on_left() {
let schema = test_schema();
let column_expr = col("c1", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
let literal_expr = lit(10i64);
let binary_expr =
Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(result.transformed);
let optimized = result.data;
let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert_eq!(*optimized_binary.op(), Operator::Gt);
}
#[test]
fn test_no_unwrap_when_types_unsupported() {
let schema = Schema::new(vec![Field::new("f1", DataType::Float32, false)]);
let column_expr = col("f1", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Float64, None));
let literal_expr = lit(10.5f64);
let binary_expr =
Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(!result.transformed);
}
#[test]
fn test_is_binary_expr_with_cast_and_literal() {
let schema = test_schema();
let column_expr = col("c1", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
let literal_expr = lit(10i64);
let binary_expr =
Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
let binary_ref = binary_expr.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert!(is_binary_expr_with_cast_and_literal(binary_ref));
}
#[test]
fn test_unwrap_cast_literal_on_left_side() {
let schema = Schema::new(vec![Field::new(
"decimal_col",
DataType::Decimal128(9, 2),
true,
)]);
let column_expr = col("decimal_col", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(
column_expr,
DataType::Decimal128(22, 2),
None,
));
let literal_expr = lit(ScalarValue::Decimal128(Some(400), 22, 2));
let binary_expr =
Arc::new(BinaryExpr::new(literal_expr, Operator::LtEq, cast_expr));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(result.transformed);
let optimized = result.data;
let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert_eq!(*optimized_binary.op(), Operator::GtEq);
assert!(!is_cast_expr(optimized_binary.left()));
let right_literal = optimized_binary
.right()
.as_any()
.downcast_ref::<Literal>()
.unwrap();
assert_eq!(
right_literal.value().data_type(),
DataType::Decimal128(9, 2)
);
}
#[test]
fn test_unwrap_cast_with_different_comparison_operators() {
let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
let operators = vec![
(Operator::Lt, Operator::Gt),
(Operator::LtEq, Operator::GtEq),
(Operator::Gt, Operator::Lt),
(Operator::GtEq, Operator::LtEq),
(Operator::Eq, Operator::Eq),
(Operator::NotEq, Operator::NotEq),
];
for (original_op, expected_op) in operators {
let column_expr = col("int_col", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
let literal_expr = lit(100i64);
let binary_expr =
Arc::new(BinaryExpr::new(literal_expr, original_op, cast_expr));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(result.transformed);
let optimized = result.data;
let optimized_binary =
optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert_eq!(
*optimized_binary.op(),
expected_op,
"Failed for operator {original_op:?} -> {expected_op:?}"
);
assert!(!is_cast_expr(optimized_binary.left()));
let right_literal = optimized_binary
.right()
.as_any()
.downcast_ref::<Literal>()
.unwrap();
assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100)));
}
}
#[test]
fn test_unwrap_cast_with_decimal_types() {
let test_cases = vec![
(9, 2, 22, 2, 400),
(10, 3, 20, 3, 1000),
(5, 1, 10, 1, 99),
];
for (col_p, col_s, cast_p, cast_s, value) in test_cases {
let schema = Schema::new(vec![Field::new(
"decimal_col",
DataType::Decimal128(col_p, col_s),
true,
)]);
let column_expr = col("decimal_col", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(
Arc::clone(&column_expr),
DataType::Decimal128(cast_p, cast_s),
None,
));
let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s));
let binary_expr =
Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(result.transformed);
let cast_expr = Arc::new(CastExpr::new(
column_expr,
DataType::Decimal128(cast_p, cast_s),
None,
));
let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s));
let binary_expr =
Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(result.transformed);
}
}
#[test]
fn test_unwrap_cast_with_null_literals() {
let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, true)]);
let column_expr = col("int_col", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
let null_literal = lit(ScalarValue::Int64(None));
let binary_expr =
Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, null_literal));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(result.transformed);
let optimized = result.data;
let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
let right_literal = optimized_binary
.right()
.as_any()
.downcast_ref::<Literal>()
.unwrap();
assert_eq!(right_literal.value(), &ScalarValue::Int32(None));
}
#[test]
fn test_unwrap_cast_with_try_cast() {
let schema = Schema::new(vec![Field::new("str_col", DataType::Utf8, true)]);
let column_expr = col("str_col", &schema).unwrap();
let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64));
let literal_expr = lit(100i64);
let binary_expr =
Arc::new(BinaryExpr::new(try_cast_expr, Operator::Gt, literal_expr));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(!result.transformed);
}
#[test]
fn test_unwrap_cast_preserves_non_comparison_operators() {
let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
let column_expr = col("int_col", &schema).unwrap();
let cast1 = Arc::new(CastExpr::new(
Arc::clone(&column_expr),
DataType::Int64,
None,
));
let lit1 = lit(10i64);
let compare1 = Arc::new(BinaryExpr::new(cast1, Operator::Gt, lit1));
let cast2 = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
let lit2 = lit(20i64);
let compare2 = Arc::new(BinaryExpr::new(cast2, Operator::Lt, lit2));
let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2));
let result = (and_expr as Arc<dyn PhysicalExpr>)
.transform_down(|node| unwrap_cast_in_comparison(node, &schema))
.unwrap();
assert!(result.transformed);
let optimized = result.data;
let and_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert_eq!(*and_binary.op(), Operator::And);
let left_binary = and_binary
.left()
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
let right_binary = and_binary
.right()
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
assert!(!is_cast_expr(left_binary.left()));
assert!(!is_cast_expr(right_binary.left()));
}
#[test]
fn test_try_cast_unwrapping() {
let schema = test_schema();
let column_expr = col("c1", &schema).unwrap();
let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64));
let literal_expr = lit(100i64);
let binary_expr =
Arc::new(BinaryExpr::new(try_cast_expr, Operator::LtEq, literal_expr));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(result.transformed);
let optimized = result.data;
let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
assert!(!is_cast_expr(optimized_binary.left()));
let right_literal = optimized_binary
.right()
.as_any()
.downcast_ref::<Literal>()
.unwrap();
assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100)));
}
#[test]
fn test_non_swappable_operator() {
let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
let column_expr = col("int_col", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
let literal_expr = lit(10i64);
let binary_expr =
Arc::new(BinaryExpr::new(literal_expr, Operator::Plus, cast_expr));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(!result.transformed);
}
#[test]
fn test_cast_that_cannot_be_unwrapped_overflow() {
let schema = Schema::new(vec![Field::new("small_int", DataType::Int8, false)]);
let column_expr = col("small_int", &schema).unwrap();
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
let literal_expr = lit(1000i64); let binary_expr =
Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
assert!(!result.transformed);
}
#[test]
fn test_complex_nested_expression() {
let schema = test_schema();
let c1_expr = col("c1", &schema).unwrap();
let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
let c1_literal = lit(10i64);
let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
let c2_expr = col("c2", &schema).unwrap();
let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
let c2_literal = lit(20i32);
let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::Eq, c2_literal));
let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary));
let result = (and_expr as Arc<dyn PhysicalExpr>)
.transform_down(|node| unwrap_cast_in_comparison(node, &schema))
.unwrap();
assert!(result.transformed);
let optimized = result.data;
let and_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
let left_binary = and_binary
.left()
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
assert!(!is_cast_expr(left_binary.left()));
let left_literal = left_binary
.right()
.as_any()
.downcast_ref::<Literal>()
.unwrap();
assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(10)));
let right_binary = and_binary
.right()
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
assert!(!is_cast_expr(right_binary.left()));
let right_literal = right_binary
.right()
.as_any()
.downcast_ref::<Literal>()
.unwrap();
assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(20)));
}
}