use crate::logical_plan::consumer::{
SubstraitConsumer, from_substrait_func_args, substrait_fun_name,
};
use datafusion::common::{DFSchema, ScalarValue, not_impl_datafusion_err, plan_err};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{Expr, SortExpr, expr};
use std::sync::Arc;
use substrait::proto::AggregateFunction;
pub async fn from_substrait_agg_func(
consumer: &impl SubstraitConsumer,
f: &AggregateFunction,
input_schema: &DFSchema,
filter: Option<Box<Expr>>,
order_by: Vec<SortExpr>,
distinct: bool,
) -> datafusion::common::Result<Arc<Expr>> {
let Some(fn_signature) = consumer
.get_extensions()
.functions
.get(&f.function_reference)
else {
return plan_err!(
"Aggregate function not registered: function anchor = {:?}",
f.function_reference
);
};
let fn_name = substrait_fun_name(fn_signature);
let udaf = consumer.get_function_registry().udaf(fn_name);
let udaf = udaf.map_err(|_| {
not_impl_datafusion_err!(
"Aggregate function {} is not supported: function anchor = {:?}",
fn_signature,
f.function_reference
)
})?;
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;
let args = if udaf.name() == "count" && args.is_empty() {
vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)]
} else {
args
};
Ok(Arc::new(Expr::AggregateFunction(
expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None),
)))
}