use datafusion_common::{Result, internal_err, tree_node::Transformed};
use datafusion_expr::{Expr, Operator, and, lit, or};
use datafusion_expr_common::interval_arithmetic::Interval;
pub(super) fn rewrite_with_preimage(
preimage_interval: Interval,
op: Operator,
expr: Expr,
) -> Result<Transformed<Expr>> {
let (lower, upper) = preimage_interval.into_bounds();
let (lower, upper) = (lit(lower), lit(upper));
let rewritten_expr = match op {
Operator::Lt => expr.lt(lower),
Operator::GtEq => expr.gt_eq(lower),
Operator::Gt => expr.gt_eq(upper),
Operator::LtEq => expr.lt(upper),
Operator::Eq => and(expr.clone().gt_eq(lower), expr.lt(upper)),
Operator::NotEq => or(expr.clone().lt(lower), expr.gt_eq(upper)),
Operator::IsNotDistinctFrom => expr
.clone()
.is_not_null()
.and(expr.clone().gt_eq(lower))
.and(expr.lt(upper)),
Operator::IsDistinctFrom => expr
.clone()
.lt(lower)
.or(expr.clone().gt_eq(upper))
.or(expr.is_null()),
_ => return internal_err!("Expect comparison operators"),
};
Ok(Transformed::yes(rewritten_expr))
}
#[cfg(test)]
mod test {
use std::any::Any;
use std::sync::Arc;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue};
use datafusion_expr::{
ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
Signature, Volatility, and, binary_expr, col, lit, or, preimage::PreimageResult,
simplify::SimplifyContext,
};
use super::Interval;
use crate::simplify_expressions::ExprSimplifier;
fn is_distinct_from(left: Expr, right: Expr) -> Expr {
binary_expr(left, Operator::IsDistinctFrom, right)
}
fn is_not_distinct_from(left: Expr, right: Expr) -> Expr {
binary_expr(left, Operator::IsNotDistinctFrom, right)
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct PreimageUdf {
signature: Signature,
enabled: bool,
}
impl PreimageUdf {
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
enabled: true,
}
}
fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
fn with_volatility(mut self, volatility: Volatility) -> Self {
self.signature.volatility = volatility;
self
}
}
impl ScalarUDFImpl for PreimageUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"preimage_func"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(500))))
}
fn preimage(
&self,
args: &[Expr],
lit_expr: &Expr,
_info: &SimplifyContext,
) -> Result<PreimageResult> {
if !self.enabled {
return Ok(PreimageResult::None);
}
if args.len() != 1 {
return Ok(PreimageResult::None);
}
let expr = args.first().cloned().expect("Should be column expression");
match lit_expr {
Expr::Literal(ScalarValue::Int32(Some(500)), _) => {
Ok(PreimageResult::Range {
expr,
interval: Box::new(Interval::try_new(
ScalarValue::Int32(Some(100)),
ScalarValue::Int32(Some(200)),
)?),
})
}
Expr::Literal(ScalarValue::Int32(Some(600)), _) => {
Ok(PreimageResult::Range {
expr,
interval: Box::new(Interval::try_new(
ScalarValue::Int32(Some(300)),
ScalarValue::Int32(Some(400)),
)?),
})
}
_ => Ok(PreimageResult::None),
}
}
}
fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
let simplify_context = SimplifyContext::default().with_schema(Arc::clone(schema));
ExprSimplifier::new(simplify_context)
.simplify(expr)
.unwrap()
}
fn preimage_udf_expr() -> Expr {
ScalarUDF::new_from_impl(PreimageUdf::new()).call(vec![col("x")])
}
fn non_immutable_udf_expr() -> Expr {
ScalarUDF::new_from_impl(PreimageUdf::new().with_volatility(Volatility::Volatile))
.call(vec![col("x")])
}
fn no_preimage_udf_expr() -> Expr {
ScalarUDF::new_from_impl(PreimageUdf::new().with_enabled(false))
.call(vec![col("x")])
}
fn test_schema() -> DFSchemaRef {
Arc::new(
DFSchema::from_unqualified_fields(
vec![Field::new("x", DataType::Int32, true)].into(),
Default::default(),
)
.unwrap(),
)
}
fn test_schema_xy() -> DFSchemaRef {
Arc::new(
DFSchema::from_unqualified_fields(
vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Int32, false),
]
.into(),
Default::default(),
)
.unwrap(),
)
}
#[test]
fn test_preimage_eq_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().eq(lit(500));
let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200)));
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_noteq_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().not_eq(lit(500));
let expected = col("x").lt(lit(100)).or(col("x").gt_eq(lit(200)));
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_eq_rewrite_swapped() {
let schema = test_schema();
let expr = lit(500).eq(preimage_udf_expr());
let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200)));
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_lt_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().lt(lit(500));
let expected = col("x").lt(lit(100));
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_lteq_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().lt_eq(lit(500));
let expected = col("x").lt(lit(200));
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_gt_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().gt(lit(500));
let expected = col("x").gt_eq(lit(200));
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_gteq_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().gt_eq(lit(500));
let expected = col("x").gt_eq(lit(100));
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_is_not_distinct_from_rewrite() {
let schema = test_schema();
let expr = is_not_distinct_from(preimage_udf_expr(), lit(500));
let expected = col("x")
.is_not_null()
.and(col("x").gt_eq(lit(100)))
.and(col("x").lt(lit(200)));
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_is_distinct_from_rewrite() {
let schema = test_schema();
let expr = is_distinct_from(preimage_udf_expr(), lit(500));
let expected = col("x")
.lt(lit(100))
.or(col("x").gt_eq(lit(200)))
.or(col("x").is_null());
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_in_list_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], false);
let expected = or(
and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))),
and(col("x").gt_eq(lit(300)), col("x").lt(lit(400))),
);
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_not_in_list_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], true);
let expected = and(
or(col("x").lt(lit(100)), col("x").gt_eq(lit(200))),
or(col("x").lt(lit(300)), col("x").gt_eq(lit(400))),
);
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_in_list_long_list_no_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().in_list((1..100).map(lit).collect(), false);
assert_eq!(optimize_test(expr.clone(), &schema), expr);
}
#[test]
fn test_preimage_non_literal_rhs_no_rewrite() {
let schema = test_schema_xy();
let expr = preimage_udf_expr().eq(col("y"));
let expected = expr.clone();
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_null_literal_no_rewrite_distinct_ops() {
let schema = test_schema();
let expr = is_distinct_from(preimage_udf_expr(), lit(ScalarValue::Int32(None)));
assert_eq!(optimize_test(expr.clone(), &schema), expr);
let expr =
is_not_distinct_from(preimage_udf_expr(), lit(ScalarValue::Int32(None)));
assert_eq!(optimize_test(expr.clone(), &schema), expr);
}
#[test]
fn test_preimage_non_immutable_no_rewrite() {
let schema = test_schema();
let expr = non_immutable_udf_expr().eq(lit(500));
let expected = expr.clone();
assert_eq!(optimize_test(expr, &schema), expected);
}
#[test]
fn test_preimage_no_preimage_no_rewrite() {
let schema = test_schema();
let expr = no_preimage_udf_expr().eq(lit(500));
let expected = expr.clone();
assert_eq!(optimize_test(expr, &schema), expected);
}
}