1use kimberlite_types::NonEmptyVec;
15use sqlparser::ast::{
16 BinaryOperator, ColumnDef as SqlColumnDef, DataType as SqlDataType, Expr, Ident, ObjectName,
17 OrderByExpr, Query, Select, SelectItem, SetExpr, Statement, Value as SqlValue,
18};
19use sqlparser::dialect::GenericDialect;
20use sqlparser::parser::Parser;
21
22use crate::error::{QueryError, Result};
23use crate::schema::ColumnName;
24use crate::value::Value;
25
26#[derive(Debug, Clone)]
32pub enum ParsedStatement {
33 Select(ParsedSelect),
35 Union(ParsedUnion),
37 CreateTable(ParsedCreateTable),
39 DropTable(String),
41 AlterTable(ParsedAlterTable),
43 CreateIndex(ParsedCreateIndex),
45 Insert(ParsedInsert),
47 Update(ParsedUpdate),
49 Delete(ParsedDelete),
51 CreateMask(ParsedCreateMask),
53 DropMask(String),
55 SetClassification(ParsedSetClassification),
57 ShowClassifications(String),
59 ShowTables,
61 ShowColumns(String),
63 CreateRole(String),
65 Grant(ParsedGrant),
67 CreateUser(ParsedCreateUser),
69}
70
71#[derive(Debug, Clone)]
73pub struct ParsedGrant {
74 pub columns: Option<Vec<String>>,
76 pub table_name: String,
78 pub role_name: String,
80}
81
82#[derive(Debug, Clone)]
84pub struct ParsedCreateUser {
85 pub username: String,
87 pub role: String,
89}
90
91#[derive(Debug, Clone)]
93pub struct ParsedSetClassification {
94 pub table_name: String,
96 pub column_name: String,
98 pub classification: String,
100}
101
102#[derive(Debug, Clone)]
104pub struct ParsedCreateMask {
105 pub mask_name: String,
107 pub table_name: String,
109 pub column_name: String,
111 pub strategy: String,
113}
114
115#[derive(Debug, Clone)]
117pub struct ParsedUnion {
118 pub left: ParsedSelect,
120 pub right: ParsedSelect,
122 pub all: bool,
124}
125
126#[derive(Debug, Clone)]
128pub enum JoinType {
129 Inner,
131 Left,
133 }
135
136#[derive(Debug, Clone)]
138pub struct ParsedJoin {
139 pub table: String,
141 pub join_type: JoinType,
143 pub on_condition: Vec<Predicate>,
145}
146
147#[derive(Debug, Clone)]
149pub struct ParsedCte {
150 pub name: String,
152 pub query: ParsedSelect,
154}
155
156#[derive(Debug, Clone)]
160pub struct ComputedColumn {
161 pub alias: ColumnName,
163 pub when_clauses: Vec<CaseWhenArm>,
165 pub else_value: Value,
167}
168
169#[derive(Debug, Clone)]
171pub struct CaseWhenArm {
172 pub condition: Vec<Predicate>,
174 pub result: Value,
176}
177
178#[derive(Debug, Clone)]
180pub struct ParsedSelect {
181 pub table: String,
183 pub joins: Vec<ParsedJoin>,
185 pub columns: Option<Vec<ColumnName>>,
187 pub case_columns: Vec<ComputedColumn>,
189 pub predicates: Vec<Predicate>,
191 pub order_by: Vec<OrderByClause>,
193 pub limit: Option<usize>,
195 pub aggregates: Vec<AggregateFunction>,
197 pub group_by: Vec<ColumnName>,
199 pub distinct: bool,
201 pub having: Vec<HavingCondition>,
203 pub ctes: Vec<ParsedCte>,
205 pub window_fns: Vec<ParsedWindowFn>,
209}
210
211#[derive(Debug, Clone)]
214pub struct ParsedWindowFn {
215 pub function: crate::window::WindowFunction,
217 pub partition_by: Vec<ColumnName>,
219 pub order_by: Vec<OrderByClause>,
221 pub alias: Option<String>,
223}
224
225#[derive(Debug, Clone)]
229pub enum HavingCondition {
230 AggregateComparison {
232 aggregate: AggregateFunction,
234 op: HavingOp,
236 value: Value,
238 },
239}
240
241#[derive(Debug, Clone, Copy, PartialEq, Eq)]
243pub enum HavingOp {
244 Eq,
245 Lt,
246 Le,
247 Gt,
248 Ge,
249}
250
251#[derive(Debug, Clone)]
259pub struct ParsedCreateTable {
260 pub table_name: String,
261 pub columns: NonEmptyVec<ParsedColumn>,
262 pub primary_key: Vec<String>,
263 pub if_not_exists: bool,
266}
267
268#[derive(Debug, Clone)]
270pub struct ParsedColumn {
271 pub name: String,
272 pub data_type: String, pub nullable: bool,
274}
275
276#[derive(Debug, Clone)]
278pub struct ParsedAlterTable {
279 pub table_name: String,
280 pub operation: AlterTableOperation,
281}
282
283#[derive(Debug, Clone)]
285pub enum AlterTableOperation {
286 AddColumn(ParsedColumn),
288 DropColumn(String),
290}
291
292#[derive(Debug, Clone)]
294pub struct ParsedCreateIndex {
295 pub index_name: String,
296 pub table_name: String,
297 pub columns: Vec<String>,
298}
299
300#[derive(Debug, Clone)]
302pub struct ParsedInsert {
303 pub table: String,
304 pub columns: Vec<String>,
305 pub values: Vec<Vec<Value>>, pub returning: Option<Vec<String>>, }
308
309#[derive(Debug, Clone)]
311pub struct ParsedUpdate {
312 pub table: String,
313 pub assignments: Vec<(String, Value)>, pub predicates: Vec<Predicate>,
315 pub returning: Option<Vec<String>>, }
317
318#[derive(Debug, Clone)]
320pub struct ParsedDelete {
321 pub table: String,
322 pub predicates: Vec<Predicate>,
323 pub returning: Option<Vec<String>>, }
325
326#[derive(Debug, Clone, PartialEq, Eq)]
328pub enum AggregateFunction {
329 CountStar,
331 Count(ColumnName),
333 Sum(ColumnName),
335 Avg(ColumnName),
337 Min(ColumnName),
339 Max(ColumnName),
341}
342
343#[derive(Debug, Clone)]
345pub enum Predicate {
346 Eq(ColumnName, PredicateValue),
348 Lt(ColumnName, PredicateValue),
350 Le(ColumnName, PredicateValue),
352 Gt(ColumnName, PredicateValue),
354 Ge(ColumnName, PredicateValue),
356 In(ColumnName, Vec<PredicateValue>),
358 Like(ColumnName, String),
360 IsNull(ColumnName),
362 IsNotNull(ColumnName),
364 Or(Vec<Predicate>, Vec<Predicate>),
366}
367
368impl Predicate {
369 #[allow(dead_code)]
373 pub fn column(&self) -> Option<&ColumnName> {
374 match self {
375 Predicate::Eq(col, _)
376 | Predicate::Lt(col, _)
377 | Predicate::Le(col, _)
378 | Predicate::Gt(col, _)
379 | Predicate::Ge(col, _)
380 | Predicate::In(col, _)
381 | Predicate::Like(col, _)
382 | Predicate::IsNull(col)
383 | Predicate::IsNotNull(col) => Some(col),
384 Predicate::Or(_, _) => None,
385 }
386 }
387}
388
389#[derive(Debug, Clone)]
391pub enum PredicateValue {
392 Int(i64),
394 String(String),
396 Bool(bool),
398 Null,
400 Param(usize),
402 Literal(Value),
404 ColumnRef(String),
407}
408
409#[derive(Debug, Clone)]
411pub struct OrderByClause {
412 pub column: ColumnName,
414 pub ascending: bool,
416}
417
418pub fn parse_statement(sql: &str) -> Result<ParsedStatement> {
424 if let Some(parsed) = try_parse_custom_statement(sql)? {
426 return Ok(parsed);
427 }
428
429 let dialect = GenericDialect {};
430 let statements =
431 Parser::parse_sql(&dialect, sql).map_err(|e| QueryError::ParseError(e.to_string()))?;
432
433 if statements.len() != 1 {
434 return Err(QueryError::ParseError(format!(
435 "expected exactly 1 statement, got {}",
436 statements.len()
437 )));
438 }
439
440 match &statements[0] {
441 Statement::Query(query) => parse_query_to_statement(query),
442 Statement::CreateTable(create_table) => {
443 let parsed = parse_create_table(create_table)?;
444 Ok(ParsedStatement::CreateTable(parsed))
445 }
446 Statement::Drop {
447 object_type,
448 names,
449 if_exists: _,
450 ..
451 } => {
452 if !matches!(object_type, sqlparser::ast::ObjectType::Table) {
453 return Err(QueryError::UnsupportedFeature(
454 "only DROP TABLE is supported".to_string(),
455 ));
456 }
457 if names.len() != 1 {
458 return Err(QueryError::ParseError(
459 "expected exactly 1 table in DROP TABLE".to_string(),
460 ));
461 }
462 let table_name = object_name_to_string(&names[0]);
463 Ok(ParsedStatement::DropTable(table_name))
464 }
465 Statement::CreateIndex(create_index) => {
466 let parsed = parse_create_index(create_index)?;
467 Ok(ParsedStatement::CreateIndex(parsed))
468 }
469 Statement::Insert(insert) => {
470 let parsed = parse_insert(insert)?;
471 Ok(ParsedStatement::Insert(parsed))
472 }
473 Statement::Update {
474 table,
475 assignments,
476 selection,
477 returning,
478 ..
479 } => {
480 let parsed = parse_update(table, assignments, selection.as_ref(), returning.as_ref())?;
481 Ok(ParsedStatement::Update(parsed))
482 }
483 Statement::Delete(delete) => {
484 let parsed = parse_delete_stmt(delete)?;
485 Ok(ParsedStatement::Delete(parsed))
486 }
487 Statement::AlterTable {
488 name, operations, ..
489 } => {
490 let parsed = parse_alter_table(name, operations)?;
491 Ok(ParsedStatement::AlterTable(parsed))
492 }
493 Statement::CreateRole { names, .. } => {
494 if names.len() != 1 {
495 return Err(QueryError::ParseError(
496 "expected exactly 1 role name".to_string(),
497 ));
498 }
499 let role_name = object_name_to_string(&names[0]);
500 Ok(ParsedStatement::CreateRole(role_name))
501 }
502 Statement::Grant {
503 privileges,
504 objects,
505 grantees,
506 ..
507 } => parse_grant(privileges, objects, grantees),
508 other => Err(QueryError::UnsupportedFeature(format!(
509 "statement type not supported: {other:?}"
510 ))),
511 }
512}
513
514pub fn try_parse_custom_statement(sql: &str) -> Result<Option<ParsedStatement>> {
523 let trimmed = sql.trim().trim_end_matches(';').trim();
524 let upper = trimmed.to_ascii_uppercase();
525
526 if upper.starts_with("CREATE MASK") {
528 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
529 if tokens.len() != 7 {
531 return Err(QueryError::ParseError(
532 "expected: CREATE MASK <name> ON <table>.<column> USING <strategy>".to_string(),
533 ));
534 }
535 if !tokens[3].eq_ignore_ascii_case("ON") {
536 return Err(QueryError::ParseError(format!(
537 "expected ON after mask name, got '{}'",
538 tokens[3]
539 )));
540 }
541 if !tokens[5].eq_ignore_ascii_case("USING") {
542 return Err(QueryError::ParseError(format!(
543 "expected USING after column reference, got '{}'",
544 tokens[5]
545 )));
546 }
547
548 let table_col = tokens[4];
550 let dot_pos = table_col.find('.').ok_or_else(|| {
551 QueryError::ParseError(format!(
552 "expected <table>.<column> but got '{table_col}' (missing '.')"
553 ))
554 })?;
555 let table_name = table_col[..dot_pos].to_string();
556 let column_name = table_col[dot_pos + 1..].to_string();
557
558 if table_name.is_empty() || column_name.is_empty() {
559 return Err(QueryError::ParseError(
560 "table name and column name must not be empty".to_string(),
561 ));
562 }
563
564 let strategy = tokens[6].to_ascii_uppercase();
565
566 return Ok(Some(ParsedStatement::CreateMask(ParsedCreateMask {
567 mask_name: tokens[2].to_string(),
568 table_name,
569 column_name,
570 strategy,
571 })));
572 }
573
574 if upper.starts_with("DROP MASK") {
576 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
577 if tokens.len() != 3 {
578 return Err(QueryError::ParseError(
579 "expected: DROP MASK <name>".to_string(),
580 ));
581 }
582 return Ok(Some(ParsedStatement::DropMask(tokens[2].to_string())));
583 }
584
585 if upper.starts_with("ALTER TABLE") && upper.contains("SET CLASSIFICATION") {
587 return parse_set_classification(trimmed);
588 }
589
590 if upper.starts_with("SHOW CLASSIFICATIONS") {
592 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
593 if tokens.len() != 4 {
595 return Err(QueryError::ParseError(
596 "expected: SHOW CLASSIFICATIONS FOR <table>".to_string(),
597 ));
598 }
599 if !tokens[2].eq_ignore_ascii_case("FOR") {
600 return Err(QueryError::ParseError(format!(
601 "expected FOR after CLASSIFICATIONS, got '{}'",
602 tokens[2]
603 )));
604 }
605 return Ok(Some(ParsedStatement::ShowClassifications(
606 tokens[3].to_string(),
607 )));
608 }
609
610 if upper == "SHOW TABLES" {
612 return Ok(Some(ParsedStatement::ShowTables));
613 }
614
615 if upper.starts_with("SHOW COLUMNS") {
617 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
618 if tokens.len() != 4 {
620 return Err(QueryError::ParseError(
621 "expected: SHOW COLUMNS FROM <table>".to_string(),
622 ));
623 }
624 if !tokens[2].eq_ignore_ascii_case("FROM") {
625 return Err(QueryError::ParseError(format!(
626 "expected FROM after COLUMNS, got '{}'",
627 tokens[2]
628 )));
629 }
630 return Ok(Some(ParsedStatement::ShowColumns(
631 tokens[3].to_string(),
632 )));
633 }
634
635 if upper.starts_with("CREATE USER") {
637 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
638 if tokens.len() != 6 {
640 return Err(QueryError::ParseError(
641 "expected: CREATE USER <name> WITH ROLE <role>".to_string(),
642 ));
643 }
644 if !tokens[3].eq_ignore_ascii_case("WITH") {
645 return Err(QueryError::ParseError(format!(
646 "expected WITH after username, got '{}'",
647 tokens[3]
648 )));
649 }
650 if !tokens[4].eq_ignore_ascii_case("ROLE") {
651 return Err(QueryError::ParseError(format!(
652 "expected ROLE after WITH, got '{}'",
653 tokens[4]
654 )));
655 }
656 return Ok(Some(ParsedStatement::CreateUser(ParsedCreateUser {
657 username: tokens[2].to_string(),
658 role: tokens[5].to_string(),
659 })));
660 }
661
662 Ok(None)
663}
664
665#[derive(Debug, Clone, Copy, PartialEq, Eq)]
682pub enum TimeTravel {
683 Offset(u64),
685 TimestampNs(i64),
688}
689
690pub fn extract_at_offset(sql: &str) -> (String, Option<u64>) {
712 let upper = sql.to_ascii_uppercase();
715
716 let Some(at_pos) = upper.rfind("AT OFFSET") else {
719 return (sql.to_string(), None);
720 };
721
722 if at_pos > 0 {
724 let prev_byte = sql.as_bytes()[at_pos - 1];
725 if prev_byte != b' ' && prev_byte != b'\t' && prev_byte != b'\n' && prev_byte != b'\r' {
726 return (sql.to_string(), None);
727 }
728 }
729
730 let after_at_offset = &sql[at_pos + 9..].trim_start();
732
733 let num_end = after_at_offset
736 .find(|c: char| !c.is_ascii_digit())
737 .unwrap_or(after_at_offset.len());
738
739 if num_end == 0 {
740 return (sql.to_string(), None);
742 }
743
744 let num_str = &after_at_offset[..num_end];
745 let Ok(offset) = num_str.parse::<u64>() else {
746 return (sql.to_string(), None);
747 };
748
749 let remainder = after_at_offset[num_end..].trim();
752 if !remainder.is_empty() && remainder != ";" {
753 return (sql.to_string(), None);
754 }
755
756 let before = sql[..at_pos].trim_end();
758 let cleaned = before.to_string();
759
760 (cleaned, Some(offset))
761}
762
763pub fn extract_time_travel(sql: &str) -> (String, Option<TimeTravel>) {
791 let (after_offset_sql, offset) = extract_at_offset(sql);
793 if let Some(o) = offset {
794 return (after_offset_sql, Some(TimeTravel::Offset(o)));
795 }
796
797 let upper = sql.to_ascii_uppercase();
803
804 let (keyword_pos, keyword_len) = if let Some(p) = upper.rfind("FOR SYSTEM_TIME AS OF") {
810 (p, "FOR SYSTEM_TIME AS OF".len())
811 } else if let Some(p) = upper.rfind("AS OF") {
812 let after = sql[p + "AS OF".len()..].trim_start();
815 if !after.starts_with('\'') {
816 return (sql.to_string(), None);
817 }
818 (p, "AS OF".len())
819 } else {
820 return (sql.to_string(), None);
821 };
822
823 if keyword_pos > 0 {
825 let prev = sql.as_bytes()[keyword_pos - 1];
826 if !matches!(prev, b' ' | b'\t' | b'\n' | b'\r') {
827 return (sql.to_string(), None);
828 }
829 }
830
831 let after_keyword = sql[keyword_pos + keyword_len..].trim_start();
832 if !after_keyword.starts_with('\'') {
833 return (sql.to_string(), None);
834 }
835
836 let ts_start = 1; let ts_end = match after_keyword[1..].find('\'') {
841 Some(i) => i + 1,
842 None => return (sql.to_string(), None),
843 };
844 let ts_str = &after_keyword[ts_start..ts_end];
845
846 let ts_ns = match chrono::DateTime::parse_from_rfc3339(ts_str) {
848 Ok(dt) => dt.timestamp_nanos_opt(),
849 Err(_) => return (sql.to_string(), None),
850 };
851 let ts_ns = match ts_ns {
852 Some(n) => n,
853 None => return (sql.to_string(), None),
854 };
855
856 let remainder = after_keyword[ts_end + 1..].trim();
858 if !remainder.is_empty() && remainder != ";" {
859 return (sql.to_string(), None);
860 }
861
862 let before = sql[..keyword_pos].trim_end();
863 (before.to_string(), Some(TimeTravel::TimestampNs(ts_ns)))
864}
865
866fn parse_set_classification(sql: &str) -> Result<Option<ParsedStatement>> {
871 let tokens: Vec<&str> = sql.split_whitespace().collect();
872 if tokens.len() != 9 {
875 return Err(QueryError::ParseError(
876 "expected: ALTER TABLE <table> MODIFY COLUMN <column> SET CLASSIFICATION '<class>'"
877 .to_string(),
878 ));
879 }
880
881 if !tokens[3].eq_ignore_ascii_case("MODIFY") {
882 return Err(QueryError::ParseError(format!(
883 "expected MODIFY, got '{}'",
884 tokens[3]
885 )));
886 }
887 if !tokens[4].eq_ignore_ascii_case("COLUMN") {
888 return Err(QueryError::ParseError(format!(
889 "expected COLUMN after MODIFY, got '{}'",
890 tokens[4]
891 )));
892 }
893 if !tokens[6].eq_ignore_ascii_case("SET") {
894 return Err(QueryError::ParseError(format!(
895 "expected SET, got '{}'",
896 tokens[6]
897 )));
898 }
899 if !tokens[7].eq_ignore_ascii_case("CLASSIFICATION") {
900 return Err(QueryError::ParseError(format!(
901 "expected CLASSIFICATION, got '{}'",
902 tokens[7]
903 )));
904 }
905
906 let table_name = tokens[2].to_string();
907 let column_name = tokens[5].to_string();
908
909 let raw_class = tokens[8];
911 let classification = raw_class
912 .strip_prefix('\'')
913 .and_then(|s| s.strip_suffix('\''))
914 .ok_or_else(|| {
915 QueryError::ParseError(format!(
916 "classification must be quoted with single quotes, got '{raw_class}'"
917 ))
918 })?
919 .to_string();
920
921 assert!(!table_name.is_empty(), "table name must not be empty");
922 assert!(!column_name.is_empty(), "column name must not be empty");
923 assert!(
924 !classification.is_empty(),
925 "classification must not be empty"
926 );
927
928 Ok(Some(ParsedStatement::SetClassification(
929 ParsedSetClassification {
930 table_name,
931 column_name,
932 classification,
933 },
934 )))
935}
936
937fn parse_grant(
939 privileges: &sqlparser::ast::Privileges,
940 objects: &sqlparser::ast::GrantObjects,
941 grantees: &[sqlparser::ast::Grantee],
942) -> Result<ParsedStatement> {
943 use sqlparser::ast::{Action, GrantObjects, GranteeName, Privileges};
944
945 let columns = match privileges {
947 Privileges::Actions(actions) => {
948 let mut cols = None;
949 for action in actions {
950 if let Action::Select { columns: Some(c) } = action {
951 cols = Some(c.iter().map(|i| i.value.clone()).collect());
952 }
953 }
954 cols
955 }
956 Privileges::All { .. } => None,
957 };
958
959 let table_name = match objects {
961 GrantObjects::Tables(tables) => {
962 if tables.len() != 1 {
963 return Err(QueryError::ParseError(
964 "expected exactly 1 table in GRANT".to_string(),
965 ));
966 }
967 object_name_to_string(&tables[0])
968 }
969 _ => {
970 return Err(QueryError::UnsupportedFeature(
971 "GRANT only supports table-level privileges".to_string(),
972 ));
973 }
974 };
975
976 if grantees.len() != 1 {
978 return Err(QueryError::ParseError(
979 "expected exactly 1 grantee in GRANT".to_string(),
980 ));
981 }
982 let role_name = match &grantees[0].name {
983 Some(GranteeName::ObjectName(name)) => object_name_to_string(name),
984 _ => {
985 return Err(QueryError::ParseError(
986 "expected a role name in GRANT".to_string(),
987 ));
988 }
989 };
990
991 Ok(ParsedStatement::Grant(ParsedGrant {
992 columns,
993 table_name,
994 role_name,
995 }))
996}
997
998fn parse_query_to_statement(query: &Query) -> Result<ParsedStatement> {
1000 let ctes = match &query.with {
1002 Some(with) => {
1003 if with.recursive {
1004 return Err(QueryError::UnsupportedFeature(
1005 "WITH RECURSIVE is not supported".to_string(),
1006 ));
1007 }
1008 parse_ctes(with)?
1009 }
1010 None => vec![],
1011 };
1012
1013 match query.body.as_ref() {
1014 SetExpr::Select(select) => {
1015 let parsed_select = parse_select(select)?;
1016
1017 let order_by = match &query.order_by {
1019 Some(ob) => parse_order_by(ob)?,
1020 None => vec![],
1021 };
1022
1023 let limit = parse_limit(query.limit.as_ref())?;
1025
1026 let mut all_ctes = ctes;
1028 all_ctes.extend(parsed_select.ctes);
1029
1030 Ok(ParsedStatement::Select(ParsedSelect {
1031 table: parsed_select.table,
1032 joins: parsed_select.joins,
1033 columns: parsed_select.columns,
1034 case_columns: parsed_select.case_columns,
1035 predicates: parsed_select.predicates,
1036 order_by,
1037 limit,
1038 aggregates: parsed_select.aggregates,
1039 group_by: parsed_select.group_by,
1040 distinct: parsed_select.distinct,
1041 having: parsed_select.having,
1042 ctes: all_ctes,
1043 window_fns: parsed_select.window_fns,
1044 }))
1045 }
1046 SetExpr::SetOperation {
1047 op,
1048 set_quantifier,
1049 left,
1050 right,
1051 } => {
1052 use sqlparser::ast::SetOperator;
1053 use sqlparser::ast::SetQuantifier;
1054
1055 if !matches!(op, SetOperator::Union) {
1056 return Err(QueryError::UnsupportedFeature(format!(
1057 "set operation not supported: {op:?} (only UNION is supported)"
1058 )));
1059 }
1060
1061 let all = matches!(set_quantifier, SetQuantifier::All);
1062
1063 let left_select = match left.as_ref() {
1065 SetExpr::Select(s) => parse_select(s)?,
1066 _ => {
1067 return Err(QueryError::UnsupportedFeature(
1068 "nested set operations not supported".to_string(),
1069 ));
1070 }
1071 };
1072 let right_select = match right.as_ref() {
1073 SetExpr::Select(s) => parse_select(s)?,
1074 _ => {
1075 return Err(QueryError::UnsupportedFeature(
1076 "nested set operations not supported".to_string(),
1077 ));
1078 }
1079 };
1080
1081 Ok(ParsedStatement::Union(ParsedUnion {
1082 left: left_select,
1083 right: right_select,
1084 all,
1085 }))
1086 }
1087 other => Err(QueryError::UnsupportedFeature(format!(
1088 "unsupported query type: {other:?}"
1089 ))),
1090 }
1091}
1092
1093fn parse_join_with_subqueries(join: &sqlparser::ast::Join) -> Result<(ParsedJoin, Vec<ParsedCte>)> {
1095 use sqlparser::ast::{JoinConstraint, JoinOperator};
1096
1097 let join_type = match &join.join_operator {
1099 JoinOperator::Inner(_) => JoinType::Inner,
1100 JoinOperator::LeftOuter(_) => JoinType::Left,
1101 other => {
1102 return Err(QueryError::UnsupportedFeature(format!(
1103 "join type not supported: {other:?}"
1104 )));
1105 }
1106 };
1107
1108 let mut inline_ctes = Vec::new();
1110 let table = match &join.relation {
1111 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1112 sqlparser::ast::TableFactor::Derived {
1113 subquery, alias, ..
1114 } => {
1115 let alias_name = alias
1116 .as_ref()
1117 .map(|a| a.name.value.clone())
1118 .ok_or_else(|| {
1119 QueryError::ParseError("subquery in JOIN requires an alias".to_string())
1120 })?;
1121
1122 let inner = match subquery.body.as_ref() {
1124 SetExpr::Select(s) => parse_select(s)?,
1125 _ => {
1126 return Err(QueryError::UnsupportedFeature(
1127 "subquery body must be a simple SELECT".to_string(),
1128 ));
1129 }
1130 };
1131
1132 let order_by = match &subquery.order_by {
1133 Some(ob) => parse_order_by(ob)?,
1134 None => vec![],
1135 };
1136 let limit = parse_limit(subquery.limit.as_ref())?;
1137
1138 inline_ctes.push(ParsedCte {
1139 name: alias_name.clone(),
1140 query: ParsedSelect {
1141 order_by,
1142 limit,
1143 ..inner
1144 },
1145 });
1146
1147 alias_name
1148 }
1149 _ => {
1150 return Err(QueryError::UnsupportedFeature(
1151 "unsupported JOIN relation type".to_string(),
1152 ));
1153 }
1154 };
1155
1156 let on_condition = match &join.join_operator {
1158 JoinOperator::Inner(JoinConstraint::On(expr))
1159 | JoinOperator::LeftOuter(JoinConstraint::On(expr)) => parse_join_condition(expr)?,
1160 JoinOperator::Inner(JoinConstraint::Using(_))
1161 | JoinOperator::LeftOuter(JoinConstraint::Using(_)) => {
1162 return Err(QueryError::UnsupportedFeature(
1163 "USING clause not supported".to_string(),
1164 ));
1165 }
1166 _ => {
1167 return Err(QueryError::UnsupportedFeature(
1168 "join without ON clause not supported".to_string(),
1169 ));
1170 }
1171 };
1172
1173 Ok((
1174 ParsedJoin {
1175 table,
1176 join_type,
1177 on_condition,
1178 },
1179 inline_ctes,
1180 ))
1181}
1182
1183fn parse_join_condition(expr: &Expr) -> Result<Vec<Predicate>> {
1186 match expr {
1187 Expr::BinaryOp {
1188 left,
1189 op: BinaryOperator::And,
1190 right,
1191 } => {
1192 let mut predicates = parse_join_condition(left)?;
1193 predicates.extend(parse_join_condition(right)?);
1194 Ok(predicates)
1195 }
1196 _ => {
1197 parse_where_expr(expr)
1199 }
1200 }
1201}
1202
1203fn parse_select(select: &Select) -> Result<ParsedSelect> {
1204 let distinct = select.distinct.is_some();
1206
1207 if select.from.len() != 1 {
1209 return Err(QueryError::ParseError(format!(
1210 "expected exactly 1 table in FROM clause, got {}",
1211 select.from.len()
1212 )));
1213 }
1214
1215 let from = &select.from[0];
1216
1217 let mut inline_ctes = Vec::new();
1219
1220 let mut joins = Vec::new();
1222 for join in &from.joins {
1223 let (parsed_join, join_ctes) = parse_join_with_subqueries(join)?;
1224 joins.push(parsed_join);
1225 inline_ctes.extend(join_ctes);
1226 }
1227
1228 let table = match &from.relation {
1229 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1230 sqlparser::ast::TableFactor::Derived {
1231 subquery, alias, ..
1232 } => {
1233 let alias_name = alias
1234 .as_ref()
1235 .map(|a| a.name.value.clone())
1236 .ok_or_else(|| {
1237 QueryError::ParseError("subquery in FROM requires an alias".to_string())
1238 })?;
1239
1240 let inner = match subquery.body.as_ref() {
1242 SetExpr::Select(s) => parse_select(s)?,
1243 _ => {
1244 return Err(QueryError::UnsupportedFeature(
1245 "subquery body must be a simple SELECT".to_string(),
1246 ));
1247 }
1248 };
1249
1250 let order_by = match &subquery.order_by {
1251 Some(ob) => parse_order_by(ob)?,
1252 None => vec![],
1253 };
1254 let limit = parse_limit(subquery.limit.as_ref())?;
1255
1256 inline_ctes.push(ParsedCte {
1257 name: alias_name.clone(),
1258 query: ParsedSelect {
1259 order_by,
1260 limit,
1261 ..inner
1262 },
1263 });
1264
1265 alias_name
1266 }
1267 other => {
1268 return Err(QueryError::UnsupportedFeature(format!(
1269 "unsupported FROM clause: {other:?}"
1270 )));
1271 }
1272 };
1273
1274 let columns = parse_select_items(&select.projection)?;
1276
1277 let case_columns = parse_case_columns_from_select_items(&select.projection)?;
1279
1280 let predicates = match &select.selection {
1282 Some(expr) => parse_where_expr(expr)?,
1283 None => vec![],
1284 };
1285
1286 let group_by = match &select.group_by {
1288 sqlparser::ast::GroupByExpr::Expressions(exprs, _) if !exprs.is_empty() => {
1289 parse_group_by_expr(exprs)?
1290 }
1291 sqlparser::ast::GroupByExpr::All(_) => {
1292 return Err(QueryError::UnsupportedFeature(
1293 "GROUP BY ALL is not supported".to_string(),
1294 ));
1295 }
1296 sqlparser::ast::GroupByExpr::Expressions(_, _) => vec![],
1297 };
1298
1299 let aggregates = parse_aggregates_from_select_items(&select.projection)?;
1301
1302 let having = match &select.having {
1304 Some(expr) => parse_having_expr(expr)?,
1305 None => vec![],
1306 };
1307
1308 let window_fns = parse_window_fns_from_select_items(&select.projection)?;
1311
1312 Ok(ParsedSelect {
1313 table,
1314 joins,
1315 columns,
1316 case_columns,
1317 predicates,
1318 order_by: vec![],
1319 limit: None,
1320 aggregates,
1321 group_by,
1322 distinct,
1323 having,
1324 ctes: inline_ctes,
1325 window_fns,
1326 })
1327}
1328
1329fn parse_ctes(with: &sqlparser::ast::With) -> Result<Vec<ParsedCte>> {
1331 let max_ctes = 16;
1332 let mut ctes = Vec::new();
1333
1334 for (i, cte) in with.cte_tables.iter().enumerate() {
1335 if i >= max_ctes {
1336 return Err(QueryError::UnsupportedFeature(format!(
1337 "too many CTEs (max {max_ctes})"
1338 )));
1339 }
1340
1341 let name = cte.alias.name.value.clone();
1342
1343 let inner_select = match cte.query.body.as_ref() {
1345 SetExpr::Select(s) => parse_select(s)?,
1346 _ => {
1347 return Err(QueryError::UnsupportedFeature(
1348 "CTE body must be a simple SELECT".to_string(),
1349 ));
1350 }
1351 };
1352
1353 let order_by = match &cte.query.order_by {
1355 Some(ob) => parse_order_by(ob)?,
1356 None => vec![],
1357 };
1358 let limit = parse_limit(cte.query.limit.as_ref())?;
1359
1360 ctes.push(ParsedCte {
1361 name,
1362 query: ParsedSelect {
1363 order_by,
1364 limit,
1365 ..inner_select
1366 },
1367 });
1368 }
1369
1370 Ok(ctes)
1371}
1372
1373fn parse_having_expr(expr: &Expr) -> Result<Vec<HavingCondition>> {
1378 match expr {
1379 Expr::BinaryOp {
1380 left,
1381 op: BinaryOperator::And,
1382 right,
1383 } => {
1384 let mut conditions = parse_having_expr(left)?;
1385 conditions.extend(parse_having_expr(right)?);
1386 Ok(conditions)
1387 }
1388 Expr::BinaryOp { left, op, right } => {
1389 let aggregate = match left.as_ref() {
1391 Expr::Function(_) => try_parse_aggregate(left)?.ok_or_else(|| {
1392 QueryError::UnsupportedFeature(
1393 "HAVING requires aggregate functions (COUNT, SUM, AVG, MIN, MAX)"
1394 .to_string(),
1395 )
1396 })?,
1397 _ => {
1398 return Err(QueryError::UnsupportedFeature(
1399 "HAVING clause must reference aggregate functions".to_string(),
1400 ));
1401 }
1402 };
1403
1404 let value = expr_to_value(right)?;
1406
1407 let having_op = match op {
1409 BinaryOperator::Eq => HavingOp::Eq,
1410 BinaryOperator::Lt => HavingOp::Lt,
1411 BinaryOperator::LtEq => HavingOp::Le,
1412 BinaryOperator::Gt => HavingOp::Gt,
1413 BinaryOperator::GtEq => HavingOp::Ge,
1414 other => {
1415 return Err(QueryError::UnsupportedFeature(format!(
1416 "unsupported HAVING operator: {other:?}"
1417 )));
1418 }
1419 };
1420
1421 Ok(vec![HavingCondition::AggregateComparison {
1422 aggregate,
1423 op: having_op,
1424 value,
1425 }])
1426 }
1427 Expr::Nested(inner) => parse_having_expr(inner),
1428 other => Err(QueryError::UnsupportedFeature(format!(
1429 "unsupported HAVING expression: {other:?}"
1430 ))),
1431 }
1432}
1433
1434fn parse_select_items(items: &[SelectItem]) -> Result<Option<Vec<ColumnName>>> {
1435 let mut columns = Vec::new();
1436
1437 for item in items {
1438 match item {
1439 SelectItem::Wildcard(_) => {
1440 return Ok(None);
1442 }
1443 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
1444 columns.push(ColumnName::new(ident.value.clone()));
1445 }
1446 SelectItem::UnnamedExpr(Expr::CompoundIdentifier(idents)) if idents.len() == 2 => {
1447 columns.push(ColumnName::new(idents[1].value.clone()));
1449 }
1450 SelectItem::ExprWithAlias {
1451 expr: Expr::Identifier(ident),
1452 alias,
1453 } => {
1454 let _ = alias;
1456 columns.push(ColumnName::new(ident.value.clone()));
1457 }
1458 SelectItem::ExprWithAlias {
1459 expr: Expr::CompoundIdentifier(idents),
1460 alias,
1461 } if idents.len() == 2 => {
1462 let _ = alias;
1464 columns.push(ColumnName::new(idents[1].value.clone()));
1465 }
1466 SelectItem::UnnamedExpr(Expr::Function(_))
1467 | SelectItem::ExprWithAlias {
1468 expr: Expr::Function(_) | Expr::Case { .. },
1469 ..
1470 } => {
1471 }
1474 other => {
1475 return Err(QueryError::UnsupportedFeature(format!(
1476 "unsupported SELECT item: {other:?}"
1477 )));
1478 }
1479 }
1480 }
1481
1482 Ok(Some(columns))
1483}
1484
1485fn parse_aggregates_from_select_items(items: &[SelectItem]) -> Result<Vec<AggregateFunction>> {
1487 let mut aggregates = Vec::new();
1488
1489 for item in items {
1490 match item {
1491 SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
1492 if let Some(agg) = try_parse_aggregate(expr)? {
1493 aggregates.push(agg);
1494 }
1495 }
1496 _ => {
1497 }
1499 }
1500 }
1501
1502 Ok(aggregates)
1503}
1504
1505fn parse_case_columns_from_select_items(items: &[SelectItem]) -> Result<Vec<ComputedColumn>> {
1510 let mut case_cols = Vec::new();
1511
1512 for item in items {
1513 if let SelectItem::ExprWithAlias {
1514 expr:
1515 Expr::Case {
1516 operand,
1517 conditions,
1518 results,
1519 else_result,
1520 },
1521 alias,
1522 } = item
1523 {
1524 if operand.is_some() {
1526 return Err(QueryError::UnsupportedFeature(
1527 "simple CASE (CASE expr WHEN val THEN ...) is not supported; use searched CASE (CASE WHEN cond THEN ...)".to_string(),
1528 ));
1529 }
1530
1531 if conditions.len() != results.len() {
1532 return Err(QueryError::ParseError(
1533 "CASE expression has mismatched WHEN/THEN count".to_string(),
1534 ));
1535 }
1536
1537 let mut when_clauses = Vec::new();
1538 for (cond_expr, result_expr) in conditions.iter().zip(results.iter()) {
1539 let condition = parse_where_expr(cond_expr)?;
1540 let result = expr_to_value(result_expr)?;
1541 when_clauses.push(CaseWhenArm { condition, result });
1542 }
1543
1544 let else_value = match else_result {
1545 Some(expr) => expr_to_value(expr)?,
1546 None => Value::Null,
1547 };
1548
1549 case_cols.push(ComputedColumn {
1550 alias: ColumnName::new(alias.value.clone()),
1551 when_clauses,
1552 else_value,
1553 });
1554 }
1555 }
1556
1557 Ok(case_cols)
1558}
1559
1560fn parse_window_fns_from_select_items(items: &[SelectItem]) -> Result<Vec<ParsedWindowFn>> {
1564 let mut out = Vec::new();
1565 for item in items {
1566 let (expr, alias) = match item {
1567 SelectItem::UnnamedExpr(e) => (e, None),
1568 SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
1569 _ => continue,
1570 };
1571 if let Some(parsed) = try_parse_window_fn(expr, alias)? {
1572 out.push(parsed);
1573 }
1574 }
1575 Ok(out)
1576}
1577
1578fn try_parse_window_fn(expr: &Expr, alias: Option<String>) -> Result<Option<ParsedWindowFn>> {
1579 let Expr::Function(func) = expr else {
1580 return Ok(None);
1581 };
1582 let Some(over) = &func.over else {
1583 return Ok(None);
1584 };
1585 let spec = match over {
1586 sqlparser::ast::WindowType::WindowSpec(s) => s,
1587 sqlparser::ast::WindowType::NamedWindow(_) => {
1588 return Err(QueryError::UnsupportedFeature(
1589 "named windows (OVER w) are not supported".into(),
1590 ));
1591 }
1592 };
1593 if spec.window_frame.is_some() {
1594 return Err(QueryError::UnsupportedFeature(
1595 "explicit window frames (ROWS/RANGE BETWEEN ...) are not supported; \
1596 omit the frame clause for default behaviour"
1597 .into(),
1598 ));
1599 }
1600
1601 let func_name = func.name.to_string().to_uppercase();
1602 let args = match &func.args {
1603 sqlparser::ast::FunctionArguments::List(list) => list.args.clone(),
1604 _ => Vec::new(),
1605 };
1606 let function = parse_window_function_name(&func_name, &args)?;
1607
1608 let partition_by: Vec<ColumnName> = spec
1609 .partition_by
1610 .iter()
1611 .map(parse_column_expr)
1612 .collect::<Result<_>>()?;
1613 let order_by: Vec<OrderByClause> = spec
1614 .order_by
1615 .iter()
1616 .map(parse_order_by_expr)
1617 .collect::<Result<_>>()?;
1618
1619 Ok(Some(ParsedWindowFn {
1620 function,
1621 partition_by,
1622 order_by,
1623 alias,
1624 }))
1625}
1626
1627fn parse_column_expr(expr: &Expr) -> Result<ColumnName> {
1628 match expr {
1629 Expr::Identifier(ident) => Ok(ColumnName::new(ident.value.clone())),
1630 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
1631 Ok(ColumnName::new(idents[1].value.clone()))
1632 }
1633 other => Err(QueryError::UnsupportedFeature(format!(
1634 "window PARTITION BY / argument must be a column reference, got: {other:?}"
1635 ))),
1636 }
1637}
1638
1639fn parse_window_function_name(
1640 name: &str,
1641 args: &[sqlparser::ast::FunctionArg],
1642) -> Result<crate::window::WindowFunction> {
1643 use crate::window::WindowFunction;
1644
1645 let arg_exprs: Vec<&Expr> = args
1646 .iter()
1647 .filter_map(|a| match a {
1648 sqlparser::ast::FunctionArg::Unnamed(
1649 sqlparser::ast::FunctionArgExpr::Expr(e),
1650 ) => Some(e),
1651 _ => None,
1652 })
1653 .collect();
1654
1655 let single_col = || -> Result<ColumnName> {
1656 if arg_exprs.is_empty() {
1657 return Err(QueryError::ParseError(format!(
1658 "{name} requires a column argument"
1659 )));
1660 }
1661 parse_column_expr(arg_exprs[0])
1662 };
1663
1664 let parse_offset = || -> Result<usize> {
1665 if arg_exprs.len() < 2 {
1666 return Ok(1);
1667 }
1668 match arg_exprs[1] {
1669 Expr::Value(SqlValue::Number(n, _)) => n.parse::<usize>().map_err(|_| {
1670 QueryError::ParseError(format!("invalid {name} offset: {n}"))
1671 }),
1672 other => Err(QueryError::UnsupportedFeature(format!(
1673 "{name} offset must be a literal integer; got {other:?}"
1674 ))),
1675 }
1676 };
1677
1678 match name {
1679 "ROW_NUMBER" => Ok(WindowFunction::RowNumber),
1680 "RANK" => Ok(WindowFunction::Rank),
1681 "DENSE_RANK" => Ok(WindowFunction::DenseRank),
1682 "LAG" => Ok(WindowFunction::Lag {
1683 column: single_col()?,
1684 offset: parse_offset()?,
1685 }),
1686 "LEAD" => Ok(WindowFunction::Lead {
1687 column: single_col()?,
1688 offset: parse_offset()?,
1689 }),
1690 "FIRST_VALUE" => Ok(WindowFunction::FirstValue {
1691 column: single_col()?,
1692 }),
1693 "LAST_VALUE" => Ok(WindowFunction::LastValue {
1694 column: single_col()?,
1695 }),
1696 other => Err(QueryError::UnsupportedFeature(format!(
1697 "unknown window function: {other}"
1698 ))),
1699 }
1700}
1701
1702fn try_parse_aggregate(expr: &Expr) -> Result<Option<AggregateFunction>> {
1705 match expr {
1706 Expr::Function(func) => {
1707 if func.over.is_some() {
1712 return Ok(None);
1713 }
1714 let func_name = func.name.to_string().to_uppercase();
1715
1716 let args = match &func.args {
1718 sqlparser::ast::FunctionArguments::List(list) => &list.args,
1719 _ => {
1720 return Err(QueryError::UnsupportedFeature(
1721 "non-list function arguments not supported".to_string(),
1722 ));
1723 }
1724 };
1725
1726 match func_name.as_str() {
1727 "COUNT" => {
1728 if args.len() == 1 {
1730 match &args[0] {
1731 sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
1732 sqlparser::ast::FunctionArgExpr::Wildcard => {
1733 Ok(Some(AggregateFunction::CountStar))
1734 }
1735 sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
1736 Ok(Some(AggregateFunction::Count(ColumnName::new(
1737 ident.value.clone(),
1738 ))))
1739 }
1740 _ => Err(QueryError::UnsupportedFeature(
1741 "COUNT with complex expression not supported".to_string(),
1742 )),
1743 },
1744 _ => Err(QueryError::UnsupportedFeature(
1745 "named function arguments not supported".to_string(),
1746 )),
1747 }
1748 } else {
1749 Err(QueryError::ParseError(format!(
1750 "COUNT expects 1 argument, got {}",
1751 args.len()
1752 )))
1753 }
1754 }
1755 "SUM" | "AVG" | "MIN" | "MAX" => {
1756 if args.len() != 1 {
1758 return Err(QueryError::ParseError(format!(
1759 "{} expects 1 argument, got {}",
1760 func_name,
1761 args.len()
1762 )));
1763 }
1764
1765 match &args[0] {
1766 sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
1767 sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
1768 let column = ColumnName::new(ident.value.clone());
1769 match func_name.as_str() {
1770 "SUM" => Ok(Some(AggregateFunction::Sum(column))),
1771 "AVG" => Ok(Some(AggregateFunction::Avg(column))),
1772 "MIN" => Ok(Some(AggregateFunction::Min(column))),
1773 "MAX" => Ok(Some(AggregateFunction::Max(column))),
1774 _ => unreachable!(),
1775 }
1776 }
1777 _ => Err(QueryError::UnsupportedFeature(format!(
1778 "{func_name} with complex expression not supported"
1779 ))),
1780 },
1781 _ => Err(QueryError::UnsupportedFeature(
1782 "named function arguments not supported".to_string(),
1783 )),
1784 }
1785 }
1786 _ => {
1787 Ok(None)
1789 }
1790 }
1791 }
1792 _ => {
1793 Ok(None)
1795 }
1796 }
1797}
1798
1799fn parse_group_by_expr(exprs: &[Expr]) -> Result<Vec<ColumnName>> {
1801 let mut columns = Vec::new();
1802
1803 for expr in exprs {
1804 match expr {
1805 Expr::Identifier(ident) => {
1806 columns.push(ColumnName::new(ident.value.clone()));
1807 }
1808 _ => {
1809 return Err(QueryError::UnsupportedFeature(
1810 "complex GROUP BY expressions not supported".to_string(),
1811 ));
1812 }
1813 }
1814 }
1815
1816 Ok(columns)
1817}
1818
1819const MAX_WHERE_DEPTH: usize = 100;
1827
1828fn parse_where_expr(expr: &Expr) -> Result<Vec<Predicate>> {
1829 parse_where_expr_inner(expr, 0)
1830}
1831
1832fn parse_where_expr_inner(expr: &Expr, depth: usize) -> Result<Vec<Predicate>> {
1833 if depth >= MAX_WHERE_DEPTH {
1834 return Err(QueryError::ParseError(format!(
1835 "WHERE clause nesting exceeds maximum depth of {MAX_WHERE_DEPTH}"
1836 )));
1837 }
1838
1839 match expr {
1840 Expr::BinaryOp {
1842 left,
1843 op: BinaryOperator::And,
1844 right,
1845 } => {
1846 let mut predicates = parse_where_expr_inner(left, depth + 1)?;
1847 predicates.extend(parse_where_expr_inner(right, depth + 1)?);
1848 Ok(predicates)
1849 }
1850
1851 Expr::BinaryOp {
1853 left,
1854 op: BinaryOperator::Or,
1855 right,
1856 } => {
1857 let left_preds = parse_where_expr_inner(left, depth + 1)?;
1858 let right_preds = parse_where_expr_inner(right, depth + 1)?;
1859 Ok(vec![Predicate::Or(left_preds, right_preds)])
1860 }
1861
1862 Expr::Like {
1864 expr,
1865 pattern,
1866 negated,
1867 ..
1868 } => {
1869 if *negated {
1870 return Err(QueryError::UnsupportedFeature(
1871 "NOT LIKE is not supported".to_string(),
1872 ));
1873 }
1874
1875 let column = expr_to_column(expr)?;
1876 let pattern_value = expr_to_predicate_value(pattern)?;
1877
1878 match pattern_value {
1879 PredicateValue::String(pattern_str)
1880 | PredicateValue::Literal(Value::Text(pattern_str)) => {
1881 Ok(vec![Predicate::Like(column, pattern_str)])
1882 }
1883 _ => Err(QueryError::UnsupportedFeature(
1884 "LIKE pattern must be a string literal".to_string(),
1885 )),
1886 }
1887 }
1888
1889 Expr::IsNull(expr) => {
1891 let column = expr_to_column(expr)?;
1892 Ok(vec![Predicate::IsNull(column)])
1893 }
1894
1895 Expr::IsNotNull(expr) => {
1896 let column = expr_to_column(expr)?;
1897 Ok(vec![Predicate::IsNotNull(column)])
1898 }
1899
1900 Expr::BinaryOp { left, op, right } => {
1902 let predicate = parse_comparison(left, op, right)?;
1903 Ok(vec![predicate])
1904 }
1905
1906 Expr::InList {
1908 expr,
1909 list,
1910 negated,
1911 } => {
1912 if *negated {
1913 return Err(QueryError::UnsupportedFeature(
1914 "NOT IN is not supported".to_string(),
1915 ));
1916 }
1917
1918 let column = expr_to_column(expr)?;
1919 let values: Result<Vec<_>> = list.iter().map(expr_to_predicate_value).collect();
1920 Ok(vec![Predicate::In(column, values?)])
1921 }
1922
1923 Expr::Between {
1925 expr,
1926 negated,
1927 low,
1928 high,
1929 } => {
1930 if *negated {
1931 return Err(QueryError::UnsupportedFeature(
1932 "NOT BETWEEN is not supported".to_string(),
1933 ));
1934 }
1935
1936 let column = expr_to_column(expr)?;
1937 let low_val = expr_to_predicate_value(low)?;
1938 let high_val = expr_to_predicate_value(high)?;
1939
1940 kimberlite_properties::sometimes!(
1941 true,
1942 "query.between_desugared_to_ge_le",
1943 "BETWEEN predicate desugared into Ge + Le pair"
1944 );
1945
1946 Ok(vec![
1947 Predicate::Ge(column.clone(), low_val),
1948 Predicate::Le(column, high_val),
1949 ])
1950 }
1951
1952 Expr::Nested(inner) => parse_where_expr_inner(inner, depth + 1),
1954
1955 other => Err(QueryError::UnsupportedFeature(format!(
1956 "unsupported WHERE expression: {other:?}"
1957 ))),
1958 }
1959}
1960
1961fn parse_comparison(left: &Expr, op: &BinaryOperator, right: &Expr) -> Result<Predicate> {
1962 let column = expr_to_column(left)?;
1963 let value = expr_to_predicate_value(right)?;
1964
1965 match op {
1966 BinaryOperator::Eq => Ok(Predicate::Eq(column, value)),
1967 BinaryOperator::Lt => Ok(Predicate::Lt(column, value)),
1968 BinaryOperator::LtEq => Ok(Predicate::Le(column, value)),
1969 BinaryOperator::Gt => Ok(Predicate::Gt(column, value)),
1970 BinaryOperator::GtEq => Ok(Predicate::Ge(column, value)),
1971 other => Err(QueryError::UnsupportedFeature(format!(
1972 "unsupported operator: {other:?}"
1973 ))),
1974 }
1975}
1976
1977fn expr_to_column(expr: &Expr) -> Result<ColumnName> {
1978 match expr {
1979 Expr::Identifier(ident) => Ok(ColumnName::new(ident.value.clone())),
1980 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
1981 Ok(ColumnName::new(idents[1].value.clone()))
1983 }
1984 other => Err(QueryError::UnsupportedFeature(format!(
1985 "expected column name, got {other:?}"
1986 ))),
1987 }
1988}
1989
1990fn expr_to_predicate_value(expr: &Expr) -> Result<PredicateValue> {
1991 match expr {
1992 Expr::Identifier(ident) => {
1994 Ok(PredicateValue::ColumnRef(ident.value.clone()))
1996 }
1997 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
1998 Ok(PredicateValue::ColumnRef(format!(
2000 "{}.{}",
2001 idents[0].value, idents[1].value
2002 )))
2003 }
2004 Expr::Value(SqlValue::Number(n, _)) => {
2005 let value = parse_number_literal(n)?;
2006 match value {
2007 Value::BigInt(v) => Ok(PredicateValue::Int(v)),
2008 Value::Decimal(_, _) => Ok(PredicateValue::Literal(value)),
2009 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
2010 }
2011 }
2012 Expr::Value(SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s)) => {
2013 Ok(PredicateValue::String(s.clone()))
2014 }
2015 Expr::Value(SqlValue::Boolean(b)) => Ok(PredicateValue::Bool(*b)),
2016 Expr::Value(SqlValue::Null) => Ok(PredicateValue::Null),
2017 Expr::Value(SqlValue::Placeholder(p)) => {
2018 if let Some(num_str) = p.strip_prefix('$') {
2020 let idx: usize = num_str.parse().map_err(|_| {
2021 QueryError::ParseError(format!("invalid parameter placeholder: {p}"))
2022 })?;
2023 if idx == 0 {
2025 return Err(QueryError::ParseError(
2026 "parameter indices start at $1, not $0".to_string(),
2027 ));
2028 }
2029 Ok(PredicateValue::Param(idx))
2030 } else {
2031 Err(QueryError::ParseError(format!(
2032 "unsupported placeholder format: {p}"
2033 )))
2034 }
2035 }
2036 Expr::UnaryOp {
2037 op: sqlparser::ast::UnaryOperator::Minus,
2038 expr,
2039 } => {
2040 if let Expr::Value(SqlValue::Number(n, _)) = expr.as_ref() {
2042 let value = parse_number_literal(n)?;
2043 match value {
2044 Value::BigInt(v) => Ok(PredicateValue::Int(-v)),
2045 Value::Decimal(v, scale) => {
2046 Ok(PredicateValue::Literal(Value::Decimal(-v, scale)))
2047 }
2048 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
2049 }
2050 } else {
2051 Err(QueryError::UnsupportedFeature(format!(
2052 "unsupported unary minus operand: {expr:?}"
2053 )))
2054 }
2055 }
2056 other => Err(QueryError::UnsupportedFeature(format!(
2057 "unsupported value expression: {other:?}"
2058 ))),
2059 }
2060}
2061
2062fn parse_order_by(order_by: &sqlparser::ast::OrderBy) -> Result<Vec<OrderByClause>> {
2063 let mut clauses = Vec::new();
2064
2065 for expr in &order_by.exprs {
2066 clauses.push(parse_order_by_expr(expr)?);
2067 }
2068
2069 Ok(clauses)
2070}
2071
2072fn parse_order_by_expr(expr: &OrderByExpr) -> Result<OrderByClause> {
2073 let column = match &expr.expr {
2074 Expr::Identifier(ident) => ColumnName::new(ident.value.clone()),
2075 other => {
2076 return Err(QueryError::UnsupportedFeature(format!(
2077 "unsupported ORDER BY expression: {other:?}"
2078 )));
2079 }
2080 };
2081
2082 let ascending = expr.asc.unwrap_or(true);
2083
2084 Ok(OrderByClause { column, ascending })
2085}
2086
2087fn parse_limit(limit: Option<&Expr>) -> Result<Option<usize>> {
2088 match limit {
2089 None => Ok(None),
2090 Some(Expr::Value(SqlValue::Number(n, _))) => {
2091 let v: usize = n
2092 .parse()
2093 .map_err(|_| QueryError::ParseError(format!("invalid LIMIT value: {n}")))?;
2094 Ok(Some(v))
2095 }
2096 Some(other) => Err(QueryError::UnsupportedFeature(format!(
2097 "unsupported LIMIT expression: {other:?}"
2098 ))),
2099 }
2100}
2101
2102fn object_name_to_string(name: &ObjectName) -> String {
2103 name.0
2104 .iter()
2105 .map(|i: &Ident| i.value.clone())
2106 .collect::<Vec<_>>()
2107 .join(".")
2108}
2109
2110fn parse_create_table(create_table: &sqlparser::ast::CreateTable) -> Result<ParsedCreateTable> {
2115 let table_name = object_name_to_string(&create_table.name);
2116
2117 let mut raw_columns = Vec::new();
2127 for col_def in &create_table.columns {
2128 let parsed_col = parse_column_def(col_def)?;
2129 raw_columns.push(parsed_col);
2130 }
2131 let columns = NonEmptyVec::try_new(raw_columns).map_err(|_| {
2132 crate::error::QueryError::ParseError(format!(
2133 "CREATE TABLE {table_name} requires at least one column"
2134 ))
2135 })?;
2136
2137 let mut primary_key = Vec::new();
2139 for constraint in &create_table.constraints {
2140 if let sqlparser::ast::TableConstraint::PrimaryKey {
2141 columns: pk_cols, ..
2142 } = constraint
2143 {
2144 for col in pk_cols {
2145 primary_key.push(col.value.clone());
2146 }
2147 }
2148 }
2149
2150 if primary_key.is_empty() {
2152 for col_def in &create_table.columns {
2153 for option in &col_def.options {
2154 if matches!(
2155 &option.option,
2156 sqlparser::ast::ColumnOption::Unique { is_primary, .. } if *is_primary
2157 ) {
2158 primary_key.push(col_def.name.value.clone());
2159 }
2160 }
2161 }
2162 }
2163
2164 Ok(ParsedCreateTable {
2165 table_name,
2166 columns,
2167 primary_key,
2168 if_not_exists: create_table.if_not_exists,
2169 })
2170}
2171
2172fn parse_column_def(col_def: &SqlColumnDef) -> Result<ParsedColumn> {
2173 let name = col_def.name.value.clone();
2174
2175 let data_type = match &col_def.data_type {
2178 SqlDataType::TinyInt(_) => "TINYINT".to_string(),
2180 SqlDataType::SmallInt(_) => "SMALLINT".to_string(),
2181 SqlDataType::Int(_) | SqlDataType::Integer(_) => "INTEGER".to_string(),
2182 SqlDataType::BigInt(_) => "BIGINT".to_string(),
2183
2184 SqlDataType::Real | SqlDataType::Float(_) | SqlDataType::Double(_) => "REAL".to_string(),
2186 SqlDataType::Decimal(precision_opt) => match precision_opt {
2187 sqlparser::ast::ExactNumberInfo::PrecisionAndScale(p, s) => {
2188 format!("DECIMAL({p},{s})")
2189 }
2190 sqlparser::ast::ExactNumberInfo::Precision(p) => {
2191 format!("DECIMAL({p},0)")
2192 }
2193 sqlparser::ast::ExactNumberInfo::None => "DECIMAL(18,2)".to_string(),
2194 },
2195
2196 SqlDataType::Text | SqlDataType::Varchar(_) | SqlDataType::String(_) => "TEXT".to_string(),
2198
2199 SqlDataType::Binary(_) | SqlDataType::Varbinary(_) | SqlDataType::Blob(_) => {
2201 "BYTES".to_string()
2202 }
2203
2204 SqlDataType::Boolean | SqlDataType::Bool => "BOOLEAN".to_string(),
2206
2207 SqlDataType::Date => "DATE".to_string(),
2209 SqlDataType::Time(_, _) => "TIME".to_string(),
2210 SqlDataType::Timestamp(_, _) => "TIMESTAMP".to_string(),
2211
2212 SqlDataType::Uuid => "UUID".to_string(),
2214 SqlDataType::JSON => "JSON".to_string(),
2215
2216 other => {
2217 return Err(QueryError::UnsupportedFeature(format!(
2218 "unsupported data type: {other:?}"
2219 )));
2220 }
2221 };
2222
2223 let mut nullable = true;
2225 for option in &col_def.options {
2226 if matches!(option.option, sqlparser::ast::ColumnOption::NotNull) {
2227 nullable = false;
2228 }
2229 }
2230
2231 Ok(ParsedColumn {
2232 name,
2233 data_type,
2234 nullable,
2235 })
2236}
2237
2238fn parse_alter_table(
2239 name: &sqlparser::ast::ObjectName,
2240 operations: &[sqlparser::ast::AlterTableOperation],
2241) -> Result<ParsedAlterTable> {
2242 let table_name = object_name_to_string(name);
2243
2244 if operations.len() != 1 {
2246 return Err(QueryError::UnsupportedFeature(
2247 "ALTER TABLE supports only one operation at a time".to_string(),
2248 ));
2249 }
2250
2251 let operation = match &operations[0] {
2252 sqlparser::ast::AlterTableOperation::AddColumn { column_def, .. } => {
2253 let parsed_col = parse_column_def(column_def)?;
2254 AlterTableOperation::AddColumn(parsed_col)
2255 }
2256 sqlparser::ast::AlterTableOperation::DropColumn {
2257 column_name,
2258 if_exists: _,
2259 ..
2260 } => {
2261 let col_name = column_name.value.clone();
2262 AlterTableOperation::DropColumn(col_name)
2263 }
2264 other => {
2265 return Err(QueryError::UnsupportedFeature(format!(
2266 "ALTER TABLE operation not supported: {other:?}"
2267 )));
2268 }
2269 };
2270
2271 Ok(ParsedAlterTable {
2272 table_name,
2273 operation,
2274 })
2275}
2276
2277fn parse_create_index(create_index: &sqlparser::ast::CreateIndex) -> Result<ParsedCreateIndex> {
2278 let index_name = match &create_index.name {
2279 Some(name) => object_name_to_string(name),
2280 None => {
2281 return Err(QueryError::ParseError(
2282 "CREATE INDEX requires an index name".to_string(),
2283 ));
2284 }
2285 };
2286
2287 let table_name = object_name_to_string(&create_index.table_name);
2288
2289 let mut columns = Vec::new();
2290 for col in &create_index.columns {
2291 columns.push(col.expr.to_string());
2292 }
2293
2294 Ok(ParsedCreateIndex {
2295 index_name,
2296 table_name,
2297 columns,
2298 })
2299}
2300
2301fn parse_insert(insert: &sqlparser::ast::Insert) -> Result<ParsedInsert> {
2306 let table = insert.table.to_string();
2308
2309 let columns: Vec<String> = insert.columns.iter().map(|c| c.value.clone()).collect();
2311
2312 let values = match insert.source.as_ref().map(|s| s.body.as_ref()) {
2314 Some(SetExpr::Values(values)) => {
2315 let mut all_rows = Vec::new();
2316 for row in &values.rows {
2317 let mut parsed_row = Vec::new();
2318 for expr in row {
2319 let val = expr_to_value(expr)?;
2320 parsed_row.push(val);
2321 }
2322 all_rows.push(parsed_row);
2323 }
2324 all_rows
2325 }
2326 _ => {
2327 return Err(QueryError::UnsupportedFeature(
2328 "only VALUES clause is supported in INSERT".to_string(),
2329 ));
2330 }
2331 };
2332
2333 let returning = parse_returning(insert.returning.as_ref())?;
2335
2336 Ok(ParsedInsert {
2337 table,
2338 columns,
2339 values,
2340 returning,
2341 })
2342}
2343
2344fn parse_update(
2345 table: &sqlparser::ast::TableWithJoins,
2346 assignments: &[sqlparser::ast::Assignment],
2347 selection: Option<&Expr>,
2348 returning: Option<&Vec<SelectItem>>,
2349) -> Result<ParsedUpdate> {
2350 let table_name = match &table.relation {
2351 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
2352 other => {
2353 return Err(QueryError::UnsupportedFeature(format!(
2354 "unsupported table in UPDATE: {other:?}"
2355 )));
2356 }
2357 };
2358
2359 let mut parsed_assignments = Vec::new();
2361 for assignment in assignments {
2362 let col_name = assignment.target.to_string();
2363 let value = expr_to_value(&assignment.value)?;
2364 parsed_assignments.push((col_name, value));
2365 }
2366
2367 let predicates = match selection {
2369 Some(expr) => parse_where_expr(expr)?,
2370 None => vec![],
2371 };
2372
2373 let returning_cols = parse_returning(returning)?;
2375
2376 Ok(ParsedUpdate {
2377 table: table_name,
2378 assignments: parsed_assignments,
2379 predicates,
2380 returning: returning_cols,
2381 })
2382}
2383
2384fn parse_delete_stmt(delete: &sqlparser::ast::Delete) -> Result<ParsedDelete> {
2385 use sqlparser::ast::FromTable;
2387
2388 let table_name = match &delete.from {
2389 FromTable::WithFromKeyword(tables) => {
2390 if tables.len() != 1 {
2391 return Err(QueryError::ParseError(
2392 "expected exactly 1 table in DELETE FROM".to_string(),
2393 ));
2394 }
2395
2396 match &tables[0].relation {
2397 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
2398 _ => {
2399 return Err(QueryError::ParseError(
2400 "DELETE only supports simple table names".to_string(),
2401 ));
2402 }
2403 }
2404 }
2405 FromTable::WithoutKeyword(tables) => {
2406 if tables.len() != 1 {
2407 return Err(QueryError::ParseError(
2408 "expected exactly 1 table in DELETE".to_string(),
2409 ));
2410 }
2411
2412 match &tables[0].relation {
2413 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
2414 _ => {
2415 return Err(QueryError::ParseError(
2416 "DELETE only supports simple table names".to_string(),
2417 ));
2418 }
2419 }
2420 }
2421 };
2422
2423 let predicates = match &delete.selection {
2425 Some(expr) => parse_where_expr(expr)?,
2426 None => vec![],
2427 };
2428
2429 let returning_cols = parse_returning(delete.returning.as_ref())?;
2431
2432 Ok(ParsedDelete {
2433 table: table_name,
2434 predicates,
2435 returning: returning_cols,
2436 })
2437}
2438
2439fn parse_returning(returning: Option<&Vec<SelectItem>>) -> Result<Option<Vec<String>>> {
2441 match returning {
2442 None => Ok(None),
2443 Some(items) => {
2444 let mut columns = Vec::new();
2445 for item in items {
2446 match item {
2447 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
2448 columns.push(ident.value.clone());
2449 }
2450 SelectItem::UnnamedExpr(Expr::CompoundIdentifier(parts)) => {
2451 if let Some(last) = parts.last() {
2453 columns.push(last.value.clone());
2454 } else {
2455 return Err(QueryError::ParseError(
2456 "invalid column in RETURNING clause".to_string(),
2457 ));
2458 }
2459 }
2460 _ => {
2461 return Err(QueryError::UnsupportedFeature(
2462 "only simple column names supported in RETURNING clause".to_string(),
2463 ));
2464 }
2465 }
2466 }
2467 Ok(Some(columns))
2468 }
2469 }
2470}
2471
2472fn parse_number_literal(n: &str) -> Result<Value> {
2476 use rust_decimal::Decimal;
2477 use std::str::FromStr;
2478
2479 if n.contains('.') {
2480 let decimal = Decimal::from_str(n)
2482 .map_err(|e| QueryError::ParseError(format!("invalid decimal '{n}': {e}")))?;
2483
2484 let scale = decimal.scale() as u8;
2486
2487 if scale > 38 {
2488 return Err(QueryError::ParseError(format!(
2489 "decimal scale too large (max 38): {n}"
2490 )));
2491 }
2492
2493 let mantissa = decimal.mantissa();
2496
2497 Ok(Value::Decimal(mantissa, scale))
2498 } else {
2499 let v: i64 = n
2501 .parse()
2502 .map_err(|_| QueryError::ParseError(format!("invalid integer: {n}")))?;
2503 Ok(Value::BigInt(v))
2504 }
2505}
2506
2507fn expr_to_value(expr: &Expr) -> Result<Value> {
2509 match expr {
2510 Expr::Value(SqlValue::Number(n, _)) => parse_number_literal(n),
2511 Expr::Value(SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s)) => {
2512 Ok(Value::Text(s.clone()))
2513 }
2514 Expr::Value(SqlValue::Boolean(b)) => Ok(Value::Boolean(*b)),
2515 Expr::Value(SqlValue::Null) => Ok(Value::Null),
2516 Expr::Value(SqlValue::Placeholder(p)) => {
2517 if let Some(num_str) = p.strip_prefix('$') {
2519 let idx: usize = num_str.parse().map_err(|_| {
2520 QueryError::ParseError(format!("invalid parameter placeholder: {p}"))
2521 })?;
2522 if idx == 0 {
2524 return Err(QueryError::ParseError(
2525 "parameter indices start at $1, not $0".to_string(),
2526 ));
2527 }
2528 Ok(Value::Placeholder(idx))
2529 } else {
2530 Err(QueryError::ParseError(format!(
2531 "unsupported placeholder format: {p}"
2532 )))
2533 }
2534 }
2535 Expr::UnaryOp {
2536 op: sqlparser::ast::UnaryOperator::Minus,
2537 expr,
2538 } => {
2539 if let Expr::Value(SqlValue::Number(n, _)) = expr.as_ref() {
2541 let value = parse_number_literal(n)?;
2542 match value {
2543 Value::BigInt(v) => Ok(Value::BigInt(-v)),
2544 Value::Decimal(v, scale) => Ok(Value::Decimal(-v, scale)),
2545 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
2546 }
2547 } else {
2548 Err(QueryError::UnsupportedFeature(format!(
2549 "unsupported unary minus operand: {expr:?}"
2550 )))
2551 }
2552 }
2553 other => Err(QueryError::UnsupportedFeature(format!(
2554 "unsupported value expression: {other:?}"
2555 ))),
2556 }
2557}
2558
2559#[cfg(test)]
2560mod tests {
2561 use super::*;
2562
2563 fn parse_test_select(sql: &str) -> ParsedSelect {
2564 match parse_statement(sql).unwrap() {
2565 ParsedStatement::Select(s) => s,
2566 _ => panic!("expected SELECT statement"),
2567 }
2568 }
2569
2570 #[test]
2571 fn test_parse_simple_select() {
2572 let result = parse_test_select("SELECT id, name FROM users");
2573 assert_eq!(result.table, "users");
2574 assert_eq!(
2575 result.columns,
2576 Some(vec![ColumnName::new("id"), ColumnName::new("name")])
2577 );
2578 assert!(result.predicates.is_empty());
2579 }
2580
2581 #[test]
2582 fn test_parse_select_star() {
2583 let result = parse_test_select("SELECT * FROM users");
2584 assert_eq!(result.table, "users");
2585 assert!(result.columns.is_none());
2586 }
2587
2588 #[test]
2589 fn test_parse_where_eq() {
2590 let result = parse_test_select("SELECT * FROM users WHERE id = 42");
2591 assert_eq!(result.predicates.len(), 1);
2592 match &result.predicates[0] {
2593 Predicate::Eq(col, PredicateValue::Int(42)) => {
2594 assert_eq!(col.as_str(), "id");
2595 }
2596 other => panic!("unexpected predicate: {other:?}"),
2597 }
2598 }
2599
2600 #[test]
2601 fn test_parse_where_string() {
2602 let result = parse_test_select("SELECT * FROM users WHERE name = 'alice'");
2603 match &result.predicates[0] {
2604 Predicate::Eq(col, PredicateValue::String(s)) => {
2605 assert_eq!(col.as_str(), "name");
2606 assert_eq!(s, "alice");
2607 }
2608 other => panic!("unexpected predicate: {other:?}"),
2609 }
2610 }
2611
2612 #[test]
2613 fn test_parse_where_and() {
2614 let result = parse_test_select("SELECT * FROM users WHERE id = 1 AND name = 'bob'");
2615 assert_eq!(result.predicates.len(), 2);
2616 }
2617
2618 #[test]
2619 fn test_parse_where_in() {
2620 let result = parse_test_select("SELECT * FROM users WHERE id IN (1, 2, 3)");
2621 match &result.predicates[0] {
2622 Predicate::In(col, values) => {
2623 assert_eq!(col.as_str(), "id");
2624 assert_eq!(values.len(), 3);
2625 }
2626 other => panic!("unexpected predicate: {other:?}"),
2627 }
2628 }
2629
2630 #[test]
2631 fn test_parse_order_by() {
2632 let result = parse_test_select("SELECT * FROM users ORDER BY name ASC, id DESC");
2633 assert_eq!(result.order_by.len(), 2);
2634 assert_eq!(result.order_by[0].column.as_str(), "name");
2635 assert!(result.order_by[0].ascending);
2636 assert_eq!(result.order_by[1].column.as_str(), "id");
2637 assert!(!result.order_by[1].ascending);
2638 }
2639
2640 #[test]
2641 fn test_parse_limit() {
2642 let result = parse_test_select("SELECT * FROM users LIMIT 10");
2643 assert_eq!(result.limit, Some(10));
2644 }
2645
2646 #[test]
2647 fn test_parse_param() {
2648 let result = parse_test_select("SELECT * FROM users WHERE id = $1");
2649 match &result.predicates[0] {
2650 Predicate::Eq(_, PredicateValue::Param(1)) => {}
2651 other => panic!("unexpected predicate: {other:?}"),
2652 }
2653 }
2654
2655 #[test]
2656 fn test_parse_inner_join() {
2657 let result =
2658 parse_statement("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
2659 if let Err(ref e) = result {
2660 eprintln!("Parse error: {e:?}");
2661 }
2662 assert!(result.is_ok());
2663 match result.unwrap() {
2664 ParsedStatement::Select(s) => {
2665 assert_eq!(s.table, "users");
2666 assert_eq!(s.joins.len(), 1);
2667 assert_eq!(s.joins[0].table, "orders");
2668 assert!(matches!(s.joins[0].join_type, JoinType::Inner));
2669 }
2670 _ => panic!("expected SELECT statement"),
2671 }
2672 }
2673
2674 #[test]
2675 fn test_parse_left_join() {
2676 let result =
2677 parse_statement("SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id");
2678 assert!(result.is_ok());
2679 match result.unwrap() {
2680 ParsedStatement::Select(s) => {
2681 assert_eq!(s.table, "users");
2682 assert_eq!(s.joins.len(), 1);
2683 assert_eq!(s.joins[0].table, "orders");
2684 assert!(matches!(s.joins[0].join_type, JoinType::Left));
2685 }
2686 _ => panic!("expected SELECT statement"),
2687 }
2688 }
2689
2690 #[test]
2691 fn test_parse_multi_join() {
2692 let result = parse_statement(
2693 "SELECT * FROM users \
2694 JOIN orders ON users.id = orders.user_id \
2695 JOIN products ON orders.product_id = products.id",
2696 );
2697 assert!(result.is_ok());
2698 match result.unwrap() {
2699 ParsedStatement::Select(s) => {
2700 assert_eq!(s.table, "users");
2701 assert_eq!(s.joins.len(), 2);
2702 assert_eq!(s.joins[0].table, "orders");
2703 assert_eq!(s.joins[1].table, "products");
2704 }
2705 _ => panic!("expected SELECT statement"),
2706 }
2707 }
2708
2709 #[test]
2710 fn test_reject_subquery() {
2711 let result = parse_statement("SELECT * FROM (SELECT * FROM users)");
2712 assert!(result.is_err());
2713 }
2714
2715 #[test]
2716 fn test_where_depth_within_limit() {
2717 let mut sql = String::from("SELECT * FROM users WHERE ");
2720 for i in 0..10 {
2721 if i > 0 {
2722 sql.push_str(" AND ");
2723 }
2724 sql.push('(');
2725 sql.push_str("id = ");
2726 sql.push_str(&i.to_string());
2727 sql.push(')');
2728 }
2729
2730 let result = parse_statement(&sql);
2731 assert!(
2732 result.is_ok(),
2733 "Moderate nesting should succeed, but got: {result:?}"
2734 );
2735 }
2736
2737 #[test]
2738 fn test_where_depth_nested_parens() {
2739 let mut sql = String::from("SELECT * FROM users WHERE ");
2742 for _ in 0..200 {
2743 sql.push('(');
2744 }
2745 sql.push_str("id = 1");
2746 for _ in 0..200 {
2747 sql.push(')');
2748 }
2749
2750 let result = parse_statement(&sql);
2751 assert!(
2752 result.is_err(),
2753 "Excessive parenthesis nesting should be rejected"
2754 );
2755 }
2756
2757 #[test]
2758 fn test_where_depth_complex_and_or() {
2759 let sql = "SELECT * FROM users WHERE \
2761 ((id = 1 AND name = 'a') OR (id = 2 AND name = 'b')) AND \
2762 ((age > 10 AND age < 20) OR (age > 30 AND age < 40))";
2763
2764 let result = parse_statement(sql);
2765 assert!(result.is_ok(), "Complex AND/OR should succeed");
2766 }
2767
2768 #[test]
2769 fn test_parse_having() {
2770 let result =
2771 parse_test_select("SELECT name, COUNT(*) FROM users GROUP BY name HAVING COUNT(*) > 5");
2772 assert_eq!(result.group_by.len(), 1);
2773 assert_eq!(result.having.len(), 1);
2774 match &result.having[0] {
2775 HavingCondition::AggregateComparison {
2776 aggregate,
2777 op,
2778 value,
2779 } => {
2780 assert!(matches!(aggregate, AggregateFunction::CountStar));
2781 assert_eq!(*op, HavingOp::Gt);
2782 assert_eq!(*value, Value::BigInt(5));
2783 }
2784 }
2785 }
2786
2787 #[test]
2788 fn test_parse_having_multiple() {
2789 let result = parse_test_select(
2790 "SELECT name, COUNT(*), SUM(age) FROM users GROUP BY name HAVING COUNT(*) > 1 AND SUM(age) < 100",
2791 );
2792 assert_eq!(result.having.len(), 2);
2793 }
2794
2795 #[test]
2796 fn test_parse_having_without_group_by() {
2797 let result = parse_test_select("SELECT COUNT(*) FROM users HAVING COUNT(*) > 0");
2798 assert!(result.group_by.is_empty());
2799 assert_eq!(result.having.len(), 1);
2800 }
2801
2802 #[test]
2803 fn test_parse_union() {
2804 let result = parse_statement("SELECT id FROM users UNION SELECT id FROM orders");
2805 assert!(result.is_ok());
2806 match result.unwrap() {
2807 ParsedStatement::Union(u) => {
2808 assert_eq!(u.left.table, "users");
2809 assert_eq!(u.right.table, "orders");
2810 assert!(!u.all);
2811 }
2812 _ => panic!("expected UNION statement"),
2813 }
2814 }
2815
2816 #[test]
2817 fn test_parse_union_all() {
2818 let result = parse_statement("SELECT id FROM users UNION ALL SELECT id FROM orders");
2819 assert!(result.is_ok());
2820 match result.unwrap() {
2821 ParsedStatement::Union(u) => {
2822 assert_eq!(u.left.table, "users");
2823 assert_eq!(u.right.table, "orders");
2824 assert!(u.all);
2825 }
2826 _ => panic!("expected UNION ALL statement"),
2827 }
2828 }
2829
2830 #[test]
2831 fn test_parse_create_mask() {
2832 let result =
2833 parse_statement("CREATE MASK ssn_mask ON patients.ssn USING REDACT").unwrap();
2834 match result {
2835 ParsedStatement::CreateMask(m) => {
2836 assert_eq!(m.mask_name, "ssn_mask");
2837 assert_eq!(m.table_name, "patients");
2838 assert_eq!(m.column_name, "ssn");
2839 assert_eq!(m.strategy, "REDACT");
2840 }
2841 _ => panic!("expected CREATE MASK statement"),
2842 }
2843 }
2844
2845 #[test]
2846 fn test_parse_create_mask_with_semicolon() {
2847 let result =
2848 parse_statement("CREATE MASK ssn_mask ON patients.ssn USING REDACT;").unwrap();
2849 match result {
2850 ParsedStatement::CreateMask(m) => {
2851 assert_eq!(m.mask_name, "ssn_mask");
2852 assert_eq!(m.strategy, "REDACT");
2853 }
2854 _ => panic!("expected CREATE MASK statement"),
2855 }
2856 }
2857
2858 #[test]
2859 fn test_parse_create_mask_hash_strategy() {
2860 let result =
2861 parse_statement("CREATE MASK email_hash ON users.email USING HASH").unwrap();
2862 match result {
2863 ParsedStatement::CreateMask(m) => {
2864 assert_eq!(m.mask_name, "email_hash");
2865 assert_eq!(m.table_name, "users");
2866 assert_eq!(m.column_name, "email");
2867 assert_eq!(m.strategy, "HASH");
2868 }
2869 _ => panic!("expected CREATE MASK statement"),
2870 }
2871 }
2872
2873 #[test]
2874 fn test_parse_create_mask_missing_on() {
2875 let result = parse_statement("CREATE MASK ssn_mask patients.ssn USING REDACT");
2876 assert!(result.is_err());
2877 }
2878
2879 #[test]
2880 fn test_parse_create_mask_missing_dot() {
2881 let result = parse_statement("CREATE MASK ssn_mask ON patients_ssn USING REDACT");
2882 assert!(result.is_err());
2883 }
2884
2885 #[test]
2886 fn test_parse_drop_mask() {
2887 let result = parse_statement("DROP MASK ssn_mask").unwrap();
2888 match result {
2889 ParsedStatement::DropMask(name) => {
2890 assert_eq!(name, "ssn_mask");
2891 }
2892 _ => panic!("expected DROP MASK statement"),
2893 }
2894 }
2895
2896 #[test]
2897 fn test_parse_drop_mask_with_semicolon() {
2898 let result = parse_statement("DROP MASK ssn_mask;").unwrap();
2899 match result {
2900 ParsedStatement::DropMask(name) => {
2901 assert_eq!(name, "ssn_mask");
2902 }
2903 _ => panic!("expected DROP MASK statement"),
2904 }
2905 }
2906
2907 #[test]
2912 fn test_parse_set_classification() {
2913 let result = parse_statement(
2914 "ALTER TABLE patients MODIFY COLUMN ssn SET CLASSIFICATION 'PHI'",
2915 )
2916 .unwrap();
2917 match result {
2918 ParsedStatement::SetClassification(sc) => {
2919 assert_eq!(sc.table_name, "patients");
2920 assert_eq!(sc.column_name, "ssn");
2921 assert_eq!(sc.classification, "PHI");
2922 }
2923 _ => panic!("expected SetClassification statement"),
2924 }
2925 }
2926
2927 #[test]
2928 fn test_parse_set_classification_with_semicolon() {
2929 let result = parse_statement(
2930 "ALTER TABLE patients MODIFY COLUMN diagnosis SET CLASSIFICATION 'MEDICAL';",
2931 )
2932 .unwrap();
2933 match result {
2934 ParsedStatement::SetClassification(sc) => {
2935 assert_eq!(sc.table_name, "patients");
2936 assert_eq!(sc.column_name, "diagnosis");
2937 assert_eq!(sc.classification, "MEDICAL");
2938 }
2939 _ => panic!("expected SetClassification statement"),
2940 }
2941 }
2942
2943 #[test]
2944 fn test_parse_set_classification_various_labels() {
2945 for label in &["PHI", "PII", "PCI", "MEDICAL", "FINANCIAL", "CONFIDENTIAL"] {
2946 let sql = format!(
2947 "ALTER TABLE t MODIFY COLUMN c SET CLASSIFICATION '{label}'"
2948 );
2949 let result = parse_statement(&sql).unwrap();
2950 match result {
2951 ParsedStatement::SetClassification(sc) => {
2952 assert_eq!(sc.classification, *label);
2953 }
2954 _ => panic!("expected SetClassification for {label}"),
2955 }
2956 }
2957 }
2958
2959 #[test]
2960 fn test_parse_set_classification_missing_quotes() {
2961 let result = parse_statement(
2962 "ALTER TABLE patients MODIFY COLUMN ssn SET CLASSIFICATION PHI",
2963 );
2964 assert!(result.is_err(), "classification must be single-quoted");
2965 }
2966
2967 #[test]
2968 fn test_parse_set_classification_missing_modify() {
2969 let result = parse_statement(
2972 "ALTER TABLE patients SET CLASSIFICATION 'PHI'",
2973 );
2974 assert!(result.is_err());
2975 }
2976
2977 #[test]
2982 fn test_parse_show_classifications() {
2983 let result = parse_statement("SHOW CLASSIFICATIONS FOR patients").unwrap();
2984 match result {
2985 ParsedStatement::ShowClassifications(table) => {
2986 assert_eq!(table, "patients");
2987 }
2988 _ => panic!("expected ShowClassifications statement"),
2989 }
2990 }
2991
2992 #[test]
2993 fn test_parse_show_classifications_with_semicolon() {
2994 let result = parse_statement("SHOW CLASSIFICATIONS FOR patients;").unwrap();
2995 match result {
2996 ParsedStatement::ShowClassifications(table) => {
2997 assert_eq!(table, "patients");
2998 }
2999 _ => panic!("expected ShowClassifications statement"),
3000 }
3001 }
3002
3003 #[test]
3004 fn test_parse_show_classifications_missing_for() {
3005 let result = parse_statement("SHOW CLASSIFICATIONS patients");
3006 assert!(result.is_err());
3007 }
3008
3009 #[test]
3010 fn test_parse_show_classifications_missing_table() {
3011 let result = parse_statement("SHOW CLASSIFICATIONS FOR");
3012 assert!(result.is_err());
3013 }
3014
3015 #[test]
3020 fn test_parse_create_role() {
3021 let result = parse_statement("CREATE ROLE billing_clerk").unwrap();
3022 match result {
3023 ParsedStatement::CreateRole(name) => {
3024 assert_eq!(name, "billing_clerk");
3025 }
3026 _ => panic!("expected CreateRole"),
3027 }
3028 }
3029
3030 #[test]
3031 fn test_parse_create_role_with_semicolon() {
3032 let result = parse_statement("CREATE ROLE doctor;").unwrap();
3033 match result {
3034 ParsedStatement::CreateRole(name) => {
3035 assert_eq!(name, "doctor");
3036 }
3037 _ => panic!("expected CreateRole"),
3038 }
3039 }
3040
3041 #[test]
3042 fn test_parse_grant_select_all_columns() {
3043 let result = parse_statement("GRANT SELECT ON patients TO doctor").unwrap();
3044 match result {
3045 ParsedStatement::Grant(g) => {
3046 assert!(g.columns.is_none());
3047 assert_eq!(g.table_name, "patients");
3048 assert_eq!(g.role_name, "doctor");
3049 }
3050 _ => panic!("expected Grant"),
3051 }
3052 }
3053
3054 #[test]
3055 fn test_parse_grant_select_specific_columns() {
3056 let result =
3057 parse_statement("GRANT SELECT (id, name, ssn) ON patients TO billing_clerk").unwrap();
3058 match result {
3059 ParsedStatement::Grant(g) => {
3060 assert_eq!(g.columns, Some(vec!["id".into(), "name".into(), "ssn".into()]));
3061 assert_eq!(g.table_name, "patients");
3062 assert_eq!(g.role_name, "billing_clerk");
3063 }
3064 _ => panic!("expected Grant"),
3065 }
3066 }
3067
3068 #[test]
3069 fn test_parse_create_user() {
3070 let result =
3071 parse_statement("CREATE USER clerk1 WITH ROLE billing_clerk").unwrap();
3072 match result {
3073 ParsedStatement::CreateUser(u) => {
3074 assert_eq!(u.username, "clerk1");
3075 assert_eq!(u.role, "billing_clerk");
3076 }
3077 _ => panic!("expected CreateUser"),
3078 }
3079 }
3080
3081 #[test]
3082 fn test_parse_create_user_with_semicolon() {
3083 let result =
3084 parse_statement("CREATE USER admin1 WITH ROLE admin;").unwrap();
3085 match result {
3086 ParsedStatement::CreateUser(u) => {
3087 assert_eq!(u.username, "admin1");
3088 assert_eq!(u.role, "admin");
3089 }
3090 _ => panic!("expected CreateUser"),
3091 }
3092 }
3093
3094 #[test]
3095 fn test_parse_create_user_missing_role() {
3096 let result = parse_statement("CREATE USER clerk1 WITH billing_clerk");
3097 assert!(result.is_err());
3098 }
3099
3100 #[test]
3105 fn test_parse_create_table_rejects_zero_columns() {
3106 let result = parse_statement("CREATE TABLE#USER");
3108 assert!(result.is_err(), "zero-column CREATE TABLE must be rejected");
3109
3110 let result = parse_statement("CREATE TABLE t ()");
3114 assert!(result.is_err(), "empty-column-list CREATE TABLE must be rejected");
3115 }
3116}