use crate::db::{
query::{
builder::AggregateExpr,
plan::{AggregateKind, expr::aggregate_count_input_expr_is_non_null_literal},
},
sql::lowering::{
AnalyzedLoweredExpr, SqlLoweringError, aggregate::semantics::AggregateTerminalSemanticKey,
},
};
#[derive(Clone, Debug)]
pub(in crate::db::sql::lowering::aggregate) enum LoweredAggregateInput {
Rows,
Field(String),
Expr(AnalyzedLoweredExpr),
}
#[derive(Clone, Debug)]
pub(in crate::db::sql::lowering::aggregate) struct LoweredSqlGlobalAggregateTerminal {
semantic_key: AggregateTerminalSemanticKey,
input: LoweredAggregateInput,
filter_expr: Option<AnalyzedLoweredExpr>,
}
impl LoweredSqlGlobalAggregateTerminal {
pub(in crate::db::sql::lowering::aggregate) fn count_rows() -> Self {
let aggregate_expr = crate::db::query::builder::aggregate::count();
let semantic_key = AggregateTerminalSemanticKey::from_aggregate_expr(&aggregate_expr);
Self {
semantic_key,
input: LoweredAggregateInput::Rows,
filter_expr: None,
}
}
pub(in crate::db::sql::lowering::aggregate) fn from_aggregate_expr_with_semantic_key(
aggregate_expr: &AggregateExpr,
semantic_key: AggregateTerminalSemanticKey,
) -> Result<Self, SqlLoweringError> {
debug_assert_eq!(
semantic_key,
AggregateTerminalSemanticKey::from_aggregate_expr(aggregate_expr),
"global aggregate terminal semantic key must match its aggregate expression",
);
let input = Self::resolve_input(aggregate_expr)?;
let filter_expr = aggregate_expr
.filter_expr()
.cloned()
.map(|expr| AnalyzedLoweredExpr::new(expr, None));
Ok(Self {
semantic_key,
input,
filter_expr,
})
}
pub(in crate::db::sql::lowering::aggregate) fn into_parts(
self,
) -> (
AggregateTerminalSemanticKey,
LoweredAggregateInput,
Option<AnalyzedLoweredExpr>,
) {
let Self {
semantic_key,
input,
filter_expr,
} = self;
(semantic_key, input, filter_expr)
}
fn resolve_input(
aggregate_expr: &AggregateExpr,
) -> Result<LoweredAggregateInput, SqlLoweringError> {
let kind = aggregate_expr.kind();
if matches!(
kind,
AggregateKind::Exists | AggregateKind::First | AggregateKind::Last
) {
return Err(SqlLoweringError::unsupported_global_aggregate_projection());
}
if kind == AggregateKind::Count
&& aggregate_expr.target_field().is_none()
&& aggregate_expr.input_expr().is_none()
{
return Ok(LoweredAggregateInput::Rows);
}
if kind == AggregateKind::Count
&& !aggregate_expr.is_distinct()
&& aggregate_expr
.input_expr()
.is_some_and(aggregate_count_input_expr_is_non_null_literal)
{
return Ok(LoweredAggregateInput::Rows);
}
if let Some(field) = aggregate_expr.target_field() {
return Ok(LoweredAggregateInput::Field(field.to_string()));
}
if let Some(input_expr) = aggregate_expr.input_expr() {
return Ok(LoweredAggregateInput::Expr(AnalyzedLoweredExpr::new(
input_expr.clone(),
None,
)));
}
Err(SqlLoweringError::unsupported_global_aggregate_projection())
}
}