use crate::db::sql::lowering::{SqlLoweringError, select::lower_select_item_expr};
use crate::db::{
predicate::Predicate,
query::builder::scalar_projection::render_scalar_projection_expr_sql_label,
query::plan::expr::{
parse_supported_order_expr, render_supported_order_expr,
rewrite_supported_order_expr_fields,
},
sql::{
identifier::{
identifier_last_segment, identifiers_tail_match, normalize_identifier_to_scope,
rewrite_field_identifiers,
},
parser::{
SqlAggregateCall, SqlAggregateInputExpr, SqlArithmeticProjectionCall, SqlHavingClause,
SqlHavingValueExpr, SqlOrderTerm, SqlProjection, SqlProjectionOperand,
SqlRoundProjectionCall, SqlRoundProjectionInput, SqlSelectItem, SqlSelectStatement,
SqlTextFunctionCall,
},
},
};
pub(in crate::db::sql::lowering) fn normalize_select_statement_to_expected_entity(
mut statement: SqlSelectStatement,
expected_entity: &'static str,
) -> Result<SqlSelectStatement, SqlLoweringError> {
let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
statement.projection =
normalize_projection_identifiers(statement.projection, entity_scope.as_slice());
statement.group_by = normalize_identifier_list(statement.group_by, entity_scope.as_slice());
statement.predicate = statement
.predicate
.map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
statement.order_by = normalize_select_order_terms(
statement.order_by,
&statement.projection,
statement.projection_aliases.as_slice(),
entity_scope.as_slice(),
)?;
statement.having = normalize_having_clauses(statement.having, entity_scope.as_slice());
Ok(statement)
}
pub(in crate::db::sql::lowering) fn normalize_having_clauses(
clauses: Vec<SqlHavingClause>,
entity_scope: &[String],
) -> Vec<SqlHavingClause> {
SqlIdentifierNormalizer::new(entity_scope).normalize_having_clauses(clauses)
}
pub(in crate::db::sql::lowering) fn sql_entity_scope_candidates(
sql_entity: &str,
expected_entity: &'static str,
) -> Vec<String> {
let mut out = Vec::new();
out.push(sql_entity.to_string());
out.push(expected_entity.to_string());
if let Some(last) = identifier_last_segment(sql_entity) {
out.push(last.to_string());
}
if let Some(last) = identifier_last_segment(expected_entity) {
out.push(last.to_string());
}
out
}
fn normalize_projection_identifiers(
projection: SqlProjection,
entity_scope: &[String],
) -> SqlProjection {
SqlIdentifierNormalizer::new(entity_scope).normalize_projection(projection)
}
#[derive(Clone, Copy)]
struct SqlIdentifierNormalizer<'a> {
entity_scope: &'a [String],
}
impl<'a> SqlIdentifierNormalizer<'a> {
const fn new(entity_scope: &'a [String]) -> Self {
Self { entity_scope }
}
fn normalize_projection(self, projection: SqlProjection) -> SqlProjection {
match projection {
SqlProjection::All => SqlProjection::All,
SqlProjection::Items(items) => SqlProjection::Items(
items
.into_iter()
.map(|item| self.normalize_select_item(item))
.collect(),
),
}
}
fn normalize_having_clauses(self, clauses: Vec<SqlHavingClause>) -> Vec<SqlHavingClause> {
clauses
.into_iter()
.map(|clause| SqlHavingClause {
left: self.normalize_having_value_expr(clause.left),
op: clause.op,
right: self.normalize_having_value_expr(clause.right),
})
.collect()
}
fn normalize_select_item(self, item: SqlSelectItem) -> SqlSelectItem {
match item {
SqlSelectItem::Field(field) => SqlSelectItem::Field(self.normalize_identifier(field)),
SqlSelectItem::Aggregate(aggregate) => {
SqlSelectItem::Aggregate(self.normalize_aggregate_call(aggregate))
}
SqlSelectItem::TextFunction(call) => {
SqlSelectItem::TextFunction(self.normalize_text_function_call(call))
}
SqlSelectItem::Arithmetic(call) => {
SqlSelectItem::Arithmetic(self.normalize_arithmetic_call(call))
}
SqlSelectItem::Round(call) => SqlSelectItem::Round(self.normalize_round_call(call)),
}
}
fn normalize_having_value_expr(self, expr: SqlHavingValueExpr) -> SqlHavingValueExpr {
match expr {
SqlHavingValueExpr::Field(field) => {
SqlHavingValueExpr::Field(self.normalize_identifier_to_scope(field))
}
SqlHavingValueExpr::Aggregate(aggregate) => {
SqlHavingValueExpr::Aggregate(self.normalize_aggregate_call(aggregate))
}
SqlHavingValueExpr::Literal(literal) => SqlHavingValueExpr::Literal(literal),
SqlHavingValueExpr::Arithmetic(call) => {
SqlHavingValueExpr::Arithmetic(self.normalize_arithmetic_call(call))
}
SqlHavingValueExpr::Round(call) => {
SqlHavingValueExpr::Round(self.normalize_round_call(call))
}
}
}
fn normalize_aggregate_call(self, aggregate: SqlAggregateCall) -> SqlAggregateCall {
SqlAggregateCall {
kind: aggregate.kind,
input: aggregate
.input
.map(|input| Box::new(self.normalize_aggregate_input_expr(*input))),
distinct: aggregate.distinct,
}
}
fn normalize_aggregate_input_expr(self, expr: SqlAggregateInputExpr) -> SqlAggregateInputExpr {
match expr {
SqlAggregateInputExpr::Field(field) => {
SqlAggregateInputExpr::Field(self.normalize_identifier_to_scope(field))
}
SqlAggregateInputExpr::Literal(literal) => SqlAggregateInputExpr::Literal(literal),
SqlAggregateInputExpr::Arithmetic(call) => {
SqlAggregateInputExpr::Arithmetic(self.normalize_arithmetic_call(call))
}
SqlAggregateInputExpr::Round(call) => {
SqlAggregateInputExpr::Round(self.normalize_round_call(call))
}
}
}
fn normalize_projection_operand(self, operand: SqlProjectionOperand) -> SqlProjectionOperand {
match operand {
SqlProjectionOperand::Field(field) => {
SqlProjectionOperand::Field(self.normalize_identifier(field))
}
SqlProjectionOperand::Aggregate(aggregate) => {
SqlProjectionOperand::Aggregate(self.normalize_aggregate_call(aggregate))
}
SqlProjectionOperand::Literal(literal) => SqlProjectionOperand::Literal(literal),
SqlProjectionOperand::Arithmetic(call) => {
SqlProjectionOperand::Arithmetic(Box::new(self.normalize_arithmetic_call(*call)))
}
}
}
fn normalize_arithmetic_call(
self,
call: SqlArithmeticProjectionCall,
) -> SqlArithmeticProjectionCall {
SqlArithmeticProjectionCall {
left: self.normalize_projection_operand(call.left),
op: call.op,
right: self.normalize_projection_operand(call.right),
}
}
fn normalize_round_call(self, call: SqlRoundProjectionCall) -> SqlRoundProjectionCall {
SqlRoundProjectionCall {
input: self.normalize_round_input(call.input),
scale: call.scale,
}
}
fn normalize_round_input(self, input: SqlRoundProjectionInput) -> SqlRoundProjectionInput {
match input {
SqlRoundProjectionInput::Operand(operand) => {
SqlRoundProjectionInput::Operand(self.normalize_projection_operand(operand))
}
SqlRoundProjectionInput::Arithmetic(call) => {
SqlRoundProjectionInput::Arithmetic(self.normalize_arithmetic_call(call))
}
}
}
fn normalize_text_function_call(self, call: SqlTextFunctionCall) -> SqlTextFunctionCall {
SqlTextFunctionCall {
function: call.function,
field: self.normalize_identifier(call.field),
literal: call.literal,
literal2: call.literal2,
literal3: call.literal3,
}
}
fn normalize_identifier(self, identifier: String) -> String {
normalize_identifier(identifier, self.entity_scope)
}
fn normalize_identifier_to_scope(self, identifier: String) -> String {
normalize_identifier_to_scope(identifier, self.entity_scope)
}
}
fn normalize_select_order_terms(
terms: Vec<SqlOrderTerm>,
projection: &SqlProjection,
projection_aliases: &[Option<String>],
entity_scope: &[String],
) -> Result<Vec<SqlOrderTerm>, SqlLoweringError> {
terms
.into_iter()
.map(|term| {
let field = match resolve_projection_order_alias(
term.field.as_str(),
projection,
projection_aliases,
)? {
Some(rewritten) => rewritten,
None => term.field,
};
Ok(SqlOrderTerm {
field: normalize_order_term_identifier(field, entity_scope),
direction: term.direction,
})
})
.collect()
}
fn resolve_projection_order_alias(
order_target: &str,
projection: &SqlProjection,
projection_aliases: &[Option<String>],
) -> Result<Option<String>, SqlLoweringError> {
let SqlProjection::Items(items) = projection else {
return Ok(None);
};
for (item, alias) in items.iter().zip(projection_aliases.iter()) {
let Some(alias) = alias.as_deref() else {
continue;
};
if !alias.eq_ignore_ascii_case(order_target) {
continue;
}
let Some(target) = order_target_from_projection_item(item) else {
return Err(SqlLoweringError::unsupported_order_by_alias(order_target));
};
return Ok(Some(target));
}
Ok(None)
}
fn order_target_from_projection_item(item: &SqlSelectItem) -> Option<String> {
match item {
SqlSelectItem::Field(field) => Some(field.clone()),
SqlSelectItem::Aggregate(_) => lower_select_item_expr(item)
.ok()
.map(|expr| render_scalar_projection_expr_sql_label(&expr)),
SqlSelectItem::TextFunction(_) => lower_select_item_expr(item)
.ok()
.and_then(|expr| render_supported_order_expr(&expr)),
SqlSelectItem::Arithmetic(_) | SqlSelectItem::Round(_) => {
lower_select_item_expr(item).ok().and_then(|expr| {
render_supported_order_expr(&expr)
.or_else(|| Some(render_scalar_projection_expr_sql_label(&expr)))
})
}
}
}
pub(in crate::db::sql::lowering) fn normalize_order_terms(
terms: Vec<SqlOrderTerm>,
entity_scope: &[String],
) -> Vec<SqlOrderTerm> {
terms
.into_iter()
.map(|term| SqlOrderTerm {
field: normalize_order_term_identifier(term.field, entity_scope),
direction: term.direction,
})
.collect()
}
fn normalize_order_term_identifier(identifier: String, entity_scope: &[String]) -> String {
let Some(expression) = parse_supported_order_expr(identifier.as_str()) else {
return normalize_identifier(identifier, entity_scope);
};
let rewritten = rewrite_supported_order_expr_fields(&expression, |field| {
normalize_identifier(field.to_string(), entity_scope)
})
.expect("supported order expression rewrite must preserve the admitted order family");
render_supported_order_expr(&rewritten)
.expect("supported order expression rendering must preserve the admitted order family")
}
pub(in crate::db::sql::lowering) fn normalize_identifier_list(
fields: Vec<String>,
entity_scope: &[String],
) -> Vec<String> {
fields
.into_iter()
.map(|field| normalize_identifier(field, entity_scope))
.collect()
}
pub(in crate::db::sql::lowering) fn adapt_predicate_identifiers_to_scope(
predicate: Predicate,
entity_scope: &[String],
) -> Predicate {
rewrite_field_identifiers(predicate, |field| normalize_identifier(field, entity_scope))
}
fn normalize_identifier(identifier: String, entity_scope: &[String]) -> String {
normalize_identifier_to_scope(identifier, entity_scope)
}
pub(in crate::db::sql::lowering) fn ensure_entity_matches_expected(
sql_entity: &str,
expected_entity: &'static str,
) -> Result<(), SqlLoweringError> {
if identifiers_tail_match(sql_entity, expected_entity) {
return Ok(());
}
Err(SqlLoweringError::entity_mismatch(
sql_entity,
expected_entity,
))
}