1use kimberlite_types::NonEmptyVec;
15use sqlparser::ast::{
16 BinaryOperator, ColumnDef as SqlColumnDef, DataType as SqlDataType, Expr, ObjectName,
17 OrderByExpr, Query, Select, SelectItem, SetExpr, Statement, Value as SqlValue,
18};
19use sqlparser::dialect::{Dialect, GenericDialect};
20use sqlparser::parser::Parser;
21
22#[derive(Debug)]
27struct KimberliteDialect {
28 inner: GenericDialect,
29}
30
31impl KimberliteDialect {
32 const fn new() -> Self {
33 Self {
34 inner: GenericDialect {},
35 }
36 }
37}
38
39impl Dialect for KimberliteDialect {
40 fn is_identifier_start(&self, ch: char) -> bool {
41 self.inner.is_identifier_start(ch)
42 }
43
44 fn is_identifier_part(&self, ch: char) -> bool {
45 self.inner.is_identifier_part(ch)
46 }
47
48 fn supports_filter_during_aggregation(&self) -> bool {
49 true
50 }
51}
52
53use crate::error::{QueryError, Result};
54use crate::expression::ScalarExpr;
55use crate::schema::{ColumnName, DataType};
56use crate::value::Value;
57
58#[derive(Debug, Clone)]
64pub enum ParsedStatement {
65 Select(ParsedSelect),
67 Union(ParsedUnion),
69 CreateTable(ParsedCreateTable),
71 DropTable { name: String, if_exists: bool },
79 AlterTable(ParsedAlterTable),
81 CreateIndex(ParsedCreateIndex),
83 Insert(ParsedInsert),
85 Update(ParsedUpdate),
87 Delete(ParsedDelete),
89 CreateMask(ParsedCreateMask),
91 DropMask(String),
93 CreateMaskingPolicy(ParsedCreateMaskingPolicy),
95 DropMaskingPolicy(String),
97 AttachMaskingPolicy(ParsedAttachMaskingPolicy),
99 DetachMaskingPolicy(ParsedDetachMaskingPolicy),
101 SetClassification(ParsedSetClassification),
103 ShowClassifications(String),
105 ShowTables,
107 ShowColumns(String),
109 CreateRole(String),
111 Grant(ParsedGrant),
113 CreateUser(ParsedCreateUser),
115}
116
117#[derive(Debug, Clone)]
119pub struct ParsedGrant {
120 pub columns: Option<Vec<String>>,
122 pub table_name: String,
124 pub role_name: String,
126}
127
128#[derive(Debug, Clone)]
130pub struct ParsedCreateUser {
131 pub username: String,
133 pub role: String,
135}
136
137#[derive(Debug, Clone)]
139pub struct ParsedSetClassification {
140 pub table_name: String,
142 pub column_name: String,
144 pub classification: String,
146}
147
148#[derive(Debug, Clone)]
150pub struct ParsedCreateMask {
151 pub mask_name: String,
153 pub table_name: String,
155 pub column_name: String,
157 pub strategy: String,
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
165pub enum ParsedMaskingStrategy {
166 RedactSsn,
168 RedactPhone,
170 RedactEmail,
172 RedactCreditCard,
174 RedactCustom {
176 replacement: String,
178 },
179 Hash,
181 Tokenize,
183 Truncate {
185 max_chars: usize,
187 },
188 Null,
190}
191
192#[derive(Debug, Clone)]
198pub struct ParsedCreateMaskingPolicy {
199 pub name: String,
201 pub strategy: ParsedMaskingStrategy,
203 pub exempt_roles: Vec<String>,
206}
207
208#[derive(Debug, Clone)]
210#[allow(clippy::struct_field_names)] pub struct ParsedAttachMaskingPolicy {
212 pub table_name: String,
214 pub column_name: String,
216 pub policy_name: String,
218}
219
220#[derive(Debug, Clone)]
222pub struct ParsedDetachMaskingPolicy {
223 pub table_name: String,
225 pub column_name: String,
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum SetOp {
232 Union,
234 Intersect,
236 Except,
238}
239
240#[derive(Debug, Clone)]
242pub struct ParsedUnion {
243 pub op: SetOp,
245 pub left: ParsedSelect,
247 pub right: ParsedSelect,
249 pub all: bool,
252}
253
254#[derive(Debug, Clone, PartialEq, Eq)]
256pub enum JoinType {
257 Inner,
259 Left,
261 Right,
263 Full,
265 Cross,
267}
268
269#[derive(Debug, Clone)]
271pub struct ParsedJoin {
272 pub table: String,
274 pub join_type: JoinType,
276 pub on_condition: Vec<Predicate>,
278}
279
280#[derive(Debug, Clone)]
282pub struct ParsedCte {
283 pub name: String,
285 pub query: ParsedSelect,
287 pub recursive_arm: Option<ParsedSelect>,
290}
291
292#[derive(Debug, Clone)]
296pub struct ComputedColumn {
297 pub alias: ColumnName,
299 pub when_clauses: Vec<CaseWhenArm>,
301 pub else_value: Value,
303}
304
305#[derive(Debug, Clone)]
307pub struct CaseWhenArm {
308 pub condition: Vec<Predicate>,
310 pub result: Value,
312}
313
314#[derive(Debug, Clone, Copy, PartialEq, Eq)]
322pub enum LimitExpr {
323 Literal(usize),
325 Param(usize),
327}
328
329#[derive(Debug, Clone)]
331pub struct ParsedSelect {
332 pub table: String,
334 pub joins: Vec<ParsedJoin>,
336 pub columns: Option<Vec<ColumnName>>,
338 pub column_aliases: Option<Vec<Option<String>>>,
345 pub case_columns: Vec<ComputedColumn>,
347 pub predicates: Vec<Predicate>,
349 pub order_by: Vec<OrderByClause>,
351 pub limit: Option<LimitExpr>,
353 pub offset: Option<LimitExpr>,
355 pub aggregates: Vec<AggregateFunction>,
357 pub aggregate_filters: Vec<Option<Vec<Predicate>>>,
364 pub group_by: Vec<ColumnName>,
366 pub distinct: bool,
368 pub having: Vec<HavingCondition>,
370 pub ctes: Vec<ParsedCte>,
372 pub window_fns: Vec<ParsedWindowFn>,
376 pub scalar_projections: Vec<ParsedScalarProjection>,
384}
385
386#[derive(Debug, Clone)]
389pub struct ParsedScalarProjection {
390 pub expr: ScalarExpr,
392 pub output_name: ColumnName,
395 pub alias: Option<String>,
399}
400
401#[derive(Debug, Clone)]
404pub struct ParsedWindowFn {
405 pub function: crate::window::WindowFunction,
407 pub partition_by: Vec<ColumnName>,
409 pub order_by: Vec<OrderByClause>,
411 pub alias: Option<String>,
413}
414
415#[derive(Debug, Clone)]
419pub enum HavingCondition {
420 AggregateComparison {
422 aggregate: AggregateFunction,
424 op: HavingOp,
426 value: Value,
428 },
429}
430
431#[derive(Debug, Clone, Copy, PartialEq, Eq)]
433pub enum HavingOp {
434 Eq,
435 Lt,
436 Le,
437 Gt,
438 Ge,
439}
440
441#[derive(Debug, Clone)]
449pub struct ParsedCreateTable {
450 pub table_name: String,
451 pub columns: NonEmptyVec<ParsedColumn>,
452 pub primary_key: Vec<String>,
453 pub if_not_exists: bool,
456}
457
458#[derive(Debug, Clone)]
460pub struct ParsedColumn {
461 pub name: String,
462 pub data_type: String, pub nullable: bool,
464}
465
466#[derive(Debug, Clone)]
468pub struct ParsedAlterTable {
469 pub table_name: String,
470 pub operation: AlterTableOperation,
471}
472
473#[derive(Debug, Clone)]
475pub enum AlterTableOperation {
476 AddColumn(ParsedColumn),
478 DropColumn(String),
480}
481
482#[derive(Debug, Clone)]
484pub struct ParsedCreateIndex {
485 pub index_name: String,
486 pub table_name: String,
487 pub columns: Vec<String>,
488}
489
490#[derive(Debug, Clone)]
495pub struct ParsedInsert {
496 pub table: String,
497 pub columns: Vec<String>,
498 pub values: Vec<Vec<Value>>, pub returning: Option<Vec<String>>, pub on_conflict: Option<OnConflictClause>,
502}
503
504#[derive(Debug, Clone, PartialEq, Eq)]
514pub struct OnConflictClause {
515 pub target: Vec<String>,
519 pub action: OnConflictAction,
521}
522
523#[derive(Debug, Clone, PartialEq, Eq)]
525pub enum OnConflictAction {
526 DoNothing,
530 DoUpdate {
534 assignments: Vec<(String, UpsertExpr)>,
538 },
539}
540
541#[derive(Debug, Clone, PartialEq, Eq)]
549pub enum UpsertExpr {
550 Value(Value),
552 Excluded(String),
554}
555
556#[derive(Debug, Clone)]
558pub struct ParsedUpdate {
559 pub table: String,
560 pub assignments: Vec<(String, Value)>, pub predicates: Vec<Predicate>,
562 pub returning: Option<Vec<String>>, }
564
565#[derive(Debug, Clone)]
567pub struct ParsedDelete {
568 pub table: String,
569 pub predicates: Vec<Predicate>,
570 pub returning: Option<Vec<String>>, }
572
573#[derive(Debug, Clone, PartialEq, Eq)]
575pub enum AggregateFunction {
576 CountStar,
578 Count(ColumnName),
580 Sum(ColumnName),
582 Avg(ColumnName),
584 Min(ColumnName),
586 Max(ColumnName),
588}
589
590#[derive(Debug, Clone)]
592pub enum Predicate {
593 Eq(ColumnName, PredicateValue),
595 Lt(ColumnName, PredicateValue),
597 Le(ColumnName, PredicateValue),
599 Gt(ColumnName, PredicateValue),
601 Ge(ColumnName, PredicateValue),
603 In(ColumnName, Vec<PredicateValue>),
605 NotIn(ColumnName, Vec<PredicateValue>),
607 NotBetween(ColumnName, PredicateValue, PredicateValue),
609 Like(ColumnName, String),
611 NotLike(ColumnName, String),
613 ILike(ColumnName, String),
615 NotILike(ColumnName, String),
617 IsNull(ColumnName),
619 IsNotNull(ColumnName),
621 JsonExtractEq {
626 column: ColumnName,
628 path: String,
630 as_text: bool,
632 value: PredicateValue,
634 },
635 JsonContains {
637 column: ColumnName,
638 value: PredicateValue,
639 },
640 InSubquery {
649 column: ColumnName,
650 subquery: Box<ParsedSelect>,
651 negated: bool,
653 },
654 Exists {
659 subquery: Box<ParsedSelect>,
660 negated: bool,
661 },
662 Always(bool),
669 Or(Vec<Predicate>, Vec<Predicate>),
671 ScalarCmp {
680 lhs: ScalarExpr,
681 op: ScalarCmpOp,
682 rhs: ScalarExpr,
683 },
684}
685
686#[derive(Debug, Clone, Copy, PartialEq, Eq)]
688pub enum ScalarCmpOp {
689 Eq,
690 NotEq,
691 Lt,
692 Le,
693 Gt,
694 Ge,
695}
696
697impl Predicate {
698 #[allow(dead_code)]
702 pub fn column(&self) -> Option<&ColumnName> {
703 match self {
704 Predicate::Eq(col, _)
705 | Predicate::Lt(col, _)
706 | Predicate::Le(col, _)
707 | Predicate::Gt(col, _)
708 | Predicate::Ge(col, _)
709 | Predicate::In(col, _)
710 | Predicate::NotIn(col, _)
711 | Predicate::NotBetween(col, _, _)
712 | Predicate::Like(col, _)
713 | Predicate::NotLike(col, _)
714 | Predicate::ILike(col, _)
715 | Predicate::NotILike(col, _)
716 | Predicate::IsNull(col)
717 | Predicate::IsNotNull(col)
718 | Predicate::JsonExtractEq { column: col, .. }
719 | Predicate::JsonContains { column: col, .. }
720 | Predicate::InSubquery { column: col, .. } => Some(col),
721 Predicate::Or(_, _)
722 | Predicate::Exists { .. }
723 | Predicate::Always(_)
724 | Predicate::ScalarCmp { .. } => None,
725 }
726 }
727}
728
729#[derive(Debug, Clone)]
731pub enum PredicateValue {
732 Int(i64),
734 String(String),
736 Bool(bool),
738 Null,
740 Param(usize),
742 Literal(Value),
744 ColumnRef(String),
747}
748
749#[derive(Debug, Clone)]
751pub struct OrderByClause {
752 pub column: ColumnName,
754 pub ascending: bool,
756}
757
758pub fn parse_statement(sql: &str) -> Result<ParsedStatement> {
764 crate::depth_check::check_sql_depth(sql)?;
765
766 if let Some(parsed) = try_parse_custom_statement(sql)? {
768 return Ok(parsed);
769 }
770
771 let dialect = KimberliteDialect::new();
772 let statements =
773 Parser::parse_sql(&dialect, sql).map_err(|e| QueryError::ParseError(e.to_string()))?;
774
775 if statements.len() != 1 {
776 return Err(QueryError::ParseError(format!(
777 "expected exactly 1 statement, got {}",
778 statements.len()
779 )));
780 }
781
782 match &statements[0] {
783 Statement::Query(query) => parse_query_to_statement(query),
784 Statement::CreateTable(create_table) => {
785 let parsed = parse_create_table(create_table)?;
786 Ok(ParsedStatement::CreateTable(parsed))
787 }
788 Statement::Drop {
789 object_type,
790 names,
791 if_exists,
792 ..
793 } => {
794 if !matches!(object_type, sqlparser::ast::ObjectType::Table) {
795 return Err(QueryError::UnsupportedFeature(
796 "only DROP TABLE is supported".to_string(),
797 ));
798 }
799 if names.len() != 1 {
800 return Err(QueryError::ParseError(
801 "expected exactly 1 table in DROP TABLE".to_string(),
802 ));
803 }
804 let table_name = object_name_to_string(&names[0]);
805 Ok(ParsedStatement::DropTable {
806 name: table_name,
807 if_exists: *if_exists,
808 })
809 }
810 Statement::CreateIndex(create_index) => {
811 let parsed = parse_create_index(create_index)?;
812 Ok(ParsedStatement::CreateIndex(parsed))
813 }
814 Statement::Insert(insert) => {
815 let parsed = parse_insert(insert)?;
816 Ok(ParsedStatement::Insert(parsed))
817 }
818 Statement::Update(update) => {
819 let parsed = parse_update(
820 &update.table,
821 &update.assignments,
822 update.selection.as_ref(),
823 update.returning.as_ref(),
824 )?;
825 Ok(ParsedStatement::Update(parsed))
826 }
827 Statement::Delete(delete) => {
828 let parsed = parse_delete_stmt(delete)?;
829 Ok(ParsedStatement::Delete(parsed))
830 }
831 Statement::AlterTable(alter_table) => {
832 let parsed = parse_alter_table(&alter_table.name, &alter_table.operations)?;
833 Ok(ParsedStatement::AlterTable(parsed))
834 }
835 Statement::CreateRole(create_role) => {
836 if create_role.names.len() != 1 {
837 return Err(QueryError::ParseError(
838 "expected exactly 1 role name".to_string(),
839 ));
840 }
841 let role_name = object_name_to_string(&create_role.names[0]);
842 Ok(ParsedStatement::CreateRole(role_name))
843 }
844 Statement::Grant(grant) => {
845 let objects = grant.objects.as_ref().ok_or_else(|| {
846 QueryError::ParseError(
847 "GRANT requires an ON clause specifying the target objects".to_string(),
848 )
849 })?;
850 parse_grant(&grant.privileges, objects, &grant.grantees)
851 }
852 other => Err(QueryError::UnsupportedFeature(format!(
853 "statement type not supported: {other:?}"
854 ))),
855 }
856}
857
858pub fn try_parse_custom_statement(sql: &str) -> Result<Option<ParsedStatement>> {
874 let trimmed = sql.trim().trim_end_matches(';').trim();
875 let upper = trimmed.to_ascii_uppercase();
876
877 if upper.starts_with("CREATE MASKING POLICY") {
879 return parse_create_masking_policy(trimmed).map(Some);
880 }
881
882 if upper.starts_with("DROP MASKING POLICY") {
884 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
885 if tokens.len() != 4 {
887 return Err(QueryError::ParseError(
888 "expected: DROP MASKING POLICY <name>".to_string(),
889 ));
890 }
891 return Ok(Some(ParsedStatement::DropMaskingPolicy(
892 tokens[3].to_string(),
893 )));
894 }
895
896 if upper.starts_with("ALTER TABLE") && upper.contains("MASKING POLICY") {
898 return parse_alter_masking_policy(trimmed).map(Some);
899 }
900
901 if upper.starts_with("CREATE MASK") {
903 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
904 if tokens.len() != 7 {
906 return Err(QueryError::ParseError(
907 "expected: CREATE MASK <name> ON <table>.<column> USING <strategy>".to_string(),
908 ));
909 }
910 if !tokens[3].eq_ignore_ascii_case("ON") {
911 return Err(QueryError::ParseError(format!(
912 "expected ON after mask name, got '{}'",
913 tokens[3]
914 )));
915 }
916 if !tokens[5].eq_ignore_ascii_case("USING") {
917 return Err(QueryError::ParseError(format!(
918 "expected USING after column reference, got '{}'",
919 tokens[5]
920 )));
921 }
922
923 let table_col = tokens[4];
925 let dot_pos = table_col.find('.').ok_or_else(|| {
926 QueryError::ParseError(format!(
927 "expected <table>.<column> but got '{table_col}' (missing '.')"
928 ))
929 })?;
930 let table_name = table_col[..dot_pos].to_string();
931 let column_name = table_col[dot_pos + 1..].to_string();
932
933 if table_name.is_empty() || column_name.is_empty() {
934 return Err(QueryError::ParseError(
935 "table name and column name must not be empty".to_string(),
936 ));
937 }
938
939 let strategy = tokens[6].to_ascii_uppercase();
940
941 return Ok(Some(ParsedStatement::CreateMask(ParsedCreateMask {
942 mask_name: tokens[2].to_string(),
943 table_name,
944 column_name,
945 strategy,
946 })));
947 }
948
949 if upper.starts_with("DROP MASK") {
951 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
952 if tokens.len() != 3 {
953 return Err(QueryError::ParseError(
954 "expected: DROP MASK <name>".to_string(),
955 ));
956 }
957 return Ok(Some(ParsedStatement::DropMask(tokens[2].to_string())));
958 }
959
960 if upper.starts_with("ALTER TABLE") && upper.contains("SET CLASSIFICATION") {
962 return parse_set_classification(trimmed);
963 }
964
965 if upper.starts_with("SHOW CLASSIFICATIONS") {
967 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
968 if tokens.len() != 4 {
970 return Err(QueryError::ParseError(
971 "expected: SHOW CLASSIFICATIONS FOR <table>".to_string(),
972 ));
973 }
974 if !tokens[2].eq_ignore_ascii_case("FOR") {
975 return Err(QueryError::ParseError(format!(
976 "expected FOR after CLASSIFICATIONS, got '{}'",
977 tokens[2]
978 )));
979 }
980 return Ok(Some(ParsedStatement::ShowClassifications(
981 tokens[3].to_string(),
982 )));
983 }
984
985 if upper == "SHOW TABLES" {
987 return Ok(Some(ParsedStatement::ShowTables));
988 }
989
990 if upper.starts_with("SHOW COLUMNS") {
992 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
993 if tokens.len() != 4 {
995 return Err(QueryError::ParseError(
996 "expected: SHOW COLUMNS FROM <table>".to_string(),
997 ));
998 }
999 if !tokens[2].eq_ignore_ascii_case("FROM") {
1000 return Err(QueryError::ParseError(format!(
1001 "expected FROM after COLUMNS, got '{}'",
1002 tokens[2]
1003 )));
1004 }
1005 return Ok(Some(ParsedStatement::ShowColumns(tokens[3].to_string())));
1006 }
1007
1008 if upper.starts_with("CREATE USER") {
1010 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
1011 if tokens.len() != 6 {
1013 return Err(QueryError::ParseError(
1014 "expected: CREATE USER <name> WITH ROLE <role>".to_string(),
1015 ));
1016 }
1017 if !tokens[3].eq_ignore_ascii_case("WITH") {
1018 return Err(QueryError::ParseError(format!(
1019 "expected WITH after username, got '{}'",
1020 tokens[3]
1021 )));
1022 }
1023 if !tokens[4].eq_ignore_ascii_case("ROLE") {
1024 return Err(QueryError::ParseError(format!(
1025 "expected ROLE after WITH, got '{}'",
1026 tokens[4]
1027 )));
1028 }
1029 return Ok(Some(ParsedStatement::CreateUser(ParsedCreateUser {
1030 username: tokens[2].to_string(),
1031 role: tokens[5].to_string(),
1032 })));
1033 }
1034
1035 Ok(None)
1036}
1037
1038fn parse_create_masking_policy(trimmed: &str) -> Result<ParsedStatement> {
1046 let after_keyword = trimmed
1050 .get("CREATE MASKING POLICY".len()..)
1051 .ok_or_else(|| QueryError::ParseError("missing policy body".to_string()))?
1052 .trim_start();
1053
1054 let upper_body = after_keyword.to_ascii_uppercase();
1057 let exempt_pos = upper_body.find("EXEMPT ROLES").ok_or_else(|| {
1058 QueryError::ParseError(
1059 "expected: CREATE MASKING POLICY <name> STRATEGY <kind> [<arg>] EXEMPT ROLES (<r>, ...)"
1060 .to_string(),
1061 )
1062 })?;
1063
1064 let header = after_keyword[..exempt_pos].trim();
1065 let exempt_tail = after_keyword[exempt_pos + "EXEMPT ROLES".len()..].trim();
1066
1067 let (name, strategy) = parse_masking_policy_header(header)?;
1068 let exempt_roles = parse_exempt_roles_list(exempt_tail)?;
1069
1070 Ok(ParsedStatement::CreateMaskingPolicy(
1071 ParsedCreateMaskingPolicy {
1072 name,
1073 strategy,
1074 exempt_roles,
1075 },
1076 ))
1077}
1078
1079fn parse_masking_policy_header(header: &str) -> Result<(String, ParsedMaskingStrategy)> {
1081 let tokens: Vec<&str> = header.split_whitespace().collect();
1082 if tokens.len() < 3 {
1083 return Err(QueryError::ParseError(
1084 "expected: <name> STRATEGY <kind> [<arg>]".to_string(),
1085 ));
1086 }
1087 let name = tokens[0].to_string();
1088 if name.is_empty() {
1089 return Err(QueryError::ParseError(
1090 "policy name must not be empty".to_string(),
1091 ));
1092 }
1093 if !tokens[1].eq_ignore_ascii_case("STRATEGY") {
1094 return Err(QueryError::ParseError(format!(
1095 "expected STRATEGY after policy name, got '{}'",
1096 tokens[1]
1097 )));
1098 }
1099
1100 let strategy = parse_masking_strategy(&tokens[2..])?;
1101 Ok((name, strategy))
1102}
1103
1104fn parse_masking_strategy(tokens: &[&str]) -> Result<ParsedMaskingStrategy> {
1106 debug_assert!(
1107 !tokens.is_empty(),
1108 "caller must pass at least the strategy keyword"
1109 );
1110 let kind = tokens[0].to_ascii_uppercase();
1111 match kind.as_str() {
1112 "REDACT_SSN" => {
1113 expect_no_strategy_arg(tokens, "REDACT_SSN").map(|()| ParsedMaskingStrategy::RedactSsn)
1114 }
1115 "REDACT_PHONE" => expect_no_strategy_arg(tokens, "REDACT_PHONE")
1116 .map(|()| ParsedMaskingStrategy::RedactPhone),
1117 "REDACT_EMAIL" => expect_no_strategy_arg(tokens, "REDACT_EMAIL")
1118 .map(|()| ParsedMaskingStrategy::RedactEmail),
1119 "REDACT_CC" => expect_no_strategy_arg(tokens, "REDACT_CC")
1120 .map(|()| ParsedMaskingStrategy::RedactCreditCard),
1121 "REDACT_CUSTOM" => {
1122 if tokens.len() != 2 {
1123 return Err(QueryError::ParseError(
1124 "REDACT_CUSTOM requires a single quoted replacement string".to_string(),
1125 ));
1126 }
1127 let replacement = unquote_string_literal(tokens[1]).ok_or_else(|| {
1128 QueryError::ParseError(
1129 "REDACT_CUSTOM replacement must be a single-quoted string".to_string(),
1130 )
1131 })?;
1132 Ok(ParsedMaskingStrategy::RedactCustom { replacement })
1133 }
1134 "HASH" => {
1135 expect_no_strategy_arg(tokens, "HASH")?;
1136 Ok(ParsedMaskingStrategy::Hash)
1137 }
1138 "TOKENIZE" => {
1139 expect_no_strategy_arg(tokens, "TOKENIZE")?;
1140 Ok(ParsedMaskingStrategy::Tokenize)
1141 }
1142 "TRUNCATE" => {
1143 if tokens.len() != 2 {
1144 return Err(QueryError::ParseError(
1145 "TRUNCATE requires a positive integer character count".to_string(),
1146 ));
1147 }
1148 let max_chars = tokens[1].parse::<usize>().map_err(|_| {
1149 QueryError::ParseError(format!(
1150 "TRUNCATE argument must be a non-negative integer, got '{}'",
1151 tokens[1]
1152 ))
1153 })?;
1154 if max_chars == 0 {
1155 return Err(QueryError::ParseError(
1156 "TRUNCATE character count must be > 0".to_string(),
1157 ));
1158 }
1159 Ok(ParsedMaskingStrategy::Truncate { max_chars })
1160 }
1161 "NULL" => {
1162 expect_no_strategy_arg(tokens, "NULL")?;
1163 Ok(ParsedMaskingStrategy::Null)
1164 }
1165 _ => Err(QueryError::ParseError(format!(
1166 "unknown masking strategy '{kind}' — expected one of REDACT_SSN, REDACT_PHONE, \
1167 REDACT_EMAIL, REDACT_CC, REDACT_CUSTOM, HASH, TOKENIZE, TRUNCATE, NULL"
1168 ))),
1169 }
1170}
1171
1172fn expect_no_strategy_arg(tokens: &[&str], kind: &str) -> Result<()> {
1173 if tokens.len() != 1 {
1174 return Err(QueryError::ParseError(format!(
1175 "{kind} takes no arguments (found {} extra token(s))",
1176 tokens.len() - 1
1177 )));
1178 }
1179 Ok(())
1180}
1181
1182fn unquote_string_literal(token: &str) -> Option<String> {
1185 let bytes = token.as_bytes();
1186 if bytes.len() < 2 || bytes[0] != b'\'' || bytes[bytes.len() - 1] != b'\'' {
1187 return None;
1188 }
1189 Some(token[1..token.len() - 1].to_string())
1190}
1191
1192fn parse_exempt_roles_list(tail: &str) -> Result<Vec<String>> {
1199 let trimmed = tail.trim();
1200 if !trimmed.starts_with('(') || !trimmed.ends_with(')') {
1201 return Err(QueryError::ParseError(
1202 "EXEMPT ROLES must be followed by a parenthesised list: EXEMPT ROLES (r1, r2, ...)"
1203 .to_string(),
1204 ));
1205 }
1206 let inner = &trimmed[1..trimmed.len() - 1];
1207 let roles: Vec<String> = inner
1208 .split(',')
1209 .map(|s| s.trim().trim_matches('\'').to_ascii_lowercase())
1210 .filter(|s| !s.is_empty())
1211 .collect();
1212 if roles.is_empty() {
1213 return Err(QueryError::ParseError(
1214 "EXEMPT ROLES list must contain at least one role".to_string(),
1215 ));
1216 }
1217 Ok(roles)
1218}
1219
1220fn parse_alter_masking_policy(trimmed: &str) -> Result<ParsedStatement> {
1222 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
1223 if tokens.len() < 9 || tokens.len() > 10 {
1226 return Err(QueryError::ParseError(
1227 "expected: ALTER TABLE <t> ALTER COLUMN <c> { SET | DROP } MASKING POLICY [<name>]"
1228 .to_string(),
1229 ));
1230 }
1231 if !tokens[0].eq_ignore_ascii_case("ALTER")
1232 || !tokens[1].eq_ignore_ascii_case("TABLE")
1233 || !tokens[3].eq_ignore_ascii_case("ALTER")
1234 || !tokens[4].eq_ignore_ascii_case("COLUMN")
1235 || !tokens[7].eq_ignore_ascii_case("MASKING")
1236 || !tokens[8].eq_ignore_ascii_case("POLICY")
1237 {
1238 return Err(QueryError::ParseError(format!(
1239 "malformed ALTER ... MASKING POLICY statement: '{trimmed}'"
1240 )));
1241 }
1242 let table_name = tokens[2].to_string();
1243 let column_name = tokens[5].to_string();
1244 let action = tokens[6].to_ascii_uppercase();
1245 match action.as_str() {
1246 "SET" => {
1247 if tokens.len() != 10 {
1248 return Err(QueryError::ParseError(
1249 "SET MASKING POLICY requires a policy name".to_string(),
1250 ));
1251 }
1252 Ok(ParsedStatement::AttachMaskingPolicy(
1253 ParsedAttachMaskingPolicy {
1254 table_name,
1255 column_name,
1256 policy_name: tokens[9].to_string(),
1257 },
1258 ))
1259 }
1260 "DROP" => {
1261 if tokens.len() != 9 {
1262 return Err(QueryError::ParseError(
1263 "DROP MASKING POLICY takes no arguments after POLICY".to_string(),
1264 ));
1265 }
1266 Ok(ParsedStatement::DetachMaskingPolicy(
1267 ParsedDetachMaskingPolicy {
1268 table_name,
1269 column_name,
1270 },
1271 ))
1272 }
1273 _ => Err(QueryError::ParseError(format!(
1274 "expected SET or DROP after column name, got '{action}'"
1275 ))),
1276 }
1277}
1278
1279#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1296pub enum TimeTravel {
1297 Offset(u64),
1299 TimestampNs(i64),
1302}
1303
1304pub fn extract_at_offset(sql: &str) -> (String, Option<u64>) {
1326 let upper = sql.to_ascii_uppercase();
1329
1330 let Some(at_pos) = upper.rfind("AT OFFSET") else {
1333 return (sql.to_string(), None);
1334 };
1335
1336 if at_pos > 0 {
1338 let prev_byte = sql.as_bytes()[at_pos - 1];
1339 if prev_byte != b' ' && prev_byte != b'\t' && prev_byte != b'\n' && prev_byte != b'\r' {
1340 return (sql.to_string(), None);
1341 }
1342 }
1343
1344 let after_at_offset = &sql[at_pos + 9..].trim_start();
1346
1347 let num_end = after_at_offset
1350 .find(|c: char| !c.is_ascii_digit())
1351 .unwrap_or(after_at_offset.len());
1352
1353 if num_end == 0 {
1354 return (sql.to_string(), None);
1356 }
1357
1358 let num_str = &after_at_offset[..num_end];
1359 let Ok(offset) = num_str.parse::<u64>() else {
1360 return (sql.to_string(), None);
1361 };
1362
1363 let remainder = after_at_offset[num_end..].trim();
1366 if !remainder.is_empty() && remainder != ";" {
1367 return (sql.to_string(), None);
1368 }
1369
1370 let before = sql[..at_pos].trim_end();
1372 let cleaned = before.to_string();
1373
1374 (cleaned, Some(offset))
1375}
1376
1377pub fn extract_time_travel(sql: &str) -> (String, Option<TimeTravel>) {
1405 let (after_offset_sql, offset) = extract_at_offset(sql);
1407 if let Some(o) = offset {
1408 return (after_offset_sql, Some(TimeTravel::Offset(o)));
1409 }
1410
1411 let upper = sql.to_ascii_uppercase();
1417
1418 let (keyword_pos, keyword_len) = if let Some(p) = upper.rfind("FOR SYSTEM_TIME AS OF") {
1424 (p, "FOR SYSTEM_TIME AS OF".len())
1425 } else if let Some(p) = upper.rfind("AS OF") {
1426 let after = sql[p + "AS OF".len()..].trim_start();
1429 if !after.starts_with('\'') {
1430 return (sql.to_string(), None);
1431 }
1432 (p, "AS OF".len())
1433 } else {
1434 return (sql.to_string(), None);
1435 };
1436
1437 if keyword_pos > 0 {
1439 let prev = sql.as_bytes()[keyword_pos - 1];
1440 if !matches!(prev, b' ' | b'\t' | b'\n' | b'\r') {
1441 return (sql.to_string(), None);
1442 }
1443 }
1444
1445 let after_keyword = sql[keyword_pos + keyword_len..].trim_start();
1446 if !after_keyword.starts_with('\'') {
1447 return (sql.to_string(), None);
1448 }
1449
1450 let ts_start = 1; let ts_end = match after_keyword[1..].find('\'') {
1455 Some(i) => i + 1,
1456 None => return (sql.to_string(), None),
1457 };
1458 let ts_str = &after_keyword[ts_start..ts_end];
1459
1460 let ts_ns = match chrono::DateTime::parse_from_rfc3339(ts_str) {
1462 Ok(dt) => dt.timestamp_nanos_opt(),
1463 Err(_) => return (sql.to_string(), None),
1464 };
1465 let ts_ns = match ts_ns {
1466 Some(n) => n,
1467 None => return (sql.to_string(), None),
1468 };
1469
1470 let remainder = after_keyword[ts_end + 1..].trim();
1472 if !remainder.is_empty() && remainder != ";" {
1473 return (sql.to_string(), None);
1474 }
1475
1476 let before = sql[..keyword_pos].trim_end();
1477 (before.to_string(), Some(TimeTravel::TimestampNs(ts_ns)))
1478}
1479
1480fn parse_set_classification(sql: &str) -> Result<Option<ParsedStatement>> {
1485 let tokens: Vec<&str> = sql.split_whitespace().collect();
1486 if tokens.len() != 9 {
1489 return Err(QueryError::ParseError(
1490 "expected: ALTER TABLE <table> MODIFY COLUMN <column> SET CLASSIFICATION '<class>'"
1491 .to_string(),
1492 ));
1493 }
1494
1495 if !tokens[3].eq_ignore_ascii_case("MODIFY") {
1496 return Err(QueryError::ParseError(format!(
1497 "expected MODIFY, got '{}'",
1498 tokens[3]
1499 )));
1500 }
1501 if !tokens[4].eq_ignore_ascii_case("COLUMN") {
1502 return Err(QueryError::ParseError(format!(
1503 "expected COLUMN after MODIFY, got '{}'",
1504 tokens[4]
1505 )));
1506 }
1507 if !tokens[6].eq_ignore_ascii_case("SET") {
1508 return Err(QueryError::ParseError(format!(
1509 "expected SET, got '{}'",
1510 tokens[6]
1511 )));
1512 }
1513 if !tokens[7].eq_ignore_ascii_case("CLASSIFICATION") {
1514 return Err(QueryError::ParseError(format!(
1515 "expected CLASSIFICATION, got '{}'",
1516 tokens[7]
1517 )));
1518 }
1519
1520 let table_name = tokens[2].to_string();
1521 let column_name = tokens[5].to_string();
1522
1523 let raw_class = tokens[8];
1525 let classification = raw_class
1526 .strip_prefix('\'')
1527 .and_then(|s| s.strip_suffix('\''))
1528 .ok_or_else(|| {
1529 QueryError::ParseError(format!(
1530 "classification must be quoted with single quotes, got '{raw_class}'"
1531 ))
1532 })?
1533 .to_string();
1534
1535 assert!(!table_name.is_empty(), "table name must not be empty");
1536 assert!(!column_name.is_empty(), "column name must not be empty");
1537 assert!(
1538 !classification.is_empty(),
1539 "classification must not be empty"
1540 );
1541
1542 Ok(Some(ParsedStatement::SetClassification(
1543 ParsedSetClassification {
1544 table_name,
1545 column_name,
1546 classification,
1547 },
1548 )))
1549}
1550
1551fn parse_grant(
1553 privileges: &sqlparser::ast::Privileges,
1554 objects: &sqlparser::ast::GrantObjects,
1555 grantees: &[sqlparser::ast::Grantee],
1556) -> Result<ParsedStatement> {
1557 use sqlparser::ast::{Action, GrantObjects, GranteeName, Privileges};
1558
1559 let columns = match privileges {
1561 Privileges::Actions(actions) => {
1562 let mut cols = None;
1563 for action in actions {
1564 if let Action::Select { columns: Some(c) } = action {
1565 cols = Some(c.iter().map(|i| i.value.clone()).collect());
1566 }
1567 }
1568 cols
1569 }
1570 Privileges::All { .. } => None,
1571 };
1572
1573 let table_name = match objects {
1575 GrantObjects::Tables(tables) => {
1576 if tables.len() != 1 {
1577 return Err(QueryError::ParseError(
1578 "expected exactly 1 table in GRANT".to_string(),
1579 ));
1580 }
1581 object_name_to_string(&tables[0])
1582 }
1583 _ => {
1584 return Err(QueryError::UnsupportedFeature(
1585 "GRANT only supports table-level privileges".to_string(),
1586 ));
1587 }
1588 };
1589
1590 if grantees.len() != 1 {
1592 return Err(QueryError::ParseError(
1593 "expected exactly 1 grantee in GRANT".to_string(),
1594 ));
1595 }
1596 let role_name = match &grantees[0].name {
1597 Some(GranteeName::ObjectName(name)) => object_name_to_string(name),
1598 _ => {
1599 return Err(QueryError::ParseError(
1600 "expected a role name in GRANT".to_string(),
1601 ));
1602 }
1603 };
1604
1605 Ok(ParsedStatement::Grant(ParsedGrant {
1606 columns,
1607 table_name,
1608 role_name,
1609 }))
1610}
1611
1612fn parse_query_to_statement(query: &Query) -> Result<ParsedStatement> {
1614 let ctes = match &query.with {
1616 Some(with) => parse_ctes(with)?,
1617 None => vec![],
1618 };
1619
1620 match query.body.as_ref() {
1621 SetExpr::Select(select) => {
1622 let parsed_select = parse_select(select)?;
1623
1624 let order_by = match &query.order_by {
1626 Some(ob) => parse_order_by(ob)?,
1627 None => vec![],
1628 };
1629
1630 let limit = parse_limit(query_limit_expr(query)?)?;
1632 let offset = parse_offset_clause(query_offset(query))?;
1633
1634 let mut all_ctes = ctes;
1636 all_ctes.extend(parsed_select.ctes);
1637
1638 Ok(ParsedStatement::Select(ParsedSelect {
1639 table: parsed_select.table,
1640 joins: parsed_select.joins,
1641 columns: parsed_select.columns,
1642 column_aliases: parsed_select.column_aliases,
1643 case_columns: parsed_select.case_columns,
1644 predicates: parsed_select.predicates,
1645 order_by,
1646 limit,
1647 offset,
1648 aggregates: parsed_select.aggregates,
1649 aggregate_filters: parsed_select.aggregate_filters,
1650 group_by: parsed_select.group_by,
1651 distinct: parsed_select.distinct,
1652 having: parsed_select.having,
1653 ctes: all_ctes,
1654 window_fns: parsed_select.window_fns,
1655 scalar_projections: parsed_select.scalar_projections,
1656 }))
1657 }
1658 SetExpr::SetOperation {
1659 op,
1660 set_quantifier,
1661 left,
1662 right,
1663 } => {
1664 use sqlparser::ast::SetOperator;
1665 use sqlparser::ast::SetQuantifier;
1666
1667 let parsed_op = match op {
1668 SetOperator::Union => SetOp::Union,
1669 SetOperator::Intersect => SetOp::Intersect,
1670 SetOperator::Except | SetOperator::Minus => SetOp::Except,
1672 };
1673
1674 let all = matches!(set_quantifier, SetQuantifier::All);
1675
1676 let left_select = match left.as_ref() {
1678 SetExpr::Select(s) => parse_select(s)?,
1679 _ => {
1680 return Err(QueryError::UnsupportedFeature(
1681 "nested set operations not supported".to_string(),
1682 ));
1683 }
1684 };
1685 let right_select = match right.as_ref() {
1686 SetExpr::Select(s) => parse_select(s)?,
1687 _ => {
1688 return Err(QueryError::UnsupportedFeature(
1689 "nested set operations not supported".to_string(),
1690 ));
1691 }
1692 };
1693
1694 Ok(ParsedStatement::Union(ParsedUnion {
1695 op: parsed_op,
1696 left: left_select,
1697 right: right_select,
1698 all,
1699 }))
1700 }
1701 other => Err(QueryError::UnsupportedFeature(format!(
1702 "unsupported query type: {other:?}"
1703 ))),
1704 }
1705}
1706
1707fn parse_join_with_subqueries(join: &sqlparser::ast::Join) -> Result<(ParsedJoin, Vec<ParsedCte>)> {
1709 use sqlparser::ast::{JoinConstraint, JoinOperator};
1710
1711 let join_type = match &join.join_operator {
1718 JoinOperator::Inner(_) | JoinOperator::Join(_) => JoinType::Inner,
1719 JoinOperator::LeftOuter(_) | JoinOperator::Left(_) => JoinType::Left,
1720 JoinOperator::RightOuter(_) | JoinOperator::Right(_) => JoinType::Right,
1721 JoinOperator::FullOuter(_) => JoinType::Full,
1722 JoinOperator::CrossJoin(_) => JoinType::Cross,
1723 other => {
1724 return Err(QueryError::UnsupportedFeature(format!(
1725 "join type not supported: {other:?}"
1726 )));
1727 }
1728 };
1729
1730 let mut inline_ctes = Vec::new();
1732 let table = match &join.relation {
1733 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1734 sqlparser::ast::TableFactor::Derived {
1735 subquery, alias, ..
1736 } => {
1737 let alias_name = alias
1738 .as_ref()
1739 .map(|a| a.name.value.clone())
1740 .ok_or_else(|| {
1741 QueryError::ParseError("subquery in JOIN requires an alias".to_string())
1742 })?;
1743
1744 let inner = match subquery.body.as_ref() {
1746 SetExpr::Select(s) => parse_select(s)?,
1747 _ => {
1748 return Err(QueryError::UnsupportedFeature(
1749 "subquery body must be a simple SELECT".to_string(),
1750 ));
1751 }
1752 };
1753
1754 let order_by = match &subquery.order_by {
1755 Some(ob) => parse_order_by(ob)?,
1756 None => vec![],
1757 };
1758 let limit = parse_limit(query_limit_expr(subquery)?)?;
1759
1760 inline_ctes.push(ParsedCte {
1761 name: alias_name.clone(),
1762 query: ParsedSelect {
1763 order_by,
1764 limit,
1765 ..inner
1766 },
1767 recursive_arm: None,
1768 });
1769
1770 alias_name
1771 }
1772 _ => {
1773 return Err(QueryError::UnsupportedFeature(
1774 "unsupported JOIN relation type".to_string(),
1775 ));
1776 }
1777 };
1778
1779 let on_condition = match &join.join_operator {
1781 JoinOperator::CrossJoin(_) => Vec::new(),
1782 JoinOperator::Inner(constraint)
1783 | JoinOperator::Join(constraint)
1784 | JoinOperator::LeftOuter(constraint)
1785 | JoinOperator::Left(constraint)
1786 | JoinOperator::RightOuter(constraint)
1787 | JoinOperator::Right(constraint)
1788 | JoinOperator::FullOuter(constraint) => match constraint {
1789 JoinConstraint::On(expr) => parse_join_condition(expr)?,
1790 JoinConstraint::Using(idents) => {
1791 let mut preds = Vec::new();
1796 for name in idents {
1797 if name.0.len() != 1 {
1798 return Err(QueryError::UnsupportedFeature(format!(
1799 "USING column must be a bare identifier, got {name}"
1800 )));
1801 }
1802 let col_name = name.0[0]
1803 .as_ident()
1804 .ok_or_else(|| {
1805 QueryError::UnsupportedFeature(format!(
1806 "USING column must be a bare identifier, got {name}"
1807 ))
1808 })?
1809 .value
1810 .clone();
1811 preds.push(Predicate::Eq(
1812 ColumnName::new(col_name.clone()),
1813 PredicateValue::ColumnRef(col_name),
1814 ));
1815 }
1816 preds
1817 }
1818 JoinConstraint::Natural => {
1819 return Err(QueryError::UnsupportedFeature(
1820 "NATURAL JOIN is not supported; use ON or USING explicitly".to_string(),
1821 ));
1822 }
1823 JoinConstraint::None => {
1824 return Err(QueryError::UnsupportedFeature(
1825 "join without ON or USING clause not supported".to_string(),
1826 ));
1827 }
1828 },
1829 _ => {
1830 return Err(QueryError::UnsupportedFeature(
1831 "join without ON clause not supported".to_string(),
1832 ));
1833 }
1834 };
1835
1836 Ok((
1837 ParsedJoin {
1838 table,
1839 join_type,
1840 on_condition,
1841 },
1842 inline_ctes,
1843 ))
1844}
1845
1846fn parse_join_condition(expr: &Expr) -> Result<Vec<Predicate>> {
1849 match expr {
1850 Expr::BinaryOp {
1851 left,
1852 op: BinaryOperator::And,
1853 right,
1854 } => {
1855 let mut predicates = parse_join_condition(left)?;
1856 predicates.extend(parse_join_condition(right)?);
1857 Ok(predicates)
1858 }
1859 _ => {
1860 parse_where_expr(expr)
1862 }
1863 }
1864}
1865
1866fn parse_select(select: &Select) -> Result<ParsedSelect> {
1867 let distinct = select.distinct.is_some();
1869
1870 if select.from.len() != 1 {
1872 return Err(QueryError::ParseError(format!(
1873 "expected exactly 1 table in FROM clause, got {}",
1874 select.from.len()
1875 )));
1876 }
1877
1878 let from = &select.from[0];
1879
1880 let mut inline_ctes = Vec::new();
1882
1883 let mut joins = Vec::new();
1885 for join in &from.joins {
1886 let (parsed_join, join_ctes) = parse_join_with_subqueries(join)?;
1887 joins.push(parsed_join);
1888 inline_ctes.extend(join_ctes);
1889 }
1890
1891 let table = match &from.relation {
1892 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1893 sqlparser::ast::TableFactor::Derived {
1894 subquery, alias, ..
1895 } => {
1896 let alias_name = alias
1897 .as_ref()
1898 .map(|a| a.name.value.clone())
1899 .ok_or_else(|| {
1900 QueryError::ParseError("subquery in FROM requires an alias".to_string())
1901 })?;
1902
1903 let inner = match subquery.body.as_ref() {
1905 SetExpr::Select(s) => parse_select(s)?,
1906 _ => {
1907 return Err(QueryError::UnsupportedFeature(
1908 "subquery body must be a simple SELECT".to_string(),
1909 ));
1910 }
1911 };
1912
1913 let order_by = match &subquery.order_by {
1914 Some(ob) => parse_order_by(ob)?,
1915 None => vec![],
1916 };
1917 let limit = parse_limit(query_limit_expr(subquery)?)?;
1918
1919 inline_ctes.push(ParsedCte {
1920 name: alias_name.clone(),
1921 query: ParsedSelect {
1922 order_by,
1923 limit,
1924 ..inner
1925 },
1926 recursive_arm: None,
1927 });
1928
1929 alias_name
1930 }
1931 other => {
1932 return Err(QueryError::UnsupportedFeature(format!(
1933 "unsupported FROM clause: {other:?}"
1934 )));
1935 }
1936 };
1937
1938 let (columns, column_aliases) = parse_select_items(&select.projection)?;
1940
1941 let case_columns = parse_case_columns_from_select_items(&select.projection)?;
1943
1944 let predicates = match &select.selection {
1946 Some(expr) => parse_where_expr(expr)?,
1947 None => vec![],
1948 };
1949
1950 let group_by = match &select.group_by {
1952 sqlparser::ast::GroupByExpr::Expressions(exprs, _) if !exprs.is_empty() => {
1953 parse_group_by_expr(exprs)?
1954 }
1955 sqlparser::ast::GroupByExpr::All(_) => {
1956 return Err(QueryError::UnsupportedFeature(
1957 "GROUP BY ALL is not supported".to_string(),
1958 ));
1959 }
1960 sqlparser::ast::GroupByExpr::Expressions(_, _) => vec![],
1961 };
1962
1963 let (aggregates, aggregate_filters) = parse_aggregates_from_select_items(&select.projection)?;
1965
1966 let having = match &select.having {
1968 Some(expr) => parse_having_expr(expr)?,
1969 None => vec![],
1970 };
1971
1972 let window_fns = parse_window_fns_from_select_items(&select.projection)?;
1975
1976 let scalar_projections = parse_scalar_columns_from_select_items(&select.projection)?;
1979
1980 Ok(ParsedSelect {
1981 table,
1982 joins,
1983 columns,
1984 column_aliases,
1985 case_columns,
1986 predicates,
1987 order_by: vec![],
1988 limit: None,
1989 offset: None,
1990 aggregates,
1991 aggregate_filters,
1992 group_by,
1993 distinct,
1994 having,
1995 ctes: inline_ctes,
1996 window_fns,
1997 scalar_projections,
1998 })
1999}
2000
2001fn parse_ctes(with: &sqlparser::ast::With) -> Result<Vec<ParsedCte>> {
2008 let max_ctes = 16;
2009 let mut ctes = Vec::new();
2010
2011 for (i, cte) in with.cte_tables.iter().enumerate() {
2012 if i >= max_ctes {
2013 return Err(QueryError::UnsupportedFeature(format!(
2014 "too many CTEs (max {max_ctes})"
2015 )));
2016 }
2017
2018 let name = cte.alias.name.value.clone();
2019
2020 let (inner_select, recursive_arm) = match cte.query.body.as_ref() {
2025 SetExpr::Select(s) => (parse_select(s)?, None),
2026 SetExpr::SetOperation {
2027 op, left, right, ..
2028 } if with.recursive => {
2029 use sqlparser::ast::SetOperator;
2030 if !matches!(op, SetOperator::Union) {
2031 return Err(QueryError::UnsupportedFeature(
2032 "recursive CTE body must use UNION (not INTERSECT/EXCEPT)".to_string(),
2033 ));
2034 }
2035 let anchor = match left.as_ref() {
2036 SetExpr::Select(s) => parse_select(s)?,
2037 _ => {
2038 return Err(QueryError::UnsupportedFeature(
2039 "recursive CTE anchor must be a simple SELECT".to_string(),
2040 ));
2041 }
2042 };
2043 let recursive = match right.as_ref() {
2044 SetExpr::Select(s) => parse_select(s)?,
2045 _ => {
2046 return Err(QueryError::UnsupportedFeature(
2047 "recursive CTE recursive arm must be a simple SELECT".to_string(),
2048 ));
2049 }
2050 };
2051 (anchor, Some(recursive))
2052 }
2053 _ => {
2054 return Err(QueryError::UnsupportedFeature(
2055 "CTE body must be a simple SELECT (or anchor UNION recursive for WITH RECURSIVE)".to_string(),
2056 ));
2057 }
2058 };
2059
2060 let order_by = match &cte.query.order_by {
2062 Some(ob) => parse_order_by(ob)?,
2063 None => vec![],
2064 };
2065 let limit = parse_limit(query_limit_expr(&cte.query)?)?;
2066
2067 ctes.push(ParsedCte {
2068 name,
2069 query: ParsedSelect {
2070 order_by,
2071 limit,
2072 ..inner_select
2073 },
2074 recursive_arm,
2075 });
2076 }
2077
2078 Ok(ctes)
2079}
2080
2081fn parse_having_expr(expr: &Expr) -> Result<Vec<HavingCondition>> {
2086 match expr {
2087 Expr::BinaryOp {
2088 left,
2089 op: BinaryOperator::And,
2090 right,
2091 } => {
2092 let mut conditions = parse_having_expr(left)?;
2093 conditions.extend(parse_having_expr(right)?);
2094 Ok(conditions)
2095 }
2096 Expr::BinaryOp { left, op, right } => {
2097 let aggregate = match left.as_ref() {
2099 Expr::Function(_) => {
2100 let (agg, _filter) = try_parse_aggregate(left)?.ok_or_else(|| {
2101 QueryError::UnsupportedFeature(
2102 "HAVING requires aggregate functions (COUNT, SUM, AVG, MIN, MAX)"
2103 .to_string(),
2104 )
2105 })?;
2106 agg
2107 }
2108 _ => {
2109 return Err(QueryError::UnsupportedFeature(
2110 "HAVING clause must reference aggregate functions".to_string(),
2111 ));
2112 }
2113 };
2114
2115 let value = expr_to_value(right)?;
2117
2118 let having_op = match op {
2120 BinaryOperator::Eq => HavingOp::Eq,
2121 BinaryOperator::Lt => HavingOp::Lt,
2122 BinaryOperator::LtEq => HavingOp::Le,
2123 BinaryOperator::Gt => HavingOp::Gt,
2124 BinaryOperator::GtEq => HavingOp::Ge,
2125 other => {
2126 return Err(QueryError::UnsupportedFeature(format!(
2127 "unsupported HAVING operator: {other:?}"
2128 )));
2129 }
2130 };
2131
2132 Ok(vec![HavingCondition::AggregateComparison {
2133 aggregate,
2134 op: having_op,
2135 value,
2136 }])
2137 }
2138 Expr::Nested(inner) => parse_having_expr(inner),
2139 other => Err(QueryError::UnsupportedFeature(format!(
2140 "unsupported HAVING expression: {other:?}"
2141 ))),
2142 }
2143}
2144
2145type ParsedSelectList = (Option<Vec<ColumnName>>, Option<Vec<Option<String>>>);
2159
2160fn parse_select_items(items: &[SelectItem]) -> Result<ParsedSelectList> {
2161 let mut columns = Vec::new();
2162 let mut aliases: Vec<Option<String>> = Vec::new();
2163
2164 for item in items {
2165 #[allow(clippy::match_same_arms)]
2170 match item {
2171 SelectItem::Wildcard(_) => {
2172 return Ok((None, None));
2175 }
2176 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
2177 columns.push(ColumnName::new(ident.value.clone()));
2178 aliases.push(None);
2179 }
2180 SelectItem::UnnamedExpr(Expr::CompoundIdentifier(idents)) if idents.len() == 2 => {
2181 columns.push(ColumnName::new(idents[1].value.clone()));
2183 aliases.push(None);
2184 }
2185 SelectItem::ExprWithAlias {
2186 expr: Expr::Identifier(ident),
2187 alias,
2188 } => {
2189 columns.push(ColumnName::new(ident.value.clone()));
2190 aliases.push(Some(alias.value.clone()));
2191 }
2192 SelectItem::ExprWithAlias {
2193 expr: Expr::CompoundIdentifier(idents),
2194 alias,
2195 } if idents.len() == 2 => {
2196 columns.push(ColumnName::new(idents[1].value.clone()));
2198 aliases.push(Some(alias.value.clone()));
2199 }
2200 SelectItem::UnnamedExpr(Expr::Function(_))
2201 | SelectItem::ExprWithAlias {
2202 expr: Expr::Function(_) | Expr::Case { .. },
2203 ..
2204 } => {
2205 }
2209 SelectItem::UnnamedExpr(Expr::Cast { .. })
2213 | SelectItem::ExprWithAlias {
2214 expr: Expr::Cast { .. },
2215 ..
2216 } => {}
2217 SelectItem::UnnamedExpr(Expr::BinaryOp {
2218 op: BinaryOperator::StringConcat,
2219 ..
2220 })
2221 | SelectItem::ExprWithAlias {
2222 expr:
2223 Expr::BinaryOp {
2224 op: BinaryOperator::StringConcat,
2225 ..
2226 },
2227 ..
2228 } => {}
2229 other => {
2230 return Err(QueryError::UnsupportedFeature(format!(
2231 "unsupported SELECT item: {other:?}"
2232 )));
2233 }
2234 }
2235 }
2236
2237 Ok((Some(columns), Some(aliases)))
2238}
2239
2240type ParsedAggregateList = (Vec<AggregateFunction>, Vec<Option<Vec<Predicate>>>);
2248
2249fn parse_aggregates_from_select_items(items: &[SelectItem]) -> Result<ParsedAggregateList> {
2250 let mut aggregates = Vec::new();
2251 let mut filters = Vec::new();
2252
2253 for item in items {
2254 match item {
2255 SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
2256 if let Some((agg, filter)) = try_parse_aggregate(expr)? {
2257 aggregates.push(agg);
2258 filters.push(filter);
2259 }
2260 }
2261 _ => {
2262 }
2264 }
2265 }
2266
2267 Ok((aggregates, filters))
2268}
2269
2270fn parse_case_columns_from_select_items(items: &[SelectItem]) -> Result<Vec<ComputedColumn>> {
2275 let mut case_cols = Vec::new();
2276
2277 for item in items {
2278 if let SelectItem::ExprWithAlias {
2279 expr:
2280 Expr::Case {
2281 operand,
2282 conditions,
2283 else_result,
2284 ..
2285 },
2286 alias,
2287 } = item
2288 {
2289 let mut when_clauses = Vec::new();
2294 for case_when in conditions {
2295 let cond_expr = &case_when.condition;
2296 let result_expr = &case_when.result;
2297 let condition = match operand.as_deref() {
2298 None => parse_where_expr(cond_expr)?,
2299 Some(operand_expr) => parse_where_expr(&Expr::BinaryOp {
2300 left: Box::new(operand_expr.clone()),
2301 op: BinaryOperator::Eq,
2302 right: Box::new(cond_expr.clone()),
2303 })?,
2304 };
2305 let result = expr_to_value(result_expr)?;
2306 when_clauses.push(CaseWhenArm { condition, result });
2307 }
2308
2309 let else_value = match else_result {
2310 Some(expr) => expr_to_value(expr)?,
2311 None => Value::Null,
2312 };
2313
2314 case_cols.push(ComputedColumn {
2315 alias: ColumnName::new(alias.value.clone()),
2316 when_clauses,
2317 else_value,
2318 });
2319 }
2320 }
2321
2322 Ok(case_cols)
2323}
2324
2325fn parse_scalar_columns_from_select_items(
2335 items: &[SelectItem],
2336) -> Result<Vec<ParsedScalarProjection>> {
2337 let mut out = Vec::new();
2338 for item in items {
2339 let (expr, alias) = match item {
2340 SelectItem::UnnamedExpr(e) => (e, None),
2341 SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
2342 _ => continue,
2343 };
2344
2345 if !is_scalar_projection_shape(expr) {
2346 continue;
2347 }
2348
2349 let scalar = expr_to_scalar_expr(expr)?;
2350 let output_name = alias
2351 .clone()
2352 .unwrap_or_else(|| synthesize_column_name(expr));
2353 out.push(ParsedScalarProjection {
2354 expr: scalar,
2355 output_name: ColumnName::new(output_name),
2356 alias,
2357 });
2358 }
2359 Ok(out)
2360}
2361
2362fn is_scalar_projection_shape(expr: &Expr) -> bool {
2366 match expr {
2367 Expr::Function(func) => {
2368 if func.over.is_some() {
2370 return false;
2371 }
2372 let name = func.name.to_string().to_uppercase();
2373 !matches!(name.as_str(), "COUNT" | "SUM" | "AVG" | "MIN" | "MAX")
2374 }
2375 Expr::Cast { .. }
2376 | Expr::BinaryOp {
2377 op: BinaryOperator::StringConcat,
2378 ..
2379 } => true,
2380 _ => false,
2381 }
2382}
2383
2384fn synthesize_column_name(expr: &Expr) -> String {
2388 match expr {
2389 Expr::Function(func) => func.name.to_string().to_lowercase(),
2390 Expr::Cast { .. } => "cast".to_string(),
2391 Expr::BinaryOp {
2392 op: BinaryOperator::StringConcat,
2393 ..
2394 } => "concat".to_string(),
2395 _ => "expr".to_string(),
2396 }
2397}
2398
2399fn parse_window_fns_from_select_items(items: &[SelectItem]) -> Result<Vec<ParsedWindowFn>> {
2403 let mut out = Vec::new();
2404 for item in items {
2405 let (expr, alias) = match item {
2406 SelectItem::UnnamedExpr(e) => (e, None),
2407 SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
2408 _ => continue,
2409 };
2410 if let Some(parsed) = try_parse_window_fn(expr, alias)? {
2411 out.push(parsed);
2412 }
2413 }
2414 Ok(out)
2415}
2416
2417fn try_parse_window_fn(expr: &Expr, alias: Option<String>) -> Result<Option<ParsedWindowFn>> {
2418 let Expr::Function(func) = expr else {
2419 return Ok(None);
2420 };
2421 let Some(over) = &func.over else {
2422 return Ok(None);
2423 };
2424 let spec = match over {
2425 sqlparser::ast::WindowType::WindowSpec(s) => s,
2426 sqlparser::ast::WindowType::NamedWindow(_) => {
2427 return Err(QueryError::UnsupportedFeature(
2428 "named windows (OVER w) are not supported".into(),
2429 ));
2430 }
2431 };
2432 if spec.window_frame.is_some() {
2433 return Err(QueryError::UnsupportedFeature(
2434 "explicit window frames (ROWS/RANGE BETWEEN ...) are not supported; \
2435 omit the frame clause for default behaviour"
2436 .into(),
2437 ));
2438 }
2439
2440 let func_name = func.name.to_string().to_uppercase();
2441 let args = match &func.args {
2442 sqlparser::ast::FunctionArguments::List(list) => list.args.clone(),
2443 _ => Vec::new(),
2444 };
2445 let function = parse_window_function_name(&func_name, &args)?;
2446
2447 let partition_by: Vec<ColumnName> = spec
2448 .partition_by
2449 .iter()
2450 .map(parse_column_expr)
2451 .collect::<Result<_>>()?;
2452 let order_by: Vec<OrderByClause> = spec
2453 .order_by
2454 .iter()
2455 .map(parse_order_by_expr)
2456 .collect::<Result<_>>()?;
2457
2458 Ok(Some(ParsedWindowFn {
2459 function,
2460 partition_by,
2461 order_by,
2462 alias,
2463 }))
2464}
2465
2466fn parse_column_expr(expr: &Expr) -> Result<ColumnName> {
2467 match expr {
2468 Expr::Identifier(ident) => Ok(ColumnName::new(ident.value.clone())),
2469 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
2470 Ok(ColumnName::new(idents[1].value.clone()))
2471 }
2472 other => Err(QueryError::UnsupportedFeature(format!(
2473 "window PARTITION BY / argument must be a column reference, got: {other:?}"
2474 ))),
2475 }
2476}
2477
2478fn parse_window_function_name(
2479 name: &str,
2480 args: &[sqlparser::ast::FunctionArg],
2481) -> Result<crate::window::WindowFunction> {
2482 use crate::window::WindowFunction;
2483
2484 let arg_exprs: Vec<&Expr> = args
2485 .iter()
2486 .filter_map(|a| match a {
2487 sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(e)) => {
2488 Some(e)
2489 }
2490 _ => None,
2491 })
2492 .collect();
2493
2494 let single_col = || -> Result<ColumnName> {
2495 if arg_exprs.is_empty() {
2496 return Err(QueryError::ParseError(format!(
2497 "{name} requires a column argument"
2498 )));
2499 }
2500 parse_column_expr(arg_exprs[0])
2501 };
2502
2503 let parse_offset = || -> Result<usize> {
2504 if arg_exprs.len() < 2 {
2505 return Ok(1);
2506 }
2507 match arg_exprs[1] {
2508 Expr::Value(vws) => match &vws.value {
2509 SqlValue::Number(n, _) => n
2510 .parse::<usize>()
2511 .map_err(|_| QueryError::ParseError(format!("invalid {name} offset: {n}"))),
2512 other => Err(QueryError::UnsupportedFeature(format!(
2513 "{name} offset must be a literal integer; got {other:?}"
2514 ))),
2515 },
2516 other => Err(QueryError::UnsupportedFeature(format!(
2517 "{name} offset must be a literal integer; got {other:?}"
2518 ))),
2519 }
2520 };
2521
2522 match name {
2523 "ROW_NUMBER" => Ok(WindowFunction::RowNumber),
2524 "RANK" => Ok(WindowFunction::Rank),
2525 "DENSE_RANK" => Ok(WindowFunction::DenseRank),
2526 "LAG" => Ok(WindowFunction::Lag {
2527 column: single_col()?,
2528 offset: parse_offset()?,
2529 }),
2530 "LEAD" => Ok(WindowFunction::Lead {
2531 column: single_col()?,
2532 offset: parse_offset()?,
2533 }),
2534 "FIRST_VALUE" => Ok(WindowFunction::FirstValue {
2535 column: single_col()?,
2536 }),
2537 "LAST_VALUE" => Ok(WindowFunction::LastValue {
2538 column: single_col()?,
2539 }),
2540 other => Err(QueryError::UnsupportedFeature(format!(
2541 "unknown window function: {other}"
2542 ))),
2543 }
2544}
2545
2546type ParsedAggregate = (AggregateFunction, Option<Vec<Predicate>>);
2548
2549fn try_parse_aggregate(expr: &Expr) -> Result<Option<ParsedAggregate>> {
2553 let parsed_filter: Option<Vec<Predicate>> = match expr {
2554 Expr::Function(func) => match &func.filter {
2555 Some(filter_expr) => Some(parse_where_expr(filter_expr)?),
2556 None => None,
2557 },
2558 _ => None,
2559 };
2560 let func_only = try_parse_aggregate_func(expr)?;
2561 Ok(func_only.map(|f| (f, parsed_filter)))
2562}
2563
2564fn try_parse_aggregate_func(expr: &Expr) -> Result<Option<AggregateFunction>> {
2566 match expr {
2567 Expr::Function(func) => {
2568 if func.over.is_some() {
2573 return Ok(None);
2574 }
2575 let func_name = func.name.to_string().to_uppercase();
2576
2577 let args = match &func.args {
2579 sqlparser::ast::FunctionArguments::List(list) => &list.args,
2580 _ => {
2581 return Err(QueryError::UnsupportedFeature(
2582 "non-list function arguments not supported".to_string(),
2583 ));
2584 }
2585 };
2586
2587 match func_name.as_str() {
2588 "COUNT" => {
2589 if args.len() == 1 {
2591 match &args[0] {
2592 sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
2593 sqlparser::ast::FunctionArgExpr::Wildcard => {
2594 Ok(Some(AggregateFunction::CountStar))
2595 }
2596 sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
2597 Ok(Some(AggregateFunction::Count(ColumnName::new(
2598 ident.value.clone(),
2599 ))))
2600 }
2601 _ => Err(QueryError::UnsupportedFeature(
2602 "COUNT with complex expression not supported".to_string(),
2603 )),
2604 },
2605 _ => Err(QueryError::UnsupportedFeature(
2606 "named function arguments not supported".to_string(),
2607 )),
2608 }
2609 } else {
2610 Err(QueryError::ParseError(format!(
2611 "COUNT expects 1 argument, got {}",
2612 args.len()
2613 )))
2614 }
2615 }
2616 "SUM" | "AVG" | "MIN" | "MAX" => {
2617 if args.len() != 1 {
2619 return Err(QueryError::ParseError(format!(
2620 "{} expects 1 argument, got {}",
2621 func_name,
2622 args.len()
2623 )));
2624 }
2625
2626 match &args[0] {
2627 sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
2628 sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
2629 let column = ColumnName::new(ident.value.clone());
2630 match func_name.as_str() {
2631 "SUM" => Ok(Some(AggregateFunction::Sum(column))),
2632 "AVG" => Ok(Some(AggregateFunction::Avg(column))),
2633 "MIN" => Ok(Some(AggregateFunction::Min(column))),
2634 "MAX" => Ok(Some(AggregateFunction::Max(column))),
2635 _ => unreachable!(),
2636 }
2637 }
2638 _ => Err(QueryError::UnsupportedFeature(format!(
2639 "{func_name} with complex expression not supported"
2640 ))),
2641 },
2642 _ => Err(QueryError::UnsupportedFeature(
2643 "named function arguments not supported".to_string(),
2644 )),
2645 }
2646 }
2647 _ => {
2648 Ok(None)
2650 }
2651 }
2652 }
2653 _ => {
2654 Ok(None)
2656 }
2657 }
2658}
2659
2660fn parse_group_by_expr(exprs: &[Expr]) -> Result<Vec<ColumnName>> {
2662 let mut columns = Vec::new();
2663
2664 for expr in exprs {
2665 match expr {
2666 Expr::Identifier(ident) => {
2667 columns.push(ColumnName::new(ident.value.clone()));
2668 }
2669 _ => {
2670 return Err(QueryError::UnsupportedFeature(
2671 "complex GROUP BY expressions not supported".to_string(),
2672 ));
2673 }
2674 }
2675 }
2676
2677 Ok(columns)
2678}
2679
2680const MAX_WHERE_DEPTH: usize = 100;
2688
2689fn parse_where_expr(expr: &Expr) -> Result<Vec<Predicate>> {
2690 parse_where_expr_inner(expr, 0)
2691}
2692
2693fn parse_select_from_query(query: &sqlparser::ast::Query) -> Result<ParsedSelect> {
2699 match query.body.as_ref() {
2700 SetExpr::Select(s) => {
2701 let mut parsed = parse_select(s)?;
2702 if let Some(ob) = &query.order_by {
2703 parsed.order_by = parse_order_by(ob)?;
2704 }
2705 parsed.limit = parse_limit(query_limit_expr(query)?)?;
2706 parsed.offset = parse_offset_clause(query_offset(query))?;
2707 Ok(parsed)
2708 }
2709 _ => Err(QueryError::UnsupportedFeature(
2710 "subquery body must be a simple SELECT (no nested UNION/INTERSECT/EXCEPT)".to_string(),
2711 )),
2712 }
2713}
2714
2715fn parse_where_expr_inner(expr: &Expr, depth: usize) -> Result<Vec<Predicate>> {
2716 if depth >= MAX_WHERE_DEPTH {
2717 return Err(QueryError::ParseError(format!(
2718 "WHERE clause nesting exceeds maximum depth of {MAX_WHERE_DEPTH}"
2719 )));
2720 }
2721
2722 match expr {
2723 Expr::BinaryOp {
2725 left,
2726 op: BinaryOperator::And,
2727 right,
2728 } => {
2729 let mut predicates = parse_where_expr_inner(left, depth + 1)?;
2730 predicates.extend(parse_where_expr_inner(right, depth + 1)?);
2731 Ok(predicates)
2732 }
2733
2734 Expr::BinaryOp {
2736 left,
2737 op: BinaryOperator::Or,
2738 right,
2739 } => {
2740 let left_preds = parse_where_expr_inner(left, depth + 1)?;
2741 let right_preds = parse_where_expr_inner(right, depth + 1)?;
2742 Ok(vec![Predicate::Or(left_preds, right_preds)])
2743 }
2744
2745 Expr::Like {
2747 expr,
2748 pattern,
2749 negated,
2750 ..
2751 } => {
2752 let column = expr_to_column(expr)?;
2753 let pattern_str = match expr_to_predicate_value(pattern)? {
2754 PredicateValue::String(s) | PredicateValue::Literal(Value::Text(s)) => s,
2755 _ => {
2756 return Err(QueryError::UnsupportedFeature(
2757 "LIKE pattern must be a string literal".to_string(),
2758 ));
2759 }
2760 };
2761 let predicate = if *negated {
2762 Predicate::NotLike(column, pattern_str)
2763 } else {
2764 Predicate::Like(column, pattern_str)
2765 };
2766 Ok(vec![predicate])
2767 }
2768
2769 Expr::ILike {
2771 expr,
2772 pattern,
2773 negated,
2774 ..
2775 } => {
2776 let column = expr_to_column(expr)?;
2777 let pattern_str = match expr_to_predicate_value(pattern)? {
2778 PredicateValue::String(s) | PredicateValue::Literal(Value::Text(s)) => s,
2779 _ => {
2780 return Err(QueryError::UnsupportedFeature(
2781 "ILIKE pattern must be a string literal".to_string(),
2782 ));
2783 }
2784 };
2785 let predicate = if *negated {
2786 Predicate::NotILike(column, pattern_str)
2787 } else {
2788 Predicate::ILike(column, pattern_str)
2789 };
2790 Ok(vec![predicate])
2791 }
2792
2793 Expr::IsNull(expr) => {
2795 let column = expr_to_column(expr)?;
2796 Ok(vec![Predicate::IsNull(column)])
2797 }
2798
2799 Expr::IsNotNull(expr) => {
2800 let column = expr_to_column(expr)?;
2801 Ok(vec![Predicate::IsNotNull(column)])
2802 }
2803
2804 Expr::BinaryOp { left, op, right } => {
2806 let predicate = parse_comparison(left, op, right)?;
2807 Ok(vec![predicate])
2808 }
2809
2810 Expr::InList {
2812 expr,
2813 list,
2814 negated,
2815 } => {
2816 let column = expr_to_column(expr)?;
2817 let values: Result<Vec<_>> = list.iter().map(expr_to_predicate_value).collect();
2818 if *negated {
2819 Ok(vec![Predicate::NotIn(column, values?)])
2820 } else {
2821 Ok(vec![Predicate::In(column, values?)])
2822 }
2823 }
2824
2825 Expr::InSubquery {
2829 expr,
2830 subquery,
2831 negated,
2832 } => {
2833 let column = expr_to_column(expr)?;
2834 let inner = parse_select_from_query(subquery)?;
2835 Ok(vec![Predicate::InSubquery {
2836 column,
2837 subquery: Box::new(inner),
2838 negated: *negated,
2839 }])
2840 }
2841
2842 Expr::Exists { subquery, negated } => {
2844 let inner = parse_select_from_query(subquery)?;
2845 Ok(vec![Predicate::Exists {
2846 subquery: Box::new(inner),
2847 negated: *negated,
2848 }])
2849 }
2850
2851 Expr::Between {
2857 expr,
2858 negated,
2859 low,
2860 high,
2861 } => {
2862 let column = expr_to_column(expr)?;
2863 let low_val = expr_to_predicate_value(low)?;
2864 let high_val = expr_to_predicate_value(high)?;
2865
2866 if *negated {
2867 return Ok(vec![Predicate::NotBetween(column, low_val, high_val)]);
2868 }
2869
2870 kimberlite_properties::sometimes!(
2871 true,
2872 "query.between_desugared_to_ge_le",
2873 "BETWEEN predicate desugared into Ge + Le pair"
2874 );
2875
2876 Ok(vec![
2877 Predicate::Ge(column.clone(), low_val),
2878 Predicate::Le(column, high_val),
2879 ])
2880 }
2881
2882 Expr::Nested(inner) => parse_where_expr_inner(inner, depth + 1),
2884
2885 other => Err(QueryError::UnsupportedFeature(format!(
2886 "unsupported WHERE expression: {other:?}"
2887 ))),
2888 }
2889}
2890
2891fn parse_comparison(left: &Expr, op: &BinaryOperator, right: &Expr) -> Result<Predicate> {
2892 let left = match left {
2896 Expr::Nested(inner) => inner.as_ref(),
2897 other => other,
2898 };
2899
2900 if matches!(op, BinaryOperator::AtArrow) {
2903 let column = expr_to_column(left)?;
2904 let value = expr_to_predicate_value(right)?;
2905 return Ok(Predicate::JsonContains { column, value });
2906 }
2907
2908 if let Expr::BinaryOp {
2913 left: json_left,
2914 op: arrow_op @ (BinaryOperator::Arrow | BinaryOperator::LongArrow),
2915 right: path_expr,
2916 } = left
2917 {
2918 let as_text = matches!(arrow_op, BinaryOperator::LongArrow);
2919 let column = expr_to_column(json_left)?;
2920 let path = match path_expr.as_ref() {
2921 Expr::Value(vws) => match &vws.value {
2922 SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => s.clone(),
2923 SqlValue::Number(n, _) => n.clone(),
2924 _ => {
2925 return Err(QueryError::UnsupportedFeature(format!(
2926 "JSON path key must be a string or integer literal, got {path_expr:?}"
2927 )));
2928 }
2929 },
2930 other => {
2931 return Err(QueryError::UnsupportedFeature(format!(
2932 "JSON path key must be a string or integer literal, got {other:?}"
2933 )));
2934 }
2935 };
2936 let value = expr_to_predicate_value(right)?;
2937 if !matches!(op, BinaryOperator::Eq) {
2938 return Err(QueryError::UnsupportedFeature(format!(
2939 "JSON path extraction supports only `=` comparison; got {op:?}"
2940 )));
2941 }
2942 return Ok(Predicate::JsonExtractEq {
2943 column,
2944 path,
2945 as_text,
2946 value,
2947 });
2948 }
2949
2950 let cmp_op = sql_binop_to_scalar_cmp(op);
2954
2955 if !expr_needs_scalar(left) && !expr_needs_scalar(right) {
2958 if let (Ok(column), Ok(value)) = (expr_to_column(left), expr_to_predicate_value(right)) {
2959 return match op {
2960 BinaryOperator::Eq => Ok(Predicate::Eq(column, value)),
2961 BinaryOperator::Lt => Ok(Predicate::Lt(column, value)),
2962 BinaryOperator::LtEq => Ok(Predicate::Le(column, value)),
2963 BinaryOperator::Gt => Ok(Predicate::Gt(column, value)),
2964 BinaryOperator::GtEq => Ok(Predicate::Ge(column, value)),
2965 BinaryOperator::NotEq => {
2966 Ok(Predicate::ScalarCmp {
2969 lhs: ScalarExpr::Column(column),
2970 op: ScalarCmpOp::NotEq,
2971 rhs: predicate_value_to_scalar_expr(&value),
2972 })
2973 }
2974 other => Err(QueryError::UnsupportedFeature(format!(
2975 "unsupported operator: {other:?}"
2976 ))),
2977 };
2978 }
2979 }
2980
2981 let lhs = expr_to_scalar_expr(left)?;
2984 let rhs = expr_to_scalar_expr(right)?;
2985 let op = cmp_op.ok_or_else(|| {
2986 QueryError::UnsupportedFeature(format!("unsupported operator in scalar comparison: {op:?}"))
2987 })?;
2988 Ok(Predicate::ScalarCmp { lhs, op, rhs })
2989}
2990
2991fn expr_needs_scalar(expr: &Expr) -> bool {
2995 match expr {
2996 Expr::Function(_)
2997 | Expr::Cast { .. }
2998 | Expr::BinaryOp {
2999 op: BinaryOperator::StringConcat,
3000 ..
3001 } => true,
3002 Expr::Nested(inner) => expr_needs_scalar(inner),
3003 _ => false,
3004 }
3005}
3006
3007fn sql_binop_to_scalar_cmp(op: &BinaryOperator) -> Option<ScalarCmpOp> {
3008 Some(match op {
3009 BinaryOperator::Eq => ScalarCmpOp::Eq,
3010 BinaryOperator::NotEq => ScalarCmpOp::NotEq,
3011 BinaryOperator::Lt => ScalarCmpOp::Lt,
3012 BinaryOperator::LtEq => ScalarCmpOp::Le,
3013 BinaryOperator::Gt => ScalarCmpOp::Gt,
3014 BinaryOperator::GtEq => ScalarCmpOp::Ge,
3015 _ => return None,
3016 })
3017}
3018
3019fn predicate_value_to_scalar_expr(pv: &PredicateValue) -> ScalarExpr {
3020 match pv {
3021 PredicateValue::Int(n) => ScalarExpr::Literal(Value::BigInt(*n)),
3022 PredicateValue::String(s) => ScalarExpr::Literal(Value::Text(s.clone())),
3023 PredicateValue::Bool(b) => ScalarExpr::Literal(Value::Boolean(*b)),
3024 PredicateValue::Null => ScalarExpr::Literal(Value::Null),
3025 PredicateValue::Param(idx) => ScalarExpr::Literal(Value::Placeholder(*idx)),
3026 PredicateValue::Literal(v) => ScalarExpr::Literal(v.clone()),
3027 PredicateValue::ColumnRef(name) => {
3028 let col = name.rsplit('.').next().unwrap_or(name);
3030 ScalarExpr::Column(ColumnName::new(col.to_string()))
3031 }
3032 }
3033}
3034
3035fn expr_to_column(expr: &Expr) -> Result<ColumnName> {
3036 match expr {
3037 Expr::Identifier(ident) => Ok(ColumnName::new(ident.value.clone())),
3038 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
3039 Ok(ColumnName::new(idents[1].value.clone()))
3041 }
3042 other => Err(QueryError::UnsupportedFeature(format!(
3043 "expected column name, got {other:?}"
3044 ))),
3045 }
3046}
3047
3048pub fn expr_to_scalar_expr(expr: &Expr) -> Result<ScalarExpr> {
3061 match expr {
3062 Expr::Value(_) | Expr::UnaryOp { .. } => Ok(ScalarExpr::Literal(expr_to_value(expr)?)),
3065
3066 Expr::Identifier(ident) => Ok(ScalarExpr::Column(ColumnName::new(ident.value.clone()))),
3068 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
3069 Ok(ScalarExpr::Column(ColumnName::new(idents[1].value.clone())))
3070 }
3071
3072 Expr::BinaryOp {
3075 left,
3076 op: BinaryOperator::StringConcat,
3077 right,
3078 } => Ok(ScalarExpr::Concat(vec![
3079 expr_to_scalar_expr(left)?,
3080 expr_to_scalar_expr(right)?,
3081 ])),
3082
3083 Expr::Cast {
3085 expr: inner,
3086 data_type,
3087 ..
3088 } => {
3089 let target = sql_data_type_to_data_type(data_type)?;
3090 Ok(ScalarExpr::Cast(
3091 Box::new(expr_to_scalar_expr(inner)?),
3092 target,
3093 ))
3094 }
3095
3096 Expr::Nested(inner) => expr_to_scalar_expr(inner),
3098
3099 Expr::Function(func) => {
3101 if func.over.is_some() {
3102 return Err(QueryError::UnsupportedFeature(
3103 "window functions are not valid in this position".to_string(),
3104 ));
3105 }
3106 if func.filter.is_some() {
3107 return Err(QueryError::UnsupportedFeature(
3108 "FILTER clause only applies to aggregate functions".to_string(),
3109 ));
3110 }
3111 let name = func.name.to_string().to_uppercase();
3112 let args = match &func.args {
3113 sqlparser::ast::FunctionArguments::List(list) => &list.args,
3114 _ => {
3115 return Err(QueryError::UnsupportedFeature(
3116 "non-list function arguments not supported".to_string(),
3117 ));
3118 }
3119 };
3120
3121 let mut arg_exprs: Vec<&Expr> = Vec::with_capacity(args.len());
3123 for a in args {
3124 match a {
3125 sqlparser::ast::FunctionArg::Unnamed(
3126 sqlparser::ast::FunctionArgExpr::Expr(e),
3127 ) => arg_exprs.push(e),
3128 _ => {
3129 return Err(QueryError::UnsupportedFeature(format!(
3130 "unsupported argument form in scalar function {name}"
3131 )));
3132 }
3133 }
3134 }
3135
3136 let want_arity = |n: usize| -> Result<()> {
3137 if arg_exprs.len() == n {
3138 Ok(())
3139 } else {
3140 Err(QueryError::ParseError(format!(
3141 "{name} expects {n} argument(s), got {}",
3142 arg_exprs.len()
3143 )))
3144 }
3145 };
3146 let scalar = |e: &Expr| expr_to_scalar_expr(e);
3147
3148 match name.as_str() {
3149 "UPPER" => {
3150 want_arity(1)?;
3151 Ok(ScalarExpr::Upper(Box::new(scalar(arg_exprs[0])?)))
3152 }
3153 "LOWER" => {
3154 want_arity(1)?;
3155 Ok(ScalarExpr::Lower(Box::new(scalar(arg_exprs[0])?)))
3156 }
3157 "LENGTH" | "CHAR_LENGTH" | "CHARACTER_LENGTH" => {
3158 want_arity(1)?;
3159 Ok(ScalarExpr::Length(Box::new(scalar(arg_exprs[0])?)))
3160 }
3161 "TRIM" => {
3162 want_arity(1)?;
3163 Ok(ScalarExpr::Trim(Box::new(scalar(arg_exprs[0])?)))
3164 }
3165 "CONCAT" => {
3166 if arg_exprs.is_empty() {
3167 return Err(QueryError::ParseError(
3168 "CONCAT expects at least one argument".to_string(),
3169 ));
3170 }
3171 let parts = arg_exprs
3172 .iter()
3173 .map(|e| scalar(e))
3174 .collect::<Result<Vec<_>>>()?;
3175 Ok(ScalarExpr::Concat(parts))
3176 }
3177 "ABS" => {
3178 want_arity(1)?;
3179 Ok(ScalarExpr::Abs(Box::new(scalar(arg_exprs[0])?)))
3180 }
3181 "ROUND" => match arg_exprs.len() {
3182 1 => Ok(ScalarExpr::Round(Box::new(scalar(arg_exprs[0])?))),
3183 2 => {
3184 let n = match expr_to_value(arg_exprs[1])? {
3186 Value::BigInt(n) => i32::try_from(n).map_err(|_| {
3187 QueryError::ParseError("ROUND scale out of range".to_string())
3188 })?,
3189 other => {
3190 return Err(QueryError::ParseError(format!(
3191 "ROUND scale must be an integer literal, got {other:?}"
3192 )));
3193 }
3194 };
3195 Ok(ScalarExpr::RoundScale(Box::new(scalar(arg_exprs[0])?), n))
3196 }
3197 _ => Err(QueryError::ParseError(format!(
3198 "ROUND expects 1 or 2 arguments, got {}",
3199 arg_exprs.len()
3200 ))),
3201 },
3202 "CEIL" | "CEILING" => {
3203 want_arity(1)?;
3204 Ok(ScalarExpr::Ceil(Box::new(scalar(arg_exprs[0])?)))
3205 }
3206 "FLOOR" => {
3207 want_arity(1)?;
3208 Ok(ScalarExpr::Floor(Box::new(scalar(arg_exprs[0])?)))
3209 }
3210 "COALESCE" => {
3211 if arg_exprs.is_empty() {
3212 return Err(QueryError::ParseError(
3213 "COALESCE expects at least one argument".to_string(),
3214 ));
3215 }
3216 let parts = arg_exprs
3217 .iter()
3218 .map(|e| scalar(e))
3219 .collect::<Result<Vec<_>>>()?;
3220 Ok(ScalarExpr::Coalesce(parts))
3221 }
3222 "NULLIF" => {
3223 want_arity(2)?;
3224 Ok(ScalarExpr::Nullif(
3225 Box::new(scalar(arg_exprs[0])?),
3226 Box::new(scalar(arg_exprs[1])?),
3227 ))
3228 }
3229 other => Err(QueryError::UnsupportedFeature(format!(
3230 "scalar function {other} is not supported"
3231 ))),
3232 }
3233 }
3234
3235 other => Err(QueryError::UnsupportedFeature(format!(
3236 "unsupported scalar expression: {other:?}"
3237 ))),
3238 }
3239}
3240
3241fn sql_data_type_to_data_type(sql_ty: &SqlDataType) -> Result<DataType> {
3245 Ok(match sql_ty {
3246 SqlDataType::TinyInt(_) => DataType::TinyInt,
3247 SqlDataType::SmallInt(_) => DataType::SmallInt,
3248 SqlDataType::Int(_) | SqlDataType::Integer(_) => DataType::Integer,
3249 SqlDataType::BigInt(_) => DataType::BigInt,
3250 SqlDataType::Real | SqlDataType::Float(_) | SqlDataType::Double(_) => DataType::Real,
3251 SqlDataType::Text | SqlDataType::Varchar(_) | SqlDataType::String(_) => DataType::Text,
3252 SqlDataType::Boolean | SqlDataType::Bool => DataType::Boolean,
3253 SqlDataType::Date => DataType::Date,
3254 SqlDataType::Time(_, _) => DataType::Time,
3255 SqlDataType::Timestamp(_, _) => DataType::Timestamp,
3256 SqlDataType::Uuid => DataType::Uuid,
3257 SqlDataType::JSON => DataType::Json,
3258 other => {
3259 return Err(QueryError::UnsupportedFeature(format!(
3260 "CAST to {other:?} is not supported"
3261 )));
3262 }
3263 })
3264}
3265
3266fn expr_to_predicate_value(expr: &Expr) -> Result<PredicateValue> {
3267 match expr {
3268 Expr::Identifier(ident) => {
3270 Ok(PredicateValue::ColumnRef(ident.value.clone()))
3272 }
3273 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
3274 Ok(PredicateValue::ColumnRef(format!(
3276 "{}.{}",
3277 idents[0].value, idents[1].value
3278 )))
3279 }
3280 Expr::Value(vws) => match &vws.value {
3281 SqlValue::Number(n, _) => {
3282 let value = parse_number_literal(n)?;
3283 match value {
3284 Value::BigInt(v) => Ok(PredicateValue::Int(v)),
3285 Value::Decimal(_, _) => Ok(PredicateValue::Literal(value)),
3286 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
3287 }
3288 }
3289 SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => {
3290 Ok(PredicateValue::String(s.clone()))
3291 }
3292 SqlValue::Boolean(b) => Ok(PredicateValue::Bool(*b)),
3293 SqlValue::Null => Ok(PredicateValue::Null),
3294 SqlValue::Placeholder(p) => Ok(PredicateValue::Param(parse_placeholder_index(p)?)),
3295 other => Err(QueryError::UnsupportedFeature(format!(
3296 "unsupported value expression: {other:?}"
3297 ))),
3298 },
3299 Expr::UnaryOp {
3300 op: sqlparser::ast::UnaryOperator::Minus,
3301 expr,
3302 } => {
3303 if let Expr::Value(vws) = expr.as_ref()
3305 && let SqlValue::Number(n, _) = &vws.value
3306 {
3307 let value = parse_number_literal(n)?;
3308 match value {
3309 Value::BigInt(v) => Ok(PredicateValue::Int(-v)),
3310 Value::Decimal(v, scale) => {
3311 Ok(PredicateValue::Literal(Value::Decimal(-v, scale)))
3312 }
3313 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
3314 }
3315 } else {
3316 Err(QueryError::UnsupportedFeature(format!(
3317 "unsupported unary minus operand: {expr:?}"
3318 )))
3319 }
3320 }
3321 other => Err(QueryError::UnsupportedFeature(format!(
3322 "unsupported value expression: {other:?}"
3323 ))),
3324 }
3325}
3326
3327fn parse_order_by(order_by: &sqlparser::ast::OrderBy) -> Result<Vec<OrderByClause>> {
3328 use sqlparser::ast::OrderByKind;
3329
3330 let exprs = match &order_by.kind {
3331 OrderByKind::Expressions(exprs) => exprs,
3332 OrderByKind::All(_) => {
3333 return Err(QueryError::UnsupportedFeature(
3334 "ORDER BY ALL is not supported".to_string(),
3335 ));
3336 }
3337 };
3338
3339 let mut clauses = Vec::new();
3340 for expr in exprs {
3341 clauses.push(parse_order_by_expr(expr)?);
3342 }
3343
3344 Ok(clauses)
3345}
3346
3347fn parse_order_by_expr(expr: &OrderByExpr) -> Result<OrderByClause> {
3348 let column = match &expr.expr {
3349 Expr::Identifier(ident) => ColumnName::new(ident.value.clone()),
3350 other => {
3351 return Err(QueryError::UnsupportedFeature(format!(
3352 "unsupported ORDER BY expression: {other:?}"
3353 )));
3354 }
3355 };
3356
3357 let ascending = expr.options.asc.unwrap_or(true);
3358
3359 Ok(OrderByClause { column, ascending })
3360}
3361
3362fn parse_limit(limit: Option<&Expr>) -> Result<Option<LimitExpr>> {
3363 match limit {
3364 None => Ok(None),
3365 Some(Expr::Value(vws)) => match &vws.value {
3366 SqlValue::Number(n, _) => {
3367 let v: usize = n
3368 .parse()
3369 .map_err(|_| QueryError::ParseError(format!("invalid LIMIT value: {n}")))?;
3370 Ok(Some(LimitExpr::Literal(v)))
3371 }
3372 SqlValue::Placeholder(p) => Ok(Some(LimitExpr::Param(parse_placeholder_index(p)?))),
3373 other => Err(QueryError::UnsupportedFeature(format!(
3374 "LIMIT must be an integer literal or parameter; got {other:?}"
3375 ))),
3376 },
3377 Some(other) => Err(QueryError::UnsupportedFeature(format!(
3378 "LIMIT must be an integer literal or parameter; got {other:?}"
3379 ))),
3380 }
3381}
3382
3383fn query_limit_expr(query: &Query) -> Result<Option<&Expr>> {
3386 use sqlparser::ast::LimitClause;
3387 match &query.limit_clause {
3388 None => Ok(None),
3389 Some(LimitClause::LimitOffset { limit, .. }) => Ok(limit.as_ref()),
3390 Some(LimitClause::OffsetCommaLimit { .. }) => Err(QueryError::UnsupportedFeature(
3391 "MySQL-style `LIMIT <offset>, <limit>` is not supported".to_string(),
3392 )),
3393 }
3394}
3395
3396fn query_offset(query: &Query) -> Option<&sqlparser::ast::Offset> {
3400 use sqlparser::ast::LimitClause;
3401 match &query.limit_clause {
3402 Some(LimitClause::LimitOffset { offset, .. }) => offset.as_ref(),
3403 _ => None,
3404 }
3405}
3406
3407fn parse_offset_clause(offset: Option<&sqlparser::ast::Offset>) -> Result<Option<LimitExpr>> {
3410 let Some(off) = offset else { return Ok(None) };
3411 match &off.value {
3412 Expr::Value(vws) => match &vws.value {
3413 SqlValue::Number(n, _) => {
3414 let v: usize = n
3415 .parse()
3416 .map_err(|_| QueryError::ParseError(format!("invalid OFFSET value: {n}")))?;
3417 Ok(Some(LimitExpr::Literal(v)))
3418 }
3419 SqlValue::Placeholder(p) => Ok(Some(LimitExpr::Param(parse_placeholder_index(p)?))),
3420 other => Err(QueryError::UnsupportedFeature(format!(
3421 "OFFSET must be an integer literal or parameter; got {other:?}"
3422 ))),
3423 },
3424 other => Err(QueryError::UnsupportedFeature(format!(
3425 "OFFSET must be an integer literal or parameter; got {other:?}"
3426 ))),
3427 }
3428}
3429
3430fn parse_placeholder_index(placeholder: &str) -> Result<usize> {
3436 let num_str = placeholder.strip_prefix('$').ok_or_else(|| {
3437 QueryError::ParseError(format!("unsupported placeholder format: {placeholder}"))
3438 })?;
3439 let idx: usize = num_str.parse().map_err(|_| {
3440 QueryError::ParseError(format!("invalid parameter placeholder: {placeholder}"))
3441 })?;
3442 if idx == 0 {
3443 return Err(QueryError::ParseError(
3444 "parameter indices start at $1, not $0".to_string(),
3445 ));
3446 }
3447 Ok(idx)
3448}
3449
3450fn object_name_to_string(name: &ObjectName) -> String {
3451 name.0
3452 .iter()
3453 .map(|part| match part.as_ident() {
3454 Some(ident) => ident.value.clone(),
3455 None => part.to_string(),
3456 })
3457 .collect::<Vec<_>>()
3458 .join(".")
3459}
3460
3461fn parse_create_table(create_table: &sqlparser::ast::CreateTable) -> Result<ParsedCreateTable> {
3466 let table_name = object_name_to_string(&create_table.name);
3467
3468 let mut raw_columns = Vec::new();
3478 for col_def in &create_table.columns {
3479 let parsed_col = parse_column_def(col_def)?;
3480 raw_columns.push(parsed_col);
3481 }
3482 let columns = NonEmptyVec::try_new(raw_columns).map_err(|_| {
3483 crate::error::QueryError::ParseError(format!(
3484 "CREATE TABLE {table_name} requires at least one column"
3485 ))
3486 })?;
3487
3488 let mut primary_key = Vec::new();
3490 for constraint in &create_table.constraints {
3491 if let sqlparser::ast::TableConstraint::PrimaryKey(pk) = constraint {
3492 for col in &pk.columns {
3493 if let Expr::Identifier(ident) = &col.column.expr {
3494 primary_key.push(ident.value.clone());
3495 } else {
3496 primary_key.push(col.column.expr.to_string());
3497 }
3498 }
3499 }
3500 }
3501
3502 if primary_key.is_empty() {
3504 for col_def in &create_table.columns {
3505 for option in &col_def.options {
3506 if matches!(&option.option, sqlparser::ast::ColumnOption::PrimaryKey(_)) {
3507 primary_key.push(col_def.name.value.clone());
3508 }
3509 }
3510 }
3511 }
3512
3513 Ok(ParsedCreateTable {
3514 table_name,
3515 columns,
3516 primary_key,
3517 if_not_exists: create_table.if_not_exists,
3518 })
3519}
3520
3521fn parse_column_def(col_def: &SqlColumnDef) -> Result<ParsedColumn> {
3522 let name = col_def.name.value.clone();
3523
3524 let data_type = match &col_def.data_type {
3527 SqlDataType::TinyInt(_) => "TINYINT".to_string(),
3529 SqlDataType::SmallInt(_) => "SMALLINT".to_string(),
3530 SqlDataType::Int(_) | SqlDataType::Integer(_) => "INTEGER".to_string(),
3531 SqlDataType::BigInt(_) => "BIGINT".to_string(),
3532
3533 SqlDataType::Real | SqlDataType::Float(_) | SqlDataType::Double(_) => "REAL".to_string(),
3535 SqlDataType::Decimal(precision_opt) => match precision_opt {
3536 sqlparser::ast::ExactNumberInfo::PrecisionAndScale(p, s) => {
3537 format!("DECIMAL({p},{s})")
3538 }
3539 sqlparser::ast::ExactNumberInfo::Precision(p) => {
3540 format!("DECIMAL({p},0)")
3541 }
3542 sqlparser::ast::ExactNumberInfo::None => "DECIMAL(18,2)".to_string(),
3543 },
3544
3545 SqlDataType::Text | SqlDataType::Varchar(_) | SqlDataType::String(_) => "TEXT".to_string(),
3547
3548 SqlDataType::Binary(_) | SqlDataType::Varbinary(_) | SqlDataType::Blob(_) => {
3550 "BYTES".to_string()
3551 }
3552
3553 SqlDataType::Boolean | SqlDataType::Bool => "BOOLEAN".to_string(),
3555
3556 SqlDataType::Date => "DATE".to_string(),
3558 SqlDataType::Time(_, _) => "TIME".to_string(),
3559 SqlDataType::Timestamp(_, _) => "TIMESTAMP".to_string(),
3560
3561 SqlDataType::Uuid => "UUID".to_string(),
3563 SqlDataType::JSON => "JSON".to_string(),
3564
3565 other => {
3566 return Err(QueryError::UnsupportedFeature(format!(
3567 "unsupported data type: {other:?}"
3568 )));
3569 }
3570 };
3571
3572 let mut nullable = true;
3574 for option in &col_def.options {
3575 if matches!(option.option, sqlparser::ast::ColumnOption::NotNull) {
3576 nullable = false;
3577 }
3578 }
3579
3580 Ok(ParsedColumn {
3581 name,
3582 data_type,
3583 nullable,
3584 })
3585}
3586
3587fn parse_alter_table(
3588 name: &sqlparser::ast::ObjectName,
3589 operations: &[sqlparser::ast::AlterTableOperation],
3590) -> Result<ParsedAlterTable> {
3591 let table_name = object_name_to_string(name);
3592
3593 if operations.len() != 1 {
3595 return Err(QueryError::UnsupportedFeature(
3596 "ALTER TABLE supports only one operation at a time".to_string(),
3597 ));
3598 }
3599
3600 let operation = match &operations[0] {
3601 sqlparser::ast::AlterTableOperation::AddColumn { column_def, .. } => {
3602 let parsed_col = parse_column_def(column_def)?;
3603 AlterTableOperation::AddColumn(parsed_col)
3604 }
3605 sqlparser::ast::AlterTableOperation::DropColumn {
3606 column_names,
3607 if_exists: _,
3608 ..
3609 } => {
3610 if column_names.len() != 1 {
3611 return Err(QueryError::UnsupportedFeature(
3612 "ALTER TABLE DROP COLUMN supports exactly one column".to_string(),
3613 ));
3614 }
3615 let col_name = column_names[0].value.clone();
3616 AlterTableOperation::DropColumn(col_name)
3617 }
3618 other => {
3619 return Err(QueryError::UnsupportedFeature(format!(
3620 "ALTER TABLE operation not supported: {other:?}"
3621 )));
3622 }
3623 };
3624
3625 Ok(ParsedAlterTable {
3626 table_name,
3627 operation,
3628 })
3629}
3630
3631fn parse_create_index(create_index: &sqlparser::ast::CreateIndex) -> Result<ParsedCreateIndex> {
3632 let index_name = match &create_index.name {
3633 Some(name) => object_name_to_string(name),
3634 None => {
3635 return Err(QueryError::ParseError(
3636 "CREATE INDEX requires an index name".to_string(),
3637 ));
3638 }
3639 };
3640
3641 let table_name = object_name_to_string(&create_index.table_name);
3642
3643 let mut columns = Vec::new();
3644 for col in &create_index.columns {
3645 columns.push(col.column.expr.to_string());
3646 }
3647
3648 Ok(ParsedCreateIndex {
3649 index_name,
3650 table_name,
3651 columns,
3652 })
3653}
3654
3655fn parse_insert(insert: &sqlparser::ast::Insert) -> Result<ParsedInsert> {
3660 let table = insert.table.to_string();
3662
3663 let columns: Vec<String> = insert.columns.iter().map(|c| c.value.clone()).collect();
3665
3666 let values = match insert.source.as_ref().map(|s| s.body.as_ref()) {
3668 Some(SetExpr::Values(values)) => {
3669 let mut all_rows = Vec::new();
3670 for row in &values.rows {
3671 let mut parsed_row = Vec::new();
3672 for expr in row {
3673 let val = expr_to_value(expr)?;
3674 parsed_row.push(val);
3675 }
3676 all_rows.push(parsed_row);
3677 }
3678 all_rows
3679 }
3680 _ => {
3681 return Err(QueryError::UnsupportedFeature(
3682 "only VALUES clause is supported in INSERT".to_string(),
3683 ));
3684 }
3685 };
3686
3687 let returning = parse_returning(insert.returning.as_ref())?;
3689
3690 let on_conflict = match insert.on.as_ref() {
3695 None => None,
3696 Some(sqlparser::ast::OnInsert::OnConflict(oc)) => Some(parse_on_conflict(oc)?),
3697 Some(sqlparser::ast::OnInsert::DuplicateKeyUpdate(_)) => {
3698 return Err(QueryError::UnsupportedFeature(
3699 "ON DUPLICATE KEY UPDATE is not supported; use ON CONFLICT (cols) DO UPDATE"
3700 .to_string(),
3701 ));
3702 }
3703 Some(other) => {
3704 return Err(QueryError::UnsupportedFeature(format!(
3705 "unsupported ON clause on INSERT: {other:?}"
3706 )));
3707 }
3708 };
3709
3710 Ok(ParsedInsert {
3711 table,
3712 columns,
3713 values,
3714 returning,
3715 on_conflict,
3716 })
3717}
3718
3719fn parse_on_conflict(oc: &sqlparser::ast::OnConflict) -> Result<OnConflictClause> {
3726 let target = match oc.conflict_target.as_ref() {
3728 Some(sqlparser::ast::ConflictTarget::Columns(cols)) => {
3729 if cols.is_empty() {
3730 return Err(QueryError::ParseError(
3731 "ON CONFLICT requires at least one target column".to_string(),
3732 ));
3733 }
3734 cols.iter().map(|i| i.value.clone()).collect()
3735 }
3736 Some(sqlparser::ast::ConflictTarget::OnConstraint(_)) => {
3737 return Err(QueryError::UnsupportedFeature(
3738 "ON CONFLICT ON CONSTRAINT <name> is not supported; use ON CONFLICT (cols) instead"
3739 .to_string(),
3740 ));
3741 }
3742 None => {
3743 return Err(QueryError::UnsupportedFeature(
3744 "ON CONFLICT without a target column list is not supported".to_string(),
3745 ));
3746 }
3747 };
3748
3749 let action = match &oc.action {
3750 sqlparser::ast::OnConflictAction::DoNothing => OnConflictAction::DoNothing,
3751 sqlparser::ast::OnConflictAction::DoUpdate(du) => {
3752 if du.selection.is_some() {
3753 return Err(QueryError::UnsupportedFeature(
3754 "ON CONFLICT DO UPDATE WHERE ... is not yet supported".to_string(),
3755 ));
3756 }
3757 let mut assignments = Vec::with_capacity(du.assignments.len());
3758 for a in &du.assignments {
3759 let col = a.target.to_string();
3760 let rhs = parse_upsert_expr(&a.value)?;
3761 assignments.push((col, rhs));
3762 }
3763 OnConflictAction::DoUpdate { assignments }
3764 }
3765 };
3766
3767 Ok(OnConflictClause { target, action })
3768}
3769
3770fn parse_upsert_expr(expr: &Expr) -> Result<UpsertExpr> {
3777 if let Expr::CompoundIdentifier(parts) = expr {
3779 if parts.len() == 2 && parts[0].value.eq_ignore_ascii_case("EXCLUDED") {
3780 return Ok(UpsertExpr::Excluded(parts[1].value.clone()));
3781 }
3782 }
3783 let v = expr_to_value(expr)?;
3786 Ok(UpsertExpr::Value(v))
3787}
3788
3789fn parse_update(
3790 table: &sqlparser::ast::TableWithJoins,
3791 assignments: &[sqlparser::ast::Assignment],
3792 selection: Option<&Expr>,
3793 returning: Option<&Vec<SelectItem>>,
3794) -> Result<ParsedUpdate> {
3795 let table_name = match &table.relation {
3796 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
3797 other => {
3798 return Err(QueryError::UnsupportedFeature(format!(
3799 "unsupported table in UPDATE: {other:?}"
3800 )));
3801 }
3802 };
3803
3804 let mut parsed_assignments = Vec::new();
3806 for assignment in assignments {
3807 let col_name = assignment.target.to_string();
3808 let value = expr_to_value(&assignment.value)?;
3809 parsed_assignments.push((col_name, value));
3810 }
3811
3812 let predicates = match selection {
3814 Some(expr) => parse_where_expr(expr)?,
3815 None => vec![],
3816 };
3817
3818 let returning_cols = parse_returning(returning)?;
3820
3821 Ok(ParsedUpdate {
3822 table: table_name,
3823 assignments: parsed_assignments,
3824 predicates,
3825 returning: returning_cols,
3826 })
3827}
3828
3829fn parse_delete_stmt(delete: &sqlparser::ast::Delete) -> Result<ParsedDelete> {
3830 use sqlparser::ast::FromTable;
3832
3833 let table_name = match &delete.from {
3834 FromTable::WithFromKeyword(tables) => {
3835 if tables.len() != 1 {
3836 return Err(QueryError::ParseError(
3837 "expected exactly 1 table in DELETE FROM".to_string(),
3838 ));
3839 }
3840
3841 match &tables[0].relation {
3842 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
3843 _ => {
3844 return Err(QueryError::ParseError(
3845 "DELETE only supports simple table names".to_string(),
3846 ));
3847 }
3848 }
3849 }
3850 FromTable::WithoutKeyword(tables) => {
3851 if tables.len() != 1 {
3852 return Err(QueryError::ParseError(
3853 "expected exactly 1 table in DELETE".to_string(),
3854 ));
3855 }
3856
3857 match &tables[0].relation {
3858 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
3859 _ => {
3860 return Err(QueryError::ParseError(
3861 "DELETE only supports simple table names".to_string(),
3862 ));
3863 }
3864 }
3865 }
3866 };
3867
3868 let predicates = match &delete.selection {
3870 Some(expr) => parse_where_expr(expr)?,
3871 None => vec![],
3872 };
3873
3874 let returning_cols = parse_returning(delete.returning.as_ref())?;
3876
3877 Ok(ParsedDelete {
3878 table: table_name,
3879 predicates,
3880 returning: returning_cols,
3881 })
3882}
3883
3884fn parse_returning(returning: Option<&Vec<SelectItem>>) -> Result<Option<Vec<String>>> {
3886 match returning {
3887 None => Ok(None),
3888 Some(items) => {
3889 let mut columns = Vec::new();
3890 for item in items {
3891 match item {
3892 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
3893 columns.push(ident.value.clone());
3894 }
3895 SelectItem::UnnamedExpr(Expr::CompoundIdentifier(parts)) => {
3896 if let Some(last) = parts.last() {
3898 columns.push(last.value.clone());
3899 } else {
3900 return Err(QueryError::ParseError(
3901 "invalid column in RETURNING clause".to_string(),
3902 ));
3903 }
3904 }
3905 _ => {
3906 return Err(QueryError::UnsupportedFeature(
3907 "only simple column names supported in RETURNING clause".to_string(),
3908 ));
3909 }
3910 }
3911 }
3912 Ok(Some(columns))
3913 }
3914 }
3915}
3916
3917fn parse_number_literal(n: &str) -> Result<Value> {
3921 use rust_decimal::Decimal;
3922 use std::str::FromStr;
3923
3924 if n.contains('.') {
3925 let decimal = Decimal::from_str(n)
3927 .map_err(|e| QueryError::ParseError(format!("invalid decimal '{n}': {e}")))?;
3928
3929 let scale = decimal.scale() as u8;
3931
3932 if scale > 38 {
3933 return Err(QueryError::ParseError(format!(
3934 "decimal scale too large (max 38): {n}"
3935 )));
3936 }
3937
3938 let mantissa = decimal.mantissa();
3941
3942 Ok(Value::Decimal(mantissa, scale))
3943 } else {
3944 let v: i64 = n
3946 .parse()
3947 .map_err(|_| QueryError::ParseError(format!("invalid integer: {n}")))?;
3948 Ok(Value::BigInt(v))
3949 }
3950}
3951
3952fn expr_to_value(expr: &Expr) -> Result<Value> {
3954 match expr {
3955 Expr::Value(vws) => match &vws.value {
3956 SqlValue::Number(n, _) => parse_number_literal(n),
3957 SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => {
3958 Ok(Value::Text(s.clone()))
3959 }
3960 SqlValue::Boolean(b) => Ok(Value::Boolean(*b)),
3961 SqlValue::Null => Ok(Value::Null),
3962 SqlValue::Placeholder(p) => Ok(Value::Placeholder(parse_placeholder_index(p)?)),
3963 other => Err(QueryError::UnsupportedFeature(format!(
3964 "unsupported value expression: {other:?}"
3965 ))),
3966 },
3967 Expr::UnaryOp {
3968 op: sqlparser::ast::UnaryOperator::Minus,
3969 expr,
3970 } => {
3971 if let Expr::Value(vws) = expr.as_ref()
3973 && let SqlValue::Number(n, _) = &vws.value
3974 {
3975 let value = parse_number_literal(n)?;
3976 match value {
3977 Value::BigInt(v) => Ok(Value::BigInt(-v)),
3978 Value::Decimal(v, scale) => Ok(Value::Decimal(-v, scale)),
3979 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
3980 }
3981 } else {
3982 Err(QueryError::UnsupportedFeature(format!(
3983 "unsupported unary minus operand: {expr:?}"
3984 )))
3985 }
3986 }
3987 other => Err(QueryError::UnsupportedFeature(format!(
3988 "unsupported value expression: {other:?}"
3989 ))),
3990 }
3991}
3992
3993#[cfg(test)]
3994mod tests {
3995 use super::*;
3996
3997 fn parse_test_select(sql: &str) -> ParsedSelect {
3998 match parse_statement(sql).unwrap() {
3999 ParsedStatement::Select(s) => s,
4000 _ => panic!("expected SELECT statement"),
4001 }
4002 }
4003
4004 #[test]
4005 fn test_parse_simple_select() {
4006 let result = parse_test_select("SELECT id, name FROM users");
4007 assert_eq!(result.table, "users");
4008 assert_eq!(
4009 result.columns,
4010 Some(vec![ColumnName::new("id"), ColumnName::new("name")])
4011 );
4012 assert!(result.predicates.is_empty());
4013 }
4014
4015 #[test]
4016 fn test_parse_select_star() {
4017 let result = parse_test_select("SELECT * FROM users");
4018 assert_eq!(result.table, "users");
4019 assert!(result.columns.is_none());
4020 }
4021
4022 #[test]
4023 fn test_parse_where_eq() {
4024 let result = parse_test_select("SELECT * FROM users WHERE id = 42");
4025 assert_eq!(result.predicates.len(), 1);
4026 match &result.predicates[0] {
4027 Predicate::Eq(col, PredicateValue::Int(42)) => {
4028 assert_eq!(col.as_str(), "id");
4029 }
4030 other => panic!("unexpected predicate: {other:?}"),
4031 }
4032 }
4033
4034 #[test]
4035 fn test_parse_where_string() {
4036 let result = parse_test_select("SELECT * FROM users WHERE name = 'alice'");
4037 match &result.predicates[0] {
4038 Predicate::Eq(col, PredicateValue::String(s)) => {
4039 assert_eq!(col.as_str(), "name");
4040 assert_eq!(s, "alice");
4041 }
4042 other => panic!("unexpected predicate: {other:?}"),
4043 }
4044 }
4045
4046 #[test]
4047 fn test_parse_where_and() {
4048 let result = parse_test_select("SELECT * FROM users WHERE id = 1 AND name = 'bob'");
4049 assert_eq!(result.predicates.len(), 2);
4050 }
4051
4052 #[test]
4053 fn test_parse_where_in() {
4054 let result = parse_test_select("SELECT * FROM users WHERE id IN (1, 2, 3)");
4055 match &result.predicates[0] {
4056 Predicate::In(col, values) => {
4057 assert_eq!(col.as_str(), "id");
4058 assert_eq!(values.len(), 3);
4059 }
4060 other => panic!("unexpected predicate: {other:?}"),
4061 }
4062 }
4063
4064 #[test]
4065 fn test_parse_order_by() {
4066 let result = parse_test_select("SELECT * FROM users ORDER BY name ASC, id DESC");
4067 assert_eq!(result.order_by.len(), 2);
4068 assert_eq!(result.order_by[0].column.as_str(), "name");
4069 assert!(result.order_by[0].ascending);
4070 assert_eq!(result.order_by[1].column.as_str(), "id");
4071 assert!(!result.order_by[1].ascending);
4072 }
4073
4074 #[test]
4075 fn test_parse_limit() {
4076 let result = parse_test_select("SELECT * FROM users LIMIT 10");
4077 assert_eq!(result.limit, Some(LimitExpr::Literal(10)));
4078 }
4079
4080 #[test]
4081 fn test_parse_limit_param() {
4082 let result = parse_test_select("SELECT * FROM users LIMIT $1");
4083 assert_eq!(result.limit, Some(LimitExpr::Param(1)));
4084 }
4085
4086 #[test]
4087 fn test_parse_offset_literal() {
4088 let result = parse_test_select("SELECT * FROM users LIMIT 10 OFFSET 5");
4089 assert_eq!(result.limit, Some(LimitExpr::Literal(10)));
4090 assert_eq!(result.offset, Some(LimitExpr::Literal(5)));
4091 }
4092
4093 #[test]
4094 fn test_parse_offset_param() {
4095 let result = parse_test_select("SELECT * FROM users LIMIT $1 OFFSET $2");
4096 assert_eq!(result.limit, Some(LimitExpr::Param(1)));
4097 assert_eq!(result.offset, Some(LimitExpr::Param(2)));
4098 }
4099
4100 #[test]
4101 fn test_parse_param() {
4102 let result = parse_test_select("SELECT * FROM users WHERE id = $1");
4103 match &result.predicates[0] {
4104 Predicate::Eq(_, PredicateValue::Param(1)) => {}
4105 other => panic!("unexpected predicate: {other:?}"),
4106 }
4107 }
4108
4109 #[test]
4110 fn test_parse_inner_join() {
4111 let result =
4112 parse_statement("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
4113 if let Err(ref e) = result {
4114 eprintln!("Parse error: {e:?}");
4115 }
4116 assert!(result.is_ok());
4117 match result.unwrap() {
4118 ParsedStatement::Select(s) => {
4119 assert_eq!(s.table, "users");
4120 assert_eq!(s.joins.len(), 1);
4121 assert_eq!(s.joins[0].table, "orders");
4122 assert!(matches!(s.joins[0].join_type, JoinType::Inner));
4123 }
4124 _ => panic!("expected SELECT statement"),
4125 }
4126 }
4127
4128 #[test]
4129 fn test_parse_left_join() {
4130 let result =
4131 parse_statement("SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id");
4132 assert!(result.is_ok());
4133 match result.unwrap() {
4134 ParsedStatement::Select(s) => {
4135 assert_eq!(s.table, "users");
4136 assert_eq!(s.joins.len(), 1);
4137 assert_eq!(s.joins[0].table, "orders");
4138 assert!(matches!(s.joins[0].join_type, JoinType::Left));
4139 }
4140 _ => panic!("expected SELECT statement"),
4141 }
4142 }
4143
4144 #[test]
4145 fn test_parse_multi_join() {
4146 let result = parse_statement(
4147 "SELECT * FROM users \
4148 JOIN orders ON users.id = orders.user_id \
4149 JOIN products ON orders.product_id = products.id",
4150 );
4151 assert!(result.is_ok());
4152 match result.unwrap() {
4153 ParsedStatement::Select(s) => {
4154 assert_eq!(s.table, "users");
4155 assert_eq!(s.joins.len(), 2);
4156 assert_eq!(s.joins[0].table, "orders");
4157 assert_eq!(s.joins[1].table, "products");
4158 }
4159 _ => panic!("expected SELECT statement"),
4160 }
4161 }
4162
4163 #[test]
4164 fn test_reject_subquery() {
4165 let result = parse_statement("SELECT * FROM (SELECT * FROM users)");
4166 assert!(result.is_err());
4167 }
4168
4169 #[test]
4170 fn test_where_depth_within_limit() {
4171 let mut sql = String::from("SELECT * FROM users WHERE ");
4174 for i in 0..10 {
4175 if i > 0 {
4176 sql.push_str(" AND ");
4177 }
4178 sql.push('(');
4179 sql.push_str("id = ");
4180 sql.push_str(&i.to_string());
4181 sql.push(')');
4182 }
4183
4184 let result = parse_statement(&sql);
4185 assert!(
4186 result.is_ok(),
4187 "Moderate nesting should succeed, but got: {result:?}"
4188 );
4189 }
4190
4191 #[test]
4192 fn test_where_depth_nested_parens() {
4193 let mut sql = String::from("SELECT * FROM users WHERE ");
4196 for _ in 0..200 {
4197 sql.push('(');
4198 }
4199 sql.push_str("id = 1");
4200 for _ in 0..200 {
4201 sql.push(')');
4202 }
4203
4204 let result = parse_statement(&sql);
4205 assert!(
4206 result.is_err(),
4207 "Excessive parenthesis nesting should be rejected"
4208 );
4209 }
4210
4211 #[test]
4212 fn test_where_depth_complex_and_or() {
4213 let sql = "SELECT * FROM users WHERE \
4215 ((id = 1 AND name = 'a') OR (id = 2 AND name = 'b')) AND \
4216 ((age > 10 AND age < 20) OR (age > 30 AND age < 40))";
4217
4218 let result = parse_statement(sql);
4219 assert!(result.is_ok(), "Complex AND/OR should succeed");
4220 }
4221
4222 #[test]
4223 fn test_parse_having() {
4224 let result =
4225 parse_test_select("SELECT name, COUNT(*) FROM users GROUP BY name HAVING COUNT(*) > 5");
4226 assert_eq!(result.group_by.len(), 1);
4227 assert_eq!(result.having.len(), 1);
4228 match &result.having[0] {
4229 HavingCondition::AggregateComparison {
4230 aggregate,
4231 op,
4232 value,
4233 } => {
4234 assert!(matches!(aggregate, AggregateFunction::CountStar));
4235 assert_eq!(*op, HavingOp::Gt);
4236 assert_eq!(*value, Value::BigInt(5));
4237 }
4238 }
4239 }
4240
4241 #[test]
4242 fn test_parse_having_multiple() {
4243 let result = parse_test_select(
4244 "SELECT name, COUNT(*), SUM(age) FROM users GROUP BY name HAVING COUNT(*) > 1 AND SUM(age) < 100",
4245 );
4246 assert_eq!(result.having.len(), 2);
4247 }
4248
4249 #[test]
4250 fn test_parse_having_without_group_by() {
4251 let result = parse_test_select("SELECT COUNT(*) FROM users HAVING COUNT(*) > 0");
4252 assert!(result.group_by.is_empty());
4253 assert_eq!(result.having.len(), 1);
4254 }
4255
4256 #[test]
4257 fn test_parse_union() {
4258 let result = parse_statement("SELECT id FROM users UNION SELECT id FROM orders");
4259 assert!(result.is_ok());
4260 match result.unwrap() {
4261 ParsedStatement::Union(u) => {
4262 assert_eq!(u.left.table, "users");
4263 assert_eq!(u.right.table, "orders");
4264 assert!(!u.all);
4265 }
4266 _ => panic!("expected UNION statement"),
4267 }
4268 }
4269
4270 #[test]
4271 fn test_parse_union_all() {
4272 let result = parse_statement("SELECT id FROM users UNION ALL SELECT id FROM orders");
4273 assert!(result.is_ok());
4274 match result.unwrap() {
4275 ParsedStatement::Union(u) => {
4276 assert_eq!(u.left.table, "users");
4277 assert_eq!(u.right.table, "orders");
4278 assert!(u.all);
4279 }
4280 _ => panic!("expected UNION ALL statement"),
4281 }
4282 }
4283
4284 #[test]
4285 fn test_parse_create_mask() {
4286 let result = parse_statement("CREATE MASK ssn_mask ON patients.ssn USING REDACT").unwrap();
4287 match result {
4288 ParsedStatement::CreateMask(m) => {
4289 assert_eq!(m.mask_name, "ssn_mask");
4290 assert_eq!(m.table_name, "patients");
4291 assert_eq!(m.column_name, "ssn");
4292 assert_eq!(m.strategy, "REDACT");
4293 }
4294 _ => panic!("expected CREATE MASK statement"),
4295 }
4296 }
4297
4298 #[test]
4299 fn test_parse_create_mask_with_semicolon() {
4300 let result = parse_statement("CREATE MASK ssn_mask ON patients.ssn USING REDACT;").unwrap();
4301 match result {
4302 ParsedStatement::CreateMask(m) => {
4303 assert_eq!(m.mask_name, "ssn_mask");
4304 assert_eq!(m.strategy, "REDACT");
4305 }
4306 _ => panic!("expected CREATE MASK statement"),
4307 }
4308 }
4309
4310 #[test]
4311 fn test_parse_create_mask_hash_strategy() {
4312 let result = parse_statement("CREATE MASK email_hash ON users.email USING HASH").unwrap();
4313 match result {
4314 ParsedStatement::CreateMask(m) => {
4315 assert_eq!(m.mask_name, "email_hash");
4316 assert_eq!(m.table_name, "users");
4317 assert_eq!(m.column_name, "email");
4318 assert_eq!(m.strategy, "HASH");
4319 }
4320 _ => panic!("expected CREATE MASK statement"),
4321 }
4322 }
4323
4324 #[test]
4325 fn test_parse_create_mask_missing_on() {
4326 let result = parse_statement("CREATE MASK ssn_mask patients.ssn USING REDACT");
4327 assert!(result.is_err());
4328 }
4329
4330 #[test]
4331 fn test_parse_create_mask_missing_dot() {
4332 let result = parse_statement("CREATE MASK ssn_mask ON patients_ssn USING REDACT");
4333 assert!(result.is_err());
4334 }
4335
4336 #[test]
4337 fn test_parse_drop_mask() {
4338 let result = parse_statement("DROP MASK ssn_mask").unwrap();
4339 match result {
4340 ParsedStatement::DropMask(name) => {
4341 assert_eq!(name, "ssn_mask");
4342 }
4343 _ => panic!("expected DROP MASK statement"),
4344 }
4345 }
4346
4347 #[test]
4348 fn test_parse_drop_mask_with_semicolon() {
4349 let result = parse_statement("DROP MASK ssn_mask;").unwrap();
4350 match result {
4351 ParsedStatement::DropMask(name) => {
4352 assert_eq!(name, "ssn_mask");
4353 }
4354 _ => panic!("expected DROP MASK statement"),
4355 }
4356 }
4357
4358 #[test]
4363 fn test_parse_create_masking_policy_redact_ssn() {
4364 let result = parse_statement(
4365 "CREATE MASKING POLICY ssn_policy STRATEGY REDACT_SSN EXEMPT ROLES ('clinician', 'billing')",
4366 )
4367 .unwrap();
4368 match result {
4369 ParsedStatement::CreateMaskingPolicy(p) => {
4370 assert_eq!(p.name, "ssn_policy");
4371 assert_eq!(p.strategy, ParsedMaskingStrategy::RedactSsn);
4372 assert_eq!(p.exempt_roles, vec!["clinician", "billing"]);
4373 }
4374 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4375 }
4376 }
4377
4378 #[test]
4379 fn test_parse_create_masking_policy_hash_single_role() {
4380 let result =
4381 parse_statement("CREATE MASKING POLICY h STRATEGY HASH EXEMPT ROLES (admin)").unwrap();
4382 match result {
4383 ParsedStatement::CreateMaskingPolicy(p) => {
4384 assert_eq!(p.name, "h");
4385 assert_eq!(p.strategy, ParsedMaskingStrategy::Hash);
4386 assert_eq!(p.exempt_roles, vec!["admin"]);
4387 }
4388 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4389 }
4390 }
4391
4392 #[test]
4393 fn test_parse_create_masking_policy_tokenize() {
4394 let result = parse_statement(
4395 "CREATE MASKING POLICY note_tok STRATEGY TOKENIZE EXEMPT ROLES ('clinician');",
4396 )
4397 .unwrap();
4398 match result {
4399 ParsedStatement::CreateMaskingPolicy(p) => {
4400 assert_eq!(p.strategy, ParsedMaskingStrategy::Tokenize);
4401 assert_eq!(p.exempt_roles, vec!["clinician"]);
4402 }
4403 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4404 }
4405 }
4406
4407 #[test]
4408 fn test_parse_create_masking_policy_truncate_with_arg() {
4409 let result = parse_statement(
4410 "CREATE MASKING POLICY tr STRATEGY TRUNCATE 4 EXEMPT ROLES ('billing')",
4411 )
4412 .unwrap();
4413 match result {
4414 ParsedStatement::CreateMaskingPolicy(p) => {
4415 assert_eq!(p.strategy, ParsedMaskingStrategy::Truncate { max_chars: 4 });
4416 }
4417 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4418 }
4419 }
4420
4421 #[test]
4422 fn test_parse_create_masking_policy_redact_custom() {
4423 let result = parse_statement(
4424 "CREATE MASKING POLICY c STRATEGY REDACT_CUSTOM '***' EXEMPT ROLES ('admin')",
4425 )
4426 .unwrap();
4427 match result {
4428 ParsedStatement::CreateMaskingPolicy(p) => match p.strategy {
4429 ParsedMaskingStrategy::RedactCustom { replacement } => {
4430 assert_eq!(replacement, "***");
4431 }
4432 other => panic!("expected RedactCustom, got {other:?}"),
4433 },
4434 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4435 }
4436 }
4437
4438 #[test]
4439 fn test_parse_create_masking_policy_null_strategy() {
4440 let result =
4441 parse_statement("CREATE MASKING POLICY n STRATEGY NULL EXEMPT ROLES ('auditor')")
4442 .unwrap();
4443 match result {
4444 ParsedStatement::CreateMaskingPolicy(p) => {
4445 assert_eq!(p.strategy, ParsedMaskingStrategy::Null);
4446 }
4447 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4448 }
4449 }
4450
4451 #[test]
4452 fn test_parse_create_masking_policy_lowercases_roles() {
4453 let result = parse_statement(
4456 "CREATE MASKING POLICY p STRATEGY HASH EXEMPT ROLES ('Clinician', 'NURSE')",
4457 )
4458 .unwrap();
4459 match result {
4460 ParsedStatement::CreateMaskingPolicy(p) => {
4461 assert_eq!(p.exempt_roles, vec!["clinician", "nurse"]);
4462 }
4463 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4464 }
4465 }
4466
4467 #[test]
4468 fn test_parse_create_masking_policy_rejects_unknown_strategy() {
4469 let result =
4470 parse_statement("CREATE MASKING POLICY p STRATEGY SCRAMBLE EXEMPT ROLES ('x')");
4471 assert!(result.is_err(), "expected unknown-strategy error");
4472 }
4473
4474 #[test]
4475 fn test_parse_create_masking_policy_rejects_zero_truncate() {
4476 let result =
4477 parse_statement("CREATE MASKING POLICY p STRATEGY TRUNCATE 0 EXEMPT ROLES ('x')");
4478 assert!(result.is_err(), "TRUNCATE 0 must be rejected");
4479 }
4480
4481 #[test]
4482 fn test_parse_create_masking_policy_rejects_empty_exempt_list() {
4483 let result = parse_statement("CREATE MASKING POLICY p STRATEGY HASH EXEMPT ROLES ()");
4484 assert!(result.is_err(), "empty EXEMPT ROLES list must be rejected");
4485 }
4486
4487 #[test]
4488 fn test_parse_create_masking_policy_rejects_missing_exempt_roles() {
4489 let result = parse_statement("CREATE MASKING POLICY p STRATEGY HASH");
4490 assert!(
4491 result.is_err(),
4492 "missing EXEMPT ROLES clause must be rejected"
4493 );
4494 }
4495
4496 #[test]
4497 fn test_parse_drop_masking_policy() {
4498 let result = parse_statement("DROP MASKING POLICY ssn_policy").unwrap();
4499 match result {
4500 ParsedStatement::DropMaskingPolicy(name) => {
4501 assert_eq!(name, "ssn_policy");
4502 }
4503 other => panic!("expected DropMaskingPolicy, got {other:?}"),
4504 }
4505 }
4506
4507 #[test]
4508 fn test_parse_drop_masking_policy_with_semicolon() {
4509 let result = parse_statement("DROP MASKING POLICY ssn_policy;").unwrap();
4510 match result {
4511 ParsedStatement::DropMaskingPolicy(name) => {
4512 assert_eq!(name, "ssn_policy");
4513 }
4514 other => panic!("expected DropMaskingPolicy, got {other:?}"),
4515 }
4516 }
4517
4518 #[test]
4519 fn test_parse_drop_masking_policy_does_not_swallow_drop_mask() {
4520 let result = parse_statement("DROP MASK ssn_mask").unwrap();
4524 assert!(matches!(result, ParsedStatement::DropMask(_)));
4525 }
4526
4527 #[test]
4528 fn test_parse_attach_masking_policy() {
4529 let result = parse_statement(
4530 "ALTER TABLE patients ALTER COLUMN medicare_number SET MASKING POLICY ssn_policy",
4531 )
4532 .unwrap();
4533 match result {
4534 ParsedStatement::AttachMaskingPolicy(a) => {
4535 assert_eq!(a.table_name, "patients");
4536 assert_eq!(a.column_name, "medicare_number");
4537 assert_eq!(a.policy_name, "ssn_policy");
4538 }
4539 other => panic!("expected AttachMaskingPolicy, got {other:?}"),
4540 }
4541 }
4542
4543 #[test]
4544 fn test_parse_detach_masking_policy() {
4545 let result = parse_statement(
4546 "ALTER TABLE patients ALTER COLUMN medicare_number DROP MASKING POLICY",
4547 )
4548 .unwrap();
4549 match result {
4550 ParsedStatement::DetachMaskingPolicy(d) => {
4551 assert_eq!(d.table_name, "patients");
4552 assert_eq!(d.column_name, "medicare_number");
4553 }
4554 other => panic!("expected DetachMaskingPolicy, got {other:?}"),
4555 }
4556 }
4557
4558 #[test]
4559 fn test_parse_attach_masking_policy_rejects_missing_policy_name() {
4560 let result =
4561 parse_statement("ALTER TABLE patients ALTER COLUMN medicare_number SET MASKING POLICY");
4562 assert!(result.is_err());
4563 }
4564
4565 #[test]
4566 fn test_parse_create_masking_policy_does_not_match_legacy_create_mask() {
4567 let result =
4570 parse_statement("CREATE MASKING POLICY p STRATEGY HASH EXEMPT ROLES ('admin')")
4571 .unwrap();
4572 assert!(matches!(result, ParsedStatement::CreateMaskingPolicy(_)));
4573 }
4574
4575 #[test]
4580 fn test_parse_set_classification() {
4581 let result =
4582 parse_statement("ALTER TABLE patients MODIFY COLUMN ssn SET CLASSIFICATION 'PHI'")
4583 .unwrap();
4584 match result {
4585 ParsedStatement::SetClassification(sc) => {
4586 assert_eq!(sc.table_name, "patients");
4587 assert_eq!(sc.column_name, "ssn");
4588 assert_eq!(sc.classification, "PHI");
4589 }
4590 _ => panic!("expected SetClassification statement"),
4591 }
4592 }
4593
4594 #[test]
4595 fn test_parse_set_classification_with_semicolon() {
4596 let result = parse_statement(
4597 "ALTER TABLE patients MODIFY COLUMN diagnosis SET CLASSIFICATION 'MEDICAL';",
4598 )
4599 .unwrap();
4600 match result {
4601 ParsedStatement::SetClassification(sc) => {
4602 assert_eq!(sc.table_name, "patients");
4603 assert_eq!(sc.column_name, "diagnosis");
4604 assert_eq!(sc.classification, "MEDICAL");
4605 }
4606 _ => panic!("expected SetClassification statement"),
4607 }
4608 }
4609
4610 #[test]
4611 fn test_parse_set_classification_various_labels() {
4612 for label in &["PHI", "PII", "PCI", "MEDICAL", "FINANCIAL", "CONFIDENTIAL"] {
4613 let sql = format!("ALTER TABLE t MODIFY COLUMN c SET CLASSIFICATION '{label}'");
4614 let result = parse_statement(&sql).unwrap();
4615 match result {
4616 ParsedStatement::SetClassification(sc) => {
4617 assert_eq!(sc.classification, *label);
4618 }
4619 _ => panic!("expected SetClassification for {label}"),
4620 }
4621 }
4622 }
4623
4624 #[test]
4625 fn test_parse_set_classification_missing_quotes() {
4626 let result =
4627 parse_statement("ALTER TABLE patients MODIFY COLUMN ssn SET CLASSIFICATION PHI");
4628 assert!(result.is_err(), "classification must be single-quoted");
4629 }
4630
4631 #[test]
4632 fn test_parse_set_classification_missing_modify() {
4633 let result = parse_statement("ALTER TABLE patients SET CLASSIFICATION 'PHI'");
4636 assert!(result.is_err());
4637 }
4638
4639 #[test]
4644 fn test_parse_show_classifications() {
4645 let result = parse_statement("SHOW CLASSIFICATIONS FOR patients").unwrap();
4646 match result {
4647 ParsedStatement::ShowClassifications(table) => {
4648 assert_eq!(table, "patients");
4649 }
4650 _ => panic!("expected ShowClassifications statement"),
4651 }
4652 }
4653
4654 #[test]
4655 fn test_parse_show_classifications_with_semicolon() {
4656 let result = parse_statement("SHOW CLASSIFICATIONS FOR patients;").unwrap();
4657 match result {
4658 ParsedStatement::ShowClassifications(table) => {
4659 assert_eq!(table, "patients");
4660 }
4661 _ => panic!("expected ShowClassifications statement"),
4662 }
4663 }
4664
4665 #[test]
4666 fn test_parse_show_classifications_missing_for() {
4667 let result = parse_statement("SHOW CLASSIFICATIONS patients");
4668 assert!(result.is_err());
4669 }
4670
4671 #[test]
4672 fn test_parse_show_classifications_missing_table() {
4673 let result = parse_statement("SHOW CLASSIFICATIONS FOR");
4674 assert!(result.is_err());
4675 }
4676
4677 #[test]
4682 fn test_parse_create_role() {
4683 let result = parse_statement("CREATE ROLE billing_clerk").unwrap();
4684 match result {
4685 ParsedStatement::CreateRole(name) => {
4686 assert_eq!(name, "billing_clerk");
4687 }
4688 _ => panic!("expected CreateRole"),
4689 }
4690 }
4691
4692 #[test]
4693 fn test_parse_create_role_with_semicolon() {
4694 let result = parse_statement("CREATE ROLE doctor;").unwrap();
4695 match result {
4696 ParsedStatement::CreateRole(name) => {
4697 assert_eq!(name, "doctor");
4698 }
4699 _ => panic!("expected CreateRole"),
4700 }
4701 }
4702
4703 #[test]
4704 fn test_parse_grant_select_all_columns() {
4705 let result = parse_statement("GRANT SELECT ON patients TO doctor").unwrap();
4706 match result {
4707 ParsedStatement::Grant(g) => {
4708 assert!(g.columns.is_none());
4709 assert_eq!(g.table_name, "patients");
4710 assert_eq!(g.role_name, "doctor");
4711 }
4712 _ => panic!("expected Grant"),
4713 }
4714 }
4715
4716 #[test]
4717 fn test_parse_grant_select_specific_columns() {
4718 let result =
4719 parse_statement("GRANT SELECT (id, name, ssn) ON patients TO billing_clerk").unwrap();
4720 match result {
4721 ParsedStatement::Grant(g) => {
4722 assert_eq!(
4723 g.columns,
4724 Some(vec!["id".into(), "name".into(), "ssn".into()])
4725 );
4726 assert_eq!(g.table_name, "patients");
4727 assert_eq!(g.role_name, "billing_clerk");
4728 }
4729 _ => panic!("expected Grant"),
4730 }
4731 }
4732
4733 #[test]
4734 fn test_parse_create_user() {
4735 let result = parse_statement("CREATE USER clerk1 WITH ROLE billing_clerk").unwrap();
4736 match result {
4737 ParsedStatement::CreateUser(u) => {
4738 assert_eq!(u.username, "clerk1");
4739 assert_eq!(u.role, "billing_clerk");
4740 }
4741 _ => panic!("expected CreateUser"),
4742 }
4743 }
4744
4745 #[test]
4746 fn test_parse_create_user_with_semicolon() {
4747 let result = parse_statement("CREATE USER admin1 WITH ROLE admin;").unwrap();
4748 match result {
4749 ParsedStatement::CreateUser(u) => {
4750 assert_eq!(u.username, "admin1");
4751 assert_eq!(u.role, "admin");
4752 }
4753 _ => panic!("expected CreateUser"),
4754 }
4755 }
4756
4757 #[test]
4758 fn test_parse_create_user_missing_role() {
4759 let result = parse_statement("CREATE USER clerk1 WITH billing_clerk");
4760 assert!(result.is_err());
4761 }
4762
4763 #[test]
4768 fn test_parse_create_table_rejects_zero_columns() {
4769 let result = parse_statement("CREATE TABLE#USER");
4771 assert!(result.is_err(), "zero-column CREATE TABLE must be rejected");
4772
4773 let result = parse_statement("CREATE TABLE t ()");
4777 assert!(
4778 result.is_err(),
4779 "empty-column-list CREATE TABLE must be rejected"
4780 );
4781 }
4782
4783 fn parse_test_insert(sql: &str) -> ParsedInsert {
4788 match parse_statement(sql).unwrap_or_else(|e| panic!("parse failed: {e}")) {
4789 ParsedStatement::Insert(i) => i,
4790 other => panic!("expected INSERT statement, got {other:?}"),
4791 }
4792 }
4793
4794 #[test]
4795 fn test_parse_insert_on_conflict_do_update() {
4796 let ins = parse_test_insert(
4797 "INSERT INTO users (id, name) VALUES (1, 'Alice') \
4798 ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name",
4799 );
4800 let oc = ins.on_conflict.expect("on_conflict must be present");
4801 assert_eq!(oc.target, vec!["id".to_string()]);
4802 match oc.action {
4803 OnConflictAction::DoUpdate { assignments } => {
4804 assert_eq!(assignments.len(), 1);
4805 assert_eq!(assignments[0].0, "name");
4806 assert_eq!(
4807 assignments[0].1,
4808 UpsertExpr::Excluded("name".to_string()),
4809 "RHS must be an EXCLUDED.col back-reference"
4810 );
4811 }
4812 other @ OnConflictAction::DoNothing => panic!("expected DoUpdate, got {other:?}"),
4813 }
4814 }
4815
4816 #[test]
4817 fn test_parse_insert_on_conflict_do_nothing() {
4818 let ins = parse_test_insert(
4819 "INSERT INTO users (id, name) VALUES (1, 'Alice') ON CONFLICT (id) DO NOTHING",
4820 );
4821 let oc = ins.on_conflict.expect("on_conflict must be present");
4822 assert_eq!(oc.target, vec!["id".to_string()]);
4823 assert!(
4824 matches!(oc.action, OnConflictAction::DoNothing),
4825 "DO NOTHING must parse to OnConflictAction::DoNothing"
4826 );
4827 }
4828
4829 #[test]
4830 fn test_parse_plain_insert_has_no_on_conflict() {
4831 let ins = parse_test_insert("INSERT INTO users (id, name) VALUES (1, 'Alice')");
4832 assert!(
4833 ins.on_conflict.is_none(),
4834 "plain INSERT must not carry an on_conflict clause"
4835 );
4836 }
4837
4838 #[test]
4839 fn test_parse_insert_on_conflict_multi_column_target() {
4840 let ins = parse_test_insert(
4841 "INSERT INTO t (tenant_id, id, v) VALUES (1, 2, 3) \
4842 ON CONFLICT (tenant_id, id) DO UPDATE SET v = EXCLUDED.v",
4843 );
4844 let oc = ins.on_conflict.expect("on_conflict must be present");
4845 assert_eq!(oc.target, vec!["tenant_id".to_string(), "id".to_string()]);
4846 }
4847
4848 #[test]
4849 fn test_parse_insert_on_conflict_with_returning() {
4850 let ins = parse_test_insert(
4851 "INSERT INTO t (id, v) VALUES (1, 2) \
4852 ON CONFLICT (id) DO UPDATE SET v = EXCLUDED.v RETURNING id, v",
4853 );
4854 assert!(ins.on_conflict.is_some());
4855 assert_eq!(ins.returning, Some(vec!["id".to_string(), "v".to_string()]));
4856 }
4857
4858 #[test]
4859 fn test_parse_insert_on_conflict_rejects_on_constraint_form() {
4860 let result = parse_statement(
4862 "INSERT INTO t (id) VALUES (1) ON CONFLICT ON CONSTRAINT pk_t DO NOTHING",
4863 );
4864 assert!(
4865 result.is_err(),
4866 "ON CONSTRAINT form must be rejected with a clear error"
4867 );
4868 }
4869
4870 #[test]
4871 fn test_parse_insert_on_conflict_literal_rhs() {
4872 let ins = parse_test_insert(
4874 "INSERT INTO t (id, v) VALUES (1, 2) \
4875 ON CONFLICT (id) DO UPDATE SET v = 42",
4876 );
4877 let oc = ins.on_conflict.expect("on_conflict must be present");
4878 match oc.action {
4879 OnConflictAction::DoUpdate { assignments } => {
4880 assert_eq!(assignments[0].0, "v");
4881 assert!(matches!(assignments[0].1, UpsertExpr::Value(_)));
4882 }
4883 other @ OnConflictAction::DoNothing => panic!("expected DoUpdate, got {other:?}"),
4884 }
4885 }
4886}