1#[cfg(test)]
11mod tests;
12
13use crate::{
14 db::{
15 predicate::{MissingRowPolicy, Predicate},
16 query::{
17 builder::aggregate::{avg, count, count_by, max_by, min_by, sum},
18 intent::{Query, QueryError, StructuralQuery},
19 plan::ExpressionOrderTerm,
20 },
21 sql::identifier::{
22 identifier_last_segment, identifiers_tail_match, normalize_identifier_to_scope,
23 rewrite_field_identifiers,
24 },
25 sql::parser::{
26 SqlAggregateCall, SqlAggregateKind, SqlDeleteStatement, SqlExplainMode,
27 SqlExplainStatement, SqlExplainTarget, SqlHavingClause, SqlHavingSymbol,
28 SqlOrderDirection, SqlOrderTerm, SqlProjection, SqlSelectItem, SqlSelectStatement,
29 SqlStatement, SqlTextFunctionCall,
30 },
31 },
32 traits::EntityKind,
33};
34use thiserror::Error as ThisError;
35
36#[derive(Clone, Debug)]
45pub struct LoweredSqlCommand(LoweredSqlCommandInner);
46
47#[derive(Clone, Debug)]
48enum LoweredSqlCommandInner {
49 Query(LoweredSqlQuery),
50 Explain {
51 mode: SqlExplainMode,
52 query: LoweredSqlQuery,
53 },
54 ExplainGlobalAggregate {
55 mode: SqlExplainMode,
56 command: LoweredSqlGlobalAggregateCommand,
57 },
58 DescribeEntity,
59 ShowIndexesEntity,
60 ShowColumnsEntity,
61 ShowEntities,
62}
63
64#[cfg(test)]
72#[derive(Debug)]
73pub(crate) enum SqlCommand<E: EntityKind> {
74 Query(Query<E>),
75 Explain {
76 mode: SqlExplainMode,
77 query: Query<E>,
78 },
79 ExplainGlobalAggregate {
80 mode: SqlExplainMode,
81 command: SqlGlobalAggregateCommand<E>,
82 },
83 DescribeEntity,
84 ShowIndexesEntity,
85 ShowColumnsEntity,
86 ShowEntities,
87}
88
89impl LoweredSqlCommand {
90 #[must_use]
91 pub(in crate::db) const fn query(&self) -> Option<&LoweredSqlQuery> {
92 match &self.0 {
93 LoweredSqlCommandInner::Query(query) => Some(query),
94 LoweredSqlCommandInner::Explain { .. }
95 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. }
96 | LoweredSqlCommandInner::DescribeEntity
97 | LoweredSqlCommandInner::ShowIndexesEntity
98 | LoweredSqlCommandInner::ShowColumnsEntity
99 | LoweredSqlCommandInner::ShowEntities => None,
100 }
101 }
102}
103
104#[derive(Clone, Debug)]
111pub(crate) enum LoweredSqlQuery {
112 Select(LoweredSelectShape),
113 Delete(LoweredBaseQueryShape),
114}
115
116impl LoweredSqlQuery {
117 pub(crate) const fn has_grouping(&self) -> bool {
119 match self {
120 Self::Select(select) => select.has_grouping(),
121 Self::Delete(_) => false,
122 }
123 }
124}
125
126#[derive(Clone, Debug, Eq, PartialEq)]
134pub(crate) enum SqlGlobalAggregateTerminal {
135 CountRows,
136 CountField(String),
137 SumField(String),
138 AvgField(String),
139 MinField(String),
140 MaxField(String),
141}
142
143#[derive(Clone, Debug)]
152pub(crate) struct LoweredSqlGlobalAggregateCommand {
153 query: LoweredBaseQueryShape,
154 terminal: SqlGlobalAggregateTerminal,
155}
156
157enum LoweredSqlAggregateShape {
165 CountRows,
166 CountField(String),
167 FieldTarget {
168 kind: SqlAggregateKind,
169 field: String,
170 },
171}
172
173#[derive(Debug)]
180pub(crate) struct SqlGlobalAggregateCommand<E: EntityKind> {
181 query: Query<E>,
182 terminal: SqlGlobalAggregateTerminal,
183}
184
185impl<E: EntityKind> SqlGlobalAggregateCommand<E> {
186 #[must_use]
188 pub(crate) const fn query(&self) -> &Query<E> {
189 &self.query
190 }
191
192 #[must_use]
194 pub(crate) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
195 &self.terminal
196 }
197}
198
199#[derive(Debug)]
209pub(crate) struct SqlGlobalAggregateCommandCore {
210 query: StructuralQuery,
211 terminal: SqlGlobalAggregateTerminal,
212}
213
214impl SqlGlobalAggregateCommandCore {
215 #[must_use]
217 pub(in crate::db) const fn query(&self) -> &StructuralQuery {
218 &self.query
219 }
220
221 #[must_use]
223 pub(in crate::db) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
224 &self.terminal
225 }
226}
227
228#[derive(Debug, ThisError)]
235pub(crate) enum SqlLoweringError {
236 #[error("{0}")]
237 Parse(#[from] crate::db::sql::parser::SqlParseError),
238
239 #[error("{0}")]
240 Query(#[from] QueryError),
241
242 #[error("SQL entity '{sql_entity}' does not match requested entity type '{expected_entity}'")]
243 EntityMismatch {
244 sql_entity: String,
245 expected_entity: &'static str,
246 },
247
248 #[error(
249 "unsupported SQL SELECT projection; supported forms are SELECT *, field lists, or grouped aggregate shapes"
250 )]
251 UnsupportedSelectProjection,
252
253 #[error("unsupported SQL SELECT DISTINCT")]
254 UnsupportedSelectDistinct,
255
256 #[error("unsupported SQL GROUP BY projection shape")]
257 UnsupportedSelectGroupBy,
258
259 #[error("unsupported SQL HAVING shape")]
260 UnsupportedSelectHaving,
261}
262
263impl SqlLoweringError {
264 fn entity_mismatch(sql_entity: impl Into<String>, expected_entity: &'static str) -> Self {
266 Self::EntityMismatch {
267 sql_entity: sql_entity.into(),
268 expected_entity,
269 }
270 }
271
272 const fn unsupported_select_projection() -> Self {
274 Self::UnsupportedSelectProjection
275 }
276
277 const fn unsupported_select_distinct() -> Self {
279 Self::UnsupportedSelectDistinct
280 }
281
282 const fn unsupported_select_group_by() -> Self {
284 Self::UnsupportedSelectGroupBy
285 }
286
287 const fn unsupported_select_having() -> Self {
289 Self::UnsupportedSelectHaving
290 }
291}
292
293#[derive(Clone, Debug)]
304pub(crate) struct PreparedSqlStatement {
305 statement: SqlStatement,
306}
307
308#[derive(Clone, Copy, Debug, Eq, PartialEq)]
309pub(crate) enum LoweredSqlLaneKind {
310 Query,
311 Explain,
312 Describe,
313 ShowIndexes,
314 ShowColumns,
315 ShowEntities,
316}
317
318#[cfg(test)]
320pub(crate) fn compile_sql_command<E: EntityKind>(
321 sql: &str,
322 consistency: MissingRowPolicy,
323) -> Result<SqlCommand<E>, SqlLoweringError> {
324 let statement = crate::db::sql::parser::parse_sql(sql)?;
325 compile_sql_command_from_statement::<E>(statement, consistency)
326}
327
328#[cfg(test)]
330pub(crate) fn compile_sql_command_from_statement<E: EntityKind>(
331 statement: SqlStatement,
332 consistency: MissingRowPolicy,
333) -> Result<SqlCommand<E>, SqlLoweringError> {
334 let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
335 compile_sql_command_from_prepared_statement::<E>(prepared, consistency)
336}
337
338#[cfg(test)]
340pub(crate) fn compile_sql_command_from_prepared_statement<E: EntityKind>(
341 prepared: PreparedSqlStatement,
342 consistency: MissingRowPolicy,
343) -> Result<SqlCommand<E>, SqlLoweringError> {
344 let lowered = lower_sql_command_from_prepared_statement(prepared, E::MODEL.primary_key.name)?;
345
346 bind_lowered_sql_command::<E>(lowered, consistency)
347}
348
349#[inline(never)]
351pub(crate) fn lower_sql_command_from_prepared_statement(
352 prepared: PreparedSqlStatement,
353 primary_key_field: &str,
354) -> Result<LoweredSqlCommand, SqlLoweringError> {
355 lower_prepared_statement(prepared.statement, primary_key_field)
356}
357
358pub(crate) const fn lowered_sql_command_lane(command: &LoweredSqlCommand) -> LoweredSqlLaneKind {
359 match command.0 {
360 LoweredSqlCommandInner::Query(_) => LoweredSqlLaneKind::Query,
361 LoweredSqlCommandInner::Explain { .. }
362 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. } => LoweredSqlLaneKind::Explain,
363 LoweredSqlCommandInner::DescribeEntity => LoweredSqlLaneKind::Describe,
364 LoweredSqlCommandInner::ShowIndexesEntity => LoweredSqlLaneKind::ShowIndexes,
365 LoweredSqlCommandInner::ShowColumnsEntity => LoweredSqlLaneKind::ShowColumns,
366 LoweredSqlCommandInner::ShowEntities => LoweredSqlLaneKind::ShowEntities,
367 }
368}
369
370pub(in crate::db) fn is_sql_global_aggregate_statement(statement: &SqlStatement) -> bool {
373 let SqlStatement::Select(statement) = statement else {
374 return false;
375 };
376
377 is_sql_global_aggregate_select(statement)
378}
379
380fn is_sql_global_aggregate_select(statement: &SqlSelectStatement) -> bool {
383 if statement.distinct || !statement.group_by.is_empty() || !statement.having.is_empty() {
384 return false;
385 }
386
387 lower_global_aggregate_terminal(statement.projection.clone()).is_ok()
388}
389
390#[inline(never)]
392pub(crate) fn render_lowered_sql_explain_plan_or_json(
393 lowered: &LoweredSqlCommand,
394 model: &'static crate::model::entity::EntityModel,
395 consistency: MissingRowPolicy,
396) -> Result<Option<String>, SqlLoweringError> {
397 let LoweredSqlCommandInner::Explain { mode, query } = &lowered.0 else {
398 return Ok(None);
399 };
400
401 let query = bind_lowered_sql_query_structural(model, query.clone(), consistency)?;
402 let rendered = match mode {
403 SqlExplainMode::Plan | SqlExplainMode::Json => {
404 let plan = query.build_plan()?;
405 let explain = plan.explain_with_model(model);
406
407 match mode {
408 SqlExplainMode::Plan => explain.render_text_canonical(),
409 SqlExplainMode::Json => explain.render_json_canonical(),
410 SqlExplainMode::Execution => unreachable!("execution mode handled above"),
411 }
412 }
413 SqlExplainMode::Execution => query.explain_execution_text()?,
414 };
415
416 Ok(Some(rendered))
417}
418
419pub(crate) fn bind_lowered_sql_explain_global_aggregate_structural(
422 lowered: &LoweredSqlCommand,
423 model: &'static crate::model::entity::EntityModel,
424 consistency: MissingRowPolicy,
425) -> Option<(SqlExplainMode, SqlGlobalAggregateCommandCore)> {
426 let LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } = &lowered.0 else {
427 return None;
428 };
429
430 Some((
431 *mode,
432 bind_lowered_sql_global_aggregate_command_structural(model, command.clone(), consistency),
433 ))
434}
435
436#[cfg(test)]
438pub(crate) fn bind_lowered_sql_command<E: EntityKind>(
439 lowered: LoweredSqlCommand,
440 consistency: MissingRowPolicy,
441) -> Result<SqlCommand<E>, SqlLoweringError> {
442 match lowered.0 {
443 LoweredSqlCommandInner::Query(query) => Ok(SqlCommand::Query(bind_lowered_sql_query::<E>(
444 query,
445 consistency,
446 )?)),
447 LoweredSqlCommandInner::Explain { mode, query } => Ok(SqlCommand::Explain {
448 mode,
449 query: bind_lowered_sql_query::<E>(query, consistency)?,
450 }),
451 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } => {
452 Ok(SqlCommand::ExplainGlobalAggregate {
453 mode,
454 command: bind_lowered_sql_global_aggregate_command::<E>(command, consistency),
455 })
456 }
457 LoweredSqlCommandInner::DescribeEntity => Ok(SqlCommand::DescribeEntity),
458 LoweredSqlCommandInner::ShowIndexesEntity => Ok(SqlCommand::ShowIndexesEntity),
459 LoweredSqlCommandInner::ShowColumnsEntity => Ok(SqlCommand::ShowColumnsEntity),
460 LoweredSqlCommandInner::ShowEntities => Ok(SqlCommand::ShowEntities),
461 }
462}
463
464#[inline(never)]
466pub(crate) fn prepare_sql_statement(
467 statement: SqlStatement,
468 expected_entity: &'static str,
469) -> Result<PreparedSqlStatement, SqlLoweringError> {
470 let statement = prepare_statement(statement, expected_entity)?;
471
472 Ok(PreparedSqlStatement { statement })
473}
474
475#[cfg(test)]
477pub(crate) fn compile_sql_global_aggregate_command<E: EntityKind>(
478 sql: &str,
479 consistency: MissingRowPolicy,
480) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
481 let statement = crate::db::sql::parser::parse_sql(sql)?;
482 let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
483 compile_sql_global_aggregate_command_from_prepared::<E>(prepared, consistency)
484}
485
486pub(crate) fn compile_sql_global_aggregate_command_from_prepared<E: EntityKind>(
490 prepared: PreparedSqlStatement,
491 consistency: MissingRowPolicy,
492) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
493 let SqlStatement::Select(statement) = prepared.statement else {
494 return Err(SqlLoweringError::unsupported_select_projection());
495 };
496
497 Ok(bind_lowered_sql_global_aggregate_command::<E>(
498 lower_global_aggregate_select_shape(statement)?,
499 consistency,
500 ))
501}
502
503#[inline(never)]
504fn prepare_statement(
505 statement: SqlStatement,
506 expected_entity: &'static str,
507) -> Result<SqlStatement, SqlLoweringError> {
508 match statement {
509 SqlStatement::Select(statement) => Ok(SqlStatement::Select(prepare_select_statement(
510 statement,
511 expected_entity,
512 )?)),
513 SqlStatement::Delete(statement) => Ok(SqlStatement::Delete(prepare_delete_statement(
514 statement,
515 expected_entity,
516 )?)),
517 SqlStatement::Explain(statement) => Ok(SqlStatement::Explain(prepare_explain_statement(
518 statement,
519 expected_entity,
520 )?)),
521 SqlStatement::Describe(statement) => {
522 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
523
524 Ok(SqlStatement::Describe(statement))
525 }
526 SqlStatement::ShowIndexes(statement) => {
527 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
528
529 Ok(SqlStatement::ShowIndexes(statement))
530 }
531 SqlStatement::ShowColumns(statement) => {
532 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
533
534 Ok(SqlStatement::ShowColumns(statement))
535 }
536 SqlStatement::ShowEntities(statement) => Ok(SqlStatement::ShowEntities(statement)),
537 }
538}
539
540fn prepare_explain_statement(
541 statement: SqlExplainStatement,
542 expected_entity: &'static str,
543) -> Result<SqlExplainStatement, SqlLoweringError> {
544 let target = match statement.statement {
545 SqlExplainTarget::Select(select_statement) => {
546 SqlExplainTarget::Select(prepare_select_statement(select_statement, expected_entity)?)
547 }
548 SqlExplainTarget::Delete(delete_statement) => {
549 SqlExplainTarget::Delete(prepare_delete_statement(delete_statement, expected_entity)?)
550 }
551 };
552
553 Ok(SqlExplainStatement {
554 mode: statement.mode,
555 statement: target,
556 })
557}
558
559fn prepare_select_statement(
560 statement: SqlSelectStatement,
561 expected_entity: &'static str,
562) -> Result<SqlSelectStatement, SqlLoweringError> {
563 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
564
565 Ok(normalize_select_statement_to_expected_entity(
566 statement,
567 expected_entity,
568 ))
569}
570
571fn normalize_select_statement_to_expected_entity(
572 mut statement: SqlSelectStatement,
573 expected_entity: &'static str,
574) -> SqlSelectStatement {
575 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
578 statement.projection =
579 normalize_projection_identifiers(statement.projection, entity_scope.as_slice());
580 statement.group_by = normalize_identifier_list(statement.group_by, entity_scope.as_slice());
581 statement.predicate = statement
582 .predicate
583 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
584 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
585 statement.having = normalize_having_clauses(statement.having, entity_scope.as_slice());
586
587 statement
588}
589
590fn prepare_delete_statement(
591 mut statement: SqlDeleteStatement,
592 expected_entity: &'static str,
593) -> Result<SqlDeleteStatement, SqlLoweringError> {
594 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
595 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
596 statement.predicate = statement
597 .predicate
598 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
599 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
600
601 Ok(statement)
602}
603
604#[inline(never)]
605fn lower_prepared_statement(
606 statement: SqlStatement,
607 primary_key_field: &str,
608) -> Result<LoweredSqlCommand, SqlLoweringError> {
609 match statement {
610 SqlStatement::Select(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
611 LoweredSqlQuery::Select(lower_select_shape(statement, primary_key_field)?),
612 ))),
613 SqlStatement::Delete(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
614 LoweredSqlQuery::Delete(lower_delete_shape(statement)),
615 ))),
616 SqlStatement::Explain(statement) => lower_explain_prepared(statement, primary_key_field),
617 SqlStatement::Describe(_) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::DescribeEntity)),
618 SqlStatement::ShowIndexes(_) => {
619 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowIndexesEntity))
620 }
621 SqlStatement::ShowColumns(_) => {
622 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowColumnsEntity))
623 }
624 SqlStatement::ShowEntities(_) => {
625 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowEntities))
626 }
627 }
628}
629
630fn lower_explain_prepared(
631 statement: SqlExplainStatement,
632 primary_key_field: &str,
633) -> Result<LoweredSqlCommand, SqlLoweringError> {
634 let mode = statement.mode;
635
636 match statement.statement {
637 SqlExplainTarget::Select(select_statement) => {
638 lower_explain_select_prepared(select_statement, mode, primary_key_field)
639 }
640 SqlExplainTarget::Delete(delete_statement) => {
641 Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
642 mode,
643 query: LoweredSqlQuery::Delete(lower_delete_shape(delete_statement)),
644 }))
645 }
646 }
647}
648
649fn lower_explain_select_prepared(
650 statement: SqlSelectStatement,
651 mode: SqlExplainMode,
652 primary_key_field: &str,
653) -> Result<LoweredSqlCommand, SqlLoweringError> {
654 match lower_select_shape(statement.clone(), primary_key_field) {
655 Ok(query) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
656 mode,
657 query: LoweredSqlQuery::Select(query),
658 })),
659 Err(SqlLoweringError::UnsupportedSelectProjection) => {
660 let command = lower_global_aggregate_select_shape(statement)?;
661
662 Ok(LoweredSqlCommand(
663 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command },
664 ))
665 }
666 Err(err) => Err(err),
667 }
668}
669
670fn lower_global_aggregate_select_shape(
671 statement: SqlSelectStatement,
672) -> Result<LoweredSqlGlobalAggregateCommand, SqlLoweringError> {
673 let SqlSelectStatement {
674 projection,
675 predicate,
676 distinct,
677 group_by,
678 having,
679 order_by,
680 limit,
681 offset,
682 entity: _,
683 } = statement;
684
685 if distinct {
686 return Err(SqlLoweringError::unsupported_select_distinct());
687 }
688 if !group_by.is_empty() {
689 return Err(SqlLoweringError::unsupported_select_group_by());
690 }
691 if !having.is_empty() {
692 return Err(SqlLoweringError::unsupported_select_having());
693 }
694
695 let terminal = lower_global_aggregate_terminal(projection)?;
696
697 Ok(LoweredSqlGlobalAggregateCommand {
698 query: LoweredBaseQueryShape {
699 predicate,
700 order_by,
701 limit,
702 offset,
703 },
704 terminal,
705 })
706}
707
708#[derive(Clone, Debug)]
716enum ResolvedHavingClause {
717 GroupField {
718 field: String,
719 op: crate::db::predicate::CompareOp,
720 value: crate::value::Value,
721 },
722 Aggregate {
723 aggregate_index: usize,
724 op: crate::db::predicate::CompareOp,
725 value: crate::value::Value,
726 },
727}
728
729#[derive(Clone, Debug)]
736pub(crate) struct LoweredSelectShape {
737 scalar_projection_fields: Option<Vec<String>>,
738 grouped_projection_aggregates: Vec<SqlAggregateCall>,
739 group_by_fields: Vec<String>,
740 distinct: bool,
741 having: Vec<ResolvedHavingClause>,
742 predicate: Option<Predicate>,
743 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
744 limit: Option<u32>,
745 offset: Option<u32>,
746}
747
748impl LoweredSelectShape {
749 const fn has_grouping(&self) -> bool {
751 !self.group_by_fields.is_empty()
752 }
753}
754
755#[derive(Clone, Debug)]
764pub(crate) struct LoweredBaseQueryShape {
765 predicate: Option<Predicate>,
766 order_by: Vec<SqlOrderTerm>,
767 limit: Option<u32>,
768 offset: Option<u32>,
769}
770
771#[inline(never)]
772fn lower_select_shape(
773 statement: SqlSelectStatement,
774 primary_key_field: &str,
775) -> Result<LoweredSelectShape, SqlLoweringError> {
776 let SqlSelectStatement {
777 projection,
778 predicate,
779 distinct,
780 group_by,
781 having,
782 order_by,
783 limit,
784 offset,
785 entity: _,
786 } = statement;
787 let projection_for_having = projection.clone();
788
789 let (scalar_projection_fields, grouped_projection_aggregates) = if group_by.is_empty() {
791 let scalar_projection_fields =
792 lower_scalar_projection_fields(projection, distinct, primary_key_field)?;
793 (scalar_projection_fields, Vec::new())
794 } else {
795 if distinct {
796 return Err(SqlLoweringError::unsupported_select_distinct());
797 }
798 let grouped_projection_aggregates =
799 grouped_projection_aggregate_calls(&projection, group_by.as_slice())?;
800 (None, grouped_projection_aggregates)
801 };
802
803 let having = lower_having_clauses(
805 having,
806 &projection_for_having,
807 group_by.as_slice(),
808 grouped_projection_aggregates.as_slice(),
809 )?;
810
811 Ok(LoweredSelectShape {
812 scalar_projection_fields,
813 grouped_projection_aggregates,
814 group_by_fields: group_by,
815 distinct,
816 having,
817 predicate,
818 order_by,
819 limit,
820 offset,
821 })
822}
823
824fn lower_scalar_projection_fields(
825 projection: SqlProjection,
826 distinct: bool,
827 primary_key_field: &str,
828) -> Result<Option<Vec<String>>, SqlLoweringError> {
829 let SqlProjection::Items(items) = projection else {
830 if distinct {
831 return Ok(None);
832 }
833
834 return Ok(None);
835 };
836
837 let has_aggregate = items
838 .iter()
839 .any(|item| matches!(item, SqlSelectItem::Aggregate(_)));
840 if has_aggregate {
841 return Err(SqlLoweringError::unsupported_select_projection());
842 }
843
844 let fields = items
845 .into_iter()
846 .map(|item| match item {
847 SqlSelectItem::Field(field) => Ok(field),
848 SqlSelectItem::Aggregate(_) | SqlSelectItem::TextFunction(_) => {
849 Err(SqlLoweringError::unsupported_select_projection())
850 }
851 })
852 .collect::<Result<Vec<_>, _>>()?;
853
854 validate_scalar_distinct_projection(distinct, fields.as_slice(), primary_key_field)?;
855
856 Ok(Some(fields))
857}
858
859fn validate_scalar_distinct_projection(
860 distinct: bool,
861 projection_fields: &[String],
862 primary_key_field: &str,
863) -> Result<(), SqlLoweringError> {
864 if !distinct {
865 return Ok(());
866 }
867
868 if projection_fields.is_empty() {
869 return Ok(());
870 }
871
872 let has_primary_key_field = projection_fields
873 .iter()
874 .any(|field| field == primary_key_field);
875 if !has_primary_key_field {
876 return Err(SqlLoweringError::unsupported_select_distinct());
877 }
878
879 Ok(())
880}
881
882fn lower_having_clauses(
883 having_clauses: Vec<SqlHavingClause>,
884 projection: &SqlProjection,
885 group_by_fields: &[String],
886 grouped_projection_aggregates: &[SqlAggregateCall],
887) -> Result<Vec<ResolvedHavingClause>, SqlLoweringError> {
888 if having_clauses.is_empty() {
889 return Ok(Vec::new());
890 }
891 if group_by_fields.is_empty() {
892 return Err(SqlLoweringError::unsupported_select_having());
893 }
894
895 let projection_aggregates = grouped_projection_aggregate_calls(projection, group_by_fields)
896 .map_err(|_| SqlLoweringError::unsupported_select_having())?;
897 if projection_aggregates.as_slice() != grouped_projection_aggregates {
898 return Err(SqlLoweringError::unsupported_select_having());
899 }
900
901 let mut lowered = Vec::with_capacity(having_clauses.len());
902 for clause in having_clauses {
903 match clause.symbol {
904 SqlHavingSymbol::Field(field) => lowered.push(ResolvedHavingClause::GroupField {
905 field,
906 op: clause.op,
907 value: clause.value,
908 }),
909 SqlHavingSymbol::Aggregate(aggregate) => {
910 let aggregate_index =
911 resolve_having_aggregate_index(&aggregate, grouped_projection_aggregates)?;
912 lowered.push(ResolvedHavingClause::Aggregate {
913 aggregate_index,
914 op: clause.op,
915 value: clause.value,
916 });
917 }
918 }
919 }
920
921 Ok(lowered)
922}
923
924#[inline(never)]
925pub(in crate::db) fn apply_lowered_select_shape(
926 mut query: StructuralQuery,
927 lowered: LoweredSelectShape,
928) -> Result<StructuralQuery, SqlLoweringError> {
929 let LoweredSelectShape {
930 scalar_projection_fields,
931 grouped_projection_aggregates,
932 group_by_fields,
933 distinct,
934 having,
935 predicate,
936 order_by,
937 limit,
938 offset,
939 } = lowered;
940
941 for field in group_by_fields {
943 query = query.group_by(field)?;
944 }
945
946 if distinct {
948 query = query.distinct();
949 }
950 if let Some(fields) = scalar_projection_fields {
951 query = query.select_fields(fields);
952 }
953 for aggregate in grouped_projection_aggregates {
954 query = query.aggregate(lower_aggregate_call(aggregate)?);
955 }
956
957 for clause in having {
959 match clause {
960 ResolvedHavingClause::GroupField { field, op, value } => {
961 query = query.having_group(field, op, value)?;
962 }
963 ResolvedHavingClause::Aggregate {
964 aggregate_index,
965 op,
966 value,
967 } => {
968 query = query.having_aggregate(aggregate_index, op, value)?;
969 }
970 }
971 }
972
973 Ok(apply_lowered_base_query_shape(
975 query,
976 LoweredBaseQueryShape {
977 predicate,
978 order_by,
979 limit,
980 offset,
981 },
982 ))
983}
984
985fn apply_lowered_base_query_shape(
986 mut query: StructuralQuery,
987 lowered: LoweredBaseQueryShape,
988) -> StructuralQuery {
989 if let Some(predicate) = lowered.predicate {
990 query = query.filter(predicate);
991 }
992 query = apply_order_terms_structural(query, lowered.order_by);
993 if let Some(limit) = lowered.limit {
994 query = query.limit(limit);
995 }
996 if let Some(offset) = lowered.offset {
997 query = query.offset(offset);
998 }
999
1000 query
1001}
1002
1003pub(in crate::db) fn bind_lowered_sql_query_structural(
1004 model: &'static crate::model::entity::EntityModel,
1005 lowered: LoweredSqlQuery,
1006 consistency: MissingRowPolicy,
1007) -> Result<StructuralQuery, SqlLoweringError> {
1008 match lowered {
1009 LoweredSqlQuery::Select(select) => {
1010 apply_lowered_select_shape(StructuralQuery::new(model, consistency), select)
1011 }
1012 LoweredSqlQuery::Delete(delete) => Ok(bind_lowered_sql_delete_query_structural(
1013 model,
1014 delete,
1015 consistency,
1016 )),
1017 }
1018}
1019
1020pub(in crate::db) fn bind_lowered_sql_delete_query_structural(
1021 model: &'static crate::model::entity::EntityModel,
1022 delete: LoweredBaseQueryShape,
1023 consistency: MissingRowPolicy,
1024) -> StructuralQuery {
1025 apply_lowered_base_query_shape(StructuralQuery::new(model, consistency).delete(), delete)
1026}
1027
1028pub(in crate::db) fn bind_lowered_sql_query<E: EntityKind>(
1029 lowered: LoweredSqlQuery,
1030 consistency: MissingRowPolicy,
1031) -> Result<Query<E>, SqlLoweringError> {
1032 let structural = bind_lowered_sql_query_structural(E::MODEL, lowered, consistency)?;
1033
1034 Ok(Query::from_inner(structural))
1035}
1036
1037fn bind_lowered_sql_global_aggregate_command<E: EntityKind>(
1038 lowered: LoweredSqlGlobalAggregateCommand,
1039 consistency: MissingRowPolicy,
1040) -> SqlGlobalAggregateCommand<E> {
1041 SqlGlobalAggregateCommand {
1042 query: Query::from_inner(apply_lowered_base_query_shape(
1043 StructuralQuery::new(E::MODEL, consistency),
1044 lowered.query,
1045 )),
1046 terminal: lowered.terminal,
1047 }
1048}
1049
1050fn bind_lowered_sql_global_aggregate_command_structural(
1051 model: &'static crate::model::entity::EntityModel,
1052 lowered: LoweredSqlGlobalAggregateCommand,
1053 consistency: MissingRowPolicy,
1054) -> SqlGlobalAggregateCommandCore {
1055 SqlGlobalAggregateCommandCore {
1056 query: apply_lowered_base_query_shape(
1057 StructuralQuery::new(model, consistency),
1058 lowered.query,
1059 ),
1060 terminal: lowered.terminal,
1061 }
1062}
1063
1064fn lower_global_aggregate_terminal(
1065 projection: SqlProjection,
1066) -> Result<SqlGlobalAggregateTerminal, SqlLoweringError> {
1067 let SqlProjection::Items(items) = projection else {
1068 return Err(SqlLoweringError::unsupported_select_projection());
1069 };
1070 if items.len() != 1 {
1071 return Err(SqlLoweringError::unsupported_select_projection());
1072 }
1073
1074 let Some(SqlSelectItem::Aggregate(aggregate)) = items.into_iter().next() else {
1075 return Err(SqlLoweringError::unsupported_select_projection());
1076 };
1077
1078 match lower_sql_aggregate_shape(aggregate)? {
1079 LoweredSqlAggregateShape::CountRows => Ok(SqlGlobalAggregateTerminal::CountRows),
1080 LoweredSqlAggregateShape::CountField(field) => {
1081 Ok(SqlGlobalAggregateTerminal::CountField(field))
1082 }
1083 LoweredSqlAggregateShape::FieldTarget {
1084 kind: SqlAggregateKind::Sum,
1085 field,
1086 } => Ok(SqlGlobalAggregateTerminal::SumField(field)),
1087 LoweredSqlAggregateShape::FieldTarget {
1088 kind: SqlAggregateKind::Avg,
1089 field,
1090 } => Ok(SqlGlobalAggregateTerminal::AvgField(field)),
1091 LoweredSqlAggregateShape::FieldTarget {
1092 kind: SqlAggregateKind::Min,
1093 field,
1094 } => Ok(SqlGlobalAggregateTerminal::MinField(field)),
1095 LoweredSqlAggregateShape::FieldTarget {
1096 kind: SqlAggregateKind::Max,
1097 field,
1098 } => Ok(SqlGlobalAggregateTerminal::MaxField(field)),
1099 LoweredSqlAggregateShape::FieldTarget {
1100 kind: SqlAggregateKind::Count,
1101 ..
1102 } => Err(SqlLoweringError::unsupported_select_projection()),
1103 }
1104}
1105
1106fn lower_sql_aggregate_shape(
1107 call: SqlAggregateCall,
1108) -> Result<LoweredSqlAggregateShape, SqlLoweringError> {
1109 match (call.kind, call.field) {
1110 (SqlAggregateKind::Count, None) => Ok(LoweredSqlAggregateShape::CountRows),
1111 (SqlAggregateKind::Count, Some(field)) => Ok(LoweredSqlAggregateShape::CountField(field)),
1112 (
1113 kind @ (SqlAggregateKind::Sum
1114 | SqlAggregateKind::Avg
1115 | SqlAggregateKind::Min
1116 | SqlAggregateKind::Max),
1117 Some(field),
1118 ) => Ok(LoweredSqlAggregateShape::FieldTarget { kind, field }),
1119 _ => Err(SqlLoweringError::unsupported_select_projection()),
1120 }
1121}
1122
1123fn grouped_projection_aggregate_calls(
1124 projection: &SqlProjection,
1125 group_by_fields: &[String],
1126) -> Result<Vec<SqlAggregateCall>, SqlLoweringError> {
1127 if group_by_fields.is_empty() {
1128 return Err(SqlLoweringError::unsupported_select_group_by());
1129 }
1130
1131 let SqlProjection::Items(items) = projection else {
1132 return Err(SqlLoweringError::unsupported_select_group_by());
1133 };
1134
1135 let mut projected_group_fields = Vec::<String>::new();
1136 let mut aggregate_calls = Vec::<SqlAggregateCall>::new();
1137 let mut seen_aggregate = false;
1138
1139 for item in items {
1140 match item {
1141 SqlSelectItem::Field(field) => {
1142 if seen_aggregate {
1145 return Err(SqlLoweringError::unsupported_select_group_by());
1146 }
1147 projected_group_fields.push(field.clone());
1148 }
1149 SqlSelectItem::Aggregate(aggregate) => {
1150 seen_aggregate = true;
1151 aggregate_calls.push(aggregate.clone());
1152 }
1153 SqlSelectItem::TextFunction(_) => {
1154 return Err(SqlLoweringError::unsupported_select_group_by());
1155 }
1156 }
1157 }
1158
1159 if aggregate_calls.is_empty() || projected_group_fields.as_slice() != group_by_fields {
1160 return Err(SqlLoweringError::unsupported_select_group_by());
1161 }
1162
1163 Ok(aggregate_calls)
1164}
1165
1166fn lower_aggregate_call(
1167 call: SqlAggregateCall,
1168) -> Result<crate::db::query::builder::AggregateExpr, SqlLoweringError> {
1169 match lower_sql_aggregate_shape(call)? {
1170 LoweredSqlAggregateShape::CountRows => Ok(count()),
1171 LoweredSqlAggregateShape::CountField(field) => Ok(count_by(field)),
1172 LoweredSqlAggregateShape::FieldTarget {
1173 kind: SqlAggregateKind::Sum,
1174 field,
1175 } => Ok(sum(field)),
1176 LoweredSqlAggregateShape::FieldTarget {
1177 kind: SqlAggregateKind::Avg,
1178 field,
1179 } => Ok(avg(field)),
1180 LoweredSqlAggregateShape::FieldTarget {
1181 kind: SqlAggregateKind::Min,
1182 field,
1183 } => Ok(min_by(field)),
1184 LoweredSqlAggregateShape::FieldTarget {
1185 kind: SqlAggregateKind::Max,
1186 field,
1187 } => Ok(max_by(field)),
1188 LoweredSqlAggregateShape::FieldTarget {
1189 kind: SqlAggregateKind::Count,
1190 ..
1191 } => Err(SqlLoweringError::unsupported_select_projection()),
1192 }
1193}
1194
1195fn resolve_having_aggregate_index(
1196 target: &SqlAggregateCall,
1197 grouped_projection_aggregates: &[SqlAggregateCall],
1198) -> Result<usize, SqlLoweringError> {
1199 let mut matched = grouped_projection_aggregates
1200 .iter()
1201 .enumerate()
1202 .filter_map(|(index, aggregate)| (aggregate == target).then_some(index));
1203 let Some(index) = matched.next() else {
1204 return Err(SqlLoweringError::unsupported_select_having());
1205 };
1206 if matched.next().is_some() {
1207 return Err(SqlLoweringError::unsupported_select_having());
1208 }
1209
1210 Ok(index)
1211}
1212
1213fn lower_delete_shape(statement: SqlDeleteStatement) -> LoweredBaseQueryShape {
1214 let SqlDeleteStatement {
1215 predicate,
1216 order_by,
1217 limit,
1218 entity: _,
1219 } = statement;
1220
1221 LoweredBaseQueryShape {
1222 predicate,
1223 order_by,
1224 limit,
1225 offset: None,
1226 }
1227}
1228
1229fn apply_order_terms_structural(
1230 mut query: StructuralQuery,
1231 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
1232) -> StructuralQuery {
1233 for term in order_by {
1234 query = match term.direction {
1235 SqlOrderDirection::Asc => query.order_by(term.field),
1236 SqlOrderDirection::Desc => query.order_by_desc(term.field),
1237 };
1238 }
1239
1240 query
1241}
1242
1243fn normalize_having_clauses(
1244 clauses: Vec<SqlHavingClause>,
1245 entity_scope: &[String],
1246) -> Vec<SqlHavingClause> {
1247 clauses
1248 .into_iter()
1249 .map(|clause| SqlHavingClause {
1250 symbol: normalize_having_symbol(clause.symbol, entity_scope),
1251 op: clause.op,
1252 value: clause.value,
1253 })
1254 .collect()
1255}
1256
1257fn normalize_having_symbol(symbol: SqlHavingSymbol, entity_scope: &[String]) -> SqlHavingSymbol {
1258 match symbol {
1259 SqlHavingSymbol::Field(field) => {
1260 SqlHavingSymbol::Field(normalize_identifier_to_scope(field, entity_scope))
1261 }
1262 SqlHavingSymbol::Aggregate(aggregate) => SqlHavingSymbol::Aggregate(
1263 normalize_aggregate_call_identifiers(aggregate, entity_scope),
1264 ),
1265 }
1266}
1267
1268fn normalize_aggregate_call_identifiers(
1269 aggregate: SqlAggregateCall,
1270 entity_scope: &[String],
1271) -> SqlAggregateCall {
1272 SqlAggregateCall {
1273 kind: aggregate.kind,
1274 field: aggregate
1275 .field
1276 .map(|field| normalize_identifier_to_scope(field, entity_scope)),
1277 }
1278}
1279
1280fn sql_entity_scope_candidates(sql_entity: &str, expected_entity: &'static str) -> Vec<String> {
1283 let mut out = Vec::new();
1284 out.push(sql_entity.to_string());
1285 out.push(expected_entity.to_string());
1286
1287 if let Some(last) = identifier_last_segment(sql_entity) {
1288 out.push(last.to_string());
1289 }
1290 if let Some(last) = identifier_last_segment(expected_entity) {
1291 out.push(last.to_string());
1292 }
1293
1294 out
1295}
1296
1297fn normalize_projection_identifiers(
1298 projection: SqlProjection,
1299 entity_scope: &[String],
1300) -> SqlProjection {
1301 match projection {
1302 SqlProjection::All => SqlProjection::All,
1303 SqlProjection::Items(items) => SqlProjection::Items(
1304 items
1305 .into_iter()
1306 .map(|item| match item {
1307 SqlSelectItem::Field(field) => {
1308 SqlSelectItem::Field(normalize_identifier(field, entity_scope))
1309 }
1310 SqlSelectItem::Aggregate(aggregate) => {
1311 SqlSelectItem::Aggregate(SqlAggregateCall {
1312 kind: aggregate.kind,
1313 field: aggregate
1314 .field
1315 .map(|field| normalize_identifier(field, entity_scope)),
1316 })
1317 }
1318 SqlSelectItem::TextFunction(SqlTextFunctionCall {
1319 function,
1320 field,
1321 literal,
1322 literal2,
1323 literal3,
1324 }) => SqlSelectItem::TextFunction(SqlTextFunctionCall {
1325 function,
1326 field: normalize_identifier(field, entity_scope),
1327 literal,
1328 literal2,
1329 literal3,
1330 }),
1331 })
1332 .collect(),
1333 ),
1334 }
1335}
1336
1337fn normalize_order_terms(
1338 terms: Vec<crate::db::sql::parser::SqlOrderTerm>,
1339 entity_scope: &[String],
1340) -> Vec<crate::db::sql::parser::SqlOrderTerm> {
1341 terms
1342 .into_iter()
1343 .map(|term| crate::db::sql::parser::SqlOrderTerm {
1344 field: normalize_order_term_identifier(term.field, entity_scope),
1345 direction: term.direction,
1346 })
1347 .collect()
1348}
1349
1350fn normalize_order_term_identifier(identifier: String, entity_scope: &[String]) -> String {
1351 let Some(expression) = ExpressionOrderTerm::parse(identifier.as_str()) else {
1352 return normalize_identifier(identifier, entity_scope);
1353 };
1354 let normalized_field = normalize_identifier(expression.field().to_string(), entity_scope);
1355
1356 expression.canonical_text_with_field(normalized_field.as_str())
1357}
1358
1359fn normalize_identifier_list(fields: Vec<String>, entity_scope: &[String]) -> Vec<String> {
1360 fields
1361 .into_iter()
1362 .map(|field| normalize_identifier(field, entity_scope))
1363 .collect()
1364}
1365
1366fn adapt_predicate_identifiers_to_scope(
1369 predicate: Predicate,
1370 entity_scope: &[String],
1371) -> Predicate {
1372 rewrite_field_identifiers(predicate, |field| normalize_identifier(field, entity_scope))
1373}
1374
1375fn normalize_identifier(identifier: String, entity_scope: &[String]) -> String {
1376 normalize_identifier_to_scope(identifier, entity_scope)
1377}
1378
1379fn ensure_entity_matches_expected(
1380 sql_entity: &str,
1381 expected_entity: &'static str,
1382) -> Result<(), SqlLoweringError> {
1383 if identifiers_tail_match(sql_entity, expected_entity) {
1384 return Ok(());
1385 }
1386
1387 Err(SqlLoweringError::entity_mismatch(
1388 sql_entity,
1389 expected_entity,
1390 ))
1391}