use std::sync::Arc;
use crate::config::ConfigOptions;
use datafusion_common::tree_node::TreeNode;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
use crate::physical_plan::empty::EmptyExec;
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::{
expressions, AggregateExpr, ColumnStatistics, ExecutionPlan, Statistics,
};
use crate::scalar::ScalarValue;
use super::optimizer::PhysicalOptimizerRule;
use crate::error::Result;
#[derive(Default)]
pub struct AggregateStatistics {}
const COUNT_STAR_NAME: &str = "COUNT(UInt8(1))";
impl AggregateStatistics {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl PhysicalOptimizerRule for AggregateStatistics {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
if let Some(partial_agg_exec) = take_optimizable(&*plan) {
let partial_agg_exec = partial_agg_exec
.as_any()
.downcast_ref::<AggregateExec>()
.expect("take_optimizable() ensures that this is a AggregateExec");
let stats = partial_agg_exec.input().statistics();
let mut projections = vec![];
for expr in partial_agg_exec.aggr_expr() {
if let Some((non_null_rows, name)) =
take_optimizable_column_count(&**expr, &stats)
{
projections.push((expressions::lit(non_null_rows), name.to_owned()));
} else if let Some((num_rows, name)) =
take_optimizable_table_count(&**expr, &stats)
{
projections.push((expressions::lit(num_rows), name.to_owned()));
} else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) {
projections.push((expressions::lit(min), name.to_owned()));
} else if let Some((max, name)) = take_optimizable_max(&**expr, &stats) {
projections.push((expressions::lit(max), name.to_owned()));
} else {
break;
}
}
if projections.len() == partial_agg_exec.aggr_expr().len() {
Ok(Arc::new(ProjectionExec::try_new(
projections,
Arc::new(EmptyExec::new(true, plan.schema())),
)?))
} else {
plan.map_children(|child| self.optimize(child, _config))
}
} else {
plan.map_children(|child| self.optimize(child, _config))
}
}
fn name(&self) -> &str {
"aggregate_statistics"
}
fn schema_check(&self) -> bool {
false
}
}
fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>> {
if let Some(final_agg_exec) = node.as_any().downcast_ref::<AggregateExec>() {
if final_agg_exec.mode() == &AggregateMode::Final
&& final_agg_exec.group_expr().is_empty()
{
let mut child = Arc::clone(final_agg_exec.input());
loop {
if let Some(partial_agg_exec) =
child.as_any().downcast_ref::<AggregateExec>()
{
if partial_agg_exec.mode() == &AggregateMode::Partial
&& partial_agg_exec.group_expr().is_empty()
&& partial_agg_exec.filter_expr().iter().all(|e| e.is_none())
{
let stats = partial_agg_exec.input().statistics();
if stats.is_exact {
return Some(child);
}
}
}
if let [ref childrens_child] = child.children().as_slice() {
child = Arc::clone(childrens_child);
} else {
break;
}
}
}
}
None
}
fn take_optimizable_table_count(
agg_expr: &dyn AggregateExpr,
stats: &Statistics,
) -> Option<(ScalarValue, &'static str)> {
if let (Some(num_rows), Some(casted_expr)) = (
stats.num_rows,
agg_expr.as_any().downcast_ref::<expressions::Count>(),
) {
if casted_expr.expressions().len() == 1 {
if let Some(lit_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some((
ScalarValue::Int64(Some(num_rows as i64)),
COUNT_STAR_NAME,
));
}
}
}
}
None
}
fn take_optimizable_column_count(
agg_expr: &dyn AggregateExpr,
stats: &Statistics,
) -> Option<(ScalarValue, String)> {
if let (Some(num_rows), Some(col_stats), Some(casted_expr)) = (
stats.num_rows,
&stats.column_statistics,
agg_expr.as_any().downcast_ref::<expressions::Count>(),
) {
if casted_expr.expressions().len() == 1 {
if let Some(col_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
if let ColumnStatistics {
null_count: Some(val),
..
} = &col_stats[col_expr.index()]
{
return Some((
ScalarValue::Int64(Some((num_rows - val) as i64)),
casted_expr.name().to_string(),
));
}
}
}
}
None
}
fn take_optimizable_min(
agg_expr: &dyn AggregateExpr,
stats: &Statistics,
) -> Option<(ScalarValue, String)> {
if let (Some(col_stats), Some(casted_expr)) = (
&stats.column_statistics,
agg_expr.as_any().downcast_ref::<expressions::Min>(),
) {
if casted_expr.expressions().len() == 1 {
if let Some(col_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
if let ColumnStatistics {
min_value: Some(val),
..
} = &col_stats[col_expr.index()]
{
return Some((val.clone(), casted_expr.name().to_string()));
}
}
}
}
None
}
fn take_optimizable_max(
agg_expr: &dyn AggregateExpr,
stats: &Statistics,
) -> Option<(ScalarValue, String)> {
if let (Some(col_stats), Some(casted_expr)) = (
&stats.column_statistics,
agg_expr.as_any().downcast_ref::<expressions::Max>(),
) {
if casted_expr.expressions().len() == 1 {
if let Some(col_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
if let ColumnStatistics {
max_value: Some(val),
..
} = &col_stats[col_expr.index()]
{
return Some((val.clone(), casted_expr.name().to_string()));
}
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::as_int64_array;
use datafusion_physical_expr::expressions::cast;
use datafusion_physical_expr::PhysicalExpr;
use crate::error::Result;
use crate::logical_expr::Operator;
use crate::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy};
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use crate::physical_plan::common;
use crate::physical_plan::expressions::Count;
use crate::physical_plan::filter::FilterExec;
use crate::physical_plan::memory::MemoryExec;
use crate::prelude::SessionContext;
fn mock_data() -> Result<Arc<MemoryExec>> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(2), None])),
Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])),
],
)?;
Ok(Arc::new(MemoryExec::try_new(
&[vec![batch]],
Arc::clone(&schema),
None,
)?))
}
async fn assert_count_optim_success(
plan: AggregateExec,
agg: TestAggregate,
) -> Result<()> {
let session_ctx = SessionContext::new();
let state = session_ctx.state();
let plan = Arc::new(plan) as _;
let optimized = AggregateStatistics::new()
.optimize(Arc::clone(&plan), state.config_options())?;
assert!(optimized.as_any().is::<ProjectionExec>());
let optimized_result =
common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?;
let nonoptimized_result =
common::collect(plan.execute(0, session_ctx.task_ctx())?).await?;
assert_eq!(optimized_result.len(), nonoptimized_result.len());
assert_eq!(optimized_result.len(), 1);
check_batch(optimized_result.into_iter().next().unwrap(), &agg);
assert_eq!(nonoptimized_result.len(), 1);
check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg);
Ok(())
}
fn check_batch(batch: RecordBatch, agg: &TestAggregate) {
let schema = batch.schema();
let fields = schema.fields();
assert_eq!(fields.len(), 1);
let field = &fields[0];
assert_eq!(field.name(), agg.column_name());
assert_eq!(field.data_type(), &DataType::Int64);
assert_eq!(
as_int64_array(batch.column(0)).unwrap().values(),
&[agg.expected_count()]
);
}
enum TestAggregate {
CountStar,
ColumnA(Arc<Schema>),
}
impl TestAggregate {
fn new_count_star() -> Self {
Self::CountStar
}
fn new_count_column(schema: &Arc<Schema>) -> Self {
Self::ColumnA(schema.clone())
}
fn count_expr(&self) -> Arc<dyn AggregateExpr> {
Arc::new(Count::new(
self.column(),
self.column_name(),
DataType::Int64,
))
}
fn column(&self) -> Arc<dyn PhysicalExpr> {
match self {
Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION),
Self::ColumnA(s) => expressions::col("a", s).unwrap(),
}
}
fn column_name(&self) -> &'static str {
match self {
Self::CountStar => COUNT_STAR_NAME,
Self::ColumnA(_) => "COUNT(a)",
}
}
fn expected_count(&self) -> i64 {
match self {
TestAggregate::CountStar => 3,
TestAggregate::ColumnA(_) => 2,
}
}
}
#[tokio::test]
async fn test_count_partial_direct_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_star();
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
source,
Arc::clone(&schema),
)?;
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
assert_count_optim_success(final_agg, agg).await?;
Ok(())
}
#[tokio::test]
async fn test_count_partial_with_nulls_direct_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_column(&schema);
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
source,
Arc::clone(&schema),
)?;
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
assert_count_optim_success(final_agg, agg).await?;
Ok(())
}
#[tokio::test]
async fn test_count_partial_indirect_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_star();
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
source,
Arc::clone(&schema),
)?;
let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
Arc::new(coalesce),
Arc::clone(&schema),
)?;
assert_count_optim_success(final_agg, agg).await?;
Ok(())
}
#[tokio::test]
async fn test_count_partial_with_nulls_indirect_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_column(&schema);
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
source,
Arc::clone(&schema),
)?;
let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
Arc::new(coalesce),
Arc::clone(&schema),
)?;
assert_count_optim_success(final_agg, agg).await?;
Ok(())
}
#[tokio::test]
async fn test_count_inexact_stat() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_star();
let filter = Arc::new(FilterExec::try_new(
expressions::binary(
expressions::col("a", &schema)?,
Operator::Gt,
cast(expressions::lit(1u32), &schema, DataType::Int32)?,
&schema,
)?,
source,
)?);
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
filter,
Arc::clone(&schema),
)?;
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
let conf = ConfigOptions::new();
let optimized =
AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?;
assert!(optimized.as_any().is::<AggregateExec>());
Ok(())
}
#[tokio::test]
async fn test_count_with_nulls_inexact_stat() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let agg = TestAggregate::new_count_column(&schema);
let filter = Arc::new(FilterExec::try_new(
expressions::binary(
expressions::col("a", &schema)?,
Operator::Gt,
cast(expressions::lit(1u32), &schema, DataType::Int32)?,
&schema,
)?,
source,
)?);
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
filter,
Arc::clone(&schema),
)?;
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![agg.count_expr()],
vec![None],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
let conf = ConfigOptions::new();
let optimized =
AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?;
assert!(optimized.as_any().is::<AggregateExec>());
Ok(())
}
}