use crate::db::sql::lowering::{
LoweredBaseQueryShape, LoweredSqlCommand, LoweredSqlCommandInner, PreparedSqlStatement,
SqlLoweringError, analyze_lowered_expr,
predicate::{lower_sql_where_bool_expr, lower_sql_where_expr},
};
#[cfg(test)]
use crate::{db::query::intent::Query, traits::EntityKind};
use crate::{
db::{
predicate::MissingRowPolicy,
query::{
builder::{
AggregateExpr,
aggregate::{avg, count, count_by, max_by, min_by, sum},
},
intent::StructuralQuery,
plan::{
AggregateKind, FieldSlot,
expr::{
Alias, Expr, ProjectionField, ProjectionSpec,
canonicalize_aggregate_input_expr, compile_scalar_projection_expr,
},
lower_global_aggregate_projection, resolve_aggregate_target_field_slot,
},
},
sql::{
lowering::expr::{SqlExprPhase, lower_sql_expr},
lowering::select::{
lower_global_aggregate_having_expr, lower_order_terms, lower_select_item_expr,
},
parser::{
SqlAggregateCall, SqlAggregateKind, SqlExplainMode, SqlExpr, SqlProjection,
SqlSelectItem, SqlSelectStatement, SqlStatement,
},
},
},
model::entity::EntityModel,
};
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) enum SqlGlobalAggregateTerminal {
CountRows {
filter_expr: Option<Expr>,
},
CountField {
field: String,
filter_expr: Option<Expr>,
distinct: bool,
},
CountExpr {
input_expr: Expr,
filter_expr: Option<Expr>,
distinct: bool,
},
SumField {
field: String,
filter_expr: Option<Expr>,
distinct: bool,
},
SumExpr {
input_expr: Expr,
filter_expr: Option<Expr>,
distinct: bool,
},
AvgField {
field: String,
filter_expr: Option<Expr>,
distinct: bool,
},
AvgExpr {
input_expr: Expr,
filter_expr: Option<Expr>,
distinct: bool,
},
MinField {
field: String,
filter_expr: Option<Expr>,
},
MinExpr {
input_expr: Expr,
filter_expr: Option<Expr>,
},
MaxField {
field: String,
filter_expr: Option<Expr>,
},
MaxExpr {
input_expr: Expr,
filter_expr: Option<Expr>,
},
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum PreparedSqlScalarAggregateDomain {
ExistingRows,
ProjectionField,
NumericField,
ScalarExtremaValue,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum PreparedSqlScalarAggregateOrderingRequirement {
None,
FieldOrder,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum PreparedSqlScalarAggregateRowSource {
ExistingRows,
ProjectedField,
NumericField,
ExtremalWinnerField,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum PreparedSqlScalarAggregateEmptySetBehavior {
Zero,
Null,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum PreparedSqlScalarAggregateDescriptorShape {
CountRows,
CountField,
SumField,
AvgField,
MinField,
MaxField,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum PreparedSqlScalarAggregateRuntimeDescriptor {
CountRows,
CountField,
NumericField { kind: AggregateKind },
ExtremalWinnerField { kind: AggregateKind },
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct PreparedSqlScalarAggregateDescriptorPolicy {
domain: PreparedSqlScalarAggregateDomain,
ordering_requirement: PreparedSqlScalarAggregateOrderingRequirement,
row_source: PreparedSqlScalarAggregateRowSource,
empty_set_behavior: PreparedSqlScalarAggregateEmptySetBehavior,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) struct PreparedSqlScalarAggregateStrategy {
target_slot: Option<FieldSlot>,
input_expr: Option<Expr>,
filter_expr: Option<Expr>,
distinct_input: bool,
domain: PreparedSqlScalarAggregateDomain,
ordering_requirement: PreparedSqlScalarAggregateOrderingRequirement,
row_source: PreparedSqlScalarAggregateRowSource,
empty_set_behavior: PreparedSqlScalarAggregateEmptySetBehavior,
descriptor_shape: PreparedSqlScalarAggregateDescriptorShape,
}
impl PreparedSqlScalarAggregateStrategy {
const fn descriptor_policy(
descriptor_shape: PreparedSqlScalarAggregateDescriptorShape,
) -> PreparedSqlScalarAggregateDescriptorPolicy {
match descriptor_shape {
PreparedSqlScalarAggregateDescriptorShape::CountRows => {
PreparedSqlScalarAggregateDescriptorPolicy {
domain: PreparedSqlScalarAggregateDomain::ExistingRows,
ordering_requirement: PreparedSqlScalarAggregateOrderingRequirement::None,
row_source: PreparedSqlScalarAggregateRowSource::ExistingRows,
empty_set_behavior: PreparedSqlScalarAggregateEmptySetBehavior::Zero,
}
}
PreparedSqlScalarAggregateDescriptorShape::CountField => {
PreparedSqlScalarAggregateDescriptorPolicy {
domain: PreparedSqlScalarAggregateDomain::ProjectionField,
ordering_requirement: PreparedSqlScalarAggregateOrderingRequirement::None,
row_source: PreparedSqlScalarAggregateRowSource::ProjectedField,
empty_set_behavior: PreparedSqlScalarAggregateEmptySetBehavior::Zero,
}
}
PreparedSqlScalarAggregateDescriptorShape::SumField
| PreparedSqlScalarAggregateDescriptorShape::AvgField => {
PreparedSqlScalarAggregateDescriptorPolicy {
domain: PreparedSqlScalarAggregateDomain::NumericField,
ordering_requirement: PreparedSqlScalarAggregateOrderingRequirement::None,
row_source: PreparedSqlScalarAggregateRowSource::NumericField,
empty_set_behavior: PreparedSqlScalarAggregateEmptySetBehavior::Null,
}
}
PreparedSqlScalarAggregateDescriptorShape::MinField
| PreparedSqlScalarAggregateDescriptorShape::MaxField => {
PreparedSqlScalarAggregateDescriptorPolicy {
domain: PreparedSqlScalarAggregateDomain::ScalarExtremaValue,
ordering_requirement: PreparedSqlScalarAggregateOrderingRequirement::FieldOrder,
row_source: PreparedSqlScalarAggregateRowSource::ExtremalWinnerField,
empty_set_behavior: PreparedSqlScalarAggregateEmptySetBehavior::Null,
}
}
}
}
pub(in crate::db) const fn from_resolved_shape(
target_slot: Option<FieldSlot>,
input_expr: Option<Expr>,
filter_expr: Option<Expr>,
distinct_input: bool,
descriptor_shape: PreparedSqlScalarAggregateDescriptorShape,
) -> Self {
let policy = Self::descriptor_policy(descriptor_shape);
Self {
target_slot,
input_expr,
filter_expr,
distinct_input,
domain: policy.domain,
ordering_requirement: policy.ordering_requirement,
row_source: policy.row_source,
empty_set_behavior: policy.empty_set_behavior,
descriptor_shape,
}
}
#[expect(
clippy::too_many_lines,
reason = "aggregate terminal preparation keeps field and expression variants on one owner-local boundary"
)]
fn from_lowered_terminal(
model: &'static EntityModel,
terminal: &SqlGlobalAggregateTerminal,
) -> Result<Self, SqlLoweringError> {
let resolve_target_slot = |field: &str| {
resolve_aggregate_target_field_slot(model, field).map_err(SqlLoweringError::from)
};
let validate_input_expr = |input_expr: &Expr| {
if let Some(field) = analyze_lowered_expr(input_expr, Some(model)).first_unknown_field()
{
return Err(SqlLoweringError::unknown_field(field));
}
if compile_scalar_projection_expr(model, input_expr).is_none() {
return Err(SqlLoweringError::unsupported_aggregate_input_expressions());
}
Ok(())
};
match terminal {
SqlGlobalAggregateTerminal::CountRows { filter_expr } => Ok(Self::from_resolved_shape(
None,
None,
filter_expr.clone(),
false,
PreparedSqlScalarAggregateDescriptorShape::CountRows,
)),
SqlGlobalAggregateTerminal::CountField {
field,
filter_expr,
distinct,
} => {
let target_slot = resolve_target_slot(field.as_str())?;
Ok(Self::from_resolved_shape(
Some(target_slot),
None,
filter_expr.clone(),
*distinct,
PreparedSqlScalarAggregateDescriptorShape::CountField,
))
}
SqlGlobalAggregateTerminal::CountExpr {
input_expr,
filter_expr,
distinct,
} => {
validate_input_expr(input_expr)?;
Ok(Self::from_resolved_shape(
None,
Some(input_expr.clone()),
filter_expr.clone(),
*distinct,
PreparedSqlScalarAggregateDescriptorShape::CountField,
))
}
SqlGlobalAggregateTerminal::SumField {
field,
filter_expr,
distinct,
} => {
let target_slot = resolve_target_slot(field.as_str())?;
Ok(Self::from_resolved_shape(
Some(target_slot),
None,
filter_expr.clone(),
*distinct,
PreparedSqlScalarAggregateDescriptorShape::SumField,
))
}
SqlGlobalAggregateTerminal::SumExpr {
input_expr,
filter_expr,
distinct,
} => {
validate_input_expr(input_expr)?;
Ok(Self::from_resolved_shape(
None,
Some(input_expr.clone()),
filter_expr.clone(),
*distinct,
PreparedSqlScalarAggregateDescriptorShape::SumField,
))
}
SqlGlobalAggregateTerminal::AvgField {
field,
filter_expr,
distinct,
} => {
let target_slot = resolve_target_slot(field.as_str())?;
Ok(Self::from_resolved_shape(
Some(target_slot),
None,
filter_expr.clone(),
*distinct,
PreparedSqlScalarAggregateDescriptorShape::AvgField,
))
}
SqlGlobalAggregateTerminal::AvgExpr {
input_expr,
filter_expr,
distinct,
} => {
validate_input_expr(input_expr)?;
Ok(Self::from_resolved_shape(
None,
Some(input_expr.clone()),
filter_expr.clone(),
*distinct,
PreparedSqlScalarAggregateDescriptorShape::AvgField,
))
}
SqlGlobalAggregateTerminal::MinField { field, filter_expr } => {
let target_slot = resolve_target_slot(field.as_str())?;
Ok(Self::from_resolved_shape(
Some(target_slot),
None,
filter_expr.clone(),
false,
PreparedSqlScalarAggregateDescriptorShape::MinField,
))
}
SqlGlobalAggregateTerminal::MinExpr {
input_expr,
filter_expr,
} => {
validate_input_expr(input_expr)?;
Ok(Self::from_resolved_shape(
None,
Some(input_expr.clone()),
filter_expr.clone(),
false,
PreparedSqlScalarAggregateDescriptorShape::MinField,
))
}
SqlGlobalAggregateTerminal::MaxField { field, filter_expr } => {
let target_slot = resolve_target_slot(field.as_str())?;
Ok(Self::from_resolved_shape(
Some(target_slot),
None,
filter_expr.clone(),
false,
PreparedSqlScalarAggregateDescriptorShape::MaxField,
))
}
SqlGlobalAggregateTerminal::MaxExpr {
input_expr,
filter_expr,
} => {
validate_input_expr(input_expr)?;
Ok(Self::from_resolved_shape(
None,
Some(input_expr.clone()),
filter_expr.clone(),
false,
PreparedSqlScalarAggregateDescriptorShape::MaxField,
))
}
}
}
#[must_use]
pub(crate) const fn target_slot(&self) -> Option<&FieldSlot> {
self.target_slot.as_ref()
}
#[must_use]
pub(crate) const fn input_expr(&self) -> Option<&Expr> {
self.input_expr.as_ref()
}
#[must_use]
pub(crate) const fn filter_expr(&self) -> Option<&Expr> {
self.filter_expr.as_ref()
}
#[must_use]
pub(crate) const fn is_distinct(&self) -> bool {
self.distinct_input
}
#[cfg(test)]
#[must_use]
pub(crate) const fn domain(&self) -> PreparedSqlScalarAggregateDomain {
self.domain
}
#[cfg(test)]
#[must_use]
pub(crate) const fn descriptor_shape(&self) -> PreparedSqlScalarAggregateDescriptorShape {
self.descriptor_shape
}
#[must_use]
pub(crate) const fn runtime_descriptor(&self) -> PreparedSqlScalarAggregateRuntimeDescriptor {
match self.descriptor_shape {
PreparedSqlScalarAggregateDescriptorShape::CountRows => {
PreparedSqlScalarAggregateRuntimeDescriptor::CountRows
}
PreparedSqlScalarAggregateDescriptorShape::CountField => {
PreparedSqlScalarAggregateRuntimeDescriptor::CountField
}
PreparedSqlScalarAggregateDescriptorShape::SumField => {
PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
kind: AggregateKind::Sum,
}
}
PreparedSqlScalarAggregateDescriptorShape::AvgField => {
PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
kind: AggregateKind::Avg,
}
}
PreparedSqlScalarAggregateDescriptorShape::MinField => {
PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
kind: AggregateKind::Min,
}
}
PreparedSqlScalarAggregateDescriptorShape::MaxField => {
PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
kind: AggregateKind::Max,
}
}
}
}
#[must_use]
pub(crate) const fn aggregate_kind(&self) -> AggregateKind {
match self.descriptor_shape {
PreparedSqlScalarAggregateDescriptorShape::CountRows
| PreparedSqlScalarAggregateDescriptorShape::CountField => AggregateKind::Count,
PreparedSqlScalarAggregateDescriptorShape::SumField => AggregateKind::Sum,
PreparedSqlScalarAggregateDescriptorShape::AvgField => AggregateKind::Avg,
PreparedSqlScalarAggregateDescriptorShape::MinField => AggregateKind::Min,
PreparedSqlScalarAggregateDescriptorShape::MaxField => AggregateKind::Max,
}
}
#[must_use]
pub(crate) fn projected_field(&self) -> Option<&str> {
self.target_slot().map(FieldSlot::field)
}
#[cfg(test)]
#[must_use]
pub(crate) const fn ordering_requirement(
&self,
) -> PreparedSqlScalarAggregateOrderingRequirement {
self.ordering_requirement
}
#[cfg(test)]
#[must_use]
pub(crate) const fn row_source(&self) -> PreparedSqlScalarAggregateRowSource {
self.row_source
}
#[cfg(test)]
#[must_use]
pub(crate) const fn empty_set_behavior(&self) -> PreparedSqlScalarAggregateEmptySetBehavior {
self.empty_set_behavior
}
}
#[derive(Clone, Debug)]
pub(crate) struct LoweredSqlGlobalAggregateCommand {
pub(in crate::db::sql::lowering) query: LoweredBaseQueryShape,
pub(in crate::db::sql::lowering) terminals: Vec<SqlGlobalAggregateTerminal>,
pub(in crate::db::sql::lowering) projection: ProjectionSpec,
pub(in crate::db::sql::lowering) having: Option<Expr>,
#[cfg(test)]
pub(in crate::db::sql::lowering) output_remap: Vec<usize>,
}
impl LoweredSqlGlobalAggregateCommand {
fn from_select_statement(statement: SqlSelectStatement) -> Result<Self, SqlLoweringError> {
let SqlSelectStatement {
projection,
projection_aliases,
predicate,
distinct,
group_by,
having,
order_by,
limit,
offset,
entity: _,
} = statement;
if distinct {
return Err(SqlLoweringError::unsupported_select_distinct());
}
if !group_by.is_empty() {
return Err(SqlLoweringError::global_aggregate_does_not_support_group_by());
}
let projection_for_having = projection.clone();
let order_by = strip_inert_global_aggregate_output_order_terms(
order_by,
&projection_for_having,
projection_aliases.as_slice(),
)?;
let mut lowered_terminals =
LoweredSqlGlobalAggregateTerminals::from_projection(projection, &projection_aliases)?;
let having =
lower_global_aggregate_having_expr(having, &projection_for_having, |aggregate| {
resolve_or_insert_global_aggregate_terminal_index_from_expr(
&mut lowered_terminals.terminals,
aggregate,
)
})?;
Ok(Self {
query: LoweredBaseQueryShape {
filter_expr: predicate
.as_ref()
.map(lower_sql_where_bool_expr)
.transpose()?,
predicate: predicate.as_ref().map(lower_sql_where_expr).transpose()?,
order_by: lower_order_terms(order_by)?,
limit,
offset,
},
terminals: lowered_terminals.terminals,
projection: lowered_terminals.projection,
having,
#[cfg(test)]
output_remap: lowered_terminals.output_remap,
})
}
#[cfg(test)]
fn into_typed<E: EntityKind>(
self,
consistency: MissingRowPolicy,
) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
let Self {
query,
terminals,
projection,
having,
output_remap,
} = self;
let terminals = terminals
.iter()
.map(|terminal| {
PreparedSqlScalarAggregateStrategy::from_lowered_terminal(E::MODEL, terminal)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(SqlGlobalAggregateCommand {
query: Query::from_inner(crate::db::sql::lowering::apply_lowered_base_query_shape(
StructuralQuery::new(E::MODEL, consistency),
query,
)),
terminals,
projection,
having,
output_remap,
})
}
fn into_structural(
self,
model: &'static EntityModel,
consistency: MissingRowPolicy,
) -> Result<SqlGlobalAggregateCommandCore, SqlLoweringError> {
let Self {
query,
terminals,
projection,
having,
#[cfg(test)]
output_remap: _,
} = self;
let strategies = terminals
.iter()
.map(|terminal| {
PreparedSqlScalarAggregateStrategy::from_lowered_terminal(model, terminal)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(SqlGlobalAggregateCommandCore {
query: crate::db::sql::lowering::apply_lowered_base_query_shape(
StructuralQuery::new(model, consistency),
query,
),
strategies,
projection,
having,
})
}
}
fn strip_inert_global_aggregate_output_order_terms(
order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
projection: &SqlProjection,
projection_aliases: &[Option<String>],
) -> Result<Vec<crate::db::sql::parser::SqlOrderTerm>, SqlLoweringError> {
let inert_targets =
collect_global_aggregate_output_order_targets(projection, projection_aliases)?;
Ok(order_by
.into_iter()
.filter(|term| !inert_targets.iter().any(|target| target == &term.field))
.collect())
}
fn collect_global_aggregate_output_order_targets(
projection: &SqlProjection,
projection_aliases: &[Option<String>],
) -> Result<Vec<crate::db::sql::parser::SqlExpr>, SqlLoweringError> {
let SqlProjection::Items(items) = projection else {
return Ok(Vec::new());
};
let mut targets = Vec::with_capacity(items.len());
for (item, alias) in items.iter().zip(projection_aliases.iter()) {
let expr = lower_select_item_expr(item, SqlExprPhase::PostAggregate)?;
let analysis = analyze_lowered_expr(&expr, None);
if !analysis.contains_aggregate() || analysis.references_direct_fields() {
continue;
}
targets.push(crate::db::sql::parser::SqlExpr::from_select_item(item));
if let Some(alias) = alias {
targets.push(crate::db::sql::parser::SqlExpr::Field(alias.clone()));
}
}
Ok(targets)
}
enum LoweredSqlAggregateShape {
CountRows {
filter_expr: Option<Expr>,
},
CountField {
field: String,
filter_expr: Option<Expr>,
distinct: bool,
},
FieldTarget {
kind: SqlAggregateKind,
field: String,
filter_expr: Option<Expr>,
distinct: bool,
},
ExpressionInput {
kind: SqlAggregateKind,
input_expr: Expr,
filter_expr: Option<Expr>,
distinct: bool,
},
}
#[cfg(test)]
#[derive(Debug)]
pub(crate) struct SqlGlobalAggregateCommand<E: EntityKind> {
query: Query<E>,
terminals: Vec<PreparedSqlScalarAggregateStrategy>,
projection: ProjectionSpec,
having: Option<Expr>,
output_remap: Vec<usize>,
}
#[cfg(test)]
impl<E: EntityKind> SqlGlobalAggregateCommand<E> {
#[must_use]
pub(crate) const fn query(&self) -> &Query<E> {
&self.query
}
#[must_use]
pub(crate) fn terminals(&self) -> &[PreparedSqlScalarAggregateStrategy] {
self.terminals.as_slice()
}
#[must_use]
#[cfg(test)]
pub(crate) const fn projection(&self) -> &ProjectionSpec {
&self.projection
}
#[must_use]
#[cfg(test)]
pub(crate) const fn having(&self) -> Option<&Expr> {
self.having.as_ref()
}
#[cfg(test)]
#[must_use]
pub(crate) fn output_remap(&self) -> &[usize] {
self.output_remap.as_slice()
}
#[cfg(test)]
#[must_use]
pub(crate) fn terminal(&self) -> &PreparedSqlScalarAggregateStrategy {
self.terminals
.first()
.expect("global aggregate command must contain at least one terminal")
}
}
#[derive(Clone, Debug)]
pub(crate) struct SqlGlobalAggregateCommandCore {
query: StructuralQuery,
strategies: Vec<PreparedSqlScalarAggregateStrategy>,
projection: ProjectionSpec,
having: Option<Expr>,
}
impl SqlGlobalAggregateCommandCore {
#[must_use]
pub(in crate::db) const fn query(&self) -> &StructuralQuery {
&self.query
}
#[must_use]
pub(in crate::db) const fn projection(&self) -> &ProjectionSpec {
&self.projection
}
#[must_use]
pub(in crate::db) const fn having(&self) -> Option<&Expr> {
self.having.as_ref()
}
#[must_use]
pub(in crate::db) const fn strategies(&self) -> &[PreparedSqlScalarAggregateStrategy] {
self.strategies.as_slice()
}
}
impl SqlStatement {
#[must_use]
pub(in crate::db) fn is_global_aggregate_lane_shape(&self) -> bool {
let Self::Select(statement) = self else {
return false;
};
statement.is_global_aggregate_lane_shape()
}
}
impl SqlSelectStatement {
#[must_use]
fn is_global_aggregate_lane_shape(&self) -> bool {
if self.distinct || !self.group_by.is_empty() {
return false;
}
if !self.might_require_global_aggregate_lane() {
return false;
}
LoweredSqlGlobalAggregateCommand::from_select_statement(self.clone()).is_ok()
}
fn might_require_global_aggregate_lane(&self) -> bool {
if !self.having.is_empty() {
return true;
}
match &self.projection {
SqlProjection::Items(items) => items.iter().any(SqlSelectItem::contains_aggregate),
SqlProjection::All => false,
}
}
}
pub(crate) fn bind_lowered_sql_explain_global_aggregate_structural(
lowered: &LoweredSqlCommand,
model: &'static EntityModel,
consistency: MissingRowPolicy,
) -> Result<Option<(SqlExplainMode, bool, SqlGlobalAggregateCommandCore)>, SqlLoweringError> {
let LoweredSqlCommandInner::ExplainGlobalAggregate {
mode,
verbose,
command,
} = &lowered.0
else {
return Ok(None);
};
Ok(Some((
*mode,
*verbose,
bind_lowered_sql_global_aggregate_command_structural(model, command.clone(), consistency)?,
)))
}
#[cfg(test)]
pub(crate) fn compile_sql_global_aggregate_command<E: EntityKind>(
sql: &str,
consistency: MissingRowPolicy,
) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
let statement = crate::db::sql::parser::parse_sql(sql)?;
let prepared = crate::db::sql::lowering::prepare_sql_statement(statement, E::MODEL.name())?;
compile_sql_global_aggregate_command_from_prepared::<E>(prepared, consistency)
}
#[cfg(test)]
pub(crate) fn compile_sql_global_aggregate_command_from_prepared<E: EntityKind>(
prepared: PreparedSqlStatement,
consistency: MissingRowPolicy,
) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
let SqlStatement::Select(statement) = prepared.statement else {
return Err(SqlLoweringError::unsupported_select_projection());
};
bind_lowered_sql_global_aggregate_command::<E>(
lower_global_aggregate_select_shape(statement)?,
consistency,
)
}
pub(in crate::db) fn compile_sql_global_aggregate_command_core_from_prepared(
prepared: PreparedSqlStatement,
model: &'static EntityModel,
consistency: MissingRowPolicy,
) -> Result<SqlGlobalAggregateCommandCore, SqlLoweringError> {
let SqlStatement::Select(statement) = prepared.statement else {
return Err(SqlLoweringError::unsupported_select_projection());
};
bind_lowered_sql_global_aggregate_command_structural(
model,
lower_global_aggregate_select_shape(statement)?,
consistency,
)
}
pub(in crate::db::sql::lowering) fn lower_global_aggregate_select_shape(
statement: SqlSelectStatement,
) -> Result<LoweredSqlGlobalAggregateCommand, SqlLoweringError> {
LoweredSqlGlobalAggregateCommand::from_select_statement(statement)
}
#[cfg(test)]
pub(in crate::db::sql::lowering) fn bind_lowered_sql_global_aggregate_command<E: EntityKind>(
lowered: LoweredSqlGlobalAggregateCommand,
consistency: MissingRowPolicy,
) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
lowered.into_typed::<E>(consistency)
}
fn bind_lowered_sql_global_aggregate_command_structural(
model: &'static EntityModel,
lowered: LoweredSqlGlobalAggregateCommand,
consistency: MissingRowPolicy,
) -> Result<SqlGlobalAggregateCommandCore, SqlLoweringError> {
lowered.into_structural(model, consistency)
}
fn lower_global_aggregate_terminal(
aggregate_expr: &AggregateExpr,
) -> Result<SqlGlobalAggregateTerminal, SqlLoweringError> {
let distinct = aggregate_expr.is_distinct();
let filter_expr = aggregate_expr.filter_expr().cloned();
match (
aggregate_expr.kind(),
aggregate_expr.target_field().map(str::to_string),
aggregate_expr.input_expr().cloned(),
) {
(AggregateKind::Count, None, None) => {
Ok(SqlGlobalAggregateTerminal::CountRows { filter_expr })
}
(AggregateKind::Count, Some(field), _) => Ok(SqlGlobalAggregateTerminal::CountField {
field,
filter_expr,
distinct,
}),
(AggregateKind::Count, None, Some(input_expr)) => {
Ok(SqlGlobalAggregateTerminal::CountExpr {
input_expr,
filter_expr,
distinct,
})
}
(AggregateKind::Sum, Some(field), _) => Ok(SqlGlobalAggregateTerminal::SumField {
field,
filter_expr,
distinct,
}),
(AggregateKind::Sum, None, Some(input_expr)) => Ok(SqlGlobalAggregateTerminal::SumExpr {
input_expr,
filter_expr,
distinct,
}),
(AggregateKind::Avg, Some(field), _) => Ok(SqlGlobalAggregateTerminal::AvgField {
field,
filter_expr,
distinct,
}),
(AggregateKind::Avg, None, Some(input_expr)) => Ok(SqlGlobalAggregateTerminal::AvgExpr {
input_expr,
filter_expr,
distinct,
}),
(AggregateKind::Min, Some(field), _) => {
Ok(SqlGlobalAggregateTerminal::MinField { field, filter_expr })
}
(AggregateKind::Min, None, Some(input_expr)) => Ok(SqlGlobalAggregateTerminal::MinExpr {
input_expr,
filter_expr,
}),
(AggregateKind::Max, Some(field), _) => {
Ok(SqlGlobalAggregateTerminal::MaxField { field, filter_expr })
}
(AggregateKind::Max, None, Some(input_expr)) => Ok(SqlGlobalAggregateTerminal::MaxExpr {
input_expr,
filter_expr,
}),
(AggregateKind::Exists | AggregateKind::First | AggregateKind::Last, _, _)
| (_, None, None) => Err(SqlLoweringError::unsupported_global_aggregate_projection()),
}
}
fn resolve_or_insert_global_aggregate_terminal_index_from_expr(
terminals: &mut Vec<SqlGlobalAggregateTerminal>,
aggregate_expr: &AggregateExpr,
) -> Result<usize, SqlLoweringError> {
let terminal = lower_global_aggregate_terminal(aggregate_expr)?;
Ok(terminals
.iter()
.position(|current| current == &terminal)
.unwrap_or_else(|| {
let index = terminals.len();
terminals.push(terminal);
index
}))
}
pub(in crate::db::sql::lowering) fn resolve_having_aggregate_expr_index(
target: &AggregateExpr,
grouped_projection_aggregates: &[SqlAggregateCall],
) -> Result<usize, SqlLoweringError> {
let mut matched =
grouped_projection_aggregates
.iter()
.enumerate()
.filter_map(|(index, aggregate)| {
lower_aggregate_call(aggregate.clone())
.ok()
.filter(|current| current == target)
.map(|_| index)
});
let Some(index) = matched.next() else {
return Err(SqlLoweringError::unsupported_select_having());
};
if matched.next().is_some() {
return Err(SqlLoweringError::unsupported_select_having());
}
Ok(index)
}
struct LoweredSqlGlobalAggregateTerminals {
terminals: Vec<SqlGlobalAggregateTerminal>,
projection: ProjectionSpec,
#[cfg(test)]
output_remap: Vec<usize>,
}
impl LoweredSqlGlobalAggregateTerminals {
fn from_projection(
projection: SqlProjection,
projection_aliases: &[Option<String>],
) -> Result<Self, SqlLoweringError> {
let SqlProjection::Items(items) = projection else {
return Err(SqlLoweringError::unsupported_global_aggregate_projection());
};
if items.is_empty() {
return Err(SqlLoweringError::unsupported_global_aggregate_projection());
}
let mut terminals = Vec::<SqlGlobalAggregateTerminal>::with_capacity(items.len());
#[cfg(test)]
let mut output_remap = Vec::<usize>::with_capacity(items.len());
let mut fields = Vec::<ProjectionField>::with_capacity(items.len());
#[cfg(test)]
let mut saw_wrapped_projection = false;
for (index, item) in items.into_iter().enumerate() {
let expr = lower_select_item_expr(&item, SqlExprPhase::PostAggregate)?;
let analysis = analyze_lowered_expr(&expr, None);
if !analysis.contains_aggregate() || analysis.references_direct_fields() {
return Err(SqlLoweringError::unsupported_global_aggregate_projection());
}
let direct_terminal_index =
collect_unique_global_aggregate_terminals_from_expr(&expr, &mut terminals)?;
#[cfg(test)]
match direct_terminal_index {
Some(unique_index) => output_remap.push(unique_index),
None => {
saw_wrapped_projection = true;
}
}
#[cfg(not(test))]
let _ = direct_terminal_index;
fields.push(ProjectionField::Scalar {
expr,
alias: projection_aliases
.get(index)
.and_then(Option::as_deref)
.map(Alias::new),
});
}
Ok(Self {
terminals,
projection: lower_global_aggregate_projection(fields),
#[cfg(test)]
output_remap: if saw_wrapped_projection {
Vec::new()
} else {
output_remap
},
})
}
}
pub(in crate::db::sql::lowering) fn expr_references_global_direct_fields(expr: &Expr) -> bool {
analyze_lowered_expr(expr, None).references_direct_fields()
}
fn collect_unique_global_aggregate_terminals_from_expr(
expr: &Expr,
terminals: &mut Vec<SqlGlobalAggregateTerminal>,
) -> Result<Option<usize>, SqlLoweringError> {
let mut direct_terminal_index = None;
expr.try_for_each_tree_aggregate(&mut |aggregate_expr| {
let terminal = lower_global_aggregate_terminal(aggregate_expr)?;
let unique_index = terminals
.iter()
.position(|current| current == &terminal)
.unwrap_or_else(|| {
let index = terminals.len();
terminals.push(terminal);
index
});
if direct_terminal_index.is_none() && matches!(expr, Expr::Aggregate(_)) {
direct_terminal_index = Some(unique_index);
}
Ok::<(), SqlLoweringError>(())
})?;
Ok(direct_terminal_index)
}
fn lower_sql_aggregate_shape(
call: SqlAggregateCall,
) -> Result<LoweredSqlAggregateShape, SqlLoweringError> {
let SqlAggregateCall {
kind,
input,
filter_expr,
distinct,
} = call;
let filter_expr = filter_expr
.map(|expr| lower_sql_where_bool_expr(expr.as_ref()))
.transpose()?;
if distinct && filter_expr.is_some() {
return Err(SqlLoweringError::unsupported_select_projection());
}
match input.map(|input| *input) {
None if kind.supports_star_input() && !distinct => {
Ok(LoweredSqlAggregateShape::CountRows { filter_expr })
}
Some(SqlExpr::Field(field)) if matches!(kind, SqlAggregateKind::Count) => {
Ok(LoweredSqlAggregateShape::CountField {
field,
filter_expr,
distinct,
})
}
Some(SqlExpr::Field(field)) if kind.lowers_shared_field_target_shape() => {
Ok(LoweredSqlAggregateShape::FieldTarget {
kind,
field,
filter_expr,
distinct,
})
}
Some(input) => Ok(LoweredSqlAggregateShape::ExpressionInput {
kind,
input_expr: canonicalize_aggregate_input_expr(
kind.aggregate_kind(),
lower_sql_expr(&input, SqlExprPhase::PreAggregate)?,
),
filter_expr,
distinct,
}),
_ => Err(SqlLoweringError::unsupported_select_projection()),
}
}
pub(in crate::db::sql::lowering) fn grouped_projection_aggregate_calls(
projection: &SqlProjection,
group_by_fields: &[String],
model: &'static EntityModel,
) -> Result<Vec<SqlAggregateCall>, SqlLoweringError> {
if group_by_fields.is_empty() {
return Err(SqlLoweringError::unsupported_select_group_by());
}
let SqlProjection::Items(items) = projection else {
return Err(SqlLoweringError::grouped_projection_requires_explicit_list());
};
GroupedProjectionAggregateCollector::new(group_by_fields, model)?.collect_from_items(items)
}
pub(in crate::db::sql::lowering) fn extend_unique_sql_expr_aggregate_calls(
aggregate_calls: &mut Vec<SqlAggregateCall>,
expr: &SqlExpr,
) {
expr.for_each_tree_aggregate(&mut |aggregate| {
push_unique_sql_aggregate_call(aggregate_calls, aggregate.clone());
});
}
pub(in crate::db::sql::lowering) fn extend_unique_sql_select_item_aggregate_calls(
aggregate_calls: &mut Vec<SqlAggregateCall>,
item: &SqlSelectItem,
) {
match item {
SqlSelectItem::Field(_) => {}
SqlSelectItem::Aggregate(aggregate) => {
push_unique_sql_aggregate_call(aggregate_calls, aggregate.clone());
}
SqlSelectItem::Expr(expr) => {
extend_unique_sql_expr_aggregate_calls(aggregate_calls, expr);
}
}
}
struct GroupedProjectionAggregateCollector<'a> {
grouped_field_names: Vec<&'a str>,
model: &'static EntityModel,
aggregate_calls: Vec<SqlAggregateCall>,
seen_aggregate: bool,
}
impl<'a> GroupedProjectionAggregateCollector<'a> {
fn new(
group_by_fields: &'a [String],
model: &'static EntityModel,
) -> Result<Self, SqlLoweringError> {
if group_by_fields.is_empty() {
return Err(SqlLoweringError::unsupported_select_group_by());
}
Ok(Self {
grouped_field_names: group_by_fields.iter().map(String::as_str).collect(),
model,
aggregate_calls: Vec::new(),
seen_aggregate: false,
})
}
fn collect_from_items(
mut self,
items: &[SqlSelectItem],
) -> Result<Vec<SqlAggregateCall>, SqlLoweringError> {
for (index, item) in items.iter().enumerate() {
self.collect_item(index, item)?;
}
if self.aggregate_calls.is_empty() {
return Err(SqlLoweringError::grouped_projection_requires_aggregate());
}
Ok(self.aggregate_calls)
}
fn collect_item(&mut self, index: usize, item: &SqlSelectItem) -> Result<(), SqlLoweringError> {
let expr = crate::db::sql::lowering::select::lower_select_item_expr(
item,
SqlExprPhase::PostAggregate,
)?;
let analysis = analyze_lowered_expr(&expr, Some(self.model));
let contains_aggregate = analysis.contains_aggregate();
if self.seen_aggregate && !contains_aggregate {
return Err(SqlLoweringError::grouped_projection_scalar_after_aggregate(
index,
));
}
if let Some(field) = analysis.first_unknown_field() {
return Err(SqlLoweringError::unknown_field(field));
}
if !expr.references_only_fields(self.grouped_field_names.as_slice()) {
return Err(SqlLoweringError::grouped_projection_references_non_group_field(index));
}
if contains_aggregate {
self.seen_aggregate = true;
extend_unique_sql_select_item_aggregate_calls(&mut self.aggregate_calls, item);
}
Ok(())
}
}
fn push_unique_sql_aggregate_call(
aggregate_calls: &mut Vec<SqlAggregateCall>,
aggregate: SqlAggregateCall,
) {
if aggregate_calls.iter().all(|current| current != &aggregate) {
aggregate_calls.push(aggregate);
}
}
pub(in crate::db::sql::lowering) fn lower_aggregate_call(
call: SqlAggregateCall,
) -> Result<crate::db::query::builder::AggregateExpr, SqlLoweringError> {
match lower_sql_aggregate_shape(call)? {
LoweredSqlAggregateShape::CountRows { filter_expr } => {
Ok(apply_aggregate_filter_expr(count(), filter_expr))
}
LoweredSqlAggregateShape::CountField {
field,
filter_expr,
distinct: false,
} => Ok(apply_aggregate_filter_expr(count_by(field), filter_expr)),
LoweredSqlAggregateShape::CountField {
field,
filter_expr,
distinct: true,
} => Ok(apply_aggregate_filter_expr(
count_by(field).distinct(),
filter_expr,
)),
LoweredSqlAggregateShape::FieldTarget {
kind,
field,
filter_expr,
distinct,
} => kind.lower_field_target_aggregate(field, filter_expr, distinct),
LoweredSqlAggregateShape::ExpressionInput {
kind,
input_expr,
filter_expr,
distinct,
} => Ok(apply_aggregate_filter_expr(
kind.lower_expression_owned_aggregate(input_expr, distinct),
filter_expr,
)),
}
}
pub(in crate::db::sql::lowering) fn lower_grouped_aggregate_call(
model: &'static EntityModel,
call: SqlAggregateCall,
) -> Result<crate::db::query::builder::AggregateExpr, SqlLoweringError> {
let aggregate = lower_aggregate_call(call)?;
validate_grouped_aggregate_scalar_subexpressions(model, &aggregate)?;
Ok(aggregate)
}
fn apply_aggregate_filter_expr(
aggregate: AggregateExpr,
filter_expr: Option<Expr>,
) -> AggregateExpr {
match filter_expr {
Some(filter_expr) => aggregate.with_filter_expr(filter_expr),
None => aggregate,
}
}
fn validate_grouped_aggregate_scalar_subexpressions(
model: &'static EntityModel,
aggregate: &AggregateExpr,
) -> Result<(), SqlLoweringError> {
if let Some(input_expr) = aggregate.input_expr() {
validate_grouped_model_bound_scalar_expr(
model,
input_expr,
SqlLoweringError::unsupported_aggregate_input_expressions,
)?;
}
if let Some(filter_expr) = aggregate.filter_expr() {
validate_grouped_model_bound_scalar_expr(
model,
filter_expr,
SqlLoweringError::unsupported_where_expression,
)?;
}
Ok(())
}
fn validate_grouped_model_bound_scalar_expr(
model: &'static EntityModel,
expr: &Expr,
unsupported: impl FnOnce() -> SqlLoweringError,
) -> Result<(), SqlLoweringError> {
if let Some(field) = analyze_lowered_expr(expr, Some(model)).first_unknown_field() {
return Err(SqlLoweringError::unknown_field(field));
}
if compile_scalar_projection_expr(model, expr).is_none() {
return Err(unsupported());
}
Ok(())
}
impl SqlAggregateKind {
fn lower_field_target_aggregate(
self,
field: String,
filter_expr: Option<Expr>,
distinct: bool,
) -> Result<AggregateExpr, SqlLoweringError> {
let aggregate = match self {
Self::Count => return Err(SqlLoweringError::unsupported_select_projection()),
Self::Sum => {
if distinct {
sum(field).distinct()
} else {
sum(field)
}
}
Self::Avg => {
if distinct {
avg(field).distinct()
} else {
avg(field)
}
}
Self::Min => min_by(field),
Self::Max => max_by(field),
};
Ok(apply_aggregate_filter_expr(aggregate, filter_expr))
}
fn lower_expression_owned_aggregate(self, input_expr: Expr, distinct: bool) -> AggregateExpr {
let aggregate = AggregateExpr::from_expression_input(self.aggregate_kind(), input_expr);
if distinct {
aggregate.distinct()
} else {
aggregate
}
}
}