1#[cfg(test)]
11mod tests;
12
13use crate::{
14 db::{
15 predicate::{MissingRowPolicy, Predicate},
16 query::{
17 builder::aggregate::{avg, count, count_by, max_by, min_by, sum},
18 intent::{Query, QueryError, StructuralQuery},
19 },
20 sql::identifier::{
21 identifier_last_segment, identifiers_tail_match, normalize_identifier_to_scope,
22 rewrite_field_identifiers,
23 },
24 sql::parser::{
25 SqlAggregateCall, SqlAggregateKind, SqlDeleteStatement, SqlExplainMode,
26 SqlExplainStatement, SqlExplainTarget, SqlHavingClause, SqlHavingSymbol,
27 SqlOrderDirection, SqlOrderTerm, SqlProjection, SqlSelectItem, SqlSelectStatement,
28 SqlStatement, SqlTextFunctionCall, parse_sql,
29 },
30 },
31 traits::EntityKind,
32};
33use thiserror::Error as ThisError;
34
35#[derive(Clone, Debug)]
44pub struct LoweredSqlCommand(LoweredSqlCommandInner);
45
46#[derive(Clone, Debug)]
47enum LoweredSqlCommandInner {
48 Query(LoweredSqlQuery),
49 Explain {
50 mode: SqlExplainMode,
51 query: LoweredSqlQuery,
52 },
53 ExplainGlobalAggregate {
54 mode: SqlExplainMode,
55 command: LoweredSqlGlobalAggregateCommand,
56 },
57 DescribeEntity,
58 ShowIndexesEntity,
59 ShowColumnsEntity,
60 ShowEntities,
61}
62
63#[cfg(test)]
71#[derive(Debug)]
72pub(crate) enum SqlCommand<E: EntityKind> {
73 Query(Query<E>),
74 Explain {
75 mode: SqlExplainMode,
76 query: Query<E>,
77 },
78 ExplainGlobalAggregate {
79 mode: SqlExplainMode,
80 command: SqlGlobalAggregateCommand<E>,
81 },
82 DescribeEntity,
83 ShowIndexesEntity,
84 ShowColumnsEntity,
85 ShowEntities,
86}
87
88impl LoweredSqlCommand {
89 #[must_use]
90 pub(in crate::db) const fn query(&self) -> Option<&LoweredSqlQuery> {
91 match &self.0 {
92 LoweredSqlCommandInner::Query(query) => Some(query),
93 LoweredSqlCommandInner::Explain { .. }
94 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. }
95 | LoweredSqlCommandInner::DescribeEntity
96 | LoweredSqlCommandInner::ShowIndexesEntity
97 | LoweredSqlCommandInner::ShowColumnsEntity
98 | LoweredSqlCommandInner::ShowEntities => None,
99 }
100 }
101}
102
103#[derive(Clone, Debug)]
110pub(crate) enum LoweredSqlQuery {
111 Select(LoweredSelectShape),
112 Delete(LoweredBaseQueryShape),
113}
114
115impl LoweredSqlQuery {
116 pub(crate) const fn has_grouping(&self) -> bool {
118 match self {
119 Self::Select(select) => select.has_grouping(),
120 Self::Delete(_) => false,
121 }
122 }
123}
124
125#[derive(Clone, Debug, Eq, PartialEq)]
133pub(crate) enum SqlGlobalAggregateTerminal {
134 CountRows,
135 CountField(String),
136 SumField(String),
137 AvgField(String),
138 MinField(String),
139 MaxField(String),
140}
141
142#[derive(Clone, Debug)]
151pub(crate) struct LoweredSqlGlobalAggregateCommand {
152 query: LoweredBaseQueryShape,
153 terminal: SqlGlobalAggregateTerminal,
154}
155
156enum LoweredSqlAggregateShape {
164 CountRows,
165 CountField(String),
166 FieldTarget {
167 kind: SqlAggregateKind,
168 field: String,
169 },
170}
171
172#[derive(Debug)]
179pub(crate) struct SqlGlobalAggregateCommand<E: EntityKind> {
180 query: Query<E>,
181 terminal: SqlGlobalAggregateTerminal,
182}
183
184impl<E: EntityKind> SqlGlobalAggregateCommand<E> {
185 #[must_use]
187 pub(crate) const fn query(&self) -> &Query<E> {
188 &self.query
189 }
190
191 #[must_use]
193 pub(crate) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
194 &self.terminal
195 }
196}
197
198#[derive(Debug)]
208pub(crate) struct SqlGlobalAggregateCommandCore {
209 query: StructuralQuery,
210 terminal: SqlGlobalAggregateTerminal,
211}
212
213impl SqlGlobalAggregateCommandCore {
214 #[must_use]
216 pub(in crate::db) const fn query(&self) -> &StructuralQuery {
217 &self.query
218 }
219
220 #[must_use]
222 pub(in crate::db) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
223 &self.terminal
224 }
225}
226
227#[derive(Debug, ThisError)]
234pub(crate) enum SqlLoweringError {
235 #[error("{0}")]
236 Parse(#[from] crate::db::sql::parser::SqlParseError),
237
238 #[error("{0}")]
239 Query(#[from] QueryError),
240
241 #[error("SQL entity '{sql_entity}' does not match requested entity type '{expected_entity}'")]
242 EntityMismatch {
243 sql_entity: String,
244 expected_entity: &'static str,
245 },
246
247 #[error(
248 "unsupported SQL SELECT projection; supported forms are SELECT *, field lists, or grouped aggregate shapes"
249 )]
250 UnsupportedSelectProjection,
251
252 #[error("unsupported SQL SELECT DISTINCT")]
253 UnsupportedSelectDistinct,
254
255 #[error("unsupported SQL GROUP BY projection shape")]
256 UnsupportedSelectGroupBy,
257
258 #[error("unsupported SQL HAVING shape")]
259 UnsupportedSelectHaving,
260
261 #[error(
262 "generated SQL query dispatch requires SELECT, DELETE, EXPLAIN SELECT, or EXPLAIN DELETE"
263 )]
264 UnsupportedQuerySurfaceStatement,
265}
266
267impl SqlLoweringError {
268 fn entity_mismatch(sql_entity: impl Into<String>, expected_entity: &'static str) -> Self {
270 Self::EntityMismatch {
271 sql_entity: sql_entity.into(),
272 expected_entity,
273 }
274 }
275
276 const fn unsupported_select_projection() -> Self {
278 Self::UnsupportedSelectProjection
279 }
280
281 const fn unsupported_select_distinct() -> Self {
283 Self::UnsupportedSelectDistinct
284 }
285
286 const fn unsupported_select_group_by() -> Self {
288 Self::UnsupportedSelectGroupBy
289 }
290
291 const fn unsupported_select_having() -> Self {
293 Self::UnsupportedSelectHaving
294 }
295
296 const fn unsupported_query_surface_statement() -> Self {
298 Self::UnsupportedQuerySurfaceStatement
299 }
300}
301
302#[derive(Clone, Debug)]
313pub(crate) struct PreparedSqlStatement {
314 statement: SqlStatement,
315}
316
317#[derive(Clone, Copy, Debug, Eq, PartialEq)]
318pub(crate) enum LoweredSqlLaneKind {
319 Query,
320 Explain,
321 Describe,
322 ShowIndexes,
323 ShowColumns,
324 ShowEntities,
325}
326
327#[cfg(test)]
329pub(crate) fn compile_sql_command<E: EntityKind>(
330 sql: &str,
331 consistency: MissingRowPolicy,
332) -> Result<SqlCommand<E>, SqlLoweringError> {
333 let statement = parse_sql(sql)?;
334 compile_sql_command_from_statement::<E>(statement, consistency)
335}
336
337#[cfg(test)]
339pub(crate) fn compile_sql_command_from_statement<E: EntityKind>(
340 statement: SqlStatement,
341 consistency: MissingRowPolicy,
342) -> Result<SqlCommand<E>, SqlLoweringError> {
343 let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
344 compile_sql_command_from_prepared_statement::<E>(prepared, consistency)
345}
346
347#[cfg(test)]
349pub(crate) fn compile_sql_command_from_prepared_statement<E: EntityKind>(
350 prepared: PreparedSqlStatement,
351 consistency: MissingRowPolicy,
352) -> Result<SqlCommand<E>, SqlLoweringError> {
353 let lowered = lower_sql_command_from_prepared_statement(prepared, E::MODEL.primary_key.name)?;
354
355 bind_lowered_sql_command::<E>(lowered, consistency)
356}
357
358#[inline(never)]
360pub(crate) fn lower_sql_command_from_prepared_statement(
361 prepared: PreparedSqlStatement,
362 primary_key_field: &str,
363) -> Result<LoweredSqlCommand, SqlLoweringError> {
364 lower_prepared_statement(prepared.statement, primary_key_field)
365}
366
367#[inline(never)]
369pub(crate) fn prepare_query_surface_statement(
370 statement: SqlStatement,
371 expected_entity: &'static str,
372) -> Result<PreparedSqlStatement, SqlLoweringError> {
373 let prepared = prepare_sql_statement(statement, expected_entity)?;
374
375 match prepared.statement {
376 SqlStatement::Select(_) | SqlStatement::Delete(_) | SqlStatement::Explain(_) => {
377 Ok(prepared)
378 }
379 SqlStatement::Describe(_)
380 | SqlStatement::ShowIndexes(_)
381 | SqlStatement::ShowColumns(_)
382 | SqlStatement::ShowEntities(_) => {
383 Err(SqlLoweringError::unsupported_query_surface_statement())
384 }
385 }
386}
387
388#[inline(never)]
390pub(crate) fn lower_query_surface_command_from_prepared_statement(
391 prepared: PreparedSqlStatement,
392 primary_key_field: &str,
393) -> Result<LoweredSqlCommand, SqlLoweringError> {
394 let lowered = lower_sql_command_from_prepared_statement(prepared, primary_key_field)?;
395
396 match lowered_sql_command_lane(&lowered) {
397 LoweredSqlLaneKind::Query | LoweredSqlLaneKind::Explain => Ok(lowered),
398 LoweredSqlLaneKind::Describe
399 | LoweredSqlLaneKind::ShowIndexes
400 | LoweredSqlLaneKind::ShowColumns
401 | LoweredSqlLaneKind::ShowEntities => {
402 Err(SqlLoweringError::unsupported_query_surface_statement())
403 }
404 }
405}
406
407pub(crate) const fn lowered_sql_command_lane(command: &LoweredSqlCommand) -> LoweredSqlLaneKind {
408 match command.0 {
409 LoweredSqlCommandInner::Query(_) => LoweredSqlLaneKind::Query,
410 LoweredSqlCommandInner::Explain { .. }
411 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. } => LoweredSqlLaneKind::Explain,
412 LoweredSqlCommandInner::DescribeEntity => LoweredSqlLaneKind::Describe,
413 LoweredSqlCommandInner::ShowIndexesEntity => LoweredSqlLaneKind::ShowIndexes,
414 LoweredSqlCommandInner::ShowColumnsEntity => LoweredSqlLaneKind::ShowColumns,
415 LoweredSqlCommandInner::ShowEntities => LoweredSqlLaneKind::ShowEntities,
416 }
417}
418
419pub(in crate::db) fn is_sql_global_aggregate_statement(statement: &SqlStatement) -> bool {
422 let SqlStatement::Select(statement) = statement else {
423 return false;
424 };
425
426 is_sql_global_aggregate_select(statement)
427}
428
429fn is_sql_global_aggregate_select(statement: &SqlSelectStatement) -> bool {
432 if statement.distinct || !statement.group_by.is_empty() || !statement.having.is_empty() {
433 return false;
434 }
435
436 lower_global_aggregate_terminal(statement.projection.clone()).is_ok()
437}
438
439#[inline(never)]
441pub(crate) fn render_lowered_sql_explain_plan_or_json(
442 lowered: &LoweredSqlCommand,
443 model: &'static crate::model::entity::EntityModel,
444 consistency: MissingRowPolicy,
445) -> Result<Option<String>, SqlLoweringError> {
446 let LoweredSqlCommandInner::Explain { mode, query } = &lowered.0 else {
447 return Ok(None);
448 };
449
450 let query = bind_lowered_sql_query_structural(model, query.clone(), consistency)?;
451 let rendered = match mode {
452 SqlExplainMode::Plan | SqlExplainMode::Json => {
453 let plan = query.build_plan()?;
454 let explain = plan.explain_with_model(model);
455
456 match mode {
457 SqlExplainMode::Plan => explain.render_text_canonical(),
458 SqlExplainMode::Json => explain.render_json_canonical(),
459 SqlExplainMode::Execution => unreachable!("execution mode handled above"),
460 }
461 }
462 SqlExplainMode::Execution => query.explain_execution_text()?,
463 };
464
465 Ok(Some(rendered))
466}
467
468pub(crate) fn bind_lowered_sql_explain_global_aggregate_structural(
471 lowered: &LoweredSqlCommand,
472 model: &'static crate::model::entity::EntityModel,
473 consistency: MissingRowPolicy,
474) -> Option<(SqlExplainMode, SqlGlobalAggregateCommandCore)> {
475 let LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } = &lowered.0 else {
476 return None;
477 };
478
479 Some((
480 *mode,
481 bind_lowered_sql_global_aggregate_command_structural(model, command.clone(), consistency),
482 ))
483}
484
485#[cfg(test)]
487pub(crate) fn bind_lowered_sql_command<E: EntityKind>(
488 lowered: LoweredSqlCommand,
489 consistency: MissingRowPolicy,
490) -> Result<SqlCommand<E>, SqlLoweringError> {
491 match lowered.0 {
492 LoweredSqlCommandInner::Query(query) => Ok(SqlCommand::Query(bind_lowered_sql_query::<E>(
493 query,
494 consistency,
495 )?)),
496 LoweredSqlCommandInner::Explain { mode, query } => Ok(SqlCommand::Explain {
497 mode,
498 query: bind_lowered_sql_query::<E>(query, consistency)?,
499 }),
500 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } => {
501 Ok(SqlCommand::ExplainGlobalAggregate {
502 mode,
503 command: bind_lowered_sql_global_aggregate_command::<E>(command, consistency),
504 })
505 }
506 LoweredSqlCommandInner::DescribeEntity => Ok(SqlCommand::DescribeEntity),
507 LoweredSqlCommandInner::ShowIndexesEntity => Ok(SqlCommand::ShowIndexesEntity),
508 LoweredSqlCommandInner::ShowColumnsEntity => Ok(SqlCommand::ShowColumnsEntity),
509 LoweredSqlCommandInner::ShowEntities => Ok(SqlCommand::ShowEntities),
510 }
511}
512
513#[inline(never)]
515pub(crate) fn prepare_sql_statement(
516 statement: SqlStatement,
517 expected_entity: &'static str,
518) -> Result<PreparedSqlStatement, SqlLoweringError> {
519 let statement = prepare_statement(statement, expected_entity)?;
520
521 Ok(PreparedSqlStatement { statement })
522}
523
524pub(crate) fn compile_sql_global_aggregate_command<E: EntityKind>(
526 sql: &str,
527 consistency: MissingRowPolicy,
528) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
529 let statement = parse_sql(sql)?;
530 let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
531 compile_sql_global_aggregate_command_from_prepared::<E>(prepared, consistency)
532}
533
534fn compile_sql_global_aggregate_command_from_prepared<E: EntityKind>(
535 prepared: PreparedSqlStatement,
536 consistency: MissingRowPolicy,
537) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
538 let SqlStatement::Select(statement) = prepared.statement else {
539 return Err(SqlLoweringError::unsupported_select_projection());
540 };
541
542 Ok(bind_lowered_sql_global_aggregate_command::<E>(
543 lower_global_aggregate_select_shape(statement)?,
544 consistency,
545 ))
546}
547
548#[inline(never)]
549fn prepare_statement(
550 statement: SqlStatement,
551 expected_entity: &'static str,
552) -> Result<SqlStatement, SqlLoweringError> {
553 match statement {
554 SqlStatement::Select(statement) => Ok(SqlStatement::Select(prepare_select_statement(
555 statement,
556 expected_entity,
557 )?)),
558 SqlStatement::Delete(statement) => Ok(SqlStatement::Delete(prepare_delete_statement(
559 statement,
560 expected_entity,
561 )?)),
562 SqlStatement::Explain(statement) => Ok(SqlStatement::Explain(prepare_explain_statement(
563 statement,
564 expected_entity,
565 )?)),
566 SqlStatement::Describe(statement) => {
567 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
568
569 Ok(SqlStatement::Describe(statement))
570 }
571 SqlStatement::ShowIndexes(statement) => {
572 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
573
574 Ok(SqlStatement::ShowIndexes(statement))
575 }
576 SqlStatement::ShowColumns(statement) => {
577 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
578
579 Ok(SqlStatement::ShowColumns(statement))
580 }
581 SqlStatement::ShowEntities(statement) => Ok(SqlStatement::ShowEntities(statement)),
582 }
583}
584
585fn prepare_explain_statement(
586 statement: SqlExplainStatement,
587 expected_entity: &'static str,
588) -> Result<SqlExplainStatement, SqlLoweringError> {
589 let target = match statement.statement {
590 SqlExplainTarget::Select(select_statement) => {
591 SqlExplainTarget::Select(prepare_select_statement(select_statement, expected_entity)?)
592 }
593 SqlExplainTarget::Delete(delete_statement) => {
594 SqlExplainTarget::Delete(prepare_delete_statement(delete_statement, expected_entity)?)
595 }
596 };
597
598 Ok(SqlExplainStatement {
599 mode: statement.mode,
600 statement: target,
601 })
602}
603
604fn prepare_select_statement(
605 statement: SqlSelectStatement,
606 expected_entity: &'static str,
607) -> Result<SqlSelectStatement, SqlLoweringError> {
608 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
609
610 Ok(normalize_select_statement_to_expected_entity(
611 statement,
612 expected_entity,
613 ))
614}
615
616fn normalize_select_statement_to_expected_entity(
617 mut statement: SqlSelectStatement,
618 expected_entity: &'static str,
619) -> SqlSelectStatement {
620 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
623 statement.projection =
624 normalize_projection_identifiers(statement.projection, entity_scope.as_slice());
625 statement.group_by = normalize_identifier_list(statement.group_by, entity_scope.as_slice());
626 statement.predicate = statement
627 .predicate
628 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
629 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
630 statement.having = normalize_having_clauses(statement.having, entity_scope.as_slice());
631
632 statement
633}
634
635fn prepare_delete_statement(
636 mut statement: SqlDeleteStatement,
637 expected_entity: &'static str,
638) -> Result<SqlDeleteStatement, SqlLoweringError> {
639 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
640 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
641 statement.predicate = statement
642 .predicate
643 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
644 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
645
646 Ok(statement)
647}
648
649#[inline(never)]
650fn lower_prepared_statement(
651 statement: SqlStatement,
652 primary_key_field: &str,
653) -> Result<LoweredSqlCommand, SqlLoweringError> {
654 match statement {
655 SqlStatement::Select(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
656 LoweredSqlQuery::Select(lower_select_shape(statement, primary_key_field)?),
657 ))),
658 SqlStatement::Delete(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
659 LoweredSqlQuery::Delete(lower_delete_shape(statement)),
660 ))),
661 SqlStatement::Explain(statement) => lower_explain_prepared(statement, primary_key_field),
662 SqlStatement::Describe(_) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::DescribeEntity)),
663 SqlStatement::ShowIndexes(_) => {
664 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowIndexesEntity))
665 }
666 SqlStatement::ShowColumns(_) => {
667 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowColumnsEntity))
668 }
669 SqlStatement::ShowEntities(_) => {
670 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowEntities))
671 }
672 }
673}
674
675fn lower_explain_prepared(
676 statement: SqlExplainStatement,
677 primary_key_field: &str,
678) -> Result<LoweredSqlCommand, SqlLoweringError> {
679 let mode = statement.mode;
680
681 match statement.statement {
682 SqlExplainTarget::Select(select_statement) => {
683 lower_explain_select_prepared(select_statement, mode, primary_key_field)
684 }
685 SqlExplainTarget::Delete(delete_statement) => {
686 Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
687 mode,
688 query: LoweredSqlQuery::Delete(lower_delete_shape(delete_statement)),
689 }))
690 }
691 }
692}
693
694fn lower_explain_select_prepared(
695 statement: SqlSelectStatement,
696 mode: SqlExplainMode,
697 primary_key_field: &str,
698) -> Result<LoweredSqlCommand, SqlLoweringError> {
699 match lower_select_shape(statement.clone(), primary_key_field) {
700 Ok(query) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
701 mode,
702 query: LoweredSqlQuery::Select(query),
703 })),
704 Err(SqlLoweringError::UnsupportedSelectProjection) => {
705 let command = lower_global_aggregate_select_shape(statement)?;
706
707 Ok(LoweredSqlCommand(
708 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command },
709 ))
710 }
711 Err(err) => Err(err),
712 }
713}
714
715fn lower_global_aggregate_select_shape(
716 statement: SqlSelectStatement,
717) -> Result<LoweredSqlGlobalAggregateCommand, SqlLoweringError> {
718 let SqlSelectStatement {
719 projection,
720 predicate,
721 distinct,
722 group_by,
723 having,
724 order_by,
725 limit,
726 offset,
727 entity: _,
728 } = statement;
729
730 if distinct {
731 return Err(SqlLoweringError::unsupported_select_distinct());
732 }
733 if !group_by.is_empty() {
734 return Err(SqlLoweringError::unsupported_select_group_by());
735 }
736 if !having.is_empty() {
737 return Err(SqlLoweringError::unsupported_select_having());
738 }
739
740 let terminal = lower_global_aggregate_terminal(projection)?;
741
742 Ok(LoweredSqlGlobalAggregateCommand {
743 query: LoweredBaseQueryShape {
744 predicate,
745 order_by,
746 limit,
747 offset,
748 },
749 terminal,
750 })
751}
752
753#[derive(Clone, Debug)]
761enum ResolvedHavingClause {
762 GroupField {
763 field: String,
764 op: crate::db::predicate::CompareOp,
765 value: crate::value::Value,
766 },
767 Aggregate {
768 aggregate_index: usize,
769 op: crate::db::predicate::CompareOp,
770 value: crate::value::Value,
771 },
772}
773
774#[derive(Clone, Debug)]
781pub(crate) struct LoweredSelectShape {
782 scalar_projection_fields: Option<Vec<String>>,
783 grouped_projection_aggregates: Vec<SqlAggregateCall>,
784 group_by_fields: Vec<String>,
785 distinct: bool,
786 having: Vec<ResolvedHavingClause>,
787 predicate: Option<Predicate>,
788 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
789 limit: Option<u32>,
790 offset: Option<u32>,
791}
792
793impl LoweredSelectShape {
794 const fn has_grouping(&self) -> bool {
796 !self.group_by_fields.is_empty()
797 }
798}
799
800#[derive(Clone, Debug)]
809pub(crate) struct LoweredBaseQueryShape {
810 predicate: Option<Predicate>,
811 order_by: Vec<SqlOrderTerm>,
812 limit: Option<u32>,
813 offset: Option<u32>,
814}
815
816#[inline(never)]
817fn lower_select_shape(
818 statement: SqlSelectStatement,
819 primary_key_field: &str,
820) -> Result<LoweredSelectShape, SqlLoweringError> {
821 let SqlSelectStatement {
822 projection,
823 predicate,
824 distinct,
825 group_by,
826 having,
827 order_by,
828 limit,
829 offset,
830 entity: _,
831 } = statement;
832 let projection_for_having = projection.clone();
833
834 let (scalar_projection_fields, grouped_projection_aggregates) = if group_by.is_empty() {
836 let scalar_projection_fields =
837 lower_scalar_projection_fields(projection, distinct, primary_key_field)?;
838 (scalar_projection_fields, Vec::new())
839 } else {
840 if distinct {
841 return Err(SqlLoweringError::unsupported_select_distinct());
842 }
843 let grouped_projection_aggregates =
844 grouped_projection_aggregate_calls(&projection, group_by.as_slice())?;
845 (None, grouped_projection_aggregates)
846 };
847
848 let having = lower_having_clauses(
850 having,
851 &projection_for_having,
852 group_by.as_slice(),
853 grouped_projection_aggregates.as_slice(),
854 )?;
855
856 Ok(LoweredSelectShape {
857 scalar_projection_fields,
858 grouped_projection_aggregates,
859 group_by_fields: group_by,
860 distinct,
861 having,
862 predicate,
863 order_by,
864 limit,
865 offset,
866 })
867}
868
869fn lower_scalar_projection_fields(
870 projection: SqlProjection,
871 distinct: bool,
872 primary_key_field: &str,
873) -> Result<Option<Vec<String>>, SqlLoweringError> {
874 let SqlProjection::Items(items) = projection else {
875 if distinct {
876 return Ok(None);
877 }
878
879 return Ok(None);
880 };
881
882 let has_aggregate = items
883 .iter()
884 .any(|item| matches!(item, SqlSelectItem::Aggregate(_)));
885 if has_aggregate {
886 return Err(SqlLoweringError::unsupported_select_projection());
887 }
888
889 let fields = items
890 .into_iter()
891 .map(|item| match item {
892 SqlSelectItem::Field(field) => Ok(field),
893 SqlSelectItem::Aggregate(_) | SqlSelectItem::TextFunction(_) => {
894 Err(SqlLoweringError::unsupported_select_projection())
895 }
896 })
897 .collect::<Result<Vec<_>, _>>()?;
898
899 validate_scalar_distinct_projection(distinct, fields.as_slice(), primary_key_field)?;
900
901 Ok(Some(fields))
902}
903
904fn validate_scalar_distinct_projection(
905 distinct: bool,
906 projection_fields: &[String],
907 primary_key_field: &str,
908) -> Result<(), SqlLoweringError> {
909 if !distinct {
910 return Ok(());
911 }
912
913 if projection_fields.is_empty() {
914 return Ok(());
915 }
916
917 let has_primary_key_field = projection_fields
918 .iter()
919 .any(|field| field == primary_key_field);
920 if !has_primary_key_field {
921 return Err(SqlLoweringError::unsupported_select_distinct());
922 }
923
924 Ok(())
925}
926
927fn lower_having_clauses(
928 having_clauses: Vec<SqlHavingClause>,
929 projection: &SqlProjection,
930 group_by_fields: &[String],
931 grouped_projection_aggregates: &[SqlAggregateCall],
932) -> Result<Vec<ResolvedHavingClause>, SqlLoweringError> {
933 if having_clauses.is_empty() {
934 return Ok(Vec::new());
935 }
936 if group_by_fields.is_empty() {
937 return Err(SqlLoweringError::unsupported_select_having());
938 }
939
940 let projection_aggregates = grouped_projection_aggregate_calls(projection, group_by_fields)
941 .map_err(|_| SqlLoweringError::unsupported_select_having())?;
942 if projection_aggregates.as_slice() != grouped_projection_aggregates {
943 return Err(SqlLoweringError::unsupported_select_having());
944 }
945
946 let mut lowered = Vec::with_capacity(having_clauses.len());
947 for clause in having_clauses {
948 match clause.symbol {
949 SqlHavingSymbol::Field(field) => lowered.push(ResolvedHavingClause::GroupField {
950 field,
951 op: clause.op,
952 value: clause.value,
953 }),
954 SqlHavingSymbol::Aggregate(aggregate) => {
955 let aggregate_index =
956 resolve_having_aggregate_index(&aggregate, grouped_projection_aggregates)?;
957 lowered.push(ResolvedHavingClause::Aggregate {
958 aggregate_index,
959 op: clause.op,
960 value: clause.value,
961 });
962 }
963 }
964 }
965
966 Ok(lowered)
967}
968
969#[inline(never)]
970pub(in crate::db) fn apply_lowered_select_shape(
971 mut query: StructuralQuery,
972 lowered: LoweredSelectShape,
973) -> Result<StructuralQuery, SqlLoweringError> {
974 let LoweredSelectShape {
975 scalar_projection_fields,
976 grouped_projection_aggregates,
977 group_by_fields,
978 distinct,
979 having,
980 predicate,
981 order_by,
982 limit,
983 offset,
984 } = lowered;
985
986 for field in group_by_fields {
988 query = query.group_by(field)?;
989 }
990
991 if distinct {
993 query = query.distinct();
994 }
995 if let Some(fields) = scalar_projection_fields {
996 query = query.select_fields(fields);
997 }
998 for aggregate in grouped_projection_aggregates {
999 query = query.aggregate(lower_aggregate_call(aggregate)?);
1000 }
1001
1002 for clause in having {
1004 match clause {
1005 ResolvedHavingClause::GroupField { field, op, value } => {
1006 query = query.having_group(field, op, value)?;
1007 }
1008 ResolvedHavingClause::Aggregate {
1009 aggregate_index,
1010 op,
1011 value,
1012 } => {
1013 query = query.having_aggregate(aggregate_index, op, value)?;
1014 }
1015 }
1016 }
1017
1018 Ok(apply_lowered_base_query_shape(
1020 query,
1021 LoweredBaseQueryShape {
1022 predicate,
1023 order_by,
1024 limit,
1025 offset,
1026 },
1027 ))
1028}
1029
1030fn apply_lowered_base_query_shape(
1031 mut query: StructuralQuery,
1032 lowered: LoweredBaseQueryShape,
1033) -> StructuralQuery {
1034 if let Some(predicate) = lowered.predicate {
1035 query = query.filter(predicate);
1036 }
1037 query = apply_order_terms_structural(query, lowered.order_by);
1038 if let Some(limit) = lowered.limit {
1039 query = query.limit(limit);
1040 }
1041 if let Some(offset) = lowered.offset {
1042 query = query.offset(offset);
1043 }
1044
1045 query
1046}
1047
1048pub(in crate::db) fn bind_lowered_sql_query_structural(
1049 model: &'static crate::model::entity::EntityModel,
1050 lowered: LoweredSqlQuery,
1051 consistency: MissingRowPolicy,
1052) -> Result<StructuralQuery, SqlLoweringError> {
1053 match lowered {
1054 LoweredSqlQuery::Select(select) => {
1055 apply_lowered_select_shape(StructuralQuery::new(model, consistency), select)
1056 }
1057 LoweredSqlQuery::Delete(delete) => Ok(bind_lowered_sql_delete_query_structural(
1058 model,
1059 delete,
1060 consistency,
1061 )),
1062 }
1063}
1064
1065pub(in crate::db) fn bind_lowered_sql_delete_query_structural(
1066 model: &'static crate::model::entity::EntityModel,
1067 delete: LoweredBaseQueryShape,
1068 consistency: MissingRowPolicy,
1069) -> StructuralQuery {
1070 apply_lowered_base_query_shape(StructuralQuery::new(model, consistency).delete(), delete)
1071}
1072
1073pub(in crate::db) fn bind_lowered_sql_query<E: EntityKind>(
1074 lowered: LoweredSqlQuery,
1075 consistency: MissingRowPolicy,
1076) -> Result<Query<E>, SqlLoweringError> {
1077 let structural = bind_lowered_sql_query_structural(E::MODEL, lowered, consistency)?;
1078
1079 Ok(Query::from_inner(structural))
1080}
1081
1082fn bind_lowered_sql_global_aggregate_command<E: EntityKind>(
1083 lowered: LoweredSqlGlobalAggregateCommand,
1084 consistency: MissingRowPolicy,
1085) -> SqlGlobalAggregateCommand<E> {
1086 SqlGlobalAggregateCommand {
1087 query: Query::from_inner(apply_lowered_base_query_shape(
1088 StructuralQuery::new(E::MODEL, consistency),
1089 lowered.query,
1090 )),
1091 terminal: lowered.terminal,
1092 }
1093}
1094
1095fn bind_lowered_sql_global_aggregate_command_structural(
1096 model: &'static crate::model::entity::EntityModel,
1097 lowered: LoweredSqlGlobalAggregateCommand,
1098 consistency: MissingRowPolicy,
1099) -> SqlGlobalAggregateCommandCore {
1100 SqlGlobalAggregateCommandCore {
1101 query: apply_lowered_base_query_shape(
1102 StructuralQuery::new(model, consistency),
1103 lowered.query,
1104 ),
1105 terminal: lowered.terminal,
1106 }
1107}
1108
1109fn lower_global_aggregate_terminal(
1110 projection: SqlProjection,
1111) -> Result<SqlGlobalAggregateTerminal, SqlLoweringError> {
1112 let SqlProjection::Items(items) = projection else {
1113 return Err(SqlLoweringError::unsupported_select_projection());
1114 };
1115 if items.len() != 1 {
1116 return Err(SqlLoweringError::unsupported_select_projection());
1117 }
1118
1119 let Some(SqlSelectItem::Aggregate(aggregate)) = items.into_iter().next() else {
1120 return Err(SqlLoweringError::unsupported_select_projection());
1121 };
1122
1123 match lower_sql_aggregate_shape(aggregate)? {
1124 LoweredSqlAggregateShape::CountRows => Ok(SqlGlobalAggregateTerminal::CountRows),
1125 LoweredSqlAggregateShape::CountField(field) => {
1126 Ok(SqlGlobalAggregateTerminal::CountField(field))
1127 }
1128 LoweredSqlAggregateShape::FieldTarget {
1129 kind: SqlAggregateKind::Sum,
1130 field,
1131 } => Ok(SqlGlobalAggregateTerminal::SumField(field)),
1132 LoweredSqlAggregateShape::FieldTarget {
1133 kind: SqlAggregateKind::Avg,
1134 field,
1135 } => Ok(SqlGlobalAggregateTerminal::AvgField(field)),
1136 LoweredSqlAggregateShape::FieldTarget {
1137 kind: SqlAggregateKind::Min,
1138 field,
1139 } => Ok(SqlGlobalAggregateTerminal::MinField(field)),
1140 LoweredSqlAggregateShape::FieldTarget {
1141 kind: SqlAggregateKind::Max,
1142 field,
1143 } => Ok(SqlGlobalAggregateTerminal::MaxField(field)),
1144 LoweredSqlAggregateShape::FieldTarget {
1145 kind: SqlAggregateKind::Count,
1146 ..
1147 } => Err(SqlLoweringError::unsupported_select_projection()),
1148 }
1149}
1150
1151fn lower_sql_aggregate_shape(
1152 call: SqlAggregateCall,
1153) -> Result<LoweredSqlAggregateShape, SqlLoweringError> {
1154 match (call.kind, call.field) {
1155 (SqlAggregateKind::Count, None) => Ok(LoweredSqlAggregateShape::CountRows),
1156 (SqlAggregateKind::Count, Some(field)) => Ok(LoweredSqlAggregateShape::CountField(field)),
1157 (
1158 kind @ (SqlAggregateKind::Sum
1159 | SqlAggregateKind::Avg
1160 | SqlAggregateKind::Min
1161 | SqlAggregateKind::Max),
1162 Some(field),
1163 ) => Ok(LoweredSqlAggregateShape::FieldTarget { kind, field }),
1164 _ => Err(SqlLoweringError::unsupported_select_projection()),
1165 }
1166}
1167
1168fn grouped_projection_aggregate_calls(
1169 projection: &SqlProjection,
1170 group_by_fields: &[String],
1171) -> Result<Vec<SqlAggregateCall>, SqlLoweringError> {
1172 if group_by_fields.is_empty() {
1173 return Err(SqlLoweringError::unsupported_select_group_by());
1174 }
1175
1176 let SqlProjection::Items(items) = projection else {
1177 return Err(SqlLoweringError::unsupported_select_group_by());
1178 };
1179
1180 let mut projected_group_fields = Vec::<String>::new();
1181 let mut aggregate_calls = Vec::<SqlAggregateCall>::new();
1182 let mut seen_aggregate = false;
1183
1184 for item in items {
1185 match item {
1186 SqlSelectItem::Field(field) => {
1187 if seen_aggregate {
1190 return Err(SqlLoweringError::unsupported_select_group_by());
1191 }
1192 projected_group_fields.push(field.clone());
1193 }
1194 SqlSelectItem::Aggregate(aggregate) => {
1195 seen_aggregate = true;
1196 aggregate_calls.push(aggregate.clone());
1197 }
1198 SqlSelectItem::TextFunction(_) => {
1199 return Err(SqlLoweringError::unsupported_select_group_by());
1200 }
1201 }
1202 }
1203
1204 if aggregate_calls.is_empty() || projected_group_fields.as_slice() != group_by_fields {
1205 return Err(SqlLoweringError::unsupported_select_group_by());
1206 }
1207
1208 Ok(aggregate_calls)
1209}
1210
1211fn lower_aggregate_call(
1212 call: SqlAggregateCall,
1213) -> Result<crate::db::query::builder::AggregateExpr, SqlLoweringError> {
1214 match lower_sql_aggregate_shape(call)? {
1215 LoweredSqlAggregateShape::CountRows => Ok(count()),
1216 LoweredSqlAggregateShape::CountField(field) => Ok(count_by(field)),
1217 LoweredSqlAggregateShape::FieldTarget {
1218 kind: SqlAggregateKind::Sum,
1219 field,
1220 } => Ok(sum(field)),
1221 LoweredSqlAggregateShape::FieldTarget {
1222 kind: SqlAggregateKind::Avg,
1223 field,
1224 } => Ok(avg(field)),
1225 LoweredSqlAggregateShape::FieldTarget {
1226 kind: SqlAggregateKind::Min,
1227 field,
1228 } => Ok(min_by(field)),
1229 LoweredSqlAggregateShape::FieldTarget {
1230 kind: SqlAggregateKind::Max,
1231 field,
1232 } => Ok(max_by(field)),
1233 LoweredSqlAggregateShape::FieldTarget {
1234 kind: SqlAggregateKind::Count,
1235 ..
1236 } => Err(SqlLoweringError::unsupported_select_projection()),
1237 }
1238}
1239
1240fn resolve_having_aggregate_index(
1241 target: &SqlAggregateCall,
1242 grouped_projection_aggregates: &[SqlAggregateCall],
1243) -> Result<usize, SqlLoweringError> {
1244 let mut matched = grouped_projection_aggregates
1245 .iter()
1246 .enumerate()
1247 .filter_map(|(index, aggregate)| (aggregate == target).then_some(index));
1248 let Some(index) = matched.next() else {
1249 return Err(SqlLoweringError::unsupported_select_having());
1250 };
1251 if matched.next().is_some() {
1252 return Err(SqlLoweringError::unsupported_select_having());
1253 }
1254
1255 Ok(index)
1256}
1257
1258fn lower_delete_shape(statement: SqlDeleteStatement) -> LoweredBaseQueryShape {
1259 let SqlDeleteStatement {
1260 predicate,
1261 order_by,
1262 limit,
1263 entity: _,
1264 } = statement;
1265
1266 LoweredBaseQueryShape {
1267 predicate,
1268 order_by,
1269 limit,
1270 offset: None,
1271 }
1272}
1273
1274fn apply_order_terms_structural(
1275 mut query: StructuralQuery,
1276 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
1277) -> StructuralQuery {
1278 for term in order_by {
1279 query = match term.direction {
1280 SqlOrderDirection::Asc => query.order_by(term.field),
1281 SqlOrderDirection::Desc => query.order_by_desc(term.field),
1282 };
1283 }
1284
1285 query
1286}
1287
1288fn normalize_having_clauses(
1289 clauses: Vec<SqlHavingClause>,
1290 entity_scope: &[String],
1291) -> Vec<SqlHavingClause> {
1292 clauses
1293 .into_iter()
1294 .map(|clause| SqlHavingClause {
1295 symbol: normalize_having_symbol(clause.symbol, entity_scope),
1296 op: clause.op,
1297 value: clause.value,
1298 })
1299 .collect()
1300}
1301
1302fn normalize_having_symbol(symbol: SqlHavingSymbol, entity_scope: &[String]) -> SqlHavingSymbol {
1303 match symbol {
1304 SqlHavingSymbol::Field(field) => {
1305 SqlHavingSymbol::Field(normalize_identifier_to_scope(field, entity_scope))
1306 }
1307 SqlHavingSymbol::Aggregate(aggregate) => SqlHavingSymbol::Aggregate(
1308 normalize_aggregate_call_identifiers(aggregate, entity_scope),
1309 ),
1310 }
1311}
1312
1313fn normalize_aggregate_call_identifiers(
1314 aggregate: SqlAggregateCall,
1315 entity_scope: &[String],
1316) -> SqlAggregateCall {
1317 SqlAggregateCall {
1318 kind: aggregate.kind,
1319 field: aggregate
1320 .field
1321 .map(|field| normalize_identifier_to_scope(field, entity_scope)),
1322 }
1323}
1324
1325fn sql_entity_scope_candidates(sql_entity: &str, expected_entity: &'static str) -> Vec<String> {
1328 let mut out = Vec::new();
1329 out.push(sql_entity.to_string());
1330 out.push(expected_entity.to_string());
1331
1332 if let Some(last) = identifier_last_segment(sql_entity) {
1333 out.push(last.to_string());
1334 }
1335 if let Some(last) = identifier_last_segment(expected_entity) {
1336 out.push(last.to_string());
1337 }
1338
1339 out
1340}
1341
1342fn normalize_projection_identifiers(
1343 projection: SqlProjection,
1344 entity_scope: &[String],
1345) -> SqlProjection {
1346 match projection {
1347 SqlProjection::All => SqlProjection::All,
1348 SqlProjection::Items(items) => SqlProjection::Items(
1349 items
1350 .into_iter()
1351 .map(|item| match item {
1352 SqlSelectItem::Field(field) => {
1353 SqlSelectItem::Field(normalize_identifier(field, entity_scope))
1354 }
1355 SqlSelectItem::Aggregate(aggregate) => {
1356 SqlSelectItem::Aggregate(SqlAggregateCall {
1357 kind: aggregate.kind,
1358 field: aggregate
1359 .field
1360 .map(|field| normalize_identifier(field, entity_scope)),
1361 })
1362 }
1363 SqlSelectItem::TextFunction(SqlTextFunctionCall {
1364 function,
1365 field,
1366 literal,
1367 literal2,
1368 literal3,
1369 }) => SqlSelectItem::TextFunction(SqlTextFunctionCall {
1370 function,
1371 field: normalize_identifier(field, entity_scope),
1372 literal,
1373 literal2,
1374 literal3,
1375 }),
1376 })
1377 .collect(),
1378 ),
1379 }
1380}
1381
1382fn normalize_order_terms(
1383 terms: Vec<crate::db::sql::parser::SqlOrderTerm>,
1384 entity_scope: &[String],
1385) -> Vec<crate::db::sql::parser::SqlOrderTerm> {
1386 terms
1387 .into_iter()
1388 .map(|term| crate::db::sql::parser::SqlOrderTerm {
1389 field: normalize_identifier(term.field, entity_scope),
1390 direction: term.direction,
1391 })
1392 .collect()
1393}
1394
1395fn normalize_identifier_list(fields: Vec<String>, entity_scope: &[String]) -> Vec<String> {
1396 fields
1397 .into_iter()
1398 .map(|field| normalize_identifier(field, entity_scope))
1399 .collect()
1400}
1401
1402fn adapt_predicate_identifiers_to_scope(
1405 predicate: Predicate,
1406 entity_scope: &[String],
1407) -> Predicate {
1408 rewrite_field_identifiers(predicate, |field| normalize_identifier(field, entity_scope))
1409}
1410
1411fn normalize_identifier(identifier: String, entity_scope: &[String]) -> String {
1412 normalize_identifier_to_scope(identifier, entity_scope)
1413}
1414
1415fn ensure_entity_matches_expected(
1416 sql_entity: &str,
1417 expected_entity: &'static str,
1418) -> Result<(), SqlLoweringError> {
1419 if identifiers_tail_match(sql_entity, expected_entity) {
1420 return Ok(());
1421 }
1422
1423 Err(SqlLoweringError::entity_mismatch(
1424 sql_entity,
1425 expected_entity,
1426 ))
1427}