pub(crate) mod groups_accumulator {
#[allow(unused_imports)]
pub(crate) mod accumulate {
pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
}
pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{
accumulate::NullState, GroupsAccumulatorAdapter,
};
}
pub(crate) mod stats {
pub use datafusion_functions_aggregate_common::stats::StatsType;
}
pub mod utils {
#[allow(deprecated)] pub use datafusion_functions_aggregate_common::utils::{
adjust_output_array, get_accum_scalar_values_as_arrays, get_sort_options,
ordering_fields, DecimalAverager, Hashable,
};
}
use std::fmt::Debug;
use std::sync::Arc;
use crate::expressions::Column;
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity};
use datafusion_expr_common::accumulator::Accumulator;
use datafusion_expr_common::groups_accumulator::GroupsAccumulator;
use datafusion_expr_common::type_coercion::aggregates::check_arg_count;
use datafusion_functions_aggregate_common::accumulator::{
AccumulatorArgs, StateFieldsArgs,
};
use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use datafusion_physical_expr_common::utils::reverse_order_bys;
#[derive(Debug, Clone)]
pub struct AggregateExprBuilder {
fun: Arc<AggregateUDF>,
args: Vec<Arc<dyn PhysicalExpr>>,
alias: Option<String>,
schema: SchemaRef,
ordering_req: LexOrdering,
ignore_nulls: bool,
is_distinct: bool,
is_reversed: bool,
}
impl AggregateExprBuilder {
pub fn new(fun: Arc<AggregateUDF>, args: Vec<Arc<dyn PhysicalExpr>>) -> Self {
Self {
fun,
args,
alias: None,
schema: Arc::new(Schema::empty()),
ordering_req: LexOrdering::default(),
ignore_nulls: false,
is_distinct: false,
is_reversed: false,
}
}
pub fn build(self) -> Result<AggregateFunctionExpr> {
let Self {
fun,
args,
alias,
schema,
ordering_req,
ignore_nulls,
is_distinct,
is_reversed,
} = self;
if args.is_empty() {
return internal_err!("args should not be empty");
}
let mut ordering_fields = vec![];
if !ordering_req.is_empty() {
let ordering_types = ordering_req
.iter()
.map(|e| e.expr.data_type(&schema))
.collect::<Result<Vec<_>>>()?;
ordering_fields =
utils::ordering_fields(ordering_req.as_ref(), &ordering_types);
}
let input_exprs_types = args
.iter()
.map(|arg| arg.data_type(&schema))
.collect::<Result<Vec<_>>>()?;
check_arg_count(
fun.name(),
&input_exprs_types,
&fun.signature().type_signature,
)?;
let data_type = fun.return_type(&input_exprs_types)?;
let is_nullable = fun.is_nullable();
let name = match alias {
None => return internal_err!("alias should be provided"),
Some(alias) => alias,
};
Ok(AggregateFunctionExpr {
fun: Arc::unwrap_or_clone(fun),
args,
data_type,
name,
schema: Arc::unwrap_or_clone(schema),
ordering_req,
ignore_nulls,
ordering_fields,
is_distinct,
input_types: input_exprs_types,
is_reversed,
is_nullable,
})
}
pub fn alias(mut self, alias: impl Into<String>) -> Self {
self.alias = Some(alias.into());
self
}
pub fn schema(mut self, schema: SchemaRef) -> Self {
self.schema = schema;
self
}
pub fn order_by(mut self, order_by: LexOrdering) -> Self {
self.ordering_req = order_by;
self
}
pub fn reversed(mut self) -> Self {
self.is_reversed = true;
self
}
pub fn with_reversed(mut self, is_reversed: bool) -> Self {
self.is_reversed = is_reversed;
self
}
pub fn distinct(mut self) -> Self {
self.is_distinct = true;
self
}
pub fn with_distinct(mut self, is_distinct: bool) -> Self {
self.is_distinct = is_distinct;
self
}
pub fn ignore_nulls(mut self) -> Self {
self.ignore_nulls = true;
self
}
pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self {
self.ignore_nulls = ignore_nulls;
self
}
}
#[derive(Debug, Clone)]
pub struct AggregateFunctionExpr {
fun: AggregateUDF,
args: Vec<Arc<dyn PhysicalExpr>>,
data_type: DataType,
name: String,
schema: Schema,
ordering_req: LexOrdering,
ignore_nulls: bool,
ordering_fields: Vec<Field>,
is_distinct: bool,
is_reversed: bool,
input_types: Vec<DataType>,
is_nullable: bool,
}
impl AggregateFunctionExpr {
pub fn fun(&self) -> &AggregateUDF {
&self.fun
}
pub fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.args.clone()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn is_distinct(&self) -> bool {
self.is_distinct
}
pub fn ignore_nulls(&self) -> bool {
self.ignore_nulls
}
pub fn is_reversed(&self) -> bool {
self.is_reversed
}
pub fn is_nullable(&self) -> bool {
self.is_nullable
}
pub fn field(&self) -> Field {
Field::new(&self.name, self.data_type.clone(), self.is_nullable)
}
pub fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let acc_args = AccumulatorArgs {
return_type: &self.data_type,
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
ordering_req: self.ordering_req.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};
self.fun.accumulator(acc_args)
}
pub fn state_fields(&self) -> Result<Vec<Field>> {
let args = StateFieldsArgs {
name: &self.name,
input_types: &self.input_types,
return_type: &self.data_type,
ordering_fields: &self.ordering_fields,
is_distinct: self.is_distinct,
};
self.fun.state_fields(args)
}
pub fn order_bys(&self) -> Option<&LexOrdering> {
if self.ordering_req.is_empty() {
return None;
}
if !self.order_sensitivity().is_insensitive() {
return Some(self.ordering_req.as_ref());
}
None
}
pub fn order_sensitivity(&self) -> AggregateOrderSensitivity {
if !self.ordering_req.is_empty() {
self.fun.order_sensitivity()
} else {
AggregateOrderSensitivity::Insensitive
}
}
pub fn with_beneficial_ordering(
self: Arc<Self>,
beneficial_ordering: bool,
) -> Result<Option<AggregateFunctionExpr>> {
let Some(updated_fn) = self
.fun
.clone()
.with_beneficial_ordering(beneficial_ordering)?
else {
return Ok(None);
};
AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec())
.order_by(self.ordering_req.clone())
.schema(Arc::new(self.schema.clone()))
.alias(self.name().to_string())
.with_ignore_nulls(self.ignore_nulls)
.with_distinct(self.is_distinct)
.with_reversed(self.is_reversed)
.build()
.map(Some)
}
pub fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let args = AccumulatorArgs {
return_type: &self.data_type,
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
ordering_req: self.ordering_req.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};
let accumulator = self.fun.create_sliding_accumulator(args)?;
if !accumulator.supports_retract_batch() {
return not_impl_err!(
"Aggregate can not be used as a sliding accumulator because \
`retract_batch` is not implemented: {}",
self.name
);
}
Ok(accumulator)
}
pub fn groups_accumulator_supported(&self) -> bool {
let args = AccumulatorArgs {
return_type: &self.data_type,
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
ordering_req: self.ordering_req.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};
self.fun.groups_accumulator_supported(args)
}
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
let args = AccumulatorArgs {
return_type: &self.data_type,
schema: &self.schema,
ignore_nulls: self.ignore_nulls,
ordering_req: self.ordering_req.as_ref(),
is_distinct: self.is_distinct,
name: &self.name,
is_reversed: self.is_reversed,
exprs: &self.args,
};
self.fun.create_groups_accumulator(args)
}
pub fn reverse_expr(&self) -> Option<AggregateFunctionExpr> {
match self.fun.reverse_udf() {
ReversedUDAF::NotSupported => None,
ReversedUDAF::Identical => Some(self.clone()),
ReversedUDAF::Reversed(reverse_udf) => {
let reverse_ordering_req = reverse_order_bys(self.ordering_req.as_ref());
let mut name = self.name().to_string();
if self.fun().name() == reverse_udf.name() {
} else {
replace_order_by_clause(&mut name);
}
replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name());
AggregateExprBuilder::new(reverse_udf, self.args.to_vec())
.order_by(reverse_ordering_req)
.schema(Arc::new(self.schema.clone()))
.alias(name)
.with_ignore_nulls(self.ignore_nulls)
.with_distinct(self.is_distinct)
.with_reversed(!self.is_reversed)
.build()
.ok()
}
}
}
pub fn all_expressions(&self) -> AggregatePhysicalExpressions {
let args = self.expressions();
let order_bys = self
.order_bys()
.cloned()
.unwrap_or_else(LexOrdering::default);
let order_by_exprs = order_bys
.iter()
.map(|sort_expr| Arc::clone(&sort_expr.expr))
.collect::<Vec<_>>();
AggregatePhysicalExpressions {
args,
order_by_exprs,
}
}
pub fn with_new_expressions(
&self,
_args: Vec<Arc<dyn PhysicalExpr>>,
_order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
) -> Option<AggregateFunctionExpr> {
None
}
pub fn get_minmax_desc(&self) -> Option<(Field, bool)> {
self.fun.is_descending().map(|flag| (self.field(), flag))
}
pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
self.fun.default_value(data_type)
}
pub fn set_monotonicity(&self) -> SetMonotonicity {
let field = self.field();
let data_type = field.data_type();
self.fun.inner().set_monotonicity(data_type)
}
pub fn get_result_ordering(&self, aggr_func_idx: usize) -> Option<PhysicalSortExpr> {
let monotonicity = self.set_monotonicity();
if monotonicity == SetMonotonicity::NotMonotonic {
return None;
}
let expr = Arc::new(Column::new(self.name(), aggr_func_idx));
let options =
SortOptions::new(monotonicity == SetMonotonicity::Decreasing, false);
Some(PhysicalSortExpr { expr, options })
}
}
pub struct AggregatePhysicalExpressions {
pub args: Vec<Arc<dyn PhysicalExpr>>,
pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
}
impl PartialEq for AggregateFunctionExpr {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.data_type == other.data_type
&& self.fun == other.fun
&& self.args.len() == other.args.len()
&& self
.args
.iter()
.zip(other.args.iter())
.all(|(this_arg, other_arg)| this_arg.eq(other_arg))
}
}
fn replace_order_by_clause(order_by: &mut String) {
let suffixes = [
(" DESC NULLS FIRST]", " ASC NULLS LAST]"),
(" ASC NULLS FIRST]", " DESC NULLS LAST]"),
(" DESC NULLS LAST]", " ASC NULLS FIRST]"),
(" ASC NULLS LAST]", " DESC NULLS FIRST]"),
];
if let Some(start) = order_by.find("ORDER BY [") {
if let Some(end) = order_by[start..].find(']') {
let order_by_start = start + 9;
let order_by_end = start + end;
let column_order = &order_by[order_by_start..=order_by_end];
for (suffix, replacement) in suffixes {
if column_order.ends_with(suffix) {
let new_order = column_order.replace(suffix, replacement);
order_by.replace_range(order_by_start..=order_by_end, &new_order);
break;
}
}
}
}
}
fn replace_fn_name_clause(aggr_name: &mut String, fn_name_old: &str, fn_name_new: &str) {
*aggr_name = aggr_name.replace(fn_name_old, fn_name_new);
}