pub mod groups_accumulator;
pub mod stats;
pub mod utils;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::type_coercion::aggregates::check_arg_count;
use datafusion_expr::ReversedUDAF;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator,
};
use std::fmt::Debug;
use std::{any::Any, sync::Arc};
use self::utils::{down_cast_any_ref, ordering_fields};
use crate::physical_expr::PhysicalExpr;
use crate::sort_expr::{LexOrdering, PhysicalSortExpr};
use crate::utils::reverse_order_bys;
use datafusion_common::exec_err;
use datafusion_expr::utils::AggregateOrderSensitivity;
#[allow(clippy::too_many_arguments)]
pub fn create_aggregate_expr(
fun: &AggregateUDF,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
sort_exprs: &[Expr],
ordering_req: &[PhysicalSortExpr],
schema: &Schema,
name: impl Into<String>,
ignore_nulls: bool,
is_distinct: bool,
) -> Result<Arc<dyn AggregateExpr>> {
debug_assert_eq!(sort_exprs.len(), ordering_req.len());
let input_exprs_types = input_phy_exprs
.iter()
.map(|arg| arg.data_type(schema))
.collect::<Result<Vec<_>>>()?;
check_arg_count(
fun.name(),
&input_exprs_types,
&fun.signature().type_signature,
)?;
let ordering_types = ordering_req
.iter()
.map(|e| e.expr.data_type(schema))
.collect::<Result<Vec<_>>>()?;
let ordering_fields = ordering_fields(ordering_req, &ordering_types);
Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
args: input_phy_exprs.to_vec(),
data_type: fun.return_type(&input_exprs_types)?,
name: name.into(),
schema: schema.clone(),
sort_exprs: sort_exprs.to_vec(),
ordering_req: ordering_req.to_vec(),
ignore_nulls,
ordering_fields,
is_distinct,
input_type: input_exprs_types[0].clone(),
}))
}
pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> {
fn as_any(&self) -> &dyn Any;
fn field(&self) -> Result<Field>;
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>>;
fn state_fields(&self) -> Result<Vec<Field>>;
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
None
}
fn order_sensitivity(&self) -> AggregateOrderSensitivity {
AggregateOrderSensitivity::Insensitive
}
fn with_beneficial_ordering(
self: Arc<Self>,
_requirement_satisfied: bool,
) -> Result<Option<Arc<dyn AggregateExpr>>> {
if self.order_bys().is_some() && self.order_sensitivity().is_beneficial() {
return exec_err!(
"Should implement with satisfied for aggregator :{:?}",
self.name()
);
}
Ok(None)
}
fn name(&self) -> &str {
"AggregateExpr: default name"
}
fn groups_accumulator_supported(&self) -> bool {
false
}
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
}
fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
None
}
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
not_impl_err!("Retractable Accumulator hasn't been implemented for {self:?} yet")
}
fn all_expressions(&self) -> AggregatePhysicalExpressions {
let args = self.expressions();
let order_bys = self.order_bys().unwrap_or(&[]);
let order_by_exprs = order_bys
.iter()
.map(|sort_expr| sort_expr.expr.clone())
.collect::<Vec<_>>();
AggregatePhysicalExpressions {
args,
order_by_exprs,
}
}
fn with_new_expressions(
&self,
_args: Vec<Arc<dyn PhysicalExpr>>,
_order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
) -> Option<Arc<dyn AggregateExpr>> {
None
}
}
pub struct AggregatePhysicalExpressions {
pub args: Vec<Arc<dyn PhysicalExpr>>,
pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
}
#[derive(Debug, Clone)]
pub struct AggregateFunctionExpr {
fun: AggregateUDF,
args: Vec<Arc<dyn PhysicalExpr>>,
data_type: DataType,
name: String,
schema: Schema,
sort_exprs: Vec<Expr>,
ordering_req: LexOrdering,
ignore_nulls: bool,
ordering_fields: Vec<Field>,
is_distinct: bool,
input_type: DataType,
}
impl AggregateFunctionExpr {
pub fn fun(&self) -> &AggregateUDF {
&self.fun
}
pub fn is_distinct(&self) -> bool {
self.is_distinct
}
}
impl AggregateExpr for AggregateFunctionExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.args.clone()
}
fn state_fields(&self) -> Result<Vec<Field>> {
let args = StateFieldsArgs {
name: &self.name,
input_type: &self.input_type,
return_type: &self.data_type,
ordering_fields: &self.ordering_fields,
is_distinct: self.is_distinct,
};
self.fun.state_fields(args)
}
fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, self.data_type.clone(), true))
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let acc_args = AccumulatorArgs {
data_type: &self.data_type,
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
args_num: self.args.len(),
name: &self.name,
};
self.fun.accumulator(acc_args)
}
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let args = AccumulatorArgs {
data_type: &self.data_type,
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
args_num: self.args.len(),
name: &self.name,
};
let accumulator = self.fun.create_sliding_accumulator(args)?;
if !accumulator.supports_retract_batch() {
return not_impl_err!(
"Aggregate can not be used as a sliding accumulator because \
`retract_batch` is not implemented: {}",
self.name
);
}
Ok(accumulator)
}
fn name(&self) -> &str {
&self.name
}
fn groups_accumulator_supported(&self) -> bool {
let args = AccumulatorArgs {
data_type: &self.data_type,
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
args_num: self.args.len(),
name: &self.name,
};
self.fun.groups_accumulator_supported(args)
}
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
let args = AccumulatorArgs {
data_type: &self.data_type,
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
args_num: self.args.len(),
name: &self.name,
};
self.fun.create_groups_accumulator(args)
}
fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
if self.ordering_req.is_empty() {
return None;
}
if !self.order_sensitivity().is_insensitive() {
return Some(&self.ordering_req);
}
None
}
fn order_sensitivity(&self) -> AggregateOrderSensitivity {
if !self.ordering_req.is_empty() {
self.fun.order_sensitivity()
} else {
AggregateOrderSensitivity::Insensitive
}
}
fn with_beneficial_ordering(
self: Arc<Self>,
beneficial_ordering: bool,
) -> Result<Option<Arc<dyn AggregateExpr>>> {
let Some(updated_fn) = self
.fun
.clone()
.with_beneficial_ordering(beneficial_ordering)?
else {
return Ok(None);
};
create_aggregate_expr(
&updated_fn,
&self.args,
&self.sort_exprs,
&self.ordering_req,
&self.schema,
self.name(),
self.ignore_nulls,
self.is_distinct,
)
.map(Some)
}
fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
match self.fun.reverse_udf() {
ReversedUDAF::NotSupported => None,
ReversedUDAF::Identical => Some(Arc::new(self.clone())),
ReversedUDAF::Reversed(reverse_udf) => {
let reverse_ordering_req = reverse_order_bys(&self.ordering_req);
let reverse_sort_exprs = self
.sort_exprs
.iter()
.map(|e| {
if let Expr::Sort(s) = e {
Expr::Sort(s.reverse())
} else {
unreachable!()
}
})
.collect::<Vec<_>>();
let mut name = self.name().to_string();
replace_order_by_clause(&mut name);
replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name());
let reverse_aggr = create_aggregate_expr(
&reverse_udf,
&self.args,
&reverse_sort_exprs,
&reverse_ordering_req,
&self.schema,
name,
self.ignore_nulls,
self.is_distinct,
)
.unwrap();
Some(reverse_aggr)
}
}
}
}
impl PartialEq<dyn Any> for AggregateFunctionExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.data_type == x.data_type
&& self.fun == x.fun
&& self.args.len() == x.args.len()
&& self
.args
.iter()
.zip(x.args.iter())
.all(|(this_arg, other_arg)| this_arg.eq(other_arg))
})
.unwrap_or(false)
}
}
fn replace_order_by_clause(order_by: &mut String) {
let suffixes = [
(" DESC NULLS FIRST]", " ASC NULLS LAST]"),
(" ASC NULLS FIRST]", " DESC NULLS LAST]"),
(" DESC NULLS LAST]", " ASC NULLS FIRST]"),
(" ASC NULLS LAST]", " DESC NULLS FIRST]"),
];
if let Some(start) = order_by.find("ORDER BY [") {
if let Some(end) = order_by[start..].find(']') {
let order_by_start = start + 9;
let order_by_end = start + end;
let column_order = &order_by[order_by_start..=order_by_end];
for (suffix, replacement) in suffixes {
if column_order.ends_with(suffix) {
let new_order = column_order.replace(suffix, replacement);
order_by.replace_range(order_by_start..=order_by_end, &new_order);
break;
}
}
}
}
}
fn replace_fn_name_clause(aggr_name: &mut String, fn_name_old: &str, fn_name_new: &str) {
*aggr_name = aggr_name.replace(fn_name_old, fn_name_new);
}