#[cfg(test)]
mod tests;
use crate::{
db::{
predicate::{CoercionId, CompareOp, MissingRowPolicy, Predicate},
query::{
builder::aggregate::{avg, count, count_by, max_by, min_by, sum},
intent::{Query, QueryError, StructuralQuery},
plan::{ExpressionOrderTerm, FieldSlot, resolve_aggregate_target_field_slot},
},
sql::identifier::{
identifier_last_segment, identifiers_tail_match, normalize_identifier_to_scope,
rewrite_field_identifiers,
},
sql::parser::{
SqlAggregateCall, SqlAggregateKind, SqlDeleteStatement, SqlExplainMode,
SqlExplainStatement, SqlExplainTarget, SqlHavingClause, SqlHavingSymbol,
SqlOrderDirection, SqlOrderTerm, SqlProjection, SqlSelectItem, SqlSelectStatement,
SqlStatement, SqlTextFunctionCall,
},
},
model::{entity::EntityModel, field::FieldKind},
traits::EntityKind,
value::Value,
};
use thiserror::Error as ThisError;
#[derive(Clone, Debug)]
pub struct LoweredSqlCommand(LoweredSqlCommandInner);
#[derive(Clone, Debug)]
enum LoweredSqlCommandInner {
Query(LoweredSqlQuery),
Explain {
mode: SqlExplainMode,
query: LoweredSqlQuery,
},
ExplainGlobalAggregate {
mode: SqlExplainMode,
command: LoweredSqlGlobalAggregateCommand,
},
DescribeEntity,
ShowIndexesEntity,
ShowColumnsEntity,
ShowEntities,
}
#[cfg(test)]
#[derive(Debug)]
pub(crate) enum SqlCommand<E: EntityKind> {
Query(Query<E>),
Explain {
mode: SqlExplainMode,
query: Query<E>,
},
ExplainGlobalAggregate {
mode: SqlExplainMode,
command: SqlGlobalAggregateCommand<E>,
},
DescribeEntity,
ShowIndexesEntity,
ShowColumnsEntity,
ShowEntities,
}
impl LoweredSqlCommand {
#[must_use]
pub(in crate::db) const fn query(&self) -> Option<&LoweredSqlQuery> {
match &self.0 {
LoweredSqlCommandInner::Query(query) => Some(query),
LoweredSqlCommandInner::Explain { .. }
| LoweredSqlCommandInner::ExplainGlobalAggregate { .. }
| LoweredSqlCommandInner::DescribeEntity
| LoweredSqlCommandInner::ShowIndexesEntity
| LoweredSqlCommandInner::ShowColumnsEntity
| LoweredSqlCommandInner::ShowEntities => None,
}
}
#[must_use]
pub(in crate::db) const fn explain_query(&self) -> Option<(SqlExplainMode, &LoweredSqlQuery)> {
match &self.0 {
LoweredSqlCommandInner::Explain { mode, query } => Some((*mode, query)),
LoweredSqlCommandInner::Query(_)
| LoweredSqlCommandInner::ExplainGlobalAggregate { .. }
| LoweredSqlCommandInner::DescribeEntity
| LoweredSqlCommandInner::ShowIndexesEntity
| LoweredSqlCommandInner::ShowColumnsEntity
| LoweredSqlCommandInner::ShowEntities => None,
}
}
}
#[derive(Clone, Debug)]
pub(crate) enum LoweredSqlQuery {
Select(LoweredSelectShape),
Delete(LoweredBaseQueryShape),
}
impl LoweredSqlQuery {
pub(crate) const fn has_grouping(&self) -> bool {
match self {
Self::Select(select) => select.has_grouping(),
Self::Delete(_) => false,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) enum SqlGlobalAggregateTerminal {
CountRows,
CountField(String),
SumField(String),
AvgField(String),
MinField(String),
MaxField(String),
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) enum TypedSqlGlobalAggregateTerminal {
CountRows,
CountField(FieldSlot),
SumField(FieldSlot),
AvgField(FieldSlot),
MinField(FieldSlot),
MaxField(FieldSlot),
}
#[derive(Clone, Debug)]
pub(crate) struct LoweredSqlGlobalAggregateCommand {
query: LoweredBaseQueryShape,
terminal: SqlGlobalAggregateTerminal,
}
enum LoweredSqlAggregateShape {
CountRows,
CountField(String),
FieldTarget {
kind: SqlAggregateKind,
field: String,
},
}
#[derive(Debug)]
pub(crate) struct SqlGlobalAggregateCommand<E: EntityKind> {
query: Query<E>,
terminal: TypedSqlGlobalAggregateTerminal,
}
impl<E: EntityKind> SqlGlobalAggregateCommand<E> {
#[must_use]
pub(crate) const fn query(&self) -> &Query<E> {
&self.query
}
#[must_use]
pub(crate) const fn terminal(&self) -> &TypedSqlGlobalAggregateTerminal {
&self.terminal
}
}
#[derive(Debug)]
pub(crate) struct SqlGlobalAggregateCommandCore {
query: StructuralQuery,
terminal: SqlGlobalAggregateTerminal,
}
impl SqlGlobalAggregateCommandCore {
#[must_use]
pub(in crate::db) const fn query(&self) -> &StructuralQuery {
&self.query
}
#[must_use]
pub(in crate::db) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
&self.terminal
}
}
#[derive(Debug, ThisError)]
pub(crate) enum SqlLoweringError {
#[error("{0}")]
Parse(#[from] crate::db::sql::parser::SqlParseError),
#[error("{0}")]
Query(#[from] QueryError),
#[error("SQL entity '{sql_entity}' does not match requested entity type '{expected_entity}'")]
EntityMismatch {
sql_entity: String,
expected_entity: &'static str,
},
#[error(
"unsupported SQL SELECT projection; supported forms are SELECT *, field lists, or grouped aggregate shapes"
)]
UnsupportedSelectProjection,
#[error("unsupported SQL SELECT DISTINCT")]
UnsupportedSelectDistinct,
#[error("unsupported SQL GROUP BY projection shape")]
UnsupportedSelectGroupBy,
#[error("unsupported SQL HAVING shape")]
UnsupportedSelectHaving,
}
impl SqlLoweringError {
fn entity_mismatch(sql_entity: impl Into<String>, expected_entity: &'static str) -> Self {
Self::EntityMismatch {
sql_entity: sql_entity.into(),
expected_entity,
}
}
const fn unsupported_select_projection() -> Self {
Self::UnsupportedSelectProjection
}
const fn unsupported_select_distinct() -> Self {
Self::UnsupportedSelectDistinct
}
const fn unsupported_select_group_by() -> Self {
Self::UnsupportedSelectGroupBy
}
const fn unsupported_select_having() -> Self {
Self::UnsupportedSelectHaving
}
}
#[derive(Clone, Debug)]
pub(crate) struct PreparedSqlStatement {
statement: SqlStatement,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum LoweredSqlLaneKind {
Query,
Explain,
Describe,
ShowIndexes,
ShowColumns,
ShowEntities,
}
#[cfg(test)]
pub(crate) fn compile_sql_command<E: EntityKind>(
sql: &str,
consistency: MissingRowPolicy,
) -> Result<SqlCommand<E>, SqlLoweringError> {
let statement = crate::db::sql::parser::parse_sql(sql)?;
compile_sql_command_from_statement::<E>(statement, consistency)
}
#[cfg(test)]
pub(crate) fn compile_sql_command_from_statement<E: EntityKind>(
statement: SqlStatement,
consistency: MissingRowPolicy,
) -> Result<SqlCommand<E>, SqlLoweringError> {
let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
compile_sql_command_from_prepared_statement::<E>(prepared, consistency)
}
#[cfg(test)]
pub(crate) fn compile_sql_command_from_prepared_statement<E: EntityKind>(
prepared: PreparedSqlStatement,
consistency: MissingRowPolicy,
) -> Result<SqlCommand<E>, SqlLoweringError> {
let lowered = lower_sql_command_from_prepared_statement(prepared, E::MODEL.primary_key.name)?;
bind_lowered_sql_command::<E>(lowered, consistency)
}
#[inline(never)]
pub(crate) fn lower_sql_command_from_prepared_statement(
prepared: PreparedSqlStatement,
primary_key_field: &str,
) -> Result<LoweredSqlCommand, SqlLoweringError> {
lower_prepared_statement(prepared.statement, primary_key_field)
}
pub(crate) const fn lowered_sql_command_lane(command: &LoweredSqlCommand) -> LoweredSqlLaneKind {
match command.0 {
LoweredSqlCommandInner::Query(_) => LoweredSqlLaneKind::Query,
LoweredSqlCommandInner::Explain { .. }
| LoweredSqlCommandInner::ExplainGlobalAggregate { .. } => LoweredSqlLaneKind::Explain,
LoweredSqlCommandInner::DescribeEntity => LoweredSqlLaneKind::Describe,
LoweredSqlCommandInner::ShowIndexesEntity => LoweredSqlLaneKind::ShowIndexes,
LoweredSqlCommandInner::ShowColumnsEntity => LoweredSqlLaneKind::ShowColumns,
LoweredSqlCommandInner::ShowEntities => LoweredSqlLaneKind::ShowEntities,
}
}
pub(in crate::db) fn is_sql_global_aggregate_statement(statement: &SqlStatement) -> bool {
let SqlStatement::Select(statement) = statement else {
return false;
};
is_sql_global_aggregate_select(statement)
}
fn is_sql_global_aggregate_select(statement: &SqlSelectStatement) -> bool {
if statement.distinct || !statement.group_by.is_empty() || !statement.having.is_empty() {
return false;
}
lower_global_aggregate_terminal(statement.projection.clone()).is_ok()
}
pub(crate) fn bind_lowered_sql_explain_global_aggregate_structural(
lowered: &LoweredSqlCommand,
model: &'static crate::model::entity::EntityModel,
consistency: MissingRowPolicy,
) -> Option<(SqlExplainMode, SqlGlobalAggregateCommandCore)> {
let LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } = &lowered.0 else {
return None;
};
Some((
*mode,
bind_lowered_sql_global_aggregate_command_structural(model, command.clone(), consistency),
))
}
#[cfg(test)]
pub(crate) fn bind_lowered_sql_command<E: EntityKind>(
lowered: LoweredSqlCommand,
consistency: MissingRowPolicy,
) -> Result<SqlCommand<E>, SqlLoweringError> {
match lowered.0 {
LoweredSqlCommandInner::Query(query) => Ok(SqlCommand::Query(bind_lowered_sql_query::<E>(
query,
consistency,
)?)),
LoweredSqlCommandInner::Explain { mode, query } => Ok(SqlCommand::Explain {
mode,
query: bind_lowered_sql_query::<E>(query, consistency)?,
}),
LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } => {
Ok(SqlCommand::ExplainGlobalAggregate {
mode,
command: bind_lowered_sql_global_aggregate_command::<E>(command, consistency)?,
})
}
LoweredSqlCommandInner::DescribeEntity => Ok(SqlCommand::DescribeEntity),
LoweredSqlCommandInner::ShowIndexesEntity => Ok(SqlCommand::ShowIndexesEntity),
LoweredSqlCommandInner::ShowColumnsEntity => Ok(SqlCommand::ShowColumnsEntity),
LoweredSqlCommandInner::ShowEntities => Ok(SqlCommand::ShowEntities),
}
}
#[inline(never)]
pub(crate) fn prepare_sql_statement(
statement: SqlStatement,
expected_entity: &'static str,
) -> Result<PreparedSqlStatement, SqlLoweringError> {
let statement = prepare_statement(statement, expected_entity)?;
Ok(PreparedSqlStatement { statement })
}
#[cfg(test)]
pub(crate) fn compile_sql_global_aggregate_command<E: EntityKind>(
sql: &str,
consistency: MissingRowPolicy,
) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
let statement = crate::db::sql::parser::parse_sql(sql)?;
let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
compile_sql_global_aggregate_command_from_prepared::<E>(prepared, consistency)
}
pub(crate) fn compile_sql_global_aggregate_command_from_prepared<E: EntityKind>(
prepared: PreparedSqlStatement,
consistency: MissingRowPolicy,
) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
let SqlStatement::Select(statement) = prepared.statement else {
return Err(SqlLoweringError::unsupported_select_projection());
};
bind_lowered_sql_global_aggregate_command::<E>(
lower_global_aggregate_select_shape(statement)?,
consistency,
)
}
fn bind_lowered_sql_global_aggregate_terminal<E: EntityKind>(
terminal: SqlGlobalAggregateTerminal,
) -> Result<TypedSqlGlobalAggregateTerminal, SqlLoweringError> {
let resolve_target_slot = |field: &str| {
resolve_aggregate_target_field_slot(E::MODEL, field).map_err(SqlLoweringError::from)
};
match terminal {
SqlGlobalAggregateTerminal::CountRows => Ok(TypedSqlGlobalAggregateTerminal::CountRows),
SqlGlobalAggregateTerminal::CountField(field) => Ok(
TypedSqlGlobalAggregateTerminal::CountField(resolve_target_slot(field.as_str())?),
),
SqlGlobalAggregateTerminal::SumField(field) => Ok(
TypedSqlGlobalAggregateTerminal::SumField(resolve_target_slot(field.as_str())?),
),
SqlGlobalAggregateTerminal::AvgField(field) => Ok(
TypedSqlGlobalAggregateTerminal::AvgField(resolve_target_slot(field.as_str())?),
),
SqlGlobalAggregateTerminal::MinField(field) => Ok(
TypedSqlGlobalAggregateTerminal::MinField(resolve_target_slot(field.as_str())?),
),
SqlGlobalAggregateTerminal::MaxField(field) => Ok(
TypedSqlGlobalAggregateTerminal::MaxField(resolve_target_slot(field.as_str())?),
),
}
}
#[inline(never)]
fn prepare_statement(
statement: SqlStatement,
expected_entity: &'static str,
) -> Result<SqlStatement, SqlLoweringError> {
match statement {
SqlStatement::Select(statement) => Ok(SqlStatement::Select(prepare_select_statement(
statement,
expected_entity,
)?)),
SqlStatement::Delete(statement) => Ok(SqlStatement::Delete(prepare_delete_statement(
statement,
expected_entity,
)?)),
SqlStatement::Explain(statement) => Ok(SqlStatement::Explain(prepare_explain_statement(
statement,
expected_entity,
)?)),
SqlStatement::Describe(statement) => {
ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
Ok(SqlStatement::Describe(statement))
}
SqlStatement::ShowIndexes(statement) => {
ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
Ok(SqlStatement::ShowIndexes(statement))
}
SqlStatement::ShowColumns(statement) => {
ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
Ok(SqlStatement::ShowColumns(statement))
}
SqlStatement::ShowEntities(statement) => Ok(SqlStatement::ShowEntities(statement)),
}
}
fn prepare_explain_statement(
statement: SqlExplainStatement,
expected_entity: &'static str,
) -> Result<SqlExplainStatement, SqlLoweringError> {
let target = match statement.statement {
SqlExplainTarget::Select(select_statement) => {
SqlExplainTarget::Select(prepare_select_statement(select_statement, expected_entity)?)
}
SqlExplainTarget::Delete(delete_statement) => {
SqlExplainTarget::Delete(prepare_delete_statement(delete_statement, expected_entity)?)
}
};
Ok(SqlExplainStatement {
mode: statement.mode,
statement: target,
})
}
fn prepare_select_statement(
statement: SqlSelectStatement,
expected_entity: &'static str,
) -> Result<SqlSelectStatement, SqlLoweringError> {
ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
Ok(normalize_select_statement_to_expected_entity(
statement,
expected_entity,
))
}
fn normalize_select_statement_to_expected_entity(
mut statement: SqlSelectStatement,
expected_entity: &'static str,
) -> SqlSelectStatement {
let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
statement.projection =
normalize_projection_identifiers(statement.projection, entity_scope.as_slice());
statement.group_by = normalize_identifier_list(statement.group_by, entity_scope.as_slice());
statement.predicate = statement
.predicate
.map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
statement.having = normalize_having_clauses(statement.having, entity_scope.as_slice());
statement
}
fn prepare_delete_statement(
mut statement: SqlDeleteStatement,
expected_entity: &'static str,
) -> Result<SqlDeleteStatement, SqlLoweringError> {
ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
statement.predicate = statement
.predicate
.map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
Ok(statement)
}
#[inline(never)]
fn lower_prepared_statement(
statement: SqlStatement,
primary_key_field: &str,
) -> Result<LoweredSqlCommand, SqlLoweringError> {
match statement {
SqlStatement::Select(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
LoweredSqlQuery::Select(lower_select_shape(statement, primary_key_field)?),
))),
SqlStatement::Delete(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
LoweredSqlQuery::Delete(lower_delete_shape(statement)),
))),
SqlStatement::Explain(statement) => lower_explain_prepared(statement, primary_key_field),
SqlStatement::Describe(_) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::DescribeEntity)),
SqlStatement::ShowIndexes(_) => {
Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowIndexesEntity))
}
SqlStatement::ShowColumns(_) => {
Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowColumnsEntity))
}
SqlStatement::ShowEntities(_) => {
Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowEntities))
}
}
}
fn lower_explain_prepared(
statement: SqlExplainStatement,
primary_key_field: &str,
) -> Result<LoweredSqlCommand, SqlLoweringError> {
let mode = statement.mode;
match statement.statement {
SqlExplainTarget::Select(select_statement) => {
lower_explain_select_prepared(select_statement, mode, primary_key_field)
}
SqlExplainTarget::Delete(delete_statement) => {
Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
mode,
query: LoweredSqlQuery::Delete(lower_delete_shape(delete_statement)),
}))
}
}
}
fn lower_explain_select_prepared(
statement: SqlSelectStatement,
mode: SqlExplainMode,
primary_key_field: &str,
) -> Result<LoweredSqlCommand, SqlLoweringError> {
match lower_select_shape(statement.clone(), primary_key_field) {
Ok(query) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
mode,
query: LoweredSqlQuery::Select(query),
})),
Err(SqlLoweringError::UnsupportedSelectProjection) => {
let command = lower_global_aggregate_select_shape(statement)?;
Ok(LoweredSqlCommand(
LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command },
))
}
Err(err) => Err(err),
}
}
fn lower_global_aggregate_select_shape(
statement: SqlSelectStatement,
) -> Result<LoweredSqlGlobalAggregateCommand, SqlLoweringError> {
let SqlSelectStatement {
projection,
predicate,
distinct,
group_by,
having,
order_by,
limit,
offset,
entity: _,
} = statement;
if distinct {
return Err(SqlLoweringError::unsupported_select_distinct());
}
if !group_by.is_empty() {
return Err(SqlLoweringError::unsupported_select_group_by());
}
if !having.is_empty() {
return Err(SqlLoweringError::unsupported_select_having());
}
let terminal = lower_global_aggregate_terminal(projection)?;
Ok(LoweredSqlGlobalAggregateCommand {
query: LoweredBaseQueryShape {
predicate,
order_by,
limit,
offset,
},
terminal,
})
}
#[derive(Clone, Debug)]
enum ResolvedHavingClause {
GroupField {
field: String,
op: crate::db::predicate::CompareOp,
value: crate::value::Value,
},
Aggregate {
aggregate_index: usize,
op: crate::db::predicate::CompareOp,
value: crate::value::Value,
},
}
#[derive(Clone, Debug)]
pub(crate) struct LoweredSelectShape {
scalar_projection_fields: Option<Vec<String>>,
grouped_projection_aggregates: Vec<SqlAggregateCall>,
group_by_fields: Vec<String>,
distinct: bool,
having: Vec<ResolvedHavingClause>,
predicate: Option<Predicate>,
order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
limit: Option<u32>,
offset: Option<u32>,
}
impl LoweredSelectShape {
const fn has_grouping(&self) -> bool {
!self.group_by_fields.is_empty()
}
}
#[derive(Clone, Debug)]
pub(crate) struct LoweredBaseQueryShape {
predicate: Option<Predicate>,
order_by: Vec<SqlOrderTerm>,
limit: Option<u32>,
offset: Option<u32>,
}
#[inline(never)]
fn lower_select_shape(
statement: SqlSelectStatement,
primary_key_field: &str,
) -> Result<LoweredSelectShape, SqlLoweringError> {
let SqlSelectStatement {
projection,
predicate,
distinct,
group_by,
having,
order_by,
limit,
offset,
entity: _,
} = statement;
let projection_for_having = projection.clone();
let (scalar_projection_fields, grouped_projection_aggregates) = if group_by.is_empty() {
let scalar_projection_fields =
lower_scalar_projection_fields(projection, distinct, primary_key_field)?;
(scalar_projection_fields, Vec::new())
} else {
if distinct {
return Err(SqlLoweringError::unsupported_select_distinct());
}
let grouped_projection_aggregates =
grouped_projection_aggregate_calls(&projection, group_by.as_slice())?;
(None, grouped_projection_aggregates)
};
let having = lower_having_clauses(
having,
&projection_for_having,
group_by.as_slice(),
grouped_projection_aggregates.as_slice(),
)?;
Ok(LoweredSelectShape {
scalar_projection_fields,
grouped_projection_aggregates,
group_by_fields: group_by,
distinct,
having,
predicate,
order_by,
limit,
offset,
})
}
fn lower_scalar_projection_fields(
projection: SqlProjection,
distinct: bool,
primary_key_field: &str,
) -> Result<Option<Vec<String>>, SqlLoweringError> {
let SqlProjection::Items(items) = projection else {
if distinct {
return Ok(None);
}
return Ok(None);
};
let has_aggregate = items
.iter()
.any(|item| matches!(item, SqlSelectItem::Aggregate(_)));
if has_aggregate {
return Err(SqlLoweringError::unsupported_select_projection());
}
let fields = items
.into_iter()
.map(|item| match item {
SqlSelectItem::Field(field) => Ok(field),
SqlSelectItem::Aggregate(_) | SqlSelectItem::TextFunction(_) => {
Err(SqlLoweringError::unsupported_select_projection())
}
})
.collect::<Result<Vec<_>, _>>()?;
validate_scalar_distinct_projection(distinct, fields.as_slice(), primary_key_field)?;
Ok(Some(fields))
}
fn validate_scalar_distinct_projection(
distinct: bool,
projection_fields: &[String],
primary_key_field: &str,
) -> Result<(), SqlLoweringError> {
if !distinct {
return Ok(());
}
if projection_fields.is_empty() {
return Ok(());
}
let has_primary_key_field = projection_fields
.iter()
.any(|field| field == primary_key_field);
if !has_primary_key_field {
return Err(SqlLoweringError::unsupported_select_distinct());
}
Ok(())
}
fn lower_having_clauses(
having_clauses: Vec<SqlHavingClause>,
projection: &SqlProjection,
group_by_fields: &[String],
grouped_projection_aggregates: &[SqlAggregateCall],
) -> Result<Vec<ResolvedHavingClause>, SqlLoweringError> {
if having_clauses.is_empty() {
return Ok(Vec::new());
}
if group_by_fields.is_empty() {
return Err(SqlLoweringError::unsupported_select_having());
}
let projection_aggregates = grouped_projection_aggregate_calls(projection, group_by_fields)
.map_err(|_| SqlLoweringError::unsupported_select_having())?;
if projection_aggregates.as_slice() != grouped_projection_aggregates {
return Err(SqlLoweringError::unsupported_select_having());
}
let mut lowered = Vec::with_capacity(having_clauses.len());
for clause in having_clauses {
match clause.symbol {
SqlHavingSymbol::Field(field) => lowered.push(ResolvedHavingClause::GroupField {
field,
op: clause.op,
value: clause.value,
}),
SqlHavingSymbol::Aggregate(aggregate) => {
let aggregate_index =
resolve_having_aggregate_index(&aggregate, grouped_projection_aggregates)?;
lowered.push(ResolvedHavingClause::Aggregate {
aggregate_index,
op: clause.op,
value: clause.value,
});
}
}
}
Ok(lowered)
}
fn canonicalize_sql_predicate_for_model(
model: &'static EntityModel,
predicate: Predicate,
) -> Predicate {
match predicate {
Predicate::And(children) => Predicate::And(
children
.into_iter()
.map(|child| canonicalize_sql_predicate_for_model(model, child))
.collect(),
),
Predicate::Or(children) => Predicate::Or(
children
.into_iter()
.map(|child| canonicalize_sql_predicate_for_model(model, child))
.collect(),
),
Predicate::Not(inner) => Predicate::Not(Box::new(canonicalize_sql_predicate_for_model(
model, *inner,
))),
Predicate::Compare(mut cmp) => {
canonicalize_sql_compare_for_model(model, &mut cmp);
Predicate::Compare(cmp)
}
Predicate::True
| Predicate::False
| Predicate::IsNull { .. }
| Predicate::IsNotNull { .. }
| Predicate::IsMissing { .. }
| Predicate::IsEmpty { .. }
| Predicate::IsNotEmpty { .. }
| Predicate::TextContains { .. }
| Predicate::TextContainsCi { .. } => predicate,
}
}
fn model_field_kind(model: &'static EntityModel, field: &str) -> Option<FieldKind> {
model
.fields()
.iter()
.find(|candidate| candidate.name() == field)
.map(crate::model::field::FieldModel::kind)
}
fn canonicalize_sql_compare_for_model(
model: &'static EntityModel,
cmp: &mut crate::db::predicate::ComparePredicate,
) {
if cmp.coercion.id != CoercionId::Strict {
return;
}
let Some(field_kind) = model_field_kind(model, &cmp.field) else {
return;
};
match cmp.op {
CompareOp::Eq | CompareOp::Ne => {
if let Some(value) =
canonicalize_strict_sql_numeric_value_for_kind(&field_kind, &cmp.value)
{
cmp.value = value;
}
}
CompareOp::In | CompareOp::NotIn => {
let Value::List(items) = &cmp.value else {
return;
};
let items = items
.iter()
.map(|item| {
canonicalize_strict_sql_numeric_value_for_kind(&field_kind, item)
.unwrap_or_else(|| item.clone())
})
.collect();
cmp.value = Value::List(items);
}
CompareOp::Lt
| CompareOp::Lte
| CompareOp::Gt
| CompareOp::Gte
| CompareOp::Contains
| CompareOp::StartsWith
| CompareOp::EndsWith => {}
}
}
fn canonicalize_strict_sql_numeric_value_for_kind(
kind: &FieldKind,
value: &Value,
) -> Option<Value> {
match kind {
FieldKind::Relation { key_kind, .. } => {
canonicalize_strict_sql_numeric_value_for_kind(key_kind, value)
}
FieldKind::Int => match value {
Value::Int(inner) => Some(Value::Int(*inner)),
Value::Uint(inner) => i64::try_from(*inner).ok().map(Value::Int),
_ => None,
},
FieldKind::Uint => match value {
Value::Int(inner) => u64::try_from(*inner).ok().map(Value::Uint),
Value::Uint(inner) => Some(Value::Uint(*inner)),
_ => None,
},
FieldKind::Account
| FieldKind::Blob
| FieldKind::Bool
| FieldKind::Date
| FieldKind::Decimal { .. }
| FieldKind::Duration
| FieldKind::Enum { .. }
| FieldKind::Float32
| FieldKind::Float64
| FieldKind::Int128
| FieldKind::IntBig
| FieldKind::List(_)
| FieldKind::Map { .. }
| FieldKind::Principal
| FieldKind::Set(_)
| FieldKind::Structured { .. }
| FieldKind::Subaccount
| FieldKind::Text
| FieldKind::Timestamp
| FieldKind::Uint128
| FieldKind::UintBig
| FieldKind::Ulid
| FieldKind::Unit => None,
}
}
#[inline(never)]
pub(in crate::db) fn apply_lowered_select_shape(
mut query: StructuralQuery,
lowered: LoweredSelectShape,
) -> Result<StructuralQuery, SqlLoweringError> {
let LoweredSelectShape {
scalar_projection_fields,
grouped_projection_aggregates,
group_by_fields,
distinct,
having,
predicate,
order_by,
limit,
offset,
} = lowered;
let model = query.model();
for field in group_by_fields {
query = query.group_by(field)?;
}
if distinct {
query = query.distinct();
}
if let Some(fields) = scalar_projection_fields {
query = query.select_fields(fields);
}
for aggregate in grouped_projection_aggregates {
query = query.aggregate(lower_aggregate_call(aggregate)?);
}
for clause in having {
match clause {
ResolvedHavingClause::GroupField { field, op, value } => {
let value = model_field_kind(model, &field)
.and_then(|field_kind| {
canonicalize_strict_sql_numeric_value_for_kind(&field_kind, &value)
})
.unwrap_or(value);
query = query.having_group(field, op, value)?;
}
ResolvedHavingClause::Aggregate {
aggregate_index,
op,
value,
} => {
query = query.having_aggregate(aggregate_index, op, value)?;
}
}
}
Ok(apply_lowered_base_query_shape(
query,
LoweredBaseQueryShape {
predicate: predicate
.map(|predicate| canonicalize_sql_predicate_for_model(model, predicate)),
order_by,
limit,
offset,
},
))
}
fn apply_lowered_base_query_shape(
mut query: StructuralQuery,
lowered: LoweredBaseQueryShape,
) -> StructuralQuery {
if let Some(predicate) = lowered.predicate {
query = query.filter(predicate);
}
query = apply_order_terms_structural(query, lowered.order_by);
if let Some(limit) = lowered.limit {
query = query.limit(limit);
}
if let Some(offset) = lowered.offset {
query = query.offset(offset);
}
query
}
pub(in crate::db) fn bind_lowered_sql_query_structural(
model: &'static crate::model::entity::EntityModel,
lowered: LoweredSqlQuery,
consistency: MissingRowPolicy,
) -> Result<StructuralQuery, SqlLoweringError> {
match lowered {
LoweredSqlQuery::Select(select) => {
apply_lowered_select_shape(StructuralQuery::new(model, consistency), select)
}
LoweredSqlQuery::Delete(delete) => Ok(bind_lowered_sql_delete_query_structural(
model,
delete,
consistency,
)),
}
}
pub(in crate::db) fn bind_lowered_sql_delete_query_structural(
model: &'static crate::model::entity::EntityModel,
delete: LoweredBaseQueryShape,
consistency: MissingRowPolicy,
) -> StructuralQuery {
apply_lowered_base_query_shape(StructuralQuery::new(model, consistency).delete(), delete)
}
pub(in crate::db) fn bind_lowered_sql_query<E: EntityKind>(
lowered: LoweredSqlQuery,
consistency: MissingRowPolicy,
) -> Result<Query<E>, SqlLoweringError> {
let structural = bind_lowered_sql_query_structural(E::MODEL, lowered, consistency)?;
Ok(Query::from_inner(structural))
}
fn bind_lowered_sql_global_aggregate_command<E: EntityKind>(
lowered: LoweredSqlGlobalAggregateCommand,
consistency: MissingRowPolicy,
) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
let terminal = bind_lowered_sql_global_aggregate_terminal::<E>(lowered.terminal)?;
Ok(SqlGlobalAggregateCommand {
query: Query::from_inner(apply_lowered_base_query_shape(
StructuralQuery::new(E::MODEL, consistency),
lowered.query,
)),
terminal,
})
}
fn bind_lowered_sql_global_aggregate_command_structural(
model: &'static crate::model::entity::EntityModel,
lowered: LoweredSqlGlobalAggregateCommand,
consistency: MissingRowPolicy,
) -> SqlGlobalAggregateCommandCore {
SqlGlobalAggregateCommandCore {
query: apply_lowered_base_query_shape(
StructuralQuery::new(model, consistency),
lowered.query,
),
terminal: lowered.terminal,
}
}
fn lower_global_aggregate_terminal(
projection: SqlProjection,
) -> Result<SqlGlobalAggregateTerminal, SqlLoweringError> {
let SqlProjection::Items(items) = projection else {
return Err(SqlLoweringError::unsupported_select_projection());
};
if items.len() != 1 {
return Err(SqlLoweringError::unsupported_select_projection());
}
let Some(SqlSelectItem::Aggregate(aggregate)) = items.into_iter().next() else {
return Err(SqlLoweringError::unsupported_select_projection());
};
match lower_sql_aggregate_shape(aggregate)? {
LoweredSqlAggregateShape::CountRows => Ok(SqlGlobalAggregateTerminal::CountRows),
LoweredSqlAggregateShape::CountField(field) => {
Ok(SqlGlobalAggregateTerminal::CountField(field))
}
LoweredSqlAggregateShape::FieldTarget {
kind: SqlAggregateKind::Sum,
field,
} => Ok(SqlGlobalAggregateTerminal::SumField(field)),
LoweredSqlAggregateShape::FieldTarget {
kind: SqlAggregateKind::Avg,
field,
} => Ok(SqlGlobalAggregateTerminal::AvgField(field)),
LoweredSqlAggregateShape::FieldTarget {
kind: SqlAggregateKind::Min,
field,
} => Ok(SqlGlobalAggregateTerminal::MinField(field)),
LoweredSqlAggregateShape::FieldTarget {
kind: SqlAggregateKind::Max,
field,
} => Ok(SqlGlobalAggregateTerminal::MaxField(field)),
LoweredSqlAggregateShape::FieldTarget {
kind: SqlAggregateKind::Count,
..
} => Err(SqlLoweringError::unsupported_select_projection()),
}
}
fn lower_sql_aggregate_shape(
call: SqlAggregateCall,
) -> Result<LoweredSqlAggregateShape, SqlLoweringError> {
match (call.kind, call.field) {
(SqlAggregateKind::Count, None) => Ok(LoweredSqlAggregateShape::CountRows),
(SqlAggregateKind::Count, Some(field)) => Ok(LoweredSqlAggregateShape::CountField(field)),
(
kind @ (SqlAggregateKind::Sum
| SqlAggregateKind::Avg
| SqlAggregateKind::Min
| SqlAggregateKind::Max),
Some(field),
) => Ok(LoweredSqlAggregateShape::FieldTarget { kind, field }),
_ => Err(SqlLoweringError::unsupported_select_projection()),
}
}
fn grouped_projection_aggregate_calls(
projection: &SqlProjection,
group_by_fields: &[String],
) -> Result<Vec<SqlAggregateCall>, SqlLoweringError> {
if group_by_fields.is_empty() {
return Err(SqlLoweringError::unsupported_select_group_by());
}
let SqlProjection::Items(items) = projection else {
return Err(SqlLoweringError::unsupported_select_group_by());
};
let mut projected_group_fields = Vec::<String>::new();
let mut aggregate_calls = Vec::<SqlAggregateCall>::new();
let mut seen_aggregate = false;
for item in items {
match item {
SqlSelectItem::Field(field) => {
if seen_aggregate {
return Err(SqlLoweringError::unsupported_select_group_by());
}
projected_group_fields.push(field.clone());
}
SqlSelectItem::Aggregate(aggregate) => {
seen_aggregate = true;
aggregate_calls.push(aggregate.clone());
}
SqlSelectItem::TextFunction(_) => {
return Err(SqlLoweringError::unsupported_select_group_by());
}
}
}
if aggregate_calls.is_empty() || projected_group_fields.as_slice() != group_by_fields {
return Err(SqlLoweringError::unsupported_select_group_by());
}
Ok(aggregate_calls)
}
fn lower_aggregate_call(
call: SqlAggregateCall,
) -> Result<crate::db::query::builder::AggregateExpr, SqlLoweringError> {
match lower_sql_aggregate_shape(call)? {
LoweredSqlAggregateShape::CountRows => Ok(count()),
LoweredSqlAggregateShape::CountField(field) => Ok(count_by(field)),
LoweredSqlAggregateShape::FieldTarget {
kind: SqlAggregateKind::Sum,
field,
} => Ok(sum(field)),
LoweredSqlAggregateShape::FieldTarget {
kind: SqlAggregateKind::Avg,
field,
} => Ok(avg(field)),
LoweredSqlAggregateShape::FieldTarget {
kind: SqlAggregateKind::Min,
field,
} => Ok(min_by(field)),
LoweredSqlAggregateShape::FieldTarget {
kind: SqlAggregateKind::Max,
field,
} => Ok(max_by(field)),
LoweredSqlAggregateShape::FieldTarget {
kind: SqlAggregateKind::Count,
..
} => Err(SqlLoweringError::unsupported_select_projection()),
}
}
fn resolve_having_aggregate_index(
target: &SqlAggregateCall,
grouped_projection_aggregates: &[SqlAggregateCall],
) -> Result<usize, SqlLoweringError> {
let mut matched = grouped_projection_aggregates
.iter()
.enumerate()
.filter_map(|(index, aggregate)| (aggregate == target).then_some(index));
let Some(index) = matched.next() else {
return Err(SqlLoweringError::unsupported_select_having());
};
if matched.next().is_some() {
return Err(SqlLoweringError::unsupported_select_having());
}
Ok(index)
}
fn lower_delete_shape(statement: SqlDeleteStatement) -> LoweredBaseQueryShape {
let SqlDeleteStatement {
predicate,
order_by,
limit,
entity: _,
} = statement;
LoweredBaseQueryShape {
predicate,
order_by,
limit,
offset: None,
}
}
fn apply_order_terms_structural(
mut query: StructuralQuery,
order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
) -> StructuralQuery {
for term in order_by {
query = match term.direction {
SqlOrderDirection::Asc => query.order_by(term.field),
SqlOrderDirection::Desc => query.order_by_desc(term.field),
};
}
query
}
fn normalize_having_clauses(
clauses: Vec<SqlHavingClause>,
entity_scope: &[String],
) -> Vec<SqlHavingClause> {
clauses
.into_iter()
.map(|clause| SqlHavingClause {
symbol: normalize_having_symbol(clause.symbol, entity_scope),
op: clause.op,
value: clause.value,
})
.collect()
}
fn normalize_having_symbol(symbol: SqlHavingSymbol, entity_scope: &[String]) -> SqlHavingSymbol {
match symbol {
SqlHavingSymbol::Field(field) => {
SqlHavingSymbol::Field(normalize_identifier_to_scope(field, entity_scope))
}
SqlHavingSymbol::Aggregate(aggregate) => SqlHavingSymbol::Aggregate(
normalize_aggregate_call_identifiers(aggregate, entity_scope),
),
}
}
fn normalize_aggregate_call_identifiers(
aggregate: SqlAggregateCall,
entity_scope: &[String],
) -> SqlAggregateCall {
SqlAggregateCall {
kind: aggregate.kind,
field: aggregate
.field
.map(|field| normalize_identifier_to_scope(field, entity_scope)),
}
}
fn sql_entity_scope_candidates(sql_entity: &str, expected_entity: &'static str) -> Vec<String> {
let mut out = Vec::new();
out.push(sql_entity.to_string());
out.push(expected_entity.to_string());
if let Some(last) = identifier_last_segment(sql_entity) {
out.push(last.to_string());
}
if let Some(last) = identifier_last_segment(expected_entity) {
out.push(last.to_string());
}
out
}
fn normalize_projection_identifiers(
projection: SqlProjection,
entity_scope: &[String],
) -> SqlProjection {
match projection {
SqlProjection::All => SqlProjection::All,
SqlProjection::Items(items) => SqlProjection::Items(
items
.into_iter()
.map(|item| match item {
SqlSelectItem::Field(field) => {
SqlSelectItem::Field(normalize_identifier(field, entity_scope))
}
SqlSelectItem::Aggregate(aggregate) => {
SqlSelectItem::Aggregate(SqlAggregateCall {
kind: aggregate.kind,
field: aggregate
.field
.map(|field| normalize_identifier(field, entity_scope)),
})
}
SqlSelectItem::TextFunction(SqlTextFunctionCall {
function,
field,
literal,
literal2,
literal3,
}) => SqlSelectItem::TextFunction(SqlTextFunctionCall {
function,
field: normalize_identifier(field, entity_scope),
literal,
literal2,
literal3,
}),
})
.collect(),
),
}
}
fn normalize_order_terms(
terms: Vec<crate::db::sql::parser::SqlOrderTerm>,
entity_scope: &[String],
) -> Vec<crate::db::sql::parser::SqlOrderTerm> {
terms
.into_iter()
.map(|term| crate::db::sql::parser::SqlOrderTerm {
field: normalize_order_term_identifier(term.field, entity_scope),
direction: term.direction,
})
.collect()
}
fn normalize_order_term_identifier(identifier: String, entity_scope: &[String]) -> String {
let Some(expression) = ExpressionOrderTerm::parse(identifier.as_str()) else {
return normalize_identifier(identifier, entity_scope);
};
let normalized_field = normalize_identifier(expression.field().to_string(), entity_scope);
expression.canonical_text_with_field(normalized_field.as_str())
}
fn normalize_identifier_list(fields: Vec<String>, entity_scope: &[String]) -> Vec<String> {
fields
.into_iter()
.map(|field| normalize_identifier(field, entity_scope))
.collect()
}
fn adapt_predicate_identifiers_to_scope(
predicate: Predicate,
entity_scope: &[String],
) -> Predicate {
rewrite_field_identifiers(predicate, |field| normalize_identifier(field, entity_scope))
}
fn normalize_identifier(identifier: String, entity_scope: &[String]) -> String {
normalize_identifier_to_scope(identifier, entity_scope)
}
fn ensure_entity_matches_expected(
sql_entity: &str,
expected_entity: &'static str,
) -> Result<(), SqlLoweringError> {
if identifiers_tail_match(sql_entity, expected_entity) {
return Ok(());
}
Err(SqlLoweringError::entity_mismatch(
sql_entity,
expected_entity,
))
}