use datafusion::common::stats::Precision;
use datafusion::common::tree_node::Transformed;
use datafusion::common::{Result, ScalarValue};
use datafusion::datasource::source_as_provider;
use datafusion::logical_expr::expr::AggregateFunction;
use datafusion::logical_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableScan};
use datafusion::optimizer::optimizer::ApplyOrder;
use datafusion::optimizer::OptimizerRule;
use datafusion::prelude::lit;
use std::sync::Arc;
use tracing::{debug, info, trace, warn};
#[derive(Debug, Default)]
pub struct MinMaxStatisticsRule;
impl MinMaxStatisticsRule {
pub fn new() -> Self {
Self
}
}
#[derive(Debug, Clone, Copy)]
enum MinMaxType {
Min,
Max,
}
impl OptimizerRule for MinMaxStatisticsRule {
fn name(&self) -> &str {
"minmax_statistics"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn datafusion::optimizer::OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let LogicalPlan::Aggregate(aggregate) = &plan else {
trace!("Skipping non-Aggregate node");
return Ok(Transformed::no(plan));
};
debug!(
group_by_count = aggregate.group_expr.len(),
aggr_count = aggregate.aggr_expr.len(),
"Evaluating Aggregate node for MIN/MAX optimization"
);
if !aggregate.group_expr.is_empty() {
debug!(
group_by_count = aggregate.group_expr.len(),
"Skipping: has GROUP BY clause"
);
return Ok(Transformed::no(plan));
}
if aggregate.aggr_expr.is_empty() {
debug!("Skipping: no aggregate expressions");
return Ok(Transformed::no(plan));
}
let input = aggregate.input.as_ref();
let table_scan = match unwrap_to_table_scan(input) {
Some(scan) => scan,
None => {
debug!("Skipping: could not find TableScan in input");
return Ok(Transformed::no(plan));
}
};
let table_name = table_scan.table_name.to_string();
debug!(table = %table_name, "Found TableScan");
if !table_scan.filters.is_empty() {
debug!(
table = %table_name,
filter_count = table_scan.filters.len(),
"Skipping: query has WHERE filters"
);
return Ok(Transformed::no(plan));
}
let provider = match source_as_provider(&table_scan.source) {
Ok(p) => p,
Err(e) => {
warn!(
table = %table_name,
error = %e,
"Could not get TableProvider from TableSource"
);
return Ok(Transformed::no(plan));
}
};
let statistics = match provider.statistics() {
Some(stats) => stats,
None => {
debug!(
table = %table_name,
"Skipping: table does not provide statistics"
);
return Ok(Transformed::no(plan));
}
};
let schema = provider.schema();
let mut minmax_values: Vec<(String, ScalarValue)> = Vec::new();
for aggr_expr in &aggregate.aggr_expr {
match extract_minmax_info(aggr_expr, input) {
Some((minmax_type, column_name)) => {
let alias_name = get_expr_alias(aggr_expr, &aggregate.schema);
let col_idx = schema
.fields()
.iter()
.position(|f| f.name() == &column_name);
if let Some(idx) = col_idx {
let col_stats = &statistics.column_statistics;
if idx < col_stats.len() {
let value = match minmax_type {
MinMaxType::Min => match &col_stats[idx].min_value {
Precision::Exact(v) => {
debug!(
alias = %alias_name,
column = %column_name,
value = %v,
"Found exact MIN value from statistics"
);
v.clone()
}
Precision::Inexact(_) => {
debug!(
column = %column_name,
"Skipping: min value is inexact"
);
return Ok(Transformed::no(plan));
}
Precision::Absent => {
debug!(
column = %column_name,
"Skipping: min value not available"
);
return Ok(Transformed::no(plan));
}
},
MinMaxType::Max => match &col_stats[idx].max_value {
Precision::Exact(v) => {
debug!(
alias = %alias_name,
column = %column_name,
value = %v,
"Found exact MAX value from statistics"
);
v.clone()
}
Precision::Inexact(_) => {
debug!(
column = %column_name,
"Skipping: max value is inexact"
);
return Ok(Transformed::no(plan));
}
Precision::Absent => {
debug!(
column = %column_name,
"Skipping: max value not available"
);
return Ok(Transformed::no(plan));
}
},
};
minmax_values.push((alias_name, value));
} else {
debug!(
column = %column_name,
"Skipping: column statistics not available"
);
return Ok(Transformed::no(plan));
}
} else {
debug!(column = %column_name, "Skipping: column not found in schema");
return Ok(Transformed::no(plan));
}
}
None => {
debug!("Skipping: found non-MIN/MAX aggregate function");
return Ok(Transformed::no(plan));
}
}
}
info!(
table = %table_name,
minmax_expressions = minmax_values.len(),
values = ?minmax_values,
"MIN/MAX optimization applied - replacing with constants"
);
create_minmax_plan(&minmax_values)
}
}
fn extract_minmax_info(expr: &Expr, input_plan: &LogicalPlan) -> Option<(MinMaxType, String)> {
let inner = unwrap_alias(expr);
match inner {
Expr::AggregateFunction(AggregateFunction { func, params }) => {
let func_name = func.name().to_lowercase();
let minmax_type = match func_name.as_str() {
"min" => MinMaxType::Min,
"max" => MinMaxType::Max,
_ => {
trace!(func_name = %func_name, "Not a MIN/MAX function");
return None;
}
};
let arg = params.args.first()?;
let inner_arg = unwrap_alias(arg);
match inner_arg {
Expr::Column(col) => {
if col.name.starts_with("__common_expr") {
trace_column_in_plan(&col.name, input_plan, minmax_type)
} else {
Some((minmax_type, col.name.clone()))
}
}
Expr::Cast(cast) => {
if let Expr::Column(col) = unwrap_alias(cast.expr.as_ref()) {
return Some((minmax_type, col.name.clone()));
}
debug!(cast_expr_type = %cast.expr.variant_name(), "MIN/MAX cast argument is not a column");
None
}
other => {
debug!(expr_type = %other.variant_name(), "MIN/MAX argument is not a column or cast");
None
}
}
}
_ => {
trace!("Not an AggregateFunction expression");
None
}
}
}
fn trace_column_in_plan(
expr_name: &str,
plan: &LogicalPlan,
minmax_type: MinMaxType,
) -> Option<(MinMaxType, String)> {
match plan {
LogicalPlan::Projection(proj) => {
for proj_expr in &proj.expr {
if let Expr::Alias(alias) = proj_expr {
if alias.name == expr_name {
let inner = unwrap_alias(&alias.expr);
if let Expr::Cast(cast) = inner {
if let Expr::Column(col) = unwrap_alias(cast.expr.as_ref()) {
debug!(
common_expr = %expr_name,
original_col = %col.name,
"Traced common expression to original column"
);
return Some((minmax_type, col.name.clone()));
}
} else if let Expr::Column(col) = inner {
debug!(
common_expr = %expr_name,
original_col = %col.name,
"Traced common expression to original column"
);
return Some((minmax_type, col.name.clone()));
}
}
}
}
trace_column_in_plan(expr_name, &proj.input, minmax_type)
}
LogicalPlan::TableScan(_) => None,
LogicalPlan::SubqueryAlias(alias) => {
trace_column_in_plan(expr_name, &alias.input, minmax_type)
}
other => {
for input in other.inputs() {
if let Some(result) = trace_column_in_plan(expr_name, input, minmax_type) {
return Some(result);
}
}
None
}
}
}
fn get_expr_alias(expr: &Expr, _schema: &Arc<datafusion::common::DFSchema>) -> String {
if let Expr::Alias(alias) = expr {
return alias.name.clone();
}
expr.schema_name().to_string()
}
fn unwrap_alias(expr: &Expr) -> &Expr {
match expr {
Expr::Alias(alias) => unwrap_alias(&alias.expr),
other => other,
}
}
fn unwrap_to_table_scan(plan: &LogicalPlan) -> Option<&TableScan> {
match plan {
LogicalPlan::TableScan(scan) => Some(scan),
LogicalPlan::Projection(Projection { input, .. }) => unwrap_to_table_scan(input),
LogicalPlan::SubqueryAlias(alias) => unwrap_to_table_scan(&alias.input),
LogicalPlan::Filter(_) => {
debug!("Found Filter node - min/max would change after filtering");
None
}
other => {
debug!(node_type = %other.display(), "Unsupported node type in input");
None
}
}
}
fn create_minmax_plan(minmax_values: &[(String, ScalarValue)]) -> Result<Transformed<LogicalPlan>> {
let exprs: Vec<Expr> = minmax_values
.iter()
.map(|(alias, value)| lit(value.clone()).alias(alias))
.collect();
let empty = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: Arc::new(arrow::datatypes::Schema::empty().try_into()?),
});
let projection = LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(empty))?);
Ok(Transformed::yes(projection))
}