use crate::input::proto::substrait;
use crate::output::comment;
use crate::output::diagnostic;
use crate::output::type_system::data;
use crate::parse::context;
use crate::parse::expressions;
use crate::parse::expressions::functions;
use std::collections::HashSet;
enum FieldType {
GroupedField,
NullableGroupedField,
Measure,
GroupingSetIndex,
}
struct Field {
expression: expressions::Expression,
data_type: data::Type,
field_type: FieldType,
}
fn parse_measure(
x: &substrait::aggregate_rel::Measure,
y: &mut context::Context,
) -> diagnostic::Result<expressions::Expression> {
let (n, e) = proto_required_field!(x, y, measure, functions::parse_aggregate_function);
let data_type = n.data_type();
let expression = e.unwrap_or_default();
y.set_data_type(data_type.clone());
if x.filter.is_some() {
let (n, e) = proto_required_field!(x, y, filter, expressions::parse_predicate);
let filter = e.unwrap_or_default();
let filter_type = n.data_type();
summary!(
y,
"Applies aggregate function {expression:#} to all rows for \
which {filter:#} returns true."
);
let filtered_expression = expressions::Expression::Function(
String::from("filter"),
vec![
expressions::functions::FunctionArgument::Value(filter_type, filter),
expressions::functions::FunctionArgument::Value(data_type, expression),
],
);
describe!(
y,
Expression,
"Filtered aggregate function: {filtered_expression}"
);
Ok(filtered_expression)
} else {
summary!(y, "Applies aggregate function {expression:#} to all rows.");
describe!(y, Expression, "Aggregate function: {expression}");
Ok(expression)
}
}
#[allow(deprecated)]
pub fn parse_aggregate_rel(
x: &substrait::AggregateRel,
y: &mut context::Context,
) -> diagnostic::Result<()> {
let in_type = handle_rel_input!(x, y);
y.set_schema(in_type);
let mut grouping_set_expressions: Vec<substrait::Expression> = vec![];
let mut fields = vec![];
let mut sets = vec![];
proto_repeated_field!(x, y, groupings, |x, y| {
sets.push(vec![]);
proto_repeated_field!(x, y, grouping_expressions, |x, y| {
let result = expressions::parse_expression(x, y);
let index = grouping_set_expressions
.iter()
.enumerate()
.find(|(_, e)| e == &x)
.map(|(i, _)| i)
.unwrap_or_else(|| {
grouping_set_expressions.push(x.clone());
fields.push(Field {
expression: result.as_ref().cloned().unwrap_or_default(),
data_type: y.data_type(),
field_type: FieldType::NullableGroupedField,
});
fields.len() - 1
});
sets.last_mut().unwrap().push(index);
result
});
match x.grouping_expressions.len() {
0 => summary!(y, "A grouping set that aggregates all rows."),
1 => summary!(
y,
"A grouping set that aggregates all rows for which \
the expression yields the same value."
),
x => summary!(
y,
"A grouping set that aggregates all rows for which \
the {x} expressions yield the same tuple of values."
),
}
Ok(())
});
drop(grouping_set_expressions);
let sets = sets;
let mut set_iter = sets.iter();
if let Some(first_set) = set_iter.next() {
let mut fields_in_all_sets = first_set.iter().cloned().collect::<HashSet<_>>();
for set in set_iter {
fields_in_all_sets = &fields_in_all_sets & &set.iter().cloned().collect::<HashSet<_>>();
}
for index in fields_in_all_sets {
fields[index].field_type = FieldType::GroupedField;
}
}
proto_repeated_field!(x, y, measures, |x, y| {
let result = parse_measure(x, y);
fields.push(Field {
expression: result.as_ref().cloned().unwrap_or_default(),
data_type: y.data_type(),
field_type: FieldType::Measure,
});
result
});
if fields.is_empty() {
diagnostic!(
y,
Error,
RelationInvalid,
"aggregate relations must have at least one grouping expression or measure"
);
}
if sets.len() > 1 {
fields.push(Field {
expression: expressions::Expression::Function(String::from("group_index"), vec![]),
data_type: data::new_integer(),
field_type: FieldType::GroupingSetIndex,
});
}
let fields = fields;
y.set_schema(data::new_struct(
fields.iter().map(|x| {
if matches!(x.field_type, FieldType::NullableGroupedField) {
x.data_type.make_nullable()
} else {
x.data_type.clone()
}
}),
false,
));
if x.groupings.is_empty() {
describe!(y, Relation, "Aggregate");
summary!(
y,
"This relation computes {} aggregate function(s) over all rows, \
returning a single row.",
x.measures.len()
);
} else if x.measures.is_empty() {
describe!(y, Relation, "Group");
summary!(
y,
"This relation groups rows from the input by the result of some \
expression(s)."
);
} else {
describe!(y, Relation, "Group & aggregate");
summary!(
y,
"This relation groups rows from the input by the result of some \
expression(s), and also compures {} aggregate function(s) over \
each group.",
x.measures.len()
);
}
let mut comment = comment::Comment::new()
.plain("The significance of the returned field(s) is:")
.lo();
for (index, field) in fields.iter().enumerate() {
comment = comment.li().plain(match field.field_type {
FieldType::GroupedField => format!(
"Field {index}: value of grouping expression {:#}.",
field.expression
),
FieldType::NullableGroupedField => format!(
"Field {index}: value of grouping expression {:#} if it is \
part of the grouping set being returned, null otherwise.",
field.expression
),
FieldType::Measure => {
if x.groupings.is_empty() {
format!(
"Field {index}: result of aggregate function {:#} \
applied to all input rows.",
field.expression
)
} else {
format!(
"Field {index}: result of aggregate function {:#} \
applied to the rows from the current group.",
field.expression
)
}
}
FieldType::GroupingSetIndex => {
format!(
"Field {index}: integer between 0 and {} inclusive, \
representing the index of the matched grouping set.",
x.groupings.len() - 1
)
}
});
}
y.push_summary(comment.lc());
handle_rel_common!(x, y);
handle_advanced_extension!(x, y);
Ok(())
}