use crate::{
db::{
executor::{StructuralAggregateTerminal, StructuralAggregateTerminalKind},
query::plan::{AggregateKind, FieldSlot, expr::Expr, resolve_aggregate_target_field_slot},
sql::lowering::{
SqlLoweringError,
aggregate::{
lowering::validate_model_bound_scalar_expr,
terminal::{AggregateInput, SqlGlobalAggregateTerminal},
},
},
},
model::entity::EntityModel,
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum PreparedSqlScalarAggregateRuntimeDescriptor {
CountRows,
CountField,
NumericField { kind: AggregateKind },
ExtremalWinnerField { kind: AggregateKind },
}
pub(crate) type PreparedSqlScalarAggregateDescriptorShape =
PreparedSqlScalarAggregateRuntimeDescriptor;
#[derive(Clone, Debug, Eq, PartialEq)]
enum AggregateTarget {
Rows,
Field(FieldSlot),
Expr(Expr),
}
impl AggregateTarget {
const fn field_slot(&self) -> Option<&FieldSlot> {
match self {
Self::Field(field_slot) => Some(field_slot),
Self::Rows | Self::Expr(_) => None,
}
}
#[cfg(test)]
const fn input_expr(&self) -> Option<&Expr> {
match self {
Self::Expr(input_expr) => Some(input_expr),
Self::Rows | Self::Field(_) => None,
}
}
fn into_executor_parts(self) -> (Option<FieldSlot>, Option<Expr>) {
match self {
Self::Rows => (None, None),
Self::Field(target_slot) => (Some(target_slot), None),
Self::Expr(input_expr) => (None, Some(input_expr)),
}
}
}
impl PreparedSqlScalarAggregateRuntimeDescriptor {
fn from_aggregate_kind(kind: AggregateKind) -> Self {
match kind {
AggregateKind::Count => Self::CountField,
AggregateKind::Sum | AggregateKind::Avg => Self::NumericField { kind },
AggregateKind::Min | AggregateKind::Max => Self::ExtremalWinnerField { kind },
AggregateKind::Exists | AggregateKind::First | AggregateKind::Last => {
unreachable!("unsupported SQL aggregate kind reached scalar aggregate descriptor")
}
}
}
#[must_use]
pub(crate) const fn runtime_descriptor(self) -> Self {
self
}
#[must_use]
pub(crate) const fn aggregate_kind(self) -> AggregateKind {
match self {
Self::CountRows | Self::CountField => AggregateKind::Count,
Self::NumericField { kind } | Self::ExtremalWinnerField { kind } => kind,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) struct PreparedSqlScalarAggregateStrategy {
target: AggregateTarget,
filter_expr: Option<Expr>,
distinct_input: bool,
descriptor_shape: PreparedSqlScalarAggregateDescriptorShape,
}
impl PreparedSqlScalarAggregateStrategy {
const fn from_resolved_shape(
target: AggregateTarget,
filter_expr: Option<Expr>,
distinct_input: bool,
descriptor_shape: PreparedSqlScalarAggregateDescriptorShape,
) -> Self {
Self {
target,
filter_expr,
distinct_input,
descriptor_shape,
}
}
const fn for_rows(filter_expr: Option<Expr>) -> Self {
Self::from_resolved_shape(
AggregateTarget::Rows,
filter_expr,
false,
PreparedSqlScalarAggregateDescriptorShape::CountRows,
)
}
fn for_field_target(
kind: AggregateKind,
target_slot: FieldSlot,
filter_expr: Option<Expr>,
distinct_input: bool,
) -> Self {
Self::from_resolved_shape(
AggregateTarget::Field(target_slot),
filter_expr,
distinct_input,
PreparedSqlScalarAggregateDescriptorShape::from_aggregate_kind(kind),
)
}
fn for_expression_input(
kind: AggregateKind,
input_expr: Expr,
filter_expr: Option<Expr>,
distinct_input: bool,
) -> Self {
Self::from_resolved_shape(
AggregateTarget::Expr(input_expr),
filter_expr,
distinct_input,
PreparedSqlScalarAggregateDescriptorShape::from_aggregate_kind(kind),
)
}
pub(in crate::db::sql::lowering::aggregate) fn from_lowered_terminal(
model: &'static EntityModel,
terminal: SqlGlobalAggregateTerminal,
) -> Result<Self, SqlLoweringError> {
match terminal.input {
AggregateInput::Rows => Ok(Self::for_rows(terminal.filter_expr)),
AggregateInput::Field(field) => {
let target_slot = resolve_aggregate_target_field_slot(model, field.as_str())
.map_err(SqlLoweringError::from)?;
Ok(Self::for_field_target(
terminal.kind,
target_slot,
terminal.filter_expr,
terminal.distinct,
))
}
AggregateInput::Expr(input_expr) => {
validate_model_bound_scalar_expr(
model,
&input_expr,
SqlLoweringError::unsupported_aggregate_input_expressions,
)?;
Ok(Self::for_expression_input(
terminal.kind,
input_expr,
terminal.filter_expr,
terminal.distinct,
))
}
}
}
#[must_use]
pub(crate) const fn target_slot(&self) -> Option<&FieldSlot> {
self.target.field_slot()
}
#[cfg(test)]
#[must_use]
pub(crate) const fn input_expr(&self) -> Option<&Expr> {
self.target.input_expr()
}
#[must_use]
pub(crate) const fn filter_expr(&self) -> Option<&Expr> {
self.filter_expr.as_ref()
}
#[cfg(test)]
#[must_use]
pub(crate) const fn is_distinct(&self) -> bool {
self.distinct_input
}
#[cfg(test)]
#[must_use]
pub(crate) const fn descriptor_shape(&self) -> PreparedSqlScalarAggregateDescriptorShape {
self.descriptor_shape
}
#[cfg(test)]
#[must_use]
pub(crate) const fn runtime_descriptor(&self) -> PreparedSqlScalarAggregateRuntimeDescriptor {
self.descriptor_shape.runtime_descriptor()
}
#[must_use]
pub(crate) const fn aggregate_kind(&self) -> AggregateKind {
self.descriptor_shape.aggregate_kind()
}
pub(in crate::db) fn into_executor_terminal(
self,
) -> Result<StructuralAggregateTerminal, &'static str> {
let Self {
target,
filter_expr,
distinct_input,
descriptor_shape,
} = self;
let (target_slot, input_expr) = target.into_executor_parts();
let kind = match descriptor_shape.runtime_descriptor() {
PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
StructuralAggregateTerminalKind::CountRows
}
PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
StructuralAggregateTerminalKind::CountValues
}
PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
kind: AggregateKind::Sum,
} => StructuralAggregateTerminalKind::Sum,
PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
kind: AggregateKind::Avg,
} => StructuralAggregateTerminalKind::Avg,
PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
kind: AggregateKind::Min,
} => StructuralAggregateTerminalKind::Min,
PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
kind: AggregateKind::Max,
} => StructuralAggregateTerminalKind::Max,
PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
| PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
return Err("prepared SQL scalar aggregate strategy drifted outside SQL support");
}
};
Ok(StructuralAggregateTerminal::new(
kind,
target_slot,
input_expr,
filter_expr,
distinct_input,
))
}
#[must_use]
pub(crate) fn projected_field(&self) -> Option<&str> {
self.target_slot().map(FieldSlot::field)
}
}