1#[cfg(test)]
11mod tests;
12
13use crate::{
14 db::{
15 predicate::{MissingRowPolicy, Predicate},
16 query::{
17 builder::aggregate::{avg, count, count_by, max_by, min_by, sum},
18 intent::{Query, QueryError, StructuralQuery},
19 },
20 sql::identifier::{
21 identifier_last_segment, identifiers_tail_match, normalize_identifier_to_scope,
22 rewrite_field_identifiers,
23 },
24 sql::parser::{
25 SqlAggregateCall, SqlAggregateKind, SqlDeleteStatement, SqlExplainMode,
26 SqlExplainStatement, SqlExplainTarget, SqlHavingClause, SqlHavingSymbol,
27 SqlOrderDirection, SqlOrderTerm, SqlProjection, SqlSelectItem, SqlSelectStatement,
28 SqlStatement, SqlTextFunctionCall, parse_sql,
29 },
30 },
31 traits::EntityKind,
32};
33use thiserror::Error as ThisError;
34
35#[derive(Clone, Debug)]
44pub struct LoweredSqlCommand(LoweredSqlCommandInner);
45
46#[derive(Clone, Debug)]
47enum LoweredSqlCommandInner {
48 Query(LoweredSqlQuery),
49 Explain {
50 mode: SqlExplainMode,
51 query: LoweredSqlQuery,
52 },
53 ExplainGlobalAggregate {
54 mode: SqlExplainMode,
55 command: LoweredSqlGlobalAggregateCommand,
56 },
57 DescribeEntity,
58 ShowIndexesEntity,
59 ShowColumnsEntity,
60 ShowEntities,
61}
62
63#[cfg(test)]
71#[derive(Debug)]
72pub(crate) enum SqlCommand<E: EntityKind> {
73 Query(Query<E>),
74 Explain {
75 mode: SqlExplainMode,
76 query: Query<E>,
77 },
78 ExplainGlobalAggregate {
79 mode: SqlExplainMode,
80 command: SqlGlobalAggregateCommand<E>,
81 },
82 DescribeEntity,
83 ShowIndexesEntity,
84 ShowColumnsEntity,
85 ShowEntities,
86}
87
88impl LoweredSqlCommand {
89 #[must_use]
90 pub(in crate::db) const fn query(&self) -> Option<&LoweredSqlQuery> {
91 match &self.0 {
92 LoweredSqlCommandInner::Query(query) => Some(query),
93 LoweredSqlCommandInner::Explain { .. }
94 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. }
95 | LoweredSqlCommandInner::DescribeEntity
96 | LoweredSqlCommandInner::ShowIndexesEntity
97 | LoweredSqlCommandInner::ShowColumnsEntity
98 | LoweredSqlCommandInner::ShowEntities => None,
99 }
100 }
101}
102
103#[derive(Clone, Debug)]
110pub(crate) enum LoweredSqlQuery {
111 Select(LoweredSelectShape),
112 Delete(LoweredBaseQueryShape),
113}
114
115#[derive(Clone, Debug, Eq, PartialEq)]
123pub(crate) enum SqlGlobalAggregateTerminal {
124 CountRows,
125 CountField(String),
126 SumField(String),
127 AvgField(String),
128 MinField(String),
129 MaxField(String),
130}
131
132#[derive(Clone, Debug)]
141pub(crate) struct LoweredSqlGlobalAggregateCommand {
142 query: LoweredBaseQueryShape,
143 terminal: SqlGlobalAggregateTerminal,
144}
145
146enum LoweredSqlAggregateShape {
154 CountRows,
155 CountField(String),
156 FieldTarget {
157 kind: SqlAggregateKind,
158 field: String,
159 },
160}
161
162#[derive(Debug)]
169pub(crate) struct SqlGlobalAggregateCommand<E: EntityKind> {
170 query: Query<E>,
171 terminal: SqlGlobalAggregateTerminal,
172}
173
174impl<E: EntityKind> SqlGlobalAggregateCommand<E> {
175 #[must_use]
177 pub(crate) const fn query(&self) -> &Query<E> {
178 &self.query
179 }
180
181 #[must_use]
183 pub(crate) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
184 &self.terminal
185 }
186}
187
188#[derive(Debug)]
198pub(crate) struct SqlGlobalAggregateCommandCore {
199 query: StructuralQuery,
200 terminal: SqlGlobalAggregateTerminal,
201}
202
203impl SqlGlobalAggregateCommandCore {
204 #[must_use]
206 pub(in crate::db) const fn query(&self) -> &StructuralQuery {
207 &self.query
208 }
209
210 #[must_use]
212 pub(in crate::db) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
213 &self.terminal
214 }
215}
216
217#[derive(Debug, ThisError)]
224pub(crate) enum SqlLoweringError {
225 #[error("{0}")]
226 Parse(#[from] crate::db::sql::parser::SqlParseError),
227
228 #[error("{0}")]
229 Query(#[from] QueryError),
230
231 #[error("SQL entity '{sql_entity}' does not match requested entity type '{expected_entity}'")]
232 EntityMismatch {
233 sql_entity: String,
234 expected_entity: &'static str,
235 },
236
237 #[error(
238 "unsupported SQL SELECT projection; supported forms are SELECT *, field lists, or grouped aggregate shapes"
239 )]
240 UnsupportedSelectProjection,
241
242 #[error("unsupported SQL SELECT DISTINCT")]
243 UnsupportedSelectDistinct,
244
245 #[error("unsupported SQL GROUP BY projection shape")]
246 UnsupportedSelectGroupBy,
247
248 #[error("unsupported SQL HAVING shape")]
249 UnsupportedSelectHaving,
250
251 #[error(
252 "generated SQL query dispatch requires SELECT, DELETE, EXPLAIN SELECT, or EXPLAIN DELETE"
253 )]
254 UnsupportedQuerySurfaceStatement,
255}
256
257impl SqlLoweringError {
258 fn entity_mismatch(sql_entity: impl Into<String>, expected_entity: &'static str) -> Self {
260 Self::EntityMismatch {
261 sql_entity: sql_entity.into(),
262 expected_entity,
263 }
264 }
265
266 const fn unsupported_select_projection() -> Self {
268 Self::UnsupportedSelectProjection
269 }
270
271 const fn unsupported_select_distinct() -> Self {
273 Self::UnsupportedSelectDistinct
274 }
275
276 const fn unsupported_select_group_by() -> Self {
278 Self::UnsupportedSelectGroupBy
279 }
280
281 const fn unsupported_select_having() -> Self {
283 Self::UnsupportedSelectHaving
284 }
285
286 const fn unsupported_query_surface_statement() -> Self {
288 Self::UnsupportedQuerySurfaceStatement
289 }
290}
291
292#[derive(Clone, Debug)]
303pub(crate) struct PreparedSqlStatement {
304 statement: SqlStatement,
305}
306
307#[derive(Clone, Debug)]
309pub(crate) enum PreparedSqlQuerySurfaceStatement {
310 Select(SqlSelectStatement),
311 Delete(SqlDeleteStatement),
312 ExplainSelect {
313 mode: SqlExplainMode,
314 statement: SqlSelectStatement,
315 },
316 ExplainDelete {
317 mode: SqlExplainMode,
318 statement: SqlDeleteStatement,
319 },
320}
321
322#[derive(Clone, Copy, Debug, Eq, PartialEq)]
323pub(crate) enum LoweredSqlLaneKind {
324 Query,
325 Explain,
326 Describe,
327 ShowIndexes,
328 ShowColumns,
329 ShowEntities,
330}
331
332#[cfg(test)]
334pub(crate) fn compile_sql_command<E: EntityKind>(
335 sql: &str,
336 consistency: MissingRowPolicy,
337) -> Result<SqlCommand<E>, SqlLoweringError> {
338 let statement = parse_sql(sql)?;
339 compile_sql_command_from_statement::<E>(statement, consistency)
340}
341
342#[cfg(test)]
344pub(crate) fn compile_sql_command_from_statement<E: EntityKind>(
345 statement: SqlStatement,
346 consistency: MissingRowPolicy,
347) -> Result<SqlCommand<E>, SqlLoweringError> {
348 let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
349 compile_sql_command_from_prepared_statement::<E>(prepared, consistency)
350}
351
352#[cfg(test)]
354pub(crate) fn compile_sql_command_from_prepared_statement<E: EntityKind>(
355 prepared: PreparedSqlStatement,
356 consistency: MissingRowPolicy,
357) -> Result<SqlCommand<E>, SqlLoweringError> {
358 let lowered = lower_sql_command_from_prepared_statement(prepared, E::MODEL.primary_key.name)?;
359
360 bind_lowered_sql_command::<E>(lowered, consistency)
361}
362
363#[inline(never)]
365pub(crate) fn lower_sql_command_from_prepared_statement(
366 prepared: PreparedSqlStatement,
367 primary_key_field: &str,
368) -> Result<LoweredSqlCommand, SqlLoweringError> {
369 lower_prepared_statement(prepared.statement, primary_key_field)
370}
371
372#[inline(never)]
374pub(crate) fn prepare_query_surface_statement(
375 statement: SqlStatement,
376 expected_entity: &'static str,
377) -> Result<PreparedSqlQuerySurfaceStatement, SqlLoweringError> {
378 match statement {
379 SqlStatement::Select(statement) => Ok(PreparedSqlQuerySurfaceStatement::Select(
380 normalize_select_statement_to_expected_entity(statement, expected_entity),
381 )),
382 SqlStatement::Delete(statement) => Ok(PreparedSqlQuerySurfaceStatement::Delete(
383 prepare_delete_statement(statement, expected_entity)?,
384 )),
385 SqlStatement::Explain(SqlExplainStatement {
386 mode,
387 statement: SqlExplainTarget::Select(select_statement),
388 }) => Ok(PreparedSqlQuerySurfaceStatement::ExplainSelect {
389 mode,
390 statement: normalize_select_statement_to_expected_entity(
391 select_statement,
392 expected_entity,
393 ),
394 }),
395 SqlStatement::Explain(SqlExplainStatement {
396 mode,
397 statement: SqlExplainTarget::Delete(delete_statement),
398 }) => Ok(PreparedSqlQuerySurfaceStatement::ExplainDelete {
399 mode,
400 statement: prepare_delete_statement(delete_statement, expected_entity)?,
401 }),
402 SqlStatement::Describe(_)
403 | SqlStatement::ShowIndexes(_)
404 | SqlStatement::ShowColumns(_)
405 | SqlStatement::ShowEntities(_) => {
406 Err(SqlLoweringError::unsupported_query_surface_statement())
407 }
408 }
409}
410
411#[inline(never)]
413pub(crate) fn lower_query_surface_command_from_prepared_statement(
414 prepared: PreparedSqlQuerySurfaceStatement,
415 primary_key_field: &str,
416) -> Result<LoweredSqlCommand, SqlLoweringError> {
417 lower_query_surface_prepared_statement(prepared, primary_key_field)
418}
419
420pub(crate) const fn lowered_sql_command_lane(command: &LoweredSqlCommand) -> LoweredSqlLaneKind {
421 match command.0 {
422 LoweredSqlCommandInner::Query(_) => LoweredSqlLaneKind::Query,
423 LoweredSqlCommandInner::Explain { .. }
424 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. } => LoweredSqlLaneKind::Explain,
425 LoweredSqlCommandInner::DescribeEntity => LoweredSqlLaneKind::Describe,
426 LoweredSqlCommandInner::ShowIndexesEntity => LoweredSqlLaneKind::ShowIndexes,
427 LoweredSqlCommandInner::ShowColumnsEntity => LoweredSqlLaneKind::ShowColumns,
428 LoweredSqlCommandInner::ShowEntities => LoweredSqlLaneKind::ShowEntities,
429 }
430}
431
432#[inline(never)]
434pub(crate) fn render_lowered_sql_explain_plan_or_json(
435 lowered: &LoweredSqlCommand,
436 model: &'static crate::model::entity::EntityModel,
437 consistency: MissingRowPolicy,
438) -> Result<Option<String>, SqlLoweringError> {
439 let LoweredSqlCommandInner::Explain { mode, query } = &lowered.0 else {
440 return Ok(None);
441 };
442
443 let query = bind_lowered_sql_query_structural(model, query.clone(), consistency)?;
444 let rendered = match mode {
445 SqlExplainMode::Plan | SqlExplainMode::Json => {
446 let plan = query.build_plan()?;
447 let explain = plan.explain_with_model(model);
448
449 match mode {
450 SqlExplainMode::Plan => explain.render_text_canonical(),
451 SqlExplainMode::Json => explain.render_json_canonical(),
452 SqlExplainMode::Execution => unreachable!("execution mode handled above"),
453 }
454 }
455 SqlExplainMode::Execution => query.explain_execution_text()?,
456 };
457
458 Ok(Some(rendered))
459}
460
461pub(crate) fn bind_lowered_sql_explain_global_aggregate_structural(
464 lowered: &LoweredSqlCommand,
465 model: &'static crate::model::entity::EntityModel,
466 consistency: MissingRowPolicy,
467) -> Option<(SqlExplainMode, SqlGlobalAggregateCommandCore)> {
468 let LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } = &lowered.0 else {
469 return None;
470 };
471
472 Some((
473 *mode,
474 bind_lowered_sql_global_aggregate_command_structural(model, command.clone(), consistency),
475 ))
476}
477
478#[cfg(test)]
480pub(crate) fn bind_lowered_sql_command<E: EntityKind>(
481 lowered: LoweredSqlCommand,
482 consistency: MissingRowPolicy,
483) -> Result<SqlCommand<E>, SqlLoweringError> {
484 match lowered.0 {
485 LoweredSqlCommandInner::Query(query) => Ok(SqlCommand::Query(bind_lowered_sql_query::<E>(
486 query,
487 consistency,
488 )?)),
489 LoweredSqlCommandInner::Explain { mode, query } => Ok(SqlCommand::Explain {
490 mode,
491 query: bind_lowered_sql_query::<E>(query, consistency)?,
492 }),
493 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } => {
494 Ok(SqlCommand::ExplainGlobalAggregate {
495 mode,
496 command: bind_lowered_sql_global_aggregate_command::<E>(command, consistency),
497 })
498 }
499 LoweredSqlCommandInner::DescribeEntity => Ok(SqlCommand::DescribeEntity),
500 LoweredSqlCommandInner::ShowIndexesEntity => Ok(SqlCommand::ShowIndexesEntity),
501 LoweredSqlCommandInner::ShowColumnsEntity => Ok(SqlCommand::ShowColumnsEntity),
502 LoweredSqlCommandInner::ShowEntities => Ok(SqlCommand::ShowEntities),
503 }
504}
505
506#[inline(never)]
508pub(crate) fn prepare_sql_statement(
509 statement: SqlStatement,
510 expected_entity: &'static str,
511) -> Result<PreparedSqlStatement, SqlLoweringError> {
512 let statement = prepare_statement(statement, expected_entity)?;
513
514 Ok(PreparedSqlStatement { statement })
515}
516
517pub(crate) fn compile_sql_global_aggregate_command<E: EntityKind>(
519 sql: &str,
520 consistency: MissingRowPolicy,
521) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
522 let statement = parse_sql(sql)?;
523 let prepared = prepare_sql_statement(statement, E::MODEL.name())?;
524 compile_sql_global_aggregate_command_from_prepared::<E>(prepared, consistency)
525}
526
527fn compile_sql_global_aggregate_command_from_prepared<E: EntityKind>(
528 prepared: PreparedSqlStatement,
529 consistency: MissingRowPolicy,
530) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
531 let SqlStatement::Select(statement) = prepared.statement else {
532 return Err(SqlLoweringError::unsupported_select_projection());
533 };
534
535 Ok(bind_lowered_sql_global_aggregate_command::<E>(
536 lower_global_aggregate_select_shape(statement)?,
537 consistency,
538 ))
539}
540
541#[inline(never)]
542fn prepare_statement(
543 statement: SqlStatement,
544 expected_entity: &'static str,
545) -> Result<SqlStatement, SqlLoweringError> {
546 match statement {
547 SqlStatement::Select(statement) => Ok(SqlStatement::Select(prepare_select_statement(
548 statement,
549 expected_entity,
550 )?)),
551 SqlStatement::Delete(statement) => Ok(SqlStatement::Delete(prepare_delete_statement(
552 statement,
553 expected_entity,
554 )?)),
555 SqlStatement::Explain(statement) => Ok(SqlStatement::Explain(prepare_explain_statement(
556 statement,
557 expected_entity,
558 )?)),
559 SqlStatement::Describe(statement) => {
560 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
561
562 Ok(SqlStatement::Describe(statement))
563 }
564 SqlStatement::ShowIndexes(statement) => {
565 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
566
567 Ok(SqlStatement::ShowIndexes(statement))
568 }
569 SqlStatement::ShowColumns(statement) => {
570 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
571
572 Ok(SqlStatement::ShowColumns(statement))
573 }
574 SqlStatement::ShowEntities(statement) => Ok(SqlStatement::ShowEntities(statement)),
575 }
576}
577
578fn prepare_explain_statement(
579 statement: SqlExplainStatement,
580 expected_entity: &'static str,
581) -> Result<SqlExplainStatement, SqlLoweringError> {
582 let target = match statement.statement {
583 SqlExplainTarget::Select(select_statement) => {
584 SqlExplainTarget::Select(prepare_select_statement(select_statement, expected_entity)?)
585 }
586 SqlExplainTarget::Delete(delete_statement) => {
587 SqlExplainTarget::Delete(prepare_delete_statement(delete_statement, expected_entity)?)
588 }
589 };
590
591 Ok(SqlExplainStatement {
592 mode: statement.mode,
593 statement: target,
594 })
595}
596
597fn prepare_select_statement(
598 statement: SqlSelectStatement,
599 expected_entity: &'static str,
600) -> Result<SqlSelectStatement, SqlLoweringError> {
601 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
602
603 Ok(normalize_select_statement_to_expected_entity(
604 statement,
605 expected_entity,
606 ))
607}
608
609fn normalize_select_statement_to_expected_entity(
610 mut statement: SqlSelectStatement,
611 expected_entity: &'static str,
612) -> SqlSelectStatement {
613 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
616 statement.projection =
617 normalize_projection_identifiers(statement.projection, entity_scope.as_slice());
618 statement.group_by = normalize_identifier_list(statement.group_by, entity_scope.as_slice());
619 statement.predicate = statement
620 .predicate
621 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
622 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
623 statement.having = normalize_having_clauses(statement.having, entity_scope.as_slice());
624
625 statement
626}
627
628fn prepare_delete_statement(
629 mut statement: SqlDeleteStatement,
630 expected_entity: &'static str,
631) -> Result<SqlDeleteStatement, SqlLoweringError> {
632 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
633 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
634 statement.predicate = statement
635 .predicate
636 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
637 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
638
639 Ok(statement)
640}
641
642#[inline(never)]
643fn lower_prepared_statement(
644 statement: SqlStatement,
645 primary_key_field: &str,
646) -> Result<LoweredSqlCommand, SqlLoweringError> {
647 match statement {
648 SqlStatement::Select(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
649 LoweredSqlQuery::Select(lower_select_shape(statement, primary_key_field)?),
650 ))),
651 SqlStatement::Delete(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
652 LoweredSqlQuery::Delete(lower_delete_shape(statement)),
653 ))),
654 SqlStatement::Explain(statement) => lower_explain_prepared(statement, primary_key_field),
655 SqlStatement::Describe(_) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::DescribeEntity)),
656 SqlStatement::ShowIndexes(_) => {
657 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowIndexesEntity))
658 }
659 SqlStatement::ShowColumns(_) => {
660 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowColumnsEntity))
661 }
662 SqlStatement::ShowEntities(_) => {
663 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowEntities))
664 }
665 }
666}
667
668#[inline(never)]
669fn lower_query_surface_prepared_statement(
670 statement: PreparedSqlQuerySurfaceStatement,
671 primary_key_field: &str,
672) -> Result<LoweredSqlCommand, SqlLoweringError> {
673 match statement {
674 PreparedSqlQuerySurfaceStatement::Select(statement) => {
675 Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
676 LoweredSqlQuery::Select(lower_select_shape(statement, primary_key_field)?),
677 )))
678 }
679 PreparedSqlQuerySurfaceStatement::Delete(statement) => Ok(LoweredSqlCommand(
680 LoweredSqlCommandInner::Query(LoweredSqlQuery::Delete(lower_delete_shape(statement))),
681 )),
682 PreparedSqlQuerySurfaceStatement::ExplainSelect { mode, statement } => {
683 lower_explain_select_prepared(statement, mode, primary_key_field)
684 }
685 PreparedSqlQuerySurfaceStatement::ExplainDelete { mode, statement } => {
686 Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
687 mode,
688 query: LoweredSqlQuery::Delete(lower_delete_shape(statement)),
689 }))
690 }
691 }
692}
693
694fn lower_explain_prepared(
695 statement: SqlExplainStatement,
696 primary_key_field: &str,
697) -> Result<LoweredSqlCommand, SqlLoweringError> {
698 let mode = statement.mode;
699
700 match statement.statement {
701 SqlExplainTarget::Select(select_statement) => {
702 lower_explain_select_prepared(select_statement, mode, primary_key_field)
703 }
704 SqlExplainTarget::Delete(delete_statement) => {
705 Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
706 mode,
707 query: LoweredSqlQuery::Delete(lower_delete_shape(delete_statement)),
708 }))
709 }
710 }
711}
712
713fn lower_explain_select_prepared(
714 statement: SqlSelectStatement,
715 mode: SqlExplainMode,
716 primary_key_field: &str,
717) -> Result<LoweredSqlCommand, SqlLoweringError> {
718 match lower_select_shape(statement.clone(), primary_key_field) {
719 Ok(query) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
720 mode,
721 query: LoweredSqlQuery::Select(query),
722 })),
723 Err(SqlLoweringError::UnsupportedSelectProjection) => {
724 let command = lower_global_aggregate_select_shape(statement)?;
725
726 Ok(LoweredSqlCommand(
727 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command },
728 ))
729 }
730 Err(err) => Err(err),
731 }
732}
733
734fn lower_global_aggregate_select_shape(
735 statement: SqlSelectStatement,
736) -> Result<LoweredSqlGlobalAggregateCommand, SqlLoweringError> {
737 let SqlSelectStatement {
738 projection,
739 predicate,
740 distinct,
741 group_by,
742 having,
743 order_by,
744 limit,
745 offset,
746 entity: _,
747 } = statement;
748
749 if distinct {
750 return Err(SqlLoweringError::unsupported_select_distinct());
751 }
752 if !group_by.is_empty() {
753 return Err(SqlLoweringError::unsupported_select_group_by());
754 }
755 if !having.is_empty() {
756 return Err(SqlLoweringError::unsupported_select_having());
757 }
758
759 let terminal = lower_global_aggregate_terminal(projection)?;
760
761 Ok(LoweredSqlGlobalAggregateCommand {
762 query: LoweredBaseQueryShape {
763 predicate,
764 order_by,
765 limit,
766 offset,
767 },
768 terminal,
769 })
770}
771
772#[derive(Clone, Debug)]
780enum ResolvedHavingClause {
781 GroupField {
782 field: String,
783 op: crate::db::predicate::CompareOp,
784 value: crate::value::Value,
785 },
786 Aggregate {
787 aggregate_index: usize,
788 op: crate::db::predicate::CompareOp,
789 value: crate::value::Value,
790 },
791}
792
793#[derive(Clone, Debug)]
800pub(crate) struct LoweredSelectShape {
801 scalar_projection_fields: Option<Vec<String>>,
802 grouped_projection_aggregates: Vec<SqlAggregateCall>,
803 group_by_fields: Vec<String>,
804 distinct: bool,
805 having: Vec<ResolvedHavingClause>,
806 predicate: Option<Predicate>,
807 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
808 limit: Option<u32>,
809 offset: Option<u32>,
810}
811
812#[derive(Clone, Debug)]
821pub(crate) struct LoweredBaseQueryShape {
822 predicate: Option<Predicate>,
823 order_by: Vec<SqlOrderTerm>,
824 limit: Option<u32>,
825 offset: Option<u32>,
826}
827
828#[inline(never)]
829fn lower_select_shape(
830 statement: SqlSelectStatement,
831 primary_key_field: &str,
832) -> Result<LoweredSelectShape, SqlLoweringError> {
833 let SqlSelectStatement {
834 projection,
835 predicate,
836 distinct,
837 group_by,
838 having,
839 order_by,
840 limit,
841 offset,
842 entity: _,
843 } = statement;
844 let projection_for_having = projection.clone();
845
846 let (scalar_projection_fields, grouped_projection_aggregates) = if group_by.is_empty() {
848 let scalar_projection_fields =
849 lower_scalar_projection_fields(projection, distinct, primary_key_field)?;
850 (scalar_projection_fields, Vec::new())
851 } else {
852 if distinct {
853 return Err(SqlLoweringError::unsupported_select_distinct());
854 }
855 let grouped_projection_aggregates =
856 grouped_projection_aggregate_calls(&projection, group_by.as_slice())?;
857 (None, grouped_projection_aggregates)
858 };
859
860 let having = lower_having_clauses(
862 having,
863 &projection_for_having,
864 group_by.as_slice(),
865 grouped_projection_aggregates.as_slice(),
866 )?;
867
868 Ok(LoweredSelectShape {
869 scalar_projection_fields,
870 grouped_projection_aggregates,
871 group_by_fields: group_by,
872 distinct,
873 having,
874 predicate,
875 order_by,
876 limit,
877 offset,
878 })
879}
880
881fn lower_scalar_projection_fields(
882 projection: SqlProjection,
883 distinct: bool,
884 primary_key_field: &str,
885) -> Result<Option<Vec<String>>, SqlLoweringError> {
886 let SqlProjection::Items(items) = projection else {
887 if distinct {
888 return Ok(None);
889 }
890
891 return Ok(None);
892 };
893
894 let has_aggregate = items
895 .iter()
896 .any(|item| matches!(item, SqlSelectItem::Aggregate(_)));
897 if has_aggregate {
898 return Err(SqlLoweringError::unsupported_select_projection());
899 }
900
901 let fields = items
902 .into_iter()
903 .map(|item| match item {
904 SqlSelectItem::Field(field) => Ok(field),
905 SqlSelectItem::Aggregate(_) | SqlSelectItem::TextFunction(_) => {
906 Err(SqlLoweringError::unsupported_select_projection())
907 }
908 })
909 .collect::<Result<Vec<_>, _>>()?;
910
911 validate_scalar_distinct_projection(distinct, fields.as_slice(), primary_key_field)?;
912
913 Ok(Some(fields))
914}
915
916fn validate_scalar_distinct_projection(
917 distinct: bool,
918 projection_fields: &[String],
919 primary_key_field: &str,
920) -> Result<(), SqlLoweringError> {
921 if !distinct {
922 return Ok(());
923 }
924
925 if projection_fields.is_empty() {
926 return Ok(());
927 }
928
929 let has_primary_key_field = projection_fields
930 .iter()
931 .any(|field| field == primary_key_field);
932 if !has_primary_key_field {
933 return Err(SqlLoweringError::unsupported_select_distinct());
934 }
935
936 Ok(())
937}
938
939fn lower_having_clauses(
940 having_clauses: Vec<SqlHavingClause>,
941 projection: &SqlProjection,
942 group_by_fields: &[String],
943 grouped_projection_aggregates: &[SqlAggregateCall],
944) -> Result<Vec<ResolvedHavingClause>, SqlLoweringError> {
945 if having_clauses.is_empty() {
946 return Ok(Vec::new());
947 }
948 if group_by_fields.is_empty() {
949 return Err(SqlLoweringError::unsupported_select_having());
950 }
951
952 let projection_aggregates = grouped_projection_aggregate_calls(projection, group_by_fields)
953 .map_err(|_| SqlLoweringError::unsupported_select_having())?;
954 if projection_aggregates.as_slice() != grouped_projection_aggregates {
955 return Err(SqlLoweringError::unsupported_select_having());
956 }
957
958 let mut lowered = Vec::with_capacity(having_clauses.len());
959 for clause in having_clauses {
960 match clause.symbol {
961 SqlHavingSymbol::Field(field) => lowered.push(ResolvedHavingClause::GroupField {
962 field,
963 op: clause.op,
964 value: clause.value,
965 }),
966 SqlHavingSymbol::Aggregate(aggregate) => {
967 let aggregate_index =
968 resolve_having_aggregate_index(&aggregate, grouped_projection_aggregates)?;
969 lowered.push(ResolvedHavingClause::Aggregate {
970 aggregate_index,
971 op: clause.op,
972 value: clause.value,
973 });
974 }
975 }
976 }
977
978 Ok(lowered)
979}
980
981#[inline(never)]
982pub(in crate::db) fn apply_lowered_select_shape(
983 mut query: StructuralQuery,
984 lowered: LoweredSelectShape,
985) -> Result<StructuralQuery, SqlLoweringError> {
986 let LoweredSelectShape {
987 scalar_projection_fields,
988 grouped_projection_aggregates,
989 group_by_fields,
990 distinct,
991 having,
992 predicate,
993 order_by,
994 limit,
995 offset,
996 } = lowered;
997
998 for field in group_by_fields {
1000 query = query.group_by(field)?;
1001 }
1002
1003 if distinct {
1005 query = query.distinct();
1006 }
1007 if let Some(fields) = scalar_projection_fields {
1008 query = query.select_fields(fields);
1009 }
1010 for aggregate in grouped_projection_aggregates {
1011 query = query.aggregate(lower_aggregate_call(aggregate)?);
1012 }
1013
1014 for clause in having {
1016 match clause {
1017 ResolvedHavingClause::GroupField { field, op, value } => {
1018 query = query.having_group(field, op, value)?;
1019 }
1020 ResolvedHavingClause::Aggregate {
1021 aggregate_index,
1022 op,
1023 value,
1024 } => {
1025 query = query.having_aggregate(aggregate_index, op, value)?;
1026 }
1027 }
1028 }
1029
1030 Ok(apply_lowered_base_query_shape(
1032 query,
1033 LoweredBaseQueryShape {
1034 predicate,
1035 order_by,
1036 limit,
1037 offset,
1038 },
1039 ))
1040}
1041
1042fn apply_lowered_base_query_shape(
1043 mut query: StructuralQuery,
1044 lowered: LoweredBaseQueryShape,
1045) -> StructuralQuery {
1046 if let Some(predicate) = lowered.predicate {
1047 query = query.filter(predicate);
1048 }
1049 query = apply_order_terms_structural(query, lowered.order_by);
1050 if let Some(limit) = lowered.limit {
1051 query = query.limit(limit);
1052 }
1053 if let Some(offset) = lowered.offset {
1054 query = query.offset(offset);
1055 }
1056
1057 query
1058}
1059
1060pub(in crate::db) fn bind_lowered_sql_query_structural(
1061 model: &'static crate::model::entity::EntityModel,
1062 lowered: LoweredSqlQuery,
1063 consistency: MissingRowPolicy,
1064) -> Result<StructuralQuery, SqlLoweringError> {
1065 match lowered {
1066 LoweredSqlQuery::Select(select) => {
1067 apply_lowered_select_shape(StructuralQuery::new(model, consistency), select)
1068 }
1069 LoweredSqlQuery::Delete(delete) => Ok(bind_lowered_sql_delete_query_structural(
1070 model,
1071 delete,
1072 consistency,
1073 )),
1074 }
1075}
1076
1077pub(in crate::db) fn bind_lowered_sql_delete_query_structural(
1078 model: &'static crate::model::entity::EntityModel,
1079 delete: LoweredBaseQueryShape,
1080 consistency: MissingRowPolicy,
1081) -> StructuralQuery {
1082 apply_lowered_base_query_shape(StructuralQuery::new(model, consistency).delete(), delete)
1083}
1084
1085pub(in crate::db) fn bind_lowered_sql_query<E: EntityKind>(
1086 lowered: LoweredSqlQuery,
1087 consistency: MissingRowPolicy,
1088) -> Result<Query<E>, SqlLoweringError> {
1089 let structural = bind_lowered_sql_query_structural(E::MODEL, lowered, consistency)?;
1090
1091 Ok(Query::from_inner(structural))
1092}
1093
1094fn bind_lowered_sql_global_aggregate_command<E: EntityKind>(
1095 lowered: LoweredSqlGlobalAggregateCommand,
1096 consistency: MissingRowPolicy,
1097) -> SqlGlobalAggregateCommand<E> {
1098 SqlGlobalAggregateCommand {
1099 query: Query::from_inner(apply_lowered_base_query_shape(
1100 StructuralQuery::new(E::MODEL, consistency),
1101 lowered.query,
1102 )),
1103 terminal: lowered.terminal,
1104 }
1105}
1106
1107fn bind_lowered_sql_global_aggregate_command_structural(
1108 model: &'static crate::model::entity::EntityModel,
1109 lowered: LoweredSqlGlobalAggregateCommand,
1110 consistency: MissingRowPolicy,
1111) -> SqlGlobalAggregateCommandCore {
1112 SqlGlobalAggregateCommandCore {
1113 query: apply_lowered_base_query_shape(
1114 StructuralQuery::new(model, consistency),
1115 lowered.query,
1116 ),
1117 terminal: lowered.terminal,
1118 }
1119}
1120
1121fn lower_global_aggregate_terminal(
1122 projection: SqlProjection,
1123) -> Result<SqlGlobalAggregateTerminal, SqlLoweringError> {
1124 let SqlProjection::Items(items) = projection else {
1125 return Err(SqlLoweringError::unsupported_select_projection());
1126 };
1127 if items.len() != 1 {
1128 return Err(SqlLoweringError::unsupported_select_projection());
1129 }
1130
1131 let Some(SqlSelectItem::Aggregate(aggregate)) = items.into_iter().next() else {
1132 return Err(SqlLoweringError::unsupported_select_projection());
1133 };
1134
1135 match lower_sql_aggregate_shape(aggregate)? {
1136 LoweredSqlAggregateShape::CountRows => Ok(SqlGlobalAggregateTerminal::CountRows),
1137 LoweredSqlAggregateShape::CountField(field) => {
1138 Ok(SqlGlobalAggregateTerminal::CountField(field))
1139 }
1140 LoweredSqlAggregateShape::FieldTarget {
1141 kind: SqlAggregateKind::Sum,
1142 field,
1143 } => Ok(SqlGlobalAggregateTerminal::SumField(field)),
1144 LoweredSqlAggregateShape::FieldTarget {
1145 kind: SqlAggregateKind::Avg,
1146 field,
1147 } => Ok(SqlGlobalAggregateTerminal::AvgField(field)),
1148 LoweredSqlAggregateShape::FieldTarget {
1149 kind: SqlAggregateKind::Min,
1150 field,
1151 } => Ok(SqlGlobalAggregateTerminal::MinField(field)),
1152 LoweredSqlAggregateShape::FieldTarget {
1153 kind: SqlAggregateKind::Max,
1154 field,
1155 } => Ok(SqlGlobalAggregateTerminal::MaxField(field)),
1156 LoweredSqlAggregateShape::FieldTarget {
1157 kind: SqlAggregateKind::Count,
1158 ..
1159 } => Err(SqlLoweringError::unsupported_select_projection()),
1160 }
1161}
1162
1163fn lower_sql_aggregate_shape(
1164 call: SqlAggregateCall,
1165) -> Result<LoweredSqlAggregateShape, SqlLoweringError> {
1166 match (call.kind, call.field) {
1167 (SqlAggregateKind::Count, None) => Ok(LoweredSqlAggregateShape::CountRows),
1168 (SqlAggregateKind::Count, Some(field)) => Ok(LoweredSqlAggregateShape::CountField(field)),
1169 (
1170 kind @ (SqlAggregateKind::Sum
1171 | SqlAggregateKind::Avg
1172 | SqlAggregateKind::Min
1173 | SqlAggregateKind::Max),
1174 Some(field),
1175 ) => Ok(LoweredSqlAggregateShape::FieldTarget { kind, field }),
1176 _ => Err(SqlLoweringError::unsupported_select_projection()),
1177 }
1178}
1179
1180fn grouped_projection_aggregate_calls(
1181 projection: &SqlProjection,
1182 group_by_fields: &[String],
1183) -> Result<Vec<SqlAggregateCall>, SqlLoweringError> {
1184 if group_by_fields.is_empty() {
1185 return Err(SqlLoweringError::unsupported_select_group_by());
1186 }
1187
1188 let SqlProjection::Items(items) = projection else {
1189 return Err(SqlLoweringError::unsupported_select_group_by());
1190 };
1191
1192 let mut projected_group_fields = Vec::<String>::new();
1193 let mut aggregate_calls = Vec::<SqlAggregateCall>::new();
1194 let mut seen_aggregate = false;
1195
1196 for item in items {
1197 match item {
1198 SqlSelectItem::Field(field) => {
1199 if seen_aggregate {
1202 return Err(SqlLoweringError::unsupported_select_group_by());
1203 }
1204 projected_group_fields.push(field.clone());
1205 }
1206 SqlSelectItem::Aggregate(aggregate) => {
1207 seen_aggregate = true;
1208 aggregate_calls.push(aggregate.clone());
1209 }
1210 SqlSelectItem::TextFunction(_) => {
1211 return Err(SqlLoweringError::unsupported_select_group_by());
1212 }
1213 }
1214 }
1215
1216 if aggregate_calls.is_empty() || projected_group_fields.as_slice() != group_by_fields {
1217 return Err(SqlLoweringError::unsupported_select_group_by());
1218 }
1219
1220 Ok(aggregate_calls)
1221}
1222
1223fn lower_aggregate_call(
1224 call: SqlAggregateCall,
1225) -> Result<crate::db::query::builder::AggregateExpr, SqlLoweringError> {
1226 match lower_sql_aggregate_shape(call)? {
1227 LoweredSqlAggregateShape::CountRows => Ok(count()),
1228 LoweredSqlAggregateShape::CountField(field) => Ok(count_by(field)),
1229 LoweredSqlAggregateShape::FieldTarget {
1230 kind: SqlAggregateKind::Sum,
1231 field,
1232 } => Ok(sum(field)),
1233 LoweredSqlAggregateShape::FieldTarget {
1234 kind: SqlAggregateKind::Avg,
1235 field,
1236 } => Ok(avg(field)),
1237 LoweredSqlAggregateShape::FieldTarget {
1238 kind: SqlAggregateKind::Min,
1239 field,
1240 } => Ok(min_by(field)),
1241 LoweredSqlAggregateShape::FieldTarget {
1242 kind: SqlAggregateKind::Max,
1243 field,
1244 } => Ok(max_by(field)),
1245 LoweredSqlAggregateShape::FieldTarget {
1246 kind: SqlAggregateKind::Count,
1247 ..
1248 } => Err(SqlLoweringError::unsupported_select_projection()),
1249 }
1250}
1251
1252fn resolve_having_aggregate_index(
1253 target: &SqlAggregateCall,
1254 grouped_projection_aggregates: &[SqlAggregateCall],
1255) -> Result<usize, SqlLoweringError> {
1256 let mut matched = grouped_projection_aggregates
1257 .iter()
1258 .enumerate()
1259 .filter_map(|(index, aggregate)| (aggregate == target).then_some(index));
1260 let Some(index) = matched.next() else {
1261 return Err(SqlLoweringError::unsupported_select_having());
1262 };
1263 if matched.next().is_some() {
1264 return Err(SqlLoweringError::unsupported_select_having());
1265 }
1266
1267 Ok(index)
1268}
1269
1270fn lower_delete_shape(statement: SqlDeleteStatement) -> LoweredBaseQueryShape {
1271 let SqlDeleteStatement {
1272 predicate,
1273 order_by,
1274 limit,
1275 entity: _,
1276 } = statement;
1277
1278 LoweredBaseQueryShape {
1279 predicate,
1280 order_by,
1281 limit,
1282 offset: None,
1283 }
1284}
1285
1286fn apply_order_terms_structural(
1287 mut query: StructuralQuery,
1288 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
1289) -> StructuralQuery {
1290 for term in order_by {
1291 query = match term.direction {
1292 SqlOrderDirection::Asc => query.order_by(term.field),
1293 SqlOrderDirection::Desc => query.order_by_desc(term.field),
1294 };
1295 }
1296
1297 query
1298}
1299
1300fn normalize_having_clauses(
1301 clauses: Vec<SqlHavingClause>,
1302 entity_scope: &[String],
1303) -> Vec<SqlHavingClause> {
1304 clauses
1305 .into_iter()
1306 .map(|clause| SqlHavingClause {
1307 symbol: normalize_having_symbol(clause.symbol, entity_scope),
1308 op: clause.op,
1309 value: clause.value,
1310 })
1311 .collect()
1312}
1313
1314fn normalize_having_symbol(symbol: SqlHavingSymbol, entity_scope: &[String]) -> SqlHavingSymbol {
1315 match symbol {
1316 SqlHavingSymbol::Field(field) => {
1317 SqlHavingSymbol::Field(normalize_identifier_to_scope(field, entity_scope))
1318 }
1319 SqlHavingSymbol::Aggregate(aggregate) => SqlHavingSymbol::Aggregate(
1320 normalize_aggregate_call_identifiers(aggregate, entity_scope),
1321 ),
1322 }
1323}
1324
1325fn normalize_aggregate_call_identifiers(
1326 aggregate: SqlAggregateCall,
1327 entity_scope: &[String],
1328) -> SqlAggregateCall {
1329 SqlAggregateCall {
1330 kind: aggregate.kind,
1331 field: aggregate
1332 .field
1333 .map(|field| normalize_identifier_to_scope(field, entity_scope)),
1334 }
1335}
1336
1337fn sql_entity_scope_candidates(sql_entity: &str, expected_entity: &'static str) -> Vec<String> {
1340 let mut out = Vec::new();
1341 out.push(sql_entity.to_string());
1342 out.push(expected_entity.to_string());
1343
1344 if let Some(last) = identifier_last_segment(sql_entity) {
1345 out.push(last.to_string());
1346 }
1347 if let Some(last) = identifier_last_segment(expected_entity) {
1348 out.push(last.to_string());
1349 }
1350
1351 out
1352}
1353
1354fn normalize_projection_identifiers(
1355 projection: SqlProjection,
1356 entity_scope: &[String],
1357) -> SqlProjection {
1358 match projection {
1359 SqlProjection::All => SqlProjection::All,
1360 SqlProjection::Items(items) => SqlProjection::Items(
1361 items
1362 .into_iter()
1363 .map(|item| match item {
1364 SqlSelectItem::Field(field) => {
1365 SqlSelectItem::Field(normalize_identifier(field, entity_scope))
1366 }
1367 SqlSelectItem::Aggregate(aggregate) => {
1368 SqlSelectItem::Aggregate(SqlAggregateCall {
1369 kind: aggregate.kind,
1370 field: aggregate
1371 .field
1372 .map(|field| normalize_identifier(field, entity_scope)),
1373 })
1374 }
1375 SqlSelectItem::TextFunction(SqlTextFunctionCall {
1376 function,
1377 field,
1378 literal,
1379 literal2,
1380 literal3,
1381 }) => SqlSelectItem::TextFunction(SqlTextFunctionCall {
1382 function,
1383 field: normalize_identifier(field, entity_scope),
1384 literal,
1385 literal2,
1386 literal3,
1387 }),
1388 })
1389 .collect(),
1390 ),
1391 }
1392}
1393
1394fn normalize_order_terms(
1395 terms: Vec<crate::db::sql::parser::SqlOrderTerm>,
1396 entity_scope: &[String],
1397) -> Vec<crate::db::sql::parser::SqlOrderTerm> {
1398 terms
1399 .into_iter()
1400 .map(|term| crate::db::sql::parser::SqlOrderTerm {
1401 field: normalize_identifier(term.field, entity_scope),
1402 direction: term.direction,
1403 })
1404 .collect()
1405}
1406
1407fn normalize_identifier_list(fields: Vec<String>, entity_scope: &[String]) -> Vec<String> {
1408 fields
1409 .into_iter()
1410 .map(|field| normalize_identifier(field, entity_scope))
1411 .collect()
1412}
1413
1414fn adapt_predicate_identifiers_to_scope(
1417 predicate: Predicate,
1418 entity_scope: &[String],
1419) -> Predicate {
1420 rewrite_field_identifiers(predicate, |field| normalize_identifier(field, entity_scope))
1421}
1422
1423fn normalize_identifier(identifier: String, entity_scope: &[String]) -> String {
1424 normalize_identifier_to_scope(identifier, entity_scope)
1425}
1426
1427fn ensure_entity_matches_expected(
1428 sql_entity: &str,
1429 expected_entity: &'static str,
1430) -> Result<(), SqlLoweringError> {
1431 if identifiers_tail_match(sql_entity, expected_entity) {
1432 return Ok(());
1433 }
1434
1435 Err(SqlLoweringError::entity_mismatch(
1436 sql_entity,
1437 expected_entity,
1438 ))
1439}