1#[cfg(test)]
11mod tests;
12
13use crate::{
14 db::{
15 predicate::{MissingRowPolicy, Predicate},
16 query::{
17 builder::aggregate::{avg, count, count_by, max_by, min_by, sum},
18 intent::{Query, QueryError, StructuralQuery},
19 },
20 sql::identifier::{
21 identifier_last_segment, identifiers_tail_match, normalize_identifier_to_scope,
22 rewrite_field_identifiers,
23 },
24 sql::parser::{
25 SqlAggregateCall, SqlAggregateKind, SqlDeleteStatement, SqlExplainMode,
26 SqlExplainStatement, SqlExplainTarget, SqlHavingClause, SqlHavingSymbol,
27 SqlOrderDirection, SqlOrderTerm, SqlProjection, SqlSelectItem, SqlSelectStatement,
28 SqlStatement, parse_sql,
29 },
30 },
31 traits::EntityKind,
32};
33use thiserror::Error as ThisError;
34
35#[derive(Debug)]
47pub(crate) enum SqlCommand<E: EntityKind> {
48 Query(Query<E>),
49 Explain {
50 mode: SqlExplainMode,
51 query: Query<E>,
52 },
53 ExplainGlobalAggregate {
54 mode: SqlExplainMode,
55 command: SqlGlobalAggregateCommand<E>,
56 },
57 DescribeEntity,
58 ShowIndexesEntity,
59 ShowColumnsEntity,
60 ShowEntities,
61}
62
63#[derive(Clone, Debug)]
72pub struct LoweredSqlCommand(LoweredSqlCommandInner);
73
74#[derive(Clone, Debug)]
75enum LoweredSqlCommandInner {
76 Query(LoweredSqlQuery),
77 Explain {
78 mode: SqlExplainMode,
79 query: LoweredSqlQuery,
80 },
81 ExplainGlobalAggregate {
82 mode: SqlExplainMode,
83 command: LoweredSqlGlobalAggregateCommand,
84 },
85 DescribeEntity,
86 ShowIndexesEntity,
87 ShowColumnsEntity,
88 ShowEntities,
89}
90
91impl LoweredSqlCommand {
92 #[must_use]
93 pub(in crate::db) const fn query(&self) -> Option<&LoweredSqlQuery> {
94 match &self.0 {
95 LoweredSqlCommandInner::Query(query) => Some(query),
96 LoweredSqlCommandInner::Explain { .. }
97 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. }
98 | LoweredSqlCommandInner::DescribeEntity
99 | LoweredSqlCommandInner::ShowIndexesEntity
100 | LoweredSqlCommandInner::ShowColumnsEntity
101 | LoweredSqlCommandInner::ShowEntities => None,
102 }
103 }
104}
105
106#[derive(Clone, Debug)]
113pub(crate) enum LoweredSqlQuery {
114 Select(LoweredSelectShape),
115 Delete(LoweredBaseQueryShape),
116}
117
118#[derive(Clone, Debug, Eq, PartialEq)]
126pub(crate) enum SqlGlobalAggregateTerminal {
127 CountRows,
128 CountField(String),
129 SumField(String),
130 AvgField(String),
131 MinField(String),
132 MaxField(String),
133}
134
135#[derive(Clone, Debug)]
144pub(crate) struct LoweredSqlGlobalAggregateCommand {
145 query: LoweredBaseQueryShape,
146 terminal: SqlGlobalAggregateTerminal,
147}
148
149#[derive(Debug)]
156pub(crate) struct SqlGlobalAggregateCommand<E: EntityKind> {
157 query: Query<E>,
158 terminal: SqlGlobalAggregateTerminal,
159}
160
161impl<E: EntityKind> SqlGlobalAggregateCommand<E> {
162 #[must_use]
164 pub(crate) const fn query(&self) -> &Query<E> {
165 &self.query
166 }
167
168 #[must_use]
170 pub(crate) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
171 &self.terminal
172 }
173}
174
175#[derive(Debug)]
185pub(crate) struct StructuralSqlGlobalAggregateCommand {
186 query: StructuralQuery,
187 terminal: SqlGlobalAggregateTerminal,
188}
189
190impl StructuralSqlGlobalAggregateCommand {
191 #[must_use]
193 pub(in crate::db) const fn query(&self) -> &StructuralQuery {
194 &self.query
195 }
196
197 #[must_use]
199 pub(in crate::db) const fn terminal(&self) -> &SqlGlobalAggregateTerminal {
200 &self.terminal
201 }
202}
203
204#[derive(Debug, ThisError)]
211pub(crate) enum SqlLoweringError {
212 #[error("{0}")]
213 Parse(#[from] crate::db::sql::parser::SqlParseError),
214
215 #[error("{0}")]
216 Query(#[from] QueryError),
217
218 #[error("SQL entity '{sql_entity}' does not match requested entity type '{expected_entity}'")]
219 EntityMismatch {
220 sql_entity: String,
221 expected_entity: &'static str,
222 },
223
224 #[error(
225 "unsupported SQL SELECT projection in this release; executable forms are SELECT *, direct field lists, or constrained grouped aggregate projection shapes"
226 )]
227 UnsupportedSelectProjection,
228
229 #[error("unsupported SQL SELECT DISTINCT in this release")]
230 UnsupportedSelectDistinct,
231
232 #[error("unsupported SQL GROUP BY projection shape in this release")]
233 UnsupportedSelectGroupBy,
234
235 #[error("unsupported SQL HAVING shape in this release")]
236 UnsupportedSelectHaving,
237}
238
239#[derive(Clone, Debug)]
250pub(crate) struct PreparedSqlStatement {
251 statement: SqlStatement,
252}
253
254#[derive(Clone, Copy, Debug, Eq, PartialEq)]
255pub(crate) enum LoweredSqlLaneKind {
256 Query,
257 Explain,
258 Describe,
259 ShowIndexes,
260 ShowColumns,
261 ShowEntities,
262}
263
264pub(crate) fn compile_sql_command<E: EntityKind>(
266 sql: &str,
267 consistency: MissingRowPolicy,
268) -> Result<SqlCommand<E>, SqlLoweringError> {
269 let statement = parse_sql(sql)?;
270 compile_sql_command_from_statement::<E>(statement, consistency)
271}
272
273pub(crate) fn compile_sql_command_from_statement<E: EntityKind>(
275 statement: SqlStatement,
276 consistency: MissingRowPolicy,
277) -> Result<SqlCommand<E>, SqlLoweringError> {
278 let prepared = prepare_sql_statement(statement, E::MODEL.entity_name())?;
279 compile_sql_command_from_prepared_statement::<E>(prepared, consistency)
280}
281
282pub(crate) fn compile_sql_command_from_prepared_statement<E: EntityKind>(
284 prepared: PreparedSqlStatement,
285 consistency: MissingRowPolicy,
286) -> Result<SqlCommand<E>, SqlLoweringError> {
287 let lowered = lower_sql_command_from_prepared_statement(prepared, E::MODEL.primary_key.name)?;
288
289 bind_lowered_sql_command::<E>(lowered, consistency)
290}
291
292pub(crate) fn lower_sql_command_from_prepared_statement(
294 prepared: PreparedSqlStatement,
295 primary_key_field: &str,
296) -> Result<LoweredSqlCommand, SqlLoweringError> {
297 lower_prepared_statement(prepared.statement, primary_key_field)
298}
299
300pub(crate) const fn lowered_sql_command_lane(command: &LoweredSqlCommand) -> LoweredSqlLaneKind {
301 match command.0 {
302 LoweredSqlCommandInner::Query(_) => LoweredSqlLaneKind::Query,
303 LoweredSqlCommandInner::Explain { .. }
304 | LoweredSqlCommandInner::ExplainGlobalAggregate { .. } => LoweredSqlLaneKind::Explain,
305 LoweredSqlCommandInner::DescribeEntity => LoweredSqlLaneKind::Describe,
306 LoweredSqlCommandInner::ShowIndexesEntity => LoweredSqlLaneKind::ShowIndexes,
307 LoweredSqlCommandInner::ShowColumnsEntity => LoweredSqlLaneKind::ShowColumns,
308 LoweredSqlCommandInner::ShowEntities => LoweredSqlLaneKind::ShowEntities,
309 }
310}
311
312pub(crate) fn render_lowered_sql_explain_plan_or_json(
313 lowered: &LoweredSqlCommand,
314 model: &'static crate::model::entity::EntityModel,
315 consistency: MissingRowPolicy,
316) -> Result<Option<String>, SqlLoweringError> {
317 let LoweredSqlCommandInner::Explain { mode, query } = &lowered.0 else {
318 return Ok(None);
319 };
320
321 let query = bind_lowered_sql_query_structural(model, query.clone(), consistency)?;
322 let rendered = match mode {
323 SqlExplainMode::Plan | SqlExplainMode::Json => {
324 let plan = query.build_plan()?;
325 let explain = plan.explain_with_model(model);
326
327 match mode {
328 SqlExplainMode::Plan => explain.render_text_canonical(),
329 SqlExplainMode::Json => explain.render_json_canonical(),
330 SqlExplainMode::Execution => unreachable!("execution mode handled above"),
331 }
332 }
333 SqlExplainMode::Execution => query.explain_execution_text()?,
334 };
335
336 Ok(Some(rendered))
337}
338
339pub(crate) fn bind_lowered_sql_explain_global_aggregate_structural(
342 lowered: &LoweredSqlCommand,
343 model: &'static crate::model::entity::EntityModel,
344 consistency: MissingRowPolicy,
345) -> Option<(SqlExplainMode, StructuralSqlGlobalAggregateCommand)> {
346 let LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } = &lowered.0 else {
347 return None;
348 };
349
350 Some((
351 *mode,
352 bind_lowered_sql_global_aggregate_command_structural(model, command.clone(), consistency),
353 ))
354}
355
356pub(crate) fn bind_lowered_sql_command<E: EntityKind>(
358 lowered: LoweredSqlCommand,
359 consistency: MissingRowPolicy,
360) -> Result<SqlCommand<E>, SqlLoweringError> {
361 match lowered.0 {
362 LoweredSqlCommandInner::Query(query) => Ok(SqlCommand::Query(bind_lowered_sql_query::<E>(
363 query,
364 consistency,
365 )?)),
366 LoweredSqlCommandInner::Explain { mode, query } => Ok(SqlCommand::Explain {
367 mode,
368 query: bind_lowered_sql_query::<E>(query, consistency)?,
369 }),
370 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command } => {
371 Ok(SqlCommand::ExplainGlobalAggregate {
372 mode,
373 command: bind_lowered_sql_global_aggregate_command::<E>(command, consistency),
374 })
375 }
376 LoweredSqlCommandInner::DescribeEntity => Ok(SqlCommand::DescribeEntity),
377 LoweredSqlCommandInner::ShowIndexesEntity => Ok(SqlCommand::ShowIndexesEntity),
378 LoweredSqlCommandInner::ShowColumnsEntity => Ok(SqlCommand::ShowColumnsEntity),
379 LoweredSqlCommandInner::ShowEntities => Ok(SqlCommand::ShowEntities),
380 }
381}
382
383pub(crate) fn prepare_sql_statement(
385 statement: SqlStatement,
386 expected_entity: &'static str,
387) -> Result<PreparedSqlStatement, SqlLoweringError> {
388 let statement = prepare_statement(statement, expected_entity)?;
389
390 Ok(PreparedSqlStatement { statement })
391}
392
393pub(crate) fn compile_sql_global_aggregate_command<E: EntityKind>(
395 sql: &str,
396 consistency: MissingRowPolicy,
397) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
398 let statement = parse_sql(sql)?;
399 let prepared = prepare_sql_statement(statement, E::MODEL.entity_name())?;
400 compile_sql_global_aggregate_command_from_prepared::<E>(prepared, consistency)
401}
402
403fn compile_sql_global_aggregate_command_from_prepared<E: EntityKind>(
404 prepared: PreparedSqlStatement,
405 consistency: MissingRowPolicy,
406) -> Result<SqlGlobalAggregateCommand<E>, SqlLoweringError> {
407 let SqlStatement::Select(statement) = prepared.statement else {
408 return Err(SqlLoweringError::UnsupportedSelectProjection);
409 };
410
411 Ok(bind_lowered_sql_global_aggregate_command::<E>(
412 lower_global_aggregate_select_shape(statement)?,
413 consistency,
414 ))
415}
416
417fn prepare_statement(
418 statement: SqlStatement,
419 expected_entity: &'static str,
420) -> Result<SqlStatement, SqlLoweringError> {
421 match statement {
422 SqlStatement::Select(statement) => Ok(SqlStatement::Select(prepare_select_statement(
423 statement,
424 expected_entity,
425 )?)),
426 SqlStatement::Delete(statement) => Ok(SqlStatement::Delete(prepare_delete_statement(
427 statement,
428 expected_entity,
429 )?)),
430 SqlStatement::Explain(statement) => Ok(SqlStatement::Explain(prepare_explain_statement(
431 statement,
432 expected_entity,
433 )?)),
434 SqlStatement::Describe(statement) => {
435 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
436
437 Ok(SqlStatement::Describe(statement))
438 }
439 SqlStatement::ShowIndexes(statement) => {
440 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
441
442 Ok(SqlStatement::ShowIndexes(statement))
443 }
444 SqlStatement::ShowColumns(statement) => {
445 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
446
447 Ok(SqlStatement::ShowColumns(statement))
448 }
449 SqlStatement::ShowEntities(statement) => Ok(SqlStatement::ShowEntities(statement)),
450 }
451}
452
453fn prepare_explain_statement(
454 statement: SqlExplainStatement,
455 expected_entity: &'static str,
456) -> Result<SqlExplainStatement, SqlLoweringError> {
457 let target = match statement.statement {
458 SqlExplainTarget::Select(select_statement) => {
459 SqlExplainTarget::Select(prepare_select_statement(select_statement, expected_entity)?)
460 }
461 SqlExplainTarget::Delete(delete_statement) => {
462 SqlExplainTarget::Delete(prepare_delete_statement(delete_statement, expected_entity)?)
463 }
464 };
465
466 Ok(SqlExplainStatement {
467 mode: statement.mode,
468 statement: target,
469 })
470}
471
472fn prepare_select_statement(
473 mut statement: SqlSelectStatement,
474 expected_entity: &'static str,
475) -> Result<SqlSelectStatement, SqlLoweringError> {
476 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
477 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
478 statement.projection =
479 normalize_projection_identifiers(statement.projection, entity_scope.as_slice());
480 statement.group_by = normalize_identifier_list(statement.group_by, entity_scope.as_slice());
481 statement.predicate = statement
482 .predicate
483 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
484 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
485 statement.having = normalize_having_clauses(statement.having, entity_scope.as_slice());
486
487 Ok(statement)
488}
489
490fn prepare_delete_statement(
491 mut statement: SqlDeleteStatement,
492 expected_entity: &'static str,
493) -> Result<SqlDeleteStatement, SqlLoweringError> {
494 ensure_entity_matches_expected(statement.entity.as_str(), expected_entity)?;
495 let entity_scope = sql_entity_scope_candidates(statement.entity.as_str(), expected_entity);
496 statement.predicate = statement
497 .predicate
498 .map(|predicate| adapt_predicate_identifiers_to_scope(predicate, entity_scope.as_slice()));
499 statement.order_by = normalize_order_terms(statement.order_by, entity_scope.as_slice());
500
501 Ok(statement)
502}
503
504fn lower_prepared_statement(
505 statement: SqlStatement,
506 primary_key_field: &str,
507) -> Result<LoweredSqlCommand, SqlLoweringError> {
508 match statement {
509 SqlStatement::Select(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
510 LoweredSqlQuery::Select(lower_select_shape(statement, primary_key_field)?),
511 ))),
512 SqlStatement::Delete(statement) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Query(
513 LoweredSqlQuery::Delete(lower_delete_shape(statement)),
514 ))),
515 SqlStatement::Explain(statement) => lower_explain_prepared(statement, primary_key_field),
516 SqlStatement::Describe(_) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::DescribeEntity)),
517 SqlStatement::ShowIndexes(_) => {
518 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowIndexesEntity))
519 }
520 SqlStatement::ShowColumns(_) => {
521 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowColumnsEntity))
522 }
523 SqlStatement::ShowEntities(_) => {
524 Ok(LoweredSqlCommand(LoweredSqlCommandInner::ShowEntities))
525 }
526 }
527}
528
529fn lower_explain_prepared(
530 statement: SqlExplainStatement,
531 primary_key_field: &str,
532) -> Result<LoweredSqlCommand, SqlLoweringError> {
533 let mode = statement.mode;
534
535 match statement.statement {
536 SqlExplainTarget::Select(select_statement) => {
537 lower_explain_select_prepared(select_statement, mode, primary_key_field)
538 }
539 SqlExplainTarget::Delete(delete_statement) => {
540 Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
541 mode,
542 query: LoweredSqlQuery::Delete(lower_delete_shape(delete_statement)),
543 }))
544 }
545 }
546}
547
548fn lower_explain_select_prepared(
549 statement: SqlSelectStatement,
550 mode: SqlExplainMode,
551 primary_key_field: &str,
552) -> Result<LoweredSqlCommand, SqlLoweringError> {
553 match lower_select_shape(statement.clone(), primary_key_field) {
554 Ok(query) => Ok(LoweredSqlCommand(LoweredSqlCommandInner::Explain {
555 mode,
556 query: LoweredSqlQuery::Select(query),
557 })),
558 Err(SqlLoweringError::UnsupportedSelectProjection) => {
559 let command = lower_global_aggregate_select_shape(statement)?;
560
561 Ok(LoweredSqlCommand(
562 LoweredSqlCommandInner::ExplainGlobalAggregate { mode, command },
563 ))
564 }
565 Err(err) => Err(err),
566 }
567}
568
569fn lower_global_aggregate_select_shape(
570 statement: SqlSelectStatement,
571) -> Result<LoweredSqlGlobalAggregateCommand, SqlLoweringError> {
572 let SqlSelectStatement {
573 projection,
574 predicate,
575 distinct,
576 group_by,
577 having,
578 order_by,
579 limit,
580 offset,
581 entity: _,
582 } = statement;
583
584 if distinct {
585 return Err(SqlLoweringError::UnsupportedSelectDistinct);
586 }
587 if !group_by.is_empty() {
588 return Err(SqlLoweringError::UnsupportedSelectGroupBy);
589 }
590 if !having.is_empty() {
591 return Err(SqlLoweringError::UnsupportedSelectHaving);
592 }
593
594 let terminal = lower_global_aggregate_terminal(projection)?;
595
596 Ok(LoweredSqlGlobalAggregateCommand {
597 query: LoweredBaseQueryShape {
598 predicate,
599 order_by,
600 limit,
601 offset,
602 },
603 terminal,
604 })
605}
606
607#[derive(Clone, Debug)]
615enum ResolvedHavingClause {
616 GroupField {
617 field: String,
618 op: crate::db::predicate::CompareOp,
619 value: crate::value::Value,
620 },
621 Aggregate {
622 aggregate_index: usize,
623 op: crate::db::predicate::CompareOp,
624 value: crate::value::Value,
625 },
626}
627
628#[derive(Clone, Debug)]
635pub(crate) struct LoweredSelectShape {
636 scalar_projection_fields: Option<Vec<String>>,
637 grouped_projection_aggregates: Vec<SqlAggregateCall>,
638 group_by_fields: Vec<String>,
639 distinct: bool,
640 having: Vec<ResolvedHavingClause>,
641 predicate: Option<Predicate>,
642 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
643 limit: Option<u32>,
644 offset: Option<u32>,
645}
646
647#[derive(Clone, Debug)]
656pub(crate) struct LoweredBaseQueryShape {
657 predicate: Option<Predicate>,
658 order_by: Vec<SqlOrderTerm>,
659 limit: Option<u32>,
660 offset: Option<u32>,
661}
662
663fn lower_select_shape(
664 statement: SqlSelectStatement,
665 primary_key_field: &str,
666) -> Result<LoweredSelectShape, SqlLoweringError> {
667 let SqlSelectStatement {
668 projection,
669 predicate,
670 distinct,
671 group_by,
672 having,
673 order_by,
674 limit,
675 offset,
676 entity: _,
677 } = statement;
678 let projection_for_having = projection.clone();
679
680 let (scalar_projection_fields, grouped_projection_aggregates) = if group_by.is_empty() {
682 let scalar_projection_fields =
683 lower_scalar_projection_fields(projection, distinct, primary_key_field)?;
684 (scalar_projection_fields, Vec::new())
685 } else {
686 if distinct {
687 return Err(SqlLoweringError::UnsupportedSelectDistinct);
688 }
689 let grouped_projection_aggregates =
690 grouped_projection_aggregate_calls(&projection, group_by.as_slice())?;
691 (None, grouped_projection_aggregates)
692 };
693
694 let having = lower_having_clauses(
696 having,
697 &projection_for_having,
698 group_by.as_slice(),
699 grouped_projection_aggregates.as_slice(),
700 )?;
701
702 Ok(LoweredSelectShape {
703 scalar_projection_fields,
704 grouped_projection_aggregates,
705 group_by_fields: group_by,
706 distinct,
707 having,
708 predicate,
709 order_by,
710 limit,
711 offset,
712 })
713}
714
715fn lower_scalar_projection_fields(
716 projection: SqlProjection,
717 distinct: bool,
718 primary_key_field: &str,
719) -> Result<Option<Vec<String>>, SqlLoweringError> {
720 let SqlProjection::Items(items) = projection else {
721 if distinct {
722 return Ok(None);
723 }
724
725 return Ok(None);
726 };
727
728 let has_aggregate = items
729 .iter()
730 .any(|item| matches!(item, SqlSelectItem::Aggregate(_)));
731 if has_aggregate {
732 return Err(SqlLoweringError::UnsupportedSelectProjection);
733 }
734
735 let fields = items
736 .into_iter()
737 .map(|item| match item {
738 SqlSelectItem::Field(field) => Ok(field),
739 SqlSelectItem::Aggregate(_) => Err(SqlLoweringError::UnsupportedSelectProjection),
740 })
741 .collect::<Result<Vec<_>, _>>()?;
742
743 validate_scalar_distinct_projection(distinct, fields.as_slice(), primary_key_field)?;
744
745 Ok(Some(fields))
746}
747
748fn validate_scalar_distinct_projection(
749 distinct: bool,
750 projection_fields: &[String],
751 primary_key_field: &str,
752) -> Result<(), SqlLoweringError> {
753 if !distinct {
754 return Ok(());
755 }
756
757 if projection_fields.is_empty() {
758 return Ok(());
759 }
760
761 let has_primary_key_field = projection_fields
762 .iter()
763 .any(|field| field == primary_key_field);
764 if !has_primary_key_field {
765 return Err(SqlLoweringError::UnsupportedSelectDistinct);
766 }
767
768 Ok(())
769}
770
771fn lower_having_clauses(
772 having_clauses: Vec<SqlHavingClause>,
773 projection: &SqlProjection,
774 group_by_fields: &[String],
775 grouped_projection_aggregates: &[SqlAggregateCall],
776) -> Result<Vec<ResolvedHavingClause>, SqlLoweringError> {
777 if having_clauses.is_empty() {
778 return Ok(Vec::new());
779 }
780 if group_by_fields.is_empty() {
781 return Err(SqlLoweringError::UnsupportedSelectHaving);
782 }
783
784 let projection_aggregates = grouped_projection_aggregate_calls(projection, group_by_fields)
785 .map_err(|_| SqlLoweringError::UnsupportedSelectHaving)?;
786 if projection_aggregates.as_slice() != grouped_projection_aggregates {
787 return Err(SqlLoweringError::UnsupportedSelectHaving);
788 }
789
790 let mut lowered = Vec::with_capacity(having_clauses.len());
791 for clause in having_clauses {
792 match clause.symbol {
793 SqlHavingSymbol::Field(field) => lowered.push(ResolvedHavingClause::GroupField {
794 field,
795 op: clause.op,
796 value: clause.value,
797 }),
798 SqlHavingSymbol::Aggregate(aggregate) => {
799 let aggregate_index =
800 resolve_having_aggregate_index(&aggregate, grouped_projection_aggregates)?;
801 lowered.push(ResolvedHavingClause::Aggregate {
802 aggregate_index,
803 op: clause.op,
804 value: clause.value,
805 });
806 }
807 }
808 }
809
810 Ok(lowered)
811}
812
813fn apply_lowered_select_shape(
814 mut query: StructuralQuery,
815 lowered: LoweredSelectShape,
816) -> Result<StructuralQuery, SqlLoweringError> {
817 for field in lowered.group_by_fields {
819 query = query.group_by(field)?;
820 }
821
822 if lowered.distinct {
824 query = query.distinct();
825 }
826 if let Some(fields) = lowered.scalar_projection_fields {
827 query = query.select_fields(fields);
828 }
829 for aggregate in lowered.grouped_projection_aggregates {
830 query = query.aggregate(lower_aggregate_call(aggregate)?);
831 }
832
833 for clause in lowered.having {
835 match clause {
836 ResolvedHavingClause::GroupField { field, op, value } => {
837 query = query.having_group(field, op, value)?;
838 }
839 ResolvedHavingClause::Aggregate {
840 aggregate_index,
841 op,
842 value,
843 } => {
844 query = query.having_aggregate(aggregate_index, op, value)?;
845 }
846 }
847 }
848
849 if let Some(predicate) = lowered.predicate {
851 query = query.filter(predicate);
852 }
853 query = apply_order_terms_structural(query, lowered.order_by);
854 if let Some(limit) = lowered.limit {
855 query = query.limit(limit);
856 }
857 if let Some(offset) = lowered.offset {
858 query = query.offset(offset);
859 }
860
861 Ok(query)
862}
863
864fn apply_lowered_base_query_shape(
865 mut query: StructuralQuery,
866 lowered: LoweredBaseQueryShape,
867) -> StructuralQuery {
868 if let Some(predicate) = lowered.predicate {
869 query = query.filter(predicate);
870 }
871 query = apply_order_terms_structural(query, lowered.order_by);
872 if let Some(limit) = lowered.limit {
873 query = query.limit(limit);
874 }
875 if let Some(offset) = lowered.offset {
876 query = query.offset(offset);
877 }
878
879 query
880}
881
882pub(in crate::db) fn bind_lowered_sql_query_structural(
883 model: &'static crate::model::entity::EntityModel,
884 lowered: LoweredSqlQuery,
885 consistency: MissingRowPolicy,
886) -> Result<StructuralQuery, SqlLoweringError> {
887 match lowered {
888 LoweredSqlQuery::Select(select) => {
889 apply_lowered_select_shape(StructuralQuery::new(model, consistency), select)
890 }
891 LoweredSqlQuery::Delete(delete) => Ok(apply_lowered_base_query_shape(
892 StructuralQuery::new(model, consistency).delete(),
893 delete,
894 )),
895 }
896}
897
898pub(in crate::db) fn bind_lowered_sql_query<E: EntityKind>(
899 lowered: LoweredSqlQuery,
900 consistency: MissingRowPolicy,
901) -> Result<Query<E>, SqlLoweringError> {
902 let structural = bind_lowered_sql_query_structural(E::MODEL, lowered, consistency)?;
903
904 Ok(Query::from_inner(structural))
905}
906
907fn bind_lowered_sql_global_aggregate_command<E: EntityKind>(
908 lowered: LoweredSqlGlobalAggregateCommand,
909 consistency: MissingRowPolicy,
910) -> SqlGlobalAggregateCommand<E> {
911 SqlGlobalAggregateCommand {
912 query: Query::from_inner(apply_lowered_base_query_shape(
913 StructuralQuery::new(E::MODEL, consistency),
914 lowered.query,
915 )),
916 terminal: lowered.terminal,
917 }
918}
919
920fn bind_lowered_sql_global_aggregate_command_structural(
921 model: &'static crate::model::entity::EntityModel,
922 lowered: LoweredSqlGlobalAggregateCommand,
923 consistency: MissingRowPolicy,
924) -> StructuralSqlGlobalAggregateCommand {
925 StructuralSqlGlobalAggregateCommand {
926 query: apply_lowered_base_query_shape(
927 StructuralQuery::new(model, consistency),
928 lowered.query,
929 ),
930 terminal: lowered.terminal,
931 }
932}
933
934fn lower_global_aggregate_terminal(
935 projection: SqlProjection,
936) -> Result<SqlGlobalAggregateTerminal, SqlLoweringError> {
937 let SqlProjection::Items(items) = projection else {
938 return Err(SqlLoweringError::UnsupportedSelectProjection);
939 };
940 if items.len() != 1 {
941 return Err(SqlLoweringError::UnsupportedSelectProjection);
942 }
943
944 let Some(SqlSelectItem::Aggregate(aggregate)) = items.into_iter().next() else {
945 return Err(SqlLoweringError::UnsupportedSelectProjection);
946 };
947
948 match (aggregate.kind, aggregate.field) {
949 (SqlAggregateKind::Count, None) => Ok(SqlGlobalAggregateTerminal::CountRows),
950 (SqlAggregateKind::Count, Some(field)) => Ok(SqlGlobalAggregateTerminal::CountField(field)),
951 (SqlAggregateKind::Sum, Some(field)) => Ok(SqlGlobalAggregateTerminal::SumField(field)),
952 (SqlAggregateKind::Avg, Some(field)) => Ok(SqlGlobalAggregateTerminal::AvgField(field)),
953 (SqlAggregateKind::Min, Some(field)) => Ok(SqlGlobalAggregateTerminal::MinField(field)),
954 (SqlAggregateKind::Max, Some(field)) => Ok(SqlGlobalAggregateTerminal::MaxField(field)),
955 _ => Err(SqlLoweringError::UnsupportedSelectProjection),
956 }
957}
958
959fn grouped_projection_aggregate_calls(
960 projection: &SqlProjection,
961 group_by_fields: &[String],
962) -> Result<Vec<SqlAggregateCall>, SqlLoweringError> {
963 if group_by_fields.is_empty() {
964 return Err(SqlLoweringError::UnsupportedSelectGroupBy);
965 }
966
967 let SqlProjection::Items(items) = projection else {
968 return Err(SqlLoweringError::UnsupportedSelectGroupBy);
969 };
970
971 let mut projected_group_fields = Vec::<String>::new();
972 let mut aggregate_calls = Vec::<SqlAggregateCall>::new();
973 let mut seen_aggregate = false;
974
975 for item in items {
976 match item {
977 SqlSelectItem::Field(field) => {
978 if seen_aggregate {
981 return Err(SqlLoweringError::UnsupportedSelectGroupBy);
982 }
983 projected_group_fields.push(field.clone());
984 }
985 SqlSelectItem::Aggregate(aggregate) => {
986 seen_aggregate = true;
987 aggregate_calls.push(aggregate.clone());
988 }
989 }
990 }
991
992 if aggregate_calls.is_empty() || projected_group_fields.as_slice() != group_by_fields {
993 return Err(SqlLoweringError::UnsupportedSelectGroupBy);
994 }
995
996 Ok(aggregate_calls)
997}
998
999fn lower_aggregate_call(
1000 call: SqlAggregateCall,
1001) -> Result<crate::db::query::builder::AggregateExpr, SqlLoweringError> {
1002 match (call.kind, call.field) {
1003 (SqlAggregateKind::Count, None) => Ok(count()),
1004 (SqlAggregateKind::Count, Some(field)) => Ok(count_by(field)),
1005 (SqlAggregateKind::Sum, Some(field)) => Ok(sum(field)),
1006 (SqlAggregateKind::Avg, Some(field)) => Ok(avg(field)),
1007 (SqlAggregateKind::Min, Some(field)) => Ok(min_by(field)),
1008 (SqlAggregateKind::Max, Some(field)) => Ok(max_by(field)),
1009 _ => Err(SqlLoweringError::UnsupportedSelectProjection),
1010 }
1011}
1012
1013fn resolve_having_aggregate_index(
1014 target: &SqlAggregateCall,
1015 grouped_projection_aggregates: &[SqlAggregateCall],
1016) -> Result<usize, SqlLoweringError> {
1017 let mut matched = grouped_projection_aggregates
1018 .iter()
1019 .enumerate()
1020 .filter_map(|(index, aggregate)| (aggregate == target).then_some(index));
1021 let Some(index) = matched.next() else {
1022 return Err(SqlLoweringError::UnsupportedSelectHaving);
1023 };
1024 if matched.next().is_some() {
1025 return Err(SqlLoweringError::UnsupportedSelectHaving);
1026 }
1027
1028 Ok(index)
1029}
1030
1031fn lower_delete_shape(statement: SqlDeleteStatement) -> LoweredBaseQueryShape {
1032 let SqlDeleteStatement {
1033 predicate,
1034 order_by,
1035 limit,
1036 entity: _,
1037 } = statement;
1038
1039 LoweredBaseQueryShape {
1040 predicate,
1041 order_by,
1042 limit,
1043 offset: None,
1044 }
1045}
1046
1047fn apply_order_terms_structural(
1048 mut query: StructuralQuery,
1049 order_by: Vec<crate::db::sql::parser::SqlOrderTerm>,
1050) -> StructuralQuery {
1051 for term in order_by {
1052 query = match term.direction {
1053 SqlOrderDirection::Asc => query.order_by(term.field),
1054 SqlOrderDirection::Desc => query.order_by_desc(term.field),
1055 };
1056 }
1057
1058 query
1059}
1060
1061fn normalize_having_clauses(
1062 clauses: Vec<SqlHavingClause>,
1063 entity_scope: &[String],
1064) -> Vec<SqlHavingClause> {
1065 clauses
1066 .into_iter()
1067 .map(|clause| SqlHavingClause {
1068 symbol: normalize_having_symbol(clause.symbol, entity_scope),
1069 op: clause.op,
1070 value: clause.value,
1071 })
1072 .collect()
1073}
1074
1075fn normalize_having_symbol(symbol: SqlHavingSymbol, entity_scope: &[String]) -> SqlHavingSymbol {
1076 match symbol {
1077 SqlHavingSymbol::Field(field) => {
1078 SqlHavingSymbol::Field(normalize_identifier_to_scope(field, entity_scope))
1079 }
1080 SqlHavingSymbol::Aggregate(aggregate) => SqlHavingSymbol::Aggregate(
1081 normalize_aggregate_call_identifiers(aggregate, entity_scope),
1082 ),
1083 }
1084}
1085
1086fn normalize_aggregate_call_identifiers(
1087 aggregate: SqlAggregateCall,
1088 entity_scope: &[String],
1089) -> SqlAggregateCall {
1090 SqlAggregateCall {
1091 kind: aggregate.kind,
1092 field: aggregate
1093 .field
1094 .map(|field| normalize_identifier_to_scope(field, entity_scope)),
1095 }
1096}
1097
1098fn sql_entity_scope_candidates(sql_entity: &str, expected_entity: &'static str) -> Vec<String> {
1101 let mut out = Vec::new();
1102 out.push(sql_entity.to_string());
1103 out.push(expected_entity.to_string());
1104
1105 if let Some(last) = identifier_last_segment(sql_entity) {
1106 out.push(last.to_string());
1107 }
1108 if let Some(last) = identifier_last_segment(expected_entity) {
1109 out.push(last.to_string());
1110 }
1111
1112 out
1113}
1114
1115fn normalize_projection_identifiers(
1116 projection: SqlProjection,
1117 entity_scope: &[String],
1118) -> SqlProjection {
1119 match projection {
1120 SqlProjection::All => SqlProjection::All,
1121 SqlProjection::Items(items) => SqlProjection::Items(
1122 items
1123 .into_iter()
1124 .map(|item| match item {
1125 SqlSelectItem::Field(field) => {
1126 SqlSelectItem::Field(normalize_identifier(field, entity_scope))
1127 }
1128 SqlSelectItem::Aggregate(aggregate) => {
1129 SqlSelectItem::Aggregate(SqlAggregateCall {
1130 kind: aggregate.kind,
1131 field: aggregate
1132 .field
1133 .map(|field| normalize_identifier(field, entity_scope)),
1134 })
1135 }
1136 })
1137 .collect(),
1138 ),
1139 }
1140}
1141
1142fn normalize_order_terms(
1143 terms: Vec<crate::db::sql::parser::SqlOrderTerm>,
1144 entity_scope: &[String],
1145) -> Vec<crate::db::sql::parser::SqlOrderTerm> {
1146 terms
1147 .into_iter()
1148 .map(|term| crate::db::sql::parser::SqlOrderTerm {
1149 field: normalize_identifier(term.field, entity_scope),
1150 direction: term.direction,
1151 })
1152 .collect()
1153}
1154
1155fn normalize_identifier_list(fields: Vec<String>, entity_scope: &[String]) -> Vec<String> {
1156 fields
1157 .into_iter()
1158 .map(|field| normalize_identifier(field, entity_scope))
1159 .collect()
1160}
1161
1162fn adapt_predicate_identifiers_to_scope(
1165 predicate: Predicate,
1166 entity_scope: &[String],
1167) -> Predicate {
1168 rewrite_field_identifiers(predicate, |field| normalize_identifier(field, entity_scope))
1169}
1170
1171fn normalize_identifier(identifier: String, entity_scope: &[String]) -> String {
1172 normalize_identifier_to_scope(identifier, entity_scope)
1173}
1174
1175fn ensure_entity_matches_expected(
1176 sql_entity: &str,
1177 expected_entity: &'static str,
1178) -> Result<(), SqlLoweringError> {
1179 if identifiers_tail_match(sql_entity, expected_entity) {
1180 return Ok(());
1181 }
1182
1183 Err(SqlLoweringError::EntityMismatch {
1184 sql_entity: sql_entity.to_string(),
1185 expected_entity,
1186 })
1187}