1#[cfg(test)]
11mod tests;
12
13use crate::{
14 db::{
15 predicate::{MissingRowPolicy, Predicate},
16 query::{
17 builder::aggregate::{avg, count, count_by, max_by, min_by, sum},
18 intent::{Query, QueryError, StructuralQuery},
19 plan::ExpressionOrderTerm,
20 },
21 sql::identifier::{
22 identifier_last_segment, identifiers_tail_match, normalize_identifier_to_scope,
23 rewrite_field_identifiers,
24 },
25 sql::parser::{
26 SqlAggregateCall, SqlAggregateKind, SqlDeleteStatement, SqlExplainMode,
27 SqlExplainStatement, SqlExplainTarget, SqlHavingClause, SqlHavingSymbol,
28 SqlOrderDirection, SqlOrderTerm, SqlProjection, SqlSelectItem, SqlSelectStatement,
29 SqlStatement, SqlTextFunctionCall,
30 },
31 },
32 traits::EntityKind,
33};
34use thiserror::Error as ThisError;
35
36#[derive(Clone, Debug)]
45pub struct LoweredSqlCommand(LoweredSqlCommandInner);
46
47#[derive(Clone, Debug)]
48enum LoweredSqlCommandInner {
49 Query(LoweredSqlQuery),
50 Explain {
51 mode: SqlExplainMode,
52 query: LoweredSqlQuery,
53 },
54 ExplainGlobalAggregate {
55 mode: SqlExplainMode,
56 command: LoweredSqlGlobalAggregateCommand,
57 },
58 DescribeEntity,
59 ShowIndexesEntity,
60 ShowColumnsEntity,
61 ShowEntities,
62}
63
64#[cfg(test)]
72#[derive(Debug)]
73pub(crate) enum SqlCommand<E: EntityKind> {
74 Query(Query<E>),
75 Explain {
76 mode: SqlExplainMode,
77 query: Query<E>,
78 },
79 ExplainGlobalAggregate {
80 mode: SqlExplainMode,
81 command: SqlGlobalAggregateCommand<E>,
82 },
83 DescribeEntity,
84 ShowIndexesEntity,
85 ShowColumnsEntity,
86 ShowEntities,
87}
88
89impl LoweredSqlCommand {
90 #[must_use]
91 pub(in crate::db) const fn query(&self) -> Option<&LoweredSqlQuery> {
92 match &self.0 {
93 LoweredSqlCommandInner::Query(query) => Some(query),
94 LoweredSqlCommandInner::Explain { .. }
95 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. }
96 | LoweredSqlCommandInner::DescribeEntity
97 | LoweredSqlCommandInner::ShowIndexesEntity
98 | LoweredSqlCommandInner::ShowColumnsEntity
99 | LoweredSqlCommandInner::ShowEntities => None,
100 }
101 }
102
103 #[must_use]
104 pub(in crate::db) const fn explain_query(&self) -> Option<(SqlExplainMode, &LoweredSqlQuery)> {
105 match &self.0 {
106 LoweredSqlCommandInner::Explain { mode, query } => Some((*mode, query)),
107 LoweredSqlCommandInner::Query(_)
108 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. }
109 | LoweredSqlCommandInner::DescribeEntity
110 | LoweredSqlCommandInner::ShowIndexesEntity
111 | LoweredSqlCommandInner::ShowColumnsEntity
112 | LoweredSqlCommandInner::ShowEntities => None,
113 }
114 }
115}
116
117#[derive(Clone, Debug)]
124pub(crate) enum LoweredSqlQuery {
125 Select(LoweredSelectShape),
126 Delete(LoweredBaseQueryShape),
127}
128
129impl LoweredSqlQuery {
130 pub(crate) const fn has_grouping(&self) -> bool {
132 match self {
133 Self::Select(select) => select.has_grouping(),
134 Self::Delete(_) => false,
135 }
136 }
137}
138
139#[derive(Clone, Debug, Eq, PartialEq)]
147pub(crate) enum SqlGlobalAggregateTerminal {
148 CountRows,
149 CountField(String),
150 SumField(String),
151 AvgField(String),
152 MinField(String),
153 MaxField(String),
154}
155
156#[derive(Clone, Debug)]
165pub(crate) struct LoweredSqlGlobalAggregateCommand {
166 query: LoweredBaseQueryShape,
167 terminal: SqlGlobalAggregateTerminal,
168}
169
170enum LoweredSqlAggregateShape {
178 CountRows,
179 CountField(String),
180 FieldTarget {
181 kind: SqlAggregateKind,
182 field: String,
183 },
184}
185
186#[derive(Debug)]
193pub(crate) struct SqlGlobalAggregateCommand<E: EntityKind> {
194 query: Query<E>,
195 terminal: SqlGlobalAggregateTerminal,
196}
197
198impl<E: EntityKind> SqlGlobalAggregateCommand<E> {
199 #[must_use]
201 pub(crate) const fn query(&self) -> &Query<E> {
202 &self.query
203 }
204
205 #[must_use]
207 pub(crate) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
208 &self.terminal
209 }
210}
211
212#[derive(Debug)]
222pub(crate) struct SqlGlobalAggregateCommandCore {
223 query: StructuralQuery,
224 terminal: SqlGlobalAggregateTerminal,
225}
226
227impl SqlGlobalAggregateCommandCore {
228 #[must_use]
230 pub(in crate::db) const fn query(&self) -> &StructuralQuery {
231 &self.query
232 }
233
234 #[must_use]
236 pub(in crate::db) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
237 &self.terminal
238 }
239}
240
241#[derive(Debug, ThisError)]
248pub(crate) enum SqlLoweringError {
249 #[error("{0}")]
250 Parse(#[from] crate::db::sql::parser::SqlParseError),
251
252 #[error("{0}")]
253 Query(#[from] QueryError),
254
255 #[error("SQL entity '{sql_entity}' does not match requested entity type '{expected_entity}'")]
256 EntityMismatch {
257 sql_entity: String,
258 expected_entity: &'static str,
259 },
260
261 #[error(
262 "unsupported SQL SELECT projection; supported forms are SELECT *, field lists, or grouped aggregate shapes"
263 )]
264 UnsupportedSelectProjection,
265
266 #[error("unsupported SQL SELECT DISTINCT")]
267 UnsupportedSelectDistinct,
268
269 #[error("unsupported SQL GROUP BY projection shape")]
270 UnsupportedSelectGroupBy,
271
272 #[error("unsupported SQL HAVING shape")]
273 UnsupportedSelectHaving,
274}
275
276impl SqlLoweringError {
277 fn entity_mismatch(sql_entity: impl Into<String>, expected_entity: &'static str) -> Self {
279 Self::EntityMismatch {
280 sql_entity: sql_entity.into(),
281 expected_entity,
282 }
283 }
284
285 const fn unsupported_select_projection() -> Self {
287 Self::UnsupportedSelectProjection
288 }
289
290 const fn unsupported_select_distinct() -> Self {
292 Self::UnsupportedSelectDistinct
293 }
294
295 const fn unsupported_select_group_by() -> Self {
297 Self::UnsupportedSelectGroupBy
298 }
299
300 const fn unsupported_select_having() -> Self {
302 Self::UnsupportedSelectHaving
303 }
304}
305
306#[derive(Clone, Debug)]
317pub(crate) struct PreparedSqlStatement {
318 statement: SqlStatement,
319}
320
321#[derive(Clone, Copy, Debug, Eq, PartialEq)]
322pub(crate) enum LoweredSqlLaneKind {
323 Query,
324 Explain,
325 Describe,
326 ShowIndexes,
327 ShowColumns,
328 ShowEntities,
329}
330
331#[cfg(test)]
333pub(crate) fn compile_sql_command<E: EntityKind>(
334 sql: &str,
335 consistency: MissingRowPolicy,
336) -> Result<SqlCommand<E>, SqlLoweringError> {
337 let statement = crate::db::sql::parser::parse_sql(sql)?;
338 compile_sql_command_from_statement::<E>(statement, consistency)
339}
340
341#[cfg(test)]
343pub(crate) fn compile_sql_command_from_statement<E: EntityKind>(
344 statement: SqlStatement,
345 consistency: MissingRowPolicy,
346) -> Result<SqlCommand<E>, SqlLoweringError> {
347 let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
348 compile_sql_command_from_prepared_statement::<E>(prepared, consistency)
349}
350
351#[cfg(test)]
353pub(crate) fn compile_sql_command_from_prepared_statement<E: EntityKind>(
354 prepared: PreparedSqlStatement,
355 consistency: MissingRowPolicy,
356) -> Result<SqlCommand<E>, SqlLoweringError> {
357 let lowered = lower_sql_command_from_prepared_statement(prepared, E::MODEL.primary_key.name)?;
358
359 bind_lowered_sql_command::<E>(lowered, consistency)
360}
361
362#[inline(never)]
364pub(crate) fn lower_sql_command_from_prepared_statement(
365 prepared: PreparedSqlStatement,
366 primary_key_field: &str,
367) -> Result<LoweredSqlCommand, SqlLoweringError> {
368 lower_prepared_statement(prepared.statement, primary_key_field)
369}
370
371pub(crate) const fn lowered_sql_command_lane(command: &LoweredSqlCommand) -> LoweredSqlLaneKind {
372 match command.0 {
373 LoweredSqlCommandInner::Query(_) => LoweredSqlLaneKind::Query,
374 LoweredSqlCommandInner::Explain { .. }
375 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. } => LoweredSqlLaneKind::Explain,
376 LoweredSqlCommandInner::DescribeEntity => LoweredSqlLaneKind::Describe,
377 LoweredSqlCommandInner::ShowIndexesEntity => LoweredSqlLaneKind::ShowIndexes,
378 LoweredSqlCommandInner::ShowColumnsEntity => LoweredSqlLaneKind::ShowColumns,
379 LoweredSqlCommandInner::ShowEntities => LoweredSqlLaneKind::ShowEntities,
380 }
381}
382
383pub(in crate::db) fn is_sql_global_aggregate_statement(statement: &SqlStatement) -> bool {
386 let SqlStatement::Select(statement) = statement else {
387 return false;
388 };
389
390 is_sql_global_aggregate_select(statement)
391}
392
393fn is_sql_global_aggregate_select(statement: &SqlSelectStatement) -> bool {
396 if statement.distinct || !statement.group_by.is_empty() || !statement.having.is_empty() {
397 return false;
398 }
399
400 lower_global_aggregate_terminal(statement.projection.clone()).is_ok()
401}
402
403#[inline(never)]
405pub(crate) fn render_lowered_sql_explain_plan_or_json(
406 lowered: &LoweredSqlCommand,
407 model: &'static crate::model::entity::EntityModel,
408 consistency: MissingRowPolicy,
409) -> Result<Option<String>, SqlLoweringError> {
410 let LoweredSqlCommandInner::Explain { mode, query } = &lowered.0 else {
411 return Ok(None);
412 };
413
414 let query = bind_lowered_sql_query_structural(model, query.clone(), consistency)?;
415 let rendered = match mode {
416 SqlExplainMode::Plan | SqlExplainMode::Json => {
417 let plan = query.build_plan()?;
418 let explain = plan.explain_with_model(model);
419
420 match mode {
421 SqlExplainMode::Plan => explain.render_text_canonical(),
422 SqlExplainMode::Json => explain.render_json_canonical(),
423 SqlExplainMode::Execution => unreachable!("execution mode handled above"),
424 }
425 }
426 SqlExplainMode::Execution => query.explain_execution_text()?,
427 };
428
429 Ok(Some(rendered))
430}
431
432pub(crate) fn bind_lowered_sql_explain_global_aggregate_structural(
435 lowered: &LoweredSqlCommand,
436 model: &'static crate::model::entity::EntityModel,
437 consistency: MissingRowPolicy,
438) -> Option<(SqlExplainMode, SqlGlobalAggregateCommandCore)> {
439 let LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } = &lowered.0 else {
440 return None;
441 };
442
443 Some((
444 *mode,
445 bind_lowered_sql_global_aggregate_command_structural(model, command.clone(), consistency),
446 ))
447}
448
449#[cfg(test)]
451pub(crate) fn bind_lowered_sql_command<E: EntityKind>(
452 lowered: LoweredSqlCommand,
453 consistency: MissingRowPolicy,
454) -> Result<SqlCommand<E>, SqlLoweringError> {
455 match lowered.0 {
456 LoweredSqlCommandInner::Query(query) => Ok(SqlCommand::Query(bind_lowered_sql_query::<E>(
457 query,
458 consistency,
459 )?)),
460 LoweredSqlCommandInner::Explain { mode, query } => Ok(SqlCommand::Explain {
461 mode,
462 query: bind_lowered_sql_query::<E>(query, consistency)?,
463 }),
464 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } => {
465 Ok(SqlCommand::ExplainGlobalAggregate {
466 mode,
467 command: bind_lowered_sql_global_aggregate_command::<E>(command, consistency),
468 })
469 }
470 LoweredSqlCommandInner::DescribeEntity => Ok(SqlCommand::DescribeEntity),
471 LoweredSqlCommandInner::ShowIndexesEntity => Ok(SqlCommand::ShowIndexesEntity),
472 LoweredSqlCommandInner::ShowColumnsEntity => Ok(SqlCommand::ShowColumnsEntity),
473 LoweredSqlCommandInner::ShowEntities => Ok(SqlCommand::ShowEntities),
474 }
475}
476
477#[inline(never)]
479pub(crate) fn prepare_sql_statement(
480 statement: SqlStatement,
481 expected_entity: &'static str,
482) -> Result<PreparedSqlStatement, SqlLoweringError> {
483 let statement = prepare_statement(statement, expected_entity)?;
484
485 Ok(PreparedSqlStatement { statement })
486}
487
488#[cfg(test)]
490pub(crate) fn compile_sql_global_aggregate_command<E: EntityKind>(
491 sql: &str,
492 consistency: MissingRowPolicy,
493) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
494 let statement = crate::db::sql::parser::parse_sql(sql)?;
495 let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
496 compile_sql_global_aggregate_command_from_prepared::<E>(prepared, consistency)
497}
498
499pub(crate) fn compile_sql_global_aggregate_command_from_prepared<E: EntityKind>(
503 prepared: PreparedSqlStatement,
504 consistency: MissingRowPolicy,
505) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
506 let SqlStatement::Select(statement) = prepared.statement else {
507 return Err(SqlLoweringError::unsupported_select_projection());
508 };
509
510 Ok(bind_lowered_sql_global_aggregate_command::<E>(
511 lower_global_aggregate_select_shape(statement)?,
512 consistency,
513 ))
514}
515
516#[inline(never)]
517fn prepare_statement(
518 statement: SqlStatement,
519 expected_entity: &'static str,
520) -> Result<SqlStatement, SqlLoweringError> {
521 match statement {
522 SqlStatement::Select(statement) => Ok(SqlStatement::Select(prepare_select_statement(
523 statement,
524 expected_entity,
525 )?)),
526 SqlStatement::Delete(statement) => Ok(SqlStatement::Delete(prepare_delete_statement(
527 statement,
528 expected_entity,
529 )?)),
530 SqlStatement::Explain(statement) => Ok(SqlStatement::Explain(prepare_explain_statement(
531 statement,
532 expected_entity,
533 )?)),
534 SqlStatement::Describe(statement) => {
535 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
536
537 Ok(SqlStatement::Describe(statement))
538 }
539 SqlStatement::ShowIndexes(statement) => {
540 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
541
542 Ok(SqlStatement::ShowIndexes(statement))
543 }
544 SqlStatement::ShowColumns(statement) => {
545 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
546
547 Ok(SqlStatement::ShowColumns(statement))
548 }
549 SqlStatement::ShowEntities(statement) => Ok(SqlStatement::ShowEntities(statement)),
550 }
551}
552
553fn prepare_explain_statement(
554 statement: SqlExplainStatement,
555 expected_entity: &'static str,
556) -> Result<SqlExplainStatement, SqlLoweringError> {
557 let target = match statement.statement {
558 SqlExplainTarget::Select(select_statement) => {
559 SqlExplainTarget::Select(prepare_select_statement(select_statement, expected_entity)?)
560 }
561 SqlExplainTarget::Delete(delete_statement) => {
562 SqlExplainTarget::Delete(prepare_delete_statement(delete_statement, expected_entity)?)
563 }
564 };
565
566 Ok(SqlExplainStatement {
567 mode: statement.mode,
568 statement: target,
569 })
570}
571
572fn prepare_select_statement(
573 statement: SqlSelectStatement,
574 expected_entity: &'static str,
575) -> Result<SqlSelectStatement, SqlLoweringError> {
576 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
577
578 Ok(normalize_select_statement_to_expected_entity(
579 statement,
580 expected_entity,
581 ))
582}
583
584fn normalize_select_statement_to_expected_entity(
585 mut statement: SqlSelectStatement,
586 expected_entity: &'static str,
587) -> SqlSelectStatement {
588 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
591 statement.projection =
592 normalize_projection_identifiers(statement.projection, entity_scope.as_slice());
593 statement.group_by = normalize_identifier_list(statement.group_by, entity_scope.as_slice());
594 statement.predicate = statement
595 .predicate
596 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
597 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
598 statement.having = normalize_having_clauses(statement.having, entity_scope.as_slice());
599
600 statement
601}
602
603fn prepare_delete_statement(
604 mut statement: SqlDeleteStatement,
605 expected_entity: &'static str,
606) -> Result<SqlDeleteStatement, SqlLoweringError> {
607 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
608 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
609 statement.predicate = statement
610 .predicate
611 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
612 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
613
614 Ok(statement)
615}
616
617#[inline(never)]
618fn lower_prepared_statement(
619 statement: SqlStatement,
620 primary_key_field: &str,
621) -> Result<LoweredSqlCommand, SqlLoweringError> {
622 match statement {
623 SqlStatement::Select(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
624 LoweredSqlQuery::Select(lower_select_shape(statement, primary_key_field)?),
625 ))),
626 SqlStatement::Delete(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
627 LoweredSqlQuery::Delete(lower_delete_shape(statement)),
628 ))),
629 SqlStatement::Explain(statement) => lower_explain_prepared(statement, primary_key_field),
630 SqlStatement::Describe(_) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::DescribeEntity)),
631 SqlStatement::ShowIndexes(_) => {
632 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowIndexesEntity))
633 }
634 SqlStatement::ShowColumns(_) => {
635 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowColumnsEntity))
636 }
637 SqlStatement::ShowEntities(_) => {
638 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowEntities))
639 }
640 }
641}
642
643fn lower_explain_prepared(
644 statement: SqlExplainStatement,
645 primary_key_field: &str,
646) -> Result<LoweredSqlCommand, SqlLoweringError> {
647 let mode = statement.mode;
648
649 match statement.statement {
650 SqlExplainTarget::Select(select_statement) => {
651 lower_explain_select_prepared(select_statement, mode, primary_key_field)
652 }
653 SqlExplainTarget::Delete(delete_statement) => {
654 Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
655 mode,
656 query: LoweredSqlQuery::Delete(lower_delete_shape(delete_statement)),
657 }))
658 }
659 }
660}
661
662fn lower_explain_select_prepared(
663 statement: SqlSelectStatement,
664 mode: SqlExplainMode,
665 primary_key_field: &str,
666) -> Result<LoweredSqlCommand, SqlLoweringError> {
667 match lower_select_shape(statement.clone(), primary_key_field) {
668 Ok(query) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
669 mode,
670 query: LoweredSqlQuery::Select(query),
671 })),
672 Err(SqlLoweringError::UnsupportedSelectProjection) => {
673 let command = lower_global_aggregate_select_shape(statement)?;
674
675 Ok(LoweredSqlCommand(
676 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command },
677 ))
678 }
679 Err(err) => Err(err),
680 }
681}
682
683fn lower_global_aggregate_select_shape(
684 statement: SqlSelectStatement,
685) -> Result<LoweredSqlGlobalAggregateCommand, SqlLoweringError> {
686 let SqlSelectStatement {
687 projection,
688 predicate,
689 distinct,
690 group_by,
691 having,
692 order_by,
693 limit,
694 offset,
695 entity: _,
696 } = statement;
697
698 if distinct {
699 return Err(SqlLoweringError::unsupported_select_distinct());
700 }
701 if !group_by.is_empty() {
702 return Err(SqlLoweringError::unsupported_select_group_by());
703 }
704 if !having.is_empty() {
705 return Err(SqlLoweringError::unsupported_select_having());
706 }
707
708 let terminal = lower_global_aggregate_terminal(projection)?;
709
710 Ok(LoweredSqlGlobalAggregateCommand {
711 query: LoweredBaseQueryShape {
712 predicate,
713 order_by,
714 limit,
715 offset,
716 },
717 terminal,
718 })
719}
720
721#[derive(Clone, Debug)]
729enum ResolvedHavingClause {
730 GroupField {
731 field: String,
732 op: crate::db::predicate::CompareOp,
733 value: crate::value::Value,
734 },
735 Aggregate {
736 aggregate_index: usize,
737 op: crate::db::predicate::CompareOp,
738 value: crate::value::Value,
739 },
740}
741
742#[derive(Clone, Debug)]
749pub(crate) struct LoweredSelectShape {
750 scalar_projection_fields: Option<Vec<String>>,
751 grouped_projection_aggregates: Vec<SqlAggregateCall>,
752 group_by_fields: Vec<String>,
753 distinct: bool,
754 having: Vec<ResolvedHavingClause>,
755 predicate: Option<Predicate>,
756 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
757 limit: Option<u32>,
758 offset: Option<u32>,
759}
760
761impl LoweredSelectShape {
762 const fn has_grouping(&self) -> bool {
764 !self.group_by_fields.is_empty()
765 }
766}
767
768#[derive(Clone, Debug)]
777pub(crate) struct LoweredBaseQueryShape {
778 predicate: Option<Predicate>,
779 order_by: Vec<SqlOrderTerm>,
780 limit: Option<u32>,
781 offset: Option<u32>,
782}
783
784#[inline(never)]
785fn lower_select_shape(
786 statement: SqlSelectStatement,
787 primary_key_field: &str,
788) -> Result<LoweredSelectShape, SqlLoweringError> {
789 let SqlSelectStatement {
790 projection,
791 predicate,
792 distinct,
793 group_by,
794 having,
795 order_by,
796 limit,
797 offset,
798 entity: _,
799 } = statement;
800 let projection_for_having = projection.clone();
801
802 let (scalar_projection_fields, grouped_projection_aggregates) = if group_by.is_empty() {
804 let scalar_projection_fields =
805 lower_scalar_projection_fields(projection, distinct, primary_key_field)?;
806 (scalar_projection_fields, Vec::new())
807 } else {
808 if distinct {
809 return Err(SqlLoweringError::unsupported_select_distinct());
810 }
811 let grouped_projection_aggregates =
812 grouped_projection_aggregate_calls(&projection, group_by.as_slice())?;
813 (None, grouped_projection_aggregates)
814 };
815
816 let having = lower_having_clauses(
818 having,
819 &projection_for_having,
820 group_by.as_slice(),
821 grouped_projection_aggregates.as_slice(),
822 )?;
823
824 Ok(LoweredSelectShape {
825 scalar_projection_fields,
826 grouped_projection_aggregates,
827 group_by_fields: group_by,
828 distinct,
829 having,
830 predicate,
831 order_by,
832 limit,
833 offset,
834 })
835}
836
837fn lower_scalar_projection_fields(
838 projection: SqlProjection,
839 distinct: bool,
840 primary_key_field: &str,
841) -> Result<Option<Vec<String>>, SqlLoweringError> {
842 let SqlProjection::Items(items) = projection else {
843 if distinct {
844 return Ok(None);
845 }
846
847 return Ok(None);
848 };
849
850 let has_aggregate = items
851 .iter()
852 .any(|item| matches!(item, SqlSelectItem::Aggregate(_)));
853 if has_aggregate {
854 return Err(SqlLoweringError::unsupported_select_projection());
855 }
856
857 let fields = items
858 .into_iter()
859 .map(|item| match item {
860 SqlSelectItem::Field(field) => Ok(field),
861 SqlSelectItem::Aggregate(_) | SqlSelectItem::TextFunction(_) => {
862 Err(SqlLoweringError::unsupported_select_projection())
863 }
864 })
865 .collect::<Result<Vec<_>, _>>()?;
866
867 validate_scalar_distinct_projection(distinct, fields.as_slice(), primary_key_field)?;
868
869 Ok(Some(fields))
870}
871
872fn validate_scalar_distinct_projection(
873 distinct: bool,
874 projection_fields: &[String],
875 primary_key_field: &str,
876) -> Result<(), SqlLoweringError> {
877 if !distinct {
878 return Ok(());
879 }
880
881 if projection_fields.is_empty() {
882 return Ok(());
883 }
884
885 let has_primary_key_field = projection_fields
886 .iter()
887 .any(|field| field == primary_key_field);
888 if !has_primary_key_field {
889 return Err(SqlLoweringError::unsupported_select_distinct());
890 }
891
892 Ok(())
893}
894
895fn lower_having_clauses(
896 having_clauses: Vec<SqlHavingClause>,
897 projection: &SqlProjection,
898 group_by_fields: &[String],
899 grouped_projection_aggregates: &[SqlAggregateCall],
900) -> Result<Vec<ResolvedHavingClause>, SqlLoweringError> {
901 if having_clauses.is_empty() {
902 return Ok(Vec::new());
903 }
904 if group_by_fields.is_empty() {
905 return Err(SqlLoweringError::unsupported_select_having());
906 }
907
908 let projection_aggregates = grouped_projection_aggregate_calls(projection, group_by_fields)
909 .map_err(|_| SqlLoweringError::unsupported_select_having())?;
910 if projection_aggregates.as_slice() != grouped_projection_aggregates {
911 return Err(SqlLoweringError::unsupported_select_having());
912 }
913
914 let mut lowered = Vec::with_capacity(having_clauses.len());
915 for clause in having_clauses {
916 match clause.symbol {
917 SqlHavingSymbol::Field(field) => lowered.push(ResolvedHavingClause::GroupField {
918 field,
919 op: clause.op,
920 value: clause.value,
921 }),
922 SqlHavingSymbol::Aggregate(aggregate) => {
923 let aggregate_index =
924 resolve_having_aggregate_index(&aggregate, grouped_projection_aggregates)?;
925 lowered.push(ResolvedHavingClause::Aggregate {
926 aggregate_index,
927 op: clause.op,
928 value: clause.value,
929 });
930 }
931 }
932 }
933
934 Ok(lowered)
935}
936
937#[inline(never)]
938pub(in crate::db) fn apply_lowered_select_shape(
939 mut query: StructuralQuery,
940 lowered: LoweredSelectShape,
941) -> Result<StructuralQuery, SqlLoweringError> {
942 let LoweredSelectShape {
943 scalar_projection_fields,
944 grouped_projection_aggregates,
945 group_by_fields,
946 distinct,
947 having,
948 predicate,
949 order_by,
950 limit,
951 offset,
952 } = lowered;
953
954 for field in group_by_fields {
956 query = query.group_by(field)?;
957 }
958
959 if distinct {
961 query = query.distinct();
962 }
963 if let Some(fields) = scalar_projection_fields {
964 query = query.select_fields(fields);
965 }
966 for aggregate in grouped_projection_aggregates {
967 query = query.aggregate(lower_aggregate_call(aggregate)?);
968 }
969
970 for clause in having {
972 match clause {
973 ResolvedHavingClause::GroupField { field, op, value } => {
974 query = query.having_group(field, op, value)?;
975 }
976 ResolvedHavingClause::Aggregate {
977 aggregate_index,
978 op,
979 value,
980 } => {
981 query = query.having_aggregate(aggregate_index, op, value)?;
982 }
983 }
984 }
985
986 Ok(apply_lowered_base_query_shape(
988 query,
989 LoweredBaseQueryShape {
990 predicate,
991 order_by,
992 limit,
993 offset,
994 },
995 ))
996}
997
998fn apply_lowered_base_query_shape(
999 mut query: StructuralQuery,
1000 lowered: LoweredBaseQueryShape,
1001) -> StructuralQuery {
1002 if let Some(predicate) = lowered.predicate {
1003 query = query.filter(predicate);
1004 }
1005 query = apply_order_terms_structural(query, lowered.order_by);
1006 if let Some(limit) = lowered.limit {
1007 query = query.limit(limit);
1008 }
1009 if let Some(offset) = lowered.offset {
1010 query = query.offset(offset);
1011 }
1012
1013 query
1014}
1015
1016pub(in crate::db) fn bind_lowered_sql_query_structural(
1017 model: &'static crate::model::entity::EntityModel,
1018 lowered: LoweredSqlQuery,
1019 consistency: MissingRowPolicy,
1020) -> Result<StructuralQuery, SqlLoweringError> {
1021 match lowered {
1022 LoweredSqlQuery::Select(select) => {
1023 apply_lowered_select_shape(StructuralQuery::new(model, consistency), select)
1024 }
1025 LoweredSqlQuery::Delete(delete) => Ok(bind_lowered_sql_delete_query_structural(
1026 model,
1027 delete,
1028 consistency,
1029 )),
1030 }
1031}
1032
1033pub(in crate::db) fn bind_lowered_sql_delete_query_structural(
1034 model: &'static crate::model::entity::EntityModel,
1035 delete: LoweredBaseQueryShape,
1036 consistency: MissingRowPolicy,
1037) -> StructuralQuery {
1038 apply_lowered_base_query_shape(StructuralQuery::new(model, consistency).delete(), delete)
1039}
1040
1041pub(in crate::db) fn bind_lowered_sql_query<E: EntityKind>(
1042 lowered: LoweredSqlQuery,
1043 consistency: MissingRowPolicy,
1044) -> Result<Query<E>, SqlLoweringError> {
1045 let structural = bind_lowered_sql_query_structural(E::MODEL, lowered, consistency)?;
1046
1047 Ok(Query::from_inner(structural))
1048}
1049
1050fn bind_lowered_sql_global_aggregate_command<E: EntityKind>(
1051 lowered: LoweredSqlGlobalAggregateCommand,
1052 consistency: MissingRowPolicy,
1053) -> SqlGlobalAggregateCommand<E> {
1054 SqlGlobalAggregateCommand {
1055 query: Query::from_inner(apply_lowered_base_query_shape(
1056 StructuralQuery::new(E::MODEL, consistency),
1057 lowered.query,
1058 )),
1059 terminal: lowered.terminal,
1060 }
1061}
1062
1063fn bind_lowered_sql_global_aggregate_command_structural(
1064 model: &'static crate::model::entity::EntityModel,
1065 lowered: LoweredSqlGlobalAggregateCommand,
1066 consistency: MissingRowPolicy,
1067) -> SqlGlobalAggregateCommandCore {
1068 SqlGlobalAggregateCommandCore {
1069 query: apply_lowered_base_query_shape(
1070 StructuralQuery::new(model, consistency),
1071 lowered.query,
1072 ),
1073 terminal: lowered.terminal,
1074 }
1075}
1076
1077fn lower_global_aggregate_terminal(
1078 projection: SqlProjection,
1079) -> Result<SqlGlobalAggregateTerminal, SqlLoweringError> {
1080 let SqlProjection::Items(items) = projection else {
1081 return Err(SqlLoweringError::unsupported_select_projection());
1082 };
1083 if items.len() != 1 {
1084 return Err(SqlLoweringError::unsupported_select_projection());
1085 }
1086
1087 let Some(SqlSelectItem::Aggregate(aggregate)) = items.into_iter().next() else {
1088 return Err(SqlLoweringError::unsupported_select_projection());
1089 };
1090
1091 match lower_sql_aggregate_shape(aggregate)? {
1092 LoweredSqlAggregateShape::CountRows => Ok(SqlGlobalAggregateTerminal::CountRows),
1093 LoweredSqlAggregateShape::CountField(field) => {
1094 Ok(SqlGlobalAggregateTerminal::CountField(field))
1095 }
1096 LoweredSqlAggregateShape::FieldTarget {
1097 kind: SqlAggregateKind::Sum,
1098 field,
1099 } => Ok(SqlGlobalAggregateTerminal::SumField(field)),
1100 LoweredSqlAggregateShape::FieldTarget {
1101 kind: SqlAggregateKind::Avg,
1102 field,
1103 } => Ok(SqlGlobalAggregateTerminal::AvgField(field)),
1104 LoweredSqlAggregateShape::FieldTarget {
1105 kind: SqlAggregateKind::Min,
1106 field,
1107 } => Ok(SqlGlobalAggregateTerminal::MinField(field)),
1108 LoweredSqlAggregateShape::FieldTarget {
1109 kind: SqlAggregateKind::Max,
1110 field,
1111 } => Ok(SqlGlobalAggregateTerminal::MaxField(field)),
1112 LoweredSqlAggregateShape::FieldTarget {
1113 kind: SqlAggregateKind::Count,
1114 ..
1115 } => Err(SqlLoweringError::unsupported_select_projection()),
1116 }
1117}
1118
1119fn lower_sql_aggregate_shape(
1120 call: SqlAggregateCall,
1121) -> Result<LoweredSqlAggregateShape, SqlLoweringError> {
1122 match (call.kind, call.field) {
1123 (SqlAggregateKind::Count, None) => Ok(LoweredSqlAggregateShape::CountRows),
1124 (SqlAggregateKind::Count, Some(field)) => Ok(LoweredSqlAggregateShape::CountField(field)),
1125 (
1126 kind @ (SqlAggregateKind::Sum
1127 | SqlAggregateKind::Avg
1128 | SqlAggregateKind::Min
1129 | SqlAggregateKind::Max),
1130 Some(field),
1131 ) => Ok(LoweredSqlAggregateShape::FieldTarget { kind, field }),
1132 _ => Err(SqlLoweringError::unsupported_select_projection()),
1133 }
1134}
1135
1136fn grouped_projection_aggregate_calls(
1137 projection: &SqlProjection,
1138 group_by_fields: &[String],
1139) -> Result<Vec<SqlAggregateCall>, SqlLoweringError> {
1140 if group_by_fields.is_empty() {
1141 return Err(SqlLoweringError::unsupported_select_group_by());
1142 }
1143
1144 let SqlProjection::Items(items) = projection else {
1145 return Err(SqlLoweringError::unsupported_select_group_by());
1146 };
1147
1148 let mut projected_group_fields = Vec::<String>::new();
1149 let mut aggregate_calls = Vec::<SqlAggregateCall>::new();
1150 let mut seen_aggregate = false;
1151
1152 for item in items {
1153 match item {
1154 SqlSelectItem::Field(field) => {
1155 if seen_aggregate {
1158 return Err(SqlLoweringError::unsupported_select_group_by());
1159 }
1160 projected_group_fields.push(field.clone());
1161 }
1162 SqlSelectItem::Aggregate(aggregate) => {
1163 seen_aggregate = true;
1164 aggregate_calls.push(aggregate.clone());
1165 }
1166 SqlSelectItem::TextFunction(_) => {
1167 return Err(SqlLoweringError::unsupported_select_group_by());
1168 }
1169 }
1170 }
1171
1172 if aggregate_calls.is_empty() || projected_group_fields.as_slice() != group_by_fields {
1173 return Err(SqlLoweringError::unsupported_select_group_by());
1174 }
1175
1176 Ok(aggregate_calls)
1177}
1178
1179fn lower_aggregate_call(
1180 call: SqlAggregateCall,
1181) -> Result<crate::db::query::builder::AggregateExpr, SqlLoweringError> {
1182 match lower_sql_aggregate_shape(call)? {
1183 LoweredSqlAggregateShape::CountRows => Ok(count()),
1184 LoweredSqlAggregateShape::CountField(field) => Ok(count_by(field)),
1185 LoweredSqlAggregateShape::FieldTarget {
1186 kind: SqlAggregateKind::Sum,
1187 field,
1188 } => Ok(sum(field)),
1189 LoweredSqlAggregateShape::FieldTarget {
1190 kind: SqlAggregateKind::Avg,
1191 field,
1192 } => Ok(avg(field)),
1193 LoweredSqlAggregateShape::FieldTarget {
1194 kind: SqlAggregateKind::Min,
1195 field,
1196 } => Ok(min_by(field)),
1197 LoweredSqlAggregateShape::FieldTarget {
1198 kind: SqlAggregateKind::Max,
1199 field,
1200 } => Ok(max_by(field)),
1201 LoweredSqlAggregateShape::FieldTarget {
1202 kind: SqlAggregateKind::Count,
1203 ..
1204 } => Err(SqlLoweringError::unsupported_select_projection()),
1205 }
1206}
1207
1208fn resolve_having_aggregate_index(
1209 target: &SqlAggregateCall,
1210 grouped_projection_aggregates: &[SqlAggregateCall],
1211) -> Result<usize, SqlLoweringError> {
1212 let mut matched = grouped_projection_aggregates
1213 .iter()
1214 .enumerate()
1215 .filter_map(|(index, aggregate)| (aggregate == target).then_some(index));
1216 let Some(index) = matched.next() else {
1217 return Err(SqlLoweringError::unsupported_select_having());
1218 };
1219 if matched.next().is_some() {
1220 return Err(SqlLoweringError::unsupported_select_having());
1221 }
1222
1223 Ok(index)
1224}
1225
1226fn lower_delete_shape(statement: SqlDeleteStatement) -> LoweredBaseQueryShape {
1227 let SqlDeleteStatement {
1228 predicate,
1229 order_by,
1230 limit,
1231 entity: _,
1232 } = statement;
1233
1234 LoweredBaseQueryShape {
1235 predicate,
1236 order_by,
1237 limit,
1238 offset: None,
1239 }
1240}
1241
1242fn apply_order_terms_structural(
1243 mut query: StructuralQuery,
1244 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
1245) -> StructuralQuery {
1246 for term in order_by {
1247 query = match term.direction {
1248 SqlOrderDirection::Asc => query.order_by(term.field),
1249 SqlOrderDirection::Desc => query.order_by_desc(term.field),
1250 };
1251 }
1252
1253 query
1254}
1255
1256fn normalize_having_clauses(
1257 clauses: Vec<SqlHavingClause>,
1258 entity_scope: &[String],
1259) -> Vec<SqlHavingClause> {
1260 clauses
1261 .into_iter()
1262 .map(|clause| SqlHavingClause {
1263 symbol: normalize_having_symbol(clause.symbol, entity_scope),
1264 op: clause.op,
1265 value: clause.value,
1266 })
1267 .collect()
1268}
1269
1270fn normalize_having_symbol(symbol: SqlHavingSymbol, entity_scope: &[String]) -> SqlHavingSymbol {
1271 match symbol {
1272 SqlHavingSymbol::Field(field) => {
1273 SqlHavingSymbol::Field(normalize_identifier_to_scope(field, entity_scope))
1274 }
1275 SqlHavingSymbol::Aggregate(aggregate) => SqlHavingSymbol::Aggregate(
1276 normalize_aggregate_call_identifiers(aggregate, entity_scope),
1277 ),
1278 }
1279}
1280
1281fn normalize_aggregate_call_identifiers(
1282 aggregate: SqlAggregateCall,
1283 entity_scope: &[String],
1284) -> SqlAggregateCall {
1285 SqlAggregateCall {
1286 kind: aggregate.kind,
1287 field: aggregate
1288 .field
1289 .map(|field| normalize_identifier_to_scope(field, entity_scope)),
1290 }
1291}
1292
1293fn sql_entity_scope_candidates(sql_entity: &str, expected_entity: &'static str) -> Vec<String> {
1296 let mut out = Vec::new();
1297 out.push(sql_entity.to_string());
1298 out.push(expected_entity.to_string());
1299
1300 if let Some(last) = identifier_last_segment(sql_entity) {
1301 out.push(last.to_string());
1302 }
1303 if let Some(last) = identifier_last_segment(expected_entity) {
1304 out.push(last.to_string());
1305 }
1306
1307 out
1308}
1309
1310fn normalize_projection_identifiers(
1311 projection: SqlProjection,
1312 entity_scope: &[String],
1313) -> SqlProjection {
1314 match projection {
1315 SqlProjection::All => SqlProjection::All,
1316 SqlProjection::Items(items) => SqlProjection::Items(
1317 items
1318 .into_iter()
1319 .map(|item| match item {
1320 SqlSelectItem::Field(field) => {
1321 SqlSelectItem::Field(normalize_identifier(field, entity_scope))
1322 }
1323 SqlSelectItem::Aggregate(aggregate) => {
1324 SqlSelectItem::Aggregate(SqlAggregateCall {
1325 kind: aggregate.kind,
1326 field: aggregate
1327 .field
1328 .map(|field| normalize_identifier(field, entity_scope)),
1329 })
1330 }
1331 SqlSelectItem::TextFunction(SqlTextFunctionCall {
1332 function,
1333 field,
1334 literal,
1335 literal2,
1336 literal3,
1337 }) => SqlSelectItem::TextFunction(SqlTextFunctionCall {
1338 function,
1339 field: normalize_identifier(field, entity_scope),
1340 literal,
1341 literal2,
1342 literal3,
1343 }),
1344 })
1345 .collect(),
1346 ),
1347 }
1348}
1349
1350fn normalize_order_terms(
1351 terms: Vec<crate::db::sql::parser::SqlOrderTerm>,
1352 entity_scope: &[String],
1353) -> Vec<crate::db::sql::parser::SqlOrderTerm> {
1354 terms
1355 .into_iter()
1356 .map(|term| crate::db::sql::parser::SqlOrderTerm {
1357 field: normalize_order_term_identifier(term.field, entity_scope),
1358 direction: term.direction,
1359 })
1360 .collect()
1361}
1362
1363fn normalize_order_term_identifier(identifier: String, entity_scope: &[String]) -> String {
1364 let Some(expression) = ExpressionOrderTerm::parse(identifier.as_str()) else {
1365 return normalize_identifier(identifier, entity_scope);
1366 };
1367 let normalized_field = normalize_identifier(expression.field().to_string(), entity_scope);
1368
1369 expression.canonical_text_with_field(normalized_field.as_str())
1370}
1371
1372fn normalize_identifier_list(fields: Vec<String>, entity_scope: &[String]) -> Vec<String> {
1373 fields
1374 .into_iter()
1375 .map(|field| normalize_identifier(field, entity_scope))
1376 .collect()
1377}
1378
1379fn adapt_predicate_identifiers_to_scope(
1382 predicate: Predicate,
1383 entity_scope: &[String],
1384) -> Predicate {
1385 rewrite_field_identifiers(predicate, |field| normalize_identifier(field, entity_scope))
1386}
1387
1388fn normalize_identifier(identifier: String, entity_scope: &[String]) -> String {
1389 normalize_identifier_to_scope(identifier, entity_scope)
1390}
1391
1392fn ensure_entity_matches_expected(
1393 sql_entity: &str,
1394 expected_entity: &'static str,
1395) -> Result<(), SqlLoweringError> {
1396 if identifiers_tail_match(sql_entity, expected_entity) {
1397 return Ok(());
1398 }
1399
1400 Err(SqlLoweringError::entity_mismatch(
1401 sql_entity,
1402 expected_entity,
1403 ))
1404}