use std::sync::Arc;
use arrow::datatypes::Schema;
use crate::execution::context::ExecutionConfig;
use crate::physical_plan::empty::EmptyExec;
use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::{
expressions, AggregateExpr, ColumnStatistics, ExecutionPlan, Statistics,
};
use crate::scalar::ScalarValue;
use super::optimizer::PhysicalOptimizerRule;
use super::utils::optimize_children;
use crate::error::Result;
#[derive(Default)]
pub struct AggregateStatistics {}
impl AggregateStatistics {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
impl PhysicalOptimizerRule for AggregateStatistics {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
execution_config: &ExecutionConfig,
) -> Result<Arc<dyn ExecutionPlan>> {
if let Some(partial_agg_exec) = take_optimizable(&*plan) {
let partial_agg_exec = partial_agg_exec
.as_any()
.downcast_ref::<HashAggregateExec>()
.expect("take_optimizable() ensures that this is a HashAggregateExec");
let stats = partial_agg_exec.input().statistics();
let mut projections = vec![];
for expr in partial_agg_exec.aggr_expr() {
if let Some((non_null_rows, name)) =
take_optimizable_column_count(&**expr, &stats)
{
projections.push((expressions::lit(non_null_rows), name.to_owned()));
} else if let Some((num_rows, name)) =
take_optimizable_table_count(&**expr, &stats)
{
projections.push((expressions::lit(num_rows), name.to_owned()));
} else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) {
projections.push((expressions::lit(min), name.to_owned()));
} else if let Some((max, name)) = take_optimizable_max(&**expr, &stats) {
projections.push((expressions::lit(max), name.to_owned()));
} else {
break;
}
}
if projections.len() == partial_agg_exec.aggr_expr().len() {
Ok(Arc::new(ProjectionExec::try_new(
projections,
Arc::new(EmptyExec::new(true, Arc::new(Schema::empty()))),
)?))
} else {
optimize_children(self, plan, execution_config)
}
} else {
optimize_children(self, plan, execution_config)
}
}
fn name(&self) -> &str {
"aggregate_statistics"
}
}
fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>> {
if let Some(final_agg_exec) = node.as_any().downcast_ref::<HashAggregateExec>() {
if final_agg_exec.mode() == &AggregateMode::Final
&& final_agg_exec.group_expr().is_empty()
{
let mut child = Arc::clone(final_agg_exec.input());
loop {
if let Some(partial_agg_exec) =
child.as_any().downcast_ref::<HashAggregateExec>()
{
if partial_agg_exec.mode() == &AggregateMode::Partial
&& partial_agg_exec.group_expr().is_empty()
{
let stats = partial_agg_exec.input().statistics();
if stats.is_exact {
return Some(child);
}
}
}
if let [ref childrens_child] = child.children().as_slice() {
child = Arc::clone(childrens_child);
} else {
break;
}
}
}
}
None
}
fn take_optimizable_table_count(
agg_expr: &dyn AggregateExpr,
stats: &Statistics,
) -> Option<(ScalarValue, &'static str)> {
if let (Some(num_rows), Some(casted_expr)) = (
stats.num_rows,
agg_expr.as_any().downcast_ref::<expressions::Count>(),
) {
if casted_expr.expressions().len() == 1 {
if let Some(lit_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &ScalarValue::UInt8(Some(1)) {
return Some((
ScalarValue::UInt64(Some(num_rows as u64)),
"COUNT(UInt8(1))",
));
}
}
}
}
None
}
fn take_optimizable_column_count(
agg_expr: &dyn AggregateExpr,
stats: &Statistics,
) -> Option<(ScalarValue, String)> {
if let (Some(num_rows), Some(col_stats), Some(casted_expr)) = (
stats.num_rows,
&stats.column_statistics,
agg_expr.as_any().downcast_ref::<expressions::Count>(),
) {
if casted_expr.expressions().len() == 1 {
if let Some(col_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
if let ColumnStatistics {
null_count: Some(val),
..
} = &col_stats[col_expr.index()]
{
let expr = format!("COUNT({})", col_expr.name());
return Some((
ScalarValue::UInt64(Some((num_rows - val) as u64)),
expr,
));
}
}
}
}
None
}
fn take_optimizable_min(
agg_expr: &dyn AggregateExpr,
stats: &Statistics,
) -> Option<(ScalarValue, String)> {
if let (Some(col_stats), Some(casted_expr)) = (
&stats.column_statistics,
agg_expr.as_any().downcast_ref::<expressions::Min>(),
) {
if casted_expr.expressions().len() == 1 {
if let Some(col_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
if let ColumnStatistics {
min_value: Some(val),
..
} = &col_stats[col_expr.index()]
{
return Some((val.clone(), format!("MIN({})", col_expr.name())));
}
}
}
}
None
}
fn take_optimizable_max(
agg_expr: &dyn AggregateExpr,
stats: &Statistics,
) -> Option<(ScalarValue, String)> {
if let (Some(col_stats), Some(casted_expr)) = (
&stats.column_statistics,
agg_expr.as_any().downcast_ref::<expressions::Max>(),
) {
if casted_expr.expressions().len() == 1 {
if let Some(col_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
if let ColumnStatistics {
max_value: Some(val),
..
} = &col_stats[col_expr.index()]
{
return Some((val.clone(), format!("MAX({})", col_expr.name())));
}
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use arrow::array::{Int32Array, UInt64Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use crate::error::Result;
use crate::execution::runtime_env::RuntimeEnv;
use crate::logical_plan::Operator;
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use crate::physical_plan::common;
use crate::physical_plan::expressions::Count;
use crate::physical_plan::filter::FilterExec;
use crate::physical_plan::hash_aggregate::HashAggregateExec;
use crate::physical_plan::memory::MemoryExec;
fn mock_data() -> Result<Arc<MemoryExec>> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(2), None])),
Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])),
],
)?;
Ok(Arc::new(MemoryExec::try_new(
&[vec![batch]],
Arc::clone(&schema),
None,
)?))
}
async fn assert_count_optim_success(
plan: HashAggregateExec,
nulls: bool,
) -> Result<()> {
let conf = ExecutionConfig::new();
let runtime = Arc::new(RuntimeEnv::default());
let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?;
let (col, count) = match nulls {
false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false), 3),
true => (Field::new("COUNT(a)", DataType::UInt64, false), 2),
};
assert!(optimized.as_any().is::<ProjectionExec>());
let result = common::collect(optimized.execute(0, runtime).await?).await?;
assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col])));
assert_eq!(
result[0]
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap()
.values(),
&[count]
);
Ok(())
}
fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc<dyn AggregateExpr> {
let expr = match schema {
None => expressions::lit(ScalarValue::UInt8(Some(1))),
Some(s) => expressions::col(col.unwrap(), s).unwrap(),
};
Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64))
}
#[tokio::test]
async fn test_count_partial_direct_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let partial_agg = HashAggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(None, None)],
source,
Arc::clone(&schema),
)?;
let final_agg = HashAggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(None, None)],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
assert_count_optim_success(final_agg, false).await?;
Ok(())
}
#[tokio::test]
async fn test_count_partial_with_nulls_direct_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let partial_agg = HashAggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
source,
Arc::clone(&schema),
)?;
let final_agg = HashAggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
assert_count_optim_success(final_agg, true).await?;
Ok(())
}
#[tokio::test]
async fn test_count_partial_indirect_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let partial_agg = HashAggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(None, None)],
source,
Arc::clone(&schema),
)?;
let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
let final_agg = HashAggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(None, None)],
Arc::new(coalesce),
Arc::clone(&schema),
)?;
assert_count_optim_success(final_agg, false).await?;
Ok(())
}
#[tokio::test]
async fn test_count_partial_with_nulls_indirect_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let partial_agg = HashAggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
source,
Arc::clone(&schema),
)?;
let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
let final_agg = HashAggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
Arc::new(coalesce),
Arc::clone(&schema),
)?;
assert_count_optim_success(final_agg, true).await?;
Ok(())
}
#[tokio::test]
async fn test_count_inexact_stat() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let filter = Arc::new(FilterExec::try_new(
expressions::binary(
expressions::col("a", &schema)?,
Operator::Gt,
expressions::lit(ScalarValue::from(1u32)),
&schema,
)?,
source,
)?);
let partial_agg = HashAggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(None, None)],
filter,
Arc::clone(&schema),
)?;
let final_agg = HashAggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(None, None)],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
let conf = ExecutionConfig::new();
let optimized =
AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?;
assert!(optimized.as_any().is::<HashAggregateExec>());
Ok(())
}
#[tokio::test]
async fn test_count_with_nulls_inexact_stat() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
let filter = Arc::new(FilterExec::try_new(
expressions::binary(
expressions::col("a", &schema)?,
Operator::Gt,
expressions::lit(ScalarValue::from(1u32)),
&schema,
)?,
source,
)?);
let partial_agg = HashAggregateExec::try_new(
AggregateMode::Partial,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
filter,
Arc::clone(&schema),
)?;
let final_agg = HashAggregateExec::try_new(
AggregateMode::Final,
vec![],
vec![count_expr(Some(&schema), Some("a"))],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
let conf = ExecutionConfig::new();
let optimized =
AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?;
assert!(optimized.as_any().is::<HashAggregateExec>());
Ok(())
}
}