use crate::logical_plan::producer::{
SubstraitProducer, from_aggregate_function, substrait_field_ref,
};
use datafusion::common::{DFSchemaRef, internal_err, not_impl_err};
use datafusion::logical_expr::expr::Alias;
use datafusion::logical_expr::utils::powerset;
use datafusion::logical_expr::{Aggregate, Distinct, Expr, GroupingSet};
use substrait::proto::aggregate_rel::{Grouping, Measure};
use substrait::proto::rel::RelType;
use substrait::proto::{AggregateRel, Expression, Rel};
pub fn from_aggregate(
producer: &mut impl SubstraitProducer,
agg: &Aggregate,
) -> datafusion::common::Result<Box<Rel>> {
let input = producer.handle_plan(agg.input.as_ref())?;
let (grouping_expressions, groupings) =
to_substrait_groupings(producer, &agg.group_expr, agg.input.schema())?;
let measures = agg
.aggr_expr
.iter()
.map(|e| to_substrait_agg_measure(producer, e, agg.input.schema()))
.collect::<datafusion::common::Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
grouping_expressions,
groupings,
measures,
advanced_extension: None,
}))),
}))
}
pub fn from_distinct(
producer: &mut impl SubstraitProducer,
distinct: &Distinct,
) -> datafusion::common::Result<Box<Rel>> {
match distinct {
Distinct::All(plan) => {
let input = producer.handle_plan(plan.as_ref())?;
let grouping = (0..plan.schema().fields().len())
.map(substrait_field_ref)
.collect::<datafusion::common::Result<Vec<_>>>()?;
#[expect(deprecated)]
Ok(Box::new(Rel {
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
grouping_expressions: vec![],
groupings: vec![Grouping {
grouping_expressions: grouping,
expression_references: vec![],
}],
measures: vec![],
advanced_extension: None,
}))),
}))
}
Distinct::On(_) => not_impl_err!("Cannot convert Distinct::On"),
}
}
pub fn to_substrait_groupings(
producer: &mut impl SubstraitProducer,
exprs: &[Expr],
schema: &DFSchemaRef,
) -> datafusion::common::Result<(Vec<Expression>, Vec<Grouping>)> {
let mut ref_group_exprs = vec![];
let groupings = match exprs.len() {
1 => match &exprs[0] {
Expr::GroupingSet(gs) => match gs {
GroupingSet::Cube(set) => {
let cube_sets = powerset(set)?;
cube_sets
.iter()
.map(|set| {
parse_flat_grouping_exprs(
producer,
&set.iter().map(|v| (*v).clone()).collect::<Vec<_>>(),
schema,
&mut ref_group_exprs,
)
})
.collect::<datafusion::common::Result<Vec<_>>>()
}
GroupingSet::GroupingSets(sets) => sets
.iter()
.map(|set| {
parse_flat_grouping_exprs(
producer,
set,
schema,
&mut ref_group_exprs,
)
})
.collect::<datafusion::common::Result<Vec<_>>>(),
GroupingSet::Rollup(set) => {
let mut sets: Vec<Vec<Expr>> = vec![vec![]];
for i in 0..set.len() {
sets.push(set[..=i].to_vec());
}
sets.iter()
.rev()
.map(|set| {
parse_flat_grouping_exprs(
producer,
set,
schema,
&mut ref_group_exprs,
)
})
.collect::<datafusion::common::Result<Vec<_>>>()
}
},
_ => Ok(vec![parse_flat_grouping_exprs(
producer,
exprs,
schema,
&mut ref_group_exprs,
)?]),
},
_ => Ok(vec![parse_flat_grouping_exprs(
producer,
exprs,
schema,
&mut ref_group_exprs,
)?]),
}?;
Ok((ref_group_exprs, groupings))
}
pub fn parse_flat_grouping_exprs(
producer: &mut impl SubstraitProducer,
exprs: &[Expr],
schema: &DFSchemaRef,
ref_group_exprs: &mut Vec<Expression>,
) -> datafusion::common::Result<Grouping> {
let mut expression_references = vec![];
let mut grouping_expressions = vec![];
for e in exprs {
let rex = producer.handle_expr(e, schema)?;
grouping_expressions.push(rex.clone());
ref_group_exprs.push(rex);
expression_references.push((ref_group_exprs.len() - 1) as u32);
}
#[expect(deprecated)]
Ok(Grouping {
grouping_expressions,
expression_references,
})
}
pub fn to_substrait_agg_measure(
producer: &mut impl SubstraitProducer,
expr: &Expr,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Measure> {
match expr {
Expr::AggregateFunction(agg_fn) => {
from_aggregate_function(producer, agg_fn, schema)
}
Expr::Alias(Alias { expr, .. }) => {
to_substrait_agg_measure(producer, expr, schema)
}
_ => internal_err!(
"Expression must be compatible with aggregation. Unsupported expression: {:?}. Expressiontype: {}",
expr,
expr.variant_name()
),
}
}