use datafusion_common::{Column, Result, ScalarValue};
use datafusion_expr::{BinaryExpr, Expr, Operator};
use std::collections::BTreeMap;
pub fn simplify_predicates(predicates: Vec<Expr>) -> Result<Vec<Expr>> {
if predicates.len() <= 1 {
return Ok(predicates);
}
let mut column_predicates: BTreeMap<Column, Vec<Expr>> = BTreeMap::new();
let mut other_predicates = Vec::new();
for pred in predicates {
match &pred {
Expr::BinaryExpr(BinaryExpr {
left,
op:
Operator::Gt
| Operator::GtEq
| Operator::Lt
| Operator::LtEq
| Operator::Eq,
right,
}) => {
let left_col = extract_column_from_expr(left);
let right_col = extract_column_from_expr(right);
if let (Some(col), Some(_)) = (&left_col, right.as_literal()) {
column_predicates.entry(col.clone()).or_default().push(pred);
} else if let (Some(_), Some(col)) = (left.as_literal(), &right_col) {
column_predicates.entry(col.clone()).or_default().push(pred);
} else {
other_predicates.push(pred);
}
}
_ => other_predicates.push(pred),
}
}
let mut result = other_predicates;
for (_, preds) in column_predicates {
let simplified = simplify_column_predicates(preds)?;
result.extend(simplified);
}
Ok(result)
}
fn simplify_column_predicates(predicates: Vec<Expr>) -> Result<Vec<Expr>> {
if predicates.len() <= 1 {
return Ok(predicates);
}
let mut greater_predicates = Vec::new(); let mut less_predicates = Vec::new(); let mut eq_predicates = Vec::new();
for pred in predicates {
match &pred {
Expr::BinaryExpr(BinaryExpr { left: _, op, right }) => {
match (op, right.as_literal().is_some()) {
(Operator::Gt, true)
| (Operator::Lt, false)
| (Operator::GtEq, true)
| (Operator::LtEq, false) => greater_predicates.push(pred),
(Operator::Lt, true)
| (Operator::Gt, false)
| (Operator::LtEq, true)
| (Operator::GtEq, false) => less_predicates.push(pred),
(Operator::Eq, _) => eq_predicates.push(pred),
_ => unreachable!("Unexpected operator: {}", op),
}
}
_ => unreachable!("Unexpected predicate {}", pred.to_string()),
}
}
let mut result = Vec::new();
if !eq_predicates.is_empty() {
if eq_predicates.len() == 1
|| eq_predicates.iter().all(|e| e == &eq_predicates[0])
{
result.push(eq_predicates.pop().unwrap());
} else {
result.push(Expr::Literal(ScalarValue::Boolean(Some(false)), None));
}
}
if !greater_predicates.is_empty() {
if let Some(most_restrictive) =
find_most_restrictive_predicate(&greater_predicates, true)?
{
result.push(most_restrictive);
} else {
result.extend(greater_predicates);
}
}
if !less_predicates.is_empty() {
if let Some(most_restrictive) =
find_most_restrictive_predicate(&less_predicates, false)?
{
result.push(most_restrictive);
} else {
result.extend(less_predicates);
}
}
Ok(result)
}
fn find_most_restrictive_predicate(
predicates: &[Expr],
find_greater: bool,
) -> Result<Option<Expr>> {
if predicates.is_empty() {
return Ok(None);
}
let mut most_restrictive_idx = 0;
let mut best_value: Option<&ScalarValue> = None;
for (idx, pred) in predicates.iter().enumerate() {
if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = pred {
let scalar_value = match (right.as_literal(), left.as_literal()) {
(Some(scalar), _) => Some(scalar),
(_, Some(scalar)) => Some(scalar),
_ => None,
};
if let Some(scalar) = scalar_value {
if let Some(current_best) = best_value {
let comparison = scalar.try_cmp(current_best)?;
let is_better = if find_greater {
comparison == std::cmp::Ordering::Greater
|| (comparison == std::cmp::Ordering::Equal
&& op == &Operator::Gt)
} else {
comparison == std::cmp::Ordering::Less
|| (comparison == std::cmp::Ordering::Equal
&& op == &Operator::Lt)
};
if is_better {
best_value = Some(scalar);
most_restrictive_idx = idx;
}
} else {
best_value = Some(scalar);
most_restrictive_idx = idx;
}
}
}
}
Ok(Some(predicates[most_restrictive_idx].clone()))
}
fn extract_column_from_expr(expr: &Expr) -> Option<Column> {
match expr {
Expr::Column(col) => Some(col.clone()),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::DataType;
use datafusion_expr::{cast, col, lit};
#[test]
fn test_simplify_predicates_with_cast() {
let predicates = vec![
col("a").lt(lit(5i32)),
cast(col("a"), DataType::Utf8).lt(lit("abc")),
col("a").lt(lit(6i32)),
];
let result = simplify_predicates(predicates).unwrap();
assert_eq!(result.len(), 2);
let has_cast_predicate = result.iter().any(|p| {
matches!(p, Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Lt,
right
}) if matches!(left.as_ref(), Expr::Cast(_)) && right == &Box::new(lit("abc")))
});
assert!(has_cast_predicate, "Cast predicate should be preserved");
let has_column_predicate = result.iter().any(|p| {
matches!(p, Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Lt,
right
}) if left == &Box::new(col("a")) && right == &Box::new(lit(5i32)))
});
assert!(has_column_predicate, "Should have a < 5 predicate");
}
#[test]
fn test_extract_column_ignores_cast() {
let cast_expr = cast(col("a"), DataType::Utf8);
assert_eq!(extract_column_from_expr(&cast_expr), None);
let col_expr = col("a");
assert_eq!(extract_column_from_expr(&col_expr), Some(Column::from("a")));
}
#[test]
fn test_simplify_predicates_direct_columns_only() {
let predicates = vec![
col("a").lt(lit(5i32)),
col("a").lt(lit(3i32)),
col("b").gt(lit(10i32)),
col("b").gt(lit(20i32)),
];
let result = simplify_predicates(predicates).unwrap();
assert_eq!(result.len(), 2);
let has_a_predicate = result.iter().any(|p| {
matches!(p, Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Lt,
right
}) if left == &Box::new(col("a")) && right == &Box::new(lit(3i32)))
});
assert!(has_a_predicate, "Should have a < 3 predicate");
let has_b_predicate = result.iter().any(|p| {
matches!(p, Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Gt,
right
}) if left == &Box::new(col("b")) && right == &Box::new(lit(20i32)))
});
assert!(has_b_predicate, "Should have b > 20 predicate");
}
}