1use kimberlite_types::NonEmptyVec;
15use sqlparser::ast::{
16 BinaryOperator, ColumnDef as SqlColumnDef, DataType as SqlDataType, Expr, ObjectName,
17 OrderByExpr, Query, Select, SelectItem, SetExpr, Statement, Value as SqlValue,
18};
19use sqlparser::dialect::{Dialect, GenericDialect};
20use sqlparser::parser::Parser;
21
22#[derive(Debug)]
27struct KimberliteDialect {
28 inner: GenericDialect,
29}
30
31impl KimberliteDialect {
32 const fn new() -> Self {
33 Self {
34 inner: GenericDialect {},
35 }
36 }
37}
38
39impl Dialect for KimberliteDialect {
40 fn is_identifier_start(&self, ch: char) -> bool {
41 self.inner.is_identifier_start(ch)
42 }
43
44 fn is_identifier_part(&self, ch: char) -> bool {
45 self.inner.is_identifier_part(ch)
46 }
47
48 fn supports_filter_during_aggregation(&self) -> bool {
49 true
50 }
51}
52
53use crate::error::{QueryError, Result};
54use crate::expression::ScalarExpr;
55use crate::schema::{ColumnName, DataType};
56use crate::value::Value;
57
58#[derive(Debug, Clone)]
64pub enum ParsedStatement {
65 Select(ParsedSelect),
67 Union(ParsedUnion),
69 CreateTable(ParsedCreateTable),
71 DropTable { name: String, if_exists: bool },
79 AlterTable(ParsedAlterTable),
81 CreateIndex(ParsedCreateIndex),
83 Insert(ParsedInsert),
85 Update(ParsedUpdate),
87 Delete(ParsedDelete),
89 CreateMask(ParsedCreateMask),
91 DropMask(String),
93 CreateMaskingPolicy(ParsedCreateMaskingPolicy),
95 DropMaskingPolicy(String),
97 AttachMaskingPolicy(ParsedAttachMaskingPolicy),
99 DetachMaskingPolicy(ParsedDetachMaskingPolicy),
101 SetClassification(ParsedSetClassification),
103 ShowClassifications(String),
105 ShowTables,
107 ShowColumns(String),
109 CreateRole(String),
111 Grant(ParsedGrant),
113 CreateUser(ParsedCreateUser),
115}
116
117#[derive(Debug, Clone)]
119pub struct ParsedGrant {
120 pub columns: Option<Vec<String>>,
122 pub table_name: String,
124 pub role_name: String,
126}
127
128#[derive(Debug, Clone)]
130pub struct ParsedCreateUser {
131 pub username: String,
133 pub role: String,
135}
136
137#[derive(Debug, Clone)]
139pub struct ParsedSetClassification {
140 pub table_name: String,
142 pub column_name: String,
144 pub classification: String,
146}
147
148#[derive(Debug, Clone)]
150pub struct ParsedCreateMask {
151 pub mask_name: String,
153 pub table_name: String,
155 pub column_name: String,
157 pub strategy: String,
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
165pub enum ParsedMaskingStrategy {
166 RedactSsn,
168 RedactPhone,
170 RedactEmail,
172 RedactCreditCard,
174 RedactCustom {
176 replacement: String,
178 },
179 Hash,
181 Tokenize,
183 Truncate {
185 max_chars: usize,
187 },
188 Null,
190}
191
192#[derive(Debug, Clone)]
198pub struct ParsedCreateMaskingPolicy {
199 pub name: String,
201 pub strategy: ParsedMaskingStrategy,
203 pub exempt_roles: Vec<String>,
206}
207
208#[derive(Debug, Clone)]
210#[allow(clippy::struct_field_names)] pub struct ParsedAttachMaskingPolicy {
212 pub table_name: String,
214 pub column_name: String,
216 pub policy_name: String,
218}
219
220#[derive(Debug, Clone)]
222pub struct ParsedDetachMaskingPolicy {
223 pub table_name: String,
225 pub column_name: String,
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum SetOp {
232 Union,
234 Intersect,
236 Except,
238}
239
240#[derive(Debug, Clone)]
242pub struct ParsedUnion {
243 pub op: SetOp,
245 pub left: ParsedSelect,
247 pub right: ParsedSelect,
249 pub all: bool,
252}
253
254#[derive(Debug, Clone, PartialEq, Eq)]
256pub enum JoinType {
257 Inner,
259 Left,
261 Right,
263 Full,
265 Cross,
267}
268
269#[derive(Debug, Clone)]
271pub struct ParsedJoin {
272 pub table: String,
274 pub join_type: JoinType,
276 pub on_condition: Vec<Predicate>,
278}
279
280#[derive(Debug, Clone)]
282pub struct ParsedCte {
283 pub name: String,
285 pub query: ParsedSelect,
287 pub recursive_arm: Option<ParsedSelect>,
290}
291
292#[derive(Debug, Clone)]
296pub struct ComputedColumn {
297 pub alias: ColumnName,
299 pub when_clauses: Vec<CaseWhenArm>,
301 pub else_value: Value,
303}
304
305#[derive(Debug, Clone)]
307pub struct CaseWhenArm {
308 pub condition: Vec<Predicate>,
310 pub result: Value,
312}
313
314#[derive(Debug, Clone, Copy, PartialEq, Eq)]
322pub enum LimitExpr {
323 Literal(usize),
325 Param(usize),
327}
328
329#[derive(Debug, Clone)]
331pub struct ParsedSelect {
332 pub table: String,
334 pub joins: Vec<ParsedJoin>,
336 pub columns: Option<Vec<ColumnName>>,
338 pub column_aliases: Option<Vec<Option<String>>>,
345 pub case_columns: Vec<ComputedColumn>,
347 pub predicates: Vec<Predicate>,
349 pub order_by: Vec<OrderByClause>,
351 pub limit: Option<LimitExpr>,
353 pub offset: Option<LimitExpr>,
355 pub aggregates: Vec<AggregateFunction>,
357 pub aggregate_filters: Vec<Option<Vec<Predicate>>>,
364 pub group_by: Vec<ColumnName>,
366 pub distinct: bool,
368 pub having: Vec<HavingCondition>,
370 pub ctes: Vec<ParsedCte>,
372 pub window_fns: Vec<ParsedWindowFn>,
376 pub scalar_projections: Vec<ParsedScalarProjection>,
384}
385
386#[derive(Debug, Clone)]
389pub struct ParsedScalarProjection {
390 pub expr: ScalarExpr,
392 pub output_name: ColumnName,
395 pub alias: Option<String>,
399}
400
401#[derive(Debug, Clone)]
404pub struct ParsedWindowFn {
405 pub function: crate::window::WindowFunction,
407 pub partition_by: Vec<ColumnName>,
409 pub order_by: Vec<OrderByClause>,
411 pub alias: Option<String>,
413}
414
415#[derive(Debug, Clone)]
419pub enum HavingCondition {
420 AggregateComparison {
422 aggregate: AggregateFunction,
424 op: HavingOp,
426 value: Value,
428 },
429}
430
431#[derive(Debug, Clone, Copy, PartialEq, Eq)]
433pub enum HavingOp {
434 Eq,
435 Lt,
436 Le,
437 Gt,
438 Ge,
439}
440
441#[derive(Debug, Clone)]
449pub struct ParsedCreateTable {
450 pub table_name: String,
451 pub columns: NonEmptyVec<ParsedColumn>,
452 pub primary_key: Vec<String>,
453 pub if_not_exists: bool,
456}
457
458#[derive(Debug, Clone)]
460pub struct ParsedColumn {
461 pub name: String,
462 pub data_type: String, pub nullable: bool,
464}
465
466#[derive(Debug, Clone)]
468pub struct ParsedAlterTable {
469 pub table_name: String,
470 pub operation: AlterTableOperation,
471}
472
473#[derive(Debug, Clone)]
475pub enum AlterTableOperation {
476 AddColumn(ParsedColumn),
478 DropColumn(String),
480}
481
482#[derive(Debug, Clone)]
484pub struct ParsedCreateIndex {
485 pub index_name: String,
486 pub table_name: String,
487 pub columns: Vec<String>,
488}
489
490#[derive(Debug, Clone)]
495pub struct ParsedInsert {
496 pub table: String,
497 pub columns: Vec<String>,
498 pub values: Vec<Vec<Value>>, pub returning: Option<Vec<String>>, pub on_conflict: Option<OnConflictClause>,
502}
503
504#[derive(Debug, Clone, PartialEq, Eq)]
514pub struct OnConflictClause {
515 pub target: Vec<String>,
519 pub action: OnConflictAction,
521}
522
523#[derive(Debug, Clone, PartialEq, Eq)]
525pub enum OnConflictAction {
526 DoNothing,
530 DoUpdate {
534 assignments: Vec<(String, UpsertExpr)>,
538 },
539}
540
541#[derive(Debug, Clone, PartialEq, Eq)]
549pub enum UpsertExpr {
550 Value(Value),
552 Excluded(String),
554}
555
556#[derive(Debug, Clone)]
558pub struct ParsedUpdate {
559 pub table: String,
560 pub assignments: Vec<(String, Value)>, pub predicates: Vec<Predicate>,
562 pub returning: Option<Vec<String>>, }
564
565#[derive(Debug, Clone)]
567pub struct ParsedDelete {
568 pub table: String,
569 pub predicates: Vec<Predicate>,
570 pub returning: Option<Vec<String>>, }
572
573#[derive(Debug, Clone, PartialEq, Eq)]
575pub enum AggregateFunction {
576 CountStar,
578 Count(ColumnName),
580 Sum(ColumnName),
582 Avg(ColumnName),
584 Min(ColumnName),
586 Max(ColumnName),
588}
589
590#[derive(Debug, Clone)]
592pub enum Predicate {
593 Eq(ColumnName, PredicateValue),
595 Lt(ColumnName, PredicateValue),
597 Le(ColumnName, PredicateValue),
599 Gt(ColumnName, PredicateValue),
601 Ge(ColumnName, PredicateValue),
603 In(ColumnName, Vec<PredicateValue>),
605 NotIn(ColumnName, Vec<PredicateValue>),
607 NotBetween(ColumnName, PredicateValue, PredicateValue),
609 Like(ColumnName, String),
611 NotLike(ColumnName, String),
613 ILike(ColumnName, String),
615 NotILike(ColumnName, String),
617 IsNull(ColumnName),
619 IsNotNull(ColumnName),
621 JsonExtractEq {
626 column: ColumnName,
628 path: String,
630 as_text: bool,
632 value: PredicateValue,
634 },
635 JsonContains {
637 column: ColumnName,
638 value: PredicateValue,
639 },
640 InSubquery {
649 column: ColumnName,
650 subquery: Box<ParsedSelect>,
651 negated: bool,
653 },
654 Exists {
659 subquery: Box<ParsedSelect>,
660 negated: bool,
661 },
662 Always(bool),
669 Or(Vec<Predicate>, Vec<Predicate>),
671 ScalarCmp {
680 lhs: ScalarExpr,
681 op: ScalarCmpOp,
682 rhs: ScalarExpr,
683 },
684}
685
686#[derive(Debug, Clone, Copy, PartialEq, Eq)]
688pub enum ScalarCmpOp {
689 Eq,
690 NotEq,
691 Lt,
692 Le,
693 Gt,
694 Ge,
695}
696
697impl Predicate {
698 #[allow(dead_code)]
702 pub fn column(&self) -> Option<&ColumnName> {
703 match self {
704 Predicate::Eq(col, _)
705 | Predicate::Lt(col, _)
706 | Predicate::Le(col, _)
707 | Predicate::Gt(col, _)
708 | Predicate::Ge(col, _)
709 | Predicate::In(col, _)
710 | Predicate::NotIn(col, _)
711 | Predicate::NotBetween(col, _, _)
712 | Predicate::Like(col, _)
713 | Predicate::NotLike(col, _)
714 | Predicate::ILike(col, _)
715 | Predicate::NotILike(col, _)
716 | Predicate::IsNull(col)
717 | Predicate::IsNotNull(col)
718 | Predicate::JsonExtractEq { column: col, .. }
719 | Predicate::JsonContains { column: col, .. }
720 | Predicate::InSubquery { column: col, .. } => Some(col),
721 Predicate::Or(_, _)
722 | Predicate::Exists { .. }
723 | Predicate::Always(_)
724 | Predicate::ScalarCmp { .. } => None,
725 }
726 }
727}
728
729#[derive(Debug, Clone)]
731pub enum PredicateValue {
732 Int(i64),
734 String(String),
736 Bool(bool),
738 Null,
740 Param(usize),
742 Literal(Value),
744 ColumnRef(String),
747}
748
749#[derive(Debug, Clone)]
751pub struct OrderByClause {
752 pub column: ColumnName,
754 pub ascending: bool,
756}
757
758pub fn parse_statement(sql: &str) -> Result<ParsedStatement> {
764 crate::depth_check::check_sql_depth(sql)?;
765
766 if let Some(parsed) = try_parse_custom_statement(sql)? {
768 return Ok(parsed);
769 }
770
771 let dialect = KimberliteDialect::new();
772 let statements =
773 Parser::parse_sql(&dialect, sql).map_err(|e| QueryError::ParseError(e.to_string()))?;
774
775 if statements.len() != 1 {
776 return Err(QueryError::ParseError(format!(
777 "expected exactly 1 statement, got {}",
778 statements.len()
779 )));
780 }
781
782 match &statements[0] {
783 Statement::Query(query) => parse_query_to_statement(query),
784 Statement::CreateTable(create_table) => {
785 let parsed = parse_create_table(create_table)?;
786 Ok(ParsedStatement::CreateTable(parsed))
787 }
788 Statement::Drop {
789 object_type,
790 names,
791 if_exists,
792 ..
793 } => {
794 if !matches!(object_type, sqlparser::ast::ObjectType::Table) {
795 return Err(QueryError::UnsupportedFeature(
796 "only DROP TABLE is supported".to_string(),
797 ));
798 }
799 if names.len() != 1 {
800 return Err(QueryError::ParseError(
801 "expected exactly 1 table in DROP TABLE".to_string(),
802 ));
803 }
804 let table_name = object_name_to_string(&names[0]);
805 Ok(ParsedStatement::DropTable {
806 name: table_name,
807 if_exists: *if_exists,
808 })
809 }
810 Statement::CreateIndex(create_index) => {
811 let parsed = parse_create_index(create_index)?;
812 Ok(ParsedStatement::CreateIndex(parsed))
813 }
814 Statement::Insert(insert) => {
815 let parsed = parse_insert(insert)?;
816 Ok(ParsedStatement::Insert(parsed))
817 }
818 Statement::Update(update) => {
819 let parsed = parse_update(
820 &update.table,
821 &update.assignments,
822 update.selection.as_ref(),
823 update.returning.as_ref(),
824 )?;
825 Ok(ParsedStatement::Update(parsed))
826 }
827 Statement::Delete(delete) => {
828 let parsed = parse_delete_stmt(delete)?;
829 Ok(ParsedStatement::Delete(parsed))
830 }
831 Statement::AlterTable(alter_table) => {
832 let parsed = parse_alter_table(&alter_table.name, &alter_table.operations)?;
833 Ok(ParsedStatement::AlterTable(parsed))
834 }
835 Statement::CreateRole(create_role) => {
836 if create_role.names.len() != 1 {
837 return Err(QueryError::ParseError(
838 "expected exactly 1 role name".to_string(),
839 ));
840 }
841 let role_name = object_name_to_string(&create_role.names[0]);
842 Ok(ParsedStatement::CreateRole(role_name))
843 }
844 Statement::Grant(grant) => {
845 let objects = grant.objects.as_ref().ok_or_else(|| {
846 QueryError::ParseError(
847 "GRANT requires an ON clause specifying the target objects".to_string(),
848 )
849 })?;
850 parse_grant(&grant.privileges, objects, &grant.grantees)
851 }
852 other => Err(QueryError::UnsupportedFeature(format!(
853 "statement type not supported: {other:?}"
854 ))),
855 }
856}
857
858pub fn try_parse_custom_statement(sql: &str) -> Result<Option<ParsedStatement>> {
874 let trimmed = sql.trim().trim_end_matches(';').trim();
875 let upper = trimmed.to_ascii_uppercase();
876
877 if upper.starts_with("CREATE MASKING POLICY") {
879 return parse_create_masking_policy(trimmed).map(Some);
880 }
881
882 if upper.starts_with("DROP MASKING POLICY") {
884 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
885 if tokens.len() != 4 {
887 return Err(QueryError::ParseError(
888 "expected: DROP MASKING POLICY <name>".to_string(),
889 ));
890 }
891 return Ok(Some(ParsedStatement::DropMaskingPolicy(
892 tokens[3].to_string(),
893 )));
894 }
895
896 if upper.starts_with("ALTER TABLE") && upper.contains("MASKING POLICY") {
898 return parse_alter_masking_policy(trimmed).map(Some);
899 }
900
901 if upper.starts_with("CREATE MASK") {
903 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
904 if tokens.len() != 7 {
906 return Err(QueryError::ParseError(
907 "expected: CREATE MASK <name> ON <table>.<column> USING <strategy>".to_string(),
908 ));
909 }
910 if !tokens[3].eq_ignore_ascii_case("ON") {
911 return Err(QueryError::ParseError(format!(
912 "expected ON after mask name, got '{}'",
913 tokens[3]
914 )));
915 }
916 if !tokens[5].eq_ignore_ascii_case("USING") {
917 return Err(QueryError::ParseError(format!(
918 "expected USING after column reference, got '{}'",
919 tokens[5]
920 )));
921 }
922
923 let table_col = tokens[4];
925 let dot_pos = table_col.find('.').ok_or_else(|| {
926 QueryError::ParseError(format!(
927 "expected <table>.<column> but got '{table_col}' (missing '.')"
928 ))
929 })?;
930 let table_name = table_col[..dot_pos].to_string();
931 let column_name = table_col[dot_pos + 1..].to_string();
932
933 if table_name.is_empty() || column_name.is_empty() {
934 return Err(QueryError::ParseError(
935 "table name and column name must not be empty".to_string(),
936 ));
937 }
938
939 let strategy = tokens[6].to_ascii_uppercase();
940
941 return Ok(Some(ParsedStatement::CreateMask(ParsedCreateMask {
942 mask_name: tokens[2].to_string(),
943 table_name,
944 column_name,
945 strategy,
946 })));
947 }
948
949 if upper.starts_with("DROP MASK") {
951 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
952 if tokens.len() != 3 {
953 return Err(QueryError::ParseError(
954 "expected: DROP MASK <name>".to_string(),
955 ));
956 }
957 return Ok(Some(ParsedStatement::DropMask(tokens[2].to_string())));
958 }
959
960 if upper.starts_with("ALTER TABLE") && upper.contains("SET CLASSIFICATION") {
962 return parse_set_classification(trimmed);
963 }
964
965 if upper.starts_with("SHOW CLASSIFICATIONS") {
967 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
968 if tokens.len() != 4 {
970 return Err(QueryError::ParseError(
971 "expected: SHOW CLASSIFICATIONS FOR <table>".to_string(),
972 ));
973 }
974 if !tokens[2].eq_ignore_ascii_case("FOR") {
975 return Err(QueryError::ParseError(format!(
976 "expected FOR after CLASSIFICATIONS, got '{}'",
977 tokens[2]
978 )));
979 }
980 return Ok(Some(ParsedStatement::ShowClassifications(
981 tokens[3].to_string(),
982 )));
983 }
984
985 if upper == "SHOW TABLES" {
987 return Ok(Some(ParsedStatement::ShowTables));
988 }
989
990 if upper.starts_with("SHOW COLUMNS") {
992 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
993 if tokens.len() != 4 {
995 return Err(QueryError::ParseError(
996 "expected: SHOW COLUMNS FROM <table>".to_string(),
997 ));
998 }
999 if !tokens[2].eq_ignore_ascii_case("FROM") {
1000 return Err(QueryError::ParseError(format!(
1001 "expected FROM after COLUMNS, got '{}'",
1002 tokens[2]
1003 )));
1004 }
1005 return Ok(Some(ParsedStatement::ShowColumns(tokens[3].to_string())));
1006 }
1007
1008 if upper.starts_with("CREATE USER") {
1010 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
1011 if tokens.len() != 6 {
1013 return Err(QueryError::ParseError(
1014 "expected: CREATE USER <name> WITH ROLE <role>".to_string(),
1015 ));
1016 }
1017 if !tokens[3].eq_ignore_ascii_case("WITH") {
1018 return Err(QueryError::ParseError(format!(
1019 "expected WITH after username, got '{}'",
1020 tokens[3]
1021 )));
1022 }
1023 if !tokens[4].eq_ignore_ascii_case("ROLE") {
1024 return Err(QueryError::ParseError(format!(
1025 "expected ROLE after WITH, got '{}'",
1026 tokens[4]
1027 )));
1028 }
1029 return Ok(Some(ParsedStatement::CreateUser(ParsedCreateUser {
1030 username: tokens[2].to_string(),
1031 role: tokens[5].to_string(),
1032 })));
1033 }
1034
1035 Ok(None)
1036}
1037
1038fn parse_create_masking_policy(trimmed: &str) -> Result<ParsedStatement> {
1046 let after_keyword = trimmed
1050 .get("CREATE MASKING POLICY".len()..)
1051 .ok_or_else(|| QueryError::ParseError("missing policy body".to_string()))?
1052 .trim_start();
1053
1054 let upper_body = after_keyword.to_ascii_uppercase();
1057 let exempt_pos = upper_body.find("EXEMPT ROLES").ok_or_else(|| {
1058 QueryError::ParseError(
1059 "expected: CREATE MASKING POLICY <name> STRATEGY <kind> [<arg>] EXEMPT ROLES (<r>, ...)"
1060 .to_string(),
1061 )
1062 })?;
1063
1064 let header = after_keyword[..exempt_pos].trim();
1065 let exempt_tail = after_keyword[exempt_pos + "EXEMPT ROLES".len()..].trim();
1066
1067 let (name, strategy) = parse_masking_policy_header(header)?;
1068 let exempt_roles = parse_exempt_roles_list(exempt_tail)?;
1069
1070 Ok(ParsedStatement::CreateMaskingPolicy(
1071 ParsedCreateMaskingPolicy {
1072 name,
1073 strategy,
1074 exempt_roles,
1075 },
1076 ))
1077}
1078
1079fn parse_masking_policy_header(header: &str) -> Result<(String, ParsedMaskingStrategy)> {
1081 let tokens: Vec<&str> = header.split_whitespace().collect();
1082 if tokens.len() < 3 {
1083 return Err(QueryError::ParseError(
1084 "expected: <name> STRATEGY <kind> [<arg>]".to_string(),
1085 ));
1086 }
1087 let name = tokens[0].to_string();
1088 if name.is_empty() {
1089 return Err(QueryError::ParseError(
1090 "policy name must not be empty".to_string(),
1091 ));
1092 }
1093 if !tokens[1].eq_ignore_ascii_case("STRATEGY") {
1094 return Err(QueryError::ParseError(format!(
1095 "expected STRATEGY after policy name, got '{}'",
1096 tokens[1]
1097 )));
1098 }
1099
1100 let strategy = parse_masking_strategy(&tokens[2..])?;
1101 Ok((name, strategy))
1102}
1103
1104fn parse_masking_strategy(tokens: &[&str]) -> Result<ParsedMaskingStrategy> {
1106 debug_assert!(
1107 !tokens.is_empty(),
1108 "caller must pass at least the strategy keyword"
1109 );
1110 let kind = tokens[0].to_ascii_uppercase();
1111 match kind.as_str() {
1112 "REDACT_SSN" => {
1113 expect_no_strategy_arg(tokens, "REDACT_SSN").map(|()| ParsedMaskingStrategy::RedactSsn)
1114 }
1115 "REDACT_PHONE" => expect_no_strategy_arg(tokens, "REDACT_PHONE")
1116 .map(|()| ParsedMaskingStrategy::RedactPhone),
1117 "REDACT_EMAIL" => expect_no_strategy_arg(tokens, "REDACT_EMAIL")
1118 .map(|()| ParsedMaskingStrategy::RedactEmail),
1119 "REDACT_CC" => expect_no_strategy_arg(tokens, "REDACT_CC")
1120 .map(|()| ParsedMaskingStrategy::RedactCreditCard),
1121 "REDACT_CUSTOM" => {
1122 if tokens.len() != 2 {
1123 return Err(QueryError::ParseError(
1124 "REDACT_CUSTOM requires a single quoted replacement string".to_string(),
1125 ));
1126 }
1127 let replacement = unquote_string_literal(tokens[1]).ok_or_else(|| {
1128 QueryError::ParseError(
1129 "REDACT_CUSTOM replacement must be a single-quoted string".to_string(),
1130 )
1131 })?;
1132 Ok(ParsedMaskingStrategy::RedactCustom { replacement })
1133 }
1134 "HASH" => {
1135 expect_no_strategy_arg(tokens, "HASH")?;
1136 Ok(ParsedMaskingStrategy::Hash)
1137 }
1138 "TOKENIZE" => {
1139 expect_no_strategy_arg(tokens, "TOKENIZE")?;
1140 Ok(ParsedMaskingStrategy::Tokenize)
1141 }
1142 "TRUNCATE" => {
1143 if tokens.len() != 2 {
1144 return Err(QueryError::ParseError(
1145 "TRUNCATE requires a positive integer character count".to_string(),
1146 ));
1147 }
1148 let max_chars = tokens[1].parse::<usize>().map_err(|_| {
1149 QueryError::ParseError(format!(
1150 "TRUNCATE argument must be a non-negative integer, got '{}'",
1151 tokens[1]
1152 ))
1153 })?;
1154 if max_chars == 0 {
1155 return Err(QueryError::ParseError(
1156 "TRUNCATE character count must be > 0".to_string(),
1157 ));
1158 }
1159 Ok(ParsedMaskingStrategy::Truncate { max_chars })
1160 }
1161 "NULL" => {
1162 expect_no_strategy_arg(tokens, "NULL")?;
1163 Ok(ParsedMaskingStrategy::Null)
1164 }
1165 _ => Err(QueryError::ParseError(format!(
1166 "unknown masking strategy '{kind}' — expected one of REDACT_SSN, REDACT_PHONE, \
1167 REDACT_EMAIL, REDACT_CC, REDACT_CUSTOM, HASH, TOKENIZE, TRUNCATE, NULL"
1168 ))),
1169 }
1170}
1171
1172fn expect_no_strategy_arg(tokens: &[&str], kind: &str) -> Result<()> {
1173 if tokens.len() != 1 {
1174 return Err(QueryError::ParseError(format!(
1175 "{kind} takes no arguments (found {} extra token(s))",
1176 tokens.len() - 1
1177 )));
1178 }
1179 Ok(())
1180}
1181
1182fn unquote_string_literal(token: &str) -> Option<String> {
1185 let bytes = token.as_bytes();
1186 if bytes.len() < 2 || bytes[0] != b'\'' || bytes[bytes.len() - 1] != b'\'' {
1187 return None;
1188 }
1189 Some(token[1..token.len() - 1].to_string())
1190}
1191
1192fn parse_exempt_roles_list(tail: &str) -> Result<Vec<String>> {
1199 let trimmed = tail.trim();
1200 if !trimmed.starts_with('(') || !trimmed.ends_with(')') {
1201 return Err(QueryError::ParseError(
1202 "EXEMPT ROLES must be followed by a parenthesised list: EXEMPT ROLES (r1, r2, ...)"
1203 .to_string(),
1204 ));
1205 }
1206 let inner = &trimmed[1..trimmed.len() - 1];
1207 let roles: Vec<String> = inner
1208 .split(',')
1209 .map(|s| s.trim().trim_matches('\'').to_ascii_lowercase())
1210 .filter(|s| !s.is_empty())
1211 .collect();
1212 if roles.is_empty() {
1213 return Err(QueryError::ParseError(
1214 "EXEMPT ROLES list must contain at least one role".to_string(),
1215 ));
1216 }
1217 Ok(roles)
1218}
1219
1220fn parse_alter_masking_policy(trimmed: &str) -> Result<ParsedStatement> {
1222 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
1223 if tokens.len() < 9 || tokens.len() > 10 {
1226 return Err(QueryError::ParseError(
1227 "expected: ALTER TABLE <t> ALTER COLUMN <c> { SET | DROP } MASKING POLICY [<name>]"
1228 .to_string(),
1229 ));
1230 }
1231 if !tokens[0].eq_ignore_ascii_case("ALTER")
1232 || !tokens[1].eq_ignore_ascii_case("TABLE")
1233 || !tokens[3].eq_ignore_ascii_case("ALTER")
1234 || !tokens[4].eq_ignore_ascii_case("COLUMN")
1235 || !tokens[7].eq_ignore_ascii_case("MASKING")
1236 || !tokens[8].eq_ignore_ascii_case("POLICY")
1237 {
1238 return Err(QueryError::ParseError(format!(
1239 "malformed ALTER ... MASKING POLICY statement: '{trimmed}'"
1240 )));
1241 }
1242 let table_name = tokens[2].to_string();
1243 let column_name = tokens[5].to_string();
1244 let action = tokens[6].to_ascii_uppercase();
1245 match action.as_str() {
1246 "SET" => {
1247 if tokens.len() != 10 {
1248 return Err(QueryError::ParseError(
1249 "SET MASKING POLICY requires a policy name".to_string(),
1250 ));
1251 }
1252 Ok(ParsedStatement::AttachMaskingPolicy(
1253 ParsedAttachMaskingPolicy {
1254 table_name,
1255 column_name,
1256 policy_name: tokens[9].to_string(),
1257 },
1258 ))
1259 }
1260 "DROP" => {
1261 if tokens.len() != 9 {
1262 return Err(QueryError::ParseError(
1263 "DROP MASKING POLICY takes no arguments after POLICY".to_string(),
1264 ));
1265 }
1266 Ok(ParsedStatement::DetachMaskingPolicy(
1267 ParsedDetachMaskingPolicy {
1268 table_name,
1269 column_name,
1270 },
1271 ))
1272 }
1273 _ => Err(QueryError::ParseError(format!(
1274 "expected SET or DROP after column name, got '{action}'"
1275 ))),
1276 }
1277}
1278
1279#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1296pub enum TimeTravel {
1297 Offset(u64),
1299 TimestampNs(i64),
1302}
1303
1304pub fn extract_at_offset(sql: &str) -> (String, Option<u64>) {
1326 let upper = sql.to_ascii_uppercase();
1329
1330 let Some(at_pos) = upper.rfind("AT OFFSET") else {
1333 return (sql.to_string(), None);
1334 };
1335
1336 if at_pos > 0 {
1338 let prev_byte = sql.as_bytes()[at_pos - 1];
1339 if prev_byte != b' ' && prev_byte != b'\t' && prev_byte != b'\n' && prev_byte != b'\r' {
1340 return (sql.to_string(), None);
1341 }
1342 }
1343
1344 let after_at_offset = &sql[at_pos + 9..].trim_start();
1346
1347 let num_end = after_at_offset
1350 .find(|c: char| !c.is_ascii_digit())
1351 .unwrap_or(after_at_offset.len());
1352
1353 if num_end == 0 {
1354 return (sql.to_string(), None);
1356 }
1357
1358 let num_str = &after_at_offset[..num_end];
1359 let Ok(offset) = num_str.parse::<u64>() else {
1360 return (sql.to_string(), None);
1361 };
1362
1363 let remainder = after_at_offset[num_end..].trim();
1366 if !remainder.is_empty() && remainder != ";" {
1367 return (sql.to_string(), None);
1368 }
1369
1370 let before = sql[..at_pos].trim_end();
1372 let cleaned = before.to_string();
1373
1374 (cleaned, Some(offset))
1375}
1376
1377pub fn extract_time_travel(sql: &str) -> (String, Option<TimeTravel>) {
1406 let (after_offset_sql, offset) = extract_at_offset(sql);
1408 if let Some(o) = offset {
1409 return (after_offset_sql, Some(TimeTravel::Offset(o)));
1410 }
1411
1412 let upper = sql.to_ascii_uppercase();
1418
1419 let (keyword_pos, keyword_len) = if let Some(p) = upper.rfind("FOR SYSTEM_TIME AS OF") {
1425 (p, "FOR SYSTEM_TIME AS OF".len())
1426 } else if let Some(p) = upper.rfind("AS OF") {
1427 let after = sql[p + "AS OF".len()..].trim_start();
1430 if !after.starts_with('\'') {
1431 return (sql.to_string(), None);
1432 }
1433 (p, "AS OF".len())
1434 } else {
1435 return (sql.to_string(), None);
1436 };
1437
1438 if keyword_pos > 0 {
1440 let prev = sql.as_bytes()[keyword_pos - 1];
1441 if !matches!(prev, b' ' | b'\t' | b'\n' | b'\r') {
1442 return (sql.to_string(), None);
1443 }
1444 }
1445
1446 let after_keyword = sql[keyword_pos + keyword_len..].trim_start();
1447 if !after_keyword.starts_with('\'') {
1448 return (sql.to_string(), None);
1449 }
1450
1451 let ts_start = 1; let ts_end = match after_keyword[1..].find('\'') {
1456 Some(i) => i + 1,
1457 None => return (sql.to_string(), None),
1458 };
1459 let ts_str = &after_keyword[ts_start..ts_end];
1460
1461 let ts_ns = match chrono::DateTime::parse_from_rfc3339(ts_str) {
1463 Ok(dt) => dt.timestamp_nanos_opt(),
1464 Err(_) => return (sql.to_string(), None),
1465 };
1466 let ts_ns = match ts_ns {
1467 Some(n) => n,
1468 None => return (sql.to_string(), None),
1469 };
1470
1471 let remainder = after_keyword[ts_end + 1..].trim();
1473 if !remainder.is_empty() && remainder != ";" {
1474 return (sql.to_string(), None);
1475 }
1476
1477 let before = sql[..keyword_pos].trim_end();
1478 (before.to_string(), Some(TimeTravel::TimestampNs(ts_ns)))
1479}
1480
1481fn parse_set_classification(sql: &str) -> Result<Option<ParsedStatement>> {
1486 let tokens: Vec<&str> = sql.split_whitespace().collect();
1487 if tokens.len() != 9 {
1490 return Err(QueryError::ParseError(
1491 "expected: ALTER TABLE <table> MODIFY COLUMN <column> SET CLASSIFICATION '<class>'"
1492 .to_string(),
1493 ));
1494 }
1495
1496 if !tokens[3].eq_ignore_ascii_case("MODIFY") {
1497 return Err(QueryError::ParseError(format!(
1498 "expected MODIFY, got '{}'",
1499 tokens[3]
1500 )));
1501 }
1502 if !tokens[4].eq_ignore_ascii_case("COLUMN") {
1503 return Err(QueryError::ParseError(format!(
1504 "expected COLUMN after MODIFY, got '{}'",
1505 tokens[4]
1506 )));
1507 }
1508 if !tokens[6].eq_ignore_ascii_case("SET") {
1509 return Err(QueryError::ParseError(format!(
1510 "expected SET, got '{}'",
1511 tokens[6]
1512 )));
1513 }
1514 if !tokens[7].eq_ignore_ascii_case("CLASSIFICATION") {
1515 return Err(QueryError::ParseError(format!(
1516 "expected CLASSIFICATION, got '{}'",
1517 tokens[7]
1518 )));
1519 }
1520
1521 let table_name = tokens[2].to_string();
1522 let column_name = tokens[5].to_string();
1523
1524 let raw_class = tokens[8];
1526 let classification = raw_class
1527 .strip_prefix('\'')
1528 .and_then(|s| s.strip_suffix('\''))
1529 .ok_or_else(|| {
1530 QueryError::ParseError(format!(
1531 "classification must be quoted with single quotes, got '{raw_class}'"
1532 ))
1533 })?
1534 .to_string();
1535
1536 assert!(!table_name.is_empty(), "table name must not be empty");
1537 assert!(!column_name.is_empty(), "column name must not be empty");
1538 assert!(
1539 !classification.is_empty(),
1540 "classification must not be empty"
1541 );
1542
1543 Ok(Some(ParsedStatement::SetClassification(
1544 ParsedSetClassification {
1545 table_name,
1546 column_name,
1547 classification,
1548 },
1549 )))
1550}
1551
1552fn parse_grant(
1554 privileges: &sqlparser::ast::Privileges,
1555 objects: &sqlparser::ast::GrantObjects,
1556 grantees: &[sqlparser::ast::Grantee],
1557) -> Result<ParsedStatement> {
1558 use sqlparser::ast::{Action, GrantObjects, GranteeName, Privileges};
1559
1560 let columns = match privileges {
1562 Privileges::Actions(actions) => {
1563 let mut cols = None;
1564 for action in actions {
1565 if let Action::Select { columns: Some(c) } = action {
1566 cols = Some(c.iter().map(|i| i.value.clone()).collect());
1567 }
1568 }
1569 cols
1570 }
1571 Privileges::All { .. } => None,
1572 };
1573
1574 let table_name = match objects {
1576 GrantObjects::Tables(tables) => {
1577 if tables.len() != 1 {
1578 return Err(QueryError::ParseError(
1579 "expected exactly 1 table in GRANT".to_string(),
1580 ));
1581 }
1582 object_name_to_string(&tables[0])
1583 }
1584 _ => {
1585 return Err(QueryError::UnsupportedFeature(
1586 "GRANT only supports table-level privileges".to_string(),
1587 ));
1588 }
1589 };
1590
1591 if grantees.len() != 1 {
1593 return Err(QueryError::ParseError(
1594 "expected exactly 1 grantee in GRANT".to_string(),
1595 ));
1596 }
1597 let role_name = match &grantees[0].name {
1598 Some(GranteeName::ObjectName(name)) => object_name_to_string(name),
1599 _ => {
1600 return Err(QueryError::ParseError(
1601 "expected a role name in GRANT".to_string(),
1602 ));
1603 }
1604 };
1605
1606 Ok(ParsedStatement::Grant(ParsedGrant {
1607 columns,
1608 table_name,
1609 role_name,
1610 }))
1611}
1612
1613fn parse_query_to_statement(query: &Query) -> Result<ParsedStatement> {
1615 let ctes = match &query.with {
1617 Some(with) => parse_ctes(with)?,
1618 None => vec![],
1619 };
1620
1621 match query.body.as_ref() {
1622 SetExpr::Select(select) => {
1623 let parsed_select = parse_select(select)?;
1624
1625 let order_by = match &query.order_by {
1627 Some(ob) => parse_order_by(ob)?,
1628 None => vec![],
1629 };
1630
1631 let limit = parse_limit(query_limit_expr(query)?)?;
1633 let offset = parse_offset_clause(query_offset(query))?;
1634
1635 let mut all_ctes = ctes;
1637 all_ctes.extend(parsed_select.ctes);
1638
1639 Ok(ParsedStatement::Select(ParsedSelect {
1640 table: parsed_select.table,
1641 joins: parsed_select.joins,
1642 columns: parsed_select.columns,
1643 column_aliases: parsed_select.column_aliases,
1644 case_columns: parsed_select.case_columns,
1645 predicates: parsed_select.predicates,
1646 order_by,
1647 limit,
1648 offset,
1649 aggregates: parsed_select.aggregates,
1650 aggregate_filters: parsed_select.aggregate_filters,
1651 group_by: parsed_select.group_by,
1652 distinct: parsed_select.distinct,
1653 having: parsed_select.having,
1654 ctes: all_ctes,
1655 window_fns: parsed_select.window_fns,
1656 scalar_projections: parsed_select.scalar_projections,
1657 }))
1658 }
1659 SetExpr::SetOperation {
1660 op,
1661 set_quantifier,
1662 left,
1663 right,
1664 } => {
1665 use sqlparser::ast::SetOperator;
1666 use sqlparser::ast::SetQuantifier;
1667
1668 let parsed_op = match op {
1669 SetOperator::Union => SetOp::Union,
1670 SetOperator::Intersect => SetOp::Intersect,
1671 SetOperator::Except | SetOperator::Minus => SetOp::Except,
1673 };
1674
1675 let all = matches!(set_quantifier, SetQuantifier::All);
1676
1677 let left_select = match left.as_ref() {
1679 SetExpr::Select(s) => parse_select(s)?,
1680 _ => {
1681 return Err(QueryError::UnsupportedFeature(
1682 "nested set operations not supported".to_string(),
1683 ));
1684 }
1685 };
1686 let right_select = match right.as_ref() {
1687 SetExpr::Select(s) => parse_select(s)?,
1688 _ => {
1689 return Err(QueryError::UnsupportedFeature(
1690 "nested set operations not supported".to_string(),
1691 ));
1692 }
1693 };
1694
1695 Ok(ParsedStatement::Union(ParsedUnion {
1696 op: parsed_op,
1697 left: left_select,
1698 right: right_select,
1699 all,
1700 }))
1701 }
1702 other => Err(QueryError::UnsupportedFeature(format!(
1703 "unsupported query type: {other:?}"
1704 ))),
1705 }
1706}
1707
1708fn parse_join_with_subqueries(join: &sqlparser::ast::Join) -> Result<(ParsedJoin, Vec<ParsedCte>)> {
1710 use sqlparser::ast::{JoinConstraint, JoinOperator};
1711
1712 let join_type = match &join.join_operator {
1719 JoinOperator::Inner(_) | JoinOperator::Join(_) => JoinType::Inner,
1720 JoinOperator::LeftOuter(_) | JoinOperator::Left(_) => JoinType::Left,
1721 JoinOperator::RightOuter(_) | JoinOperator::Right(_) => JoinType::Right,
1722 JoinOperator::FullOuter(_) => JoinType::Full,
1723 JoinOperator::CrossJoin(_) => JoinType::Cross,
1724 other => {
1725 return Err(QueryError::UnsupportedFeature(format!(
1726 "join type not supported: {other:?}"
1727 )));
1728 }
1729 };
1730
1731 let mut inline_ctes = Vec::new();
1733 let table = match &join.relation {
1734 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1735 sqlparser::ast::TableFactor::Derived {
1736 subquery, alias, ..
1737 } => {
1738 let alias_name = alias
1739 .as_ref()
1740 .map(|a| a.name.value.clone())
1741 .ok_or_else(|| {
1742 QueryError::ParseError("subquery in JOIN requires an alias".to_string())
1743 })?;
1744
1745 let inner = match subquery.body.as_ref() {
1747 SetExpr::Select(s) => parse_select(s)?,
1748 _ => {
1749 return Err(QueryError::UnsupportedFeature(
1750 "subquery body must be a simple SELECT".to_string(),
1751 ));
1752 }
1753 };
1754
1755 let order_by = match &subquery.order_by {
1756 Some(ob) => parse_order_by(ob)?,
1757 None => vec![],
1758 };
1759 let limit = parse_limit(query_limit_expr(subquery)?)?;
1760
1761 inline_ctes.push(ParsedCte {
1762 name: alias_name.clone(),
1763 query: ParsedSelect {
1764 order_by,
1765 limit,
1766 ..inner
1767 },
1768 recursive_arm: None,
1769 });
1770
1771 alias_name
1772 }
1773 _ => {
1774 return Err(QueryError::UnsupportedFeature(
1775 "unsupported JOIN relation type".to_string(),
1776 ));
1777 }
1778 };
1779
1780 let on_condition = match &join.join_operator {
1782 JoinOperator::CrossJoin(_) => Vec::new(),
1783 JoinOperator::Inner(constraint)
1784 | JoinOperator::Join(constraint)
1785 | JoinOperator::LeftOuter(constraint)
1786 | JoinOperator::Left(constraint)
1787 | JoinOperator::RightOuter(constraint)
1788 | JoinOperator::Right(constraint)
1789 | JoinOperator::FullOuter(constraint) => match constraint {
1790 JoinConstraint::On(expr) => parse_join_condition(expr)?,
1791 JoinConstraint::Using(idents) => {
1792 let mut preds = Vec::new();
1797 for name in idents {
1798 if name.0.len() != 1 {
1799 return Err(QueryError::UnsupportedFeature(format!(
1800 "USING column must be a bare identifier, got {name}"
1801 )));
1802 }
1803 let col_name = name.0[0]
1804 .as_ident()
1805 .ok_or_else(|| {
1806 QueryError::UnsupportedFeature(format!(
1807 "USING column must be a bare identifier, got {name}"
1808 ))
1809 })?
1810 .value
1811 .clone();
1812 preds.push(Predicate::Eq(
1813 ColumnName::new(col_name.clone()),
1814 PredicateValue::ColumnRef(col_name),
1815 ));
1816 }
1817 preds
1818 }
1819 JoinConstraint::Natural => {
1820 return Err(QueryError::UnsupportedFeature(
1821 "NATURAL JOIN is not supported; use ON or USING explicitly".to_string(),
1822 ));
1823 }
1824 JoinConstraint::None => {
1825 return Err(QueryError::UnsupportedFeature(
1826 "join without ON or USING clause not supported".to_string(),
1827 ));
1828 }
1829 },
1830 _ => {
1831 return Err(QueryError::UnsupportedFeature(
1832 "join without ON clause not supported".to_string(),
1833 ));
1834 }
1835 };
1836
1837 Ok((
1838 ParsedJoin {
1839 table,
1840 join_type,
1841 on_condition,
1842 },
1843 inline_ctes,
1844 ))
1845}
1846
1847fn parse_join_condition(expr: &Expr) -> Result<Vec<Predicate>> {
1850 match expr {
1851 Expr::BinaryOp {
1852 left,
1853 op: BinaryOperator::And,
1854 right,
1855 } => {
1856 let mut predicates = parse_join_condition(left)?;
1857 predicates.extend(parse_join_condition(right)?);
1858 Ok(predicates)
1859 }
1860 _ => {
1861 parse_where_expr(expr)
1863 }
1864 }
1865}
1866
1867fn parse_select(select: &Select) -> Result<ParsedSelect> {
1868 let distinct = select.distinct.is_some();
1870
1871 if select.from.len() != 1 {
1873 return Err(QueryError::ParseError(format!(
1874 "expected exactly 1 table in FROM clause, got {}",
1875 select.from.len()
1876 )));
1877 }
1878
1879 let from = &select.from[0];
1880
1881 let mut inline_ctes = Vec::new();
1883
1884 let mut joins = Vec::new();
1886 for join in &from.joins {
1887 let (parsed_join, join_ctes) = parse_join_with_subqueries(join)?;
1888 joins.push(parsed_join);
1889 inline_ctes.extend(join_ctes);
1890 }
1891
1892 let table = match &from.relation {
1893 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1894 sqlparser::ast::TableFactor::Derived {
1895 subquery, alias, ..
1896 } => {
1897 let alias_name = alias
1898 .as_ref()
1899 .map(|a| a.name.value.clone())
1900 .ok_or_else(|| {
1901 QueryError::ParseError("subquery in FROM requires an alias".to_string())
1902 })?;
1903
1904 let inner = match subquery.body.as_ref() {
1906 SetExpr::Select(s) => parse_select(s)?,
1907 _ => {
1908 return Err(QueryError::UnsupportedFeature(
1909 "subquery body must be a simple SELECT".to_string(),
1910 ));
1911 }
1912 };
1913
1914 let order_by = match &subquery.order_by {
1915 Some(ob) => parse_order_by(ob)?,
1916 None => vec![],
1917 };
1918 let limit = parse_limit(query_limit_expr(subquery)?)?;
1919
1920 inline_ctes.push(ParsedCte {
1921 name: alias_name.clone(),
1922 query: ParsedSelect {
1923 order_by,
1924 limit,
1925 ..inner
1926 },
1927 recursive_arm: None,
1928 });
1929
1930 alias_name
1931 }
1932 other => {
1933 return Err(QueryError::UnsupportedFeature(format!(
1934 "unsupported FROM clause: {other:?}"
1935 )));
1936 }
1937 };
1938
1939 let (columns, column_aliases) = parse_select_items(&select.projection)?;
1941
1942 let case_columns = parse_case_columns_from_select_items(&select.projection)?;
1944
1945 let predicates = match &select.selection {
1947 Some(expr) => parse_where_expr(expr)?,
1948 None => vec![],
1949 };
1950
1951 let group_by = match &select.group_by {
1953 sqlparser::ast::GroupByExpr::Expressions(exprs, _) if !exprs.is_empty() => {
1954 parse_group_by_expr(exprs)?
1955 }
1956 sqlparser::ast::GroupByExpr::All(_) => {
1957 return Err(QueryError::UnsupportedFeature(
1958 "GROUP BY ALL is not supported".to_string(),
1959 ));
1960 }
1961 sqlparser::ast::GroupByExpr::Expressions(_, _) => vec![],
1962 };
1963
1964 let (aggregates, aggregate_filters) = parse_aggregates_from_select_items(&select.projection)?;
1966
1967 let having = match &select.having {
1969 Some(expr) => parse_having_expr(expr)?,
1970 None => vec![],
1971 };
1972
1973 let window_fns = parse_window_fns_from_select_items(&select.projection)?;
1976
1977 let scalar_projections = parse_scalar_columns_from_select_items(&select.projection)?;
1980
1981 Ok(ParsedSelect {
1982 table,
1983 joins,
1984 columns,
1985 column_aliases,
1986 case_columns,
1987 predicates,
1988 order_by: vec![],
1989 limit: None,
1990 offset: None,
1991 aggregates,
1992 aggregate_filters,
1993 group_by,
1994 distinct,
1995 having,
1996 ctes: inline_ctes,
1997 window_fns,
1998 scalar_projections,
1999 })
2000}
2001
2002fn parse_ctes(with: &sqlparser::ast::With) -> Result<Vec<ParsedCte>> {
2009 let max_ctes = 16;
2010 let mut ctes = Vec::new();
2011
2012 for (i, cte) in with.cte_tables.iter().enumerate() {
2013 if i >= max_ctes {
2014 return Err(QueryError::UnsupportedFeature(format!(
2015 "too many CTEs (max {max_ctes})"
2016 )));
2017 }
2018
2019 let name = cte.alias.name.value.clone();
2020
2021 let (inner_select, recursive_arm) = match cte.query.body.as_ref() {
2026 SetExpr::Select(s) => (parse_select(s)?, None),
2027 SetExpr::SetOperation {
2028 op, left, right, ..
2029 } if with.recursive => {
2030 use sqlparser::ast::SetOperator;
2031 if !matches!(op, SetOperator::Union) {
2032 return Err(QueryError::UnsupportedFeature(
2033 "recursive CTE body must use UNION (not INTERSECT/EXCEPT)".to_string(),
2034 ));
2035 }
2036 let anchor = match left.as_ref() {
2037 SetExpr::Select(s) => parse_select(s)?,
2038 _ => {
2039 return Err(QueryError::UnsupportedFeature(
2040 "recursive CTE anchor must be a simple SELECT".to_string(),
2041 ));
2042 }
2043 };
2044 let recursive = match right.as_ref() {
2045 SetExpr::Select(s) => parse_select(s)?,
2046 _ => {
2047 return Err(QueryError::UnsupportedFeature(
2048 "recursive CTE recursive arm must be a simple SELECT".to_string(),
2049 ));
2050 }
2051 };
2052 (anchor, Some(recursive))
2053 }
2054 _ => {
2055 return Err(QueryError::UnsupportedFeature(
2056 "CTE body must be a simple SELECT (or anchor UNION recursive for WITH RECURSIVE)".to_string(),
2057 ));
2058 }
2059 };
2060
2061 let order_by = match &cte.query.order_by {
2063 Some(ob) => parse_order_by(ob)?,
2064 None => vec![],
2065 };
2066 let limit = parse_limit(query_limit_expr(&cte.query)?)?;
2067
2068 ctes.push(ParsedCte {
2069 name,
2070 query: ParsedSelect {
2071 order_by,
2072 limit,
2073 ..inner_select
2074 },
2075 recursive_arm,
2076 });
2077 }
2078
2079 Ok(ctes)
2080}
2081
2082fn parse_having_expr(expr: &Expr) -> Result<Vec<HavingCondition>> {
2087 match expr {
2088 Expr::BinaryOp {
2089 left,
2090 op: BinaryOperator::And,
2091 right,
2092 } => {
2093 let mut conditions = parse_having_expr(left)?;
2094 conditions.extend(parse_having_expr(right)?);
2095 Ok(conditions)
2096 }
2097 Expr::BinaryOp { left, op, right } => {
2098 let aggregate = match left.as_ref() {
2100 Expr::Function(_) => {
2101 let (agg, _filter) = try_parse_aggregate(left)?.ok_or_else(|| {
2102 QueryError::UnsupportedFeature(
2103 "HAVING requires aggregate functions (COUNT, SUM, AVG, MIN, MAX)"
2104 .to_string(),
2105 )
2106 })?;
2107 agg
2108 }
2109 _ => {
2110 return Err(QueryError::UnsupportedFeature(
2111 "HAVING clause must reference aggregate functions".to_string(),
2112 ));
2113 }
2114 };
2115
2116 let value = expr_to_value(right)?;
2118
2119 let having_op = match op {
2121 BinaryOperator::Eq => HavingOp::Eq,
2122 BinaryOperator::Lt => HavingOp::Lt,
2123 BinaryOperator::LtEq => HavingOp::Le,
2124 BinaryOperator::Gt => HavingOp::Gt,
2125 BinaryOperator::GtEq => HavingOp::Ge,
2126 other => {
2127 return Err(QueryError::UnsupportedFeature(format!(
2128 "unsupported HAVING operator: {other:?}"
2129 )));
2130 }
2131 };
2132
2133 Ok(vec![HavingCondition::AggregateComparison {
2134 aggregate,
2135 op: having_op,
2136 value,
2137 }])
2138 }
2139 Expr::Nested(inner) => parse_having_expr(inner),
2140 other => Err(QueryError::UnsupportedFeature(format!(
2141 "unsupported HAVING expression: {other:?}"
2142 ))),
2143 }
2144}
2145
2146type ParsedSelectList = (Option<Vec<ColumnName>>, Option<Vec<Option<String>>>);
2160
2161fn parse_select_items(items: &[SelectItem]) -> Result<ParsedSelectList> {
2162 let mut columns = Vec::new();
2163 let mut aliases: Vec<Option<String>> = Vec::new();
2164
2165 for item in items {
2166 #[allow(clippy::match_same_arms)]
2171 match item {
2172 SelectItem::Wildcard(_) => {
2173 return Ok((None, None));
2176 }
2177 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
2178 columns.push(ColumnName::new(ident.value.clone()));
2179 aliases.push(None);
2180 }
2181 SelectItem::UnnamedExpr(Expr::CompoundIdentifier(idents)) if idents.len() == 2 => {
2182 columns.push(ColumnName::new(idents[1].value.clone()));
2184 aliases.push(None);
2185 }
2186 SelectItem::ExprWithAlias {
2187 expr: Expr::Identifier(ident),
2188 alias,
2189 } => {
2190 columns.push(ColumnName::new(ident.value.clone()));
2191 aliases.push(Some(alias.value.clone()));
2192 }
2193 SelectItem::ExprWithAlias {
2194 expr: Expr::CompoundIdentifier(idents),
2195 alias,
2196 } if idents.len() == 2 => {
2197 columns.push(ColumnName::new(idents[1].value.clone()));
2199 aliases.push(Some(alias.value.clone()));
2200 }
2201 SelectItem::UnnamedExpr(Expr::Function(_))
2202 | SelectItem::ExprWithAlias {
2203 expr: Expr::Function(_) | Expr::Case { .. },
2204 ..
2205 } => {
2206 }
2210 SelectItem::UnnamedExpr(Expr::Cast { .. })
2214 | SelectItem::ExprWithAlias {
2215 expr: Expr::Cast { .. },
2216 ..
2217 } => {}
2218 SelectItem::UnnamedExpr(Expr::BinaryOp {
2219 op: BinaryOperator::StringConcat,
2220 ..
2221 })
2222 | SelectItem::ExprWithAlias {
2223 expr:
2224 Expr::BinaryOp {
2225 op: BinaryOperator::StringConcat,
2226 ..
2227 },
2228 ..
2229 } => {}
2230 other => {
2231 return Err(QueryError::UnsupportedFeature(format!(
2232 "unsupported SELECT item: {other:?}"
2233 )));
2234 }
2235 }
2236 }
2237
2238 Ok((Some(columns), Some(aliases)))
2239}
2240
2241type ParsedAggregateList = (Vec<AggregateFunction>, Vec<Option<Vec<Predicate>>>);
2249
2250fn parse_aggregates_from_select_items(items: &[SelectItem]) -> Result<ParsedAggregateList> {
2251 let mut aggregates = Vec::new();
2252 let mut filters = Vec::new();
2253
2254 for item in items {
2255 match item {
2256 SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
2257 if let Some((agg, filter)) = try_parse_aggregate(expr)? {
2258 aggregates.push(agg);
2259 filters.push(filter);
2260 }
2261 }
2262 _ => {
2263 }
2265 }
2266 }
2267
2268 Ok((aggregates, filters))
2269}
2270
2271fn parse_case_columns_from_select_items(items: &[SelectItem]) -> Result<Vec<ComputedColumn>> {
2276 let mut case_cols = Vec::new();
2277
2278 for item in items {
2279 if let SelectItem::ExprWithAlias {
2280 expr:
2281 Expr::Case {
2282 operand,
2283 conditions,
2284 else_result,
2285 ..
2286 },
2287 alias,
2288 } = item
2289 {
2290 let mut when_clauses = Vec::new();
2295 for case_when in conditions {
2296 let cond_expr = &case_when.condition;
2297 let result_expr = &case_when.result;
2298 let condition = match operand.as_deref() {
2299 None => parse_where_expr(cond_expr)?,
2300 Some(operand_expr) => parse_where_expr(&Expr::BinaryOp {
2301 left: Box::new(operand_expr.clone()),
2302 op: BinaryOperator::Eq,
2303 right: Box::new(cond_expr.clone()),
2304 })?,
2305 };
2306 let result = expr_to_value(result_expr)?;
2307 when_clauses.push(CaseWhenArm { condition, result });
2308 }
2309
2310 let else_value = match else_result {
2311 Some(expr) => expr_to_value(expr)?,
2312 None => Value::Null,
2313 };
2314
2315 case_cols.push(ComputedColumn {
2316 alias: ColumnName::new(alias.value.clone()),
2317 when_clauses,
2318 else_value,
2319 });
2320 }
2321 }
2322
2323 Ok(case_cols)
2324}
2325
2326fn parse_scalar_columns_from_select_items(
2336 items: &[SelectItem],
2337) -> Result<Vec<ParsedScalarProjection>> {
2338 let mut out = Vec::new();
2339 for item in items {
2340 let (expr, alias) = match item {
2341 SelectItem::UnnamedExpr(e) => (e, None),
2342 SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
2343 _ => continue,
2344 };
2345
2346 if !is_scalar_projection_shape(expr) {
2347 continue;
2348 }
2349
2350 let scalar = expr_to_scalar_expr(expr)?;
2351 let output_name = alias
2352 .clone()
2353 .unwrap_or_else(|| synthesize_column_name(expr));
2354 out.push(ParsedScalarProjection {
2355 expr: scalar,
2356 output_name: ColumnName::new(output_name),
2357 alias,
2358 });
2359 }
2360 Ok(out)
2361}
2362
2363fn is_scalar_projection_shape(expr: &Expr) -> bool {
2367 match expr {
2368 Expr::Function(func) => {
2369 if func.over.is_some() {
2371 return false;
2372 }
2373 let name = func.name.to_string().to_uppercase();
2374 !matches!(name.as_str(), "COUNT" | "SUM" | "AVG" | "MIN" | "MAX")
2375 }
2376 Expr::Cast { .. }
2377 | Expr::BinaryOp {
2378 op: BinaryOperator::StringConcat,
2379 ..
2380 } => true,
2381 _ => false,
2382 }
2383}
2384
2385fn synthesize_column_name(expr: &Expr) -> String {
2389 match expr {
2390 Expr::Function(func) => func.name.to_string().to_lowercase(),
2391 Expr::Cast { .. } => "cast".to_string(),
2392 Expr::BinaryOp {
2393 op: BinaryOperator::StringConcat,
2394 ..
2395 } => "concat".to_string(),
2396 _ => "expr".to_string(),
2397 }
2398}
2399
2400fn parse_window_fns_from_select_items(items: &[SelectItem]) -> Result<Vec<ParsedWindowFn>> {
2404 let mut out = Vec::new();
2405 for item in items {
2406 let (expr, alias) = match item {
2407 SelectItem::UnnamedExpr(e) => (e, None),
2408 SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
2409 _ => continue,
2410 };
2411 if let Some(parsed) = try_parse_window_fn(expr, alias)? {
2412 out.push(parsed);
2413 }
2414 }
2415 Ok(out)
2416}
2417
2418fn try_parse_window_fn(expr: &Expr, alias: Option<String>) -> Result<Option<ParsedWindowFn>> {
2419 let Expr::Function(func) = expr else {
2420 return Ok(None);
2421 };
2422 let Some(over) = &func.over else {
2423 return Ok(None);
2424 };
2425 let spec = match over {
2426 sqlparser::ast::WindowType::WindowSpec(s) => s,
2427 sqlparser::ast::WindowType::NamedWindow(_) => {
2428 return Err(QueryError::UnsupportedFeature(
2429 "named windows (OVER w) are not supported".into(),
2430 ));
2431 }
2432 };
2433 if spec.window_frame.is_some() {
2434 return Err(QueryError::UnsupportedFeature(
2435 "explicit window frames (ROWS/RANGE BETWEEN ...) are not supported; \
2436 omit the frame clause for default behaviour"
2437 .into(),
2438 ));
2439 }
2440
2441 let func_name = func.name.to_string().to_uppercase();
2442 let args = match &func.args {
2443 sqlparser::ast::FunctionArguments::List(list) => list.args.clone(),
2444 _ => Vec::new(),
2445 };
2446 let function = parse_window_function_name(&func_name, &args)?;
2447
2448 let partition_by: Vec<ColumnName> = spec
2449 .partition_by
2450 .iter()
2451 .map(parse_column_expr)
2452 .collect::<Result<_>>()?;
2453 let order_by: Vec<OrderByClause> = spec
2454 .order_by
2455 .iter()
2456 .map(parse_order_by_expr)
2457 .collect::<Result<_>>()?;
2458
2459 Ok(Some(ParsedWindowFn {
2460 function,
2461 partition_by,
2462 order_by,
2463 alias,
2464 }))
2465}
2466
2467fn parse_column_expr(expr: &Expr) -> Result<ColumnName> {
2468 match expr {
2469 Expr::Identifier(ident) => Ok(ColumnName::new(ident.value.clone())),
2470 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
2471 Ok(ColumnName::new(idents[1].value.clone()))
2472 }
2473 other => Err(QueryError::UnsupportedFeature(format!(
2474 "window PARTITION BY / argument must be a column reference, got: {other:?}"
2475 ))),
2476 }
2477}
2478
2479fn parse_window_function_name(
2480 name: &str,
2481 args: &[sqlparser::ast::FunctionArg],
2482) -> Result<crate::window::WindowFunction> {
2483 use crate::window::WindowFunction;
2484
2485 let arg_exprs: Vec<&Expr> = args
2486 .iter()
2487 .filter_map(|a| match a {
2488 sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(e)) => {
2489 Some(e)
2490 }
2491 _ => None,
2492 })
2493 .collect();
2494
2495 let single_col = || -> Result<ColumnName> {
2496 if arg_exprs.is_empty() {
2497 return Err(QueryError::ParseError(format!(
2498 "{name} requires a column argument"
2499 )));
2500 }
2501 parse_column_expr(arg_exprs[0])
2502 };
2503
2504 let parse_offset = || -> Result<usize> {
2505 if arg_exprs.len() < 2 {
2506 return Ok(1);
2507 }
2508 match arg_exprs[1] {
2509 Expr::Value(vws) => match &vws.value {
2510 SqlValue::Number(n, _) => n
2511 .parse::<usize>()
2512 .map_err(|_| QueryError::ParseError(format!("invalid {name} offset: {n}"))),
2513 other => Err(QueryError::UnsupportedFeature(format!(
2514 "{name} offset must be a literal integer; got {other:?}"
2515 ))),
2516 },
2517 other => Err(QueryError::UnsupportedFeature(format!(
2518 "{name} offset must be a literal integer; got {other:?}"
2519 ))),
2520 }
2521 };
2522
2523 match name {
2524 "ROW_NUMBER" => Ok(WindowFunction::RowNumber),
2525 "RANK" => Ok(WindowFunction::Rank),
2526 "DENSE_RANK" => Ok(WindowFunction::DenseRank),
2527 "LAG" => Ok(WindowFunction::Lag {
2528 column: single_col()?,
2529 offset: parse_offset()?,
2530 }),
2531 "LEAD" => Ok(WindowFunction::Lead {
2532 column: single_col()?,
2533 offset: parse_offset()?,
2534 }),
2535 "FIRST_VALUE" => Ok(WindowFunction::FirstValue {
2536 column: single_col()?,
2537 }),
2538 "LAST_VALUE" => Ok(WindowFunction::LastValue {
2539 column: single_col()?,
2540 }),
2541 other => Err(QueryError::UnsupportedFeature(format!(
2542 "unknown window function: {other}"
2543 ))),
2544 }
2545}
2546
2547type ParsedAggregate = (AggregateFunction, Option<Vec<Predicate>>);
2549
2550fn try_parse_aggregate(expr: &Expr) -> Result<Option<ParsedAggregate>> {
2554 let parsed_filter: Option<Vec<Predicate>> = match expr {
2555 Expr::Function(func) => match &func.filter {
2556 Some(filter_expr) => Some(parse_where_expr(filter_expr)?),
2557 None => None,
2558 },
2559 _ => None,
2560 };
2561 let func_only = try_parse_aggregate_func(expr)?;
2562 Ok(func_only.map(|f| (f, parsed_filter)))
2563}
2564
2565fn try_parse_aggregate_func(expr: &Expr) -> Result<Option<AggregateFunction>> {
2567 match expr {
2568 Expr::Function(func) => {
2569 if func.over.is_some() {
2574 return Ok(None);
2575 }
2576 let func_name = func.name.to_string().to_uppercase();
2577
2578 let args = match &func.args {
2580 sqlparser::ast::FunctionArguments::List(list) => &list.args,
2581 _ => {
2582 return Err(QueryError::UnsupportedFeature(
2583 "non-list function arguments not supported".to_string(),
2584 ));
2585 }
2586 };
2587
2588 match func_name.as_str() {
2589 "COUNT" => {
2590 if args.len() == 1 {
2592 match &args[0] {
2593 sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
2594 sqlparser::ast::FunctionArgExpr::Wildcard => {
2595 Ok(Some(AggregateFunction::CountStar))
2596 }
2597 sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
2598 Ok(Some(AggregateFunction::Count(ColumnName::new(
2599 ident.value.clone(),
2600 ))))
2601 }
2602 _ => Err(QueryError::UnsupportedFeature(
2603 "COUNT with complex expression not supported".to_string(),
2604 )),
2605 },
2606 _ => Err(QueryError::UnsupportedFeature(
2607 "named function arguments not supported".to_string(),
2608 )),
2609 }
2610 } else {
2611 Err(QueryError::ParseError(format!(
2612 "COUNT expects 1 argument, got {}",
2613 args.len()
2614 )))
2615 }
2616 }
2617 "SUM" | "AVG" | "MIN" | "MAX" => {
2618 if args.len() != 1 {
2620 return Err(QueryError::ParseError(format!(
2621 "{} expects 1 argument, got {}",
2622 func_name,
2623 args.len()
2624 )));
2625 }
2626
2627 match &args[0] {
2628 sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
2629 sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
2630 let column = ColumnName::new(ident.value.clone());
2631 match func_name.as_str() {
2632 "SUM" => Ok(Some(AggregateFunction::Sum(column))),
2633 "AVG" => Ok(Some(AggregateFunction::Avg(column))),
2634 "MIN" => Ok(Some(AggregateFunction::Min(column))),
2635 "MAX" => Ok(Some(AggregateFunction::Max(column))),
2636 _ => unreachable!(),
2637 }
2638 }
2639 _ => Err(QueryError::UnsupportedFeature(format!(
2640 "{func_name} with complex expression not supported"
2641 ))),
2642 },
2643 _ => Err(QueryError::UnsupportedFeature(
2644 "named function arguments not supported".to_string(),
2645 )),
2646 }
2647 }
2648 _ => {
2649 Ok(None)
2651 }
2652 }
2653 }
2654 _ => {
2655 Ok(None)
2657 }
2658 }
2659}
2660
2661fn parse_group_by_expr(exprs: &[Expr]) -> Result<Vec<ColumnName>> {
2663 let mut columns = Vec::new();
2664
2665 for expr in exprs {
2666 match expr {
2667 Expr::Identifier(ident) => {
2668 columns.push(ColumnName::new(ident.value.clone()));
2669 }
2670 _ => {
2671 return Err(QueryError::UnsupportedFeature(
2672 "complex GROUP BY expressions not supported".to_string(),
2673 ));
2674 }
2675 }
2676 }
2677
2678 Ok(columns)
2679}
2680
2681const MAX_WHERE_DEPTH: usize = 100;
2689
2690fn parse_where_expr(expr: &Expr) -> Result<Vec<Predicate>> {
2691 parse_where_expr_inner(expr, 0)
2692}
2693
2694fn parse_select_from_query(query: &sqlparser::ast::Query) -> Result<ParsedSelect> {
2700 match query.body.as_ref() {
2701 SetExpr::Select(s) => {
2702 let mut parsed = parse_select(s)?;
2703 if let Some(ob) = &query.order_by {
2704 parsed.order_by = parse_order_by(ob)?;
2705 }
2706 parsed.limit = parse_limit(query_limit_expr(query)?)?;
2707 parsed.offset = parse_offset_clause(query_offset(query))?;
2708 Ok(parsed)
2709 }
2710 _ => Err(QueryError::UnsupportedFeature(
2711 "subquery body must be a simple SELECT (no nested UNION/INTERSECT/EXCEPT)".to_string(),
2712 )),
2713 }
2714}
2715
2716fn parse_where_expr_inner(expr: &Expr, depth: usize) -> Result<Vec<Predicate>> {
2717 if depth >= MAX_WHERE_DEPTH {
2718 return Err(QueryError::ParseError(format!(
2719 "WHERE clause nesting exceeds maximum depth of {MAX_WHERE_DEPTH}"
2720 )));
2721 }
2722
2723 match expr {
2724 Expr::BinaryOp {
2726 left,
2727 op: BinaryOperator::And,
2728 right,
2729 } => {
2730 let mut predicates = parse_where_expr_inner(left, depth + 1)?;
2731 predicates.extend(parse_where_expr_inner(right, depth + 1)?);
2732 Ok(predicates)
2733 }
2734
2735 Expr::BinaryOp {
2737 left,
2738 op: BinaryOperator::Or,
2739 right,
2740 } => {
2741 let left_preds = parse_where_expr_inner(left, depth + 1)?;
2742 let right_preds = parse_where_expr_inner(right, depth + 1)?;
2743 Ok(vec![Predicate::Or(left_preds, right_preds)])
2744 }
2745
2746 Expr::Like {
2748 expr,
2749 pattern,
2750 negated,
2751 ..
2752 } => {
2753 let column = expr_to_column(expr)?;
2754 let pattern_str = match expr_to_predicate_value(pattern)? {
2755 PredicateValue::String(s) | PredicateValue::Literal(Value::Text(s)) => s,
2756 _ => {
2757 return Err(QueryError::UnsupportedFeature(
2758 "LIKE pattern must be a string literal".to_string(),
2759 ));
2760 }
2761 };
2762 let predicate = if *negated {
2763 Predicate::NotLike(column, pattern_str)
2764 } else {
2765 Predicate::Like(column, pattern_str)
2766 };
2767 Ok(vec![predicate])
2768 }
2769
2770 Expr::ILike {
2772 expr,
2773 pattern,
2774 negated,
2775 ..
2776 } => {
2777 let column = expr_to_column(expr)?;
2778 let pattern_str = match expr_to_predicate_value(pattern)? {
2779 PredicateValue::String(s) | PredicateValue::Literal(Value::Text(s)) => s,
2780 _ => {
2781 return Err(QueryError::UnsupportedFeature(
2782 "ILIKE pattern must be a string literal".to_string(),
2783 ));
2784 }
2785 };
2786 let predicate = if *negated {
2787 Predicate::NotILike(column, pattern_str)
2788 } else {
2789 Predicate::ILike(column, pattern_str)
2790 };
2791 Ok(vec![predicate])
2792 }
2793
2794 Expr::IsNull(expr) => {
2796 let column = expr_to_column(expr)?;
2797 Ok(vec![Predicate::IsNull(column)])
2798 }
2799
2800 Expr::IsNotNull(expr) => {
2801 let column = expr_to_column(expr)?;
2802 Ok(vec![Predicate::IsNotNull(column)])
2803 }
2804
2805 Expr::BinaryOp { left, op, right } => {
2807 let predicate = parse_comparison(left, op, right)?;
2808 Ok(vec![predicate])
2809 }
2810
2811 Expr::InList {
2813 expr,
2814 list,
2815 negated,
2816 } => {
2817 let column = expr_to_column(expr)?;
2818 let values: Result<Vec<_>> = list.iter().map(expr_to_predicate_value).collect();
2819 if *negated {
2820 Ok(vec![Predicate::NotIn(column, values?)])
2821 } else {
2822 Ok(vec![Predicate::In(column, values?)])
2823 }
2824 }
2825
2826 Expr::InSubquery {
2830 expr,
2831 subquery,
2832 negated,
2833 } => {
2834 let column = expr_to_column(expr)?;
2835 let inner = parse_select_from_query(subquery)?;
2836 Ok(vec![Predicate::InSubquery {
2837 column,
2838 subquery: Box::new(inner),
2839 negated: *negated,
2840 }])
2841 }
2842
2843 Expr::Exists { subquery, negated } => {
2845 let inner = parse_select_from_query(subquery)?;
2846 Ok(vec![Predicate::Exists {
2847 subquery: Box::new(inner),
2848 negated: *negated,
2849 }])
2850 }
2851
2852 Expr::Between {
2858 expr,
2859 negated,
2860 low,
2861 high,
2862 } => {
2863 let column = expr_to_column(expr)?;
2864 let low_val = expr_to_predicate_value(low)?;
2865 let high_val = expr_to_predicate_value(high)?;
2866
2867 if *negated {
2868 return Ok(vec![Predicate::NotBetween(column, low_val, high_val)]);
2869 }
2870
2871 kimberlite_properties::sometimes!(
2872 true,
2873 "query.between_desugared_to_ge_le",
2874 "BETWEEN predicate desugared into Ge + Le pair"
2875 );
2876
2877 Ok(vec![
2878 Predicate::Ge(column.clone(), low_val),
2879 Predicate::Le(column, high_val),
2880 ])
2881 }
2882
2883 Expr::Nested(inner) => parse_where_expr_inner(inner, depth + 1),
2885
2886 other => Err(QueryError::UnsupportedFeature(format!(
2887 "unsupported WHERE expression: {other:?}"
2888 ))),
2889 }
2890}
2891
2892fn parse_comparison(left: &Expr, op: &BinaryOperator, right: &Expr) -> Result<Predicate> {
2893 let left = match left {
2897 Expr::Nested(inner) => inner.as_ref(),
2898 other => other,
2899 };
2900
2901 if matches!(op, BinaryOperator::AtArrow) {
2904 let column = expr_to_column(left)?;
2905 let value = expr_to_predicate_value(right)?;
2906 return Ok(Predicate::JsonContains { column, value });
2907 }
2908
2909 if let Expr::BinaryOp {
2914 left: json_left,
2915 op: arrow_op @ (BinaryOperator::Arrow | BinaryOperator::LongArrow),
2916 right: path_expr,
2917 } = left
2918 {
2919 let as_text = matches!(arrow_op, BinaryOperator::LongArrow);
2920 let column = expr_to_column(json_left)?;
2921 let path = match path_expr.as_ref() {
2922 Expr::Value(vws) => match &vws.value {
2923 SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => s.clone(),
2924 SqlValue::Number(n, _) => n.clone(),
2925 _ => {
2926 return Err(QueryError::UnsupportedFeature(format!(
2927 "JSON path key must be a string or integer literal, got {path_expr:?}"
2928 )));
2929 }
2930 },
2931 other => {
2932 return Err(QueryError::UnsupportedFeature(format!(
2933 "JSON path key must be a string or integer literal, got {other:?}"
2934 )));
2935 }
2936 };
2937 let value = expr_to_predicate_value(right)?;
2938 if !matches!(op, BinaryOperator::Eq) {
2939 return Err(QueryError::UnsupportedFeature(format!(
2940 "JSON path extraction supports only `=` comparison; got {op:?}"
2941 )));
2942 }
2943 return Ok(Predicate::JsonExtractEq {
2944 column,
2945 path,
2946 as_text,
2947 value,
2948 });
2949 }
2950
2951 let cmp_op = sql_binop_to_scalar_cmp(op);
2955
2956 if !expr_needs_scalar(left) && !expr_needs_scalar(right) {
2959 if let (Ok(column), Ok(value)) = (expr_to_column(left), expr_to_predicate_value(right)) {
2960 return match op {
2961 BinaryOperator::Eq => Ok(Predicate::Eq(column, value)),
2962 BinaryOperator::Lt => Ok(Predicate::Lt(column, value)),
2963 BinaryOperator::LtEq => Ok(Predicate::Le(column, value)),
2964 BinaryOperator::Gt => Ok(Predicate::Gt(column, value)),
2965 BinaryOperator::GtEq => Ok(Predicate::Ge(column, value)),
2966 BinaryOperator::NotEq => {
2967 Ok(Predicate::ScalarCmp {
2970 lhs: ScalarExpr::Column(column),
2971 op: ScalarCmpOp::NotEq,
2972 rhs: predicate_value_to_scalar_expr(&value),
2973 })
2974 }
2975 other => Err(QueryError::UnsupportedFeature(format!(
2976 "unsupported operator: {other:?}"
2977 ))),
2978 };
2979 }
2980 }
2981
2982 let lhs = expr_to_scalar_expr(left)?;
2985 let rhs = expr_to_scalar_expr(right)?;
2986 let op = cmp_op.ok_or_else(|| {
2987 QueryError::UnsupportedFeature(format!("unsupported operator in scalar comparison: {op:?}"))
2988 })?;
2989 Ok(Predicate::ScalarCmp { lhs, op, rhs })
2990}
2991
2992fn expr_needs_scalar(expr: &Expr) -> bool {
2996 match expr {
2997 Expr::Function(_)
2998 | Expr::Cast { .. }
2999 | Expr::BinaryOp {
3000 op: BinaryOperator::StringConcat,
3001 ..
3002 } => true,
3003 Expr::Nested(inner) => expr_needs_scalar(inner),
3004 _ => false,
3005 }
3006}
3007
3008fn sql_binop_to_scalar_cmp(op: &BinaryOperator) -> Option<ScalarCmpOp> {
3009 Some(match op {
3010 BinaryOperator::Eq => ScalarCmpOp::Eq,
3011 BinaryOperator::NotEq => ScalarCmpOp::NotEq,
3012 BinaryOperator::Lt => ScalarCmpOp::Lt,
3013 BinaryOperator::LtEq => ScalarCmpOp::Le,
3014 BinaryOperator::Gt => ScalarCmpOp::Gt,
3015 BinaryOperator::GtEq => ScalarCmpOp::Ge,
3016 _ => return None,
3017 })
3018}
3019
3020fn predicate_value_to_scalar_expr(pv: &PredicateValue) -> ScalarExpr {
3021 match pv {
3022 PredicateValue::Int(n) => ScalarExpr::Literal(Value::BigInt(*n)),
3023 PredicateValue::String(s) => ScalarExpr::Literal(Value::Text(s.clone())),
3024 PredicateValue::Bool(b) => ScalarExpr::Literal(Value::Boolean(*b)),
3025 PredicateValue::Null => ScalarExpr::Literal(Value::Null),
3026 PredicateValue::Param(idx) => ScalarExpr::Literal(Value::Placeholder(*idx)),
3027 PredicateValue::Literal(v) => ScalarExpr::Literal(v.clone()),
3028 PredicateValue::ColumnRef(name) => {
3029 let col = name.rsplit('.').next().unwrap_or(name);
3031 ScalarExpr::Column(ColumnName::new(col.to_string()))
3032 }
3033 }
3034}
3035
3036fn expr_to_column(expr: &Expr) -> Result<ColumnName> {
3037 match expr {
3038 Expr::Identifier(ident) => Ok(ColumnName::new(ident.value.clone())),
3039 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
3040 Ok(ColumnName::new(idents[1].value.clone()))
3042 }
3043 other => Err(QueryError::UnsupportedFeature(format!(
3044 "expected column name, got {other:?}"
3045 ))),
3046 }
3047}
3048
3049pub fn expr_to_scalar_expr(expr: &Expr) -> Result<ScalarExpr> {
3062 match expr {
3063 Expr::Value(_) | Expr::UnaryOp { .. } => Ok(ScalarExpr::Literal(expr_to_value(expr)?)),
3066
3067 Expr::Identifier(ident) => Ok(ScalarExpr::Column(ColumnName::new(ident.value.clone()))),
3069 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
3070 Ok(ScalarExpr::Column(ColumnName::new(idents[1].value.clone())))
3071 }
3072
3073 Expr::BinaryOp {
3076 left,
3077 op: BinaryOperator::StringConcat,
3078 right,
3079 } => Ok(ScalarExpr::Concat(vec![
3080 expr_to_scalar_expr(left)?,
3081 expr_to_scalar_expr(right)?,
3082 ])),
3083
3084 Expr::Cast {
3086 expr: inner,
3087 data_type,
3088 ..
3089 } => {
3090 let target = sql_data_type_to_data_type(data_type)?;
3091 Ok(ScalarExpr::Cast(
3092 Box::new(expr_to_scalar_expr(inner)?),
3093 target,
3094 ))
3095 }
3096
3097 Expr::Nested(inner) => expr_to_scalar_expr(inner),
3099
3100 Expr::Function(func) => {
3102 if func.over.is_some() {
3103 return Err(QueryError::UnsupportedFeature(
3104 "window functions are not valid in this position".to_string(),
3105 ));
3106 }
3107 if func.filter.is_some() {
3108 return Err(QueryError::UnsupportedFeature(
3109 "FILTER clause only applies to aggregate functions".to_string(),
3110 ));
3111 }
3112 let name = func.name.to_string().to_uppercase();
3113 let args = match &func.args {
3114 sqlparser::ast::FunctionArguments::List(list) => &list.args,
3115 _ => {
3116 return Err(QueryError::UnsupportedFeature(
3117 "non-list function arguments not supported".to_string(),
3118 ));
3119 }
3120 };
3121
3122 let mut arg_exprs: Vec<&Expr> = Vec::with_capacity(args.len());
3124 for a in args {
3125 match a {
3126 sqlparser::ast::FunctionArg::Unnamed(
3127 sqlparser::ast::FunctionArgExpr::Expr(e),
3128 ) => arg_exprs.push(e),
3129 _ => {
3130 return Err(QueryError::UnsupportedFeature(format!(
3131 "unsupported argument form in scalar function {name}"
3132 )));
3133 }
3134 }
3135 }
3136
3137 let want_arity = |n: usize| -> Result<()> {
3138 if arg_exprs.len() == n {
3139 Ok(())
3140 } else {
3141 Err(QueryError::ParseError(format!(
3142 "{name} expects {n} argument(s), got {}",
3143 arg_exprs.len()
3144 )))
3145 }
3146 };
3147 let scalar = |e: &Expr| expr_to_scalar_expr(e);
3148
3149 match name.as_str() {
3150 "UPPER" => {
3151 want_arity(1)?;
3152 Ok(ScalarExpr::Upper(Box::new(scalar(arg_exprs[0])?)))
3153 }
3154 "LOWER" => {
3155 want_arity(1)?;
3156 Ok(ScalarExpr::Lower(Box::new(scalar(arg_exprs[0])?)))
3157 }
3158 "LENGTH" | "CHAR_LENGTH" | "CHARACTER_LENGTH" => {
3159 want_arity(1)?;
3160 Ok(ScalarExpr::Length(Box::new(scalar(arg_exprs[0])?)))
3161 }
3162 "TRIM" => {
3163 want_arity(1)?;
3164 Ok(ScalarExpr::Trim(Box::new(scalar(arg_exprs[0])?)))
3165 }
3166 "CONCAT" => {
3167 if arg_exprs.is_empty() {
3168 return Err(QueryError::ParseError(
3169 "CONCAT expects at least one argument".to_string(),
3170 ));
3171 }
3172 let parts = arg_exprs
3173 .iter()
3174 .map(|e| scalar(e))
3175 .collect::<Result<Vec<_>>>()?;
3176 Ok(ScalarExpr::Concat(parts))
3177 }
3178 "ABS" => {
3179 want_arity(1)?;
3180 Ok(ScalarExpr::Abs(Box::new(scalar(arg_exprs[0])?)))
3181 }
3182 "ROUND" => match arg_exprs.len() {
3183 1 => Ok(ScalarExpr::Round(Box::new(scalar(arg_exprs[0])?))),
3184 2 => {
3185 let n = match expr_to_value(arg_exprs[1])? {
3187 Value::BigInt(n) => i32::try_from(n).map_err(|_| {
3188 QueryError::ParseError("ROUND scale out of range".to_string())
3189 })?,
3190 other => {
3191 return Err(QueryError::ParseError(format!(
3192 "ROUND scale must be an integer literal, got {other:?}"
3193 )));
3194 }
3195 };
3196 Ok(ScalarExpr::RoundScale(Box::new(scalar(arg_exprs[0])?), n))
3197 }
3198 _ => Err(QueryError::ParseError(format!(
3199 "ROUND expects 1 or 2 arguments, got {}",
3200 arg_exprs.len()
3201 ))),
3202 },
3203 "CEIL" | "CEILING" => {
3204 want_arity(1)?;
3205 Ok(ScalarExpr::Ceil(Box::new(scalar(arg_exprs[0])?)))
3206 }
3207 "FLOOR" => {
3208 want_arity(1)?;
3209 Ok(ScalarExpr::Floor(Box::new(scalar(arg_exprs[0])?)))
3210 }
3211 "COALESCE" => {
3212 if arg_exprs.is_empty() {
3213 return Err(QueryError::ParseError(
3214 "COALESCE expects at least one argument".to_string(),
3215 ));
3216 }
3217 let parts = arg_exprs
3218 .iter()
3219 .map(|e| scalar(e))
3220 .collect::<Result<Vec<_>>>()?;
3221 Ok(ScalarExpr::Coalesce(parts))
3222 }
3223 "NULLIF" => {
3224 want_arity(2)?;
3225 Ok(ScalarExpr::Nullif(
3226 Box::new(scalar(arg_exprs[0])?),
3227 Box::new(scalar(arg_exprs[1])?),
3228 ))
3229 }
3230 "MOD" => {
3232 want_arity(2)?;
3233 Ok(ScalarExpr::Mod(
3234 Box::new(scalar(arg_exprs[0])?),
3235 Box::new(scalar(arg_exprs[1])?),
3236 ))
3237 }
3238 "POWER" | "POW" => {
3239 want_arity(2)?;
3240 Ok(ScalarExpr::Power(
3241 Box::new(scalar(arg_exprs[0])?),
3242 Box::new(scalar(arg_exprs[1])?),
3243 ))
3244 }
3245 "SQRT" => {
3246 want_arity(1)?;
3247 Ok(ScalarExpr::Sqrt(Box::new(scalar(arg_exprs[0])?)))
3248 }
3249 "SUBSTRING" | "SUBSTR" => {
3250 use kimberlite_types::SubstringRange;
3257 match arg_exprs.len() {
3258 2 => {
3259 let start = match expr_to_value(arg_exprs[1])? {
3260 Value::BigInt(n) => n,
3261 Value::Integer(n) => i64::from(n),
3262 other => {
3263 return Err(QueryError::ParseError(format!(
3264 "SUBSTRING start must be an integer literal, got {other:?}"
3265 )));
3266 }
3267 };
3268 Ok(ScalarExpr::Substring(
3269 Box::new(scalar(arg_exprs[0])?),
3270 SubstringRange::from_start(start),
3271 ))
3272 }
3273 3 => {
3274 let start = match expr_to_value(arg_exprs[1])? {
3275 Value::BigInt(n) => n,
3276 Value::Integer(n) => i64::from(n),
3277 other => {
3278 return Err(QueryError::ParseError(format!(
3279 "SUBSTRING start must be an integer literal, got {other:?}"
3280 )));
3281 }
3282 };
3283 let length = match expr_to_value(arg_exprs[2])? {
3284 Value::BigInt(n) => n,
3285 Value::Integer(n) => i64::from(n),
3286 other => {
3287 return Err(QueryError::ParseError(format!(
3288 "SUBSTRING length must be an integer literal, got {other:?}"
3289 )));
3290 }
3291 };
3292 let range = SubstringRange::try_new(start, length)
3293 .map_err(|e| QueryError::ParseError(format!("SUBSTRING: {e}")))?;
3294 Ok(ScalarExpr::Substring(
3295 Box::new(scalar(arg_exprs[0])?),
3296 range,
3297 ))
3298 }
3299 n => Err(QueryError::ParseError(format!(
3300 "SUBSTRING expects 2 or 3 arguments, got {n}"
3301 ))),
3302 }
3303 }
3304 "EXTRACT" => {
3305 use kimberlite_types::DateField;
3312 want_arity(2)?;
3313 let field_name = match expr_to_value(arg_exprs[0])? {
3314 Value::Text(s) => s,
3315 other => {
3316 return Err(QueryError::ParseError(format!(
3317 "EXTRACT field must be a string literal, got {other:?}"
3318 )));
3319 }
3320 };
3321 let field = DateField::parse(&field_name)
3322 .map_err(|e| QueryError::ParseError(format!("EXTRACT: {e}")))?;
3323 Ok(ScalarExpr::Extract(field, Box::new(scalar(arg_exprs[1])?)))
3324 }
3325 "DATE_TRUNC" | "DATETRUNC" => {
3326 use kimberlite_types::DateField;
3327 want_arity(2)?;
3328 let field_name = match expr_to_value(arg_exprs[0])? {
3329 Value::Text(s) => s,
3330 other => {
3331 return Err(QueryError::ParseError(format!(
3332 "DATE_TRUNC field must be a string literal, got {other:?}"
3333 )));
3334 }
3335 };
3336 let field = DateField::parse(&field_name)
3337 .map_err(|e| QueryError::ParseError(format!("DATE_TRUNC: {e}")))?;
3338 if !field.is_truncatable() {
3339 return Err(QueryError::ParseError(format!(
3340 "DATE_TRUNC field {field:?} is not truncatable (use one of YEAR, MONTH, DAY, HOUR, MINUTE, SECOND)"
3341 )));
3342 }
3343 Ok(ScalarExpr::DateTrunc(
3344 field,
3345 Box::new(scalar(arg_exprs[1])?),
3346 ))
3347 }
3348 "NOW" => {
3349 if !arg_exprs.is_empty() {
3350 return Err(QueryError::ParseError(format!(
3351 "NOW expects 0 arguments, got {}",
3352 arg_exprs.len()
3353 )));
3354 }
3355 Ok(ScalarExpr::Now)
3356 }
3357 "CURRENT_TIMESTAMP" => {
3358 if !arg_exprs.is_empty() {
3359 return Err(QueryError::ParseError(format!(
3360 "CURRENT_TIMESTAMP expects 0 arguments, got {}",
3361 arg_exprs.len()
3362 )));
3363 }
3364 Ok(ScalarExpr::CurrentTimestamp)
3365 }
3366 "CURRENT_DATE" => {
3367 if !arg_exprs.is_empty() {
3368 return Err(QueryError::ParseError(format!(
3369 "CURRENT_DATE expects 0 arguments, got {}",
3370 arg_exprs.len()
3371 )));
3372 }
3373 Ok(ScalarExpr::CurrentDate)
3374 }
3375 other => Err(QueryError::UnsupportedFeature(format!(
3376 "scalar function {other} is not supported"
3377 ))),
3378 }
3379 }
3380
3381 other => Err(QueryError::UnsupportedFeature(format!(
3382 "unsupported scalar expression: {other:?}"
3383 ))),
3384 }
3385}
3386
3387fn sql_data_type_to_data_type(sql_ty: &SqlDataType) -> Result<DataType> {
3391 Ok(match sql_ty {
3392 SqlDataType::TinyInt(_) => DataType::TinyInt,
3393 SqlDataType::SmallInt(_) => DataType::SmallInt,
3394 SqlDataType::Int(_) | SqlDataType::Integer(_) => DataType::Integer,
3395 SqlDataType::BigInt(_) => DataType::BigInt,
3396 SqlDataType::Real | SqlDataType::Float(_) | SqlDataType::Double(_) => DataType::Real,
3397 SqlDataType::Text | SqlDataType::Varchar(_) | SqlDataType::String(_) => DataType::Text,
3398 SqlDataType::Boolean | SqlDataType::Bool => DataType::Boolean,
3399 SqlDataType::Date => DataType::Date,
3400 SqlDataType::Time(_, _) => DataType::Time,
3401 SqlDataType::Timestamp(_, _) => DataType::Timestamp,
3402 SqlDataType::Uuid => DataType::Uuid,
3403 SqlDataType::JSON => DataType::Json,
3404 other => {
3405 return Err(QueryError::UnsupportedFeature(format!(
3406 "CAST to {other:?} is not supported"
3407 )));
3408 }
3409 })
3410}
3411
3412fn expr_to_predicate_value(expr: &Expr) -> Result<PredicateValue> {
3413 match expr {
3414 Expr::Identifier(ident) => {
3416 Ok(PredicateValue::ColumnRef(ident.value.clone()))
3418 }
3419 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
3420 Ok(PredicateValue::ColumnRef(format!(
3422 "{}.{}",
3423 idents[0].value, idents[1].value
3424 )))
3425 }
3426 Expr::Value(vws) => match &vws.value {
3427 SqlValue::Number(n, _) => {
3428 let value = parse_number_literal(n)?;
3429 match value {
3430 Value::BigInt(v) => Ok(PredicateValue::Int(v)),
3431 Value::Decimal(_, _) => Ok(PredicateValue::Literal(value)),
3432 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
3433 }
3434 }
3435 SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => {
3436 Ok(PredicateValue::String(s.clone()))
3437 }
3438 SqlValue::Boolean(b) => Ok(PredicateValue::Bool(*b)),
3439 SqlValue::Null => Ok(PredicateValue::Null),
3440 SqlValue::Placeholder(p) => Ok(PredicateValue::Param(parse_placeholder_index(p)?)),
3441 other => Err(QueryError::UnsupportedFeature(format!(
3442 "unsupported value expression: {other:?}"
3443 ))),
3444 },
3445 Expr::UnaryOp {
3446 op: sqlparser::ast::UnaryOperator::Minus,
3447 expr,
3448 } => {
3449 if let Expr::Value(vws) = expr.as_ref()
3451 && let SqlValue::Number(n, _) = &vws.value
3452 {
3453 let value = parse_number_literal(n)?;
3454 match value {
3455 Value::BigInt(v) => Ok(PredicateValue::Int(-v)),
3456 Value::Decimal(v, scale) => {
3457 Ok(PredicateValue::Literal(Value::Decimal(-v, scale)))
3458 }
3459 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
3460 }
3461 } else {
3462 Err(QueryError::UnsupportedFeature(format!(
3463 "unsupported unary minus operand: {expr:?}"
3464 )))
3465 }
3466 }
3467 other => Err(QueryError::UnsupportedFeature(format!(
3468 "unsupported value expression: {other:?}"
3469 ))),
3470 }
3471}
3472
3473fn parse_order_by(order_by: &sqlparser::ast::OrderBy) -> Result<Vec<OrderByClause>> {
3474 use sqlparser::ast::OrderByKind;
3475
3476 let exprs = match &order_by.kind {
3477 OrderByKind::Expressions(exprs) => exprs,
3478 OrderByKind::All(_) => {
3479 return Err(QueryError::UnsupportedFeature(
3480 "ORDER BY ALL is not supported".to_string(),
3481 ));
3482 }
3483 };
3484
3485 let mut clauses = Vec::new();
3486 for expr in exprs {
3487 clauses.push(parse_order_by_expr(expr)?);
3488 }
3489
3490 Ok(clauses)
3491}
3492
3493fn parse_order_by_expr(expr: &OrderByExpr) -> Result<OrderByClause> {
3494 let column = match &expr.expr {
3495 Expr::Identifier(ident) => ColumnName::new(ident.value.clone()),
3496 other => {
3497 return Err(QueryError::UnsupportedFeature(format!(
3498 "unsupported ORDER BY expression: {other:?}"
3499 )));
3500 }
3501 };
3502
3503 let ascending = expr.options.asc.unwrap_or(true);
3504
3505 Ok(OrderByClause { column, ascending })
3506}
3507
3508fn parse_limit(limit: Option<&Expr>) -> Result<Option<LimitExpr>> {
3509 match limit {
3510 None => Ok(None),
3511 Some(Expr::Value(vws)) => match &vws.value {
3512 SqlValue::Number(n, _) => {
3513 let v: usize = n
3514 .parse()
3515 .map_err(|_| QueryError::ParseError(format!("invalid LIMIT value: {n}")))?;
3516 Ok(Some(LimitExpr::Literal(v)))
3517 }
3518 SqlValue::Placeholder(p) => Ok(Some(LimitExpr::Param(parse_placeholder_index(p)?))),
3519 other => Err(QueryError::UnsupportedFeature(format!(
3520 "LIMIT must be an integer literal or parameter; got {other:?}"
3521 ))),
3522 },
3523 Some(other) => Err(QueryError::UnsupportedFeature(format!(
3524 "LIMIT must be an integer literal or parameter; got {other:?}"
3525 ))),
3526 }
3527}
3528
3529fn query_limit_expr(query: &Query) -> Result<Option<&Expr>> {
3532 use sqlparser::ast::LimitClause;
3533 match &query.limit_clause {
3534 None => Ok(None),
3535 Some(LimitClause::LimitOffset { limit, .. }) => Ok(limit.as_ref()),
3536 Some(LimitClause::OffsetCommaLimit { .. }) => Err(QueryError::UnsupportedFeature(
3537 "MySQL-style `LIMIT <offset>, <limit>` is not supported".to_string(),
3538 )),
3539 }
3540}
3541
3542fn query_offset(query: &Query) -> Option<&sqlparser::ast::Offset> {
3546 use sqlparser::ast::LimitClause;
3547 match &query.limit_clause {
3548 Some(LimitClause::LimitOffset { offset, .. }) => offset.as_ref(),
3549 _ => None,
3550 }
3551}
3552
3553fn parse_offset_clause(offset: Option<&sqlparser::ast::Offset>) -> Result<Option<LimitExpr>> {
3556 let Some(off) = offset else { return Ok(None) };
3557 match &off.value {
3558 Expr::Value(vws) => match &vws.value {
3559 SqlValue::Number(n, _) => {
3560 let v: usize = n
3561 .parse()
3562 .map_err(|_| QueryError::ParseError(format!("invalid OFFSET value: {n}")))?;
3563 Ok(Some(LimitExpr::Literal(v)))
3564 }
3565 SqlValue::Placeholder(p) => Ok(Some(LimitExpr::Param(parse_placeholder_index(p)?))),
3566 other => Err(QueryError::UnsupportedFeature(format!(
3567 "OFFSET must be an integer literal or parameter; got {other:?}"
3568 ))),
3569 },
3570 other => Err(QueryError::UnsupportedFeature(format!(
3571 "OFFSET must be an integer literal or parameter; got {other:?}"
3572 ))),
3573 }
3574}
3575
3576fn parse_placeholder_index(placeholder: &str) -> Result<usize> {
3582 let num_str = placeholder.strip_prefix('$').ok_or_else(|| {
3583 QueryError::ParseError(format!("unsupported placeholder format: {placeholder}"))
3584 })?;
3585 let idx: usize = num_str.parse().map_err(|_| {
3586 QueryError::ParseError(format!("invalid parameter placeholder: {placeholder}"))
3587 })?;
3588 if idx == 0 {
3589 return Err(QueryError::ParseError(
3590 "parameter indices start at $1, not $0".to_string(),
3591 ));
3592 }
3593 Ok(idx)
3594}
3595
3596fn object_name_to_string(name: &ObjectName) -> String {
3597 name.0
3598 .iter()
3599 .map(|part| match part.as_ident() {
3600 Some(ident) => ident.value.clone(),
3601 None => part.to_string(),
3602 })
3603 .collect::<Vec<_>>()
3604 .join(".")
3605}
3606
3607fn parse_create_table(create_table: &sqlparser::ast::CreateTable) -> Result<ParsedCreateTable> {
3612 let table_name = object_name_to_string(&create_table.name);
3613
3614 let mut raw_columns = Vec::new();
3624 for col_def in &create_table.columns {
3625 let parsed_col = parse_column_def(col_def)?;
3626 raw_columns.push(parsed_col);
3627 }
3628 let columns = NonEmptyVec::try_new(raw_columns).map_err(|_| {
3629 crate::error::QueryError::ParseError(format!(
3630 "CREATE TABLE {table_name} requires at least one column"
3631 ))
3632 })?;
3633
3634 let mut primary_key = Vec::new();
3636 for constraint in &create_table.constraints {
3637 if let sqlparser::ast::TableConstraint::PrimaryKey(pk) = constraint {
3638 for col in &pk.columns {
3639 if let Expr::Identifier(ident) = &col.column.expr {
3640 primary_key.push(ident.value.clone());
3641 } else {
3642 primary_key.push(col.column.expr.to_string());
3643 }
3644 }
3645 }
3646 }
3647
3648 if primary_key.is_empty() {
3650 for col_def in &create_table.columns {
3651 for option in &col_def.options {
3652 if matches!(&option.option, sqlparser::ast::ColumnOption::PrimaryKey(_)) {
3653 primary_key.push(col_def.name.value.clone());
3654 }
3655 }
3656 }
3657 }
3658
3659 Ok(ParsedCreateTable {
3660 table_name,
3661 columns,
3662 primary_key,
3663 if_not_exists: create_table.if_not_exists,
3664 })
3665}
3666
3667fn parse_column_def(col_def: &SqlColumnDef) -> Result<ParsedColumn> {
3668 let name = col_def.name.value.clone();
3669
3670 let data_type = match &col_def.data_type {
3673 SqlDataType::TinyInt(_) => "TINYINT".to_string(),
3675 SqlDataType::SmallInt(_) => "SMALLINT".to_string(),
3676 SqlDataType::Int(_) | SqlDataType::Integer(_) => "INTEGER".to_string(),
3677 SqlDataType::BigInt(_) => "BIGINT".to_string(),
3678
3679 SqlDataType::Real | SqlDataType::Float(_) | SqlDataType::Double(_) => "REAL".to_string(),
3681 SqlDataType::Decimal(precision_opt) => match precision_opt {
3682 sqlparser::ast::ExactNumberInfo::PrecisionAndScale(p, s) => {
3683 format!("DECIMAL({p},{s})")
3684 }
3685 sqlparser::ast::ExactNumberInfo::Precision(p) => {
3686 format!("DECIMAL({p},0)")
3687 }
3688 sqlparser::ast::ExactNumberInfo::None => "DECIMAL(18,2)".to_string(),
3689 },
3690
3691 SqlDataType::Text | SqlDataType::Varchar(_) | SqlDataType::String(_) => "TEXT".to_string(),
3693
3694 SqlDataType::Binary(_) | SqlDataType::Varbinary(_) | SqlDataType::Blob(_) => {
3696 "BYTES".to_string()
3697 }
3698
3699 SqlDataType::Boolean | SqlDataType::Bool => "BOOLEAN".to_string(),
3701
3702 SqlDataType::Date => "DATE".to_string(),
3704 SqlDataType::Time(_, _) => "TIME".to_string(),
3705 SqlDataType::Timestamp(_, _) => "TIMESTAMP".to_string(),
3706
3707 SqlDataType::Uuid => "UUID".to_string(),
3709 SqlDataType::JSON => "JSON".to_string(),
3710
3711 other => {
3712 return Err(QueryError::UnsupportedFeature(format!(
3713 "unsupported data type: {other:?}"
3714 )));
3715 }
3716 };
3717
3718 let mut nullable = true;
3720 for option in &col_def.options {
3721 if matches!(option.option, sqlparser::ast::ColumnOption::NotNull) {
3722 nullable = false;
3723 }
3724 }
3725
3726 Ok(ParsedColumn {
3727 name,
3728 data_type,
3729 nullable,
3730 })
3731}
3732
3733fn parse_alter_table(
3734 name: &sqlparser::ast::ObjectName,
3735 operations: &[sqlparser::ast::AlterTableOperation],
3736) -> Result<ParsedAlterTable> {
3737 let table_name = object_name_to_string(name);
3738
3739 if operations.len() != 1 {
3741 return Err(QueryError::UnsupportedFeature(
3742 "ALTER TABLE supports only one operation at a time".to_string(),
3743 ));
3744 }
3745
3746 let operation = match &operations[0] {
3747 sqlparser::ast::AlterTableOperation::AddColumn { column_def, .. } => {
3748 let parsed_col = parse_column_def(column_def)?;
3749 AlterTableOperation::AddColumn(parsed_col)
3750 }
3751 sqlparser::ast::AlterTableOperation::DropColumn {
3752 column_names,
3753 if_exists: _,
3754 ..
3755 } => {
3756 if column_names.len() != 1 {
3757 return Err(QueryError::UnsupportedFeature(
3758 "ALTER TABLE DROP COLUMN supports exactly one column".to_string(),
3759 ));
3760 }
3761 let col_name = column_names[0].value.clone();
3762 AlterTableOperation::DropColumn(col_name)
3763 }
3764 other => {
3765 return Err(QueryError::UnsupportedFeature(format!(
3766 "ALTER TABLE operation not supported: {other:?}"
3767 )));
3768 }
3769 };
3770
3771 Ok(ParsedAlterTable {
3772 table_name,
3773 operation,
3774 })
3775}
3776
3777fn parse_create_index(create_index: &sqlparser::ast::CreateIndex) -> Result<ParsedCreateIndex> {
3778 let index_name = match &create_index.name {
3779 Some(name) => object_name_to_string(name),
3780 None => {
3781 return Err(QueryError::ParseError(
3782 "CREATE INDEX requires an index name".to_string(),
3783 ));
3784 }
3785 };
3786
3787 let table_name = object_name_to_string(&create_index.table_name);
3788
3789 let mut columns = Vec::new();
3790 for col in &create_index.columns {
3791 columns.push(col.column.expr.to_string());
3792 }
3793
3794 Ok(ParsedCreateIndex {
3795 index_name,
3796 table_name,
3797 columns,
3798 })
3799}
3800
3801fn parse_insert(insert: &sqlparser::ast::Insert) -> Result<ParsedInsert> {
3806 let table = insert.table.to_string();
3808
3809 let columns: Vec<String> = insert.columns.iter().map(|c| c.value.clone()).collect();
3811
3812 let values = match insert.source.as_ref().map(|s| s.body.as_ref()) {
3814 Some(SetExpr::Values(values)) => {
3815 let mut all_rows = Vec::new();
3816 for row in &values.rows {
3817 let mut parsed_row = Vec::new();
3818 for expr in row {
3819 let val = expr_to_value(expr)?;
3820 parsed_row.push(val);
3821 }
3822 all_rows.push(parsed_row);
3823 }
3824 all_rows
3825 }
3826 _ => {
3827 return Err(QueryError::UnsupportedFeature(
3828 "only VALUES clause is supported in INSERT".to_string(),
3829 ));
3830 }
3831 };
3832
3833 let returning = parse_returning(insert.returning.as_ref())?;
3835
3836 let on_conflict = match insert.on.as_ref() {
3841 None => None,
3842 Some(sqlparser::ast::OnInsert::OnConflict(oc)) => Some(parse_on_conflict(oc)?),
3843 Some(sqlparser::ast::OnInsert::DuplicateKeyUpdate(_)) => {
3844 return Err(QueryError::UnsupportedFeature(
3845 "ON DUPLICATE KEY UPDATE is not supported; use ON CONFLICT (cols) DO UPDATE"
3846 .to_string(),
3847 ));
3848 }
3849 Some(other) => {
3850 return Err(QueryError::UnsupportedFeature(format!(
3851 "unsupported ON clause on INSERT: {other:?}"
3852 )));
3853 }
3854 };
3855
3856 Ok(ParsedInsert {
3857 table,
3858 columns,
3859 values,
3860 returning,
3861 on_conflict,
3862 })
3863}
3864
3865fn parse_on_conflict(oc: &sqlparser::ast::OnConflict) -> Result<OnConflictClause> {
3872 let target = match oc.conflict_target.as_ref() {
3874 Some(sqlparser::ast::ConflictTarget::Columns(cols)) => {
3875 if cols.is_empty() {
3876 return Err(QueryError::ParseError(
3877 "ON CONFLICT requires at least one target column".to_string(),
3878 ));
3879 }
3880 cols.iter().map(|i| i.value.clone()).collect()
3881 }
3882 Some(sqlparser::ast::ConflictTarget::OnConstraint(_)) => {
3883 return Err(QueryError::UnsupportedFeature(
3884 "ON CONFLICT ON CONSTRAINT <name> is not supported; use ON CONFLICT (cols) instead"
3885 .to_string(),
3886 ));
3887 }
3888 None => {
3889 return Err(QueryError::UnsupportedFeature(
3890 "ON CONFLICT without a target column list is not supported".to_string(),
3891 ));
3892 }
3893 };
3894
3895 let action = match &oc.action {
3896 sqlparser::ast::OnConflictAction::DoNothing => OnConflictAction::DoNothing,
3897 sqlparser::ast::OnConflictAction::DoUpdate(du) => {
3898 if du.selection.is_some() {
3899 return Err(QueryError::UnsupportedFeature(
3900 "ON CONFLICT DO UPDATE WHERE ... is not yet supported".to_string(),
3901 ));
3902 }
3903 let mut assignments = Vec::with_capacity(du.assignments.len());
3904 for a in &du.assignments {
3905 let col = a.target.to_string();
3906 let rhs = parse_upsert_expr(&a.value)?;
3907 assignments.push((col, rhs));
3908 }
3909 OnConflictAction::DoUpdate { assignments }
3910 }
3911 };
3912
3913 Ok(OnConflictClause { target, action })
3914}
3915
3916fn parse_upsert_expr(expr: &Expr) -> Result<UpsertExpr> {
3923 if let Expr::CompoundIdentifier(parts) = expr {
3925 if parts.len() == 2 && parts[0].value.eq_ignore_ascii_case("EXCLUDED") {
3926 return Ok(UpsertExpr::Excluded(parts[1].value.clone()));
3927 }
3928 }
3929 let v = expr_to_value(expr)?;
3932 Ok(UpsertExpr::Value(v))
3933}
3934
3935fn parse_update(
3936 table: &sqlparser::ast::TableWithJoins,
3937 assignments: &[sqlparser::ast::Assignment],
3938 selection: Option<&Expr>,
3939 returning: Option<&Vec<SelectItem>>,
3940) -> Result<ParsedUpdate> {
3941 let table_name = match &table.relation {
3942 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
3943 other => {
3944 return Err(QueryError::UnsupportedFeature(format!(
3945 "unsupported table in UPDATE: {other:?}"
3946 )));
3947 }
3948 };
3949
3950 let mut parsed_assignments = Vec::new();
3952 for assignment in assignments {
3953 let col_name = assignment.target.to_string();
3954 let value = expr_to_value(&assignment.value)?;
3955 parsed_assignments.push((col_name, value));
3956 }
3957
3958 let predicates = match selection {
3960 Some(expr) => parse_where_expr(expr)?,
3961 None => vec![],
3962 };
3963
3964 let returning_cols = parse_returning(returning)?;
3966
3967 Ok(ParsedUpdate {
3968 table: table_name,
3969 assignments: parsed_assignments,
3970 predicates,
3971 returning: returning_cols,
3972 })
3973}
3974
3975fn parse_delete_stmt(delete: &sqlparser::ast::Delete) -> Result<ParsedDelete> {
3976 use sqlparser::ast::FromTable;
3978
3979 let table_name = match &delete.from {
3980 FromTable::WithFromKeyword(tables) => {
3981 if tables.len() != 1 {
3982 return Err(QueryError::ParseError(
3983 "expected exactly 1 table in DELETE FROM".to_string(),
3984 ));
3985 }
3986
3987 match &tables[0].relation {
3988 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
3989 _ => {
3990 return Err(QueryError::ParseError(
3991 "DELETE only supports simple table names".to_string(),
3992 ));
3993 }
3994 }
3995 }
3996 FromTable::WithoutKeyword(tables) => {
3997 if tables.len() != 1 {
3998 return Err(QueryError::ParseError(
3999 "expected exactly 1 table in DELETE".to_string(),
4000 ));
4001 }
4002
4003 match &tables[0].relation {
4004 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
4005 _ => {
4006 return Err(QueryError::ParseError(
4007 "DELETE only supports simple table names".to_string(),
4008 ));
4009 }
4010 }
4011 }
4012 };
4013
4014 let predicates = match &delete.selection {
4016 Some(expr) => parse_where_expr(expr)?,
4017 None => vec![],
4018 };
4019
4020 let returning_cols = parse_returning(delete.returning.as_ref())?;
4022
4023 Ok(ParsedDelete {
4024 table: table_name,
4025 predicates,
4026 returning: returning_cols,
4027 })
4028}
4029
4030fn parse_returning(returning: Option<&Vec<SelectItem>>) -> Result<Option<Vec<String>>> {
4032 match returning {
4033 None => Ok(None),
4034 Some(items) => {
4035 let mut columns = Vec::new();
4036 for item in items {
4037 match item {
4038 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
4039 columns.push(ident.value.clone());
4040 }
4041 SelectItem::UnnamedExpr(Expr::CompoundIdentifier(parts)) => {
4042 if let Some(last) = parts.last() {
4044 columns.push(last.value.clone());
4045 } else {
4046 return Err(QueryError::ParseError(
4047 "invalid column in RETURNING clause".to_string(),
4048 ));
4049 }
4050 }
4051 _ => {
4052 return Err(QueryError::UnsupportedFeature(
4053 "only simple column names supported in RETURNING clause".to_string(),
4054 ));
4055 }
4056 }
4057 }
4058 Ok(Some(columns))
4059 }
4060 }
4061}
4062
4063fn parse_number_literal(n: &str) -> Result<Value> {
4067 use rust_decimal::Decimal;
4068 use std::str::FromStr;
4069
4070 if n.contains('.') {
4071 let decimal = Decimal::from_str(n)
4073 .map_err(|e| QueryError::ParseError(format!("invalid decimal '{n}': {e}")))?;
4074
4075 let scale = decimal.scale() as u8;
4077
4078 if scale > 38 {
4079 return Err(QueryError::ParseError(format!(
4080 "decimal scale too large (max 38): {n}"
4081 )));
4082 }
4083
4084 let mantissa = decimal.mantissa();
4087
4088 Ok(Value::Decimal(mantissa, scale))
4089 } else {
4090 let v: i64 = n
4092 .parse()
4093 .map_err(|_| QueryError::ParseError(format!("invalid integer: {n}")))?;
4094 Ok(Value::BigInt(v))
4095 }
4096}
4097
4098fn expr_to_value(expr: &Expr) -> Result<Value> {
4100 match expr {
4101 Expr::Value(vws) => match &vws.value {
4102 SqlValue::Number(n, _) => parse_number_literal(n),
4103 SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => {
4104 Ok(Value::Text(s.clone()))
4105 }
4106 SqlValue::Boolean(b) => Ok(Value::Boolean(*b)),
4107 SqlValue::Null => Ok(Value::Null),
4108 SqlValue::Placeholder(p) => Ok(Value::Placeholder(parse_placeholder_index(p)?)),
4109 other => Err(QueryError::UnsupportedFeature(format!(
4110 "unsupported value expression: {other:?}"
4111 ))),
4112 },
4113 Expr::UnaryOp {
4114 op: sqlparser::ast::UnaryOperator::Minus,
4115 expr,
4116 } => {
4117 if let Expr::Value(vws) = expr.as_ref()
4119 && let SqlValue::Number(n, _) = &vws.value
4120 {
4121 let value = parse_number_literal(n)?;
4122 match value {
4123 Value::BigInt(v) => Ok(Value::BigInt(-v)),
4124 Value::Decimal(v, scale) => Ok(Value::Decimal(-v, scale)),
4125 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
4126 }
4127 } else {
4128 Err(QueryError::UnsupportedFeature(format!(
4129 "unsupported unary minus operand: {expr:?}"
4130 )))
4131 }
4132 }
4133 other => Err(QueryError::UnsupportedFeature(format!(
4134 "unsupported value expression: {other:?}"
4135 ))),
4136 }
4137}
4138
4139#[cfg(test)]
4140mod tests {
4141 use super::*;
4142
4143 fn parse_test_select(sql: &str) -> ParsedSelect {
4144 match parse_statement(sql).unwrap() {
4145 ParsedStatement::Select(s) => s,
4146 _ => panic!("expected SELECT statement"),
4147 }
4148 }
4149
4150 #[test]
4151 fn test_parse_simple_select() {
4152 let result = parse_test_select("SELECT id, name FROM users");
4153 assert_eq!(result.table, "users");
4154 assert_eq!(
4155 result.columns,
4156 Some(vec![ColumnName::new("id"), ColumnName::new("name")])
4157 );
4158 assert!(result.predicates.is_empty());
4159 }
4160
4161 #[test]
4162 fn test_parse_select_star() {
4163 let result = parse_test_select("SELECT * FROM users");
4164 assert_eq!(result.table, "users");
4165 assert!(result.columns.is_none());
4166 }
4167
4168 #[test]
4169 fn test_parse_where_eq() {
4170 let result = parse_test_select("SELECT * FROM users WHERE id = 42");
4171 assert_eq!(result.predicates.len(), 1);
4172 match &result.predicates[0] {
4173 Predicate::Eq(col, PredicateValue::Int(42)) => {
4174 assert_eq!(col.as_str(), "id");
4175 }
4176 other => panic!("unexpected predicate: {other:?}"),
4177 }
4178 }
4179
4180 #[test]
4181 fn test_parse_where_string() {
4182 let result = parse_test_select("SELECT * FROM users WHERE name = 'alice'");
4183 match &result.predicates[0] {
4184 Predicate::Eq(col, PredicateValue::String(s)) => {
4185 assert_eq!(col.as_str(), "name");
4186 assert_eq!(s, "alice");
4187 }
4188 other => panic!("unexpected predicate: {other:?}"),
4189 }
4190 }
4191
4192 #[test]
4193 fn test_parse_where_and() {
4194 let result = parse_test_select("SELECT * FROM users WHERE id = 1 AND name = 'bob'");
4195 assert_eq!(result.predicates.len(), 2);
4196 }
4197
4198 #[test]
4199 fn test_parse_where_in() {
4200 let result = parse_test_select("SELECT * FROM users WHERE id IN (1, 2, 3)");
4201 match &result.predicates[0] {
4202 Predicate::In(col, values) => {
4203 assert_eq!(col.as_str(), "id");
4204 assert_eq!(values.len(), 3);
4205 }
4206 other => panic!("unexpected predicate: {other:?}"),
4207 }
4208 }
4209
4210 #[test]
4211 fn test_parse_order_by() {
4212 let result = parse_test_select("SELECT * FROM users ORDER BY name ASC, id DESC");
4213 assert_eq!(result.order_by.len(), 2);
4214 assert_eq!(result.order_by[0].column.as_str(), "name");
4215 assert!(result.order_by[0].ascending);
4216 assert_eq!(result.order_by[1].column.as_str(), "id");
4217 assert!(!result.order_by[1].ascending);
4218 }
4219
4220 #[test]
4221 fn test_parse_limit() {
4222 let result = parse_test_select("SELECT * FROM users LIMIT 10");
4223 assert_eq!(result.limit, Some(LimitExpr::Literal(10)));
4224 }
4225
4226 #[test]
4227 fn test_parse_limit_param() {
4228 let result = parse_test_select("SELECT * FROM users LIMIT $1");
4229 assert_eq!(result.limit, Some(LimitExpr::Param(1)));
4230 }
4231
4232 #[test]
4233 fn test_parse_offset_literal() {
4234 let result = parse_test_select("SELECT * FROM users LIMIT 10 OFFSET 5");
4235 assert_eq!(result.limit, Some(LimitExpr::Literal(10)));
4236 assert_eq!(result.offset, Some(LimitExpr::Literal(5)));
4237 }
4238
4239 #[test]
4240 fn test_parse_offset_param() {
4241 let result = parse_test_select("SELECT * FROM users LIMIT $1 OFFSET $2");
4242 assert_eq!(result.limit, Some(LimitExpr::Param(1)));
4243 assert_eq!(result.offset, Some(LimitExpr::Param(2)));
4244 }
4245
4246 #[test]
4247 fn test_parse_param() {
4248 let result = parse_test_select("SELECT * FROM users WHERE id = $1");
4249 match &result.predicates[0] {
4250 Predicate::Eq(_, PredicateValue::Param(1)) => {}
4251 other => panic!("unexpected predicate: {other:?}"),
4252 }
4253 }
4254
4255 #[test]
4256 fn test_parse_inner_join() {
4257 let result =
4258 parse_statement("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
4259 if let Err(ref e) = result {
4260 eprintln!("Parse error: {e:?}");
4261 }
4262 assert!(result.is_ok());
4263 match result.unwrap() {
4264 ParsedStatement::Select(s) => {
4265 assert_eq!(s.table, "users");
4266 assert_eq!(s.joins.len(), 1);
4267 assert_eq!(s.joins[0].table, "orders");
4268 assert!(matches!(s.joins[0].join_type, JoinType::Inner));
4269 }
4270 _ => panic!("expected SELECT statement"),
4271 }
4272 }
4273
4274 #[test]
4275 fn test_parse_left_join() {
4276 let result =
4277 parse_statement("SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id");
4278 assert!(result.is_ok());
4279 match result.unwrap() {
4280 ParsedStatement::Select(s) => {
4281 assert_eq!(s.table, "users");
4282 assert_eq!(s.joins.len(), 1);
4283 assert_eq!(s.joins[0].table, "orders");
4284 assert!(matches!(s.joins[0].join_type, JoinType::Left));
4285 }
4286 _ => panic!("expected SELECT statement"),
4287 }
4288 }
4289
4290 #[test]
4291 fn test_parse_multi_join() {
4292 let result = parse_statement(
4293 "SELECT * FROM users \
4294 JOIN orders ON users.id = orders.user_id \
4295 JOIN products ON orders.product_id = products.id",
4296 );
4297 assert!(result.is_ok());
4298 match result.unwrap() {
4299 ParsedStatement::Select(s) => {
4300 assert_eq!(s.table, "users");
4301 assert_eq!(s.joins.len(), 2);
4302 assert_eq!(s.joins[0].table, "orders");
4303 assert_eq!(s.joins[1].table, "products");
4304 }
4305 _ => panic!("expected SELECT statement"),
4306 }
4307 }
4308
4309 #[test]
4310 fn test_reject_subquery() {
4311 let result = parse_statement("SELECT * FROM (SELECT * FROM users)");
4312 assert!(result.is_err());
4313 }
4314
4315 #[test]
4316 fn test_where_depth_within_limit() {
4317 let mut sql = String::from("SELECT * FROM users WHERE ");
4320 for i in 0..10 {
4321 if i > 0 {
4322 sql.push_str(" AND ");
4323 }
4324 sql.push('(');
4325 sql.push_str("id = ");
4326 sql.push_str(&i.to_string());
4327 sql.push(')');
4328 }
4329
4330 let result = parse_statement(&sql);
4331 assert!(
4332 result.is_ok(),
4333 "Moderate nesting should succeed, but got: {result:?}"
4334 );
4335 }
4336
4337 #[test]
4338 fn test_where_depth_nested_parens() {
4339 let mut sql = String::from("SELECT * FROM users WHERE ");
4342 for _ in 0..200 {
4343 sql.push('(');
4344 }
4345 sql.push_str("id = 1");
4346 for _ in 0..200 {
4347 sql.push(')');
4348 }
4349
4350 let result = parse_statement(&sql);
4351 assert!(
4352 result.is_err(),
4353 "Excessive parenthesis nesting should be rejected"
4354 );
4355 }
4356
4357 #[test]
4358 fn test_where_depth_complex_and_or() {
4359 let sql = "SELECT * FROM users WHERE \
4361 ((id = 1 AND name = 'a') OR (id = 2 AND name = 'b')) AND \
4362 ((age > 10 AND age < 20) OR (age > 30 AND age < 40))";
4363
4364 let result = parse_statement(sql);
4365 assert!(result.is_ok(), "Complex AND/OR should succeed");
4366 }
4367
4368 #[test]
4369 fn test_parse_having() {
4370 let result =
4371 parse_test_select("SELECT name, COUNT(*) FROM users GROUP BY name HAVING COUNT(*) > 5");
4372 assert_eq!(result.group_by.len(), 1);
4373 assert_eq!(result.having.len(), 1);
4374 match &result.having[0] {
4375 HavingCondition::AggregateComparison {
4376 aggregate,
4377 op,
4378 value,
4379 } => {
4380 assert!(matches!(aggregate, AggregateFunction::CountStar));
4381 assert_eq!(*op, HavingOp::Gt);
4382 assert_eq!(*value, Value::BigInt(5));
4383 }
4384 }
4385 }
4386
4387 #[test]
4388 fn test_parse_having_multiple() {
4389 let result = parse_test_select(
4390 "SELECT name, COUNT(*), SUM(age) FROM users GROUP BY name HAVING COUNT(*) > 1 AND SUM(age) < 100",
4391 );
4392 assert_eq!(result.having.len(), 2);
4393 }
4394
4395 #[test]
4396 fn test_parse_having_without_group_by() {
4397 let result = parse_test_select("SELECT COUNT(*) FROM users HAVING COUNT(*) > 0");
4398 assert!(result.group_by.is_empty());
4399 assert_eq!(result.having.len(), 1);
4400 }
4401
4402 #[test]
4403 fn test_parse_union() {
4404 let result = parse_statement("SELECT id FROM users UNION SELECT id FROM orders");
4405 assert!(result.is_ok());
4406 match result.unwrap() {
4407 ParsedStatement::Union(u) => {
4408 assert_eq!(u.left.table, "users");
4409 assert_eq!(u.right.table, "orders");
4410 assert!(!u.all);
4411 }
4412 _ => panic!("expected UNION statement"),
4413 }
4414 }
4415
4416 #[test]
4417 fn test_parse_union_all() {
4418 let result = parse_statement("SELECT id FROM users UNION ALL SELECT id FROM orders");
4419 assert!(result.is_ok());
4420 match result.unwrap() {
4421 ParsedStatement::Union(u) => {
4422 assert_eq!(u.left.table, "users");
4423 assert_eq!(u.right.table, "orders");
4424 assert!(u.all);
4425 }
4426 _ => panic!("expected UNION ALL statement"),
4427 }
4428 }
4429
4430 #[test]
4431 fn test_parse_create_mask() {
4432 let result = parse_statement("CREATE MASK ssn_mask ON patients.ssn USING REDACT").unwrap();
4433 match result {
4434 ParsedStatement::CreateMask(m) => {
4435 assert_eq!(m.mask_name, "ssn_mask");
4436 assert_eq!(m.table_name, "patients");
4437 assert_eq!(m.column_name, "ssn");
4438 assert_eq!(m.strategy, "REDACT");
4439 }
4440 _ => panic!("expected CREATE MASK statement"),
4441 }
4442 }
4443
4444 #[test]
4445 fn test_parse_create_mask_with_semicolon() {
4446 let result = parse_statement("CREATE MASK ssn_mask ON patients.ssn USING REDACT;").unwrap();
4447 match result {
4448 ParsedStatement::CreateMask(m) => {
4449 assert_eq!(m.mask_name, "ssn_mask");
4450 assert_eq!(m.strategy, "REDACT");
4451 }
4452 _ => panic!("expected CREATE MASK statement"),
4453 }
4454 }
4455
4456 #[test]
4457 fn test_parse_create_mask_hash_strategy() {
4458 let result = parse_statement("CREATE MASK email_hash ON users.email USING HASH").unwrap();
4459 match result {
4460 ParsedStatement::CreateMask(m) => {
4461 assert_eq!(m.mask_name, "email_hash");
4462 assert_eq!(m.table_name, "users");
4463 assert_eq!(m.column_name, "email");
4464 assert_eq!(m.strategy, "HASH");
4465 }
4466 _ => panic!("expected CREATE MASK statement"),
4467 }
4468 }
4469
4470 #[test]
4471 fn test_parse_create_mask_missing_on() {
4472 let result = parse_statement("CREATE MASK ssn_mask patients.ssn USING REDACT");
4473 assert!(result.is_err());
4474 }
4475
4476 #[test]
4477 fn test_parse_create_mask_missing_dot() {
4478 let result = parse_statement("CREATE MASK ssn_mask ON patients_ssn USING REDACT");
4479 assert!(result.is_err());
4480 }
4481
4482 #[test]
4483 fn test_parse_drop_mask() {
4484 let result = parse_statement("DROP MASK ssn_mask").unwrap();
4485 match result {
4486 ParsedStatement::DropMask(name) => {
4487 assert_eq!(name, "ssn_mask");
4488 }
4489 _ => panic!("expected DROP MASK statement"),
4490 }
4491 }
4492
4493 #[test]
4494 fn test_parse_drop_mask_with_semicolon() {
4495 let result = parse_statement("DROP MASK ssn_mask;").unwrap();
4496 match result {
4497 ParsedStatement::DropMask(name) => {
4498 assert_eq!(name, "ssn_mask");
4499 }
4500 _ => panic!("expected DROP MASK statement"),
4501 }
4502 }
4503
4504 #[test]
4509 fn test_parse_create_masking_policy_redact_ssn() {
4510 let result = parse_statement(
4511 "CREATE MASKING POLICY ssn_policy STRATEGY REDACT_SSN EXEMPT ROLES ('clinician', 'billing')",
4512 )
4513 .unwrap();
4514 match result {
4515 ParsedStatement::CreateMaskingPolicy(p) => {
4516 assert_eq!(p.name, "ssn_policy");
4517 assert_eq!(p.strategy, ParsedMaskingStrategy::RedactSsn);
4518 assert_eq!(p.exempt_roles, vec!["clinician", "billing"]);
4519 }
4520 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4521 }
4522 }
4523
4524 #[test]
4525 fn test_parse_create_masking_policy_hash_single_role() {
4526 let result =
4527 parse_statement("CREATE MASKING POLICY h STRATEGY HASH EXEMPT ROLES (admin)").unwrap();
4528 match result {
4529 ParsedStatement::CreateMaskingPolicy(p) => {
4530 assert_eq!(p.name, "h");
4531 assert_eq!(p.strategy, ParsedMaskingStrategy::Hash);
4532 assert_eq!(p.exempt_roles, vec!["admin"]);
4533 }
4534 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4535 }
4536 }
4537
4538 #[test]
4539 fn test_parse_create_masking_policy_tokenize() {
4540 let result = parse_statement(
4541 "CREATE MASKING POLICY note_tok STRATEGY TOKENIZE EXEMPT ROLES ('clinician');",
4542 )
4543 .unwrap();
4544 match result {
4545 ParsedStatement::CreateMaskingPolicy(p) => {
4546 assert_eq!(p.strategy, ParsedMaskingStrategy::Tokenize);
4547 assert_eq!(p.exempt_roles, vec!["clinician"]);
4548 }
4549 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4550 }
4551 }
4552
4553 #[test]
4554 fn test_parse_create_masking_policy_truncate_with_arg() {
4555 let result = parse_statement(
4556 "CREATE MASKING POLICY tr STRATEGY TRUNCATE 4 EXEMPT ROLES ('billing')",
4557 )
4558 .unwrap();
4559 match result {
4560 ParsedStatement::CreateMaskingPolicy(p) => {
4561 assert_eq!(p.strategy, ParsedMaskingStrategy::Truncate { max_chars: 4 });
4562 }
4563 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4564 }
4565 }
4566
4567 #[test]
4568 fn test_parse_create_masking_policy_redact_custom() {
4569 let result = parse_statement(
4570 "CREATE MASKING POLICY c STRATEGY REDACT_CUSTOM '***' EXEMPT ROLES ('admin')",
4571 )
4572 .unwrap();
4573 match result {
4574 ParsedStatement::CreateMaskingPolicy(p) => match p.strategy {
4575 ParsedMaskingStrategy::RedactCustom { replacement } => {
4576 assert_eq!(replacement, "***");
4577 }
4578 other => panic!("expected RedactCustom, got {other:?}"),
4579 },
4580 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4581 }
4582 }
4583
4584 #[test]
4585 fn test_parse_create_masking_policy_null_strategy() {
4586 let result =
4587 parse_statement("CREATE MASKING POLICY n STRATEGY NULL EXEMPT ROLES ('auditor')")
4588 .unwrap();
4589 match result {
4590 ParsedStatement::CreateMaskingPolicy(p) => {
4591 assert_eq!(p.strategy, ParsedMaskingStrategy::Null);
4592 }
4593 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4594 }
4595 }
4596
4597 #[test]
4598 fn test_parse_create_masking_policy_lowercases_roles() {
4599 let result = parse_statement(
4602 "CREATE MASKING POLICY p STRATEGY HASH EXEMPT ROLES ('Clinician', 'NURSE')",
4603 )
4604 .unwrap();
4605 match result {
4606 ParsedStatement::CreateMaskingPolicy(p) => {
4607 assert_eq!(p.exempt_roles, vec!["clinician", "nurse"]);
4608 }
4609 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4610 }
4611 }
4612
4613 #[test]
4614 fn test_parse_create_masking_policy_rejects_unknown_strategy() {
4615 let result =
4616 parse_statement("CREATE MASKING POLICY p STRATEGY SCRAMBLE EXEMPT ROLES ('x')");
4617 assert!(result.is_err(), "expected unknown-strategy error");
4618 }
4619
4620 #[test]
4621 fn test_parse_create_masking_policy_rejects_zero_truncate() {
4622 let result =
4623 parse_statement("CREATE MASKING POLICY p STRATEGY TRUNCATE 0 EXEMPT ROLES ('x')");
4624 assert!(result.is_err(), "TRUNCATE 0 must be rejected");
4625 }
4626
4627 #[test]
4628 fn test_parse_create_masking_policy_rejects_empty_exempt_list() {
4629 let result = parse_statement("CREATE MASKING POLICY p STRATEGY HASH EXEMPT ROLES ()");
4630 assert!(result.is_err(), "empty EXEMPT ROLES list must be rejected");
4631 }
4632
4633 #[test]
4634 fn test_parse_create_masking_policy_rejects_missing_exempt_roles() {
4635 let result = parse_statement("CREATE MASKING POLICY p STRATEGY HASH");
4636 assert!(
4637 result.is_err(),
4638 "missing EXEMPT ROLES clause must be rejected"
4639 );
4640 }
4641
4642 #[test]
4643 fn test_parse_drop_masking_policy() {
4644 let result = parse_statement("DROP MASKING POLICY ssn_policy").unwrap();
4645 match result {
4646 ParsedStatement::DropMaskingPolicy(name) => {
4647 assert_eq!(name, "ssn_policy");
4648 }
4649 other => panic!("expected DropMaskingPolicy, got {other:?}"),
4650 }
4651 }
4652
4653 #[test]
4654 fn test_parse_drop_masking_policy_with_semicolon() {
4655 let result = parse_statement("DROP MASKING POLICY ssn_policy;").unwrap();
4656 match result {
4657 ParsedStatement::DropMaskingPolicy(name) => {
4658 assert_eq!(name, "ssn_policy");
4659 }
4660 other => panic!("expected DropMaskingPolicy, got {other:?}"),
4661 }
4662 }
4663
4664 #[test]
4665 fn test_parse_drop_masking_policy_does_not_swallow_drop_mask() {
4666 let result = parse_statement("DROP MASK ssn_mask").unwrap();
4670 assert!(matches!(result, ParsedStatement::DropMask(_)));
4671 }
4672
4673 #[test]
4674 fn test_parse_attach_masking_policy() {
4675 let result = parse_statement(
4676 "ALTER TABLE patients ALTER COLUMN medicare_number SET MASKING POLICY ssn_policy",
4677 )
4678 .unwrap();
4679 match result {
4680 ParsedStatement::AttachMaskingPolicy(a) => {
4681 assert_eq!(a.table_name, "patients");
4682 assert_eq!(a.column_name, "medicare_number");
4683 assert_eq!(a.policy_name, "ssn_policy");
4684 }
4685 other => panic!("expected AttachMaskingPolicy, got {other:?}"),
4686 }
4687 }
4688
4689 #[test]
4690 fn test_parse_detach_masking_policy() {
4691 let result = parse_statement(
4692 "ALTER TABLE patients ALTER COLUMN medicare_number DROP MASKING POLICY",
4693 )
4694 .unwrap();
4695 match result {
4696 ParsedStatement::DetachMaskingPolicy(d) => {
4697 assert_eq!(d.table_name, "patients");
4698 assert_eq!(d.column_name, "medicare_number");
4699 }
4700 other => panic!("expected DetachMaskingPolicy, got {other:?}"),
4701 }
4702 }
4703
4704 #[test]
4705 fn test_parse_attach_masking_policy_rejects_missing_policy_name() {
4706 let result =
4707 parse_statement("ALTER TABLE patients ALTER COLUMN medicare_number SET MASKING POLICY");
4708 assert!(result.is_err());
4709 }
4710
4711 #[test]
4712 fn test_parse_create_masking_policy_does_not_match_legacy_create_mask() {
4713 let result =
4716 parse_statement("CREATE MASKING POLICY p STRATEGY HASH EXEMPT ROLES ('admin')")
4717 .unwrap();
4718 assert!(matches!(result, ParsedStatement::CreateMaskingPolicy(_)));
4719 }
4720
4721 #[test]
4726 fn test_parse_set_classification() {
4727 let result =
4728 parse_statement("ALTER TABLE patients MODIFY COLUMN ssn SET CLASSIFICATION 'PHI'")
4729 .unwrap();
4730 match result {
4731 ParsedStatement::SetClassification(sc) => {
4732 assert_eq!(sc.table_name, "patients");
4733 assert_eq!(sc.column_name, "ssn");
4734 assert_eq!(sc.classification, "PHI");
4735 }
4736 _ => panic!("expected SetClassification statement"),
4737 }
4738 }
4739
4740 #[test]
4741 fn test_parse_set_classification_with_semicolon() {
4742 let result = parse_statement(
4743 "ALTER TABLE patients MODIFY COLUMN diagnosis SET CLASSIFICATION 'MEDICAL';",
4744 )
4745 .unwrap();
4746 match result {
4747 ParsedStatement::SetClassification(sc) => {
4748 assert_eq!(sc.table_name, "patients");
4749 assert_eq!(sc.column_name, "diagnosis");
4750 assert_eq!(sc.classification, "MEDICAL");
4751 }
4752 _ => panic!("expected SetClassification statement"),
4753 }
4754 }
4755
4756 #[test]
4757 fn test_parse_set_classification_various_labels() {
4758 for label in &["PHI", "PII", "PCI", "MEDICAL", "FINANCIAL", "CONFIDENTIAL"] {
4759 let sql = format!("ALTER TABLE t MODIFY COLUMN c SET CLASSIFICATION '{label}'");
4760 let result = parse_statement(&sql).unwrap();
4761 match result {
4762 ParsedStatement::SetClassification(sc) => {
4763 assert_eq!(sc.classification, *label);
4764 }
4765 _ => panic!("expected SetClassification for {label}"),
4766 }
4767 }
4768 }
4769
4770 #[test]
4771 fn test_parse_set_classification_missing_quotes() {
4772 let result =
4773 parse_statement("ALTER TABLE patients MODIFY COLUMN ssn SET CLASSIFICATION PHI");
4774 assert!(result.is_err(), "classification must be single-quoted");
4775 }
4776
4777 #[test]
4778 fn test_parse_set_classification_missing_modify() {
4779 let result = parse_statement("ALTER TABLE patients SET CLASSIFICATION 'PHI'");
4782 assert!(result.is_err());
4783 }
4784
4785 #[test]
4790 fn test_parse_show_classifications() {
4791 let result = parse_statement("SHOW CLASSIFICATIONS FOR patients").unwrap();
4792 match result {
4793 ParsedStatement::ShowClassifications(table) => {
4794 assert_eq!(table, "patients");
4795 }
4796 _ => panic!("expected ShowClassifications statement"),
4797 }
4798 }
4799
4800 #[test]
4801 fn test_parse_show_classifications_with_semicolon() {
4802 let result = parse_statement("SHOW CLASSIFICATIONS FOR patients;").unwrap();
4803 match result {
4804 ParsedStatement::ShowClassifications(table) => {
4805 assert_eq!(table, "patients");
4806 }
4807 _ => panic!("expected ShowClassifications statement"),
4808 }
4809 }
4810
4811 #[test]
4812 fn test_parse_show_classifications_missing_for() {
4813 let result = parse_statement("SHOW CLASSIFICATIONS patients");
4814 assert!(result.is_err());
4815 }
4816
4817 #[test]
4818 fn test_parse_show_classifications_missing_table() {
4819 let result = parse_statement("SHOW CLASSIFICATIONS FOR");
4820 assert!(result.is_err());
4821 }
4822
4823 #[test]
4828 fn test_parse_create_role() {
4829 let result = parse_statement("CREATE ROLE billing_clerk").unwrap();
4830 match result {
4831 ParsedStatement::CreateRole(name) => {
4832 assert_eq!(name, "billing_clerk");
4833 }
4834 _ => panic!("expected CreateRole"),
4835 }
4836 }
4837
4838 #[test]
4839 fn test_parse_create_role_with_semicolon() {
4840 let result = parse_statement("CREATE ROLE doctor;").unwrap();
4841 match result {
4842 ParsedStatement::CreateRole(name) => {
4843 assert_eq!(name, "doctor");
4844 }
4845 _ => panic!("expected CreateRole"),
4846 }
4847 }
4848
4849 #[test]
4850 fn test_parse_grant_select_all_columns() {
4851 let result = parse_statement("GRANT SELECT ON patients TO doctor").unwrap();
4852 match result {
4853 ParsedStatement::Grant(g) => {
4854 assert!(g.columns.is_none());
4855 assert_eq!(g.table_name, "patients");
4856 assert_eq!(g.role_name, "doctor");
4857 }
4858 _ => panic!("expected Grant"),
4859 }
4860 }
4861
4862 #[test]
4863 fn test_parse_grant_select_specific_columns() {
4864 let result =
4865 parse_statement("GRANT SELECT (id, name, ssn) ON patients TO billing_clerk").unwrap();
4866 match result {
4867 ParsedStatement::Grant(g) => {
4868 assert_eq!(
4869 g.columns,
4870 Some(vec!["id".into(), "name".into(), "ssn".into()])
4871 );
4872 assert_eq!(g.table_name, "patients");
4873 assert_eq!(g.role_name, "billing_clerk");
4874 }
4875 _ => panic!("expected Grant"),
4876 }
4877 }
4878
4879 #[test]
4880 fn test_parse_create_user() {
4881 let result = parse_statement("CREATE USER clerk1 WITH ROLE billing_clerk").unwrap();
4882 match result {
4883 ParsedStatement::CreateUser(u) => {
4884 assert_eq!(u.username, "clerk1");
4885 assert_eq!(u.role, "billing_clerk");
4886 }
4887 _ => panic!("expected CreateUser"),
4888 }
4889 }
4890
4891 #[test]
4892 fn test_parse_create_user_with_semicolon() {
4893 let result = parse_statement("CREATE USER admin1 WITH ROLE admin;").unwrap();
4894 match result {
4895 ParsedStatement::CreateUser(u) => {
4896 assert_eq!(u.username, "admin1");
4897 assert_eq!(u.role, "admin");
4898 }
4899 _ => panic!("expected CreateUser"),
4900 }
4901 }
4902
4903 #[test]
4904 fn test_parse_create_user_missing_role() {
4905 let result = parse_statement("CREATE USER clerk1 WITH billing_clerk");
4906 assert!(result.is_err());
4907 }
4908
4909 #[test]
4914 fn test_parse_create_table_rejects_zero_columns() {
4915 let result = parse_statement("CREATE TABLE#USER");
4917 assert!(result.is_err(), "zero-column CREATE TABLE must be rejected");
4918
4919 let result = parse_statement("CREATE TABLE t ()");
4923 assert!(
4924 result.is_err(),
4925 "empty-column-list CREATE TABLE must be rejected"
4926 );
4927 }
4928
4929 fn parse_test_insert(sql: &str) -> ParsedInsert {
4934 match parse_statement(sql).unwrap_or_else(|e| panic!("parse failed: {e}")) {
4935 ParsedStatement::Insert(i) => i,
4936 other => panic!("expected INSERT statement, got {other:?}"),
4937 }
4938 }
4939
4940 #[test]
4941 fn test_parse_insert_on_conflict_do_update() {
4942 let ins = parse_test_insert(
4943 "INSERT INTO users (id, name) VALUES (1, 'Alice') \
4944 ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name",
4945 );
4946 let oc = ins.on_conflict.expect("on_conflict must be present");
4947 assert_eq!(oc.target, vec!["id".to_string()]);
4948 match oc.action {
4949 OnConflictAction::DoUpdate { assignments } => {
4950 assert_eq!(assignments.len(), 1);
4951 assert_eq!(assignments[0].0, "name");
4952 assert_eq!(
4953 assignments[0].1,
4954 UpsertExpr::Excluded("name".to_string()),
4955 "RHS must be an EXCLUDED.col back-reference"
4956 );
4957 }
4958 other @ OnConflictAction::DoNothing => panic!("expected DoUpdate, got {other:?}"),
4959 }
4960 }
4961
4962 #[test]
4963 fn test_parse_insert_on_conflict_do_nothing() {
4964 let ins = parse_test_insert(
4965 "INSERT INTO users (id, name) VALUES (1, 'Alice') ON CONFLICT (id) DO NOTHING",
4966 );
4967 let oc = ins.on_conflict.expect("on_conflict must be present");
4968 assert_eq!(oc.target, vec!["id".to_string()]);
4969 assert!(
4970 matches!(oc.action, OnConflictAction::DoNothing),
4971 "DO NOTHING must parse to OnConflictAction::DoNothing"
4972 );
4973 }
4974
4975 #[test]
4976 fn test_parse_plain_insert_has_no_on_conflict() {
4977 let ins = parse_test_insert("INSERT INTO users (id, name) VALUES (1, 'Alice')");
4978 assert!(
4979 ins.on_conflict.is_none(),
4980 "plain INSERT must not carry an on_conflict clause"
4981 );
4982 }
4983
4984 #[test]
4985 fn test_parse_insert_on_conflict_multi_column_target() {
4986 let ins = parse_test_insert(
4987 "INSERT INTO t (tenant_id, id, v) VALUES (1, 2, 3) \
4988 ON CONFLICT (tenant_id, id) DO UPDATE SET v = EXCLUDED.v",
4989 );
4990 let oc = ins.on_conflict.expect("on_conflict must be present");
4991 assert_eq!(oc.target, vec!["tenant_id".to_string(), "id".to_string()]);
4992 }
4993
4994 #[test]
4995 fn test_parse_insert_on_conflict_with_returning() {
4996 let ins = parse_test_insert(
4997 "INSERT INTO t (id, v) VALUES (1, 2) \
4998 ON CONFLICT (id) DO UPDATE SET v = EXCLUDED.v RETURNING id, v",
4999 );
5000 assert!(ins.on_conflict.is_some());
5001 assert_eq!(ins.returning, Some(vec!["id".to_string(), "v".to_string()]));
5002 }
5003
5004 #[test]
5005 fn test_parse_insert_on_conflict_rejects_on_constraint_form() {
5006 let result = parse_statement(
5008 "INSERT INTO t (id) VALUES (1) ON CONFLICT ON CONSTRAINT pk_t DO NOTHING",
5009 );
5010 assert!(
5011 result.is_err(),
5012 "ON CONSTRAINT form must be rejected with a clear error"
5013 );
5014 }
5015
5016 #[test]
5017 fn test_parse_insert_on_conflict_literal_rhs() {
5018 let ins = parse_test_insert(
5020 "INSERT INTO t (id, v) VALUES (1, 2) \
5021 ON CONFLICT (id) DO UPDATE SET v = 42",
5022 );
5023 let oc = ins.on_conflict.expect("on_conflict must be present");
5024 match oc.action {
5025 OnConflictAction::DoUpdate { assignments } => {
5026 assert_eq!(assignments[0].0, "v");
5027 assert!(matches!(assignments[0].1, UpsertExpr::Value(_)));
5028 }
5029 other @ OnConflictAction::DoNothing => panic!("expected DoUpdate, got {other:?}"),
5030 }
5031 }
5032}