1#[cfg(test)]
11mod tests;
12
13use crate::{
14 db::{
15 predicate::{MissingRowPolicy, Predicate},
16 query::{
17 builder::aggregate::{avg, count, count_by, max_by, min_by, sum},
18 intent::{Query, QueryError, StructuralQuery},
19 },
20 sql::identifier::{
21 identifier_last_segment, identifiers_tail_match, normalize_identifier_to_scope,
22 rewrite_field_identifiers,
23 },
24 sql::parser::{
25 SqlAggregateCall, SqlAggregateKind, SqlDeleteStatement, SqlExplainMode,
26 SqlExplainStatement, SqlExplainTarget, SqlHavingClause, SqlHavingSymbol,
27 SqlOrderDirection, SqlOrderTerm, SqlProjection, SqlSelectItem, SqlSelectStatement,
28 SqlStatement, SqlTextFunctionCall,
29 },
30 },
31 traits::EntityKind,
32};
33use thiserror::Error as ThisError;
34
35#[derive(Clone, Debug)]
44pub struct LoweredSqlCommand(LoweredSqlCommandInner);
45
46#[derive(Clone, Debug)]
47enum LoweredSqlCommandInner {
48 Query(LoweredSqlQuery),
49 Explain {
50 mode: SqlExplainMode,
51 query: LoweredSqlQuery,
52 },
53 ExplainGlobalAggregate {
54 mode: SqlExplainMode,
55 command: LoweredSqlGlobalAggregateCommand,
56 },
57 DescribeEntity,
58 ShowIndexesEntity,
59 ShowColumnsEntity,
60 ShowEntities,
61}
62
63#[cfg(test)]
71#[derive(Debug)]
72pub(crate) enum SqlCommand<E: EntityKind> {
73 Query(Query<E>),
74 Explain {
75 mode: SqlExplainMode,
76 query: Query<E>,
77 },
78 ExplainGlobalAggregate {
79 mode: SqlExplainMode,
80 command: SqlGlobalAggregateCommand<E>,
81 },
82 DescribeEntity,
83 ShowIndexesEntity,
84 ShowColumnsEntity,
85 ShowEntities,
86}
87
88impl LoweredSqlCommand {
89 #[must_use]
90 pub(in crate::db) const fn query(&self) -> Option<&LoweredSqlQuery> {
91 match &self.0 {
92 LoweredSqlCommandInner::Query(query) => Some(query),
93 LoweredSqlCommandInner::Explain { .. }
94 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. }
95 | LoweredSqlCommandInner::DescribeEntity
96 | LoweredSqlCommandInner::ShowIndexesEntity
97 | LoweredSqlCommandInner::ShowColumnsEntity
98 | LoweredSqlCommandInner::ShowEntities => None,
99 }
100 }
101}
102
103#[derive(Clone, Debug)]
110pub(crate) enum LoweredSqlQuery {
111 Select(LoweredSelectShape),
112 Delete(LoweredBaseQueryShape),
113}
114
115impl LoweredSqlQuery {
116 pub(crate) const fn has_grouping(&self) -> bool {
118 match self {
119 Self::Select(select) => select.has_grouping(),
120 Self::Delete(_) => false,
121 }
122 }
123}
124
125#[derive(Clone, Debug, Eq, PartialEq)]
133pub(crate) enum SqlGlobalAggregateTerminal {
134 CountRows,
135 CountField(String),
136 SumField(String),
137 AvgField(String),
138 MinField(String),
139 MaxField(String),
140}
141
142#[derive(Clone, Debug)]
151pub(crate) struct LoweredSqlGlobalAggregateCommand {
152 query: LoweredBaseQueryShape,
153 terminal: SqlGlobalAggregateTerminal,
154}
155
156enum LoweredSqlAggregateShape {
164 CountRows,
165 CountField(String),
166 FieldTarget {
167 kind: SqlAggregateKind,
168 field: String,
169 },
170}
171
172#[derive(Debug)]
179pub(crate) struct SqlGlobalAggregateCommand<E: EntityKind> {
180 query: Query<E>,
181 terminal: SqlGlobalAggregateTerminal,
182}
183
184impl<E: EntityKind> SqlGlobalAggregateCommand<E> {
185 #[must_use]
187 pub(crate) const fn query(&self) -> &Query<E> {
188 &self.query
189 }
190
191 #[must_use]
193 pub(crate) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
194 &self.terminal
195 }
196}
197
198#[derive(Debug)]
208pub(crate) struct SqlGlobalAggregateCommandCore {
209 query: StructuralQuery,
210 terminal: SqlGlobalAggregateTerminal,
211}
212
213impl SqlGlobalAggregateCommandCore {
214 #[must_use]
216 pub(in crate::db) const fn query(&self) -> &StructuralQuery {
217 &self.query
218 }
219
220 #[must_use]
222 pub(in crate::db) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
223 &self.terminal
224 }
225}
226
227#[derive(Debug, ThisError)]
234pub(crate) enum SqlLoweringError {
235 #[error("{0}")]
236 Parse(#[from] crate::db::sql::parser::SqlParseError),
237
238 #[error("{0}")]
239 Query(#[from] QueryError),
240
241 #[error("SQL entity '{sql_entity}' does not match requested entity type '{expected_entity}'")]
242 EntityMismatch {
243 sql_entity: String,
244 expected_entity: &'static str,
245 },
246
247 #[error(
248 "unsupported SQL SELECT projection; supported forms are SELECT *, field lists, or grouped aggregate shapes"
249 )]
250 UnsupportedSelectProjection,
251
252 #[error("unsupported SQL SELECT DISTINCT")]
253 UnsupportedSelectDistinct,
254
255 #[error("unsupported SQL GROUP BY projection shape")]
256 UnsupportedSelectGroupBy,
257
258 #[error("unsupported SQL HAVING shape")]
259 UnsupportedSelectHaving,
260}
261
262impl SqlLoweringError {
263 fn entity_mismatch(sql_entity: impl Into<String>, expected_entity: &'static str) -> Self {
265 Self::EntityMismatch {
266 sql_entity: sql_entity.into(),
267 expected_entity,
268 }
269 }
270
271 const fn unsupported_select_projection() -> Self {
273 Self::UnsupportedSelectProjection
274 }
275
276 const fn unsupported_select_distinct() -> Self {
278 Self::UnsupportedSelectDistinct
279 }
280
281 const fn unsupported_select_group_by() -> Self {
283 Self::UnsupportedSelectGroupBy
284 }
285
286 const fn unsupported_select_having() -> Self {
288 Self::UnsupportedSelectHaving
289 }
290}
291
292#[derive(Clone, Debug)]
303pub(crate) struct PreparedSqlStatement {
304 statement: SqlStatement,
305}
306
307#[derive(Clone, Copy, Debug, Eq, PartialEq)]
308pub(crate) enum LoweredSqlLaneKind {
309 Query,
310 Explain,
311 Describe,
312 ShowIndexes,
313 ShowColumns,
314 ShowEntities,
315}
316
317#[cfg(test)]
319pub(crate) fn compile_sql_command<E: EntityKind>(
320 sql: &str,
321 consistency: MissingRowPolicy,
322) -> Result<SqlCommand<E>, SqlLoweringError> {
323 let statement = crate::db::sql::parser::parse_sql(sql)?;
324 compile_sql_command_from_statement::<E>(statement, consistency)
325}
326
327#[cfg(test)]
329pub(crate) fn compile_sql_command_from_statement<E: EntityKind>(
330 statement: SqlStatement,
331 consistency: MissingRowPolicy,
332) -> Result<SqlCommand<E>, SqlLoweringError> {
333 let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
334 compile_sql_command_from_prepared_statement::<E>(prepared, consistency)
335}
336
337#[cfg(test)]
339pub(crate) fn compile_sql_command_from_prepared_statement<E: EntityKind>(
340 prepared: PreparedSqlStatement,
341 consistency: MissingRowPolicy,
342) -> Result<SqlCommand<E>, SqlLoweringError> {
343 let lowered = lower_sql_command_from_prepared_statement(prepared, E::MODEL.primary_key.name)?;
344
345 bind_lowered_sql_command::<E>(lowered, consistency)
346}
347
348#[inline(never)]
350pub(crate) fn lower_sql_command_from_prepared_statement(
351 prepared: PreparedSqlStatement,
352 primary_key_field: &str,
353) -> Result<LoweredSqlCommand, SqlLoweringError> {
354 lower_prepared_statement(prepared.statement, primary_key_field)
355}
356
357pub(crate) const fn lowered_sql_command_lane(command: &LoweredSqlCommand) -> LoweredSqlLaneKind {
358 match command.0 {
359 LoweredSqlCommandInner::Query(_) => LoweredSqlLaneKind::Query,
360 LoweredSqlCommandInner::Explain { .. }
361 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. } => LoweredSqlLaneKind::Explain,
362 LoweredSqlCommandInner::DescribeEntity => LoweredSqlLaneKind::Describe,
363 LoweredSqlCommandInner::ShowIndexesEntity => LoweredSqlLaneKind::ShowIndexes,
364 LoweredSqlCommandInner::ShowColumnsEntity => LoweredSqlLaneKind::ShowColumns,
365 LoweredSqlCommandInner::ShowEntities => LoweredSqlLaneKind::ShowEntities,
366 }
367}
368
369pub(in crate::db) fn is_sql_global_aggregate_statement(statement: &SqlStatement) -> bool {
372 let SqlStatement::Select(statement) = statement else {
373 return false;
374 };
375
376 is_sql_global_aggregate_select(statement)
377}
378
379fn is_sql_global_aggregate_select(statement: &SqlSelectStatement) -> bool {
382 if statement.distinct || !statement.group_by.is_empty() || !statement.having.is_empty() {
383 return false;
384 }
385
386 lower_global_aggregate_terminal(statement.projection.clone()).is_ok()
387}
388
389#[inline(never)]
391pub(crate) fn render_lowered_sql_explain_plan_or_json(
392 lowered: &LoweredSqlCommand,
393 model: &'static crate::model::entity::EntityModel,
394 consistency: MissingRowPolicy,
395) -> Result<Option<String>, SqlLoweringError> {
396 let LoweredSqlCommandInner::Explain { mode, query } = &lowered.0 else {
397 return Ok(None);
398 };
399
400 let query = bind_lowered_sql_query_structural(model, query.clone(), consistency)?;
401 let rendered = match mode {
402 SqlExplainMode::Plan | SqlExplainMode::Json => {
403 let plan = query.build_plan()?;
404 let explain = plan.explain_with_model(model);
405
406 match mode {
407 SqlExplainMode::Plan => explain.render_text_canonical(),
408 SqlExplainMode::Json => explain.render_json_canonical(),
409 SqlExplainMode::Execution => unreachable!("execution mode handled above"),
410 }
411 }
412 SqlExplainMode::Execution => query.explain_execution_text()?,
413 };
414
415 Ok(Some(rendered))
416}
417
418pub(crate) fn bind_lowered_sql_explain_global_aggregate_structural(
421 lowered: &LoweredSqlCommand,
422 model: &'static crate::model::entity::EntityModel,
423 consistency: MissingRowPolicy,
424) -> Option<(SqlExplainMode, SqlGlobalAggregateCommandCore)> {
425 let LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } = &lowered.0 else {
426 return None;
427 };
428
429 Some((
430 *mode,
431 bind_lowered_sql_global_aggregate_command_structural(model, command.clone(), consistency),
432 ))
433}
434
435#[cfg(test)]
437pub(crate) fn bind_lowered_sql_command<E: EntityKind>(
438 lowered: LoweredSqlCommand,
439 consistency: MissingRowPolicy,
440) -> Result<SqlCommand<E>, SqlLoweringError> {
441 match lowered.0 {
442 LoweredSqlCommandInner::Query(query) => Ok(SqlCommand::Query(bind_lowered_sql_query::<E>(
443 query,
444 consistency,
445 )?)),
446 LoweredSqlCommandInner::Explain { mode, query } => Ok(SqlCommand::Explain {
447 mode,
448 query: bind_lowered_sql_query::<E>(query, consistency)?,
449 }),
450 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } => {
451 Ok(SqlCommand::ExplainGlobalAggregate {
452 mode,
453 command: bind_lowered_sql_global_aggregate_command::<E>(command, consistency),
454 })
455 }
456 LoweredSqlCommandInner::DescribeEntity => Ok(SqlCommand::DescribeEntity),
457 LoweredSqlCommandInner::ShowIndexesEntity => Ok(SqlCommand::ShowIndexesEntity),
458 LoweredSqlCommandInner::ShowColumnsEntity => Ok(SqlCommand::ShowColumnsEntity),
459 LoweredSqlCommandInner::ShowEntities => Ok(SqlCommand::ShowEntities),
460 }
461}
462
463#[inline(never)]
465pub(crate) fn prepare_sql_statement(
466 statement: SqlStatement,
467 expected_entity: &'static str,
468) -> Result<PreparedSqlStatement, SqlLoweringError> {
469 let statement = prepare_statement(statement, expected_entity)?;
470
471 Ok(PreparedSqlStatement { statement })
472}
473
474#[cfg(test)]
476pub(crate) fn compile_sql_global_aggregate_command<E: EntityKind>(
477 sql: &str,
478 consistency: MissingRowPolicy,
479) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
480 let statement = crate::db::sql::parser::parse_sql(sql)?;
481 let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
482 compile_sql_global_aggregate_command_from_prepared::<E>(prepared, consistency)
483}
484
485pub(crate) fn compile_sql_global_aggregate_command_from_prepared<E: EntityKind>(
489 prepared: PreparedSqlStatement,
490 consistency: MissingRowPolicy,
491) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
492 let SqlStatement::Select(statement) = prepared.statement else {
493 return Err(SqlLoweringError::unsupported_select_projection());
494 };
495
496 Ok(bind_lowered_sql_global_aggregate_command::<E>(
497 lower_global_aggregate_select_shape(statement)?,
498 consistency,
499 ))
500}
501
502#[inline(never)]
503fn prepare_statement(
504 statement: SqlStatement,
505 expected_entity: &'static str,
506) -> Result<SqlStatement, SqlLoweringError> {
507 match statement {
508 SqlStatement::Select(statement) => Ok(SqlStatement::Select(prepare_select_statement(
509 statement,
510 expected_entity,
511 )?)),
512 SqlStatement::Delete(statement) => Ok(SqlStatement::Delete(prepare_delete_statement(
513 statement,
514 expected_entity,
515 )?)),
516 SqlStatement::Explain(statement) => Ok(SqlStatement::Explain(prepare_explain_statement(
517 statement,
518 expected_entity,
519 )?)),
520 SqlStatement::Describe(statement) => {
521 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
522
523 Ok(SqlStatement::Describe(statement))
524 }
525 SqlStatement::ShowIndexes(statement) => {
526 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
527
528 Ok(SqlStatement::ShowIndexes(statement))
529 }
530 SqlStatement::ShowColumns(statement) => {
531 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
532
533 Ok(SqlStatement::ShowColumns(statement))
534 }
535 SqlStatement::ShowEntities(statement) => Ok(SqlStatement::ShowEntities(statement)),
536 }
537}
538
539fn prepare_explain_statement(
540 statement: SqlExplainStatement,
541 expected_entity: &'static str,
542) -> Result<SqlExplainStatement, SqlLoweringError> {
543 let target = match statement.statement {
544 SqlExplainTarget::Select(select_statement) => {
545 SqlExplainTarget::Select(prepare_select_statement(select_statement, expected_entity)?)
546 }
547 SqlExplainTarget::Delete(delete_statement) => {
548 SqlExplainTarget::Delete(prepare_delete_statement(delete_statement, expected_entity)?)
549 }
550 };
551
552 Ok(SqlExplainStatement {
553 mode: statement.mode,
554 statement: target,
555 })
556}
557
558fn prepare_select_statement(
559 statement: SqlSelectStatement,
560 expected_entity: &'static str,
561) -> Result<SqlSelectStatement, SqlLoweringError> {
562 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
563
564 Ok(normalize_select_statement_to_expected_entity(
565 statement,
566 expected_entity,
567 ))
568}
569
570fn normalize_select_statement_to_expected_entity(
571 mut statement: SqlSelectStatement,
572 expected_entity: &'static str,
573) -> SqlSelectStatement {
574 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
577 statement.projection =
578 normalize_projection_identifiers(statement.projection, entity_scope.as_slice());
579 statement.group_by = normalize_identifier_list(statement.group_by, entity_scope.as_slice());
580 statement.predicate = statement
581 .predicate
582 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
583 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
584 statement.having = normalize_having_clauses(statement.having, entity_scope.as_slice());
585
586 statement
587}
588
589fn prepare_delete_statement(
590 mut statement: SqlDeleteStatement,
591 expected_entity: &'static str,
592) -> Result<SqlDeleteStatement, SqlLoweringError> {
593 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
594 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
595 statement.predicate = statement
596 .predicate
597 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
598 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
599
600 Ok(statement)
601}
602
603#[inline(never)]
604fn lower_prepared_statement(
605 statement: SqlStatement,
606 primary_key_field: &str,
607) -> Result<LoweredSqlCommand, SqlLoweringError> {
608 match statement {
609 SqlStatement::Select(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
610 LoweredSqlQuery::Select(lower_select_shape(statement, primary_key_field)?),
611 ))),
612 SqlStatement::Delete(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
613 LoweredSqlQuery::Delete(lower_delete_shape(statement)),
614 ))),
615 SqlStatement::Explain(statement) => lower_explain_prepared(statement, primary_key_field),
616 SqlStatement::Describe(_) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::DescribeEntity)),
617 SqlStatement::ShowIndexes(_) => {
618 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowIndexesEntity))
619 }
620 SqlStatement::ShowColumns(_) => {
621 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowColumnsEntity))
622 }
623 SqlStatement::ShowEntities(_) => {
624 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowEntities))
625 }
626 }
627}
628
629fn lower_explain_prepared(
630 statement: SqlExplainStatement,
631 primary_key_field: &str,
632) -> Result<LoweredSqlCommand, SqlLoweringError> {
633 let mode = statement.mode;
634
635 match statement.statement {
636 SqlExplainTarget::Select(select_statement) => {
637 lower_explain_select_prepared(select_statement, mode, primary_key_field)
638 }
639 SqlExplainTarget::Delete(delete_statement) => {
640 Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
641 mode,
642 query: LoweredSqlQuery::Delete(lower_delete_shape(delete_statement)),
643 }))
644 }
645 }
646}
647
648fn lower_explain_select_prepared(
649 statement: SqlSelectStatement,
650 mode: SqlExplainMode,
651 primary_key_field: &str,
652) -> Result<LoweredSqlCommand, SqlLoweringError> {
653 match lower_select_shape(statement.clone(), primary_key_field) {
654 Ok(query) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
655 mode,
656 query: LoweredSqlQuery::Select(query),
657 })),
658 Err(SqlLoweringError::UnsupportedSelectProjection) => {
659 let command = lower_global_aggregate_select_shape(statement)?;
660
661 Ok(LoweredSqlCommand(
662 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command },
663 ))
664 }
665 Err(err) => Err(err),
666 }
667}
668
669fn lower_global_aggregate_select_shape(
670 statement: SqlSelectStatement,
671) -> Result<LoweredSqlGlobalAggregateCommand, SqlLoweringError> {
672 let SqlSelectStatement {
673 projection,
674 predicate,
675 distinct,
676 group_by,
677 having,
678 order_by,
679 limit,
680 offset,
681 entity: _,
682 } = statement;
683
684 if distinct {
685 return Err(SqlLoweringError::unsupported_select_distinct());
686 }
687 if !group_by.is_empty() {
688 return Err(SqlLoweringError::unsupported_select_group_by());
689 }
690 if !having.is_empty() {
691 return Err(SqlLoweringError::unsupported_select_having());
692 }
693
694 let terminal = lower_global_aggregate_terminal(projection)?;
695
696 Ok(LoweredSqlGlobalAggregateCommand {
697 query: LoweredBaseQueryShape {
698 predicate,
699 order_by,
700 limit,
701 offset,
702 },
703 terminal,
704 })
705}
706
707#[derive(Clone, Debug)]
715enum ResolvedHavingClause {
716 GroupField {
717 field: String,
718 op: crate::db::predicate::CompareOp,
719 value: crate::value::Value,
720 },
721 Aggregate {
722 aggregate_index: usize,
723 op: crate::db::predicate::CompareOp,
724 value: crate::value::Value,
725 },
726}
727
728#[derive(Clone, Debug)]
735pub(crate) struct LoweredSelectShape {
736 scalar_projection_fields: Option<Vec<String>>,
737 grouped_projection_aggregates: Vec<SqlAggregateCall>,
738 group_by_fields: Vec<String>,
739 distinct: bool,
740 having: Vec<ResolvedHavingClause>,
741 predicate: Option<Predicate>,
742 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
743 limit: Option<u32>,
744 offset: Option<u32>,
745}
746
747impl LoweredSelectShape {
748 const fn has_grouping(&self) -> bool {
750 !self.group_by_fields.is_empty()
751 }
752}
753
754#[derive(Clone, Debug)]
763pub(crate) struct LoweredBaseQueryShape {
764 predicate: Option<Predicate>,
765 order_by: Vec<SqlOrderTerm>,
766 limit: Option<u32>,
767 offset: Option<u32>,
768}
769
770#[inline(never)]
771fn lower_select_shape(
772 statement: SqlSelectStatement,
773 primary_key_field: &str,
774) -> Result<LoweredSelectShape, SqlLoweringError> {
775 let SqlSelectStatement {
776 projection,
777 predicate,
778 distinct,
779 group_by,
780 having,
781 order_by,
782 limit,
783 offset,
784 entity: _,
785 } = statement;
786 let projection_for_having = projection.clone();
787
788 let (scalar_projection_fields, grouped_projection_aggregates) = if group_by.is_empty() {
790 let scalar_projection_fields =
791 lower_scalar_projection_fields(projection, distinct, primary_key_field)?;
792 (scalar_projection_fields, Vec::new())
793 } else {
794 if distinct {
795 return Err(SqlLoweringError::unsupported_select_distinct());
796 }
797 let grouped_projection_aggregates =
798 grouped_projection_aggregate_calls(&projection, group_by.as_slice())?;
799 (None, grouped_projection_aggregates)
800 };
801
802 let having = lower_having_clauses(
804 having,
805 &projection_for_having,
806 group_by.as_slice(),
807 grouped_projection_aggregates.as_slice(),
808 )?;
809
810 Ok(LoweredSelectShape {
811 scalar_projection_fields,
812 grouped_projection_aggregates,
813 group_by_fields: group_by,
814 distinct,
815 having,
816 predicate,
817 order_by,
818 limit,
819 offset,
820 })
821}
822
823fn lower_scalar_projection_fields(
824 projection: SqlProjection,
825 distinct: bool,
826 primary_key_field: &str,
827) -> Result<Option<Vec<String>>, SqlLoweringError> {
828 let SqlProjection::Items(items) = projection else {
829 if distinct {
830 return Ok(None);
831 }
832
833 return Ok(None);
834 };
835
836 let has_aggregate = items
837 .iter()
838 .any(|item| matches!(item, SqlSelectItem::Aggregate(_)));
839 if has_aggregate {
840 return Err(SqlLoweringError::unsupported_select_projection());
841 }
842
843 let fields = items
844 .into_iter()
845 .map(|item| match item {
846 SqlSelectItem::Field(field) => Ok(field),
847 SqlSelectItem::Aggregate(_) | SqlSelectItem::TextFunction(_) => {
848 Err(SqlLoweringError::unsupported_select_projection())
849 }
850 })
851 .collect::<Result<Vec<_>, _>>()?;
852
853 validate_scalar_distinct_projection(distinct, fields.as_slice(), primary_key_field)?;
854
855 Ok(Some(fields))
856}
857
858fn validate_scalar_distinct_projection(
859 distinct: bool,
860 projection_fields: &[String],
861 primary_key_field: &str,
862) -> Result<(), SqlLoweringError> {
863 if !distinct {
864 return Ok(());
865 }
866
867 if projection_fields.is_empty() {
868 return Ok(());
869 }
870
871 let has_primary_key_field = projection_fields
872 .iter()
873 .any(|field| field == primary_key_field);
874 if !has_primary_key_field {
875 return Err(SqlLoweringError::unsupported_select_distinct());
876 }
877
878 Ok(())
879}
880
881fn lower_having_clauses(
882 having_clauses: Vec<SqlHavingClause>,
883 projection: &SqlProjection,
884 group_by_fields: &[String],
885 grouped_projection_aggregates: &[SqlAggregateCall],
886) -> Result<Vec<ResolvedHavingClause>, SqlLoweringError> {
887 if having_clauses.is_empty() {
888 return Ok(Vec::new());
889 }
890 if group_by_fields.is_empty() {
891 return Err(SqlLoweringError::unsupported_select_having());
892 }
893
894 let projection_aggregates = grouped_projection_aggregate_calls(projection, group_by_fields)
895 .map_err(|_| SqlLoweringError::unsupported_select_having())?;
896 if projection_aggregates.as_slice() != grouped_projection_aggregates {
897 return Err(SqlLoweringError::unsupported_select_having());
898 }
899
900 let mut lowered = Vec::with_capacity(having_clauses.len());
901 for clause in having_clauses {
902 match clause.symbol {
903 SqlHavingSymbol::Field(field) => lowered.push(ResolvedHavingClause::GroupField {
904 field,
905 op: clause.op,
906 value: clause.value,
907 }),
908 SqlHavingSymbol::Aggregate(aggregate) => {
909 let aggregate_index =
910 resolve_having_aggregate_index(&aggregate, grouped_projection_aggregates)?;
911 lowered.push(ResolvedHavingClause::Aggregate {
912 aggregate_index,
913 op: clause.op,
914 value: clause.value,
915 });
916 }
917 }
918 }
919
920 Ok(lowered)
921}
922
923#[inline(never)]
924pub(in crate::db) fn apply_lowered_select_shape(
925 mut query: StructuralQuery,
926 lowered: LoweredSelectShape,
927) -> Result<StructuralQuery, SqlLoweringError> {
928 let LoweredSelectShape {
929 scalar_projection_fields,
930 grouped_projection_aggregates,
931 group_by_fields,
932 distinct,
933 having,
934 predicate,
935 order_by,
936 limit,
937 offset,
938 } = lowered;
939
940 for field in group_by_fields {
942 query = query.group_by(field)?;
943 }
944
945 if distinct {
947 query = query.distinct();
948 }
949 if let Some(fields) = scalar_projection_fields {
950 query = query.select_fields(fields);
951 }
952 for aggregate in grouped_projection_aggregates {
953 query = query.aggregate(lower_aggregate_call(aggregate)?);
954 }
955
956 for clause in having {
958 match clause {
959 ResolvedHavingClause::GroupField { field, op, value } => {
960 query = query.having_group(field, op, value)?;
961 }
962 ResolvedHavingClause::Aggregate {
963 aggregate_index,
964 op,
965 value,
966 } => {
967 query = query.having_aggregate(aggregate_index, op, value)?;
968 }
969 }
970 }
971
972 Ok(apply_lowered_base_query_shape(
974 query,
975 LoweredBaseQueryShape {
976 predicate,
977 order_by,
978 limit,
979 offset,
980 },
981 ))
982}
983
984fn apply_lowered_base_query_shape(
985 mut query: StructuralQuery,
986 lowered: LoweredBaseQueryShape,
987) -> StructuralQuery {
988 if let Some(predicate) = lowered.predicate {
989 query = query.filter(predicate);
990 }
991 query = apply_order_terms_structural(query, lowered.order_by);
992 if let Some(limit) = lowered.limit {
993 query = query.limit(limit);
994 }
995 if let Some(offset) = lowered.offset {
996 query = query.offset(offset);
997 }
998
999 query
1000}
1001
1002pub(in crate::db) fn bind_lowered_sql_query_structural(
1003 model: &'static crate::model::entity::EntityModel,
1004 lowered: LoweredSqlQuery,
1005 consistency: MissingRowPolicy,
1006) -> Result<StructuralQuery, SqlLoweringError> {
1007 match lowered {
1008 LoweredSqlQuery::Select(select) => {
1009 apply_lowered_select_shape(StructuralQuery::new(model, consistency), select)
1010 }
1011 LoweredSqlQuery::Delete(delete) => Ok(bind_lowered_sql_delete_query_structural(
1012 model,
1013 delete,
1014 consistency,
1015 )),
1016 }
1017}
1018
1019pub(in crate::db) fn bind_lowered_sql_delete_query_structural(
1020 model: &'static crate::model::entity::EntityModel,
1021 delete: LoweredBaseQueryShape,
1022 consistency: MissingRowPolicy,
1023) -> StructuralQuery {
1024 apply_lowered_base_query_shape(StructuralQuery::new(model, consistency).delete(), delete)
1025}
1026
1027pub(in crate::db) fn bind_lowered_sql_query<E: EntityKind>(
1028 lowered: LoweredSqlQuery,
1029 consistency: MissingRowPolicy,
1030) -> Result<Query<E>, SqlLoweringError> {
1031 let structural = bind_lowered_sql_query_structural(E::MODEL, lowered, consistency)?;
1032
1033 Ok(Query::from_inner(structural))
1034}
1035
1036fn bind_lowered_sql_global_aggregate_command<E: EntityKind>(
1037 lowered: LoweredSqlGlobalAggregateCommand,
1038 consistency: MissingRowPolicy,
1039) -> SqlGlobalAggregateCommand<E> {
1040 SqlGlobalAggregateCommand {
1041 query: Query::from_inner(apply_lowered_base_query_shape(
1042 StructuralQuery::new(E::MODEL, consistency),
1043 lowered.query,
1044 )),
1045 terminal: lowered.terminal,
1046 }
1047}
1048
1049fn bind_lowered_sql_global_aggregate_command_structural(
1050 model: &'static crate::model::entity::EntityModel,
1051 lowered: LoweredSqlGlobalAggregateCommand,
1052 consistency: MissingRowPolicy,
1053) -> SqlGlobalAggregateCommandCore {
1054 SqlGlobalAggregateCommandCore {
1055 query: apply_lowered_base_query_shape(
1056 StructuralQuery::new(model, consistency),
1057 lowered.query,
1058 ),
1059 terminal: lowered.terminal,
1060 }
1061}
1062
1063fn lower_global_aggregate_terminal(
1064 projection: SqlProjection,
1065) -> Result<SqlGlobalAggregateTerminal, SqlLoweringError> {
1066 let SqlProjection::Items(items) = projection else {
1067 return Err(SqlLoweringError::unsupported_select_projection());
1068 };
1069 if items.len() != 1 {
1070 return Err(SqlLoweringError::unsupported_select_projection());
1071 }
1072
1073 let Some(SqlSelectItem::Aggregate(aggregate)) = items.into_iter().next() else {
1074 return Err(SqlLoweringError::unsupported_select_projection());
1075 };
1076
1077 match lower_sql_aggregate_shape(aggregate)? {
1078 LoweredSqlAggregateShape::CountRows => Ok(SqlGlobalAggregateTerminal::CountRows),
1079 LoweredSqlAggregateShape::CountField(field) => {
1080 Ok(SqlGlobalAggregateTerminal::CountField(field))
1081 }
1082 LoweredSqlAggregateShape::FieldTarget {
1083 kind: SqlAggregateKind::Sum,
1084 field,
1085 } => Ok(SqlGlobalAggregateTerminal::SumField(field)),
1086 LoweredSqlAggregateShape::FieldTarget {
1087 kind: SqlAggregateKind::Avg,
1088 field,
1089 } => Ok(SqlGlobalAggregateTerminal::AvgField(field)),
1090 LoweredSqlAggregateShape::FieldTarget {
1091 kind: SqlAggregateKind::Min,
1092 field,
1093 } => Ok(SqlGlobalAggregateTerminal::MinField(field)),
1094 LoweredSqlAggregateShape::FieldTarget {
1095 kind: SqlAggregateKind::Max,
1096 field,
1097 } => Ok(SqlGlobalAggregateTerminal::MaxField(field)),
1098 LoweredSqlAggregateShape::FieldTarget {
1099 kind: SqlAggregateKind::Count,
1100 ..
1101 } => Err(SqlLoweringError::unsupported_select_projection()),
1102 }
1103}
1104
1105fn lower_sql_aggregate_shape(
1106 call: SqlAggregateCall,
1107) -> Result<LoweredSqlAggregateShape, SqlLoweringError> {
1108 match (call.kind, call.field) {
1109 (SqlAggregateKind::Count, None) => Ok(LoweredSqlAggregateShape::CountRows),
1110 (SqlAggregateKind::Count, Some(field)) => Ok(LoweredSqlAggregateShape::CountField(field)),
1111 (
1112 kind @ (SqlAggregateKind::Sum
1113 | SqlAggregateKind::Avg
1114 | SqlAggregateKind::Min
1115 | SqlAggregateKind::Max),
1116 Some(field),
1117 ) => Ok(LoweredSqlAggregateShape::FieldTarget { kind, field }),
1118 _ => Err(SqlLoweringError::unsupported_select_projection()),
1119 }
1120}
1121
1122fn grouped_projection_aggregate_calls(
1123 projection: &SqlProjection,
1124 group_by_fields: &[String],
1125) -> Result<Vec<SqlAggregateCall>, SqlLoweringError> {
1126 if group_by_fields.is_empty() {
1127 return Err(SqlLoweringError::unsupported_select_group_by());
1128 }
1129
1130 let SqlProjection::Items(items) = projection else {
1131 return Err(SqlLoweringError::unsupported_select_group_by());
1132 };
1133
1134 let mut projected_group_fields = Vec::<String>::new();
1135 let mut aggregate_calls = Vec::<SqlAggregateCall>::new();
1136 let mut seen_aggregate = false;
1137
1138 for item in items {
1139 match item {
1140 SqlSelectItem::Field(field) => {
1141 if seen_aggregate {
1144 return Err(SqlLoweringError::unsupported_select_group_by());
1145 }
1146 projected_group_fields.push(field.clone());
1147 }
1148 SqlSelectItem::Aggregate(aggregate) => {
1149 seen_aggregate = true;
1150 aggregate_calls.push(aggregate.clone());
1151 }
1152 SqlSelectItem::TextFunction(_) => {
1153 return Err(SqlLoweringError::unsupported_select_group_by());
1154 }
1155 }
1156 }
1157
1158 if aggregate_calls.is_empty() || projected_group_fields.as_slice() != group_by_fields {
1159 return Err(SqlLoweringError::unsupported_select_group_by());
1160 }
1161
1162 Ok(aggregate_calls)
1163}
1164
1165fn lower_aggregate_call(
1166 call: SqlAggregateCall,
1167) -> Result<crate::db::query::builder::AggregateExpr, SqlLoweringError> {
1168 match lower_sql_aggregate_shape(call)? {
1169 LoweredSqlAggregateShape::CountRows => Ok(count()),
1170 LoweredSqlAggregateShape::CountField(field) => Ok(count_by(field)),
1171 LoweredSqlAggregateShape::FieldTarget {
1172 kind: SqlAggregateKind::Sum,
1173 field,
1174 } => Ok(sum(field)),
1175 LoweredSqlAggregateShape::FieldTarget {
1176 kind: SqlAggregateKind::Avg,
1177 field,
1178 } => Ok(avg(field)),
1179 LoweredSqlAggregateShape::FieldTarget {
1180 kind: SqlAggregateKind::Min,
1181 field,
1182 } => Ok(min_by(field)),
1183 LoweredSqlAggregateShape::FieldTarget {
1184 kind: SqlAggregateKind::Max,
1185 field,
1186 } => Ok(max_by(field)),
1187 LoweredSqlAggregateShape::FieldTarget {
1188 kind: SqlAggregateKind::Count,
1189 ..
1190 } => Err(SqlLoweringError::unsupported_select_projection()),
1191 }
1192}
1193
1194fn resolve_having_aggregate_index(
1195 target: &SqlAggregateCall,
1196 grouped_projection_aggregates: &[SqlAggregateCall],
1197) -> Result<usize, SqlLoweringError> {
1198 let mut matched = grouped_projection_aggregates
1199 .iter()
1200 .enumerate()
1201 .filter_map(|(index, aggregate)| (aggregate == target).then_some(index));
1202 let Some(index) = matched.next() else {
1203 return Err(SqlLoweringError::unsupported_select_having());
1204 };
1205 if matched.next().is_some() {
1206 return Err(SqlLoweringError::unsupported_select_having());
1207 }
1208
1209 Ok(index)
1210}
1211
1212fn lower_delete_shape(statement: SqlDeleteStatement) -> LoweredBaseQueryShape {
1213 let SqlDeleteStatement {
1214 predicate,
1215 order_by,
1216 limit,
1217 entity: _,
1218 } = statement;
1219
1220 LoweredBaseQueryShape {
1221 predicate,
1222 order_by,
1223 limit,
1224 offset: None,
1225 }
1226}
1227
1228fn apply_order_terms_structural(
1229 mut query: StructuralQuery,
1230 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
1231) -> StructuralQuery {
1232 for term in order_by {
1233 query = match term.direction {
1234 SqlOrderDirection::Asc => query.order_by(term.field),
1235 SqlOrderDirection::Desc => query.order_by_desc(term.field),
1236 };
1237 }
1238
1239 query
1240}
1241
1242fn normalize_having_clauses(
1243 clauses: Vec<SqlHavingClause>,
1244 entity_scope: &[String],
1245) -> Vec<SqlHavingClause> {
1246 clauses
1247 .into_iter()
1248 .map(|clause| SqlHavingClause {
1249 symbol: normalize_having_symbol(clause.symbol, entity_scope),
1250 op: clause.op,
1251 value: clause.value,
1252 })
1253 .collect()
1254}
1255
1256fn normalize_having_symbol(symbol: SqlHavingSymbol, entity_scope: &[String]) -> SqlHavingSymbol {
1257 match symbol {
1258 SqlHavingSymbol::Field(field) => {
1259 SqlHavingSymbol::Field(normalize_identifier_to_scope(field, entity_scope))
1260 }
1261 SqlHavingSymbol::Aggregate(aggregate) => SqlHavingSymbol::Aggregate(
1262 normalize_aggregate_call_identifiers(aggregate, entity_scope),
1263 ),
1264 }
1265}
1266
1267fn normalize_aggregate_call_identifiers(
1268 aggregate: SqlAggregateCall,
1269 entity_scope: &[String],
1270) -> SqlAggregateCall {
1271 SqlAggregateCall {
1272 kind: aggregate.kind,
1273 field: aggregate
1274 .field
1275 .map(|field| normalize_identifier_to_scope(field, entity_scope)),
1276 }
1277}
1278
1279fn sql_entity_scope_candidates(sql_entity: &str, expected_entity: &'static str) -> Vec<String> {
1282 let mut out = Vec::new();
1283 out.push(sql_entity.to_string());
1284 out.push(expected_entity.to_string());
1285
1286 if let Some(last) = identifier_last_segment(sql_entity) {
1287 out.push(last.to_string());
1288 }
1289 if let Some(last) = identifier_last_segment(expected_entity) {
1290 out.push(last.to_string());
1291 }
1292
1293 out
1294}
1295
1296fn normalize_projection_identifiers(
1297 projection: SqlProjection,
1298 entity_scope: &[String],
1299) -> SqlProjection {
1300 match projection {
1301 SqlProjection::All => SqlProjection::All,
1302 SqlProjection::Items(items) => SqlProjection::Items(
1303 items
1304 .into_iter()
1305 .map(|item| match item {
1306 SqlSelectItem::Field(field) => {
1307 SqlSelectItem::Field(normalize_identifier(field, entity_scope))
1308 }
1309 SqlSelectItem::Aggregate(aggregate) => {
1310 SqlSelectItem::Aggregate(SqlAggregateCall {
1311 kind: aggregate.kind,
1312 field: aggregate
1313 .field
1314 .map(|field| normalize_identifier(field, entity_scope)),
1315 })
1316 }
1317 SqlSelectItem::TextFunction(SqlTextFunctionCall {
1318 function,
1319 field,
1320 literal,
1321 literal2,
1322 literal3,
1323 }) => SqlSelectItem::TextFunction(SqlTextFunctionCall {
1324 function,
1325 field: normalize_identifier(field, entity_scope),
1326 literal,
1327 literal2,
1328 literal3,
1329 }),
1330 })
1331 .collect(),
1332 ),
1333 }
1334}
1335
1336fn normalize_order_terms(
1337 terms: Vec<crate::db::sql::parser::SqlOrderTerm>,
1338 entity_scope: &[String],
1339) -> Vec<crate::db::sql::parser::SqlOrderTerm> {
1340 terms
1341 .into_iter()
1342 .map(|term| crate::db::sql::parser::SqlOrderTerm {
1343 field: normalize_identifier(term.field, entity_scope),
1344 direction: term.direction,
1345 })
1346 .collect()
1347}
1348
1349fn normalize_identifier_list(fields: Vec<String>, entity_scope: &[String]) -> Vec<String> {
1350 fields
1351 .into_iter()
1352 .map(|field| normalize_identifier(field, entity_scope))
1353 .collect()
1354}
1355
1356fn adapt_predicate_identifiers_to_scope(
1359 predicate: Predicate,
1360 entity_scope: &[String],
1361) -> Predicate {
1362 rewrite_field_identifiers(predicate, |field| normalize_identifier(field, entity_scope))
1363}
1364
1365fn normalize_identifier(identifier: String, entity_scope: &[String]) -> String {
1366 normalize_identifier_to_scope(identifier, entity_scope)
1367}
1368
1369fn ensure_entity_matches_expected(
1370 sql_entity: &str,
1371 expected_entity: &'static str,
1372) -> Result<(), SqlLoweringError> {
1373 if identifiers_tail_match(sql_entity, expected_entity) {
1374 return Ok(());
1375 }
1376
1377 Err(SqlLoweringError::entity_mismatch(
1378 sql_entity,
1379 expected_entity,
1380 ))
1381}