1use crate::error::{HematiteError, Result};
4use crate::parser::types::{LiteralValue, SqlTypeName};
5
6#[derive(Debug, Clone)]
7pub enum Statement {
8 Begin,
9 Commit,
10 Rollback,
11 Savepoint(String),
12 RollbackToSavepoint(String),
13 ReleaseSavepoint(String),
14 Explain(ExplainStatement),
15 Describe(DescribeStatement),
16 ShowTables,
17 ShowViews,
18 ShowIndexes(Option<String>),
19 ShowTriggers(Option<String>),
20 ShowCreateTable(String),
21 ShowCreateView(String),
22 Select(SelectStatement),
23 SelectInto(SelectIntoStatement),
24 Update(UpdateStatement),
25 Insert(InsertStatement),
26 Delete(DeleteStatement),
27 Create(CreateStatement),
28 CreateView(CreateViewStatement),
29 CreateTrigger(CreateTriggerStatement),
30 CreateIndex(CreateIndexStatement),
31 Alter(AlterStatement),
32 Drop(DropStatement),
33 DropView(DropViewStatement),
34 DropTrigger(DropTriggerStatement),
35 DropIndex(DropIndexStatement),
36}
37
38#[derive(Debug, Clone)]
39pub struct SelectIntoStatement {
40 pub table: String,
41 pub query: SelectStatement,
42}
43
44#[derive(Debug, Clone)]
45pub struct ExplainStatement {
46 pub statement: Box<Statement>,
47}
48
49#[derive(Debug, Clone)]
50pub struct DescribeStatement {
51 pub table: String,
52}
53
54#[derive(Debug, Clone)]
55pub struct SelectStatement {
56 pub with_clause: Vec<CommonTableExpression>,
57 pub distinct: bool,
58 pub columns: Vec<SelectItem>,
59 pub column_aliases: Vec<Option<String>>,
60 pub from: TableReference,
61 pub where_clause: Option<WhereClause>,
62 pub group_by: Vec<Expression>,
63 pub having_clause: Option<WhereClause>,
64 pub order_by: Vec<OrderByItem>,
65 pub limit: Option<usize>,
66 pub offset: Option<usize>,
67 pub set_operation: Option<SetOperation>,
68}
69
70#[derive(Debug, Clone)]
71pub struct CommonTableExpression {
72 pub name: String,
73 pub recursive: bool,
74 pub query: Box<SelectStatement>,
75}
76
77#[derive(Debug, Clone)]
78pub struct SetOperation {
79 pub operator: SetOperator,
80 pub right: Box<SelectStatement>,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum SetOperator {
85 Union,
86 UnionAll,
87 Intersect,
88 Except,
89}
90
91#[derive(Debug, Clone)]
92pub enum SelectItem {
93 Wildcard,
94 Column(String),
95 Expression(Expression),
96 CountAll,
97 Aggregate {
98 function: AggregateFunction,
99 column: String,
100 },
101 Window {
102 function: WindowFunction,
103 window: WindowSpec,
104 },
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum AggregateFunction {
109 Count,
110 Sum,
111 Avg,
112 Min,
113 Max,
114}
115
116#[derive(Debug, Clone)]
117pub struct WindowSpec {
118 pub partition_by: Vec<Expression>,
119 pub order_by: Vec<OrderByItem>,
120}
121
122#[derive(Debug, Clone)]
123pub enum WindowFunction {
124 RowNumber,
125 Rank,
126 DenseRank,
127 Aggregate {
128 function: AggregateFunction,
129 target: AggregateTarget,
130 },
131}
132
133#[derive(Debug, Clone)]
134pub enum TableReference {
135 Table(String, Option<String>),
136 Derived {
137 subquery: Box<SelectStatement>,
138 alias: String,
139 },
140 CrossJoin(Box<TableReference>, Box<TableReference>),
141 InnerJoin {
142 left: Box<TableReference>,
143 right: Box<TableReference>,
144 on: Condition,
145 },
146 LeftJoin {
147 left: Box<TableReference>,
148 right: Box<TableReference>,
149 on: Condition,
150 },
151 RightJoin {
152 left: Box<TableReference>,
153 right: Box<TableReference>,
154 on: Condition,
155 },
156 FullOuterJoin {
157 left: Box<TableReference>,
158 right: Box<TableReference>,
159 on: Condition,
160 },
161}
162
163#[derive(Debug, Clone, PartialEq, Eq)]
164pub struct TableBinding {
165 pub table_name: String,
166 pub alias: Option<String>,
167}
168
169#[derive(Debug, Clone)]
170#[cfg(test)]
171struct SourceBinding {
172 source_name: String,
173 alias: Option<String>,
174 columns: Vec<String>,
175 has_hidden_rowid: bool,
176}
177
178#[derive(Debug, Clone)]
179pub struct WhereClause {
180 pub conditions: Vec<Condition>,
181}
182
183#[derive(Debug, Clone)]
184pub struct OrderByItem {
185 pub column: String,
186 pub direction: SortDirection,
187}
188
189#[derive(Debug, Clone, Copy, PartialEq, Eq)]
190pub enum SortDirection {
191 Asc,
192 Desc,
193}
194
195#[derive(Debug, Clone)]
196pub enum Condition {
197 Comparison {
198 left: Expression,
199 operator: ComparisonOperator,
200 right: Expression,
201 },
202 InList {
203 expr: Expression,
204 values: Vec<Expression>,
205 is_not: bool,
206 },
207 InSubquery {
208 expr: Expression,
209 subquery: Box<SelectStatement>,
210 is_not: bool,
211 },
212 Between {
213 expr: Expression,
214 lower: Expression,
215 upper: Expression,
216 is_not: bool,
217 },
218 Like {
219 expr: Expression,
220 pattern: Expression,
221 is_not: bool,
222 },
223 Exists {
224 subquery: Box<SelectStatement>,
225 is_not: bool,
226 },
227 NullCheck {
228 expr: Expression,
229 is_not: bool,
230 },
231 Not(Box<Condition>),
232 Logical {
233 left: Box<Condition>,
234 operator: LogicalOperator,
235 right: Box<Condition>,
236 },
237}
238
239#[derive(Debug, Clone)]
240pub enum Expression {
241 Column(String),
242 Literal(LiteralValue),
243 IntervalLiteral {
244 value: String,
245 qualifier: IntervalQualifier,
246 },
247 Parameter(usize),
248 ScalarSubquery(Box<SelectStatement>),
249 Cast {
250 expr: Box<Expression>,
251 target_type: SqlTypeName,
252 },
253 Case {
254 branches: Vec<CaseWhenClause>,
255 else_expr: Option<Box<Expression>>,
256 },
257 ScalarFunctionCall {
258 function: ScalarFunction,
259 args: Vec<Expression>,
260 },
261 AggregateCall {
262 function: AggregateFunction,
263 target: AggregateTarget,
264 },
265 UnaryMinus(Box<Expression>),
266 UnaryNot(Box<Expression>),
267 Binary {
268 left: Box<Expression>,
269 operator: ArithmeticOperator,
270 right: Box<Expression>,
271 },
272 Comparison {
273 left: Box<Expression>,
274 operator: ComparisonOperator,
275 right: Box<Expression>,
276 },
277 InList {
278 expr: Box<Expression>,
279 values: Vec<Expression>,
280 is_not: bool,
281 },
282 InSubquery {
283 expr: Box<Expression>,
284 subquery: Box<SelectStatement>,
285 is_not: bool,
286 },
287 Between {
288 expr: Box<Expression>,
289 lower: Box<Expression>,
290 upper: Box<Expression>,
291 is_not: bool,
292 },
293 Like {
294 expr: Box<Expression>,
295 pattern: Box<Expression>,
296 is_not: bool,
297 },
298 Exists {
299 subquery: Box<SelectStatement>,
300 is_not: bool,
301 },
302 NullCheck {
303 expr: Box<Expression>,
304 is_not: bool,
305 },
306 Logical {
307 left: Box<Expression>,
308 operator: LogicalOperator,
309 right: Box<Expression>,
310 },
311}
312
313#[derive(Debug, Clone, Copy, PartialEq, Eq)]
314pub enum IntervalQualifier {
315 YearToMonth,
316 DayToSecond,
317}
318
319#[derive(Debug, Clone, Copy, PartialEq, Eq)]
320pub enum ArithmeticOperator {
321 Add,
322 Subtract,
323 Multiply,
324 Divide,
325 Modulo,
326}
327
328#[derive(Debug, Clone, Copy, PartialEq, Eq)]
329pub enum ScalarFunction {
330 Coalesce,
331 IfNull,
332 NullIf,
333 DateFn,
334 TimeFn,
335 Year,
336 Month,
337 Day,
338 Hour,
339 Minute,
340 Second,
341 TimeToSec,
342 SecToTime,
343 UnixTimestamp,
344 Lower,
345 Upper,
346 Length,
347 OctetLength,
348 BitLength,
349 Trim,
350 Abs,
351 Round,
352 Concat,
353 ConcatWs,
354 Substring,
355 LeftFn,
356 RightFn,
357 Greatest,
358 Least,
359 Replace,
360 Repeat,
361 Reverse,
362 Locate,
363 Hex,
364 Unhex,
365 Ceil,
366 Floor,
367 Power,
368}
369
370#[derive(Debug, Clone)]
371pub struct CaseWhenClause {
372 pub condition: Expression,
373 pub result: Expression,
374}
375
376#[derive(Debug, Clone, PartialEq, Eq)]
377pub enum AggregateTarget {
378 All,
379 Column(String),
380}
381
382#[derive(Debug, Clone)]
383pub enum ComparisonOperator {
384 Equal,
385 NotEqual,
386 LessThan,
387 LessThanOrEqual,
388 GreaterThan,
389 GreaterThanOrEqual,
390}
391
392#[derive(Debug, Clone)]
393pub enum LogicalOperator {
394 And,
395 Or,
396}
397
398#[derive(Debug, Clone)]
399pub struct InsertStatement {
400 pub table: String,
401 pub columns: Vec<String>,
402 pub source: InsertSource,
403 pub on_duplicate: Option<Vec<UpdateAssignment>>,
404}
405
406#[derive(Debug, Clone)]
407pub enum InsertSource {
408 Values(Vec<Vec<Expression>>),
409 Select(Box<SelectStatement>),
410}
411
412#[derive(Debug, Clone)]
413pub struct UpdateAssignment {
414 pub column: String,
415 pub value: Expression,
416}
417
418#[derive(Debug, Clone)]
419pub struct UpdateStatement {
420 pub table: String,
421 pub target_binding: Option<String>,
422 pub source: Option<TableReference>,
423 pub assignments: Vec<UpdateAssignment>,
424 pub where_clause: Option<WhereClause>,
425}
426
427#[derive(Debug, Clone)]
428pub struct DeleteStatement {
429 pub table: String,
430 pub target_binding: Option<String>,
431 pub source: Option<TableReference>,
432 pub where_clause: Option<WhereClause>,
433}
434
435impl UpdateStatement {
436 pub(crate) fn source(&self) -> TableReference {
437 self.source
438 .clone()
439 .unwrap_or_else(|| TableReference::Table(self.table.clone(), None))
440 }
441
442 pub(crate) fn target_binding_name(&self) -> &str {
443 self.target_binding.as_deref().unwrap_or(&self.table)
444 }
445}
446
447impl DeleteStatement {
448 pub(crate) fn source(&self) -> TableReference {
449 self.source
450 .clone()
451 .unwrap_or_else(|| TableReference::Table(self.table.clone(), None))
452 }
453
454 pub(crate) fn target_binding_name(&self) -> &str {
455 self.target_binding.as_deref().unwrap_or(&self.table)
456 }
457}
458
459#[derive(Debug, Clone)]
460pub struct CreateStatement {
461 pub table: String,
462 pub columns: Vec<ColumnDefinition>,
463 pub constraints: Vec<TableConstraint>,
464 pub if_not_exists: bool,
465}
466
467#[derive(Debug, Clone)]
468pub struct CreateViewStatement {
469 pub view: String,
470 pub if_not_exists: bool,
471 pub query: SelectStatement,
472}
473
474#[derive(Debug, Clone)]
475pub struct CreateTriggerStatement {
476 pub trigger: String,
477 pub table: String,
478 pub event: TriggerEvent,
479 pub body: Box<Statement>,
480}
481
482#[derive(Debug, Clone, Copy, PartialEq, Eq)]
483pub enum TriggerEvent {
484 Insert,
485 Update,
486 Delete,
487}
488
489#[derive(Debug, Clone)]
490pub struct CreateIndexStatement {
491 pub index_name: String,
492 pub table: String,
493 pub columns: Vec<String>,
494 pub unique: bool,
495 pub if_not_exists: bool,
496}
497
498#[derive(Debug, Clone)]
499pub struct DropStatement {
500 pub table: String,
501 pub if_exists: bool,
502}
503
504#[derive(Debug, Clone)]
505pub struct DropViewStatement {
506 pub view: String,
507 pub if_exists: bool,
508}
509
510#[derive(Debug, Clone)]
511pub struct DropTriggerStatement {
512 pub trigger: String,
513 pub if_exists: bool,
514}
515
516#[derive(Debug, Clone)]
517pub struct DropIndexStatement {
518 pub index_name: String,
519 pub table: String,
520 pub if_exists: bool,
521}
522
523#[derive(Debug, Clone)]
524pub struct AlterStatement {
525 pub table: String,
526 pub operation: AlterOperation,
527}
528
529#[derive(Debug, Clone)]
530pub enum AlterOperation {
531 RenameTo(String),
532 RenameColumn {
533 old_name: String,
534 new_name: String,
535 },
536 AddColumn(ColumnDefinition),
537 AddConstraint(TableConstraint),
538 DropColumn(String),
539 DropConstraint(String),
540 AlterColumnSetDefault {
541 column_name: String,
542 default_value: LiteralValue,
543 },
544 AlterColumnDropDefault {
545 column_name: String,
546 },
547 AlterColumnSetNotNull {
548 column_name: String,
549 },
550 AlterColumnDropNotNull {
551 column_name: String,
552 },
553}
554
555#[derive(Debug, Clone)]
556pub struct ColumnDefinition {
557 pub name: String,
558 pub data_type: SqlTypeName,
559 pub character_set: Option<String>,
560 pub collation: Option<String>,
561 pub nullable: bool,
562 pub primary_key: bool,
563 pub auto_increment: bool,
564 pub unique: bool,
565 pub default_value: Option<LiteralValue>,
566 pub check_constraint: Option<CheckConstraintDefinition>,
567 pub references: Option<ForeignKeyDefinition>,
568}
569
570#[derive(Debug, Clone, PartialEq, Eq)]
571pub struct CheckConstraintDefinition {
572 pub name: Option<String>,
573 pub expression_sql: String,
574}
575
576#[derive(Debug, Clone, PartialEq, Eq)]
577pub struct ForeignKeyDefinition {
578 pub name: Option<String>,
579 pub columns: Vec<String>,
580 pub referenced_table: String,
581 pub referenced_columns: Vec<String>,
582 pub on_delete: ForeignKeyAction,
583 pub on_update: ForeignKeyAction,
584}
585
586#[derive(Debug, Clone, PartialEq, Eq)]
587pub struct UniqueConstraintDefinition {
588 pub name: Option<String>,
589 pub columns: Vec<String>,
590}
591
592#[derive(Debug, Clone, PartialEq, Eq)]
593pub enum TableConstraint {
594 Check(CheckConstraintDefinition),
595 Unique(UniqueConstraintDefinition),
596 ForeignKey(ForeignKeyDefinition),
597}
598
599#[derive(Debug, Clone, Copy, PartialEq, Eq)]
600pub enum ForeignKeyAction {
601 Restrict,
602 Cascade,
603 SetNull,
604}
605
606impl Statement {
607 pub(crate) fn to_sql(&self) -> String {
608 match self {
609 Statement::Select(select) => select.to_sql(),
610 Statement::SelectInto(select_into) => select_into.to_sql(),
611 Statement::Insert(insert) => {
612 let mut sql = format!(
613 "INSERT INTO {} ({}) ",
614 insert.table,
615 insert.columns.join(", ")
616 );
617 match &insert.source {
618 InsertSource::Values(rows) => {
619 let rows_sql = rows
620 .iter()
621 .map(|row| {
622 format!(
623 "({})",
624 row.iter()
625 .map(Expression::to_sql)
626 .collect::<Vec<_>>()
627 .join(", ")
628 )
629 })
630 .collect::<Vec<_>>()
631 .join(", ");
632 sql.push_str(&format!("VALUES {rows_sql}"));
633 }
634 InsertSource::Select(select) => {
635 sql.push_str(&select.to_sql());
636 }
637 }
638 if let Some(assignments) = &insert.on_duplicate {
639 sql.push_str(" ON DUPLICATE KEY UPDATE ");
640 sql.push_str(
641 &assignments
642 .iter()
643 .map(|assignment| {
644 format!("{} = {}", assignment.column, assignment.value.to_sql())
645 })
646 .collect::<Vec<_>>()
647 .join(", "),
648 );
649 }
650 sql
651 }
652 Statement::Update(update) => {
653 let source = update.source();
654 let mut sql = format!(
655 "UPDATE {} SET {}",
656 source.to_sql(),
657 update
658 .assignments
659 .iter()
660 .map(|assignment| {
661 format!("{} = {}", assignment.column, assignment.value.to_sql())
662 })
663 .collect::<Vec<_>>()
664 .join(", ")
665 );
666 if let Some(where_clause) = &update.where_clause {
667 sql.push_str(&format!(
668 " WHERE {}",
669 where_clause
670 .conditions
671 .iter()
672 .map(Condition::to_sql)
673 .collect::<Vec<_>>()
674 .join(" AND ")
675 ));
676 }
677 sql
678 }
679 Statement::Delete(delete) => {
680 let mut sql = match delete.source.as_ref() {
681 Some(source) => format!(
682 "DELETE {} FROM {}",
683 delete.target_binding_name(),
684 source.to_sql()
685 ),
686 None => format!("DELETE FROM {}", delete.table),
687 };
688 if let Some(where_clause) = &delete.where_clause {
689 sql.push_str(&format!(
690 " WHERE {}",
691 where_clause
692 .conditions
693 .iter()
694 .map(Condition::to_sql)
695 .collect::<Vec<_>>()
696 .join(" AND ")
697 ));
698 }
699 sql
700 }
701 Statement::Explain(explain) => format!("EXPLAIN {}", explain.statement.to_sql()),
702 Statement::Describe(describe) => format!("DESCRIBE {}", describe.table),
703 Statement::ShowTables => "SHOW TABLES".to_string(),
704 Statement::ShowViews => "SHOW VIEWS".to_string(),
705 Statement::ShowIndexes(table) => match table {
706 Some(table) => format!("SHOW INDEXES FROM {table}"),
707 None => "SHOW INDEXES".to_string(),
708 },
709 Statement::ShowTriggers(table) => match table {
710 Some(table) => format!("SHOW TRIGGERS FROM {table}"),
711 None => "SHOW TRIGGERS".to_string(),
712 },
713 Statement::ShowCreateTable(table) => format!("SHOW CREATE TABLE {table}"),
714 Statement::ShowCreateView(view) => format!("SHOW CREATE VIEW {view}"),
715 Statement::Begin => "BEGIN".to_string(),
716 Statement::Commit => "COMMIT".to_string(),
717 Statement::Rollback => "ROLLBACK".to_string(),
718 Statement::Savepoint(name) => format!("SAVEPOINT {name}"),
719 Statement::RollbackToSavepoint(name) => format!("ROLLBACK TO SAVEPOINT {name}"),
720 Statement::ReleaseSavepoint(name) => format!("RELEASE SAVEPOINT {name}"),
721 Statement::Create(_)
722 | Statement::CreateView(_)
723 | Statement::CreateTrigger(_)
724 | Statement::CreateIndex(_)
725 | Statement::Alter(_)
726 | Statement::Drop(_)
727 | Statement::DropView(_)
728 | Statement::DropTrigger(_)
729 | Statement::DropIndex(_) => format!("{self:?}"),
730 }
731 }
732
733 #[cfg(test)]
734 pub fn validate(&self, catalog: &crate::catalog::Schema) -> Result<()> {
735 match self {
736 Statement::Begin
737 | Statement::Commit
738 | Statement::Rollback
739 | Statement::Savepoint(_)
740 | Statement::RollbackToSavepoint(_)
741 | Statement::ReleaseSavepoint(_)
742 | Statement::ShowTables
743 | Statement::ShowViews
744 | Statement::ShowIndexes(_)
745 | Statement::ShowTriggers(_)
746 | Statement::ShowCreateTable(_)
747 | Statement::ShowCreateView(_) => Ok(()),
748 Statement::Explain(explain) => explain.statement.validate(catalog),
749 Statement::Describe(describe) => {
750 if catalog.get_table_by_name(&describe.table).is_none() {
751 Err(HematiteError::ParseError(format!(
752 "Table '{}' does not exist",
753 describe.table
754 )))
755 } else {
756 Ok(())
757 }
758 }
759 Statement::Select(select) => select.validate(catalog),
760 Statement::SelectInto(select_into) => {
761 if catalog.get_table_by_name(&select_into.table).is_some()
762 || catalog.view(&select_into.table).is_some()
763 {
764 Err(HematiteError::ParseError(format!(
765 "Table '{}' already exists",
766 select_into.table
767 )))
768 } else {
769 select_into.query.validate(catalog)
770 }
771 }
772 Statement::Update(update) => update.validate(catalog),
773 Statement::Insert(insert) => insert.validate(catalog),
774 Statement::Delete(delete) => delete.validate(catalog),
775 Statement::Create(create) => create.validate(catalog),
776 Statement::CreateView(_create_view) => Ok(()),
777 Statement::CreateTrigger(_create_trigger) => Ok(()),
778 Statement::CreateIndex(create_index) => create_index.validate(catalog),
779 Statement::Alter(alter) => alter.validate(catalog),
780 Statement::Drop(drop) => drop.validate(catalog),
781 Statement::DropView(_drop_view) => Ok(()),
782 Statement::DropTrigger(_drop_trigger) => Ok(()),
783 Statement::DropIndex(drop_index) => drop_index.validate(catalog),
784 }
785 }
786
787 fn into_select(self) -> Result<SelectStatement> {
788 match self {
789 Statement::Select(select) => Ok(select),
790 _ => Err(HematiteError::InternalError(
791 "expected SELECT statement while binding a subquery".to_string(),
792 )),
793 }
794 }
795
796 pub fn is_read_only(&self) -> bool {
797 matches!(
798 self,
799 Statement::Explain(_)
800 | Statement::Describe(_)
801 | Statement::ShowTables
802 | Statement::ShowViews
803 | Statement::ShowIndexes(_)
804 | Statement::ShowTriggers(_)
805 | Statement::ShowCreateTable(_)
806 | Statement::ShowCreateView(_)
807 | Statement::Select(_)
808 )
809 }
810
811 pub fn mutates_schema(&self) -> bool {
812 matches!(
813 self,
814 Statement::Create(_)
815 | Statement::SelectInto(_)
816 | Statement::CreateView(_)
817 | Statement::CreateTrigger(_)
818 | Statement::CreateIndex(_)
819 | Statement::Alter(_)
820 | Statement::Drop(_)
821 | Statement::DropView(_)
822 | Statement::DropTrigger(_)
823 | Statement::DropIndex(_)
824 )
825 }
826
827 pub fn parameter_count(&self) -> usize {
828 let mut max_index: Option<usize> = None;
829 self.visit_parameters(&mut |index| {
830 max_index = Some(max_index.map_or(index, |current| current.max(index)));
831 });
832 max_index.map_or(0, |index| index + 1)
833 }
834
835 pub fn bind_parameters(&self, parameters: &[LiteralValue]) -> Result<Statement> {
836 self.bind_statement(parameters)
837 }
838
839 fn visit_parameters<F>(&self, f: &mut F)
840 where
841 F: FnMut(usize),
842 {
843 match self {
844 Statement::Begin
845 | Statement::Commit
846 | Statement::Rollback
847 | Statement::Savepoint(_)
848 | Statement::RollbackToSavepoint(_)
849 | Statement::ReleaseSavepoint(_)
850 | Statement::Describe(_)
851 | Statement::ShowTables
852 | Statement::ShowViews
853 | Statement::ShowIndexes(_)
854 | Statement::ShowTriggers(_)
855 | Statement::ShowCreateTable(_)
856 | Statement::ShowCreateView(_) => {}
857 Statement::Explain(explain) => explain.statement.visit_parameters(f),
858 Statement::Select(select) => {
859 select.visit_parameters(f);
860 }
861 Statement::SelectInto(select_into) => {
862 select_into.query.visit_parameters(f);
863 }
864 Statement::Update(update) => {
865 for assignment in &update.assignments {
866 assignment.value.visit_parameters(f);
867 }
868 if let Some(where_clause) = &update.where_clause {
869 where_clause.visit_parameters(f);
870 }
871 }
872 Statement::Insert(insert) => {
873 match &insert.source {
874 InsertSource::Values(rows) => {
875 for row in rows {
876 for expr in row {
877 expr.visit_parameters(f);
878 }
879 }
880 }
881 InsertSource::Select(select) => {
882 select.visit_parameters(f);
883 }
884 }
885 if let Some(assignments) = &insert.on_duplicate {
886 for assignment in assignments {
887 assignment.value.visit_parameters(f);
888 }
889 }
890 }
891 Statement::Delete(delete) => {
892 if let Some(where_clause) = &delete.where_clause {
893 where_clause.visit_parameters(f);
894 }
895 }
896 Statement::Create(_)
897 | Statement::CreateView(_)
898 | Statement::CreateTrigger(_)
899 | Statement::CreateIndex(_)
900 | Statement::Alter(_)
901 | Statement::Drop(_)
902 | Statement::DropView(_)
903 | Statement::DropTrigger(_)
904 | Statement::DropIndex(_) => {}
905 }
906 }
907
908 fn bind_statement(&self, parameters: &[LiteralValue]) -> Result<Statement> {
909 match self {
910 Statement::Begin => Ok(Statement::Begin),
911 Statement::Commit => Ok(Statement::Commit),
912 Statement::Rollback => Ok(Statement::Rollback),
913 Statement::Savepoint(name) => Ok(Statement::Savepoint(name.clone())),
914 Statement::RollbackToSavepoint(name) => {
915 Ok(Statement::RollbackToSavepoint(name.clone()))
916 }
917 Statement::ReleaseSavepoint(name) => Ok(Statement::ReleaseSavepoint(name.clone())),
918 Statement::Explain(explain) => Ok(Statement::Explain(ExplainStatement {
919 statement: Box::new(explain.statement.bind_parameters(parameters)?),
920 })),
921 Statement::Describe(describe) => Ok(Statement::Describe(describe.clone())),
922 Statement::ShowTables => Ok(Statement::ShowTables),
923 Statement::ShowViews => Ok(Statement::ShowViews),
924 Statement::ShowIndexes(table) => Ok(Statement::ShowIndexes(table.clone())),
925 Statement::ShowTriggers(table) => Ok(Statement::ShowTriggers(table.clone())),
926 Statement::ShowCreateTable(table) => Ok(Statement::ShowCreateTable(table.clone())),
927 Statement::ShowCreateView(view) => Ok(Statement::ShowCreateView(view.clone())),
928 Statement::Select(select) => Ok(Statement::Select(SelectStatement {
929 with_clause: select
930 .with_clause
931 .iter()
932 .map(|cte| {
933 Ok(CommonTableExpression {
934 name: cte.name.clone(),
935 recursive: cte.recursive,
936 query: Box::new(
937 Statement::Select((*cte.query).clone())
938 .bind_parameters(parameters)?
939 .into_select()?,
940 ),
941 })
942 })
943 .collect::<Result<Vec<_>>>()?,
944 distinct: select.distinct,
945 columns: select
946 .columns
947 .iter()
948 .map(|item| item.bind(parameters))
949 .collect::<Result<Vec<_>>>()?,
950 column_aliases: select.column_aliases.clone(),
951 from: select.from.clone(),
952 where_clause: select
953 .where_clause
954 .as_ref()
955 .map(|where_clause| where_clause.bind(parameters))
956 .transpose()?,
957 group_by: select
958 .group_by
959 .iter()
960 .map(|expr| expr.bind(parameters))
961 .collect::<Result<Vec<_>>>()?,
962 having_clause: select
963 .having_clause
964 .as_ref()
965 .map(|having_clause| having_clause.bind(parameters))
966 .transpose()?,
967 order_by: select.order_by.clone(),
968 limit: select.limit,
969 offset: select.offset,
970 set_operation: select
971 .set_operation
972 .as_ref()
973 .map(|set_operation| {
974 Ok::<SetOperation, HematiteError>(SetOperation {
975 operator: set_operation.operator,
976 right: Box::new(
977 Statement::Select((*set_operation.right).clone())
978 .bind_parameters(parameters)?
979 .into_select()?,
980 ),
981 })
982 })
983 .transpose()?,
984 })),
985 Statement::SelectInto(select_into) => Ok(Statement::SelectInto(SelectIntoStatement {
986 table: select_into.table.clone(),
987 query: Statement::Select(select_into.query.clone())
988 .bind_parameters(parameters)?
989 .into_select()?,
990 })),
991 Statement::Update(update) => Ok(Statement::Update(UpdateStatement {
992 table: update.table.clone(),
993 target_binding: update.target_binding.clone(),
994 source: update.source.clone(),
995 assignments: update
996 .assignments
997 .iter()
998 .map(|assignment| {
999 Ok(UpdateAssignment {
1000 column: assignment.column.clone(),
1001 value: assignment.value.bind(parameters)?,
1002 })
1003 })
1004 .collect::<Result<Vec<_>>>()?,
1005 where_clause: update
1006 .where_clause
1007 .as_ref()
1008 .map(|where_clause| where_clause.bind(parameters))
1009 .transpose()?,
1010 })),
1011 Statement::Insert(insert) => Ok(Statement::Insert(InsertStatement {
1012 table: insert.table.clone(),
1013 columns: insert.columns.clone(),
1014 source: match &insert.source {
1015 InsertSource::Values(rows) => InsertSource::Values(
1016 rows.iter()
1017 .map(|row| {
1018 row.iter()
1019 .map(|expr| expr.bind(parameters))
1020 .collect::<Result<Vec<_>>>()
1021 })
1022 .collect::<Result<Vec<_>>>()?,
1023 ),
1024 InsertSource::Select(select) => InsertSource::Select(Box::new(
1025 Statement::Select((**select).clone())
1026 .bind_parameters(parameters)?
1027 .into_select()?,
1028 )),
1029 },
1030 on_duplicate: insert
1031 .on_duplicate
1032 .as_ref()
1033 .map(|assignments| {
1034 assignments
1035 .iter()
1036 .map(|assignment| {
1037 Ok(UpdateAssignment {
1038 column: assignment.column.clone(),
1039 value: assignment.value.bind(parameters)?,
1040 })
1041 })
1042 .collect::<Result<Vec<_>>>()
1043 })
1044 .transpose()?,
1045 })),
1046 Statement::Delete(delete) => Ok(Statement::Delete(DeleteStatement {
1047 table: delete.table.clone(),
1048 target_binding: delete.target_binding.clone(),
1049 source: delete.source.clone(),
1050 where_clause: delete
1051 .where_clause
1052 .as_ref()
1053 .map(|where_clause| where_clause.bind(parameters))
1054 .transpose()?,
1055 })),
1056 Statement::Create(create) => Ok(Statement::Create(create.clone())),
1057 Statement::CreateView(create_view) => Ok(Statement::CreateView(CreateViewStatement {
1058 view: create_view.view.clone(),
1059 if_not_exists: create_view.if_not_exists,
1060 query: Statement::Select(create_view.query.clone())
1061 .bind_parameters(parameters)?
1062 .into_select()?,
1063 })),
1064 Statement::CreateTrigger(create_trigger) => {
1065 Ok(Statement::CreateTrigger(CreateTriggerStatement {
1066 trigger: create_trigger.trigger.clone(),
1067 table: create_trigger.table.clone(),
1068 event: create_trigger.event,
1069 body: Box::new(create_trigger.body.bind_parameters(parameters)?),
1070 }))
1071 }
1072 Statement::CreateIndex(create_index) => {
1073 Ok(Statement::CreateIndex(create_index.clone()))
1074 }
1075 Statement::Alter(alter) => Ok(Statement::Alter(alter.clone())),
1076 Statement::Drop(drop) => Ok(Statement::Drop(drop.clone())),
1077 Statement::DropView(drop_view) => Ok(Statement::DropView(drop_view.clone())),
1078 Statement::DropTrigger(drop_trigger) => {
1079 Ok(Statement::DropTrigger(drop_trigger.clone()))
1080 }
1081 Statement::DropIndex(drop_index) => Ok(Statement::DropIndex(drop_index.clone())),
1082 }
1083 }
1084}
1085
1086impl SelectIntoStatement {
1087 fn to_sql(&self) -> String {
1088 self.query.to_sql_with_into(&self.table)
1089 }
1090}
1091
1092impl WhereClause {
1093 fn visit_parameters<F>(&self, f: &mut F)
1094 where
1095 F: FnMut(usize),
1096 {
1097 for condition in &self.conditions {
1098 condition.visit_parameters(f);
1099 }
1100 }
1101
1102 fn bind(&self, parameters: &[LiteralValue]) -> Result<WhereClause> {
1103 Ok(WhereClause {
1104 conditions: self
1105 .conditions
1106 .iter()
1107 .map(|condition| condition.bind(parameters))
1108 .collect::<Result<Vec<_>>>()?,
1109 })
1110 }
1111}
1112
1113impl Condition {
1114 fn collect_dependency_names_into(&self, names: &mut std::collections::BTreeSet<String>) {
1115 match self {
1116 Condition::Comparison { left, right, .. } => {
1117 left.collect_dependency_names_into(names);
1118 right.collect_dependency_names_into(names);
1119 }
1120 Condition::InList { expr, values, .. } => {
1121 expr.collect_dependency_names_into(names);
1122 for value in values {
1123 value.collect_dependency_names_into(names);
1124 }
1125 }
1126 Condition::InSubquery { expr, subquery, .. } => {
1127 expr.collect_dependency_names_into(names);
1128 subquery.collect_dependency_names_into(names);
1129 }
1130 Condition::Between {
1131 expr, lower, upper, ..
1132 } => {
1133 expr.collect_dependency_names_into(names);
1134 lower.collect_dependency_names_into(names);
1135 upper.collect_dependency_names_into(names);
1136 }
1137 Condition::Like { expr, pattern, .. } => {
1138 expr.collect_dependency_names_into(names);
1139 pattern.collect_dependency_names_into(names);
1140 }
1141 Condition::Exists { subquery, .. } => subquery.collect_dependency_names_into(names),
1142 Condition::NullCheck { expr, .. } => expr.collect_dependency_names_into(names),
1143 Condition::Not(condition) => condition.collect_dependency_names_into(names),
1144 Condition::Logical { left, right, .. } => {
1145 left.collect_dependency_names_into(names);
1146 right.collect_dependency_names_into(names);
1147 }
1148 }
1149 }
1150
1151 fn visit_parameters<F>(&self, f: &mut F)
1152 where
1153 F: FnMut(usize),
1154 {
1155 match self {
1156 Condition::Comparison { left, right, .. } => {
1157 left.visit_parameters(f);
1158 right.visit_parameters(f);
1159 }
1160 Condition::InList { expr, values, .. } => {
1161 expr.visit_parameters(f);
1162 for value in values {
1163 value.visit_parameters(f);
1164 }
1165 }
1166 Condition::InSubquery { expr, subquery, .. } => {
1167 expr.visit_parameters(f);
1168 subquery.visit_parameters(f);
1169 }
1170 Condition::Between {
1171 expr, lower, upper, ..
1172 } => {
1173 expr.visit_parameters(f);
1174 lower.visit_parameters(f);
1175 upper.visit_parameters(f);
1176 }
1177 Condition::Like { expr, pattern, .. } => {
1178 expr.visit_parameters(f);
1179 pattern.visit_parameters(f);
1180 }
1181 Condition::Exists { subquery, .. } => subquery.visit_parameters(f),
1182 Condition::NullCheck { expr, .. } => expr.visit_parameters(f),
1183 Condition::Not(condition) => condition.visit_parameters(f),
1184 Condition::Logical { left, right, .. } => {
1185 left.visit_parameters(f);
1186 right.visit_parameters(f);
1187 }
1188 }
1189 }
1190
1191 fn bind(&self, parameters: &[LiteralValue]) -> Result<Condition> {
1192 match self {
1193 Condition::Comparison {
1194 left,
1195 operator,
1196 right,
1197 } => Ok(Condition::Comparison {
1198 left: left.bind(parameters)?,
1199 operator: operator.clone(),
1200 right: right.bind(parameters)?,
1201 }),
1202 Condition::InList {
1203 expr,
1204 values,
1205 is_not,
1206 } => Ok(Condition::InList {
1207 expr: expr.bind(parameters)?,
1208 values: values
1209 .iter()
1210 .map(|value| value.bind(parameters))
1211 .collect::<Result<Vec<_>>>()?,
1212 is_not: *is_not,
1213 }),
1214 Condition::InSubquery {
1215 expr,
1216 subquery,
1217 is_not,
1218 } => Ok(Condition::InSubquery {
1219 expr: expr.bind(parameters)?,
1220 subquery: Box::new(
1221 Statement::Select((**subquery).clone())
1222 .bind_parameters(parameters)?
1223 .into_select()?,
1224 ),
1225 is_not: *is_not,
1226 }),
1227 Condition::Between {
1228 expr,
1229 lower,
1230 upper,
1231 is_not,
1232 } => Ok(Condition::Between {
1233 expr: expr.bind(parameters)?,
1234 lower: lower.bind(parameters)?,
1235 upper: upper.bind(parameters)?,
1236 is_not: *is_not,
1237 }),
1238 Condition::Like {
1239 expr,
1240 pattern,
1241 is_not,
1242 } => Ok(Condition::Like {
1243 expr: expr.bind(parameters)?,
1244 pattern: pattern.bind(parameters)?,
1245 is_not: *is_not,
1246 }),
1247 Condition::Exists { subquery, is_not } => Ok(Condition::Exists {
1248 subquery: Box::new(
1249 Statement::Select((**subquery).clone())
1250 .bind_parameters(parameters)?
1251 .into_select()?,
1252 ),
1253 is_not: *is_not,
1254 }),
1255 Condition::NullCheck { expr, is_not } => Ok(Condition::NullCheck {
1256 expr: expr.bind(parameters)?,
1257 is_not: *is_not,
1258 }),
1259 Condition::Not(condition) => Ok(Condition::Not(Box::new(condition.bind(parameters)?))),
1260 Condition::Logical {
1261 left,
1262 operator,
1263 right,
1264 } => Ok(Condition::Logical {
1265 left: Box::new(left.bind(parameters)?),
1266 operator: operator.clone(),
1267 right: Box::new(right.bind(parameters)?),
1268 }),
1269 }
1270 }
1271}
1272
1273impl SelectItem {
1274 fn to_sql(&self) -> String {
1275 match self {
1276 SelectItem::Wildcard => "*".to_string(),
1277 SelectItem::Column(name) => name.clone(),
1278 SelectItem::Expression(expr) => expr.to_sql(),
1279 SelectItem::CountAll => "COUNT(*)".to_string(),
1280 SelectItem::Aggregate { function, column } => {
1281 format!("{}({})", function.to_sql(), column)
1282 }
1283 SelectItem::Window { function, window } => {
1284 format!("{} OVER ({})", function.to_sql(), window.to_sql())
1285 }
1286 }
1287 }
1288
1289 fn visit_parameters<F>(&self, f: &mut F)
1290 where
1291 F: FnMut(usize),
1292 {
1293 match self {
1294 SelectItem::Expression(expr) => expr.visit_parameters(f),
1295 SelectItem::Window { window, .. } => {
1296 for expr in &window.partition_by {
1297 expr.visit_parameters(f);
1298 }
1299 }
1300 SelectItem::Wildcard
1301 | SelectItem::Column(_)
1302 | SelectItem::CountAll
1303 | SelectItem::Aggregate { .. } => {}
1304 }
1305 }
1306
1307 fn bind(&self, parameters: &[LiteralValue]) -> Result<SelectItem> {
1308 match self {
1309 SelectItem::Wildcard => Ok(SelectItem::Wildcard),
1310 SelectItem::Column(name) => Ok(SelectItem::Column(name.clone())),
1311 SelectItem::Expression(expr) => Ok(SelectItem::Expression(expr.bind(parameters)?)),
1312 SelectItem::CountAll => Ok(SelectItem::CountAll),
1313 SelectItem::Aggregate { function, column } => Ok(SelectItem::Aggregate {
1314 function: *function,
1315 column: column.clone(),
1316 }),
1317 SelectItem::Window { function, window } => Ok(SelectItem::Window {
1318 function: function.clone(),
1319 window: WindowSpec {
1320 partition_by: window
1321 .partition_by
1322 .iter()
1323 .map(|expr| expr.bind(parameters))
1324 .collect::<Result<Vec<_>>>()?,
1325 order_by: window.order_by.clone(),
1326 },
1327 }),
1328 }
1329 }
1330}
1331
1332impl WindowSpec {
1333 fn to_sql(&self) -> String {
1334 let mut parts = Vec::new();
1335 if !self.partition_by.is_empty() {
1336 parts.push(format!(
1337 "PARTITION BY {}",
1338 self.partition_by
1339 .iter()
1340 .map(Expression::to_sql)
1341 .collect::<Vec<_>>()
1342 .join(", ")
1343 ));
1344 }
1345 if !self.order_by.is_empty() {
1346 parts.push(format!(
1347 "ORDER BY {}",
1348 self.order_by
1349 .iter()
1350 .map(|item| format!(
1351 "{} {}",
1352 item.column,
1353 match item.direction {
1354 SortDirection::Asc => "ASC",
1355 SortDirection::Desc => "DESC",
1356 }
1357 ))
1358 .collect::<Vec<_>>()
1359 .join(", ")
1360 ));
1361 }
1362 parts.join(" ")
1363 }
1364}
1365
1366impl WindowFunction {
1367 fn to_sql(&self) -> String {
1368 match self {
1369 WindowFunction::RowNumber => "ROW_NUMBER()".to_string(),
1370 WindowFunction::Rank => "RANK()".to_string(),
1371 WindowFunction::DenseRank => "DENSE_RANK()".to_string(),
1372 WindowFunction::Aggregate { function, target } => match target {
1373 AggregateTarget::All => format!("{}(*)", function.to_sql()),
1374 AggregateTarget::Column(column) => format!("{}({})", function.to_sql(), column),
1375 },
1376 }
1377 }
1378}
1379
1380impl Expression {
1381 fn collect_dependency_names_into(&self, names: &mut std::collections::BTreeSet<String>) {
1382 match self {
1383 Expression::ScalarSubquery(subquery) => subquery.collect_dependency_names_into(names),
1384 Expression::Case {
1385 branches,
1386 else_expr,
1387 } => {
1388 for branch in branches {
1389 branch.condition.collect_dependency_names_into(names);
1390 branch.result.collect_dependency_names_into(names);
1391 }
1392 if let Some(else_expr) = else_expr {
1393 else_expr.collect_dependency_names_into(names);
1394 }
1395 }
1396 Expression::ScalarFunctionCall { args, .. } => {
1397 for arg in args {
1398 arg.collect_dependency_names_into(names);
1399 }
1400 }
1401 Expression::AggregateCall { .. } => {}
1402 Expression::Cast { expr, .. }
1403 | Expression::UnaryMinus(expr)
1404 | Expression::UnaryNot(expr) => expr.collect_dependency_names_into(names),
1405 Expression::Binary { left, right, .. }
1406 | Expression::Comparison { left, right, .. }
1407 | Expression::Logical { left, right, .. } => {
1408 left.collect_dependency_names_into(names);
1409 right.collect_dependency_names_into(names);
1410 }
1411 Expression::InList { expr, values, .. } => {
1412 expr.collect_dependency_names_into(names);
1413 for value in values {
1414 value.collect_dependency_names_into(names);
1415 }
1416 }
1417 Expression::InSubquery { expr, subquery, .. } => {
1418 expr.collect_dependency_names_into(names);
1419 subquery.collect_dependency_names_into(names);
1420 }
1421 Expression::Between {
1422 expr, lower, upper, ..
1423 } => {
1424 expr.collect_dependency_names_into(names);
1425 lower.collect_dependency_names_into(names);
1426 upper.collect_dependency_names_into(names);
1427 }
1428 Expression::Like { expr, pattern, .. } => {
1429 expr.collect_dependency_names_into(names);
1430 pattern.collect_dependency_names_into(names);
1431 }
1432 Expression::Exists { subquery, .. } => subquery.collect_dependency_names_into(names),
1433 Expression::NullCheck { expr, .. } => expr.collect_dependency_names_into(names),
1434 Expression::Column(_)
1435 | Expression::Literal(_)
1436 | Expression::IntervalLiteral { .. }
1437 | Expression::Parameter(_) => {}
1438 }
1439 }
1440
1441 fn visit_parameters<F>(&self, f: &mut F)
1442 where
1443 F: FnMut(usize),
1444 {
1445 match self {
1446 Expression::Parameter(index) => f(*index),
1447 Expression::ScalarSubquery(subquery) => subquery.visit_parameters(f),
1448 Expression::Cast { expr, .. } => expr.visit_parameters(f),
1449 Expression::Case {
1450 branches,
1451 else_expr,
1452 } => {
1453 for branch in branches {
1454 branch.condition.visit_parameters(f);
1455 branch.result.visit_parameters(f);
1456 }
1457 if let Some(else_expr) = else_expr {
1458 else_expr.visit_parameters(f);
1459 }
1460 }
1461 Expression::ScalarFunctionCall { args, .. } => {
1462 for arg in args {
1463 arg.visit_parameters(f);
1464 }
1465 }
1466 Expression::AggregateCall { .. } => {}
1467 Expression::UnaryMinus(expr) => expr.visit_parameters(f),
1468 Expression::UnaryNot(expr) => expr.visit_parameters(f),
1469 Expression::Binary { left, right, .. } => {
1470 left.visit_parameters(f);
1471 right.visit_parameters(f);
1472 }
1473 Expression::Comparison { left, right, .. } => {
1474 left.visit_parameters(f);
1475 right.visit_parameters(f);
1476 }
1477 Expression::InList { expr, values, .. } => {
1478 expr.visit_parameters(f);
1479 for value in values {
1480 value.visit_parameters(f);
1481 }
1482 }
1483 Expression::InSubquery { expr, subquery, .. } => {
1484 expr.visit_parameters(f);
1485 subquery.visit_parameters(f);
1486 }
1487 Expression::Between {
1488 expr, lower, upper, ..
1489 } => {
1490 expr.visit_parameters(f);
1491 lower.visit_parameters(f);
1492 upper.visit_parameters(f);
1493 }
1494 Expression::Like { expr, pattern, .. } => {
1495 expr.visit_parameters(f);
1496 pattern.visit_parameters(f);
1497 }
1498 Expression::Exists { subquery, .. } => subquery.visit_parameters(f),
1499 Expression::NullCheck { expr, .. } => expr.visit_parameters(f),
1500 Expression::Logical { left, right, .. } => {
1501 left.visit_parameters(f);
1502 right.visit_parameters(f);
1503 }
1504 Expression::Column(_) | Expression::Literal(_) | Expression::IntervalLiteral { .. } => {
1505 }
1506 }
1507 }
1508
1509 fn bind(&self, parameters: &[LiteralValue]) -> Result<Expression> {
1510 match self {
1511 Expression::Column(name) => Ok(Expression::Column(name.clone())),
1512 Expression::Literal(value) => Ok(Expression::Literal(value.clone())),
1513 Expression::IntervalLiteral { value, qualifier } => Ok(Expression::IntervalLiteral {
1514 value: value.clone(),
1515 qualifier: *qualifier,
1516 }),
1517 Expression::Parameter(index) => parameters
1518 .get(*index)
1519 .cloned()
1520 .map(Expression::Literal)
1521 .ok_or_else(|| {
1522 HematiteError::ParseError(format!(
1523 "Missing bound value for parameter {}",
1524 index + 1
1525 ))
1526 }),
1527 Expression::ScalarSubquery(subquery) => Ok(Expression::ScalarSubquery(Box::new(
1528 Statement::Select((**subquery).clone())
1529 .bind_parameters(parameters)?
1530 .into_select()?,
1531 ))),
1532 Expression::Cast { expr, target_type } => Ok(Expression::Cast {
1533 expr: Box::new(expr.bind(parameters)?),
1534 target_type: target_type.clone(),
1535 }),
1536 Expression::Case {
1537 branches,
1538 else_expr,
1539 } => Ok(Expression::Case {
1540 branches: branches
1541 .iter()
1542 .map(|branch| {
1543 Ok(CaseWhenClause {
1544 condition: branch.condition.bind(parameters)?,
1545 result: branch.result.bind(parameters)?,
1546 })
1547 })
1548 .collect::<Result<Vec<_>>>()?,
1549 else_expr: else_expr
1550 .as_ref()
1551 .map(|expr| expr.bind(parameters).map(Box::new))
1552 .transpose()?,
1553 }),
1554 Expression::ScalarFunctionCall { function, args } => {
1555 let mut bound_args = Vec::with_capacity(args.len());
1556 for arg in args {
1557 bound_args.push(arg.bind(parameters)?);
1558 }
1559 Ok(Expression::ScalarFunctionCall {
1560 function: *function,
1561 args: bound_args,
1562 })
1563 }
1564 Expression::AggregateCall { function, target } => Ok(Expression::AggregateCall {
1565 function: *function,
1566 target: target.clone(),
1567 }),
1568 Expression::UnaryMinus(expr) => {
1569 Ok(Expression::UnaryMinus(Box::new(expr.bind(parameters)?)))
1570 }
1571 Expression::UnaryNot(expr) => {
1572 Ok(Expression::UnaryNot(Box::new(expr.bind(parameters)?)))
1573 }
1574 Expression::Binary {
1575 left,
1576 operator,
1577 right,
1578 } => Ok(Expression::Binary {
1579 left: Box::new(left.bind(parameters)?),
1580 operator: *operator,
1581 right: Box::new(right.bind(parameters)?),
1582 }),
1583 Expression::Comparison {
1584 left,
1585 operator,
1586 right,
1587 } => Ok(Expression::Comparison {
1588 left: Box::new(left.bind(parameters)?),
1589 operator: operator.clone(),
1590 right: Box::new(right.bind(parameters)?),
1591 }),
1592 Expression::InList {
1593 expr,
1594 values,
1595 is_not,
1596 } => Ok(Expression::InList {
1597 expr: Box::new(expr.bind(parameters)?),
1598 values: values
1599 .iter()
1600 .map(|value| value.bind(parameters))
1601 .collect::<Result<Vec<_>>>()?,
1602 is_not: *is_not,
1603 }),
1604 Expression::InSubquery {
1605 expr,
1606 subquery,
1607 is_not,
1608 } => Ok(Expression::InSubquery {
1609 expr: Box::new(expr.bind(parameters)?),
1610 subquery: Box::new(
1611 Statement::Select((**subquery).clone())
1612 .bind_parameters(parameters)?
1613 .into_select()?,
1614 ),
1615 is_not: *is_not,
1616 }),
1617 Expression::Between {
1618 expr,
1619 lower,
1620 upper,
1621 is_not,
1622 } => Ok(Expression::Between {
1623 expr: Box::new(expr.bind(parameters)?),
1624 lower: Box::new(lower.bind(parameters)?),
1625 upper: Box::new(upper.bind(parameters)?),
1626 is_not: *is_not,
1627 }),
1628 Expression::Like {
1629 expr,
1630 pattern,
1631 is_not,
1632 } => Ok(Expression::Like {
1633 expr: Box::new(expr.bind(parameters)?),
1634 pattern: Box::new(pattern.bind(parameters)?),
1635 is_not: *is_not,
1636 }),
1637 Expression::Exists { subquery, is_not } => Ok(Expression::Exists {
1638 subquery: Box::new(
1639 Statement::Select((**subquery).clone())
1640 .bind_parameters(parameters)?
1641 .into_select()?,
1642 ),
1643 is_not: *is_not,
1644 }),
1645 Expression::NullCheck { expr, is_not } => Ok(Expression::NullCheck {
1646 expr: Box::new(expr.bind(parameters)?),
1647 is_not: *is_not,
1648 }),
1649 Expression::Logical {
1650 left,
1651 operator,
1652 right,
1653 } => Ok(Expression::Logical {
1654 left: Box::new(left.bind(parameters)?),
1655 operator: operator.clone(),
1656 right: Box::new(right.bind(parameters)?),
1657 }),
1658 }
1659 }
1660}
1661
1662impl IntervalQualifier {
1663 pub fn to_sql(self) -> &'static str {
1664 match self {
1665 IntervalQualifier::YearToMonth => "YEAR TO MONTH",
1666 IntervalQualifier::DayToSecond => "DAY TO SECOND",
1667 }
1668 }
1669}
1670
1671impl SelectStatement {
1672 pub(crate) fn single_table_scope(table_name: &str) -> Self {
1673 Self {
1674 with_clause: Vec::new(),
1675 distinct: false,
1676 columns: Vec::new(),
1677 column_aliases: Vec::new(),
1678 from: TableReference::Table(table_name.to_string(), None),
1679 where_clause: None,
1680 group_by: Vec::new(),
1681 having_clause: None,
1682 order_by: Vec::new(),
1683 limit: None,
1684 offset: None,
1685 set_operation: None,
1686 }
1687 }
1688
1689 fn visit_parameters<F>(&self, f: &mut F)
1690 where
1691 F: FnMut(usize),
1692 {
1693 for item in &self.columns {
1694 item.visit_parameters(f);
1695 }
1696 for cte in &self.with_clause {
1697 cte.query.visit_parameters(f);
1698 }
1699 if let Some(where_clause) = &self.where_clause {
1700 where_clause.visit_parameters(f);
1701 }
1702 for expr in &self.group_by {
1703 expr.visit_parameters(f);
1704 }
1705 if let Some(having_clause) = &self.having_clause {
1706 having_clause.visit_parameters(f);
1707 }
1708 if let Some(set_operation) = &self.set_operation {
1709 set_operation.right.visit_parameters(f);
1710 }
1711 }
1712
1713 pub(crate) fn is_hidden_rowid(name: &str) -> bool {
1714 name.eq_ignore_ascii_case("rowid")
1715 }
1716
1717 pub(crate) fn lookup_cte<'a>(&'a self, name: &str) -> Option<&'a CommonTableExpression> {
1718 self.with_clause
1719 .iter()
1720 .find(|cte| cte.name.eq_ignore_ascii_case(name))
1721 }
1722
1723 pub(crate) fn references_cte(&self, name: &str) -> bool {
1724 self.lookup_cte(name).is_some()
1725 }
1726
1727 pub(crate) fn references_source_name(&self, name: &str) -> bool {
1728 self.from.references_source_name(name)
1729 || self.where_clause.as_ref().is_some_and(|where_clause| {
1730 where_clause
1731 .conditions
1732 .iter()
1733 .any(|condition| condition.references_source_name(name))
1734 })
1735 || self
1736 .group_by
1737 .iter()
1738 .any(|expr| expr.references_source_name(name))
1739 || self.having_clause.as_ref().is_some_and(|having_clause| {
1740 having_clause
1741 .conditions
1742 .iter()
1743 .any(|condition| condition.references_source_name(name))
1744 })
1745 || self
1746 .set_operation
1747 .as_ref()
1748 .is_some_and(|set_operation| set_operation.right.references_source_name(name))
1749 }
1750
1751 pub(crate) fn has_non_table_source(&self) -> bool {
1752 self.from.has_non_table_source(self)
1753 }
1754
1755 pub(crate) fn split_column_reference(name: &str) -> (Option<&str>, &str) {
1756 match name.split_once('.') {
1757 Some((qualifier, column_name)) => (Some(qualifier), column_name),
1758 None => (None, name),
1759 }
1760 }
1761
1762 pub(crate) fn column_reference_name(name: &str) -> &str {
1763 Self::split_column_reference(name).1
1764 }
1765
1766 pub(crate) fn default_output_name(item: &SelectItem, index: usize) -> Option<String> {
1767 match item {
1768 SelectItem::Wildcard => None,
1769 SelectItem::Column(name) => Some(Self::column_reference_name(name).to_string()),
1770 SelectItem::Expression(_) => Some(format!("expr{}", index + 1)),
1771 SelectItem::CountAll => Some("COUNT(*)".to_string()),
1772 SelectItem::Aggregate { function, column } => Some(format!(
1773 "{}({})",
1774 match function {
1775 AggregateFunction::Count => "COUNT",
1776 AggregateFunction::Sum => "SUM",
1777 AggregateFunction::Avg => "AVG",
1778 AggregateFunction::Min => "MIN",
1779 AggregateFunction::Max => "MAX",
1780 },
1781 column
1782 )),
1783 SelectItem::Window { function, .. } => Some(function.to_sql()),
1784 }
1785 }
1786
1787 pub(crate) fn output_name(&self, index: usize) -> Option<String> {
1788 self.column_aliases
1789 .get(index)
1790 .and_then(|alias| alias.clone())
1791 .or_else(|| {
1792 self.columns
1793 .get(index)
1794 .and_then(|item| Self::default_output_name(item, index))
1795 })
1796 }
1797
1798 pub(crate) fn dependency_names(&self) -> Vec<String> {
1799 let mut names = std::collections::BTreeSet::new();
1800 self.collect_dependency_names_into(&mut names);
1801 names.into_iter().collect()
1802 }
1803
1804 fn collect_dependency_names_into(&self, names: &mut std::collections::BTreeSet<String>) {
1805 for cte in &self.with_clause {
1806 cte.query.collect_dependency_names_into(names);
1807 }
1808 self.from.collect_dependency_names_into(names);
1809 if let Some(where_clause) = &self.where_clause {
1810 for condition in &where_clause.conditions {
1811 condition.collect_dependency_names_into(names);
1812 }
1813 }
1814 for expr in &self.group_by {
1815 expr.collect_dependency_names_into(names);
1816 }
1817 if let Some(having_clause) = &self.having_clause {
1818 for condition in &having_clause.conditions {
1819 condition.collect_dependency_names_into(names);
1820 }
1821 }
1822 if let Some(set_operation) = &self.set_operation {
1823 set_operation.right.collect_dependency_names_into(names);
1824 }
1825 }
1826
1827 pub(crate) fn to_sql_with_into(&self, table: &str) -> String {
1828 let mut parts = Vec::new();
1829 if !self.with_clause.is_empty() {
1830 let recursive = self.with_clause.iter().any(|cte| cte.recursive);
1831 let ctes = self
1832 .with_clause
1833 .iter()
1834 .map(|cte| format!("{} AS ({})", cte.name, cte.query.to_sql()))
1835 .collect::<Vec<_>>()
1836 .join(", ");
1837 parts.push(format!(
1838 "WITH {}{}",
1839 if recursive { "RECURSIVE " } else { "" },
1840 ctes
1841 ));
1842 }
1843
1844 let projections = self
1845 .columns
1846 .iter()
1847 .enumerate()
1848 .map(|(index, item)| {
1849 let base = item.to_sql();
1850 match self
1851 .column_aliases
1852 .get(index)
1853 .and_then(|alias| alias.clone())
1854 {
1855 Some(alias) => format!("{base} AS {alias}"),
1856 None => base,
1857 }
1858 })
1859 .collect::<Vec<_>>()
1860 .join(", ");
1861 parts.push(format!(
1862 "SELECT{} {} INTO {}",
1863 if self.distinct { " DISTINCT" } else { "" },
1864 projections,
1865 table
1866 ));
1867 parts.push(format!("FROM {}", self.from.to_sql()));
1868
1869 if let Some(where_clause) = &self.where_clause {
1870 parts.push(format!(
1871 "WHERE {}",
1872 where_clause
1873 .conditions
1874 .iter()
1875 .map(Condition::to_sql)
1876 .collect::<Vec<_>>()
1877 .join(" AND ")
1878 ));
1879 }
1880 if !self.group_by.is_empty() {
1881 parts.push(format!(
1882 "GROUP BY {}",
1883 self.group_by
1884 .iter()
1885 .map(Expression::to_sql)
1886 .collect::<Vec<_>>()
1887 .join(", ")
1888 ));
1889 }
1890 if let Some(having_clause) = &self.having_clause {
1891 parts.push(format!(
1892 "HAVING {}",
1893 having_clause
1894 .conditions
1895 .iter()
1896 .map(Condition::to_sql)
1897 .collect::<Vec<_>>()
1898 .join(" AND ")
1899 ));
1900 }
1901 if !self.order_by.is_empty() {
1902 parts.push(format!(
1903 "ORDER BY {}",
1904 self.order_by
1905 .iter()
1906 .map(|item| format!("{} {}", item.column, item.direction.to_sql()))
1907 .collect::<Vec<_>>()
1908 .join(", ")
1909 ));
1910 }
1911 if let Some(limit) = self.limit {
1912 parts.push(format!("LIMIT {}", limit));
1913 }
1914 if let Some(offset) = self.offset {
1915 parts.push(format!("OFFSET {}", offset));
1916 }
1917 if let Some(set_operation) = &self.set_operation {
1918 parts.push(format!(
1919 "{} {}",
1920 set_operation.operator.to_sql(),
1921 set_operation.right.to_sql()
1922 ));
1923 }
1924
1925 parts.join(" ")
1926 }
1927
1928 pub(crate) fn to_sql(&self) -> String {
1929 let mut parts = Vec::new();
1930 if !self.with_clause.is_empty() {
1931 let recursive = self.with_clause.iter().any(|cte| cte.recursive);
1932 let ctes = self
1933 .with_clause
1934 .iter()
1935 .map(|cte| format!("{} AS ({})", cte.name, cte.query.to_sql()))
1936 .collect::<Vec<_>>()
1937 .join(", ");
1938 parts.push(format!(
1939 "WITH {}{}",
1940 if recursive { "RECURSIVE " } else { "" },
1941 ctes
1942 ));
1943 }
1944
1945 let projections = self
1946 .columns
1947 .iter()
1948 .enumerate()
1949 .map(|(index, item)| {
1950 let base = item.to_sql();
1951 match self
1952 .column_aliases
1953 .get(index)
1954 .and_then(|alias| alias.clone())
1955 {
1956 Some(alias) => format!("{base} AS {alias}"),
1957 None => base,
1958 }
1959 })
1960 .collect::<Vec<_>>()
1961 .join(", ");
1962 parts.push(format!(
1963 "SELECT{} {}",
1964 if self.distinct { " DISTINCT" } else { "" },
1965 projections
1966 ));
1967 parts.push(format!("FROM {}", self.from.to_sql()));
1968
1969 if let Some(where_clause) = &self.where_clause {
1970 parts.push(format!(
1971 "WHERE {}",
1972 where_clause
1973 .conditions
1974 .iter()
1975 .map(Condition::to_sql)
1976 .collect::<Vec<_>>()
1977 .join(" AND ")
1978 ));
1979 }
1980 if !self.group_by.is_empty() {
1981 parts.push(format!(
1982 "GROUP BY {}",
1983 self.group_by
1984 .iter()
1985 .map(Expression::to_sql)
1986 .collect::<Vec<_>>()
1987 .join(", ")
1988 ));
1989 }
1990 if let Some(having_clause) = &self.having_clause {
1991 parts.push(format!(
1992 "HAVING {}",
1993 having_clause
1994 .conditions
1995 .iter()
1996 .map(Condition::to_sql)
1997 .collect::<Vec<_>>()
1998 .join(" AND ")
1999 ));
2000 }
2001 if !self.order_by.is_empty() {
2002 parts.push(format!(
2003 "ORDER BY {}",
2004 self.order_by
2005 .iter()
2006 .map(|item| format!("{} {}", item.column, item.direction.to_sql()))
2007 .collect::<Vec<_>>()
2008 .join(", ")
2009 ));
2010 }
2011 if let Some(limit) = self.limit {
2012 parts.push(format!("LIMIT {}", limit));
2013 }
2014 if let Some(offset) = self.offset {
2015 parts.push(format!("OFFSET {}", offset));
2016 }
2017 if let Some(set_operation) = &self.set_operation {
2018 parts.push(format!(
2019 "{} {}",
2020 set_operation.operator.to_sql(),
2021 set_operation.right.to_sql()
2022 ));
2023 }
2024
2025 parts.join(" ")
2026 }
2027
2028 pub(crate) fn collect_table_bindings(from: &TableReference) -> Vec<TableBinding> {
2029 let mut bindings = Vec::new();
2030 Self::collect_table_bindings_into(from, &mut bindings);
2031 bindings
2032 }
2033
2034 fn collect_table_bindings_into(from: &TableReference, bindings: &mut Vec<TableBinding>) {
2035 match from {
2036 TableReference::Table(table_name, alias) => bindings.push(TableBinding {
2037 table_name: table_name.clone(),
2038 alias: alias.clone(),
2039 }),
2040 TableReference::Derived { alias, .. } => bindings.push(TableBinding {
2041 table_name: alias.clone(),
2042 alias: None,
2043 }),
2044 TableReference::CrossJoin(left, right) => {
2045 Self::collect_table_bindings_into(left, bindings);
2046 Self::collect_table_bindings_into(right, bindings);
2047 }
2048 TableReference::InnerJoin { left, right, .. }
2049 | TableReference::LeftJoin { left, right, .. }
2050 | TableReference::RightJoin { left, right, .. }
2051 | TableReference::FullOuterJoin { left, right, .. } => {
2052 Self::collect_table_bindings_into(left, bindings);
2053 Self::collect_table_bindings_into(right, bindings);
2054 }
2055 }
2056 }
2057
2058 #[cfg(test)]
2059 fn collect_source_bindings(
2060 &self,
2061 catalog: &crate::catalog::Schema,
2062 from: &TableReference,
2063 ) -> Result<Vec<SourceBinding>> {
2064 let mut bindings = Vec::new();
2065 self.collect_source_bindings_into(catalog, from, &mut bindings)?;
2066 Ok(bindings)
2067 }
2068
2069 #[cfg(test)]
2070 fn collect_source_bindings_into(
2071 &self,
2072 catalog: &crate::catalog::Schema,
2073 from: &TableReference,
2074 bindings: &mut Vec<SourceBinding>,
2075 ) -> Result<()> {
2076 match from {
2077 TableReference::Table(table_name, alias) => {
2078 if let Some(cte) = self.lookup_cte(table_name) {
2079 if !cte.recursive {
2080 cte.query.validate(catalog)?;
2081 }
2082 bindings.push(SourceBinding {
2083 source_name: table_name.clone(),
2084 alias: alias.clone(),
2085 columns: if cte.recursive {
2086 cte.query
2087 .columns
2088 .iter()
2089 .enumerate()
2090 .map(|(index, _)| {
2091 cte.query.output_name(index).ok_or_else(|| {
2092 HematiteError::ParseError(format!(
2093 "Recursive CTE '{}' requires a name for projected column {}",
2094 cte.name,
2095 index + 1
2096 ))
2097 })
2098 })
2099 .collect::<Result<Vec<_>>>()?
2100 } else {
2101 cte.query.projected_column_names(catalog)?
2102 },
2103 has_hidden_rowid: false,
2104 });
2105 Ok(())
2106 } else {
2107 let table = catalog.get_table_by_name(table_name).ok_or_else(|| {
2108 HematiteError::ParseError(format!("Table '{}' does not exist", table_name))
2109 })?;
2110 bindings.push(SourceBinding {
2111 source_name: table_name.clone(),
2112 alias: alias.clone(),
2113 columns: table
2114 .columns
2115 .iter()
2116 .map(|column| column.name.clone())
2117 .collect(),
2118 has_hidden_rowid: true,
2119 });
2120 Ok(())
2121 }
2122 }
2123 TableReference::Derived { subquery, alias } => {
2124 subquery.validate(catalog)?;
2125 bindings.push(SourceBinding {
2126 source_name: alias.clone(),
2127 alias: None,
2128 columns: subquery.projected_column_names(catalog)?,
2129 has_hidden_rowid: false,
2130 });
2131 Ok(())
2132 }
2133 TableReference::CrossJoin(left, right) => {
2134 self.collect_source_bindings_into(catalog, left, bindings)?;
2135 self.collect_source_bindings_into(catalog, right, bindings)
2136 }
2137 TableReference::InnerJoin { left, right, .. }
2138 | TableReference::LeftJoin { left, right, .. }
2139 | TableReference::RightJoin { left, right, .. }
2140 | TableReference::FullOuterJoin { left, right, .. } => {
2141 self.collect_source_bindings_into(catalog, left, bindings)?;
2142 self.collect_source_bindings_into(catalog, right, bindings)
2143 }
2144 }
2145 }
2146
2147 #[cfg(test)]
2148 fn projected_column_names(&self, catalog: &crate::catalog::Schema) -> Result<Vec<String>> {
2149 let mut names = Vec::with_capacity(self.columns.len());
2150 for (index, item) in self.columns.iter().enumerate() {
2151 if let Some(alias) = self
2152 .column_aliases
2153 .get(index)
2154 .and_then(|alias| alias.clone())
2155 {
2156 names.push(alias);
2157 continue;
2158 }
2159
2160 match item {
2161 SelectItem::Wildcard => {
2162 names.extend(
2163 self.collect_source_bindings(catalog, &self.from)?
2164 .into_iter()
2165 .flat_map(|binding| binding.columns),
2166 );
2167 }
2168 SelectItem::Column(name) => {
2169 self.validate_column_reference(name, catalog, &self.from)?;
2170 if let Some(name) = Self::default_output_name(item, index) {
2171 names.push(name);
2172 }
2173 }
2174 SelectItem::CountAll | SelectItem::Aggregate { .. } => {
2175 if let Some(name) = Self::default_output_name(item, index) {
2176 names.push(name);
2177 }
2178 }
2179 SelectItem::Window { .. } => {
2180 if let Some(name) = Self::default_output_name(item, index) {
2181 names.push(name);
2182 }
2183 }
2184 SelectItem::Expression(_) => {
2185 return Err(HematiteError::ParseError(
2186 "Expression projections in derived tables or CTEs require aliases"
2187 .to_string(),
2188 ))
2189 }
2190 }
2191 }
2192 Ok(names)
2193 }
2194
2195 #[cfg(test)]
2196 pub(crate) fn validate_column_reference(
2197 &self,
2198 name: &str,
2199 catalog: &crate::catalog::Schema,
2200 from: &TableReference,
2201 ) -> Result<()> {
2202 self.validate_column_reference_with_outer(name, catalog, from, &[])
2203 }
2204
2205 #[cfg(test)]
2206 fn validate_column_reference_with_outer(
2207 &self,
2208 name: &str,
2209 catalog: &crate::catalog::Schema,
2210 from: &TableReference,
2211 outer_bindings: &[SourceBinding],
2212 ) -> Result<()> {
2213 let (qualifier, column_name) = Self::split_column_reference(name);
2214 let local_bindings = self.collect_source_bindings(catalog, from)?;
2215 let local_matches =
2216 Self::collect_matching_source_names(qualifier, column_name, &local_bindings)?;
2217 if !local_matches.is_empty() {
2218 return match local_matches.len() {
2219 1 => Ok(()),
2220 _ => Err(HematiteError::ParseError(format!(
2221 "Column reference '{}' is ambiguous",
2222 name
2223 ))),
2224 };
2225 }
2226
2227 let outer_matches =
2228 Self::collect_matching_source_names(qualifier, column_name, outer_bindings)?;
2229 match outer_matches.len() {
2230 0 => {
2231 if let Some(qualifier) = qualifier {
2232 Err(HematiteError::ParseError(format!(
2233 "Column '{}' does not exist in table '{}'",
2234 column_name, qualifier
2235 )))
2236 } else {
2237 Err(HematiteError::ParseError(format!(
2238 "Column '{}' does not exist in the query source set",
2239 column_name
2240 )))
2241 }
2242 }
2243 1 => Ok(()),
2244 _ => Err(HematiteError::ParseError(format!(
2245 "Column reference '{}' is ambiguous",
2246 name
2247 ))),
2248 }
2249 }
2250
2251 #[cfg(test)]
2252 fn collect_matching_source_names(
2253 qualifier: Option<&str>,
2254 column_name: &str,
2255 bindings: &[SourceBinding],
2256 ) -> Result<Vec<String>> {
2257 let candidate_bindings: Vec<&SourceBinding> = if let Some(qualifier) = qualifier {
2258 bindings
2259 .iter()
2260 .filter(|binding| {
2261 binding.source_name == qualifier
2262 || binding
2263 .alias
2264 .as_deref()
2265 .is_some_and(|alias| alias == qualifier)
2266 })
2267 .collect()
2268 } else {
2269 bindings.iter().collect()
2270 };
2271 let mut matched_tables = Vec::new();
2272
2273 for binding in candidate_bindings {
2274 if binding.columns.iter().any(|column| column == column_name)
2275 || (binding.has_hidden_rowid && Self::is_hidden_rowid(column_name))
2276 {
2277 matched_tables.push(binding.source_name.clone());
2278 }
2279 }
2280
2281 Ok(matched_tables)
2282 }
2283
2284 #[cfg(test)]
2285 fn combined_outer_bindings(
2286 &self,
2287 catalog: &crate::catalog::Schema,
2288 from: &TableReference,
2289 outer_bindings: &[SourceBinding],
2290 ) -> Result<Vec<SourceBinding>> {
2291 let mut bindings = self.collect_source_bindings(catalog, from)?;
2292 bindings.extend(outer_bindings.iter().cloned());
2293 Ok(bindings)
2294 }
2295
2296 #[cfg(test)]
2297 pub fn validate(&self, catalog: &crate::catalog::Schema) -> Result<()> {
2298 self.validate_with_outer_bindings(catalog, &[])
2299 }
2300
2301 #[cfg(test)]
2302 fn validate_with_outer_bindings(
2303 &self,
2304 catalog: &crate::catalog::Schema,
2305 outer_bindings: &[SourceBinding],
2306 ) -> Result<()> {
2307 if let Some(set_operation) = &self.set_operation {
2308 set_operation
2309 .right
2310 .validate_with_outer_bindings(catalog, outer_bindings)?;
2311 if self.columns.len() != set_operation.right.columns.len() {
2312 return Err(HematiteError::ParseError(
2313 "Set operations require both queries to project the same number of columns"
2314 .to_string(),
2315 ));
2316 }
2317 }
2318
2319 for cte in &self.with_clause {
2320 if cte.recursive {
2321 let set_operation = cte.query.set_operation.as_ref().ok_or_else(|| {
2322 HematiteError::ParseError(format!(
2323 "Recursive CTE '{}' requires UNION or UNION ALL",
2324 cte.name
2325 ))
2326 })?;
2327 if !matches!(
2328 set_operation.operator,
2329 SetOperator::Union | SetOperator::UnionAll
2330 ) {
2331 return Err(HematiteError::ParseError(format!(
2332 "Recursive CTE '{}' requires UNION or UNION ALL",
2333 cte.name
2334 )));
2335 }
2336
2337 let mut anchor = (*cte.query).clone();
2338 anchor.set_operation = None;
2339 if anchor.references_source_name(&cte.name) {
2340 return Err(HematiteError::ParseError(format!(
2341 "Recursive CTE '{}' anchor term cannot reference itself",
2342 cte.name
2343 )));
2344 }
2345 if !set_operation.right.references_source_name(&cte.name) {
2346 return Err(HematiteError::ParseError(format!(
2347 "Recursive CTE '{}' recursive term must reference itself",
2348 cte.name
2349 )));
2350 }
2351 if anchor.columns.len() != set_operation.right.columns.len() {
2352 return Err(HematiteError::ParseError(format!(
2353 "Recursive CTE '{}' anchor and recursive terms must project the same number of columns",
2354 cte.name
2355 )));
2356 }
2357
2358 anchor.validate(catalog)?;
2359
2360 let mut recursive_term = (*set_operation.right).clone();
2361 recursive_term.with_clause.push(CommonTableExpression {
2362 name: cte.name.clone(),
2363 recursive: false,
2364 query: Box::new(anchor.clone()),
2365 });
2366 recursive_term.validate(catalog)?;
2367 } else {
2368 cte.query.validate(catalog)?;
2369 }
2370 }
2371
2372 let bindings = self.collect_source_bindings(catalog, &self.from)?;
2373 if bindings.is_empty() {
2374 return Err(HematiteError::ParseError(
2375 "SELECT requires at least one table source".to_string(),
2376 ));
2377 }
2378 self.validate_table_reference(catalog, &self.from, outer_bindings)?;
2379
2380 let has_aggregate = self.columns.iter().any(|item| match item {
2381 SelectItem::CountAll | SelectItem::Aggregate { .. } => true,
2382 SelectItem::Expression(expr) => Self::expression_contains_aggregate(expr),
2383 SelectItem::Window { .. } | SelectItem::Wildcard | SelectItem::Column(_) => false,
2384 }) || self.having_clause.as_ref().is_some_and(|having| {
2385 having
2386 .conditions
2387 .iter()
2388 .any(Self::condition_contains_aggregate)
2389 });
2390 if self.distinct && has_aggregate {
2391 return Err(HematiteError::ParseError(
2392 "DISTINCT cannot be combined with aggregate select items yet".to_string(),
2393 ));
2394 }
2395
2396 for item in &self.columns {
2397 match item {
2398 SelectItem::Column(name) => self.validate_column_reference_with_outer(
2399 name,
2400 catalog,
2401 &self.from,
2402 outer_bindings,
2403 )?,
2404 SelectItem::Expression(expr) => {
2405 self.validate_expression(expr, catalog, &self.from, outer_bindings)?;
2406 }
2407 SelectItem::Aggregate { column, .. } => {
2408 self.validate_column_reference_with_outer(
2409 column,
2410 catalog,
2411 &self.from,
2412 outer_bindings,
2413 )?;
2414 }
2415 SelectItem::Window { window, .. } => {
2416 for expr in &window.partition_by {
2417 self.validate_expression(expr, catalog, &self.from, outer_bindings)?;
2418 }
2419 for item in &window.order_by {
2420 self.validate_column_reference_with_outer(
2421 &item.column,
2422 catalog,
2423 &self.from,
2424 outer_bindings,
2425 )?;
2426 }
2427 }
2428 SelectItem::Wildcard | SelectItem::CountAll => {} }
2430 }
2431
2432 if let Some(where_clause) = &self.where_clause {
2433 for condition in &where_clause.conditions {
2434 self.validate_condition(condition, catalog, &self.from, outer_bindings)?;
2435 }
2436 }
2437
2438 for expr in &self.group_by {
2439 self.validate_expression(expr, catalog, &self.from, outer_bindings)?;
2440 }
2441
2442 if !self.group_by.is_empty() {
2443 for item in &self.columns {
2444 match item {
2445 SelectItem::Wildcard => {
2446 return Err(HematiteError::ParseError(
2447 "Wildcard select is not supported with GROUP BY".to_string(),
2448 ));
2449 }
2450 SelectItem::Column(name) => {
2451 let grouped = self.group_by.iter().any(|expr| {
2452 matches!(expr, Expression::Column(group_name) if group_name == name)
2453 });
2454 if !grouped {
2455 return Err(HematiteError::ParseError(format!(
2456 "Selected column '{}' must appear in GROUP BY or be aggregated",
2457 name
2458 )));
2459 }
2460 }
2461 SelectItem::Expression(_) => {
2462 return Err(HematiteError::ParseError(
2463 "Expression select items are not supported with GROUP BY yet"
2464 .to_string(),
2465 ));
2466 }
2467 SelectItem::Window { .. } => {
2468 return Err(HematiteError::ParseError(
2469 "Window functions cannot be combined with GROUP BY yet".to_string(),
2470 ));
2471 }
2472 SelectItem::CountAll | SelectItem::Aggregate { .. } => {}
2473 }
2474 }
2475 } else if has_aggregate
2476 && self
2477 .columns
2478 .iter()
2479 .any(|item| !matches!(item, SelectItem::CountAll | SelectItem::Aggregate { .. }))
2480 {
2481 return Err(HematiteError::ParseError(
2482 "Aggregate select items cannot be combined with non-aggregate select items without GROUP BY"
2483 .to_string(),
2484 ));
2485 }
2486
2487 if self.having_clause.is_some() && self.group_by.is_empty() && !has_aggregate {
2488 return Err(HematiteError::ParseError(
2489 "HAVING requires GROUP BY or aggregate select items".to_string(),
2490 ));
2491 }
2492
2493 for item in &self.order_by {
2494 self.validate_column_reference_with_outer(
2495 &item.column,
2496 catalog,
2497 &self.from,
2498 outer_bindings,
2499 )?;
2500 }
2501
2502 Ok(())
2503 }
2504
2505 #[cfg(test)]
2506 fn validate_table_reference(
2507 &self,
2508 catalog: &crate::catalog::Schema,
2509 from: &TableReference,
2510 outer_bindings: &[SourceBinding],
2511 ) -> Result<()> {
2512 match from {
2513 TableReference::Table(_, _) => Ok(()),
2514 TableReference::Derived { subquery, .. } => {
2515 subquery.validate(catalog)?;
2516 let _ = subquery.projected_column_names(catalog)?;
2517 Ok(())
2518 }
2519 TableReference::CrossJoin(left, right) => {
2520 self.validate_table_reference(catalog, left, outer_bindings)?;
2521 self.validate_table_reference(catalog, right, outer_bindings)
2522 }
2523 TableReference::InnerJoin { left, right, on }
2524 | TableReference::LeftJoin { left, right, on }
2525 | TableReference::RightJoin { left, right, on }
2526 | TableReference::FullOuterJoin { left, right, on } => {
2527 self.validate_table_reference(catalog, left, outer_bindings)?;
2528 self.validate_table_reference(catalog, right, outer_bindings)?;
2529 self.validate_condition(on, catalog, from, outer_bindings)
2530 }
2531 }
2532 }
2533
2534 #[cfg(test)]
2535 fn validate_condition(
2536 &self,
2537 condition: &Condition,
2538 catalog: &crate::catalog::Schema,
2539 from: &TableReference,
2540 outer_bindings: &[SourceBinding],
2541 ) -> Result<()> {
2542 match condition {
2543 Condition::Comparison { left, right, .. } => {
2544 self.validate_expression(left, catalog, from, outer_bindings)?;
2545 self.validate_expression(right, catalog, from, outer_bindings)?;
2546 }
2547 Condition::InList { expr, values, .. } => {
2548 self.validate_expression(expr, catalog, from, outer_bindings)?;
2549 for value in values {
2550 self.validate_expression(value, catalog, from, outer_bindings)?;
2551 }
2552 }
2553 Condition::InSubquery { expr, subquery, .. } => {
2554 self.validate_expression(expr, catalog, from, outer_bindings)?;
2555 subquery.validate_with_outer_bindings(
2556 catalog,
2557 &self.combined_outer_bindings(catalog, from, outer_bindings)?,
2558 )?;
2559 if subquery.columns.len() != 1 {
2560 return Err(HematiteError::ParseError(
2561 "Subquery predicates require exactly one selected column".to_string(),
2562 ));
2563 }
2564 }
2565 Condition::Between {
2566 expr, lower, upper, ..
2567 } => {
2568 self.validate_expression(expr, catalog, from, outer_bindings)?;
2569 self.validate_expression(lower, catalog, from, outer_bindings)?;
2570 self.validate_expression(upper, catalog, from, outer_bindings)?;
2571 }
2572 Condition::Like { expr, pattern, .. } => {
2573 self.validate_expression(expr, catalog, from, outer_bindings)?;
2574 self.validate_expression(pattern, catalog, from, outer_bindings)?;
2575 }
2576 Condition::Exists { subquery, .. } => {
2577 subquery.validate_with_outer_bindings(
2578 catalog,
2579 &self.combined_outer_bindings(catalog, from, outer_bindings)?,
2580 )?;
2581 }
2582 Condition::NullCheck { expr, .. } => {
2583 self.validate_expression(expr, catalog, from, outer_bindings)?;
2584 }
2585 Condition::Not(condition) => {
2586 self.validate_condition(condition, catalog, from, outer_bindings)?;
2587 }
2588 Condition::Logical { left, right, .. } => {
2589 self.validate_condition(left, catalog, from, outer_bindings)?;
2590 self.validate_condition(right, catalog, from, outer_bindings)?;
2591 }
2592 }
2593
2594 Ok(())
2595 }
2596
2597 #[cfg(test)]
2598 fn validate_expression(
2599 &self,
2600 expr: &Expression,
2601 catalog: &crate::catalog::Schema,
2602 from: &TableReference,
2603 outer_bindings: &[SourceBinding],
2604 ) -> Result<()> {
2605 match expr {
2606 Expression::Column(name) => {
2607 self.validate_column_reference_with_outer(name, catalog, from, outer_bindings)?
2608 }
2609 Expression::ScalarSubquery(subquery) => {
2610 subquery.validate_with_outer_bindings(
2611 catalog,
2612 &self.combined_outer_bindings(catalog, from, outer_bindings)?,
2613 )?;
2614 if subquery.columns.len() != 1 {
2615 return Err(HematiteError::ParseError(
2616 "Scalar subqueries require exactly one selected column".to_string(),
2617 ));
2618 }
2619 }
2620 Expression::Case {
2621 branches,
2622 else_expr,
2623 } => {
2624 for branch in branches {
2625 self.validate_expression(&branch.condition, catalog, from, outer_bindings)?;
2626 self.validate_expression(&branch.result, catalog, from, outer_bindings)?;
2627 }
2628 if let Some(else_expr) = else_expr {
2629 self.validate_expression(else_expr, catalog, from, outer_bindings)?;
2630 }
2631 }
2632 Expression::ScalarFunctionCall { args, .. } => {
2633 for arg in args {
2634 self.validate_expression(arg, catalog, from, outer_bindings)?;
2635 }
2636 }
2637 Expression::AggregateCall { target, .. } => {
2638 if let AggregateTarget::Column(name) = target {
2639 self.validate_column_reference_with_outer(name, catalog, from, outer_bindings)?;
2640 }
2641 }
2642 Expression::UnaryMinus(expr) => {
2643 self.validate_expression(expr, catalog, from, outer_bindings)?
2644 }
2645 Expression::UnaryNot(expr) => {
2646 self.validate_expression(expr, catalog, from, outer_bindings)?
2647 }
2648 Expression::Cast { expr, .. } => {
2649 self.validate_expression(expr, catalog, from, outer_bindings)?
2650 }
2651 Expression::Binary { left, right, .. } => {
2652 self.validate_expression(left, catalog, from, outer_bindings)?;
2653 self.validate_expression(right, catalog, from, outer_bindings)?;
2654 }
2655 Expression::Comparison { left, right, .. } => {
2656 self.validate_expression(left, catalog, from, outer_bindings)?;
2657 self.validate_expression(right, catalog, from, outer_bindings)?;
2658 }
2659 Expression::InList { expr, values, .. } => {
2660 self.validate_expression(expr, catalog, from, outer_bindings)?;
2661 for value in values {
2662 self.validate_expression(value, catalog, from, outer_bindings)?;
2663 }
2664 }
2665 Expression::InSubquery { expr, subquery, .. } => {
2666 self.validate_expression(expr, catalog, from, outer_bindings)?;
2667 subquery.validate_with_outer_bindings(
2668 catalog,
2669 &self.combined_outer_bindings(catalog, from, outer_bindings)?,
2670 )?;
2671 if subquery.columns.len() != 1 {
2672 return Err(HematiteError::ParseError(
2673 "Subquery predicates require exactly one selected column".to_string(),
2674 ));
2675 }
2676 }
2677 Expression::Between {
2678 expr, lower, upper, ..
2679 } => {
2680 self.validate_expression(expr, catalog, from, outer_bindings)?;
2681 self.validate_expression(lower, catalog, from, outer_bindings)?;
2682 self.validate_expression(upper, catalog, from, outer_bindings)?;
2683 }
2684 Expression::Like { expr, pattern, .. } => {
2685 self.validate_expression(expr, catalog, from, outer_bindings)?;
2686 self.validate_expression(pattern, catalog, from, outer_bindings)?;
2687 }
2688 Expression::Exists { subquery, .. } => {
2689 subquery.validate_with_outer_bindings(
2690 catalog,
2691 &self.combined_outer_bindings(catalog, from, outer_bindings)?,
2692 )?;
2693 }
2694 Expression::NullCheck { expr, .. } => {
2695 self.validate_expression(expr, catalog, from, outer_bindings)?;
2696 }
2697 Expression::Logical { left, right, .. } => {
2698 self.validate_expression(left, catalog, from, outer_bindings)?;
2699 self.validate_expression(right, catalog, from, outer_bindings)?;
2700 }
2701 Expression::Literal(_)
2702 | Expression::IntervalLiteral { .. }
2703 | Expression::Parameter(_) => {}
2704 }
2705
2706 Ok(())
2707 }
2708
2709 #[cfg(test)]
2710 fn expression_contains_aggregate(expr: &Expression) -> bool {
2711 match expr {
2712 Expression::AggregateCall { .. } => true,
2713 Expression::ScalarSubquery(_) => false,
2714 Expression::Case {
2715 branches,
2716 else_expr,
2717 } => {
2718 branches.iter().any(|branch| {
2719 Self::expression_contains_aggregate(&branch.condition)
2720 || Self::expression_contains_aggregate(&branch.result)
2721 }) || else_expr
2722 .as_ref()
2723 .is_some_and(|expr| Self::expression_contains_aggregate(expr))
2724 }
2725 Expression::ScalarFunctionCall { args, .. } => {
2726 args.iter().any(Self::expression_contains_aggregate)
2727 }
2728 Expression::Cast { expr, .. } => Self::expression_contains_aggregate(expr),
2729 Expression::UnaryMinus(expr) => Self::expression_contains_aggregate(expr),
2730 Expression::UnaryNot(expr) => Self::expression_contains_aggregate(expr),
2731 Expression::Binary { left, right, .. } => {
2732 Self::expression_contains_aggregate(left)
2733 || Self::expression_contains_aggregate(right)
2734 }
2735 Expression::Comparison { left, right, .. } => {
2736 Self::expression_contains_aggregate(left)
2737 || Self::expression_contains_aggregate(right)
2738 }
2739 Expression::InList { expr, values, .. } => {
2740 Self::expression_contains_aggregate(expr)
2741 || values.iter().any(Self::expression_contains_aggregate)
2742 }
2743 Expression::InSubquery { expr, subquery, .. } => {
2744 Self::expression_contains_aggregate(expr)
2745 || subquery.where_clause.as_ref().is_some_and(|where_clause| {
2746 where_clause
2747 .conditions
2748 .iter()
2749 .any(Self::condition_contains_aggregate)
2750 })
2751 }
2752 Expression::Between {
2753 expr, lower, upper, ..
2754 } => {
2755 Self::expression_contains_aggregate(expr)
2756 || Self::expression_contains_aggregate(lower)
2757 || Self::expression_contains_aggregate(upper)
2758 }
2759 Expression::Like { expr, pattern, .. } => {
2760 Self::expression_contains_aggregate(expr)
2761 || Self::expression_contains_aggregate(pattern)
2762 }
2763 Expression::Exists { subquery, .. } => {
2764 subquery.where_clause.as_ref().is_some_and(|where_clause| {
2765 where_clause
2766 .conditions
2767 .iter()
2768 .any(Self::condition_contains_aggregate)
2769 })
2770 }
2771 Expression::NullCheck { expr, .. } => Self::expression_contains_aggregate(expr),
2772 Expression::Logical { left, right, .. } => {
2773 Self::expression_contains_aggregate(left)
2774 || Self::expression_contains_aggregate(right)
2775 }
2776 Expression::Column(_)
2777 | Expression::Literal(_)
2778 | Expression::IntervalLiteral { .. }
2779 | Expression::Parameter(_) => false,
2780 }
2781 }
2782
2783 #[cfg(test)]
2784 fn condition_contains_aggregate(condition: &Condition) -> bool {
2785 match condition {
2786 Condition::Comparison { left, right, .. } => {
2787 Self::expression_contains_aggregate(left)
2788 || Self::expression_contains_aggregate(right)
2789 }
2790 Condition::InList { expr, values, .. } => {
2791 Self::expression_contains_aggregate(expr)
2792 || values.iter().any(Self::expression_contains_aggregate)
2793 }
2794 Condition::InSubquery { expr, subquery, .. } => {
2795 Self::expression_contains_aggregate(expr)
2796 || subquery.where_clause.as_ref().is_some_and(|where_clause| {
2797 where_clause
2798 .conditions
2799 .iter()
2800 .any(Self::condition_contains_aggregate)
2801 })
2802 }
2803 Condition::Between {
2804 expr, lower, upper, ..
2805 } => {
2806 Self::expression_contains_aggregate(expr)
2807 || Self::expression_contains_aggregate(lower)
2808 || Self::expression_contains_aggregate(upper)
2809 }
2810 Condition::Like { expr, pattern, .. } => {
2811 Self::expression_contains_aggregate(expr)
2812 || Self::expression_contains_aggregate(pattern)
2813 }
2814 Condition::Exists { subquery, .. } => {
2815 subquery.where_clause.as_ref().is_some_and(|where_clause| {
2816 where_clause
2817 .conditions
2818 .iter()
2819 .any(Self::condition_contains_aggregate)
2820 })
2821 }
2822 Condition::NullCheck { expr, .. } => Self::expression_contains_aggregate(expr),
2823 Condition::Not(condition) => Self::condition_contains_aggregate(condition),
2824 Condition::Logical { left, right, .. } => {
2825 Self::condition_contains_aggregate(left)
2826 || Self::condition_contains_aggregate(right)
2827 }
2828 }
2829 }
2830}
2831
2832impl TableReference {
2833 fn collect_dependency_names_into(&self, names: &mut std::collections::BTreeSet<String>) {
2834 match self {
2835 TableReference::Table(table_name, _) => {
2836 names.insert(table_name.clone());
2837 }
2838 TableReference::Derived { subquery, .. } => {
2839 subquery.collect_dependency_names_into(names)
2840 }
2841 TableReference::CrossJoin(left, right) => {
2842 left.collect_dependency_names_into(names);
2843 right.collect_dependency_names_into(names);
2844 }
2845 TableReference::InnerJoin { left, right, on }
2846 | TableReference::LeftJoin { left, right, on }
2847 | TableReference::RightJoin { left, right, on }
2848 | TableReference::FullOuterJoin { left, right, on } => {
2849 left.collect_dependency_names_into(names);
2850 right.collect_dependency_names_into(names);
2851 on.collect_dependency_names_into(names);
2852 }
2853 }
2854 }
2855
2856 fn to_sql(&self) -> String {
2857 match self {
2858 TableReference::Table(table_name, Some(alias)) => format!("{table_name} {alias}"),
2859 TableReference::Table(table_name, None) => table_name.clone(),
2860 TableReference::Derived { subquery, alias } => {
2861 format!("({}) {}", subquery.to_sql(), alias)
2862 }
2863 TableReference::CrossJoin(left, right) => {
2864 format!("{}, {}", left.to_sql(), right.to_sql())
2865 }
2866 TableReference::InnerJoin { left, right, on } => {
2867 format!(
2868 "{} INNER JOIN {} ON {}",
2869 left.to_sql(),
2870 right.to_sql(),
2871 on.to_sql()
2872 )
2873 }
2874 TableReference::LeftJoin { left, right, on } => {
2875 format!(
2876 "{} LEFT JOIN {} ON {}",
2877 left.to_sql(),
2878 right.to_sql(),
2879 on.to_sql()
2880 )
2881 }
2882 TableReference::RightJoin { left, right, on } => {
2883 format!(
2884 "{} RIGHT JOIN {} ON {}",
2885 left.to_sql(),
2886 right.to_sql(),
2887 on.to_sql()
2888 )
2889 }
2890 TableReference::FullOuterJoin { left, right, on } => format!(
2891 "{} FULL OUTER JOIN {} ON {}",
2892 left.to_sql(),
2893 right.to_sql(),
2894 on.to_sql()
2895 ),
2896 }
2897 }
2898
2899 pub(crate) fn references_source_name(&self, name: &str) -> bool {
2900 match self {
2901 TableReference::Table(table_name, _) => table_name.eq_ignore_ascii_case(name),
2902 TableReference::Derived { subquery, .. } => subquery.references_source_name(name),
2903 TableReference::CrossJoin(left, right) => {
2904 left.references_source_name(name) || right.references_source_name(name)
2905 }
2906 TableReference::InnerJoin { left, right, on }
2907 | TableReference::LeftJoin { left, right, on }
2908 | TableReference::RightJoin { left, right, on }
2909 | TableReference::FullOuterJoin { left, right, on } => {
2910 left.references_source_name(name)
2911 || right.references_source_name(name)
2912 || on.references_source_name(name)
2913 }
2914 }
2915 }
2916
2917 pub(crate) fn has_non_table_source(&self, statement: &SelectStatement) -> bool {
2918 match self {
2919 TableReference::Table(table_name, _) => statement.references_cte(table_name),
2920 TableReference::Derived { .. } => true,
2921 TableReference::CrossJoin(left, right)
2922 | TableReference::InnerJoin { left, right, .. }
2923 | TableReference::LeftJoin { left, right, .. }
2924 | TableReference::RightJoin { left, right, .. }
2925 | TableReference::FullOuterJoin { left, right, .. } => {
2926 left.has_non_table_source(statement) || right.has_non_table_source(statement)
2927 }
2928 }
2929 }
2930}
2931
2932#[cfg(test)]
2933impl InsertStatement {
2934 pub fn validate(&self, catalog: &crate::catalog::Schema) -> Result<()> {
2935 let table = catalog.get_table_by_name(&self.table).ok_or_else(|| {
2936 HematiteError::ParseError(format!("Table '{}' does not exist", self.table))
2937 })?;
2938
2939 let mut seen_columns = std::collections::HashSet::new();
2940
2941 for col_name in &self.columns {
2943 if !seen_columns.insert(col_name) {
2944 return Err(HematiteError::ParseError(format!(
2945 "Duplicate column '{}' in INSERT",
2946 col_name
2947 )));
2948 }
2949 if table.get_column_by_name(col_name).is_none() {
2950 return Err(HematiteError::ParseError(format!(
2951 "Column '{}' does not exist in table '{}'",
2952 col_name, self.table
2953 )));
2954 }
2955 }
2956
2957 if self.columns.is_empty() {
2958 return Err(HematiteError::ParseError(
2959 "INSERT must specify at least one column".to_string(),
2960 ));
2961 }
2962
2963 match &self.source {
2964 InsertSource::Values(rows) => {
2965 for (i, value_row) in rows.iter().enumerate() {
2966 if value_row.len() != self.columns.len() {
2967 return Err(HematiteError::ParseError(format!(
2968 "Value row {} has {} values, expected {}",
2969 i,
2970 value_row.len(),
2971 self.columns.len()
2972 )));
2973 }
2974
2975 for value in value_row {
2976 if matches!(value, Expression::Column(_)) {
2977 return Err(HematiteError::ParseError(format!(
2978 "INSERT value row {} cannot reference columns",
2979 i
2980 )));
2981 }
2982 }
2983 }
2984 }
2985 InsertSource::Select(select) => {
2986 if select.columns.len() != self.columns.len() {
2987 return Err(HematiteError::ParseError(format!(
2988 "INSERT SELECT returns {} columns, expected {}",
2989 select.columns.len(),
2990 self.columns.len()
2991 )));
2992 }
2993 }
2994 }
2995
2996 Ok(())
2997 }
2998}
2999
3000#[cfg(test)]
3001impl UpdateStatement {
3002 pub fn validate(&self, catalog: &crate::catalog::Schema) -> Result<()> {
3003 let table = catalog.get_table_by_name(&self.table).ok_or_else(|| {
3004 HematiteError::ParseError(format!("Table '{}' does not exist", self.table))
3005 })?;
3006
3007 if self.assignments.is_empty() {
3008 return Err(HematiteError::ParseError(
3009 "UPDATE must specify at least one assignment".to_string(),
3010 ));
3011 }
3012
3013 let mut seen_columns = std::collections::HashSet::new();
3014 let scope = SelectStatement {
3015 with_clause: Vec::new(),
3016 distinct: false,
3017 columns: Vec::new(),
3018 column_aliases: Vec::new(),
3019 from: self.source(),
3020 where_clause: None,
3021 group_by: Vec::new(),
3022 having_clause: None,
3023 order_by: Vec::new(),
3024 limit: None,
3025 offset: None,
3026 set_operation: None,
3027 };
3028 for assignment in &self.assignments {
3029 if !seen_columns.insert(&assignment.column) {
3030 return Err(HematiteError::ParseError(format!(
3031 "Duplicate column '{}' in UPDATE",
3032 assignment.column
3033 )));
3034 }
3035
3036 if table.get_column_by_name(&assignment.column).is_none() {
3037 return Err(HematiteError::ParseError(format!(
3038 "Column '{}' does not exist in table '{}'",
3039 assignment.column, self.table
3040 )));
3041 }
3042
3043 scope.validate_expression(&assignment.value, catalog, &scope.from, &[])?;
3044 }
3045
3046 if let Some(where_clause) = &self.where_clause {
3047 for condition in &where_clause.conditions {
3048 scope.validate_condition(condition, catalog, &scope.from, &[])?;
3049 }
3050 }
3051
3052 Ok(())
3053 }
3054}
3055
3056#[cfg(test)]
3057impl CreateStatement {
3058 pub fn validate(&self, catalog: &crate::catalog::Schema) -> Result<()> {
3059 if catalog.get_table_by_name(&self.table).is_some() {
3061 if self.if_not_exists {
3062 return Ok(());
3063 }
3064 return Err(HematiteError::ParseError(format!(
3065 "Table '{}' already exists",
3066 self.table
3067 )));
3068 }
3069
3070 let mut column_names = std::collections::HashSet::new();
3072 for col in &self.columns {
3073 if column_names.contains(&col.name) {
3074 return Err(HematiteError::ParseError(format!(
3075 "Duplicate column name '{}'",
3076 col.name
3077 )));
3078 }
3079 if (!col.character_set.is_none() || !col.collation.is_none())
3080 && !col.data_type.supports_text_metadata()
3081 {
3082 return Err(HematiteError::ParseError(format!(
3083 "Column '{}' can only use CHARACTER SET or COLLATE with CHAR, VARCHAR, or TEXT",
3084 col.name
3085 )));
3086 }
3087 column_names.insert(col.name.clone());
3088 }
3089
3090 if !self.columns.iter().any(|col| col.primary_key) {
3092 return Err(HematiteError::ParseError(
3093 "Table must have at least one primary key column".to_string(),
3094 ));
3095 }
3096
3097 let auto_increment_columns = self
3098 .columns
3099 .iter()
3100 .filter(|column| column.auto_increment)
3101 .collect::<Vec<_>>();
3102 if auto_increment_columns.len() > 1 {
3103 return Err(HematiteError::ParseError(
3104 "Only one AUTO_INCREMENT column is allowed per table".to_string(),
3105 ));
3106 }
3107 if let Some(column) = auto_increment_columns.first() {
3108 if !matches!(column.data_type, SqlTypeName::Int | SqlTypeName::UInt) {
3109 return Err(HematiteError::ParseError(format!(
3110 "AUTO_INCREMENT column '{}' must use an integer type",
3111 column.name
3112 )));
3113 }
3114 if !column.primary_key {
3115 return Err(HematiteError::ParseError(format!(
3116 "AUTO_INCREMENT column '{}' must be a PRIMARY KEY",
3117 column.name
3118 )));
3119 }
3120 if column.default_value.is_some() {
3121 return Err(HematiteError::ParseError(format!(
3122 "AUTO_INCREMENT column '{}' cannot also declare a DEFAULT value",
3123 column.name
3124 )));
3125 }
3126 }
3127
3128 for unique_constraint in self.unique_constraints() {
3129 self.validate_unique_constraint(unique_constraint)?;
3130 }
3131
3132 for foreign_key in self.foreign_keys() {
3133 self.validate_foreign_key(catalog, foreign_key)?;
3134 }
3135
3136 Ok(())
3137 }
3138
3139 fn foreign_keys(&self) -> Vec<&ForeignKeyDefinition> {
3140 let mut foreign_keys = self
3141 .columns
3142 .iter()
3143 .filter_map(|column| column.references.as_ref())
3144 .collect::<Vec<_>>();
3145
3146 foreign_keys.extend(
3147 self.constraints
3148 .iter()
3149 .filter_map(|constraint| match constraint {
3150 TableConstraint::Check(_) | TableConstraint::Unique(_) => None,
3151 TableConstraint::ForeignKey(foreign_key) => Some(foreign_key),
3152 }),
3153 );
3154
3155 foreign_keys
3156 }
3157
3158 fn unique_constraints(&self) -> Vec<&UniqueConstraintDefinition> {
3159 self.constraints
3160 .iter()
3161 .filter_map(|constraint| match constraint {
3162 TableConstraint::Unique(unique) => Some(unique),
3163 TableConstraint::Check(_) | TableConstraint::ForeignKey(_) => None,
3164 })
3165 .collect()
3166 }
3167
3168 fn validate_unique_constraint(
3169 &self,
3170 unique_constraint: &UniqueConstraintDefinition,
3171 ) -> Result<()> {
3172 if unique_constraint.columns.is_empty() {
3173 return Err(HematiteError::ParseError(
3174 "UNIQUE constraint must reference at least one column".to_string(),
3175 ));
3176 }
3177
3178 self.validate_local_constraint_columns(&unique_constraint.columns, "UNIQUE constraint")?;
3179
3180 Ok(())
3181 }
3182
3183 fn validate_local_constraint_columns(
3184 &self,
3185 columns: &[String],
3186 constraint_label: &str,
3187 ) -> Result<()> {
3188 validate_named_columns(columns, constraint_label, |column| {
3189 if self
3190 .columns
3191 .iter()
3192 .any(|candidate| candidate.name == column)
3193 {
3194 Ok(())
3195 } else {
3196 Err(HematiteError::ParseError(format!(
3197 "{} column '{}' does not exist in table '{}'",
3198 constraint_label, column, self.table
3199 )))
3200 }
3201 })
3202 }
3203
3204 fn validate_foreign_key(
3205 &self,
3206 catalog: &crate::catalog::Schema,
3207 foreign_key: &ForeignKeyDefinition,
3208 ) -> Result<()> {
3209 if foreign_key.columns.is_empty() {
3210 return Err(HematiteError::ParseError(
3211 "Foreign key must reference at least one local column".to_string(),
3212 ));
3213 }
3214 if foreign_key.columns.len() != foreign_key.referenced_columns.len() {
3215 return Err(HematiteError::ParseError(format!(
3216 "Foreign key on table '{}' must reference the same number of local and parent columns",
3217 self.table
3218 )));
3219 }
3220 self.validate_local_constraint_columns(&foreign_key.columns, "Foreign key")?;
3221
3222 let referenced_table = catalog
3223 .get_table_by_name(&foreign_key.referenced_table)
3224 .ok_or_else(|| {
3225 HematiteError::ParseError(format!(
3226 "Referenced table '{}' does not exist",
3227 foreign_key.referenced_table
3228 ))
3229 })?;
3230 let referenced_column_indices =
3231 self.referenced_column_indices(referenced_table, foreign_key)?;
3232 let references_primary_key =
3233 referenced_table.primary_key_columns == referenced_column_indices;
3234 let references_unique_index = referenced_table
3235 .secondary_indexes
3236 .iter()
3237 .any(|index| index.unique && index.column_indices == referenced_column_indices);
3238
3239 if !references_primary_key && !references_unique_index {
3240 return Err(HematiteError::ParseError(format!(
3241 "Foreign key '{}.{:?}' must reference a PRIMARY KEY or UNIQUE index with the same column list",
3242 foreign_key.referenced_table, foreign_key.referenced_columns
3243 )));
3244 }
3245
3246 Ok(())
3247 }
3248
3249 fn referenced_column_indices(
3250 &self,
3251 referenced_table: &crate::catalog::Table,
3252 foreign_key: &ForeignKeyDefinition,
3253 ) -> Result<Vec<usize>> {
3254 foreign_key
3255 .referenced_columns
3256 .iter()
3257 .map(|column| {
3258 referenced_table.get_column_index(column).ok_or_else(|| {
3259 HematiteError::ParseError(format!(
3260 "Referenced column '{}.{}' does not exist",
3261 foreign_key.referenced_table, column
3262 ))
3263 })
3264 })
3265 .collect()
3266 }
3267}
3268
3269#[cfg(test)]
3270impl DeleteStatement {
3271 pub fn validate(&self, catalog: &crate::catalog::Schema) -> Result<()> {
3272 let _table = require_table(catalog, &self.table)?;
3273 let scope = SelectStatement {
3274 with_clause: Vec::new(),
3275 distinct: false,
3276 columns: Vec::new(),
3277 column_aliases: Vec::new(),
3278 from: self.source(),
3279 where_clause: None,
3280 group_by: Vec::new(),
3281 having_clause: None,
3282 order_by: Vec::new(),
3283 limit: None,
3284 offset: None,
3285 set_operation: None,
3286 };
3287
3288 if let Some(where_clause) = &self.where_clause {
3289 for condition in &where_clause.conditions {
3290 scope.validate_condition(condition, catalog, &scope.from, &[])?;
3291 }
3292 }
3293
3294 Ok(())
3295 }
3296}
3297
3298#[cfg(test)]
3299impl DropStatement {
3300 pub fn validate(&self, catalog: &crate::catalog::Schema) -> Result<()> {
3301 if self.if_exists && catalog.get_table_by_name(&self.table).is_none() {
3302 return Ok(());
3303 }
3304 let _table = require_table(catalog, &self.table)?;
3305 Ok(())
3306 }
3307}
3308
3309#[cfg(test)]
3310impl AlterStatement {
3311 pub fn validate(&self, catalog: &crate::catalog::Schema) -> Result<()> {
3312 match &self.operation {
3313 AlterOperation::RenameTo(new_name) => {
3314 self.require_table(catalog)?;
3315 if new_name == &self.table {
3316 return Err(HematiteError::ParseError(
3317 "ALTER TABLE RENAME TO requires a different table name".to_string(),
3318 ));
3319 }
3320 if catalog.get_table_by_name(new_name).is_some() {
3321 return Err(HematiteError::ParseError(format!(
3322 "Table '{}' already exists",
3323 new_name
3324 )));
3325 }
3326 }
3327 AlterOperation::RenameColumn { old_name, new_name } => {
3328 self.validate_rename_column(catalog, old_name, new_name)?;
3329 }
3330 AlterOperation::AddColumn(column) => {
3331 let table = self.require_table(catalog)?;
3332 if table.get_column_by_name(&column.name).is_some() {
3333 return Err(HematiteError::ParseError(format!(
3334 "Column '{}' already exists in table '{}'",
3335 column.name, self.table
3336 )));
3337 }
3338 if column.primary_key {
3339 return Err(HematiteError::ParseError(
3340 "ALTER TABLE ADD COLUMN cannot add a PRIMARY KEY column".to_string(),
3341 ));
3342 }
3343 if column.auto_increment {
3344 return Err(HematiteError::ParseError(
3345 "ALTER TABLE ADD COLUMN does not support AUTO_INCREMENT columns"
3346 .to_string(),
3347 ));
3348 }
3349 if column.unique {
3350 return Err(HematiteError::ParseError(
3351 "ALTER TABLE ADD COLUMN does not support UNIQUE columns; add a UNIQUE index separately".to_string(),
3352 ));
3353 }
3354 if !column.nullable && column.default_value.is_none() {
3355 return Err(HematiteError::ParseError(
3356 "ALTER TABLE ADD COLUMN requires the new column to be nullable or have a DEFAULT value".to_string(),
3357 ));
3358 }
3359 if column.check_constraint.is_some() {
3360 return Err(HematiteError::ParseError(
3361 "ALTER TABLE ADD COLUMN does not support CHECK constraints".to_string(),
3362 ));
3363 }
3364 if column.references.is_some() {
3365 return Err(HematiteError::ParseError(
3366 "ALTER TABLE ADD COLUMN does not support FOREIGN KEY constraints"
3367 .to_string(),
3368 ));
3369 }
3370 if let Some(default_value) = &column.default_value {
3371 if default_value.is_null() && !column.nullable {
3372 return Err(HematiteError::ParseError(format!(
3373 "Column '{}' cannot use DEFAULT NULL when declared NOT NULL",
3374 column.name
3375 )));
3376 }
3377 if !default_value.is_null()
3378 && !default_value.is_compatible_with(column.data_type.clone())
3379 {
3380 return Err(HematiteError::ParseError(format!(
3381 "DEFAULT value for column '{}' is incompatible with {:?}",
3382 column.name, column.data_type
3383 )));
3384 }
3385 }
3386 }
3387 AlterOperation::AddConstraint(constraint) => match constraint {
3388 TableConstraint::Check(check) => {
3389 if check.name.is_none() {
3390 return Err(HematiteError::ParseError(
3391 "ALTER TABLE ADD CONSTRAINT requires a constraint name".to_string(),
3392 ));
3393 }
3394 }
3395 TableConstraint::Unique(unique) => {
3396 if unique.name.is_none() {
3397 return Err(HematiteError::ParseError(
3398 "ALTER TABLE ADD CONSTRAINT requires a constraint name".to_string(),
3399 ));
3400 }
3401 }
3402 TableConstraint::ForeignKey(foreign_key) => {
3403 if foreign_key.name.is_none() {
3404 return Err(HematiteError::ParseError(
3405 "ALTER TABLE ADD CONSTRAINT requires a constraint name".to_string(),
3406 ));
3407 }
3408 }
3409 },
3410 AlterOperation::DropColumn(column_name) => {
3411 self.validate_drop_column(catalog, column_name)?;
3412 }
3413 AlterOperation::DropConstraint(constraint_name) => {
3414 let table = self.require_table(catalog)?;
3415 if !table
3416 .list_named_constraints()
3417 .iter()
3418 .any(|constraint| constraint.name == *constraint_name)
3419 {
3420 return Err(HematiteError::ParseError(format!(
3421 "Constraint '{}' does not exist on table '{}'",
3422 constraint_name, self.table
3423 )));
3424 }
3425 }
3426 AlterOperation::AlterColumnSetDefault {
3427 column_name,
3428 default_value,
3429 } => {
3430 self.validate_set_column_default(catalog, column_name, default_value)?;
3431 }
3432 AlterOperation::AlterColumnDropDefault { column_name } => {
3433 self.validate_existing_column(catalog, column_name)?;
3434 }
3435 AlterOperation::AlterColumnSetNotNull { column_name } => {
3436 self.validate_existing_column(catalog, column_name)?;
3437 }
3438 AlterOperation::AlterColumnDropNotNull { column_name } => {
3439 self.validate_drop_not_null(catalog, column_name)?;
3440 }
3441 }
3442
3443 Ok(())
3444 }
3445
3446 fn require_table<'a>(
3447 &self,
3448 catalog: &'a crate::catalog::Schema,
3449 ) -> Result<&'a crate::catalog::Table> {
3450 require_table(catalog, &self.table)
3451 }
3452
3453 fn validate_rename_column(
3454 &self,
3455 catalog: &crate::catalog::Schema,
3456 old_name: &str,
3457 new_name: &str,
3458 ) -> Result<()> {
3459 let table = self.require_table(catalog)?;
3460 if old_name == new_name {
3461 return Err(HematiteError::ParseError(
3462 "ALTER TABLE RENAME COLUMN requires a different column name".to_string(),
3463 ));
3464 }
3465 if table.get_column_by_name(old_name).is_none() {
3466 return Err(HematiteError::ParseError(format!(
3467 "Column '{}' does not exist in table '{}'",
3468 old_name, self.table
3469 )));
3470 }
3471 if table.get_column_by_name(new_name).is_some() {
3472 return Err(HematiteError::ParseError(format!(
3473 "Column '{}' already exists in table '{}'",
3474 new_name, self.table
3475 )));
3476 }
3477 Ok(())
3478 }
3479
3480 fn validate_existing_column(
3481 &self,
3482 catalog: &crate::catalog::Schema,
3483 column_name: &str,
3484 ) -> Result<()> {
3485 let table = self.require_table(catalog)?;
3486 if table.get_column_by_name(column_name).is_none() {
3487 return Err(HematiteError::ParseError(format!(
3488 "Column '{}' does not exist in table '{}'",
3489 column_name, self.table
3490 )));
3491 }
3492 Ok(())
3493 }
3494
3495 fn validate_set_column_default(
3496 &self,
3497 catalog: &crate::catalog::Schema,
3498 column_name: &str,
3499 default_value: &LiteralValue,
3500 ) -> Result<()> {
3501 let table = self.require_table(catalog)?;
3502 let column = table.get_column_by_name(column_name).ok_or_else(|| {
3503 HematiteError::ParseError(format!(
3504 "Column '{}' does not exist in table '{}'",
3505 column_name, self.table
3506 ))
3507 })?;
3508 if default_value.is_null() && !column.nullable {
3509 return Err(HematiteError::ParseError(format!(
3510 "Column '{}' cannot use DEFAULT NULL while declared NOT NULL",
3511 column_name
3512 )));
3513 }
3514 if !default_value.is_null()
3515 && !default_value
3516 .is_compatible_with(sql_type_name_for_catalog_type(column.data_type.clone()))
3517 {
3518 return Err(HematiteError::ParseError(format!(
3519 "DEFAULT value for column '{}' is incompatible with {:?}",
3520 column_name, column.data_type
3521 )));
3522 }
3523 Ok(())
3524 }
3525
3526 fn validate_drop_not_null(
3527 &self,
3528 catalog: &crate::catalog::Schema,
3529 column_name: &str,
3530 ) -> Result<()> {
3531 let table = self.require_table(catalog)?;
3532 let column = table.get_column_by_name(column_name).ok_or_else(|| {
3533 HematiteError::ParseError(format!(
3534 "Column '{}' does not exist in table '{}'",
3535 column_name, self.table
3536 ))
3537 })?;
3538 if column.primary_key {
3539 return Err(HematiteError::ParseError(format!(
3540 "Primary-key column '{}' cannot drop NOT NULL",
3541 column_name
3542 )));
3543 }
3544 if column.auto_increment {
3545 return Err(HematiteError::ParseError(format!(
3546 "AUTO_INCREMENT column '{}' cannot drop NOT NULL",
3547 column_name
3548 )));
3549 }
3550 Ok(())
3551 }
3552
3553 fn validate_drop_column(
3554 &self,
3555 catalog: &crate::catalog::Schema,
3556 column_name: &str,
3557 ) -> Result<()> {
3558 let table = self.require_table(catalog)?;
3559 let column_index = table.get_column_index(column_name).ok_or_else(|| {
3560 HematiteError::ParseError(format!(
3561 "Column '{}' does not exist in table '{}'",
3562 column_name, self.table
3563 ))
3564 })?;
3565 if table.columns.len() == 1 {
3566 return Err(HematiteError::ParseError(
3567 "ALTER TABLE DROP COLUMN cannot remove the last column".to_string(),
3568 ));
3569 }
3570 if table.primary_key_columns.contains(&column_index) {
3571 return Err(HematiteError::ParseError(format!(
3572 "Cannot drop primary-key column '{}'",
3573 column_name
3574 )));
3575 }
3576 if table
3577 .secondary_indexes
3578 .iter()
3579 .any(|index| index.column_indices.contains(&column_index))
3580 {
3581 return Err(HematiteError::ParseError(format!(
3582 "Cannot drop column '{}' because it is used by an index",
3583 column_name
3584 )));
3585 }
3586 if table
3587 .foreign_keys
3588 .iter()
3589 .any(|foreign_key| foreign_key.column_indices.contains(&column_index))
3590 {
3591 return Err(HematiteError::ParseError(format!(
3592 "Cannot drop column '{}' because it is used by a foreign key",
3593 column_name
3594 )));
3595 }
3596 for constraint in &table.check_constraints {
3597 let condition =
3598 crate::parser::parser::parse_condition_fragment(&constraint.expression_sql)?;
3599 if condition.references_column(column_name, Some(&table.name)) {
3600 return Err(HematiteError::ParseError(format!(
3601 "Cannot drop column '{}' because it is used by a CHECK constraint",
3602 column_name
3603 )));
3604 }
3605 }
3606 if catalog.tables().values().any(|other_table| {
3607 other_table.name != table.name
3608 && other_table.foreign_keys.iter().any(|foreign_key| {
3609 foreign_key.referenced_table == table.name
3610 && foreign_key
3611 .referenced_columns
3612 .iter()
3613 .any(|referenced_column| referenced_column == column_name)
3614 })
3615 }) {
3616 return Err(HematiteError::ParseError(format!(
3617 "Cannot drop column '{}' because it is referenced by a foreign key",
3618 column_name
3619 )));
3620 }
3621 Ok(())
3622 }
3623}
3624
3625impl Condition {
3626 pub(crate) fn references_source_name(&self, name: &str) -> bool {
3627 match self {
3628 Condition::Comparison { left, right, .. } => {
3629 left.references_source_name(name) || right.references_source_name(name)
3630 }
3631 Condition::InList { expr, values, .. } => {
3632 expr.references_source_name(name)
3633 || values
3634 .iter()
3635 .any(|value| value.references_source_name(name))
3636 }
3637 Condition::InSubquery { expr, subquery, .. } => {
3638 expr.references_source_name(name) || subquery.references_source_name(name)
3639 }
3640 Condition::Between {
3641 expr, lower, upper, ..
3642 } => {
3643 expr.references_source_name(name)
3644 || lower.references_source_name(name)
3645 || upper.references_source_name(name)
3646 }
3647 Condition::Like { expr, pattern, .. } => {
3648 expr.references_source_name(name) || pattern.references_source_name(name)
3649 }
3650 Condition::Exists { subquery, .. } => subquery.references_source_name(name),
3651 Condition::NullCheck { expr, .. } => expr.references_source_name(name),
3652 Condition::Not(condition) => condition.references_source_name(name),
3653 Condition::Logical { left, right, .. } => {
3654 left.references_source_name(name) || right.references_source_name(name)
3655 }
3656 }
3657 }
3658
3659 pub(crate) fn references_column(&self, column_name: &str, table_name: Option<&str>) -> bool {
3660 match self {
3661 Condition::Comparison { left, right, .. } => {
3662 left.references_column(column_name, table_name)
3663 || right.references_column(column_name, table_name)
3664 }
3665 Condition::InList { expr, values, .. } => {
3666 expr.references_column(column_name, table_name)
3667 || values
3668 .iter()
3669 .any(|value| value.references_column(column_name, table_name))
3670 }
3671 Condition::InSubquery { expr, .. } => expr.references_column(column_name, table_name),
3672 Condition::Between {
3673 expr, lower, upper, ..
3674 } => {
3675 expr.references_column(column_name, table_name)
3676 || lower.references_column(column_name, table_name)
3677 || upper.references_column(column_name, table_name)
3678 }
3679 Condition::Like { expr, pattern, .. } => {
3680 expr.references_column(column_name, table_name)
3681 || pattern.references_column(column_name, table_name)
3682 }
3683 Condition::Exists { .. } => false,
3684 Condition::NullCheck { expr, .. } => expr.references_column(column_name, table_name),
3685 Condition::Not(condition) => condition.references_column(column_name, table_name),
3686 Condition::Logical { left, right, .. } => {
3687 left.references_column(column_name, table_name)
3688 || right.references_column(column_name, table_name)
3689 }
3690 }
3691 }
3692
3693 pub(crate) fn rename_column_references(
3694 &mut self,
3695 old_name: &str,
3696 new_name: &str,
3697 table_name: Option<&str>,
3698 ) {
3699 match self {
3700 Condition::Comparison { left, right, .. } => {
3701 left.rename_column_references(old_name, new_name, table_name);
3702 right.rename_column_references(old_name, new_name, table_name);
3703 }
3704 Condition::InList { expr, values, .. } => {
3705 expr.rename_column_references(old_name, new_name, table_name);
3706 for value in values {
3707 value.rename_column_references(old_name, new_name, table_name);
3708 }
3709 }
3710 Condition::InSubquery { expr, .. } => {
3711 expr.rename_column_references(old_name, new_name, table_name);
3712 }
3713 Condition::Between {
3714 expr, lower, upper, ..
3715 } => {
3716 expr.rename_column_references(old_name, new_name, table_name);
3717 lower.rename_column_references(old_name, new_name, table_name);
3718 upper.rename_column_references(old_name, new_name, table_name);
3719 }
3720 Condition::Like { expr, pattern, .. } => {
3721 expr.rename_column_references(old_name, new_name, table_name);
3722 pattern.rename_column_references(old_name, new_name, table_name);
3723 }
3724 Condition::Exists { .. } => {}
3725 Condition::NullCheck { expr, .. } => {
3726 expr.rename_column_references(old_name, new_name, table_name);
3727 }
3728 Condition::Not(condition) => {
3729 condition.rename_column_references(old_name, new_name, table_name);
3730 }
3731 Condition::Logical { left, right, .. } => {
3732 left.rename_column_references(old_name, new_name, table_name);
3733 right.rename_column_references(old_name, new_name, table_name);
3734 }
3735 }
3736 }
3737
3738 pub fn to_sql(&self) -> String {
3739 match self {
3740 Condition::Comparison {
3741 left,
3742 operator,
3743 right,
3744 } => format!("{} {} {}", left.to_sql(), operator.to_sql(), right.to_sql()),
3745 Condition::InList {
3746 expr,
3747 values,
3748 is_not,
3749 } => format!(
3750 "{} {}IN ({})",
3751 expr.to_sql(),
3752 if *is_not { "NOT " } else { "" },
3753 values
3754 .iter()
3755 .map(Expression::to_sql)
3756 .collect::<Vec<_>>()
3757 .join(", ")
3758 ),
3759 Condition::InSubquery { expr, is_not, .. } => format!(
3760 "{} {}IN (<subquery>)",
3761 expr.to_sql(),
3762 if *is_not { "NOT " } else { "" }
3763 ),
3764 Condition::Between {
3765 expr,
3766 lower,
3767 upper,
3768 is_not,
3769 } => format!(
3770 "{} {}BETWEEN {} AND {}",
3771 expr.to_sql(),
3772 if *is_not { "NOT " } else { "" },
3773 lower.to_sql(),
3774 upper.to_sql()
3775 ),
3776 Condition::Like {
3777 expr,
3778 pattern,
3779 is_not,
3780 } => format!(
3781 "{} {}LIKE {}",
3782 expr.to_sql(),
3783 if *is_not { "NOT " } else { "" },
3784 pattern.to_sql()
3785 ),
3786 Condition::Exists { is_not, .. } => {
3787 format!("{}EXISTS (<subquery>)", if *is_not { "NOT " } else { "" })
3788 }
3789 Condition::NullCheck { expr, is_not } => format!(
3790 "{} IS {}NULL",
3791 expr.to_sql(),
3792 if *is_not { "NOT " } else { "" }
3793 ),
3794 Condition::Not(inner) => format!("NOT ({})", inner.to_sql()),
3795 Condition::Logical {
3796 left,
3797 operator,
3798 right,
3799 } => format!(
3800 "({}) {} ({})",
3801 left.to_sql(),
3802 operator.to_sql(),
3803 right.to_sql()
3804 ),
3805 }
3806 }
3807}
3808
3809impl Expression {
3810 pub(crate) fn references_source_name(&self, name: &str) -> bool {
3811 match self {
3812 Expression::ScalarSubquery(subquery) => subquery.references_source_name(name),
3813 Expression::Case {
3814 branches,
3815 else_expr,
3816 } => {
3817 branches.iter().any(|branch| {
3818 branch.condition.references_source_name(name)
3819 || branch.result.references_source_name(name)
3820 }) || else_expr
3821 .as_ref()
3822 .is_some_and(|expr| expr.references_source_name(name))
3823 }
3824 Expression::ScalarFunctionCall { args, .. } => {
3825 args.iter().any(|arg| arg.references_source_name(name))
3826 }
3827 Expression::Cast { expr, .. } => expr.references_source_name(name),
3828 Expression::UnaryMinus(expr) => expr.references_source_name(name),
3829 Expression::UnaryNot(expr) => expr.references_source_name(name),
3830 Expression::Binary { left, right, .. } => {
3831 left.references_source_name(name) || right.references_source_name(name)
3832 }
3833 Expression::Comparison { left, right, .. } => {
3834 left.references_source_name(name) || right.references_source_name(name)
3835 }
3836 Expression::InList { expr, values, .. } => {
3837 expr.references_source_name(name)
3838 || values
3839 .iter()
3840 .any(|value| value.references_source_name(name))
3841 }
3842 Expression::InSubquery { expr, subquery, .. } => {
3843 expr.references_source_name(name) || subquery.references_source_name(name)
3844 }
3845 Expression::Between {
3846 expr, lower, upper, ..
3847 } => {
3848 expr.references_source_name(name)
3849 || lower.references_source_name(name)
3850 || upper.references_source_name(name)
3851 }
3852 Expression::Like { expr, pattern, .. } => {
3853 expr.references_source_name(name) || pattern.references_source_name(name)
3854 }
3855 Expression::Exists { subquery, .. } => subquery.references_source_name(name),
3856 Expression::NullCheck { expr, .. } => expr.references_source_name(name),
3857 Expression::Logical { left, right, .. } => {
3858 left.references_source_name(name) || right.references_source_name(name)
3859 }
3860 Expression::Column(_)
3861 | Expression::Literal(_)
3862 | Expression::IntervalLiteral { .. }
3863 | Expression::Parameter(_)
3864 | Expression::AggregateCall { .. } => false,
3865 }
3866 }
3867
3868 pub(crate) fn references_column(&self, column_name: &str, table_name: Option<&str>) -> bool {
3869 match self {
3870 Expression::Column(name) => column_name_matches(name, column_name, table_name),
3871 Expression::ScalarSubquery(_) => false,
3872 Expression::Case {
3873 branches,
3874 else_expr,
3875 } => {
3876 branches.iter().any(|branch| {
3877 branch.condition.references_column(column_name, table_name)
3878 || branch.result.references_column(column_name, table_name)
3879 }) || else_expr
3880 .as_ref()
3881 .is_some_and(|expr| expr.references_column(column_name, table_name))
3882 }
3883 Expression::ScalarFunctionCall { args, .. } => args
3884 .iter()
3885 .any(|arg| arg.references_column(column_name, table_name)),
3886 Expression::AggregateCall { target, .. } => match target {
3887 AggregateTarget::All => false,
3888 AggregateTarget::Column(name) => column_name_matches(name, column_name, table_name),
3889 },
3890 Expression::Cast { expr, .. } => expr.references_column(column_name, table_name),
3891 Expression::UnaryMinus(expr) => expr.references_column(column_name, table_name),
3892 Expression::UnaryNot(expr) => expr.references_column(column_name, table_name),
3893 Expression::Binary { left, right, .. } => {
3894 left.references_column(column_name, table_name)
3895 || right.references_column(column_name, table_name)
3896 }
3897 Expression::Comparison { left, right, .. } => {
3898 left.references_column(column_name, table_name)
3899 || right.references_column(column_name, table_name)
3900 }
3901 Expression::InList { expr, values, .. } => {
3902 expr.references_column(column_name, table_name)
3903 || values
3904 .iter()
3905 .any(|value| value.references_column(column_name, table_name))
3906 }
3907 Expression::InSubquery { expr, .. } => expr.references_column(column_name, table_name),
3908 Expression::Between {
3909 expr, lower, upper, ..
3910 } => {
3911 expr.references_column(column_name, table_name)
3912 || lower.references_column(column_name, table_name)
3913 || upper.references_column(column_name, table_name)
3914 }
3915 Expression::Like { expr, pattern, .. } => {
3916 expr.references_column(column_name, table_name)
3917 || pattern.references_column(column_name, table_name)
3918 }
3919 Expression::Exists { .. } => false,
3920 Expression::NullCheck { expr, .. } => expr.references_column(column_name, table_name),
3921 Expression::Logical { left, right, .. } => {
3922 left.references_column(column_name, table_name)
3923 || right.references_column(column_name, table_name)
3924 }
3925 Expression::Literal(_)
3926 | Expression::IntervalLiteral { .. }
3927 | Expression::Parameter(_) => false,
3928 }
3929 }
3930
3931 pub(crate) fn rename_column_references(
3932 &mut self,
3933 old_name: &str,
3934 new_name: &str,
3935 table_name: Option<&str>,
3936 ) {
3937 match self {
3938 Expression::Column(name) => {
3939 rename_column_name(name, old_name, new_name, table_name);
3940 }
3941 Expression::ScalarSubquery(_) => {}
3942 Expression::Case {
3943 branches,
3944 else_expr,
3945 } => {
3946 for branch in branches {
3947 branch
3948 .condition
3949 .rename_column_references(old_name, new_name, table_name);
3950 branch
3951 .result
3952 .rename_column_references(old_name, new_name, table_name);
3953 }
3954 if let Some(else_expr) = else_expr {
3955 else_expr.rename_column_references(old_name, new_name, table_name);
3956 }
3957 }
3958 Expression::ScalarFunctionCall { args, .. } => {
3959 for arg in args {
3960 arg.rename_column_references(old_name, new_name, table_name);
3961 }
3962 }
3963 Expression::AggregateCall { target, .. } => {
3964 if let AggregateTarget::Column(name) = target {
3965 rename_column_name(name, old_name, new_name, table_name);
3966 }
3967 }
3968 Expression::Cast { expr, .. } => {
3969 expr.rename_column_references(old_name, new_name, table_name);
3970 }
3971 Expression::UnaryMinus(expr) => {
3972 expr.rename_column_references(old_name, new_name, table_name);
3973 }
3974 Expression::UnaryNot(expr) => {
3975 expr.rename_column_references(old_name, new_name, table_name);
3976 }
3977 Expression::Binary { left, right, .. } => {
3978 left.rename_column_references(old_name, new_name, table_name);
3979 right.rename_column_references(old_name, new_name, table_name);
3980 }
3981 Expression::Comparison { left, right, .. } => {
3982 left.rename_column_references(old_name, new_name, table_name);
3983 right.rename_column_references(old_name, new_name, table_name);
3984 }
3985 Expression::InList { expr, values, .. } => {
3986 expr.rename_column_references(old_name, new_name, table_name);
3987 for value in values {
3988 value.rename_column_references(old_name, new_name, table_name);
3989 }
3990 }
3991 Expression::InSubquery { expr, .. } => {
3992 expr.rename_column_references(old_name, new_name, table_name);
3993 }
3994 Expression::Between {
3995 expr, lower, upper, ..
3996 } => {
3997 expr.rename_column_references(old_name, new_name, table_name);
3998 lower.rename_column_references(old_name, new_name, table_name);
3999 upper.rename_column_references(old_name, new_name, table_name);
4000 }
4001 Expression::Like { expr, pattern, .. } => {
4002 expr.rename_column_references(old_name, new_name, table_name);
4003 pattern.rename_column_references(old_name, new_name, table_name);
4004 }
4005 Expression::Exists { .. } => {}
4006 Expression::NullCheck { expr, .. } => {
4007 expr.rename_column_references(old_name, new_name, table_name);
4008 }
4009 Expression::Logical { left, right, .. } => {
4010 left.rename_column_references(old_name, new_name, table_name);
4011 right.rename_column_references(old_name, new_name, table_name);
4012 }
4013 Expression::Literal(_)
4014 | Expression::IntervalLiteral { .. }
4015 | Expression::Parameter(_) => {}
4016 }
4017 }
4018
4019 pub fn to_sql(&self) -> String {
4020 match self {
4021 Expression::Column(name) => name.clone(),
4022 Expression::Literal(value) => match value {
4023 LiteralValue::Integer(i) => i.to_string(),
4024 LiteralValue::Text(s) => format!("'{}'", s.replace('\'', "''")),
4025 LiteralValue::Blob(bytes) => {
4026 format!(
4027 "X'{}'",
4028 bytes
4029 .iter()
4030 .map(|byte| format!("{byte:02X}"))
4031 .collect::<String>()
4032 )
4033 }
4034 LiteralValue::Boolean(true) => "TRUE".to_string(),
4035 LiteralValue::Boolean(false) => "FALSE".to_string(),
4036 LiteralValue::Float(f) => f.to_string(),
4037 LiteralValue::Null => "NULL".to_string(),
4038 },
4039 Expression::IntervalLiteral { value, qualifier } => {
4040 format!(
4041 "INTERVAL '{}' {}",
4042 value.replace('\'', "''"),
4043 qualifier.to_sql()
4044 )
4045 }
4046 Expression::Parameter(index) => format!("?{}", index + 1),
4047 Expression::ScalarSubquery(_) => "(<subquery>)".to_string(),
4048 Expression::Case {
4049 branches,
4050 else_expr,
4051 } => {
4052 let mut parts = vec!["CASE".to_string()];
4053 for branch in branches {
4054 parts.push(format!(
4055 "WHEN {} THEN {}",
4056 branch.condition.to_sql(),
4057 branch.result.to_sql()
4058 ));
4059 }
4060 if let Some(else_expr) = else_expr {
4061 parts.push(format!("ELSE {}", else_expr.to_sql()));
4062 }
4063 parts.push("END".to_string());
4064 parts.join(" ")
4065 }
4066 Expression::ScalarFunctionCall { function, args } => format!(
4067 "{}({})",
4068 function.to_sql(),
4069 args.iter()
4070 .map(Expression::to_sql)
4071 .collect::<Vec<_>>()
4072 .join(", ")
4073 ),
4074 Expression::AggregateCall { function, target } => {
4075 format!("{}({})", function.to_sql(), target.to_sql())
4076 }
4077 Expression::Cast { expr, target_type } => {
4078 format!("CAST({} AS {})", expr.to_sql(), target_type.to_sql())
4079 }
4080 Expression::UnaryMinus(expr) => format!("-{}", expr.to_sql()),
4081 Expression::UnaryNot(expr) => format!("NOT {}", expr.to_sql()),
4082 Expression::Binary {
4083 left,
4084 operator,
4085 right,
4086 } => format!(
4087 "({} {} {})",
4088 left.to_sql(),
4089 operator.to_sql(),
4090 right.to_sql()
4091 ),
4092 Expression::Comparison {
4093 left,
4094 operator,
4095 right,
4096 } => format!("{} {} {}", left.to_sql(), operator.to_sql(), right.to_sql()),
4097 Expression::InList {
4098 expr,
4099 values,
4100 is_not,
4101 } => format!(
4102 "{} {}IN ({})",
4103 expr.to_sql(),
4104 if *is_not { "NOT " } else { "" },
4105 values
4106 .iter()
4107 .map(Expression::to_sql)
4108 .collect::<Vec<_>>()
4109 .join(", ")
4110 ),
4111 Expression::InSubquery { expr, is_not, .. } => format!(
4112 "{} {}IN (<subquery>)",
4113 expr.to_sql(),
4114 if *is_not { "NOT " } else { "" }
4115 ),
4116 Expression::Between {
4117 expr,
4118 lower,
4119 upper,
4120 is_not,
4121 } => format!(
4122 "{} {}BETWEEN {} AND {}",
4123 expr.to_sql(),
4124 if *is_not { "NOT " } else { "" },
4125 lower.to_sql(),
4126 upper.to_sql()
4127 ),
4128 Expression::Like {
4129 expr,
4130 pattern,
4131 is_not,
4132 } => format!(
4133 "{} {}LIKE {}",
4134 expr.to_sql(),
4135 if *is_not { "NOT " } else { "" },
4136 pattern.to_sql()
4137 ),
4138 Expression::Exists { is_not, .. } => {
4139 format!("{}EXISTS (<subquery>)", if *is_not { "NOT " } else { "" })
4140 }
4141 Expression::NullCheck { expr, is_not } => format!(
4142 "{} IS {}NULL",
4143 expr.to_sql(),
4144 if *is_not { "NOT " } else { "" }
4145 ),
4146 Expression::Logical {
4147 left,
4148 operator,
4149 right,
4150 } => format!(
4151 "({}) {} ({})",
4152 left.to_sql(),
4153 operator.to_sql(),
4154 right.to_sql()
4155 ),
4156 }
4157 }
4158}
4159
4160impl AggregateFunction {
4161 fn to_sql(self) -> &'static str {
4162 match self {
4163 AggregateFunction::Count => "COUNT",
4164 AggregateFunction::Sum => "SUM",
4165 AggregateFunction::Avg => "AVG",
4166 AggregateFunction::Min => "MIN",
4167 AggregateFunction::Max => "MAX",
4168 }
4169 }
4170}
4171
4172impl ScalarFunction {
4173 pub fn from_identifier(name: &str) -> Option<Self> {
4174 match name.to_ascii_uppercase().as_str() {
4175 "COALESCE" => Some(Self::Coalesce),
4176 "IFNULL" => Some(Self::IfNull),
4177 "NULLIF" => Some(Self::NullIf),
4178 "DATE" => Some(Self::DateFn),
4179 "TIME" => Some(Self::TimeFn),
4180 "YEAR" => Some(Self::Year),
4181 "MONTH" => Some(Self::Month),
4182 "DAY" => Some(Self::Day),
4183 "HOUR" => Some(Self::Hour),
4184 "MINUTE" => Some(Self::Minute),
4185 "SECOND" => Some(Self::Second),
4186 "TIME_TO_SEC" => Some(Self::TimeToSec),
4187 "SEC_TO_TIME" => Some(Self::SecToTime),
4188 "UNIX_TIMESTAMP" => Some(Self::UnixTimestamp),
4189 "LOWER" => Some(Self::Lower),
4190 "UPPER" => Some(Self::Upper),
4191 "LENGTH" => Some(Self::Length),
4192 "OCTET_LENGTH" => Some(Self::OctetLength),
4193 "BIT_LENGTH" => Some(Self::BitLength),
4194 "TRIM" => Some(Self::Trim),
4195 "ABS" => Some(Self::Abs),
4196 "ROUND" => Some(Self::Round),
4197 "CONCAT" => Some(Self::Concat),
4198 "CONCAT_WS" => Some(Self::ConcatWs),
4199 "SUBSTRING" | "SUBSTR" => Some(Self::Substring),
4200 "LEFT" => Some(Self::LeftFn),
4201 "RIGHT" => Some(Self::RightFn),
4202 "GREATEST" => Some(Self::Greatest),
4203 "LEAST" => Some(Self::Least),
4204 "REPLACE" => Some(Self::Replace),
4205 "REPEAT" => Some(Self::Repeat),
4206 "REVERSE" => Some(Self::Reverse),
4207 "LOCATE" => Some(Self::Locate),
4208 "HEX" => Some(Self::Hex),
4209 "UNHEX" => Some(Self::Unhex),
4210 "CEIL" | "CEILING" => Some(Self::Ceil),
4211 "FLOOR" => Some(Self::Floor),
4212 "POWER" | "POW" => Some(Self::Power),
4213 _ => None,
4214 }
4215 }
4216
4217 pub(crate) fn to_sql(self) -> &'static str {
4218 match self {
4219 ScalarFunction::Coalesce => "COALESCE",
4220 ScalarFunction::IfNull => "IFNULL",
4221 ScalarFunction::NullIf => "NULLIF",
4222 ScalarFunction::DateFn => "DATE",
4223 ScalarFunction::TimeFn => "TIME",
4224 ScalarFunction::Year => "YEAR",
4225 ScalarFunction::Month => "MONTH",
4226 ScalarFunction::Day => "DAY",
4227 ScalarFunction::Hour => "HOUR",
4228 ScalarFunction::Minute => "MINUTE",
4229 ScalarFunction::Second => "SECOND",
4230 ScalarFunction::TimeToSec => "TIME_TO_SEC",
4231 ScalarFunction::SecToTime => "SEC_TO_TIME",
4232 ScalarFunction::UnixTimestamp => "UNIX_TIMESTAMP",
4233 ScalarFunction::Lower => "LOWER",
4234 ScalarFunction::Upper => "UPPER",
4235 ScalarFunction::Length => "LENGTH",
4236 ScalarFunction::OctetLength => "OCTET_LENGTH",
4237 ScalarFunction::BitLength => "BIT_LENGTH",
4238 ScalarFunction::Trim => "TRIM",
4239 ScalarFunction::Abs => "ABS",
4240 ScalarFunction::Round => "ROUND",
4241 ScalarFunction::Concat => "CONCAT",
4242 ScalarFunction::ConcatWs => "CONCAT_WS",
4243 ScalarFunction::Substring => "SUBSTRING",
4244 ScalarFunction::LeftFn => "LEFT",
4245 ScalarFunction::RightFn => "RIGHT",
4246 ScalarFunction::Greatest => "GREATEST",
4247 ScalarFunction::Least => "LEAST",
4248 ScalarFunction::Replace => "REPLACE",
4249 ScalarFunction::Repeat => "REPEAT",
4250 ScalarFunction::Reverse => "REVERSE",
4251 ScalarFunction::Locate => "LOCATE",
4252 ScalarFunction::Hex => "HEX",
4253 ScalarFunction::Unhex => "UNHEX",
4254 ScalarFunction::Ceil => "CEIL",
4255 ScalarFunction::Floor => "FLOOR",
4256 ScalarFunction::Power => "POWER",
4257 }
4258 }
4259}
4260
4261impl AggregateTarget {
4262 fn to_sql(&self) -> String {
4263 match self {
4264 AggregateTarget::All => "*".to_string(),
4265 AggregateTarget::Column(column) => column.clone(),
4266 }
4267 }
4268}
4269
4270impl ComparisonOperator {
4271 fn to_sql(&self) -> &'static str {
4272 match self {
4273 ComparisonOperator::Equal => "=",
4274 ComparisonOperator::NotEqual => "!=",
4275 ComparisonOperator::LessThan => "<",
4276 ComparisonOperator::LessThanOrEqual => "<=",
4277 ComparisonOperator::GreaterThan => ">",
4278 ComparisonOperator::GreaterThanOrEqual => ">=",
4279 }
4280 }
4281}
4282
4283impl LogicalOperator {
4284 fn to_sql(&self) -> &'static str {
4285 match self {
4286 LogicalOperator::And => "AND",
4287 LogicalOperator::Or => "OR",
4288 }
4289 }
4290}
4291
4292impl ArithmeticOperator {
4293 fn to_sql(&self) -> &'static str {
4294 match self {
4295 ArithmeticOperator::Add => "+",
4296 ArithmeticOperator::Subtract => "-",
4297 ArithmeticOperator::Multiply => "*",
4298 ArithmeticOperator::Divide => "/",
4299 ArithmeticOperator::Modulo => "%",
4300 }
4301 }
4302}
4303
4304impl SortDirection {
4305 fn to_sql(&self) -> &'static str {
4306 match self {
4307 SortDirection::Asc => "ASC",
4308 SortDirection::Desc => "DESC",
4309 }
4310 }
4311}
4312
4313impl SetOperator {
4314 fn to_sql(&self) -> &'static str {
4315 match self {
4316 SetOperator::Union => "UNION",
4317 SetOperator::UnionAll => "UNION ALL",
4318 SetOperator::Intersect => "INTERSECT",
4319 SetOperator::Except => "EXCEPT",
4320 }
4321 }
4322}
4323
4324#[cfg(test)]
4325impl CreateIndexStatement {
4326 pub fn validate(&self, catalog: &crate::catalog::Schema) -> Result<()> {
4327 let table = require_table(catalog, &self.table)?;
4328
4329 if self.columns.is_empty() {
4330 return Err(HematiteError::ParseError(
4331 "CREATE INDEX must specify at least one column".to_string(),
4332 ));
4333 }
4334
4335 validate_named_columns(&self.columns, "CREATE INDEX", |column| {
4336 if table.get_column_by_name(column).is_some() {
4337 Ok(())
4338 } else {
4339 Err(HematiteError::ParseError(format!(
4340 "Column '{}' does not exist in table '{}'",
4341 column, self.table
4342 )))
4343 }
4344 })?;
4345
4346 if table.get_secondary_index(&self.index_name).is_some() {
4347 if self.if_not_exists {
4348 return Ok(());
4349 }
4350 return Err(HematiteError::ParseError(format!(
4351 "Index '{}' already exists on table '{}'",
4352 self.index_name, self.table
4353 )));
4354 }
4355
4356 Ok(())
4357 }
4358}
4359
4360#[cfg(test)]
4361fn require_table<'a>(
4362 catalog: &'a crate::catalog::Schema,
4363 table_name: &str,
4364) -> Result<&'a crate::catalog::Table> {
4365 catalog
4366 .get_table_by_name(table_name)
4367 .ok_or_else(|| HematiteError::ParseError(format!("Table '{}' does not exist", table_name)))
4368}
4369
4370#[cfg(test)]
4371fn sql_type_name_for_catalog_type(data_type: crate::catalog::DataType) -> SqlTypeName {
4372 match data_type {
4373 crate::catalog::DataType::Int8 => SqlTypeName::Int8,
4374 crate::catalog::DataType::Int16 => SqlTypeName::Int16,
4375 crate::catalog::DataType::Int => SqlTypeName::Int,
4376 crate::catalog::DataType::Int64 => SqlTypeName::Int64,
4377 crate::catalog::DataType::Int128 => SqlTypeName::Int128,
4378 crate::catalog::DataType::UInt8 => SqlTypeName::UInt8,
4379 crate::catalog::DataType::UInt16 => SqlTypeName::UInt16,
4380 crate::catalog::DataType::UInt => SqlTypeName::UInt,
4381 crate::catalog::DataType::UInt64 => SqlTypeName::UInt64,
4382 crate::catalog::DataType::UInt128 => SqlTypeName::UInt128,
4383 crate::catalog::DataType::Text => SqlTypeName::Text,
4384 crate::catalog::DataType::Char(length) => SqlTypeName::Char(length),
4385 crate::catalog::DataType::VarChar(length) => SqlTypeName::VarChar(length),
4386 crate::catalog::DataType::Binary(length) => SqlTypeName::Binary(length),
4387 crate::catalog::DataType::VarBinary(length) => SqlTypeName::VarBinary(length),
4388 crate::catalog::DataType::Enum(values) => SqlTypeName::Enum(values),
4389 crate::catalog::DataType::Boolean => SqlTypeName::Boolean,
4390 crate::catalog::DataType::Float32 => SqlTypeName::Float32,
4391 crate::catalog::DataType::Float => SqlTypeName::Float,
4392 crate::catalog::DataType::Decimal { precision, scale } => {
4393 SqlTypeName::Decimal { precision, scale }
4394 }
4395 crate::catalog::DataType::Blob => SqlTypeName::Blob,
4396 crate::catalog::DataType::Date => SqlTypeName::Date,
4397 crate::catalog::DataType::Time => SqlTypeName::Time,
4398 crate::catalog::DataType::DateTime => SqlTypeName::DateTime,
4399 crate::catalog::DataType::TimeWithTimeZone => SqlTypeName::TimeWithTimeZone,
4400 crate::catalog::DataType::IntervalYearMonth
4401 | crate::catalog::DataType::IntervalDaySecond => {
4402 panic!("interval runtime types do not map to schema SQL type names")
4403 }
4404 }
4405}
4406
4407#[cfg(test)]
4408fn validate_named_columns<F>(
4409 columns: &[String],
4410 constraint_label: &str,
4411 mut validate_column: F,
4412) -> Result<()>
4413where
4414 F: FnMut(&str) -> Result<()>,
4415{
4416 let mut seen = std::collections::HashSet::new();
4417 for column in columns {
4418 if !seen.insert(column) {
4419 return Err(HematiteError::ParseError(format!(
4420 "{} repeats column '{}'",
4421 constraint_label, column
4422 )));
4423 }
4424 validate_column(column)?;
4425 }
4426 Ok(())
4427}
4428
4429fn rename_column_name(name: &mut String, old_name: &str, new_name: &str, table_name: Option<&str>) {
4430 if name == old_name {
4431 *name = new_name.to_string();
4432 } else if let Some(table_name) = table_name {
4433 let qualified = format!("{}.{}", table_name, old_name);
4434 if name == &qualified {
4435 *name = format!("{}.{}", table_name, new_name);
4436 }
4437 }
4438}
4439
4440fn column_name_matches(name: &str, column_name: &str, table_name: Option<&str>) -> bool {
4441 let (qualifier, bare_name) = SelectStatement::split_column_reference(name);
4442 if let Some(qualifier) = qualifier {
4443 qualifier == table_name.unwrap_or_default() && bare_name == column_name
4444 } else {
4445 name == column_name
4446 }
4447}
4448
4449#[cfg(test)]
4450impl DropIndexStatement {
4451 pub fn validate(&self, catalog: &crate::catalog::Schema) -> Result<()> {
4452 if self.if_exists && catalog.get_table_by_name(&self.table).is_none() {
4453 return Ok(());
4454 }
4455 let table = require_table(catalog, &self.table)?;
4456
4457 if table.get_secondary_index(&self.index_name).is_none() {
4458 if self.if_exists {
4459 return Ok(());
4460 }
4461 return Err(HematiteError::ParseError(format!(
4462 "Index '{}' does not exist on table '{}'",
4463 self.index_name, self.table
4464 )));
4465 }
4466
4467 Ok(())
4468 }
4469}