use datafusion::common::stats::Precision;
use datafusion::common::tree_node::Transformed;
use datafusion::common::Result;
use datafusion::datasource::source_as_provider;
use datafusion::logical_expr::expr::AggregateFunction;
use datafusion::logical_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableScan};
use datafusion::optimizer::optimizer::ApplyOrder;
use datafusion::optimizer::OptimizerRule;
use datafusion::prelude::lit;
use std::sync::Arc;
use tracing::{debug, info, trace, warn};
#[derive(Debug, Default)]
pub struct CountStatisticsRule;
impl CountStatisticsRule {
pub fn new() -> Self {
Self
}
}
impl OptimizerRule for CountStatisticsRule {
fn name(&self) -> &str {
"count_statistics"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn datafusion::optimizer::OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let LogicalPlan::Aggregate(aggregate) = &plan else {
trace!("Skipping non-Aggregate node");
return Ok(Transformed::no(plan));
};
debug!(
group_by_count = aggregate.group_expr.len(),
aggr_count = aggregate.aggr_expr.len(),
"Evaluating Aggregate node for count optimization"
);
if !aggregate.group_expr.is_empty() {
debug!(
group_by_count = aggregate.group_expr.len(),
"Skipping: has GROUP BY clause"
);
return Ok(Transformed::no(plan));
}
if aggregate.aggr_expr.is_empty() {
debug!("Skipping: no aggregate expressions");
return Ok(Transformed::no(plan));
}
let input = aggregate.input.as_ref();
let table_scan = match unwrap_to_table_scan(input) {
Some(scan) => scan,
None => {
debug!("Skipping: could not find TableScan in input");
return Ok(Transformed::no(plan));
}
};
let table_name = table_scan.table_name.to_string();
debug!(table = %table_name, "Found TableScan");
if !table_scan.filters.is_empty() {
debug!(
table = %table_name,
filter_count = table_scan.filters.len(),
"Skipping: query has WHERE filters"
);
return Ok(Transformed::no(plan));
}
let provider = match source_as_provider(&table_scan.source) {
Ok(p) => p,
Err(e) => {
warn!(
table = %table_name,
error = %e,
"Could not get TableProvider from TableSource"
);
return Ok(Transformed::no(plan));
}
};
let statistics = match provider.statistics() {
Some(stats) => stats,
None => {
debug!(
table = %table_name,
"Skipping: table does not provide statistics"
);
return Ok(Transformed::no(plan));
}
};
let num_rows: usize = match statistics.num_rows {
Precision::Exact(n) => {
debug!(table = %table_name, num_rows = n, "Got exact row count from statistics");
n
}
Precision::Inexact(n) => {
debug!(
table = %table_name,
approx_rows = n,
"Skipping: row count is inexact"
);
return Ok(Transformed::no(plan));
}
Precision::Absent => {
debug!(table = %table_name, "Skipping: row count not available");
return Ok(Transformed::no(plan));
}
};
let schema = provider.schema();
let mut count_values: Vec<(String, i64)> = Vec::new();
for aggr_expr in &aggregate.aggr_expr {
match extract_count_info(aggr_expr) {
Some((is_count_star, column_name)) => {
let alias_name = get_expr_alias(aggr_expr, &aggregate.schema);
let count_value = if is_count_star {
debug!(alias = %alias_name, "Found count(*)");
num_rows as i64
} else if let Some(ref col_name) = column_name {
let col_idx = schema.fields().iter().position(|f| f.name() == col_name);
if let Some(idx) = col_idx {
let col_stats = &statistics.column_statistics;
if idx < col_stats.len() {
match col_stats[idx].null_count {
Precision::Exact(0) => {
debug!(
alias = %alias_name,
column = %col_name,
"Column has no nulls, count = num_rows"
);
num_rows as i64
}
Precision::Exact(nulls) => {
let count = num_rows.saturating_sub(nulls) as i64;
debug!(
alias = %alias_name,
column = %col_name,
null_count = nulls,
count,
"Column has nulls, count = num_rows - null_count"
);
count
}
Precision::Inexact(nulls) => {
debug!(
column = %col_name,
approx_nulls = nulls,
"Skipping: null count is inexact"
);
return Ok(Transformed::no(plan));
}
Precision::Absent => {
debug!(
column = %col_name,
"Skipping: null count not available"
);
return Ok(Transformed::no(plan));
}
}
} else {
debug!(
column = %col_name,
"Skipping: column statistics not available"
);
return Ok(Transformed::no(plan));
}
} else {
debug!(column = %col_name, "Skipping: column not found in schema");
return Ok(Transformed::no(plan));
}
} else {
debug!("Skipping: column name not available");
return Ok(Transformed::no(plan));
};
count_values.push((alias_name, count_value));
}
None => {
debug!("Skipping: found non-count aggregate function");
return Ok(Transformed::no(plan));
}
}
}
info!(
table = %table_name,
count_expressions = count_values.len(),
values = ?count_values,
"Count optimization applied - replacing with constants"
);
create_multi_count_plan(&count_values)
}
}
fn extract_count_info(expr: &Expr) -> Option<(bool, Option<String>)> {
let inner = unwrap_alias(expr);
match inner {
Expr::AggregateFunction(AggregateFunction { func, params }) => {
if func.name() != "count" {
trace!(func_name = %func.name(), "Not a count function");
return None;
}
match params.args.first() {
Some(Expr::Literal(_, _)) => Some((true, None)), Some(Expr::Column(col)) => Some((false, Some(col.name.clone()))),
None => Some((true, None)), _ => {
debug!("Unsupported count argument type");
None
}
}
}
_ => {
trace!("Not an AggregateFunction expression");
None
}
}
}
fn get_expr_alias(expr: &Expr, _schema: &Arc<datafusion::common::DFSchema>) -> String {
if let Expr::Alias(alias) = expr {
return alias.name.clone();
}
expr.schema_name().to_string()
}
fn unwrap_alias(expr: &Expr) -> &Expr {
match expr {
Expr::Alias(alias) => unwrap_alias(&alias.expr),
other => other,
}
}
fn unwrap_to_table_scan(plan: &LogicalPlan) -> Option<&TableScan> {
match plan {
LogicalPlan::TableScan(scan) => Some(scan),
LogicalPlan::Projection(Projection { input, .. }) => unwrap_to_table_scan(input),
LogicalPlan::SubqueryAlias(alias) => unwrap_to_table_scan(&alias.input),
LogicalPlan::Filter(_) => {
debug!("Found Filter node - count would change after filtering");
None
}
other => {
debug!(node_type = %other.display(), "Unsupported node type in input");
None
}
}
}
fn create_multi_count_plan(count_values: &[(String, i64)]) -> Result<Transformed<LogicalPlan>> {
let exprs: Vec<Expr> = count_values
.iter()
.map(|(alias, value)| lit(*value).alias(alias))
.collect();
let empty = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: Arc::new(arrow::datatypes::Schema::empty().try_into()?),
});
let projection = LogicalPlan::Projection(Projection::try_new(exprs, Arc::new(empty))?);
Ok(Transformed::yes(projection))
}