use crate::{
db::{
query::plan::{
AggregateKind, FieldSlot, expr::Expr, resolve_aggregate_target_field_slot_with_schema,
},
schema::SchemaInfo,
sql::lowering::{
SqlLoweringError,
aggregate::{
lowering::validate_analyzed_model_bound_scalar_expr,
semantics::{PreparedAggregateSemantics, PreparedAggregateTarget},
terminal::{LoweredAggregateInput, LoweredSqlGlobalAggregateTerminal},
},
},
},
model::entity::EntityModel,
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum PreparedSqlScalarAggregatePlanFragment {
CountRows,
CountField,
NumericField { kind: AggregateKind },
ExtremalWinnerField { kind: AggregateKind },
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) struct PreparedSqlScalarAggregateStrategy {
semantics: PreparedAggregateSemantics,
filter_expr: Option<Expr>,
}
impl PreparedSqlScalarAggregateStrategy {
const fn from_semantics(
semantics: PreparedAggregateSemantics,
filter_expr: Option<Expr>,
) -> Self {
Self {
semantics,
filter_expr,
}
}
pub(in crate::db::sql::lowering::aggregate) fn from_lowered_terminal_with_schema(
model: &'static EntityModel,
schema: &SchemaInfo,
terminal: LoweredSqlGlobalAggregateTerminal,
) -> Result<Self, SqlLoweringError> {
let (semantic_key, input, filter_expr) = terminal.into_parts();
let (semantic_identity, semantic_filter_expr) = semantic_key.into_identity_and_filter();
debug_assert_eq!(
semantic_filter_expr.as_ref(),
filter_expr
.as_ref()
.map(crate::db::sql::lowering::AnalyzedLoweredExpr::expr),
"global aggregate semantic key filter must match retained filter analysis",
);
let kind = semantic_identity.kind();
let distinct_input = semantic_identity.distinct();
let target = match input {
LoweredAggregateInput::Rows => PreparedAggregateTarget::Rows,
LoweredAggregateInput::Field(field) => {
validate_field_target_sql_aggregate_capabilities(schema, field.as_str(), kind)?;
let target_slot =
resolve_aggregate_target_field_slot_with_schema(model, schema, field.as_str())
.map_err(SqlLoweringError::from)?;
PreparedAggregateTarget::Field(target_slot)
}
LoweredAggregateInput::Expr(input_expr) => {
validate_analyzed_model_bound_scalar_expr(
model,
schema,
&input_expr,
SqlLoweringError::unsupported_aggregate_input_expressions,
)?;
PreparedAggregateTarget::Expr(input_expr.into_expr())
}
};
let filter_expr = match filter_expr {
Some(filter_expr) => {
Self::validate_global_aggregate_filter_expr(model, schema, &filter_expr)?;
Some(filter_expr.into_expr())
}
None => None,
};
let semantics = PreparedAggregateSemantics::try_from_kind_target_and_distinct(
kind,
target,
distinct_input,
)?;
Ok(Self::from_semantics(semantics, filter_expr))
}
fn validate_global_aggregate_filter_expr(
model: &'static EntityModel,
schema: &SchemaInfo,
filter_expr: &crate::db::sql::lowering::AnalyzedLoweredExpr,
) -> Result<(), SqlLoweringError> {
match validate_analyzed_model_bound_scalar_expr(
model,
schema,
filter_expr,
SqlLoweringError::unsupported_where_expression,
) {
Err(SqlLoweringError::UnknownField { field }) => {
let _ = field;
Err(crate::db::QueryError::invariant().into())
}
result => result,
}
}
#[cfg(any(test, feature = "sql-explain"))]
#[must_use]
pub(crate) const fn target_slot(&self) -> Option<&FieldSlot> {
self.semantics.target_slot()
}
#[cfg(test)]
#[must_use]
pub(in crate::db) const fn input_expr(&self) -> Option<&Expr> {
self.semantics.input_expr()
}
#[must_use]
pub(in crate::db) const fn filter_expr(&self) -> Option<&Expr> {
self.filter_expr.as_ref()
}
#[cfg(test)]
#[must_use]
pub(crate) const fn is_distinct(&self) -> bool {
self.semantics.distinct_input()
}
#[must_use]
pub(crate) const fn plan_fragment(&self) -> PreparedSqlScalarAggregatePlanFragment {
self.prepared_plan_fragment()
}
#[cfg(any(test, feature = "sql-explain"))]
#[must_use]
pub(crate) const fn aggregate_kind(&self) -> AggregateKind {
self.semantics.aggregate_kind()
}
const fn prepared_plan_fragment(&self) -> PreparedSqlScalarAggregatePlanFragment {
match &self.semantics {
PreparedAggregateSemantics::Count {
target: PreparedAggregateTarget::Rows,
..
} => PreparedSqlScalarAggregatePlanFragment::CountRows,
PreparedAggregateSemantics::Count { .. } => {
PreparedSqlScalarAggregatePlanFragment::CountField
}
PreparedAggregateSemantics::Sum { .. } => {
PreparedSqlScalarAggregatePlanFragment::NumericField {
kind: AggregateKind::Sum,
}
}
PreparedAggregateSemantics::Avg { .. } => {
PreparedSqlScalarAggregatePlanFragment::NumericField {
kind: AggregateKind::Avg,
}
}
PreparedAggregateSemantics::Min { .. } => {
PreparedSqlScalarAggregatePlanFragment::ExtremalWinnerField {
kind: AggregateKind::Min,
}
}
PreparedAggregateSemantics::Max { .. } => {
PreparedSqlScalarAggregatePlanFragment::ExtremalWinnerField {
kind: AggregateKind::Max,
}
}
}
}
pub(in crate::db) fn into_structural_terminal_inputs(
self,
) -> (
PreparedSqlScalarAggregatePlanFragment,
Option<FieldSlot>,
Option<Expr>,
Option<Expr>,
bool,
) {
let descriptor = self.plan_fragment();
let Self {
semantics,
filter_expr,
} = self;
let distinct_input = semantics.distinct_input();
let (target_slot, input_expr) = semantics.into_terminal_inputs();
(
descriptor,
target_slot,
input_expr,
filter_expr,
distinct_input,
)
}
#[cfg(feature = "sql-explain")]
#[must_use]
pub(crate) fn projected_field(&self) -> Option<&str> {
self.target_slot().map(FieldSlot::field)
}
}
fn validate_field_target_sql_aggregate_capabilities(
schema: &SchemaInfo,
field_name: &str,
kind: AggregateKind,
) -> Result<(), SqlLoweringError> {
let Some(capabilities) = schema.sql_capabilities(field_name) else {
return Ok(());
};
let aggregate_input = capabilities.aggregate_input();
let supported = match kind {
AggregateKind::Count => aggregate_input.count(),
AggregateKind::Sum | AggregateKind::Avg => aggregate_input.numeric(),
AggregateKind::Min | AggregateKind::Max => aggregate_input.extrema(),
AggregateKind::Exists | AggregateKind::First | AggregateKind::Last => false,
};
if !supported {
return Err(SqlLoweringError::unsupported_global_aggregate_projection());
}
Ok(())
}