1use kimberlite_types::NonEmptyVec;
15use sqlparser::ast::{
16 BinaryOperator, ColumnDef as SqlColumnDef, DataType as SqlDataType, Expr, ObjectName,
17 OrderByExpr, Query, Select, SelectItem, SetExpr, Statement, Value as SqlValue,
18};
19use sqlparser::dialect::{Dialect, GenericDialect};
20use sqlparser::parser::Parser;
21
22#[derive(Debug)]
27struct KimberliteDialect {
28 inner: GenericDialect,
29}
30
31impl KimberliteDialect {
32 const fn new() -> Self {
33 Self {
34 inner: GenericDialect {},
35 }
36 }
37}
38
39impl Dialect for KimberliteDialect {
40 fn is_identifier_start(&self, ch: char) -> bool {
41 self.inner.is_identifier_start(ch)
42 }
43
44 fn is_identifier_part(&self, ch: char) -> bool {
45 self.inner.is_identifier_part(ch)
46 }
47
48 fn supports_filter_during_aggregation(&self) -> bool {
49 true
50 }
51}
52
53use crate::error::{QueryError, Result};
54use crate::expression::ScalarExpr;
55use crate::schema::{ColumnName, DataType};
56use crate::value::Value;
57
58#[derive(Debug, Clone)]
64pub enum ParsedStatement {
65 Select(ParsedSelect),
67 Union(ParsedUnion),
69 CreateTable(ParsedCreateTable),
71 DropTable { name: String, if_exists: bool },
79 AlterTable(ParsedAlterTable),
81 CreateIndex(ParsedCreateIndex),
83 Insert(ParsedInsert),
85 Update(ParsedUpdate),
87 Delete(ParsedDelete),
89 CreateMask(ParsedCreateMask),
91 DropMask(String),
93 CreateMaskingPolicy(ParsedCreateMaskingPolicy),
95 DropMaskingPolicy(String),
97 AttachMaskingPolicy(ParsedAttachMaskingPolicy),
99 DetachMaskingPolicy(ParsedDetachMaskingPolicy),
101 SetClassification(ParsedSetClassification),
103 ShowClassifications(String),
105 ShowTables,
107 ShowColumns(String),
109 CreateRole(String),
111 Grant(ParsedGrant),
113 CreateUser(ParsedCreateUser),
115}
116
117#[derive(Debug, Clone)]
119pub struct ParsedGrant {
120 pub columns: Option<Vec<String>>,
122 pub table_name: String,
124 pub role_name: String,
126}
127
128#[derive(Debug, Clone)]
130pub struct ParsedCreateUser {
131 pub username: String,
133 pub role: String,
135}
136
137#[derive(Debug, Clone)]
139pub struct ParsedSetClassification {
140 pub table_name: String,
142 pub column_name: String,
144 pub classification: String,
146}
147
148#[derive(Debug, Clone)]
150pub struct ParsedCreateMask {
151 pub mask_name: String,
153 pub table_name: String,
155 pub column_name: String,
157 pub strategy: String,
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
165pub enum ParsedMaskingStrategy {
166 RedactSsn,
168 RedactPhone,
170 RedactEmail,
172 RedactCreditCard,
174 RedactCustom {
176 replacement: String,
178 },
179 Hash,
181 Tokenize,
183 Truncate {
185 max_chars: usize,
187 },
188 Null,
190}
191
192#[derive(Debug, Clone)]
198pub struct ParsedCreateMaskingPolicy {
199 pub name: String,
201 pub strategy: ParsedMaskingStrategy,
203 pub exempt_roles: Vec<String>,
206}
207
208#[derive(Debug, Clone)]
210#[allow(clippy::struct_field_names)] pub struct ParsedAttachMaskingPolicy {
212 pub table_name: String,
214 pub column_name: String,
216 pub policy_name: String,
218}
219
220#[derive(Debug, Clone)]
222pub struct ParsedDetachMaskingPolicy {
223 pub table_name: String,
225 pub column_name: String,
227}
228
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
231pub enum SetOp {
232 Union,
234 Intersect,
236 Except,
238}
239
240#[derive(Debug, Clone)]
242pub struct ParsedUnion {
243 pub op: SetOp,
245 pub left: ParsedSelect,
247 pub right: ParsedSelect,
249 pub all: bool,
252}
253
254#[derive(Debug, Clone, PartialEq, Eq)]
256pub enum JoinType {
257 Inner,
259 Left,
261 Right,
263 Full,
265 Cross,
267}
268
269#[derive(Debug, Clone)]
271pub struct ParsedJoin {
272 pub table: String,
274 pub join_type: JoinType,
276 pub on_condition: Vec<Predicate>,
278}
279
280#[derive(Debug, Clone)]
282pub struct ParsedCte {
283 pub name: String,
285 pub query: ParsedSelect,
287 pub recursive_arm: Option<ParsedSelect>,
290}
291
292#[derive(Debug, Clone)]
296pub struct ComputedColumn {
297 pub alias: ColumnName,
299 pub when_clauses: Vec<CaseWhenArm>,
301 pub else_value: Value,
303}
304
305#[derive(Debug, Clone)]
307pub struct CaseWhenArm {
308 pub condition: Vec<Predicate>,
310 pub result: Value,
312}
313
314#[derive(Debug, Clone, Copy, PartialEq, Eq)]
322pub enum LimitExpr {
323 Literal(usize),
325 Param(usize),
327}
328
329#[derive(Debug, Clone)]
331pub struct ParsedSelect {
332 pub table: String,
334 pub joins: Vec<ParsedJoin>,
336 pub columns: Option<Vec<ColumnName>>,
338 pub column_aliases: Option<Vec<Option<String>>>,
345 pub case_columns: Vec<ComputedColumn>,
347 pub predicates: Vec<Predicate>,
349 pub order_by: Vec<OrderByClause>,
351 pub limit: Option<LimitExpr>,
353 pub offset: Option<LimitExpr>,
355 pub aggregates: Vec<AggregateFunction>,
357 pub aggregate_filters: Vec<Option<Vec<Predicate>>>,
364 pub group_by: Vec<ColumnName>,
366 pub distinct: bool,
368 pub having: Vec<HavingCondition>,
370 pub ctes: Vec<ParsedCte>,
372 pub window_fns: Vec<ParsedWindowFn>,
376 pub scalar_projections: Vec<ParsedScalarProjection>,
384}
385
386#[derive(Debug, Clone)]
389pub struct ParsedScalarProjection {
390 pub expr: ScalarExpr,
392 pub output_name: ColumnName,
395 pub alias: Option<String>,
399}
400
401#[derive(Debug, Clone)]
404pub struct ParsedWindowFn {
405 pub function: crate::window::WindowFunction,
407 pub partition_by: Vec<ColumnName>,
409 pub order_by: Vec<OrderByClause>,
411 pub alias: Option<String>,
413}
414
415#[derive(Debug, Clone)]
419pub enum HavingCondition {
420 AggregateComparison {
422 aggregate: AggregateFunction,
424 op: HavingOp,
426 value: Value,
428 },
429}
430
431#[derive(Debug, Clone, Copy, PartialEq, Eq)]
433pub enum HavingOp {
434 Eq,
435 Lt,
436 Le,
437 Gt,
438 Ge,
439}
440
441#[derive(Debug, Clone)]
449pub struct ParsedCreateTable {
450 pub table_name: String,
451 pub columns: NonEmptyVec<ParsedColumn>,
452 pub primary_key: Vec<String>,
453 pub if_not_exists: bool,
456}
457
458#[derive(Debug, Clone)]
460pub struct ParsedColumn {
461 pub name: String,
462 pub data_type: String, pub nullable: bool,
464}
465
466#[derive(Debug, Clone)]
468pub struct ParsedAlterTable {
469 pub table_name: String,
470 pub operation: AlterTableOperation,
471}
472
473#[derive(Debug, Clone)]
475pub enum AlterTableOperation {
476 AddColumn(ParsedColumn),
478 DropColumn(String),
480}
481
482#[derive(Debug, Clone)]
484pub struct ParsedCreateIndex {
485 pub index_name: String,
486 pub table_name: String,
487 pub columns: Vec<String>,
488}
489
490#[derive(Debug, Clone)]
495pub struct ParsedInsert {
496 pub table: String,
497 pub columns: Vec<String>,
498 pub values: Vec<Vec<Value>>, pub returning: Option<Vec<String>>, pub on_conflict: Option<OnConflictClause>,
502}
503
504#[derive(Debug, Clone, PartialEq, Eq)]
514pub struct OnConflictClause {
515 pub target: Vec<String>,
519 pub action: OnConflictAction,
521}
522
523#[derive(Debug, Clone, PartialEq, Eq)]
525pub enum OnConflictAction {
526 DoNothing,
530 DoUpdate {
534 assignments: Vec<(String, UpsertExpr)>,
538 },
539}
540
541#[derive(Debug, Clone, PartialEq, Eq)]
549pub enum UpsertExpr {
550 Value(Value),
552 Excluded(String),
554}
555
556#[derive(Debug, Clone)]
558pub struct ParsedUpdate {
559 pub table: String,
560 pub assignments: Vec<(String, Value)>, pub predicates: Vec<Predicate>,
562 pub returning: Option<Vec<String>>, }
564
565#[derive(Debug, Clone)]
567pub struct ParsedDelete {
568 pub table: String,
569 pub predicates: Vec<Predicate>,
570 pub returning: Option<Vec<String>>, }
572
573#[derive(Debug, Clone, PartialEq, Eq)]
575pub enum AggregateFunction {
576 CountStar,
578 Count(ColumnName),
580 Sum(ColumnName),
582 Avg(ColumnName),
584 Min(ColumnName),
586 Max(ColumnName),
588}
589
590#[derive(Debug, Clone)]
592pub enum Predicate {
593 Eq(ColumnName, PredicateValue),
595 Lt(ColumnName, PredicateValue),
597 Le(ColumnName, PredicateValue),
599 Gt(ColumnName, PredicateValue),
601 Ge(ColumnName, PredicateValue),
603 In(ColumnName, Vec<PredicateValue>),
605 NotIn(ColumnName, Vec<PredicateValue>),
607 NotBetween(ColumnName, PredicateValue, PredicateValue),
609 Like(ColumnName, String),
611 NotLike(ColumnName, String),
613 ILike(ColumnName, String),
615 NotILike(ColumnName, String),
617 IsNull(ColumnName),
619 IsNotNull(ColumnName),
621 JsonExtractEq {
626 column: ColumnName,
628 path: String,
630 as_text: bool,
632 value: PredicateValue,
634 },
635 JsonContains {
637 column: ColumnName,
638 value: PredicateValue,
639 },
640 InSubquery {
649 column: ColumnName,
650 subquery: Box<ParsedSelect>,
651 negated: bool,
653 },
654 Exists {
659 subquery: Box<ParsedSelect>,
660 negated: bool,
661 },
662 Always(bool),
669 Or(Vec<Predicate>, Vec<Predicate>),
671 ScalarCmp {
680 lhs: ScalarExpr,
681 op: ScalarCmpOp,
682 rhs: ScalarExpr,
683 },
684}
685
686#[derive(Debug, Clone, Copy, PartialEq, Eq)]
688pub enum ScalarCmpOp {
689 Eq,
690 NotEq,
691 Lt,
692 Le,
693 Gt,
694 Ge,
695}
696
697impl Predicate {
698 #[allow(dead_code)]
702 pub fn column(&self) -> Option<&ColumnName> {
703 match self {
704 Predicate::Eq(col, _)
705 | Predicate::Lt(col, _)
706 | Predicate::Le(col, _)
707 | Predicate::Gt(col, _)
708 | Predicate::Ge(col, _)
709 | Predicate::In(col, _)
710 | Predicate::NotIn(col, _)
711 | Predicate::NotBetween(col, _, _)
712 | Predicate::Like(col, _)
713 | Predicate::NotLike(col, _)
714 | Predicate::ILike(col, _)
715 | Predicate::NotILike(col, _)
716 | Predicate::IsNull(col)
717 | Predicate::IsNotNull(col)
718 | Predicate::JsonExtractEq { column: col, .. }
719 | Predicate::JsonContains { column: col, .. }
720 | Predicate::InSubquery { column: col, .. } => Some(col),
721 Predicate::Or(_, _)
722 | Predicate::Exists { .. }
723 | Predicate::Always(_)
724 | Predicate::ScalarCmp { .. } => None,
725 }
726 }
727}
728
729#[derive(Debug, Clone)]
731pub enum PredicateValue {
732 Int(i64),
734 String(String),
736 Bool(bool),
738 Null,
740 Param(usize),
742 Literal(Value),
744 ColumnRef(String),
747}
748
749#[derive(Debug, Clone)]
751pub struct OrderByClause {
752 pub column: ColumnName,
754 pub ascending: bool,
756}
757
758pub fn parse_statement(sql: &str) -> Result<ParsedStatement> {
764 crate::depth_check::check_sql_depth(sql)?;
765
766 if let Some(parsed) = try_parse_custom_statement(sql)? {
768 return Ok(parsed);
769 }
770
771 let dialect = KimberliteDialect::new();
772 let statements =
773 Parser::parse_sql(&dialect, sql).map_err(|e| QueryError::ParseError(e.to_string()))?;
774
775 if statements.len() != 1 {
776 return Err(QueryError::ParseError(format!(
777 "expected exactly 1 statement, got {}",
778 statements.len()
779 )));
780 }
781
782 match &statements[0] {
783 Statement::Query(query) => parse_query_to_statement(query),
784 Statement::CreateTable(create_table) => {
785 let parsed = parse_create_table(create_table)?;
786 Ok(ParsedStatement::CreateTable(parsed))
787 }
788 Statement::Drop {
789 object_type,
790 names,
791 if_exists,
792 ..
793 } => {
794 if !matches!(object_type, sqlparser::ast::ObjectType::Table) {
795 return Err(QueryError::UnsupportedFeature(
796 "only DROP TABLE is supported".to_string(),
797 ));
798 }
799 if names.len() != 1 {
800 return Err(QueryError::ParseError(
801 "expected exactly 1 table in DROP TABLE".to_string(),
802 ));
803 }
804 let table_name = object_name_to_string(&names[0]);
805 Ok(ParsedStatement::DropTable {
806 name: table_name,
807 if_exists: *if_exists,
808 })
809 }
810 Statement::CreateIndex(create_index) => {
811 let parsed = parse_create_index(create_index)?;
812 Ok(ParsedStatement::CreateIndex(parsed))
813 }
814 Statement::Insert(insert) => {
815 let parsed = parse_insert(insert)?;
816 Ok(ParsedStatement::Insert(parsed))
817 }
818 Statement::Update(update) => {
819 let parsed = parse_update(
820 &update.table,
821 &update.assignments,
822 update.selection.as_ref(),
823 update.returning.as_ref(),
824 )?;
825 Ok(ParsedStatement::Update(parsed))
826 }
827 Statement::Delete(delete) => {
828 let parsed = parse_delete_stmt(delete)?;
829 Ok(ParsedStatement::Delete(parsed))
830 }
831 Statement::AlterTable(alter_table) => {
832 let parsed = parse_alter_table(&alter_table.name, &alter_table.operations)?;
833 Ok(ParsedStatement::AlterTable(parsed))
834 }
835 Statement::CreateRole(create_role) => {
836 if create_role.names.len() != 1 {
837 return Err(QueryError::ParseError(
838 "expected exactly 1 role name".to_string(),
839 ));
840 }
841 let role_name = object_name_to_string(&create_role.names[0]);
842 Ok(ParsedStatement::CreateRole(role_name))
843 }
844 Statement::Grant(grant) => {
845 let objects = grant.objects.as_ref().ok_or_else(|| {
846 QueryError::ParseError(
847 "GRANT requires an ON clause specifying the target objects".to_string(),
848 )
849 })?;
850 parse_grant(&grant.privileges, objects, &grant.grantees)
851 }
852 other => Err(QueryError::UnsupportedFeature(format!(
853 "statement type not supported: {other:?}"
854 ))),
855 }
856}
857
858pub fn try_parse_custom_statement(sql: &str) -> Result<Option<ParsedStatement>> {
874 let trimmed = sql.trim().trim_end_matches(';').trim();
875 let upper = trimmed.to_ascii_uppercase();
876
877 if upper.starts_with("CREATE MASKING POLICY") {
879 return parse_create_masking_policy(trimmed).map(Some);
880 }
881
882 if upper.starts_with("DROP MASKING POLICY") {
884 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
885 if tokens.len() != 4 {
887 return Err(QueryError::ParseError(
888 "expected: DROP MASKING POLICY <name>".to_string(),
889 ));
890 }
891 return Ok(Some(ParsedStatement::DropMaskingPolicy(
892 tokens[3].to_string(),
893 )));
894 }
895
896 if upper.starts_with("ALTER TABLE") && upper.contains("MASKING POLICY") {
898 return parse_alter_masking_policy(trimmed).map(Some);
899 }
900
901 if upper.starts_with("CREATE MASK") {
903 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
904 if tokens.len() != 7 {
906 return Err(QueryError::ParseError(
907 "expected: CREATE MASK <name> ON <table>.<column> USING <strategy>".to_string(),
908 ));
909 }
910 if !tokens[3].eq_ignore_ascii_case("ON") {
911 return Err(QueryError::ParseError(format!(
912 "expected ON after mask name, got '{}'",
913 tokens[3]
914 )));
915 }
916 if !tokens[5].eq_ignore_ascii_case("USING") {
917 return Err(QueryError::ParseError(format!(
918 "expected USING after column reference, got '{}'",
919 tokens[5]
920 )));
921 }
922
923 let table_col = tokens[4];
925 let dot_pos = table_col.find('.').ok_or_else(|| {
926 QueryError::ParseError(format!(
927 "expected <table>.<column> but got '{table_col}' (missing '.')"
928 ))
929 })?;
930 let table_name = table_col[..dot_pos].to_string();
931 let column_name = table_col[dot_pos + 1..].to_string();
932
933 if table_name.is_empty() || column_name.is_empty() {
934 return Err(QueryError::ParseError(
935 "table name and column name must not be empty".to_string(),
936 ));
937 }
938
939 let strategy = tokens[6].to_ascii_uppercase();
940
941 return Ok(Some(ParsedStatement::CreateMask(ParsedCreateMask {
942 mask_name: tokens[2].to_string(),
943 table_name,
944 column_name,
945 strategy,
946 })));
947 }
948
949 if upper.starts_with("DROP MASK") {
951 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
952 if tokens.len() != 3 {
953 return Err(QueryError::ParseError(
954 "expected: DROP MASK <name>".to_string(),
955 ));
956 }
957 return Ok(Some(ParsedStatement::DropMask(tokens[2].to_string())));
958 }
959
960 if upper.starts_with("ALTER TABLE") && upper.contains("SET CLASSIFICATION") {
962 return parse_set_classification(trimmed);
963 }
964
965 if upper.starts_with("SHOW CLASSIFICATIONS") {
967 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
968 if tokens.len() != 4 {
970 return Err(QueryError::ParseError(
971 "expected: SHOW CLASSIFICATIONS FOR <table>".to_string(),
972 ));
973 }
974 if !tokens[2].eq_ignore_ascii_case("FOR") {
975 return Err(QueryError::ParseError(format!(
976 "expected FOR after CLASSIFICATIONS, got '{}'",
977 tokens[2]
978 )));
979 }
980 return Ok(Some(ParsedStatement::ShowClassifications(
981 tokens[3].to_string(),
982 )));
983 }
984
985 if upper == "SHOW TABLES" {
987 return Ok(Some(ParsedStatement::ShowTables));
988 }
989
990 if upper.starts_with("SHOW COLUMNS") {
992 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
993 if tokens.len() != 4 {
995 return Err(QueryError::ParseError(
996 "expected: SHOW COLUMNS FROM <table>".to_string(),
997 ));
998 }
999 if !tokens[2].eq_ignore_ascii_case("FROM") {
1000 return Err(QueryError::ParseError(format!(
1001 "expected FROM after COLUMNS, got '{}'",
1002 tokens[2]
1003 )));
1004 }
1005 return Ok(Some(ParsedStatement::ShowColumns(tokens[3].to_string())));
1006 }
1007
1008 if upper.starts_with("CREATE USER") {
1010 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
1011 if tokens.len() != 6 {
1013 return Err(QueryError::ParseError(
1014 "expected: CREATE USER <name> WITH ROLE <role>".to_string(),
1015 ));
1016 }
1017 if !tokens[3].eq_ignore_ascii_case("WITH") {
1018 return Err(QueryError::ParseError(format!(
1019 "expected WITH after username, got '{}'",
1020 tokens[3]
1021 )));
1022 }
1023 if !tokens[4].eq_ignore_ascii_case("ROLE") {
1024 return Err(QueryError::ParseError(format!(
1025 "expected ROLE after WITH, got '{}'",
1026 tokens[4]
1027 )));
1028 }
1029 return Ok(Some(ParsedStatement::CreateUser(ParsedCreateUser {
1030 username: tokens[2].to_string(),
1031 role: tokens[5].to_string(),
1032 })));
1033 }
1034
1035 Ok(None)
1036}
1037
1038fn parse_create_masking_policy(trimmed: &str) -> Result<ParsedStatement> {
1046 let after_keyword = trimmed
1050 .get("CREATE MASKING POLICY".len()..)
1051 .ok_or_else(|| QueryError::ParseError("missing policy body".to_string()))?
1052 .trim_start();
1053
1054 let upper_body = after_keyword.to_ascii_uppercase();
1057 let exempt_pos = upper_body.find("EXEMPT ROLES").ok_or_else(|| {
1058 QueryError::ParseError(
1059 "expected: CREATE MASKING POLICY <name> STRATEGY <kind> [<arg>] EXEMPT ROLES (<r>, ...)"
1060 .to_string(),
1061 )
1062 })?;
1063
1064 let header = after_keyword[..exempt_pos].trim();
1065 let exempt_tail = after_keyword[exempt_pos + "EXEMPT ROLES".len()..].trim();
1066
1067 let (name, strategy) = parse_masking_policy_header(header)?;
1068 let exempt_roles = parse_exempt_roles_list(exempt_tail)?;
1069
1070 Ok(ParsedStatement::CreateMaskingPolicy(
1071 ParsedCreateMaskingPolicy {
1072 name,
1073 strategy,
1074 exempt_roles,
1075 },
1076 ))
1077}
1078
1079fn parse_masking_policy_header(header: &str) -> Result<(String, ParsedMaskingStrategy)> {
1081 let tokens: Vec<&str> = header.split_whitespace().collect();
1082 if tokens.len() < 3 {
1083 return Err(QueryError::ParseError(
1084 "expected: <name> STRATEGY <kind> [<arg>]".to_string(),
1085 ));
1086 }
1087 let name = tokens[0].to_string();
1088 if name.is_empty() {
1089 return Err(QueryError::ParseError(
1090 "policy name must not be empty".to_string(),
1091 ));
1092 }
1093 if !tokens[1].eq_ignore_ascii_case("STRATEGY") {
1094 return Err(QueryError::ParseError(format!(
1095 "expected STRATEGY after policy name, got '{}'",
1096 tokens[1]
1097 )));
1098 }
1099
1100 let strategy = parse_masking_strategy(&tokens[2..])?;
1101 Ok((name, strategy))
1102}
1103
1104fn parse_masking_strategy(tokens: &[&str]) -> Result<ParsedMaskingStrategy> {
1106 debug_assert!(
1107 !tokens.is_empty(),
1108 "caller must pass at least the strategy keyword"
1109 );
1110 let kind = tokens[0].to_ascii_uppercase();
1111 match kind.as_str() {
1112 "REDACT_SSN" => {
1113 expect_no_strategy_arg(tokens, "REDACT_SSN").map(|()| ParsedMaskingStrategy::RedactSsn)
1114 }
1115 "REDACT_PHONE" => expect_no_strategy_arg(tokens, "REDACT_PHONE")
1116 .map(|()| ParsedMaskingStrategy::RedactPhone),
1117 "REDACT_EMAIL" => expect_no_strategy_arg(tokens, "REDACT_EMAIL")
1118 .map(|()| ParsedMaskingStrategy::RedactEmail),
1119 "REDACT_CC" => expect_no_strategy_arg(tokens, "REDACT_CC")
1120 .map(|()| ParsedMaskingStrategy::RedactCreditCard),
1121 "REDACT_CUSTOM" => {
1122 if tokens.len() != 2 {
1123 return Err(QueryError::ParseError(
1124 "REDACT_CUSTOM requires a single quoted replacement string".to_string(),
1125 ));
1126 }
1127 let replacement = unquote_string_literal(tokens[1]).ok_or_else(|| {
1128 QueryError::ParseError(
1129 "REDACT_CUSTOM replacement must be a single-quoted string".to_string(),
1130 )
1131 })?;
1132 Ok(ParsedMaskingStrategy::RedactCustom { replacement })
1133 }
1134 "HASH" => {
1135 expect_no_strategy_arg(tokens, "HASH")?;
1136 Ok(ParsedMaskingStrategy::Hash)
1137 }
1138 "TOKENIZE" => {
1139 expect_no_strategy_arg(tokens, "TOKENIZE")?;
1140 Ok(ParsedMaskingStrategy::Tokenize)
1141 }
1142 "TRUNCATE" => {
1143 if tokens.len() != 2 {
1144 return Err(QueryError::ParseError(
1145 "TRUNCATE requires a positive integer character count".to_string(),
1146 ));
1147 }
1148 let max_chars = tokens[1].parse::<usize>().map_err(|_| {
1149 QueryError::ParseError(format!(
1150 "TRUNCATE argument must be a non-negative integer, got '{}'",
1151 tokens[1]
1152 ))
1153 })?;
1154 if max_chars == 0 {
1155 return Err(QueryError::ParseError(
1156 "TRUNCATE character count must be > 0".to_string(),
1157 ));
1158 }
1159 Ok(ParsedMaskingStrategy::Truncate { max_chars })
1160 }
1161 "NULL" => {
1162 expect_no_strategy_arg(tokens, "NULL")?;
1163 Ok(ParsedMaskingStrategy::Null)
1164 }
1165 _ => Err(QueryError::ParseError(format!(
1166 "unknown masking strategy '{kind}' — expected one of REDACT_SSN, REDACT_PHONE, \
1167 REDACT_EMAIL, REDACT_CC, REDACT_CUSTOM, HASH, TOKENIZE, TRUNCATE, NULL"
1168 ))),
1169 }
1170}
1171
1172fn expect_no_strategy_arg(tokens: &[&str], kind: &str) -> Result<()> {
1173 if tokens.len() != 1 {
1174 return Err(QueryError::ParseError(format!(
1175 "{kind} takes no arguments (found {} extra token(s))",
1176 tokens.len() - 1
1177 )));
1178 }
1179 Ok(())
1180}
1181
1182fn unquote_string_literal(token: &str) -> Option<String> {
1185 let bytes = token.as_bytes();
1186 if bytes.len() < 2 || bytes[0] != b'\'' || bytes[bytes.len() - 1] != b'\'' {
1187 return None;
1188 }
1189 Some(token[1..token.len() - 1].to_string())
1190}
1191
1192fn parse_exempt_roles_list(tail: &str) -> Result<Vec<String>> {
1199 let trimmed = tail.trim();
1200 if !trimmed.starts_with('(') || !trimmed.ends_with(')') {
1201 return Err(QueryError::ParseError(
1202 "EXEMPT ROLES must be followed by a parenthesised list: EXEMPT ROLES (r1, r2, ...)"
1203 .to_string(),
1204 ));
1205 }
1206 let inner = &trimmed[1..trimmed.len() - 1];
1207 let roles: Vec<String> = inner
1208 .split(',')
1209 .map(|s| s.trim().trim_matches('\'').to_ascii_lowercase())
1210 .filter(|s| !s.is_empty())
1211 .collect();
1212 if roles.is_empty() {
1213 return Err(QueryError::ParseError(
1214 "EXEMPT ROLES list must contain at least one role".to_string(),
1215 ));
1216 }
1217 Ok(roles)
1218}
1219
1220fn parse_alter_masking_policy(trimmed: &str) -> Result<ParsedStatement> {
1222 let tokens: Vec<&str> = trimmed.split_whitespace().collect();
1223 if tokens.len() < 9 || tokens.len() > 10 {
1226 return Err(QueryError::ParseError(
1227 "expected: ALTER TABLE <t> ALTER COLUMN <c> { SET | DROP } MASKING POLICY [<name>]"
1228 .to_string(),
1229 ));
1230 }
1231 if !tokens[0].eq_ignore_ascii_case("ALTER")
1232 || !tokens[1].eq_ignore_ascii_case("TABLE")
1233 || !tokens[3].eq_ignore_ascii_case("ALTER")
1234 || !tokens[4].eq_ignore_ascii_case("COLUMN")
1235 || !tokens[7].eq_ignore_ascii_case("MASKING")
1236 || !tokens[8].eq_ignore_ascii_case("POLICY")
1237 {
1238 return Err(QueryError::ParseError(format!(
1239 "malformed ALTER ... MASKING POLICY statement: '{trimmed}'"
1240 )));
1241 }
1242 let table_name = tokens[2].to_string();
1243 let column_name = tokens[5].to_string();
1244 let action = tokens[6].to_ascii_uppercase();
1245 match action.as_str() {
1246 "SET" => {
1247 if tokens.len() != 10 {
1248 return Err(QueryError::ParseError(
1249 "SET MASKING POLICY requires a policy name".to_string(),
1250 ));
1251 }
1252 Ok(ParsedStatement::AttachMaskingPolicy(
1253 ParsedAttachMaskingPolicy {
1254 table_name,
1255 column_name,
1256 policy_name: tokens[9].to_string(),
1257 },
1258 ))
1259 }
1260 "DROP" => {
1261 if tokens.len() != 9 {
1262 return Err(QueryError::ParseError(
1263 "DROP MASKING POLICY takes no arguments after POLICY".to_string(),
1264 ));
1265 }
1266 Ok(ParsedStatement::DetachMaskingPolicy(
1267 ParsedDetachMaskingPolicy {
1268 table_name,
1269 column_name,
1270 },
1271 ))
1272 }
1273 _ => Err(QueryError::ParseError(format!(
1274 "expected SET or DROP after column name, got '{action}'"
1275 ))),
1276 }
1277}
1278
1279#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1296pub enum TimeTravel {
1297 Offset(u64),
1299 TimestampNs(i64),
1302}
1303
1304pub fn extract_at_offset(sql: &str) -> (String, Option<u64>) {
1326 let upper = sql.to_ascii_uppercase();
1329
1330 let Some(at_pos) = upper.rfind("AT OFFSET") else {
1333 return (sql.to_string(), None);
1334 };
1335
1336 if at_pos > 0 {
1338 let prev_byte = sql.as_bytes()[at_pos - 1];
1339 if prev_byte != b' ' && prev_byte != b'\t' && prev_byte != b'\n' && prev_byte != b'\r' {
1340 return (sql.to_string(), None);
1341 }
1342 }
1343
1344 let after_at_offset = &sql[at_pos + 9..].trim_start();
1346
1347 let num_end = after_at_offset
1350 .find(|c: char| !c.is_ascii_digit())
1351 .unwrap_or(after_at_offset.len());
1352
1353 if num_end == 0 {
1354 return (sql.to_string(), None);
1356 }
1357
1358 let num_str = &after_at_offset[..num_end];
1359 let Ok(offset) = num_str.parse::<u64>() else {
1360 return (sql.to_string(), None);
1361 };
1362
1363 let remainder = after_at_offset[num_end..].trim();
1366 if !remainder.is_empty() && remainder != ";" {
1367 return (sql.to_string(), None);
1368 }
1369
1370 let before = sql[..at_pos].trim_end();
1372 let cleaned = before.to_string();
1373
1374 (cleaned, Some(offset))
1375}
1376
1377pub fn extract_time_travel(sql: &str) -> (String, Option<TimeTravel>) {
1405 let (after_offset_sql, offset) = extract_at_offset(sql);
1407 if let Some(o) = offset {
1408 return (after_offset_sql, Some(TimeTravel::Offset(o)));
1409 }
1410
1411 let upper = sql.to_ascii_uppercase();
1417
1418 let (keyword_pos, keyword_len) = if let Some(p) = upper.rfind("FOR SYSTEM_TIME AS OF") {
1424 (p, "FOR SYSTEM_TIME AS OF".len())
1425 } else if let Some(p) = upper.rfind("AS OF") {
1426 let after = sql[p + "AS OF".len()..].trim_start();
1429 if !after.starts_with('\'') {
1430 return (sql.to_string(), None);
1431 }
1432 (p, "AS OF".len())
1433 } else {
1434 return (sql.to_string(), None);
1435 };
1436
1437 if keyword_pos > 0 {
1439 let prev = sql.as_bytes()[keyword_pos - 1];
1440 if !matches!(prev, b' ' | b'\t' | b'\n' | b'\r') {
1441 return (sql.to_string(), None);
1442 }
1443 }
1444
1445 let after_keyword = sql[keyword_pos + keyword_len..].trim_start();
1446 if !after_keyword.starts_with('\'') {
1447 return (sql.to_string(), None);
1448 }
1449
1450 let ts_start = 1; let ts_end = match after_keyword[1..].find('\'') {
1455 Some(i) => i + 1,
1456 None => return (sql.to_string(), None),
1457 };
1458 let ts_str = &after_keyword[ts_start..ts_end];
1459
1460 let ts_ns = match chrono::DateTime::parse_from_rfc3339(ts_str) {
1462 Ok(dt) => dt.timestamp_nanos_opt(),
1463 Err(_) => return (sql.to_string(), None),
1464 };
1465 let ts_ns = match ts_ns {
1466 Some(n) => n,
1467 None => return (sql.to_string(), None),
1468 };
1469
1470 let remainder = after_keyword[ts_end + 1..].trim();
1472 if !remainder.is_empty() && remainder != ";" {
1473 return (sql.to_string(), None);
1474 }
1475
1476 let before = sql[..keyword_pos].trim_end();
1477 (before.to_string(), Some(TimeTravel::TimestampNs(ts_ns)))
1478}
1479
1480fn parse_set_classification(sql: &str) -> Result<Option<ParsedStatement>> {
1485 let tokens: Vec<&str> = sql.split_whitespace().collect();
1486 if tokens.len() != 9 {
1489 return Err(QueryError::ParseError(
1490 "expected: ALTER TABLE <table> MODIFY COLUMN <column> SET CLASSIFICATION '<class>'"
1491 .to_string(),
1492 ));
1493 }
1494
1495 if !tokens[3].eq_ignore_ascii_case("MODIFY") {
1496 return Err(QueryError::ParseError(format!(
1497 "expected MODIFY, got '{}'",
1498 tokens[3]
1499 )));
1500 }
1501 if !tokens[4].eq_ignore_ascii_case("COLUMN") {
1502 return Err(QueryError::ParseError(format!(
1503 "expected COLUMN after MODIFY, got '{}'",
1504 tokens[4]
1505 )));
1506 }
1507 if !tokens[6].eq_ignore_ascii_case("SET") {
1508 return Err(QueryError::ParseError(format!(
1509 "expected SET, got '{}'",
1510 tokens[6]
1511 )));
1512 }
1513 if !tokens[7].eq_ignore_ascii_case("CLASSIFICATION") {
1514 return Err(QueryError::ParseError(format!(
1515 "expected CLASSIFICATION, got '{}'",
1516 tokens[7]
1517 )));
1518 }
1519
1520 let table_name = tokens[2].to_string();
1521 let column_name = tokens[5].to_string();
1522
1523 let raw_class = tokens[8];
1525 let classification = raw_class
1526 .strip_prefix('\'')
1527 .and_then(|s| s.strip_suffix('\''))
1528 .ok_or_else(|| {
1529 QueryError::ParseError(format!(
1530 "classification must be quoted with single quotes, got '{raw_class}'"
1531 ))
1532 })?
1533 .to_string();
1534
1535 assert!(!table_name.is_empty(), "table name must not be empty");
1536 assert!(!column_name.is_empty(), "column name must not be empty");
1537 assert!(
1538 !classification.is_empty(),
1539 "classification must not be empty"
1540 );
1541
1542 Ok(Some(ParsedStatement::SetClassification(
1543 ParsedSetClassification {
1544 table_name,
1545 column_name,
1546 classification,
1547 },
1548 )))
1549}
1550
1551fn parse_grant(
1553 privileges: &sqlparser::ast::Privileges,
1554 objects: &sqlparser::ast::GrantObjects,
1555 grantees: &[sqlparser::ast::Grantee],
1556) -> Result<ParsedStatement> {
1557 use sqlparser::ast::{Action, GrantObjects, GranteeName, Privileges};
1558
1559 let columns = match privileges {
1561 Privileges::Actions(actions) => {
1562 let mut cols = None;
1563 for action in actions {
1564 if let Action::Select { columns: Some(c) } = action {
1565 cols = Some(c.iter().map(|i| i.value.clone()).collect());
1566 }
1567 }
1568 cols
1569 }
1570 Privileges::All { .. } => None,
1571 };
1572
1573 let table_name = match objects {
1575 GrantObjects::Tables(tables) => {
1576 if tables.len() != 1 {
1577 return Err(QueryError::ParseError(
1578 "expected exactly 1 table in GRANT".to_string(),
1579 ));
1580 }
1581 object_name_to_string(&tables[0])
1582 }
1583 _ => {
1584 return Err(QueryError::UnsupportedFeature(
1585 "GRANT only supports table-level privileges".to_string(),
1586 ));
1587 }
1588 };
1589
1590 if grantees.len() != 1 {
1592 return Err(QueryError::ParseError(
1593 "expected exactly 1 grantee in GRANT".to_string(),
1594 ));
1595 }
1596 let role_name = match &grantees[0].name {
1597 Some(GranteeName::ObjectName(name)) => object_name_to_string(name),
1598 _ => {
1599 return Err(QueryError::ParseError(
1600 "expected a role name in GRANT".to_string(),
1601 ));
1602 }
1603 };
1604
1605 Ok(ParsedStatement::Grant(ParsedGrant {
1606 columns,
1607 table_name,
1608 role_name,
1609 }))
1610}
1611
1612fn parse_query_to_statement(query: &Query) -> Result<ParsedStatement> {
1614 let ctes = match &query.with {
1616 Some(with) => parse_ctes(with)?,
1617 None => vec![],
1618 };
1619
1620 match query.body.as_ref() {
1621 SetExpr::Select(select) => {
1622 let parsed_select = parse_select(select)?;
1623
1624 let order_by = match &query.order_by {
1626 Some(ob) => parse_order_by(ob)?,
1627 None => vec![],
1628 };
1629
1630 let limit = parse_limit(query_limit_expr(query)?)?;
1632 let offset = parse_offset_clause(query_offset(query))?;
1633
1634 let mut all_ctes = ctes;
1636 all_ctes.extend(parsed_select.ctes);
1637
1638 Ok(ParsedStatement::Select(ParsedSelect {
1639 table: parsed_select.table,
1640 joins: parsed_select.joins,
1641 columns: parsed_select.columns,
1642 column_aliases: parsed_select.column_aliases,
1643 case_columns: parsed_select.case_columns,
1644 predicates: parsed_select.predicates,
1645 order_by,
1646 limit,
1647 offset,
1648 aggregates: parsed_select.aggregates,
1649 aggregate_filters: parsed_select.aggregate_filters,
1650 group_by: parsed_select.group_by,
1651 distinct: parsed_select.distinct,
1652 having: parsed_select.having,
1653 ctes: all_ctes,
1654 window_fns: parsed_select.window_fns,
1655 scalar_projections: parsed_select.scalar_projections,
1656 }))
1657 }
1658 SetExpr::SetOperation {
1659 op,
1660 set_quantifier,
1661 left,
1662 right,
1663 } => {
1664 use sqlparser::ast::SetOperator;
1665 use sqlparser::ast::SetQuantifier;
1666
1667 let parsed_op = match op {
1668 SetOperator::Union => SetOp::Union,
1669 SetOperator::Intersect => SetOp::Intersect,
1670 SetOperator::Except | SetOperator::Minus => SetOp::Except,
1672 };
1673
1674 let all = matches!(set_quantifier, SetQuantifier::All);
1675
1676 let left_select = match left.as_ref() {
1678 SetExpr::Select(s) => parse_select(s)?,
1679 _ => {
1680 return Err(QueryError::UnsupportedFeature(
1681 "nested set operations not supported".to_string(),
1682 ));
1683 }
1684 };
1685 let right_select = match right.as_ref() {
1686 SetExpr::Select(s) => parse_select(s)?,
1687 _ => {
1688 return Err(QueryError::UnsupportedFeature(
1689 "nested set operations not supported".to_string(),
1690 ));
1691 }
1692 };
1693
1694 Ok(ParsedStatement::Union(ParsedUnion {
1695 op: parsed_op,
1696 left: left_select,
1697 right: right_select,
1698 all,
1699 }))
1700 }
1701 other => Err(QueryError::UnsupportedFeature(format!(
1702 "unsupported query type: {other:?}"
1703 ))),
1704 }
1705}
1706
1707fn parse_join_with_subqueries(join: &sqlparser::ast::Join) -> Result<(ParsedJoin, Vec<ParsedCte>)> {
1709 use sqlparser::ast::{JoinConstraint, JoinOperator};
1710
1711 let join_type = match &join.join_operator {
1718 JoinOperator::Inner(_) | JoinOperator::Join(_) => JoinType::Inner,
1719 JoinOperator::LeftOuter(_) | JoinOperator::Left(_) => JoinType::Left,
1720 JoinOperator::RightOuter(_) | JoinOperator::Right(_) => JoinType::Right,
1721 JoinOperator::FullOuter(_) => JoinType::Full,
1722 JoinOperator::CrossJoin(_) => JoinType::Cross,
1723 other => {
1724 return Err(QueryError::UnsupportedFeature(format!(
1725 "join type not supported: {other:?}"
1726 )));
1727 }
1728 };
1729
1730 let mut inline_ctes = Vec::new();
1732 let table = match &join.relation {
1733 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1734 sqlparser::ast::TableFactor::Derived {
1735 subquery, alias, ..
1736 } => {
1737 let alias_name = alias
1738 .as_ref()
1739 .map(|a| a.name.value.clone())
1740 .ok_or_else(|| {
1741 QueryError::ParseError("subquery in JOIN requires an alias".to_string())
1742 })?;
1743
1744 let inner = match subquery.body.as_ref() {
1746 SetExpr::Select(s) => parse_select(s)?,
1747 _ => {
1748 return Err(QueryError::UnsupportedFeature(
1749 "subquery body must be a simple SELECT".to_string(),
1750 ));
1751 }
1752 };
1753
1754 let order_by = match &subquery.order_by {
1755 Some(ob) => parse_order_by(ob)?,
1756 None => vec![],
1757 };
1758 let limit = parse_limit(query_limit_expr(subquery)?)?;
1759
1760 inline_ctes.push(ParsedCte {
1761 name: alias_name.clone(),
1762 query: ParsedSelect {
1763 order_by,
1764 limit,
1765 ..inner
1766 },
1767 recursive_arm: None,
1768 });
1769
1770 alias_name
1771 }
1772 _ => {
1773 return Err(QueryError::UnsupportedFeature(
1774 "unsupported JOIN relation type".to_string(),
1775 ));
1776 }
1777 };
1778
1779 let on_condition = match &join.join_operator {
1781 JoinOperator::CrossJoin(_) => Vec::new(),
1782 JoinOperator::Inner(constraint)
1783 | JoinOperator::Join(constraint)
1784 | JoinOperator::LeftOuter(constraint)
1785 | JoinOperator::Left(constraint)
1786 | JoinOperator::RightOuter(constraint)
1787 | JoinOperator::Right(constraint)
1788 | JoinOperator::FullOuter(constraint) => match constraint {
1789 JoinConstraint::On(expr) => parse_join_condition(expr)?,
1790 JoinConstraint::Using(idents) => {
1791 let mut preds = Vec::new();
1796 for name in idents {
1797 if name.0.len() != 1 {
1798 return Err(QueryError::UnsupportedFeature(format!(
1799 "USING column must be a bare identifier, got {name}"
1800 )));
1801 }
1802 let col_name = name.0[0]
1803 .as_ident()
1804 .ok_or_else(|| {
1805 QueryError::UnsupportedFeature(format!(
1806 "USING column must be a bare identifier, got {name}"
1807 ))
1808 })?
1809 .value
1810 .clone();
1811 preds.push(Predicate::Eq(
1812 ColumnName::new(col_name.clone()),
1813 PredicateValue::ColumnRef(col_name),
1814 ));
1815 }
1816 preds
1817 }
1818 JoinConstraint::Natural => {
1819 return Err(QueryError::UnsupportedFeature(
1820 "NATURAL JOIN is not supported; use ON or USING explicitly".to_string(),
1821 ));
1822 }
1823 JoinConstraint::None => {
1824 return Err(QueryError::UnsupportedFeature(
1825 "join without ON or USING clause not supported".to_string(),
1826 ));
1827 }
1828 },
1829 _ => {
1830 return Err(QueryError::UnsupportedFeature(
1831 "join without ON clause not supported".to_string(),
1832 ));
1833 }
1834 };
1835
1836 Ok((
1837 ParsedJoin {
1838 table,
1839 join_type,
1840 on_condition,
1841 },
1842 inline_ctes,
1843 ))
1844}
1845
1846fn parse_join_condition(expr: &Expr) -> Result<Vec<Predicate>> {
1849 match expr {
1850 Expr::BinaryOp {
1851 left,
1852 op: BinaryOperator::And,
1853 right,
1854 } => {
1855 let mut predicates = parse_join_condition(left)?;
1856 predicates.extend(parse_join_condition(right)?);
1857 Ok(predicates)
1858 }
1859 _ => {
1860 parse_where_expr(expr)
1862 }
1863 }
1864}
1865
1866fn parse_select(select: &Select) -> Result<ParsedSelect> {
1867 let distinct = select.distinct.is_some();
1869
1870 if select.from.len() != 1 {
1872 return Err(QueryError::ParseError(format!(
1873 "expected exactly 1 table in FROM clause, got {}",
1874 select.from.len()
1875 )));
1876 }
1877
1878 let from = &select.from[0];
1879
1880 let mut inline_ctes = Vec::new();
1882
1883 let mut joins = Vec::new();
1885 for join in &from.joins {
1886 let (parsed_join, join_ctes) = parse_join_with_subqueries(join)?;
1887 joins.push(parsed_join);
1888 inline_ctes.extend(join_ctes);
1889 }
1890
1891 let table = match &from.relation {
1892 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
1893 sqlparser::ast::TableFactor::Derived {
1894 subquery, alias, ..
1895 } => {
1896 let alias_name = alias
1897 .as_ref()
1898 .map(|a| a.name.value.clone())
1899 .ok_or_else(|| {
1900 QueryError::ParseError("subquery in FROM requires an alias".to_string())
1901 })?;
1902
1903 let inner = match subquery.body.as_ref() {
1905 SetExpr::Select(s) => parse_select(s)?,
1906 _ => {
1907 return Err(QueryError::UnsupportedFeature(
1908 "subquery body must be a simple SELECT".to_string(),
1909 ));
1910 }
1911 };
1912
1913 let order_by = match &subquery.order_by {
1914 Some(ob) => parse_order_by(ob)?,
1915 None => vec![],
1916 };
1917 let limit = parse_limit(query_limit_expr(subquery)?)?;
1918
1919 inline_ctes.push(ParsedCte {
1920 name: alias_name.clone(),
1921 query: ParsedSelect {
1922 order_by,
1923 limit,
1924 ..inner
1925 },
1926 recursive_arm: None,
1927 });
1928
1929 alias_name
1930 }
1931 other => {
1932 return Err(QueryError::UnsupportedFeature(format!(
1933 "unsupported FROM clause: {other:?}"
1934 )));
1935 }
1936 };
1937
1938 let (columns, column_aliases) = parse_select_items(&select.projection)?;
1940
1941 let case_columns = parse_case_columns_from_select_items(&select.projection)?;
1943
1944 let predicates = match &select.selection {
1946 Some(expr) => parse_where_expr(expr)?,
1947 None => vec![],
1948 };
1949
1950 let group_by = match &select.group_by {
1952 sqlparser::ast::GroupByExpr::Expressions(exprs, _) if !exprs.is_empty() => {
1953 parse_group_by_expr(exprs)?
1954 }
1955 sqlparser::ast::GroupByExpr::All(_) => {
1956 return Err(QueryError::UnsupportedFeature(
1957 "GROUP BY ALL is not supported".to_string(),
1958 ));
1959 }
1960 sqlparser::ast::GroupByExpr::Expressions(_, _) => vec![],
1961 };
1962
1963 let (aggregates, aggregate_filters) = parse_aggregates_from_select_items(&select.projection)?;
1965
1966 let having = match &select.having {
1968 Some(expr) => parse_having_expr(expr)?,
1969 None => vec![],
1970 };
1971
1972 let window_fns = parse_window_fns_from_select_items(&select.projection)?;
1975
1976 let scalar_projections = parse_scalar_columns_from_select_items(&select.projection)?;
1979
1980 Ok(ParsedSelect {
1981 table,
1982 joins,
1983 columns,
1984 column_aliases,
1985 case_columns,
1986 predicates,
1987 order_by: vec![],
1988 limit: None,
1989 offset: None,
1990 aggregates,
1991 aggregate_filters,
1992 group_by,
1993 distinct,
1994 having,
1995 ctes: inline_ctes,
1996 window_fns,
1997 scalar_projections,
1998 })
1999}
2000
2001fn parse_ctes(with: &sqlparser::ast::With) -> Result<Vec<ParsedCte>> {
2008 let max_ctes = 16;
2009 let mut ctes = Vec::new();
2010
2011 for (i, cte) in with.cte_tables.iter().enumerate() {
2012 if i >= max_ctes {
2013 return Err(QueryError::UnsupportedFeature(format!(
2014 "too many CTEs (max {max_ctes})"
2015 )));
2016 }
2017
2018 let name = cte.alias.name.value.clone();
2019
2020 let (inner_select, recursive_arm) = match cte.query.body.as_ref() {
2025 SetExpr::Select(s) => (parse_select(s)?, None),
2026 SetExpr::SetOperation {
2027 op, left, right, ..
2028 } if with.recursive => {
2029 use sqlparser::ast::SetOperator;
2030 if !matches!(op, SetOperator::Union) {
2031 return Err(QueryError::UnsupportedFeature(
2032 "recursive CTE body must use UNION (not INTERSECT/EXCEPT)".to_string(),
2033 ));
2034 }
2035 let anchor = match left.as_ref() {
2036 SetExpr::Select(s) => parse_select(s)?,
2037 _ => {
2038 return Err(QueryError::UnsupportedFeature(
2039 "recursive CTE anchor must be a simple SELECT".to_string(),
2040 ));
2041 }
2042 };
2043 let recursive = match right.as_ref() {
2044 SetExpr::Select(s) => parse_select(s)?,
2045 _ => {
2046 return Err(QueryError::UnsupportedFeature(
2047 "recursive CTE recursive arm must be a simple SELECT".to_string(),
2048 ));
2049 }
2050 };
2051 (anchor, Some(recursive))
2052 }
2053 _ => {
2054 return Err(QueryError::UnsupportedFeature(
2055 "CTE body must be a simple SELECT (or anchor UNION recursive for WITH RECURSIVE)".to_string(),
2056 ));
2057 }
2058 };
2059
2060 let order_by = match &cte.query.order_by {
2062 Some(ob) => parse_order_by(ob)?,
2063 None => vec![],
2064 };
2065 let limit = parse_limit(query_limit_expr(&cte.query)?)?;
2066
2067 ctes.push(ParsedCte {
2068 name,
2069 query: ParsedSelect {
2070 order_by,
2071 limit,
2072 ..inner_select
2073 },
2074 recursive_arm,
2075 });
2076 }
2077
2078 Ok(ctes)
2079}
2080
2081fn parse_having_expr(expr: &Expr) -> Result<Vec<HavingCondition>> {
2086 match expr {
2087 Expr::BinaryOp {
2088 left,
2089 op: BinaryOperator::And,
2090 right,
2091 } => {
2092 let mut conditions = parse_having_expr(left)?;
2093 conditions.extend(parse_having_expr(right)?);
2094 Ok(conditions)
2095 }
2096 Expr::BinaryOp { left, op, right } => {
2097 let aggregate = match left.as_ref() {
2099 Expr::Function(_) => {
2100 let (agg, _filter) = try_parse_aggregate(left)?.ok_or_else(|| {
2101 QueryError::UnsupportedFeature(
2102 "HAVING requires aggregate functions (COUNT, SUM, AVG, MIN, MAX)"
2103 .to_string(),
2104 )
2105 })?;
2106 agg
2107 }
2108 _ => {
2109 return Err(QueryError::UnsupportedFeature(
2110 "HAVING clause must reference aggregate functions".to_string(),
2111 ));
2112 }
2113 };
2114
2115 let value = expr_to_value(right)?;
2117
2118 let having_op = match op {
2120 BinaryOperator::Eq => HavingOp::Eq,
2121 BinaryOperator::Lt => HavingOp::Lt,
2122 BinaryOperator::LtEq => HavingOp::Le,
2123 BinaryOperator::Gt => HavingOp::Gt,
2124 BinaryOperator::GtEq => HavingOp::Ge,
2125 other => {
2126 return Err(QueryError::UnsupportedFeature(format!(
2127 "unsupported HAVING operator: {other:?}"
2128 )));
2129 }
2130 };
2131
2132 Ok(vec![HavingCondition::AggregateComparison {
2133 aggregate,
2134 op: having_op,
2135 value,
2136 }])
2137 }
2138 Expr::Nested(inner) => parse_having_expr(inner),
2139 other => Err(QueryError::UnsupportedFeature(format!(
2140 "unsupported HAVING expression: {other:?}"
2141 ))),
2142 }
2143}
2144
2145type ParsedSelectList = (Option<Vec<ColumnName>>, Option<Vec<Option<String>>>);
2159
2160fn parse_select_items(items: &[SelectItem]) -> Result<ParsedSelectList> {
2161 let mut columns = Vec::new();
2162 let mut aliases: Vec<Option<String>> = Vec::new();
2163
2164 for item in items {
2165 #[allow(clippy::match_same_arms)]
2170 match item {
2171 SelectItem::Wildcard(_) => {
2172 return Ok((None, None));
2175 }
2176 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
2177 columns.push(ColumnName::new(ident.value.clone()));
2178 aliases.push(None);
2179 }
2180 SelectItem::UnnamedExpr(Expr::CompoundIdentifier(idents)) if idents.len() == 2 => {
2181 columns.push(ColumnName::new(idents[1].value.clone()));
2183 aliases.push(None);
2184 }
2185 SelectItem::ExprWithAlias {
2186 expr: Expr::Identifier(ident),
2187 alias,
2188 } => {
2189 columns.push(ColumnName::new(ident.value.clone()));
2190 aliases.push(Some(alias.value.clone()));
2191 }
2192 SelectItem::ExprWithAlias {
2193 expr: Expr::CompoundIdentifier(idents),
2194 alias,
2195 } if idents.len() == 2 => {
2196 columns.push(ColumnName::new(idents[1].value.clone()));
2198 aliases.push(Some(alias.value.clone()));
2199 }
2200 SelectItem::UnnamedExpr(Expr::Function(_))
2201 | SelectItem::ExprWithAlias {
2202 expr: Expr::Function(_) | Expr::Case { .. },
2203 ..
2204 } => {
2205 }
2209 SelectItem::UnnamedExpr(Expr::Cast { .. })
2213 | SelectItem::ExprWithAlias {
2214 expr: Expr::Cast { .. },
2215 ..
2216 } => {}
2217 SelectItem::UnnamedExpr(Expr::BinaryOp {
2218 op: BinaryOperator::StringConcat,
2219 ..
2220 })
2221 | SelectItem::ExprWithAlias {
2222 expr:
2223 Expr::BinaryOp {
2224 op: BinaryOperator::StringConcat,
2225 ..
2226 },
2227 ..
2228 } => {}
2229 other => {
2230 return Err(QueryError::UnsupportedFeature(format!(
2231 "unsupported SELECT item: {other:?}"
2232 )));
2233 }
2234 }
2235 }
2236
2237 Ok((Some(columns), Some(aliases)))
2238}
2239
2240type ParsedAggregateList = (Vec<AggregateFunction>, Vec<Option<Vec<Predicate>>>);
2248
2249fn parse_aggregates_from_select_items(items: &[SelectItem]) -> Result<ParsedAggregateList> {
2250 let mut aggregates = Vec::new();
2251 let mut filters = Vec::new();
2252
2253 for item in items {
2254 match item {
2255 SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
2256 if let Some((agg, filter)) = try_parse_aggregate(expr)? {
2257 aggregates.push(agg);
2258 filters.push(filter);
2259 }
2260 }
2261 _ => {
2262 }
2264 }
2265 }
2266
2267 Ok((aggregates, filters))
2268}
2269
2270fn parse_case_columns_from_select_items(items: &[SelectItem]) -> Result<Vec<ComputedColumn>> {
2275 let mut case_cols = Vec::new();
2276
2277 for item in items {
2278 if let SelectItem::ExprWithAlias {
2279 expr:
2280 Expr::Case {
2281 operand,
2282 conditions,
2283 else_result,
2284 ..
2285 },
2286 alias,
2287 } = item
2288 {
2289 let mut when_clauses = Vec::new();
2294 for case_when in conditions {
2295 let cond_expr = &case_when.condition;
2296 let result_expr = &case_when.result;
2297 let condition = match operand.as_deref() {
2298 None => parse_where_expr(cond_expr)?,
2299 Some(operand_expr) => parse_where_expr(&Expr::BinaryOp {
2300 left: Box::new(operand_expr.clone()),
2301 op: BinaryOperator::Eq,
2302 right: Box::new(cond_expr.clone()),
2303 })?,
2304 };
2305 let result = expr_to_value(result_expr)?;
2306 when_clauses.push(CaseWhenArm { condition, result });
2307 }
2308
2309 let else_value = match else_result {
2310 Some(expr) => expr_to_value(expr)?,
2311 None => Value::Null,
2312 };
2313
2314 case_cols.push(ComputedColumn {
2315 alias: ColumnName::new(alias.value.clone()),
2316 when_clauses,
2317 else_value,
2318 });
2319 }
2320 }
2321
2322 Ok(case_cols)
2323}
2324
2325fn parse_scalar_columns_from_select_items(
2335 items: &[SelectItem],
2336) -> Result<Vec<ParsedScalarProjection>> {
2337 let mut out = Vec::new();
2338 for item in items {
2339 let (expr, alias) = match item {
2340 SelectItem::UnnamedExpr(e) => (e, None),
2341 SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
2342 _ => continue,
2343 };
2344
2345 if !is_scalar_projection_shape(expr) {
2346 continue;
2347 }
2348
2349 let scalar = expr_to_scalar_expr(expr)?;
2350 let output_name = alias
2351 .clone()
2352 .unwrap_or_else(|| synthesize_column_name(expr));
2353 out.push(ParsedScalarProjection {
2354 expr: scalar,
2355 output_name: ColumnName::new(output_name),
2356 alias,
2357 });
2358 }
2359 Ok(out)
2360}
2361
2362fn is_scalar_projection_shape(expr: &Expr) -> bool {
2366 match expr {
2367 Expr::Function(func) => {
2368 if func.over.is_some() {
2370 return false;
2371 }
2372 let name = func.name.to_string().to_uppercase();
2373 !matches!(name.as_str(), "COUNT" | "SUM" | "AVG" | "MIN" | "MAX")
2374 }
2375 Expr::Cast { .. }
2376 | Expr::BinaryOp {
2377 op: BinaryOperator::StringConcat,
2378 ..
2379 } => true,
2380 _ => false,
2381 }
2382}
2383
2384fn synthesize_column_name(expr: &Expr) -> String {
2388 match expr {
2389 Expr::Function(func) => func.name.to_string().to_lowercase(),
2390 Expr::Cast { .. } => "cast".to_string(),
2391 Expr::BinaryOp {
2392 op: BinaryOperator::StringConcat,
2393 ..
2394 } => "concat".to_string(),
2395 _ => "expr".to_string(),
2396 }
2397}
2398
2399fn parse_window_fns_from_select_items(items: &[SelectItem]) -> Result<Vec<ParsedWindowFn>> {
2403 let mut out = Vec::new();
2404 for item in items {
2405 let (expr, alias) = match item {
2406 SelectItem::UnnamedExpr(e) => (e, None),
2407 SelectItem::ExprWithAlias { expr, alias } => (expr, Some(alias.value.clone())),
2408 _ => continue,
2409 };
2410 if let Some(parsed) = try_parse_window_fn(expr, alias)? {
2411 out.push(parsed);
2412 }
2413 }
2414 Ok(out)
2415}
2416
2417fn try_parse_window_fn(expr: &Expr, alias: Option<String>) -> Result<Option<ParsedWindowFn>> {
2418 let Expr::Function(func) = expr else {
2419 return Ok(None);
2420 };
2421 let Some(over) = &func.over else {
2422 return Ok(None);
2423 };
2424 let spec = match over {
2425 sqlparser::ast::WindowType::WindowSpec(s) => s,
2426 sqlparser::ast::WindowType::NamedWindow(_) => {
2427 return Err(QueryError::UnsupportedFeature(
2428 "named windows (OVER w) are not supported".into(),
2429 ));
2430 }
2431 };
2432 if spec.window_frame.is_some() {
2433 return Err(QueryError::UnsupportedFeature(
2434 "explicit window frames (ROWS/RANGE BETWEEN ...) are not supported; \
2435 omit the frame clause for default behaviour"
2436 .into(),
2437 ));
2438 }
2439
2440 let func_name = func.name.to_string().to_uppercase();
2441 let args = match &func.args {
2442 sqlparser::ast::FunctionArguments::List(list) => list.args.clone(),
2443 _ => Vec::new(),
2444 };
2445 let function = parse_window_function_name(&func_name, &args)?;
2446
2447 let partition_by: Vec<ColumnName> = spec
2448 .partition_by
2449 .iter()
2450 .map(parse_column_expr)
2451 .collect::<Result<_>>()?;
2452 let order_by: Vec<OrderByClause> = spec
2453 .order_by
2454 .iter()
2455 .map(parse_order_by_expr)
2456 .collect::<Result<_>>()?;
2457
2458 Ok(Some(ParsedWindowFn {
2459 function,
2460 partition_by,
2461 order_by,
2462 alias,
2463 }))
2464}
2465
2466fn parse_column_expr(expr: &Expr) -> Result<ColumnName> {
2467 match expr {
2468 Expr::Identifier(ident) => Ok(ColumnName::new(ident.value.clone())),
2469 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
2470 Ok(ColumnName::new(idents[1].value.clone()))
2471 }
2472 other => Err(QueryError::UnsupportedFeature(format!(
2473 "window PARTITION BY / argument must be a column reference, got: {other:?}"
2474 ))),
2475 }
2476}
2477
2478fn parse_window_function_name(
2479 name: &str,
2480 args: &[sqlparser::ast::FunctionArg],
2481) -> Result<crate::window::WindowFunction> {
2482 use crate::window::WindowFunction;
2483
2484 let arg_exprs: Vec<&Expr> = args
2485 .iter()
2486 .filter_map(|a| match a {
2487 sqlparser::ast::FunctionArg::Unnamed(sqlparser::ast::FunctionArgExpr::Expr(e)) => {
2488 Some(e)
2489 }
2490 _ => None,
2491 })
2492 .collect();
2493
2494 let single_col = || -> Result<ColumnName> {
2495 if arg_exprs.is_empty() {
2496 return Err(QueryError::ParseError(format!(
2497 "{name} requires a column argument"
2498 )));
2499 }
2500 parse_column_expr(arg_exprs[0])
2501 };
2502
2503 let parse_offset = || -> Result<usize> {
2504 if arg_exprs.len() < 2 {
2505 return Ok(1);
2506 }
2507 match arg_exprs[1] {
2508 Expr::Value(vws) => match &vws.value {
2509 SqlValue::Number(n, _) => n
2510 .parse::<usize>()
2511 .map_err(|_| QueryError::ParseError(format!("invalid {name} offset: {n}"))),
2512 other => Err(QueryError::UnsupportedFeature(format!(
2513 "{name} offset must be a literal integer; got {other:?}"
2514 ))),
2515 },
2516 other => Err(QueryError::UnsupportedFeature(format!(
2517 "{name} offset must be a literal integer; got {other:?}"
2518 ))),
2519 }
2520 };
2521
2522 match name {
2523 "ROW_NUMBER" => Ok(WindowFunction::RowNumber),
2524 "RANK" => Ok(WindowFunction::Rank),
2525 "DENSE_RANK" => Ok(WindowFunction::DenseRank),
2526 "LAG" => Ok(WindowFunction::Lag {
2527 column: single_col()?,
2528 offset: parse_offset()?,
2529 }),
2530 "LEAD" => Ok(WindowFunction::Lead {
2531 column: single_col()?,
2532 offset: parse_offset()?,
2533 }),
2534 "FIRST_VALUE" => Ok(WindowFunction::FirstValue {
2535 column: single_col()?,
2536 }),
2537 "LAST_VALUE" => Ok(WindowFunction::LastValue {
2538 column: single_col()?,
2539 }),
2540 other => Err(QueryError::UnsupportedFeature(format!(
2541 "unknown window function: {other}"
2542 ))),
2543 }
2544}
2545
2546type ParsedAggregate = (AggregateFunction, Option<Vec<Predicate>>);
2548
2549fn try_parse_aggregate(expr: &Expr) -> Result<Option<ParsedAggregate>> {
2553 let parsed_filter: Option<Vec<Predicate>> = match expr {
2554 Expr::Function(func) => match &func.filter {
2555 Some(filter_expr) => Some(parse_where_expr(filter_expr)?),
2556 None => None,
2557 },
2558 _ => None,
2559 };
2560 let func_only = try_parse_aggregate_func(expr)?;
2561 Ok(func_only.map(|f| (f, parsed_filter)))
2562}
2563
2564fn try_parse_aggregate_func(expr: &Expr) -> Result<Option<AggregateFunction>> {
2566 match expr {
2567 Expr::Function(func) => {
2568 if func.over.is_some() {
2573 return Ok(None);
2574 }
2575 let func_name = func.name.to_string().to_uppercase();
2576
2577 let args = match &func.args {
2579 sqlparser::ast::FunctionArguments::List(list) => &list.args,
2580 _ => {
2581 return Err(QueryError::UnsupportedFeature(
2582 "non-list function arguments not supported".to_string(),
2583 ));
2584 }
2585 };
2586
2587 match func_name.as_str() {
2588 "COUNT" => {
2589 if args.len() == 1 {
2591 match &args[0] {
2592 sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
2593 sqlparser::ast::FunctionArgExpr::Wildcard => {
2594 Ok(Some(AggregateFunction::CountStar))
2595 }
2596 sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
2597 Ok(Some(AggregateFunction::Count(ColumnName::new(
2598 ident.value.clone(),
2599 ))))
2600 }
2601 _ => Err(QueryError::UnsupportedFeature(
2602 "COUNT with complex expression not supported".to_string(),
2603 )),
2604 },
2605 _ => Err(QueryError::UnsupportedFeature(
2606 "named function arguments not supported".to_string(),
2607 )),
2608 }
2609 } else {
2610 Err(QueryError::ParseError(format!(
2611 "COUNT expects 1 argument, got {}",
2612 args.len()
2613 )))
2614 }
2615 }
2616 "SUM" | "AVG" | "MIN" | "MAX" => {
2617 if args.len() != 1 {
2619 return Err(QueryError::ParseError(format!(
2620 "{} expects 1 argument, got {}",
2621 func_name,
2622 args.len()
2623 )));
2624 }
2625
2626 match &args[0] {
2627 sqlparser::ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
2628 sqlparser::ast::FunctionArgExpr::Expr(Expr::Identifier(ident)) => {
2629 let column = ColumnName::new(ident.value.clone());
2630 match func_name.as_str() {
2631 "SUM" => Ok(Some(AggregateFunction::Sum(column))),
2632 "AVG" => Ok(Some(AggregateFunction::Avg(column))),
2633 "MIN" => Ok(Some(AggregateFunction::Min(column))),
2634 "MAX" => Ok(Some(AggregateFunction::Max(column))),
2635 _ => unreachable!(),
2636 }
2637 }
2638 _ => Err(QueryError::UnsupportedFeature(format!(
2639 "{func_name} with complex expression not supported"
2640 ))),
2641 },
2642 _ => Err(QueryError::UnsupportedFeature(
2643 "named function arguments not supported".to_string(),
2644 )),
2645 }
2646 }
2647 _ => {
2648 Ok(None)
2650 }
2651 }
2652 }
2653 _ => {
2654 Ok(None)
2656 }
2657 }
2658}
2659
2660fn parse_group_by_expr(exprs: &[Expr]) -> Result<Vec<ColumnName>> {
2662 let mut columns = Vec::new();
2663
2664 for expr in exprs {
2665 match expr {
2666 Expr::Identifier(ident) => {
2667 columns.push(ColumnName::new(ident.value.clone()));
2668 }
2669 _ => {
2670 return Err(QueryError::UnsupportedFeature(
2671 "complex GROUP BY expressions not supported".to_string(),
2672 ));
2673 }
2674 }
2675 }
2676
2677 Ok(columns)
2678}
2679
2680const MAX_WHERE_DEPTH: usize = 100;
2688
2689fn parse_where_expr(expr: &Expr) -> Result<Vec<Predicate>> {
2690 parse_where_expr_inner(expr, 0)
2691}
2692
2693fn parse_select_from_query(query: &sqlparser::ast::Query) -> Result<ParsedSelect> {
2699 match query.body.as_ref() {
2700 SetExpr::Select(s) => {
2701 let mut parsed = parse_select(s)?;
2702 if let Some(ob) = &query.order_by {
2703 parsed.order_by = parse_order_by(ob)?;
2704 }
2705 parsed.limit = parse_limit(query_limit_expr(query)?)?;
2706 parsed.offset = parse_offset_clause(query_offset(query))?;
2707 Ok(parsed)
2708 }
2709 _ => Err(QueryError::UnsupportedFeature(
2710 "subquery body must be a simple SELECT (no nested UNION/INTERSECT/EXCEPT)".to_string(),
2711 )),
2712 }
2713}
2714
2715fn parse_where_expr_inner(expr: &Expr, depth: usize) -> Result<Vec<Predicate>> {
2716 if depth >= MAX_WHERE_DEPTH {
2717 return Err(QueryError::ParseError(format!(
2718 "WHERE clause nesting exceeds maximum depth of {MAX_WHERE_DEPTH}"
2719 )));
2720 }
2721
2722 match expr {
2723 Expr::BinaryOp {
2725 left,
2726 op: BinaryOperator::And,
2727 right,
2728 } => {
2729 let mut predicates = parse_where_expr_inner(left, depth + 1)?;
2730 predicates.extend(parse_where_expr_inner(right, depth + 1)?);
2731 Ok(predicates)
2732 }
2733
2734 Expr::BinaryOp {
2736 left,
2737 op: BinaryOperator::Or,
2738 right,
2739 } => {
2740 let left_preds = parse_where_expr_inner(left, depth + 1)?;
2741 let right_preds = parse_where_expr_inner(right, depth + 1)?;
2742 Ok(vec![Predicate::Or(left_preds, right_preds)])
2743 }
2744
2745 Expr::Like {
2747 expr,
2748 pattern,
2749 negated,
2750 ..
2751 } => {
2752 let column = expr_to_column(expr)?;
2753 let pattern_str = match expr_to_predicate_value(pattern)? {
2754 PredicateValue::String(s) | PredicateValue::Literal(Value::Text(s)) => s,
2755 _ => {
2756 return Err(QueryError::UnsupportedFeature(
2757 "LIKE pattern must be a string literal".to_string(),
2758 ));
2759 }
2760 };
2761 let predicate = if *negated {
2762 Predicate::NotLike(column, pattern_str)
2763 } else {
2764 Predicate::Like(column, pattern_str)
2765 };
2766 Ok(vec![predicate])
2767 }
2768
2769 Expr::ILike {
2771 expr,
2772 pattern,
2773 negated,
2774 ..
2775 } => {
2776 let column = expr_to_column(expr)?;
2777 let pattern_str = match expr_to_predicate_value(pattern)? {
2778 PredicateValue::String(s) | PredicateValue::Literal(Value::Text(s)) => s,
2779 _ => {
2780 return Err(QueryError::UnsupportedFeature(
2781 "ILIKE pattern must be a string literal".to_string(),
2782 ));
2783 }
2784 };
2785 let predicate = if *negated {
2786 Predicate::NotILike(column, pattern_str)
2787 } else {
2788 Predicate::ILike(column, pattern_str)
2789 };
2790 Ok(vec![predicate])
2791 }
2792
2793 Expr::IsNull(expr) => {
2795 let column = expr_to_column(expr)?;
2796 Ok(vec![Predicate::IsNull(column)])
2797 }
2798
2799 Expr::IsNotNull(expr) => {
2800 let column = expr_to_column(expr)?;
2801 Ok(vec![Predicate::IsNotNull(column)])
2802 }
2803
2804 Expr::BinaryOp { left, op, right } => {
2806 let predicate = parse_comparison(left, op, right)?;
2807 Ok(vec![predicate])
2808 }
2809
2810 Expr::InList {
2812 expr,
2813 list,
2814 negated,
2815 } => {
2816 let column = expr_to_column(expr)?;
2817 let values: Result<Vec<_>> = list.iter().map(expr_to_predicate_value).collect();
2818 if *negated {
2819 Ok(vec![Predicate::NotIn(column, values?)])
2820 } else {
2821 Ok(vec![Predicate::In(column, values?)])
2822 }
2823 }
2824
2825 Expr::InSubquery {
2829 expr,
2830 subquery,
2831 negated,
2832 } => {
2833 let column = expr_to_column(expr)?;
2834 let inner = parse_select_from_query(subquery)?;
2835 Ok(vec![Predicate::InSubquery {
2836 column,
2837 subquery: Box::new(inner),
2838 negated: *negated,
2839 }])
2840 }
2841
2842 Expr::Exists { subquery, negated } => {
2844 let inner = parse_select_from_query(subquery)?;
2845 Ok(vec![Predicate::Exists {
2846 subquery: Box::new(inner),
2847 negated: *negated,
2848 }])
2849 }
2850
2851 Expr::Between {
2857 expr,
2858 negated,
2859 low,
2860 high,
2861 } => {
2862 let column = expr_to_column(expr)?;
2863 let low_val = expr_to_predicate_value(low)?;
2864 let high_val = expr_to_predicate_value(high)?;
2865
2866 if *negated {
2867 return Ok(vec![Predicate::NotBetween(column, low_val, high_val)]);
2868 }
2869
2870 kimberlite_properties::sometimes!(
2871 true,
2872 "query.between_desugared_to_ge_le",
2873 "BETWEEN predicate desugared into Ge + Le pair"
2874 );
2875
2876 Ok(vec![
2877 Predicate::Ge(column.clone(), low_val),
2878 Predicate::Le(column, high_val),
2879 ])
2880 }
2881
2882 Expr::Nested(inner) => parse_where_expr_inner(inner, depth + 1),
2884
2885 other => Err(QueryError::UnsupportedFeature(format!(
2886 "unsupported WHERE expression: {other:?}"
2887 ))),
2888 }
2889}
2890
2891fn parse_comparison(left: &Expr, op: &BinaryOperator, right: &Expr) -> Result<Predicate> {
2892 let left = match left {
2896 Expr::Nested(inner) => inner.as_ref(),
2897 other => other,
2898 };
2899
2900 if matches!(op, BinaryOperator::AtArrow) {
2903 let column = expr_to_column(left)?;
2904 let value = expr_to_predicate_value(right)?;
2905 return Ok(Predicate::JsonContains { column, value });
2906 }
2907
2908 if let Expr::BinaryOp {
2913 left: json_left,
2914 op: arrow_op @ (BinaryOperator::Arrow | BinaryOperator::LongArrow),
2915 right: path_expr,
2916 } = left
2917 {
2918 let as_text = matches!(arrow_op, BinaryOperator::LongArrow);
2919 let column = expr_to_column(json_left)?;
2920 let path = match path_expr.as_ref() {
2921 Expr::Value(vws) => match &vws.value {
2922 SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => s.clone(),
2923 SqlValue::Number(n, _) => n.clone(),
2924 _ => {
2925 return Err(QueryError::UnsupportedFeature(format!(
2926 "JSON path key must be a string or integer literal, got {path_expr:?}"
2927 )));
2928 }
2929 },
2930 other => {
2931 return Err(QueryError::UnsupportedFeature(format!(
2932 "JSON path key must be a string or integer literal, got {other:?}"
2933 )));
2934 }
2935 };
2936 let value = expr_to_predicate_value(right)?;
2937 if !matches!(op, BinaryOperator::Eq) {
2938 return Err(QueryError::UnsupportedFeature(format!(
2939 "JSON path extraction supports only `=` comparison; got {op:?}"
2940 )));
2941 }
2942 return Ok(Predicate::JsonExtractEq {
2943 column,
2944 path,
2945 as_text,
2946 value,
2947 });
2948 }
2949
2950 let cmp_op = sql_binop_to_scalar_cmp(op);
2954
2955 if !expr_needs_scalar(left) && !expr_needs_scalar(right) {
2958 if let (Ok(column), Ok(value)) = (expr_to_column(left), expr_to_predicate_value(right)) {
2959 return match op {
2960 BinaryOperator::Eq => Ok(Predicate::Eq(column, value)),
2961 BinaryOperator::Lt => Ok(Predicate::Lt(column, value)),
2962 BinaryOperator::LtEq => Ok(Predicate::Le(column, value)),
2963 BinaryOperator::Gt => Ok(Predicate::Gt(column, value)),
2964 BinaryOperator::GtEq => Ok(Predicate::Ge(column, value)),
2965 BinaryOperator::NotEq => {
2966 Ok(Predicate::ScalarCmp {
2969 lhs: ScalarExpr::Column(column),
2970 op: ScalarCmpOp::NotEq,
2971 rhs: predicate_value_to_scalar_expr(&value),
2972 })
2973 }
2974 other => Err(QueryError::UnsupportedFeature(format!(
2975 "unsupported operator: {other:?}"
2976 ))),
2977 };
2978 }
2979 }
2980
2981 let lhs = expr_to_scalar_expr(left)?;
2984 let rhs = expr_to_scalar_expr(right)?;
2985 let op = cmp_op.ok_or_else(|| {
2986 QueryError::UnsupportedFeature(format!("unsupported operator in scalar comparison: {op:?}"))
2987 })?;
2988 Ok(Predicate::ScalarCmp { lhs, op, rhs })
2989}
2990
2991fn expr_needs_scalar(expr: &Expr) -> bool {
2995 match expr {
2996 Expr::Function(_)
2997 | Expr::Cast { .. }
2998 | Expr::BinaryOp {
2999 op: BinaryOperator::StringConcat,
3000 ..
3001 } => true,
3002 Expr::Nested(inner) => expr_needs_scalar(inner),
3003 _ => false,
3004 }
3005}
3006
3007fn sql_binop_to_scalar_cmp(op: &BinaryOperator) -> Option<ScalarCmpOp> {
3008 Some(match op {
3009 BinaryOperator::Eq => ScalarCmpOp::Eq,
3010 BinaryOperator::NotEq => ScalarCmpOp::NotEq,
3011 BinaryOperator::Lt => ScalarCmpOp::Lt,
3012 BinaryOperator::LtEq => ScalarCmpOp::Le,
3013 BinaryOperator::Gt => ScalarCmpOp::Gt,
3014 BinaryOperator::GtEq => ScalarCmpOp::Ge,
3015 _ => return None,
3016 })
3017}
3018
3019fn predicate_value_to_scalar_expr(pv: &PredicateValue) -> ScalarExpr {
3020 match pv {
3021 PredicateValue::Int(n) => ScalarExpr::Literal(Value::BigInt(*n)),
3022 PredicateValue::String(s) => ScalarExpr::Literal(Value::Text(s.clone())),
3023 PredicateValue::Bool(b) => ScalarExpr::Literal(Value::Boolean(*b)),
3024 PredicateValue::Null => ScalarExpr::Literal(Value::Null),
3025 PredicateValue::Param(idx) => ScalarExpr::Literal(Value::Placeholder(*idx)),
3026 PredicateValue::Literal(v) => ScalarExpr::Literal(v.clone()),
3027 PredicateValue::ColumnRef(name) => {
3028 let col = name.rsplit('.').next().unwrap_or(name);
3030 ScalarExpr::Column(ColumnName::new(col.to_string()))
3031 }
3032 }
3033}
3034
3035fn expr_to_column(expr: &Expr) -> Result<ColumnName> {
3036 match expr {
3037 Expr::Identifier(ident) => Ok(ColumnName::new(ident.value.clone())),
3038 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
3039 Ok(ColumnName::new(idents[1].value.clone()))
3041 }
3042 other => Err(QueryError::UnsupportedFeature(format!(
3043 "expected column name, got {other:?}"
3044 ))),
3045 }
3046}
3047
3048pub fn expr_to_scalar_expr(expr: &Expr) -> Result<ScalarExpr> {
3061 match expr {
3062 Expr::Value(_) | Expr::UnaryOp { .. } => Ok(ScalarExpr::Literal(expr_to_value(expr)?)),
3065
3066 Expr::Identifier(ident) => Ok(ScalarExpr::Column(ColumnName::new(ident.value.clone()))),
3068 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
3069 Ok(ScalarExpr::Column(ColumnName::new(idents[1].value.clone())))
3070 }
3071
3072 Expr::BinaryOp {
3075 left,
3076 op: BinaryOperator::StringConcat,
3077 right,
3078 } => Ok(ScalarExpr::Concat(vec![
3079 expr_to_scalar_expr(left)?,
3080 expr_to_scalar_expr(right)?,
3081 ])),
3082
3083 Expr::Cast {
3085 expr: inner,
3086 data_type,
3087 ..
3088 } => {
3089 let target = sql_data_type_to_data_type(data_type)?;
3090 Ok(ScalarExpr::Cast(
3091 Box::new(expr_to_scalar_expr(inner)?),
3092 target,
3093 ))
3094 }
3095
3096 Expr::Nested(inner) => expr_to_scalar_expr(inner),
3098
3099 Expr::Function(func) => {
3101 if func.over.is_some() {
3102 return Err(QueryError::UnsupportedFeature(
3103 "window functions are not valid in this position".to_string(),
3104 ));
3105 }
3106 if func.filter.is_some() {
3107 return Err(QueryError::UnsupportedFeature(
3108 "FILTER clause only applies to aggregate functions".to_string(),
3109 ));
3110 }
3111 let name = func.name.to_string().to_uppercase();
3112 let args = match &func.args {
3113 sqlparser::ast::FunctionArguments::List(list) => &list.args,
3114 _ => {
3115 return Err(QueryError::UnsupportedFeature(
3116 "non-list function arguments not supported".to_string(),
3117 ));
3118 }
3119 };
3120
3121 let mut arg_exprs: Vec<&Expr> = Vec::with_capacity(args.len());
3123 for a in args {
3124 match a {
3125 sqlparser::ast::FunctionArg::Unnamed(
3126 sqlparser::ast::FunctionArgExpr::Expr(e),
3127 ) => arg_exprs.push(e),
3128 _ => {
3129 return Err(QueryError::UnsupportedFeature(format!(
3130 "unsupported argument form in scalar function {name}"
3131 )));
3132 }
3133 }
3134 }
3135
3136 let want_arity = |n: usize| -> Result<()> {
3137 if arg_exprs.len() == n {
3138 Ok(())
3139 } else {
3140 Err(QueryError::ParseError(format!(
3141 "{name} expects {n} argument(s), got {}",
3142 arg_exprs.len()
3143 )))
3144 }
3145 };
3146 let scalar = |e: &Expr| expr_to_scalar_expr(e);
3147
3148 match name.as_str() {
3149 "UPPER" => {
3150 want_arity(1)?;
3151 Ok(ScalarExpr::Upper(Box::new(scalar(arg_exprs[0])?)))
3152 }
3153 "LOWER" => {
3154 want_arity(1)?;
3155 Ok(ScalarExpr::Lower(Box::new(scalar(arg_exprs[0])?)))
3156 }
3157 "LENGTH" | "CHAR_LENGTH" | "CHARACTER_LENGTH" => {
3158 want_arity(1)?;
3159 Ok(ScalarExpr::Length(Box::new(scalar(arg_exprs[0])?)))
3160 }
3161 "TRIM" => {
3162 want_arity(1)?;
3163 Ok(ScalarExpr::Trim(Box::new(scalar(arg_exprs[0])?)))
3164 }
3165 "CONCAT" => {
3166 if arg_exprs.is_empty() {
3167 return Err(QueryError::ParseError(
3168 "CONCAT expects at least one argument".to_string(),
3169 ));
3170 }
3171 let parts = arg_exprs
3172 .iter()
3173 .map(|e| scalar(e))
3174 .collect::<Result<Vec<_>>>()?;
3175 Ok(ScalarExpr::Concat(parts))
3176 }
3177 "ABS" => {
3178 want_arity(1)?;
3179 Ok(ScalarExpr::Abs(Box::new(scalar(arg_exprs[0])?)))
3180 }
3181 "ROUND" => match arg_exprs.len() {
3182 1 => Ok(ScalarExpr::Round(Box::new(scalar(arg_exprs[0])?))),
3183 2 => {
3184 let n = match expr_to_value(arg_exprs[1])? {
3186 Value::BigInt(n) => i32::try_from(n).map_err(|_| {
3187 QueryError::ParseError("ROUND scale out of range".to_string())
3188 })?,
3189 other => {
3190 return Err(QueryError::ParseError(format!(
3191 "ROUND scale must be an integer literal, got {other:?}"
3192 )));
3193 }
3194 };
3195 Ok(ScalarExpr::RoundScale(Box::new(scalar(arg_exprs[0])?), n))
3196 }
3197 _ => Err(QueryError::ParseError(format!(
3198 "ROUND expects 1 or 2 arguments, got {}",
3199 arg_exprs.len()
3200 ))),
3201 },
3202 "CEIL" | "CEILING" => {
3203 want_arity(1)?;
3204 Ok(ScalarExpr::Ceil(Box::new(scalar(arg_exprs[0])?)))
3205 }
3206 "FLOOR" => {
3207 want_arity(1)?;
3208 Ok(ScalarExpr::Floor(Box::new(scalar(arg_exprs[0])?)))
3209 }
3210 "COALESCE" => {
3211 if arg_exprs.is_empty() {
3212 return Err(QueryError::ParseError(
3213 "COALESCE expects at least one argument".to_string(),
3214 ));
3215 }
3216 let parts = arg_exprs
3217 .iter()
3218 .map(|e| scalar(e))
3219 .collect::<Result<Vec<_>>>()?;
3220 Ok(ScalarExpr::Coalesce(parts))
3221 }
3222 "NULLIF" => {
3223 want_arity(2)?;
3224 Ok(ScalarExpr::Nullif(
3225 Box::new(scalar(arg_exprs[0])?),
3226 Box::new(scalar(arg_exprs[1])?),
3227 ))
3228 }
3229 "MOD" => {
3231 want_arity(2)?;
3232 Ok(ScalarExpr::Mod(
3233 Box::new(scalar(arg_exprs[0])?),
3234 Box::new(scalar(arg_exprs[1])?),
3235 ))
3236 }
3237 "POWER" | "POW" => {
3238 want_arity(2)?;
3239 Ok(ScalarExpr::Power(
3240 Box::new(scalar(arg_exprs[0])?),
3241 Box::new(scalar(arg_exprs[1])?),
3242 ))
3243 }
3244 "SQRT" => {
3245 want_arity(1)?;
3246 Ok(ScalarExpr::Sqrt(Box::new(scalar(arg_exprs[0])?)))
3247 }
3248 "SUBSTRING" | "SUBSTR" => {
3249 use kimberlite_types::SubstringRange;
3256 match arg_exprs.len() {
3257 2 => {
3258 let start = match expr_to_value(arg_exprs[1])? {
3259 Value::BigInt(n) => n,
3260 Value::Integer(n) => i64::from(n),
3261 other => {
3262 return Err(QueryError::ParseError(format!(
3263 "SUBSTRING start must be an integer literal, got {other:?}"
3264 )));
3265 }
3266 };
3267 Ok(ScalarExpr::Substring(
3268 Box::new(scalar(arg_exprs[0])?),
3269 SubstringRange::from_start(start),
3270 ))
3271 }
3272 3 => {
3273 let start = match expr_to_value(arg_exprs[1])? {
3274 Value::BigInt(n) => n,
3275 Value::Integer(n) => i64::from(n),
3276 other => {
3277 return Err(QueryError::ParseError(format!(
3278 "SUBSTRING start must be an integer literal, got {other:?}"
3279 )));
3280 }
3281 };
3282 let length = match expr_to_value(arg_exprs[2])? {
3283 Value::BigInt(n) => n,
3284 Value::Integer(n) => i64::from(n),
3285 other => {
3286 return Err(QueryError::ParseError(format!(
3287 "SUBSTRING length must be an integer literal, got {other:?}"
3288 )));
3289 }
3290 };
3291 let range = SubstringRange::try_new(start, length)
3292 .map_err(|e| QueryError::ParseError(format!("SUBSTRING: {e}")))?;
3293 Ok(ScalarExpr::Substring(
3294 Box::new(scalar(arg_exprs[0])?),
3295 range,
3296 ))
3297 }
3298 n => Err(QueryError::ParseError(format!(
3299 "SUBSTRING expects 2 or 3 arguments, got {n}"
3300 ))),
3301 }
3302 }
3303 "EXTRACT" => {
3304 use kimberlite_types::DateField;
3311 want_arity(2)?;
3312 let field_name = match expr_to_value(arg_exprs[0])? {
3313 Value::Text(s) => s,
3314 other => {
3315 return Err(QueryError::ParseError(format!(
3316 "EXTRACT field must be a string literal, got {other:?}"
3317 )));
3318 }
3319 };
3320 let field = DateField::parse(&field_name)
3321 .map_err(|e| QueryError::ParseError(format!("EXTRACT: {e}")))?;
3322 Ok(ScalarExpr::Extract(field, Box::new(scalar(arg_exprs[1])?)))
3323 }
3324 "DATE_TRUNC" | "DATETRUNC" => {
3325 use kimberlite_types::DateField;
3326 want_arity(2)?;
3327 let field_name = match expr_to_value(arg_exprs[0])? {
3328 Value::Text(s) => s,
3329 other => {
3330 return Err(QueryError::ParseError(format!(
3331 "DATE_TRUNC field must be a string literal, got {other:?}"
3332 )));
3333 }
3334 };
3335 let field = DateField::parse(&field_name)
3336 .map_err(|e| QueryError::ParseError(format!("DATE_TRUNC: {e}")))?;
3337 if !field.is_truncatable() {
3338 return Err(QueryError::ParseError(format!(
3339 "DATE_TRUNC field {field:?} is not truncatable (use one of YEAR, MONTH, DAY, HOUR, MINUTE, SECOND)"
3340 )));
3341 }
3342 Ok(ScalarExpr::DateTrunc(
3343 field,
3344 Box::new(scalar(arg_exprs[1])?),
3345 ))
3346 }
3347 "NOW" => {
3348 if !arg_exprs.is_empty() {
3349 return Err(QueryError::ParseError(format!(
3350 "NOW expects 0 arguments, got {}",
3351 arg_exprs.len()
3352 )));
3353 }
3354 Ok(ScalarExpr::Now)
3355 }
3356 "CURRENT_TIMESTAMP" => {
3357 if !arg_exprs.is_empty() {
3358 return Err(QueryError::ParseError(format!(
3359 "CURRENT_TIMESTAMP expects 0 arguments, got {}",
3360 arg_exprs.len()
3361 )));
3362 }
3363 Ok(ScalarExpr::CurrentTimestamp)
3364 }
3365 "CURRENT_DATE" => {
3366 if !arg_exprs.is_empty() {
3367 return Err(QueryError::ParseError(format!(
3368 "CURRENT_DATE expects 0 arguments, got {}",
3369 arg_exprs.len()
3370 )));
3371 }
3372 Ok(ScalarExpr::CurrentDate)
3373 }
3374 other => Err(QueryError::UnsupportedFeature(format!(
3375 "scalar function {other} is not supported"
3376 ))),
3377 }
3378 }
3379
3380 other => Err(QueryError::UnsupportedFeature(format!(
3381 "unsupported scalar expression: {other:?}"
3382 ))),
3383 }
3384}
3385
3386fn sql_data_type_to_data_type(sql_ty: &SqlDataType) -> Result<DataType> {
3390 Ok(match sql_ty {
3391 SqlDataType::TinyInt(_) => DataType::TinyInt,
3392 SqlDataType::SmallInt(_) => DataType::SmallInt,
3393 SqlDataType::Int(_) | SqlDataType::Integer(_) => DataType::Integer,
3394 SqlDataType::BigInt(_) => DataType::BigInt,
3395 SqlDataType::Real | SqlDataType::Float(_) | SqlDataType::Double(_) => DataType::Real,
3396 SqlDataType::Text | SqlDataType::Varchar(_) | SqlDataType::String(_) => DataType::Text,
3397 SqlDataType::Boolean | SqlDataType::Bool => DataType::Boolean,
3398 SqlDataType::Date => DataType::Date,
3399 SqlDataType::Time(_, _) => DataType::Time,
3400 SqlDataType::Timestamp(_, _) => DataType::Timestamp,
3401 SqlDataType::Uuid => DataType::Uuid,
3402 SqlDataType::JSON => DataType::Json,
3403 other => {
3404 return Err(QueryError::UnsupportedFeature(format!(
3405 "CAST to {other:?} is not supported"
3406 )));
3407 }
3408 })
3409}
3410
3411fn expr_to_predicate_value(expr: &Expr) -> Result<PredicateValue> {
3412 match expr {
3413 Expr::Identifier(ident) => {
3415 Ok(PredicateValue::ColumnRef(ident.value.clone()))
3417 }
3418 Expr::CompoundIdentifier(idents) if idents.len() == 2 => {
3419 Ok(PredicateValue::ColumnRef(format!(
3421 "{}.{}",
3422 idents[0].value, idents[1].value
3423 )))
3424 }
3425 Expr::Value(vws) => match &vws.value {
3426 SqlValue::Number(n, _) => {
3427 let value = parse_number_literal(n)?;
3428 match value {
3429 Value::BigInt(v) => Ok(PredicateValue::Int(v)),
3430 Value::Decimal(_, _) => Ok(PredicateValue::Literal(value)),
3431 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
3432 }
3433 }
3434 SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => {
3435 Ok(PredicateValue::String(s.clone()))
3436 }
3437 SqlValue::Boolean(b) => Ok(PredicateValue::Bool(*b)),
3438 SqlValue::Null => Ok(PredicateValue::Null),
3439 SqlValue::Placeholder(p) => Ok(PredicateValue::Param(parse_placeholder_index(p)?)),
3440 other => Err(QueryError::UnsupportedFeature(format!(
3441 "unsupported value expression: {other:?}"
3442 ))),
3443 },
3444 Expr::UnaryOp {
3445 op: sqlparser::ast::UnaryOperator::Minus,
3446 expr,
3447 } => {
3448 if let Expr::Value(vws) = expr.as_ref()
3450 && let SqlValue::Number(n, _) = &vws.value
3451 {
3452 let value = parse_number_literal(n)?;
3453 match value {
3454 Value::BigInt(v) => Ok(PredicateValue::Int(-v)),
3455 Value::Decimal(v, scale) => {
3456 Ok(PredicateValue::Literal(Value::Decimal(-v, scale)))
3457 }
3458 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
3459 }
3460 } else {
3461 Err(QueryError::UnsupportedFeature(format!(
3462 "unsupported unary minus operand: {expr:?}"
3463 )))
3464 }
3465 }
3466 other => Err(QueryError::UnsupportedFeature(format!(
3467 "unsupported value expression: {other:?}"
3468 ))),
3469 }
3470}
3471
3472fn parse_order_by(order_by: &sqlparser::ast::OrderBy) -> Result<Vec<OrderByClause>> {
3473 use sqlparser::ast::OrderByKind;
3474
3475 let exprs = match &order_by.kind {
3476 OrderByKind::Expressions(exprs) => exprs,
3477 OrderByKind::All(_) => {
3478 return Err(QueryError::UnsupportedFeature(
3479 "ORDER BY ALL is not supported".to_string(),
3480 ));
3481 }
3482 };
3483
3484 let mut clauses = Vec::new();
3485 for expr in exprs {
3486 clauses.push(parse_order_by_expr(expr)?);
3487 }
3488
3489 Ok(clauses)
3490}
3491
3492fn parse_order_by_expr(expr: &OrderByExpr) -> Result<OrderByClause> {
3493 let column = match &expr.expr {
3494 Expr::Identifier(ident) => ColumnName::new(ident.value.clone()),
3495 other => {
3496 return Err(QueryError::UnsupportedFeature(format!(
3497 "unsupported ORDER BY expression: {other:?}"
3498 )));
3499 }
3500 };
3501
3502 let ascending = expr.options.asc.unwrap_or(true);
3503
3504 Ok(OrderByClause { column, ascending })
3505}
3506
3507fn parse_limit(limit: Option<&Expr>) -> Result<Option<LimitExpr>> {
3508 match limit {
3509 None => Ok(None),
3510 Some(Expr::Value(vws)) => match &vws.value {
3511 SqlValue::Number(n, _) => {
3512 let v: usize = n
3513 .parse()
3514 .map_err(|_| QueryError::ParseError(format!("invalid LIMIT value: {n}")))?;
3515 Ok(Some(LimitExpr::Literal(v)))
3516 }
3517 SqlValue::Placeholder(p) => Ok(Some(LimitExpr::Param(parse_placeholder_index(p)?))),
3518 other => Err(QueryError::UnsupportedFeature(format!(
3519 "LIMIT must be an integer literal or parameter; got {other:?}"
3520 ))),
3521 },
3522 Some(other) => Err(QueryError::UnsupportedFeature(format!(
3523 "LIMIT must be an integer literal or parameter; got {other:?}"
3524 ))),
3525 }
3526}
3527
3528fn query_limit_expr(query: &Query) -> Result<Option<&Expr>> {
3531 use sqlparser::ast::LimitClause;
3532 match &query.limit_clause {
3533 None => Ok(None),
3534 Some(LimitClause::LimitOffset { limit, .. }) => Ok(limit.as_ref()),
3535 Some(LimitClause::OffsetCommaLimit { .. }) => Err(QueryError::UnsupportedFeature(
3536 "MySQL-style `LIMIT <offset>, <limit>` is not supported".to_string(),
3537 )),
3538 }
3539}
3540
3541fn query_offset(query: &Query) -> Option<&sqlparser::ast::Offset> {
3545 use sqlparser::ast::LimitClause;
3546 match &query.limit_clause {
3547 Some(LimitClause::LimitOffset { offset, .. }) => offset.as_ref(),
3548 _ => None,
3549 }
3550}
3551
3552fn parse_offset_clause(offset: Option<&sqlparser::ast::Offset>) -> Result<Option<LimitExpr>> {
3555 let Some(off) = offset else { return Ok(None) };
3556 match &off.value {
3557 Expr::Value(vws) => match &vws.value {
3558 SqlValue::Number(n, _) => {
3559 let v: usize = n
3560 .parse()
3561 .map_err(|_| QueryError::ParseError(format!("invalid OFFSET value: {n}")))?;
3562 Ok(Some(LimitExpr::Literal(v)))
3563 }
3564 SqlValue::Placeholder(p) => Ok(Some(LimitExpr::Param(parse_placeholder_index(p)?))),
3565 other => Err(QueryError::UnsupportedFeature(format!(
3566 "OFFSET must be an integer literal or parameter; got {other:?}"
3567 ))),
3568 },
3569 other => Err(QueryError::UnsupportedFeature(format!(
3570 "OFFSET must be an integer literal or parameter; got {other:?}"
3571 ))),
3572 }
3573}
3574
3575fn parse_placeholder_index(placeholder: &str) -> Result<usize> {
3581 let num_str = placeholder.strip_prefix('$').ok_or_else(|| {
3582 QueryError::ParseError(format!("unsupported placeholder format: {placeholder}"))
3583 })?;
3584 let idx: usize = num_str.parse().map_err(|_| {
3585 QueryError::ParseError(format!("invalid parameter placeholder: {placeholder}"))
3586 })?;
3587 if idx == 0 {
3588 return Err(QueryError::ParseError(
3589 "parameter indices start at $1, not $0".to_string(),
3590 ));
3591 }
3592 Ok(idx)
3593}
3594
3595fn object_name_to_string(name: &ObjectName) -> String {
3596 name.0
3597 .iter()
3598 .map(|part| match part.as_ident() {
3599 Some(ident) => ident.value.clone(),
3600 None => part.to_string(),
3601 })
3602 .collect::<Vec<_>>()
3603 .join(".")
3604}
3605
3606fn parse_create_table(create_table: &sqlparser::ast::CreateTable) -> Result<ParsedCreateTable> {
3611 let table_name = object_name_to_string(&create_table.name);
3612
3613 let mut raw_columns = Vec::new();
3623 for col_def in &create_table.columns {
3624 let parsed_col = parse_column_def(col_def)?;
3625 raw_columns.push(parsed_col);
3626 }
3627 let columns = NonEmptyVec::try_new(raw_columns).map_err(|_| {
3628 crate::error::QueryError::ParseError(format!(
3629 "CREATE TABLE {table_name} requires at least one column"
3630 ))
3631 })?;
3632
3633 let mut primary_key = Vec::new();
3635 for constraint in &create_table.constraints {
3636 if let sqlparser::ast::TableConstraint::PrimaryKey(pk) = constraint {
3637 for col in &pk.columns {
3638 if let Expr::Identifier(ident) = &col.column.expr {
3639 primary_key.push(ident.value.clone());
3640 } else {
3641 primary_key.push(col.column.expr.to_string());
3642 }
3643 }
3644 }
3645 }
3646
3647 if primary_key.is_empty() {
3649 for col_def in &create_table.columns {
3650 for option in &col_def.options {
3651 if matches!(&option.option, sqlparser::ast::ColumnOption::PrimaryKey(_)) {
3652 primary_key.push(col_def.name.value.clone());
3653 }
3654 }
3655 }
3656 }
3657
3658 Ok(ParsedCreateTable {
3659 table_name,
3660 columns,
3661 primary_key,
3662 if_not_exists: create_table.if_not_exists,
3663 })
3664}
3665
3666fn parse_column_def(col_def: &SqlColumnDef) -> Result<ParsedColumn> {
3667 let name = col_def.name.value.clone();
3668
3669 let data_type = match &col_def.data_type {
3672 SqlDataType::TinyInt(_) => "TINYINT".to_string(),
3674 SqlDataType::SmallInt(_) => "SMALLINT".to_string(),
3675 SqlDataType::Int(_) | SqlDataType::Integer(_) => "INTEGER".to_string(),
3676 SqlDataType::BigInt(_) => "BIGINT".to_string(),
3677
3678 SqlDataType::Real | SqlDataType::Float(_) | SqlDataType::Double(_) => "REAL".to_string(),
3680 SqlDataType::Decimal(precision_opt) => match precision_opt {
3681 sqlparser::ast::ExactNumberInfo::PrecisionAndScale(p, s) => {
3682 format!("DECIMAL({p},{s})")
3683 }
3684 sqlparser::ast::ExactNumberInfo::Precision(p) => {
3685 format!("DECIMAL({p},0)")
3686 }
3687 sqlparser::ast::ExactNumberInfo::None => "DECIMAL(18,2)".to_string(),
3688 },
3689
3690 SqlDataType::Text | SqlDataType::Varchar(_) | SqlDataType::String(_) => "TEXT".to_string(),
3692
3693 SqlDataType::Binary(_) | SqlDataType::Varbinary(_) | SqlDataType::Blob(_) => {
3695 "BYTES".to_string()
3696 }
3697
3698 SqlDataType::Boolean | SqlDataType::Bool => "BOOLEAN".to_string(),
3700
3701 SqlDataType::Date => "DATE".to_string(),
3703 SqlDataType::Time(_, _) => "TIME".to_string(),
3704 SqlDataType::Timestamp(_, _) => "TIMESTAMP".to_string(),
3705
3706 SqlDataType::Uuid => "UUID".to_string(),
3708 SqlDataType::JSON => "JSON".to_string(),
3709
3710 other => {
3711 return Err(QueryError::UnsupportedFeature(format!(
3712 "unsupported data type: {other:?}"
3713 )));
3714 }
3715 };
3716
3717 let mut nullable = true;
3719 for option in &col_def.options {
3720 if matches!(option.option, sqlparser::ast::ColumnOption::NotNull) {
3721 nullable = false;
3722 }
3723 }
3724
3725 Ok(ParsedColumn {
3726 name,
3727 data_type,
3728 nullable,
3729 })
3730}
3731
3732fn parse_alter_table(
3733 name: &sqlparser::ast::ObjectName,
3734 operations: &[sqlparser::ast::AlterTableOperation],
3735) -> Result<ParsedAlterTable> {
3736 let table_name = object_name_to_string(name);
3737
3738 if operations.len() != 1 {
3740 return Err(QueryError::UnsupportedFeature(
3741 "ALTER TABLE supports only one operation at a time".to_string(),
3742 ));
3743 }
3744
3745 let operation = match &operations[0] {
3746 sqlparser::ast::AlterTableOperation::AddColumn { column_def, .. } => {
3747 let parsed_col = parse_column_def(column_def)?;
3748 AlterTableOperation::AddColumn(parsed_col)
3749 }
3750 sqlparser::ast::AlterTableOperation::DropColumn {
3751 column_names,
3752 if_exists: _,
3753 ..
3754 } => {
3755 if column_names.len() != 1 {
3756 return Err(QueryError::UnsupportedFeature(
3757 "ALTER TABLE DROP COLUMN supports exactly one column".to_string(),
3758 ));
3759 }
3760 let col_name = column_names[0].value.clone();
3761 AlterTableOperation::DropColumn(col_name)
3762 }
3763 other => {
3764 return Err(QueryError::UnsupportedFeature(format!(
3765 "ALTER TABLE operation not supported: {other:?}"
3766 )));
3767 }
3768 };
3769
3770 Ok(ParsedAlterTable {
3771 table_name,
3772 operation,
3773 })
3774}
3775
3776fn parse_create_index(create_index: &sqlparser::ast::CreateIndex) -> Result<ParsedCreateIndex> {
3777 let index_name = match &create_index.name {
3778 Some(name) => object_name_to_string(name),
3779 None => {
3780 return Err(QueryError::ParseError(
3781 "CREATE INDEX requires an index name".to_string(),
3782 ));
3783 }
3784 };
3785
3786 let table_name = object_name_to_string(&create_index.table_name);
3787
3788 let mut columns = Vec::new();
3789 for col in &create_index.columns {
3790 columns.push(col.column.expr.to_string());
3791 }
3792
3793 Ok(ParsedCreateIndex {
3794 index_name,
3795 table_name,
3796 columns,
3797 })
3798}
3799
3800fn parse_insert(insert: &sqlparser::ast::Insert) -> Result<ParsedInsert> {
3805 let table = insert.table.to_string();
3807
3808 let columns: Vec<String> = insert.columns.iter().map(|c| c.value.clone()).collect();
3810
3811 let values = match insert.source.as_ref().map(|s| s.body.as_ref()) {
3813 Some(SetExpr::Values(values)) => {
3814 let mut all_rows = Vec::new();
3815 for row in &values.rows {
3816 let mut parsed_row = Vec::new();
3817 for expr in row {
3818 let val = expr_to_value(expr)?;
3819 parsed_row.push(val);
3820 }
3821 all_rows.push(parsed_row);
3822 }
3823 all_rows
3824 }
3825 _ => {
3826 return Err(QueryError::UnsupportedFeature(
3827 "only VALUES clause is supported in INSERT".to_string(),
3828 ));
3829 }
3830 };
3831
3832 let returning = parse_returning(insert.returning.as_ref())?;
3834
3835 let on_conflict = match insert.on.as_ref() {
3840 None => None,
3841 Some(sqlparser::ast::OnInsert::OnConflict(oc)) => Some(parse_on_conflict(oc)?),
3842 Some(sqlparser::ast::OnInsert::DuplicateKeyUpdate(_)) => {
3843 return Err(QueryError::UnsupportedFeature(
3844 "ON DUPLICATE KEY UPDATE is not supported; use ON CONFLICT (cols) DO UPDATE"
3845 .to_string(),
3846 ));
3847 }
3848 Some(other) => {
3849 return Err(QueryError::UnsupportedFeature(format!(
3850 "unsupported ON clause on INSERT: {other:?}"
3851 )));
3852 }
3853 };
3854
3855 Ok(ParsedInsert {
3856 table,
3857 columns,
3858 values,
3859 returning,
3860 on_conflict,
3861 })
3862}
3863
3864fn parse_on_conflict(oc: &sqlparser::ast::OnConflict) -> Result<OnConflictClause> {
3871 let target = match oc.conflict_target.as_ref() {
3873 Some(sqlparser::ast::ConflictTarget::Columns(cols)) => {
3874 if cols.is_empty() {
3875 return Err(QueryError::ParseError(
3876 "ON CONFLICT requires at least one target column".to_string(),
3877 ));
3878 }
3879 cols.iter().map(|i| i.value.clone()).collect()
3880 }
3881 Some(sqlparser::ast::ConflictTarget::OnConstraint(_)) => {
3882 return Err(QueryError::UnsupportedFeature(
3883 "ON CONFLICT ON CONSTRAINT <name> is not supported; use ON CONFLICT (cols) instead"
3884 .to_string(),
3885 ));
3886 }
3887 None => {
3888 return Err(QueryError::UnsupportedFeature(
3889 "ON CONFLICT without a target column list is not supported".to_string(),
3890 ));
3891 }
3892 };
3893
3894 let action = match &oc.action {
3895 sqlparser::ast::OnConflictAction::DoNothing => OnConflictAction::DoNothing,
3896 sqlparser::ast::OnConflictAction::DoUpdate(du) => {
3897 if du.selection.is_some() {
3898 return Err(QueryError::UnsupportedFeature(
3899 "ON CONFLICT DO UPDATE WHERE ... is not yet supported".to_string(),
3900 ));
3901 }
3902 let mut assignments = Vec::with_capacity(du.assignments.len());
3903 for a in &du.assignments {
3904 let col = a.target.to_string();
3905 let rhs = parse_upsert_expr(&a.value)?;
3906 assignments.push((col, rhs));
3907 }
3908 OnConflictAction::DoUpdate { assignments }
3909 }
3910 };
3911
3912 Ok(OnConflictClause { target, action })
3913}
3914
3915fn parse_upsert_expr(expr: &Expr) -> Result<UpsertExpr> {
3922 if let Expr::CompoundIdentifier(parts) = expr {
3924 if parts.len() == 2 && parts[0].value.eq_ignore_ascii_case("EXCLUDED") {
3925 return Ok(UpsertExpr::Excluded(parts[1].value.clone()));
3926 }
3927 }
3928 let v = expr_to_value(expr)?;
3931 Ok(UpsertExpr::Value(v))
3932}
3933
3934fn parse_update(
3935 table: &sqlparser::ast::TableWithJoins,
3936 assignments: &[sqlparser::ast::Assignment],
3937 selection: Option<&Expr>,
3938 returning: Option<&Vec<SelectItem>>,
3939) -> Result<ParsedUpdate> {
3940 let table_name = match &table.relation {
3941 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
3942 other => {
3943 return Err(QueryError::UnsupportedFeature(format!(
3944 "unsupported table in UPDATE: {other:?}"
3945 )));
3946 }
3947 };
3948
3949 let mut parsed_assignments = Vec::new();
3951 for assignment in assignments {
3952 let col_name = assignment.target.to_string();
3953 let value = expr_to_value(&assignment.value)?;
3954 parsed_assignments.push((col_name, value));
3955 }
3956
3957 let predicates = match selection {
3959 Some(expr) => parse_where_expr(expr)?,
3960 None => vec![],
3961 };
3962
3963 let returning_cols = parse_returning(returning)?;
3965
3966 Ok(ParsedUpdate {
3967 table: table_name,
3968 assignments: parsed_assignments,
3969 predicates,
3970 returning: returning_cols,
3971 })
3972}
3973
3974fn parse_delete_stmt(delete: &sqlparser::ast::Delete) -> Result<ParsedDelete> {
3975 use sqlparser::ast::FromTable;
3977
3978 let table_name = match &delete.from {
3979 FromTable::WithFromKeyword(tables) => {
3980 if tables.len() != 1 {
3981 return Err(QueryError::ParseError(
3982 "expected exactly 1 table in DELETE FROM".to_string(),
3983 ));
3984 }
3985
3986 match &tables[0].relation {
3987 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
3988 _ => {
3989 return Err(QueryError::ParseError(
3990 "DELETE only supports simple table names".to_string(),
3991 ));
3992 }
3993 }
3994 }
3995 FromTable::WithoutKeyword(tables) => {
3996 if tables.len() != 1 {
3997 return Err(QueryError::ParseError(
3998 "expected exactly 1 table in DELETE".to_string(),
3999 ));
4000 }
4001
4002 match &tables[0].relation {
4003 sqlparser::ast::TableFactor::Table { name, .. } => object_name_to_string(name),
4004 _ => {
4005 return Err(QueryError::ParseError(
4006 "DELETE only supports simple table names".to_string(),
4007 ));
4008 }
4009 }
4010 }
4011 };
4012
4013 let predicates = match &delete.selection {
4015 Some(expr) => parse_where_expr(expr)?,
4016 None => vec![],
4017 };
4018
4019 let returning_cols = parse_returning(delete.returning.as_ref())?;
4021
4022 Ok(ParsedDelete {
4023 table: table_name,
4024 predicates,
4025 returning: returning_cols,
4026 })
4027}
4028
4029fn parse_returning(returning: Option<&Vec<SelectItem>>) -> Result<Option<Vec<String>>> {
4031 match returning {
4032 None => Ok(None),
4033 Some(items) => {
4034 let mut columns = Vec::new();
4035 for item in items {
4036 match item {
4037 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
4038 columns.push(ident.value.clone());
4039 }
4040 SelectItem::UnnamedExpr(Expr::CompoundIdentifier(parts)) => {
4041 if let Some(last) = parts.last() {
4043 columns.push(last.value.clone());
4044 } else {
4045 return Err(QueryError::ParseError(
4046 "invalid column in RETURNING clause".to_string(),
4047 ));
4048 }
4049 }
4050 _ => {
4051 return Err(QueryError::UnsupportedFeature(
4052 "only simple column names supported in RETURNING clause".to_string(),
4053 ));
4054 }
4055 }
4056 }
4057 Ok(Some(columns))
4058 }
4059 }
4060}
4061
4062fn parse_number_literal(n: &str) -> Result<Value> {
4066 use rust_decimal::Decimal;
4067 use std::str::FromStr;
4068
4069 if n.contains('.') {
4070 let decimal = Decimal::from_str(n)
4072 .map_err(|e| QueryError::ParseError(format!("invalid decimal '{n}': {e}")))?;
4073
4074 let scale = decimal.scale() as u8;
4076
4077 if scale > 38 {
4078 return Err(QueryError::ParseError(format!(
4079 "decimal scale too large (max 38): {n}"
4080 )));
4081 }
4082
4083 let mantissa = decimal.mantissa();
4086
4087 Ok(Value::Decimal(mantissa, scale))
4088 } else {
4089 let v: i64 = n
4091 .parse()
4092 .map_err(|_| QueryError::ParseError(format!("invalid integer: {n}")))?;
4093 Ok(Value::BigInt(v))
4094 }
4095}
4096
4097fn expr_to_value(expr: &Expr) -> Result<Value> {
4099 match expr {
4100 Expr::Value(vws) => match &vws.value {
4101 SqlValue::Number(n, _) => parse_number_literal(n),
4102 SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => {
4103 Ok(Value::Text(s.clone()))
4104 }
4105 SqlValue::Boolean(b) => Ok(Value::Boolean(*b)),
4106 SqlValue::Null => Ok(Value::Null),
4107 SqlValue::Placeholder(p) => Ok(Value::Placeholder(parse_placeholder_index(p)?)),
4108 other => Err(QueryError::UnsupportedFeature(format!(
4109 "unsupported value expression: {other:?}"
4110 ))),
4111 },
4112 Expr::UnaryOp {
4113 op: sqlparser::ast::UnaryOperator::Minus,
4114 expr,
4115 } => {
4116 if let Expr::Value(vws) = expr.as_ref()
4118 && let SqlValue::Number(n, _) = &vws.value
4119 {
4120 let value = parse_number_literal(n)?;
4121 match value {
4122 Value::BigInt(v) => Ok(Value::BigInt(-v)),
4123 Value::Decimal(v, scale) => Ok(Value::Decimal(-v, scale)),
4124 _ => unreachable!("parse_number_literal only returns BigInt or Decimal"),
4125 }
4126 } else {
4127 Err(QueryError::UnsupportedFeature(format!(
4128 "unsupported unary minus operand: {expr:?}"
4129 )))
4130 }
4131 }
4132 other => Err(QueryError::UnsupportedFeature(format!(
4133 "unsupported value expression: {other:?}"
4134 ))),
4135 }
4136}
4137
4138#[cfg(test)]
4139mod tests {
4140 use super::*;
4141
4142 fn parse_test_select(sql: &str) -> ParsedSelect {
4143 match parse_statement(sql).unwrap() {
4144 ParsedStatement::Select(s) => s,
4145 _ => panic!("expected SELECT statement"),
4146 }
4147 }
4148
4149 #[test]
4150 fn test_parse_simple_select() {
4151 let result = parse_test_select("SELECT id, name FROM users");
4152 assert_eq!(result.table, "users");
4153 assert_eq!(
4154 result.columns,
4155 Some(vec![ColumnName::new("id"), ColumnName::new("name")])
4156 );
4157 assert!(result.predicates.is_empty());
4158 }
4159
4160 #[test]
4161 fn test_parse_select_star() {
4162 let result = parse_test_select("SELECT * FROM users");
4163 assert_eq!(result.table, "users");
4164 assert!(result.columns.is_none());
4165 }
4166
4167 #[test]
4168 fn test_parse_where_eq() {
4169 let result = parse_test_select("SELECT * FROM users WHERE id = 42");
4170 assert_eq!(result.predicates.len(), 1);
4171 match &result.predicates[0] {
4172 Predicate::Eq(col, PredicateValue::Int(42)) => {
4173 assert_eq!(col.as_str(), "id");
4174 }
4175 other => panic!("unexpected predicate: {other:?}"),
4176 }
4177 }
4178
4179 #[test]
4180 fn test_parse_where_string() {
4181 let result = parse_test_select("SELECT * FROM users WHERE name = 'alice'");
4182 match &result.predicates[0] {
4183 Predicate::Eq(col, PredicateValue::String(s)) => {
4184 assert_eq!(col.as_str(), "name");
4185 assert_eq!(s, "alice");
4186 }
4187 other => panic!("unexpected predicate: {other:?}"),
4188 }
4189 }
4190
4191 #[test]
4192 fn test_parse_where_and() {
4193 let result = parse_test_select("SELECT * FROM users WHERE id = 1 AND name = 'bob'");
4194 assert_eq!(result.predicates.len(), 2);
4195 }
4196
4197 #[test]
4198 fn test_parse_where_in() {
4199 let result = parse_test_select("SELECT * FROM users WHERE id IN (1, 2, 3)");
4200 match &result.predicates[0] {
4201 Predicate::In(col, values) => {
4202 assert_eq!(col.as_str(), "id");
4203 assert_eq!(values.len(), 3);
4204 }
4205 other => panic!("unexpected predicate: {other:?}"),
4206 }
4207 }
4208
4209 #[test]
4210 fn test_parse_order_by() {
4211 let result = parse_test_select("SELECT * FROM users ORDER BY name ASC, id DESC");
4212 assert_eq!(result.order_by.len(), 2);
4213 assert_eq!(result.order_by[0].column.as_str(), "name");
4214 assert!(result.order_by[0].ascending);
4215 assert_eq!(result.order_by[1].column.as_str(), "id");
4216 assert!(!result.order_by[1].ascending);
4217 }
4218
4219 #[test]
4220 fn test_parse_limit() {
4221 let result = parse_test_select("SELECT * FROM users LIMIT 10");
4222 assert_eq!(result.limit, Some(LimitExpr::Literal(10)));
4223 }
4224
4225 #[test]
4226 fn test_parse_limit_param() {
4227 let result = parse_test_select("SELECT * FROM users LIMIT $1");
4228 assert_eq!(result.limit, Some(LimitExpr::Param(1)));
4229 }
4230
4231 #[test]
4232 fn test_parse_offset_literal() {
4233 let result = parse_test_select("SELECT * FROM users LIMIT 10 OFFSET 5");
4234 assert_eq!(result.limit, Some(LimitExpr::Literal(10)));
4235 assert_eq!(result.offset, Some(LimitExpr::Literal(5)));
4236 }
4237
4238 #[test]
4239 fn test_parse_offset_param() {
4240 let result = parse_test_select("SELECT * FROM users LIMIT $1 OFFSET $2");
4241 assert_eq!(result.limit, Some(LimitExpr::Param(1)));
4242 assert_eq!(result.offset, Some(LimitExpr::Param(2)));
4243 }
4244
4245 #[test]
4246 fn test_parse_param() {
4247 let result = parse_test_select("SELECT * FROM users WHERE id = $1");
4248 match &result.predicates[0] {
4249 Predicate::Eq(_, PredicateValue::Param(1)) => {}
4250 other => panic!("unexpected predicate: {other:?}"),
4251 }
4252 }
4253
4254 #[test]
4255 fn test_parse_inner_join() {
4256 let result =
4257 parse_statement("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
4258 if let Err(ref e) = result {
4259 eprintln!("Parse error: {e:?}");
4260 }
4261 assert!(result.is_ok());
4262 match result.unwrap() {
4263 ParsedStatement::Select(s) => {
4264 assert_eq!(s.table, "users");
4265 assert_eq!(s.joins.len(), 1);
4266 assert_eq!(s.joins[0].table, "orders");
4267 assert!(matches!(s.joins[0].join_type, JoinType::Inner));
4268 }
4269 _ => panic!("expected SELECT statement"),
4270 }
4271 }
4272
4273 #[test]
4274 fn test_parse_left_join() {
4275 let result =
4276 parse_statement("SELECT * FROM users LEFT JOIN orders ON users.id = orders.user_id");
4277 assert!(result.is_ok());
4278 match result.unwrap() {
4279 ParsedStatement::Select(s) => {
4280 assert_eq!(s.table, "users");
4281 assert_eq!(s.joins.len(), 1);
4282 assert_eq!(s.joins[0].table, "orders");
4283 assert!(matches!(s.joins[0].join_type, JoinType::Left));
4284 }
4285 _ => panic!("expected SELECT statement"),
4286 }
4287 }
4288
4289 #[test]
4290 fn test_parse_multi_join() {
4291 let result = parse_statement(
4292 "SELECT * FROM users \
4293 JOIN orders ON users.id = orders.user_id \
4294 JOIN products ON orders.product_id = products.id",
4295 );
4296 assert!(result.is_ok());
4297 match result.unwrap() {
4298 ParsedStatement::Select(s) => {
4299 assert_eq!(s.table, "users");
4300 assert_eq!(s.joins.len(), 2);
4301 assert_eq!(s.joins[0].table, "orders");
4302 assert_eq!(s.joins[1].table, "products");
4303 }
4304 _ => panic!("expected SELECT statement"),
4305 }
4306 }
4307
4308 #[test]
4309 fn test_reject_subquery() {
4310 let result = parse_statement("SELECT * FROM (SELECT * FROM users)");
4311 assert!(result.is_err());
4312 }
4313
4314 #[test]
4315 fn test_where_depth_within_limit() {
4316 let mut sql = String::from("SELECT * FROM users WHERE ");
4319 for i in 0..10 {
4320 if i > 0 {
4321 sql.push_str(" AND ");
4322 }
4323 sql.push('(');
4324 sql.push_str("id = ");
4325 sql.push_str(&i.to_string());
4326 sql.push(')');
4327 }
4328
4329 let result = parse_statement(&sql);
4330 assert!(
4331 result.is_ok(),
4332 "Moderate nesting should succeed, but got: {result:?}"
4333 );
4334 }
4335
4336 #[test]
4337 fn test_where_depth_nested_parens() {
4338 let mut sql = String::from("SELECT * FROM users WHERE ");
4341 for _ in 0..200 {
4342 sql.push('(');
4343 }
4344 sql.push_str("id = 1");
4345 for _ in 0..200 {
4346 sql.push(')');
4347 }
4348
4349 let result = parse_statement(&sql);
4350 assert!(
4351 result.is_err(),
4352 "Excessive parenthesis nesting should be rejected"
4353 );
4354 }
4355
4356 #[test]
4357 fn test_where_depth_complex_and_or() {
4358 let sql = "SELECT * FROM users WHERE \
4360 ((id = 1 AND name = 'a') OR (id = 2 AND name = 'b')) AND \
4361 ((age > 10 AND age < 20) OR (age > 30 AND age < 40))";
4362
4363 let result = parse_statement(sql);
4364 assert!(result.is_ok(), "Complex AND/OR should succeed");
4365 }
4366
4367 #[test]
4368 fn test_parse_having() {
4369 let result =
4370 parse_test_select("SELECT name, COUNT(*) FROM users GROUP BY name HAVING COUNT(*) > 5");
4371 assert_eq!(result.group_by.len(), 1);
4372 assert_eq!(result.having.len(), 1);
4373 match &result.having[0] {
4374 HavingCondition::AggregateComparison {
4375 aggregate,
4376 op,
4377 value,
4378 } => {
4379 assert!(matches!(aggregate, AggregateFunction::CountStar));
4380 assert_eq!(*op, HavingOp::Gt);
4381 assert_eq!(*value, Value::BigInt(5));
4382 }
4383 }
4384 }
4385
4386 #[test]
4387 fn test_parse_having_multiple() {
4388 let result = parse_test_select(
4389 "SELECT name, COUNT(*), SUM(age) FROM users GROUP BY name HAVING COUNT(*) > 1 AND SUM(age) < 100",
4390 );
4391 assert_eq!(result.having.len(), 2);
4392 }
4393
4394 #[test]
4395 fn test_parse_having_without_group_by() {
4396 let result = parse_test_select("SELECT COUNT(*) FROM users HAVING COUNT(*) > 0");
4397 assert!(result.group_by.is_empty());
4398 assert_eq!(result.having.len(), 1);
4399 }
4400
4401 #[test]
4402 fn test_parse_union() {
4403 let result = parse_statement("SELECT id FROM users UNION SELECT id FROM orders");
4404 assert!(result.is_ok());
4405 match result.unwrap() {
4406 ParsedStatement::Union(u) => {
4407 assert_eq!(u.left.table, "users");
4408 assert_eq!(u.right.table, "orders");
4409 assert!(!u.all);
4410 }
4411 _ => panic!("expected UNION statement"),
4412 }
4413 }
4414
4415 #[test]
4416 fn test_parse_union_all() {
4417 let result = parse_statement("SELECT id FROM users UNION ALL SELECT id FROM orders");
4418 assert!(result.is_ok());
4419 match result.unwrap() {
4420 ParsedStatement::Union(u) => {
4421 assert_eq!(u.left.table, "users");
4422 assert_eq!(u.right.table, "orders");
4423 assert!(u.all);
4424 }
4425 _ => panic!("expected UNION ALL statement"),
4426 }
4427 }
4428
4429 #[test]
4430 fn test_parse_create_mask() {
4431 let result = parse_statement("CREATE MASK ssn_mask ON patients.ssn USING REDACT").unwrap();
4432 match result {
4433 ParsedStatement::CreateMask(m) => {
4434 assert_eq!(m.mask_name, "ssn_mask");
4435 assert_eq!(m.table_name, "patients");
4436 assert_eq!(m.column_name, "ssn");
4437 assert_eq!(m.strategy, "REDACT");
4438 }
4439 _ => panic!("expected CREATE MASK statement"),
4440 }
4441 }
4442
4443 #[test]
4444 fn test_parse_create_mask_with_semicolon() {
4445 let result = parse_statement("CREATE MASK ssn_mask ON patients.ssn USING REDACT;").unwrap();
4446 match result {
4447 ParsedStatement::CreateMask(m) => {
4448 assert_eq!(m.mask_name, "ssn_mask");
4449 assert_eq!(m.strategy, "REDACT");
4450 }
4451 _ => panic!("expected CREATE MASK statement"),
4452 }
4453 }
4454
4455 #[test]
4456 fn test_parse_create_mask_hash_strategy() {
4457 let result = parse_statement("CREATE MASK email_hash ON users.email USING HASH").unwrap();
4458 match result {
4459 ParsedStatement::CreateMask(m) => {
4460 assert_eq!(m.mask_name, "email_hash");
4461 assert_eq!(m.table_name, "users");
4462 assert_eq!(m.column_name, "email");
4463 assert_eq!(m.strategy, "HASH");
4464 }
4465 _ => panic!("expected CREATE MASK statement"),
4466 }
4467 }
4468
4469 #[test]
4470 fn test_parse_create_mask_missing_on() {
4471 let result = parse_statement("CREATE MASK ssn_mask patients.ssn USING REDACT");
4472 assert!(result.is_err());
4473 }
4474
4475 #[test]
4476 fn test_parse_create_mask_missing_dot() {
4477 let result = parse_statement("CREATE MASK ssn_mask ON patients_ssn USING REDACT");
4478 assert!(result.is_err());
4479 }
4480
4481 #[test]
4482 fn test_parse_drop_mask() {
4483 let result = parse_statement("DROP MASK ssn_mask").unwrap();
4484 match result {
4485 ParsedStatement::DropMask(name) => {
4486 assert_eq!(name, "ssn_mask");
4487 }
4488 _ => panic!("expected DROP MASK statement"),
4489 }
4490 }
4491
4492 #[test]
4493 fn test_parse_drop_mask_with_semicolon() {
4494 let result = parse_statement("DROP MASK ssn_mask;").unwrap();
4495 match result {
4496 ParsedStatement::DropMask(name) => {
4497 assert_eq!(name, "ssn_mask");
4498 }
4499 _ => panic!("expected DROP MASK statement"),
4500 }
4501 }
4502
4503 #[test]
4508 fn test_parse_create_masking_policy_redact_ssn() {
4509 let result = parse_statement(
4510 "CREATE MASKING POLICY ssn_policy STRATEGY REDACT_SSN EXEMPT ROLES ('clinician', 'billing')",
4511 )
4512 .unwrap();
4513 match result {
4514 ParsedStatement::CreateMaskingPolicy(p) => {
4515 assert_eq!(p.name, "ssn_policy");
4516 assert_eq!(p.strategy, ParsedMaskingStrategy::RedactSsn);
4517 assert_eq!(p.exempt_roles, vec!["clinician", "billing"]);
4518 }
4519 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4520 }
4521 }
4522
4523 #[test]
4524 fn test_parse_create_masking_policy_hash_single_role() {
4525 let result =
4526 parse_statement("CREATE MASKING POLICY h STRATEGY HASH EXEMPT ROLES (admin)").unwrap();
4527 match result {
4528 ParsedStatement::CreateMaskingPolicy(p) => {
4529 assert_eq!(p.name, "h");
4530 assert_eq!(p.strategy, ParsedMaskingStrategy::Hash);
4531 assert_eq!(p.exempt_roles, vec!["admin"]);
4532 }
4533 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4534 }
4535 }
4536
4537 #[test]
4538 fn test_parse_create_masking_policy_tokenize() {
4539 let result = parse_statement(
4540 "CREATE MASKING POLICY note_tok STRATEGY TOKENIZE EXEMPT ROLES ('clinician');",
4541 )
4542 .unwrap();
4543 match result {
4544 ParsedStatement::CreateMaskingPolicy(p) => {
4545 assert_eq!(p.strategy, ParsedMaskingStrategy::Tokenize);
4546 assert_eq!(p.exempt_roles, vec!["clinician"]);
4547 }
4548 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4549 }
4550 }
4551
4552 #[test]
4553 fn test_parse_create_masking_policy_truncate_with_arg() {
4554 let result = parse_statement(
4555 "CREATE MASKING POLICY tr STRATEGY TRUNCATE 4 EXEMPT ROLES ('billing')",
4556 )
4557 .unwrap();
4558 match result {
4559 ParsedStatement::CreateMaskingPolicy(p) => {
4560 assert_eq!(p.strategy, ParsedMaskingStrategy::Truncate { max_chars: 4 });
4561 }
4562 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4563 }
4564 }
4565
4566 #[test]
4567 fn test_parse_create_masking_policy_redact_custom() {
4568 let result = parse_statement(
4569 "CREATE MASKING POLICY c STRATEGY REDACT_CUSTOM '***' EXEMPT ROLES ('admin')",
4570 )
4571 .unwrap();
4572 match result {
4573 ParsedStatement::CreateMaskingPolicy(p) => match p.strategy {
4574 ParsedMaskingStrategy::RedactCustom { replacement } => {
4575 assert_eq!(replacement, "***");
4576 }
4577 other => panic!("expected RedactCustom, got {other:?}"),
4578 },
4579 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4580 }
4581 }
4582
4583 #[test]
4584 fn test_parse_create_masking_policy_null_strategy() {
4585 let result =
4586 parse_statement("CREATE MASKING POLICY n STRATEGY NULL EXEMPT ROLES ('auditor')")
4587 .unwrap();
4588 match result {
4589 ParsedStatement::CreateMaskingPolicy(p) => {
4590 assert_eq!(p.strategy, ParsedMaskingStrategy::Null);
4591 }
4592 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4593 }
4594 }
4595
4596 #[test]
4597 fn test_parse_create_masking_policy_lowercases_roles() {
4598 let result = parse_statement(
4601 "CREATE MASKING POLICY p STRATEGY HASH EXEMPT ROLES ('Clinician', 'NURSE')",
4602 )
4603 .unwrap();
4604 match result {
4605 ParsedStatement::CreateMaskingPolicy(p) => {
4606 assert_eq!(p.exempt_roles, vec!["clinician", "nurse"]);
4607 }
4608 other => panic!("expected CreateMaskingPolicy, got {other:?}"),
4609 }
4610 }
4611
4612 #[test]
4613 fn test_parse_create_masking_policy_rejects_unknown_strategy() {
4614 let result =
4615 parse_statement("CREATE MASKING POLICY p STRATEGY SCRAMBLE EXEMPT ROLES ('x')");
4616 assert!(result.is_err(), "expected unknown-strategy error");
4617 }
4618
4619 #[test]
4620 fn test_parse_create_masking_policy_rejects_zero_truncate() {
4621 let result =
4622 parse_statement("CREATE MASKING POLICY p STRATEGY TRUNCATE 0 EXEMPT ROLES ('x')");
4623 assert!(result.is_err(), "TRUNCATE 0 must be rejected");
4624 }
4625
4626 #[test]
4627 fn test_parse_create_masking_policy_rejects_empty_exempt_list() {
4628 let result = parse_statement("CREATE MASKING POLICY p STRATEGY HASH EXEMPT ROLES ()");
4629 assert!(result.is_err(), "empty EXEMPT ROLES list must be rejected");
4630 }
4631
4632 #[test]
4633 fn test_parse_create_masking_policy_rejects_missing_exempt_roles() {
4634 let result = parse_statement("CREATE MASKING POLICY p STRATEGY HASH");
4635 assert!(
4636 result.is_err(),
4637 "missing EXEMPT ROLES clause must be rejected"
4638 );
4639 }
4640
4641 #[test]
4642 fn test_parse_drop_masking_policy() {
4643 let result = parse_statement("DROP MASKING POLICY ssn_policy").unwrap();
4644 match result {
4645 ParsedStatement::DropMaskingPolicy(name) => {
4646 assert_eq!(name, "ssn_policy");
4647 }
4648 other => panic!("expected DropMaskingPolicy, got {other:?}"),
4649 }
4650 }
4651
4652 #[test]
4653 fn test_parse_drop_masking_policy_with_semicolon() {
4654 let result = parse_statement("DROP MASKING POLICY ssn_policy;").unwrap();
4655 match result {
4656 ParsedStatement::DropMaskingPolicy(name) => {
4657 assert_eq!(name, "ssn_policy");
4658 }
4659 other => panic!("expected DropMaskingPolicy, got {other:?}"),
4660 }
4661 }
4662
4663 #[test]
4664 fn test_parse_drop_masking_policy_does_not_swallow_drop_mask() {
4665 let result = parse_statement("DROP MASK ssn_mask").unwrap();
4669 assert!(matches!(result, ParsedStatement::DropMask(_)));
4670 }
4671
4672 #[test]
4673 fn test_parse_attach_masking_policy() {
4674 let result = parse_statement(
4675 "ALTER TABLE patients ALTER COLUMN medicare_number SET MASKING POLICY ssn_policy",
4676 )
4677 .unwrap();
4678 match result {
4679 ParsedStatement::AttachMaskingPolicy(a) => {
4680 assert_eq!(a.table_name, "patients");
4681 assert_eq!(a.column_name, "medicare_number");
4682 assert_eq!(a.policy_name, "ssn_policy");
4683 }
4684 other => panic!("expected AttachMaskingPolicy, got {other:?}"),
4685 }
4686 }
4687
4688 #[test]
4689 fn test_parse_detach_masking_policy() {
4690 let result = parse_statement(
4691 "ALTER TABLE patients ALTER COLUMN medicare_number DROP MASKING POLICY",
4692 )
4693 .unwrap();
4694 match result {
4695 ParsedStatement::DetachMaskingPolicy(d) => {
4696 assert_eq!(d.table_name, "patients");
4697 assert_eq!(d.column_name, "medicare_number");
4698 }
4699 other => panic!("expected DetachMaskingPolicy, got {other:?}"),
4700 }
4701 }
4702
4703 #[test]
4704 fn test_parse_attach_masking_policy_rejects_missing_policy_name() {
4705 let result =
4706 parse_statement("ALTER TABLE patients ALTER COLUMN medicare_number SET MASKING POLICY");
4707 assert!(result.is_err());
4708 }
4709
4710 #[test]
4711 fn test_parse_create_masking_policy_does_not_match_legacy_create_mask() {
4712 let result =
4715 parse_statement("CREATE MASKING POLICY p STRATEGY HASH EXEMPT ROLES ('admin')")
4716 .unwrap();
4717 assert!(matches!(result, ParsedStatement::CreateMaskingPolicy(_)));
4718 }
4719
4720 #[test]
4725 fn test_parse_set_classification() {
4726 let result =
4727 parse_statement("ALTER TABLE patients MODIFY COLUMN ssn SET CLASSIFICATION 'PHI'")
4728 .unwrap();
4729 match result {
4730 ParsedStatement::SetClassification(sc) => {
4731 assert_eq!(sc.table_name, "patients");
4732 assert_eq!(sc.column_name, "ssn");
4733 assert_eq!(sc.classification, "PHI");
4734 }
4735 _ => panic!("expected SetClassification statement"),
4736 }
4737 }
4738
4739 #[test]
4740 fn test_parse_set_classification_with_semicolon() {
4741 let result = parse_statement(
4742 "ALTER TABLE patients MODIFY COLUMN diagnosis SET CLASSIFICATION 'MEDICAL';",
4743 )
4744 .unwrap();
4745 match result {
4746 ParsedStatement::SetClassification(sc) => {
4747 assert_eq!(sc.table_name, "patients");
4748 assert_eq!(sc.column_name, "diagnosis");
4749 assert_eq!(sc.classification, "MEDICAL");
4750 }
4751 _ => panic!("expected SetClassification statement"),
4752 }
4753 }
4754
4755 #[test]
4756 fn test_parse_set_classification_various_labels() {
4757 for label in &["PHI", "PII", "PCI", "MEDICAL", "FINANCIAL", "CONFIDENTIAL"] {
4758 let sql = format!("ALTER TABLE t MODIFY COLUMN c SET CLASSIFICATION '{label}'");
4759 let result = parse_statement(&sql).unwrap();
4760 match result {
4761 ParsedStatement::SetClassification(sc) => {
4762 assert_eq!(sc.classification, *label);
4763 }
4764 _ => panic!("expected SetClassification for {label}"),
4765 }
4766 }
4767 }
4768
4769 #[test]
4770 fn test_parse_set_classification_missing_quotes() {
4771 let result =
4772 parse_statement("ALTER TABLE patients MODIFY COLUMN ssn SET CLASSIFICATION PHI");
4773 assert!(result.is_err(), "classification must be single-quoted");
4774 }
4775
4776 #[test]
4777 fn test_parse_set_classification_missing_modify() {
4778 let result = parse_statement("ALTER TABLE patients SET CLASSIFICATION 'PHI'");
4781 assert!(result.is_err());
4782 }
4783
4784 #[test]
4789 fn test_parse_show_classifications() {
4790 let result = parse_statement("SHOW CLASSIFICATIONS FOR patients").unwrap();
4791 match result {
4792 ParsedStatement::ShowClassifications(table) => {
4793 assert_eq!(table, "patients");
4794 }
4795 _ => panic!("expected ShowClassifications statement"),
4796 }
4797 }
4798
4799 #[test]
4800 fn test_parse_show_classifications_with_semicolon() {
4801 let result = parse_statement("SHOW CLASSIFICATIONS FOR patients;").unwrap();
4802 match result {
4803 ParsedStatement::ShowClassifications(table) => {
4804 assert_eq!(table, "patients");
4805 }
4806 _ => panic!("expected ShowClassifications statement"),
4807 }
4808 }
4809
4810 #[test]
4811 fn test_parse_show_classifications_missing_for() {
4812 let result = parse_statement("SHOW CLASSIFICATIONS patients");
4813 assert!(result.is_err());
4814 }
4815
4816 #[test]
4817 fn test_parse_show_classifications_missing_table() {
4818 let result = parse_statement("SHOW CLASSIFICATIONS FOR");
4819 assert!(result.is_err());
4820 }
4821
4822 #[test]
4827 fn test_parse_create_role() {
4828 let result = parse_statement("CREATE ROLE billing_clerk").unwrap();
4829 match result {
4830 ParsedStatement::CreateRole(name) => {
4831 assert_eq!(name, "billing_clerk");
4832 }
4833 _ => panic!("expected CreateRole"),
4834 }
4835 }
4836
4837 #[test]
4838 fn test_parse_create_role_with_semicolon() {
4839 let result = parse_statement("CREATE ROLE doctor;").unwrap();
4840 match result {
4841 ParsedStatement::CreateRole(name) => {
4842 assert_eq!(name, "doctor");
4843 }
4844 _ => panic!("expected CreateRole"),
4845 }
4846 }
4847
4848 #[test]
4849 fn test_parse_grant_select_all_columns() {
4850 let result = parse_statement("GRANT SELECT ON patients TO doctor").unwrap();
4851 match result {
4852 ParsedStatement::Grant(g) => {
4853 assert!(g.columns.is_none());
4854 assert_eq!(g.table_name, "patients");
4855 assert_eq!(g.role_name, "doctor");
4856 }
4857 _ => panic!("expected Grant"),
4858 }
4859 }
4860
4861 #[test]
4862 fn test_parse_grant_select_specific_columns() {
4863 let result =
4864 parse_statement("GRANT SELECT (id, name, ssn) ON patients TO billing_clerk").unwrap();
4865 match result {
4866 ParsedStatement::Grant(g) => {
4867 assert_eq!(
4868 g.columns,
4869 Some(vec!["id".into(), "name".into(), "ssn".into()])
4870 );
4871 assert_eq!(g.table_name, "patients");
4872 assert_eq!(g.role_name, "billing_clerk");
4873 }
4874 _ => panic!("expected Grant"),
4875 }
4876 }
4877
4878 #[test]
4879 fn test_parse_create_user() {
4880 let result = parse_statement("CREATE USER clerk1 WITH ROLE billing_clerk").unwrap();
4881 match result {
4882 ParsedStatement::CreateUser(u) => {
4883 assert_eq!(u.username, "clerk1");
4884 assert_eq!(u.role, "billing_clerk");
4885 }
4886 _ => panic!("expected CreateUser"),
4887 }
4888 }
4889
4890 #[test]
4891 fn test_parse_create_user_with_semicolon() {
4892 let result = parse_statement("CREATE USER admin1 WITH ROLE admin;").unwrap();
4893 match result {
4894 ParsedStatement::CreateUser(u) => {
4895 assert_eq!(u.username, "admin1");
4896 assert_eq!(u.role, "admin");
4897 }
4898 _ => panic!("expected CreateUser"),
4899 }
4900 }
4901
4902 #[test]
4903 fn test_parse_create_user_missing_role() {
4904 let result = parse_statement("CREATE USER clerk1 WITH billing_clerk");
4905 assert!(result.is_err());
4906 }
4907
4908 #[test]
4913 fn test_parse_create_table_rejects_zero_columns() {
4914 let result = parse_statement("CREATE TABLE#USER");
4916 assert!(result.is_err(), "zero-column CREATE TABLE must be rejected");
4917
4918 let result = parse_statement("CREATE TABLE t ()");
4922 assert!(
4923 result.is_err(),
4924 "empty-column-list CREATE TABLE must be rejected"
4925 );
4926 }
4927
4928 fn parse_test_insert(sql: &str) -> ParsedInsert {
4933 match parse_statement(sql).unwrap_or_else(|e| panic!("parse failed: {e}")) {
4934 ParsedStatement::Insert(i) => i,
4935 other => panic!("expected INSERT statement, got {other:?}"),
4936 }
4937 }
4938
4939 #[test]
4940 fn test_parse_insert_on_conflict_do_update() {
4941 let ins = parse_test_insert(
4942 "INSERT INTO users (id, name) VALUES (1, 'Alice') \
4943 ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name",
4944 );
4945 let oc = ins.on_conflict.expect("on_conflict must be present");
4946 assert_eq!(oc.target, vec!["id".to_string()]);
4947 match oc.action {
4948 OnConflictAction::DoUpdate { assignments } => {
4949 assert_eq!(assignments.len(), 1);
4950 assert_eq!(assignments[0].0, "name");
4951 assert_eq!(
4952 assignments[0].1,
4953 UpsertExpr::Excluded("name".to_string()),
4954 "RHS must be an EXCLUDED.col back-reference"
4955 );
4956 }
4957 other @ OnConflictAction::DoNothing => panic!("expected DoUpdate, got {other:?}"),
4958 }
4959 }
4960
4961 #[test]
4962 fn test_parse_insert_on_conflict_do_nothing() {
4963 let ins = parse_test_insert(
4964 "INSERT INTO users (id, name) VALUES (1, 'Alice') ON CONFLICT (id) DO NOTHING",
4965 );
4966 let oc = ins.on_conflict.expect("on_conflict must be present");
4967 assert_eq!(oc.target, vec!["id".to_string()]);
4968 assert!(
4969 matches!(oc.action, OnConflictAction::DoNothing),
4970 "DO NOTHING must parse to OnConflictAction::DoNothing"
4971 );
4972 }
4973
4974 #[test]
4975 fn test_parse_plain_insert_has_no_on_conflict() {
4976 let ins = parse_test_insert("INSERT INTO users (id, name) VALUES (1, 'Alice')");
4977 assert!(
4978 ins.on_conflict.is_none(),
4979 "plain INSERT must not carry an on_conflict clause"
4980 );
4981 }
4982
4983 #[test]
4984 fn test_parse_insert_on_conflict_multi_column_target() {
4985 let ins = parse_test_insert(
4986 "INSERT INTO t (tenant_id, id, v) VALUES (1, 2, 3) \
4987 ON CONFLICT (tenant_id, id) DO UPDATE SET v = EXCLUDED.v",
4988 );
4989 let oc = ins.on_conflict.expect("on_conflict must be present");
4990 assert_eq!(oc.target, vec!["tenant_id".to_string(), "id".to_string()]);
4991 }
4992
4993 #[test]
4994 fn test_parse_insert_on_conflict_with_returning() {
4995 let ins = parse_test_insert(
4996 "INSERT INTO t (id, v) VALUES (1, 2) \
4997 ON CONFLICT (id) DO UPDATE SET v = EXCLUDED.v RETURNING id, v",
4998 );
4999 assert!(ins.on_conflict.is_some());
5000 assert_eq!(ins.returning, Some(vec!["id".to_string(), "v".to_string()]));
5001 }
5002
5003 #[test]
5004 fn test_parse_insert_on_conflict_rejects_on_constraint_form() {
5005 let result = parse_statement(
5007 "INSERT INTO t (id) VALUES (1) ON CONFLICT ON CONSTRAINT pk_t DO NOTHING",
5008 );
5009 assert!(
5010 result.is_err(),
5011 "ON CONSTRAINT form must be rejected with a clear error"
5012 );
5013 }
5014
5015 #[test]
5016 fn test_parse_insert_on_conflict_literal_rhs() {
5017 let ins = parse_test_insert(
5019 "INSERT INTO t (id, v) VALUES (1, 2) \
5020 ON CONFLICT (id) DO UPDATE SET v = 42",
5021 );
5022 let oc = ins.on_conflict.expect("on_conflict must be present");
5023 match oc.action {
5024 OnConflictAction::DoUpdate { assignments } => {
5025 assert_eq!(assignments[0].0, "v");
5026 assert!(matches!(assignments[0].1, UpsertExpr::Value(_)));
5027 }
5028 other @ OnConflictAction::DoNothing => panic!("expected DoUpdate, got {other:?}"),
5029 }
5030 }
5031}