1use std::collections::HashSet;
2use std::fmt::Debug;
3use std::ops::ControlFlow;
4
5use datafusion::sql::sqlparser::ast::Array;
6use datafusion::sql::sqlparser::ast::ArrayElemTypeDef;
7use datafusion::sql::sqlparser::ast::BinaryOperator;
8use datafusion::sql::sqlparser::ast::CastKind;
9use datafusion::sql::sqlparser::ast::DataType;
10use datafusion::sql::sqlparser::ast::Expr;
11use datafusion::sql::sqlparser::ast::Function;
12use datafusion::sql::sqlparser::ast::FunctionArg;
13use datafusion::sql::sqlparser::ast::FunctionArgExpr;
14use datafusion::sql::sqlparser::ast::FunctionArgumentList;
15use datafusion::sql::sqlparser::ast::FunctionArguments;
16use datafusion::sql::sqlparser::ast::Ident;
17use datafusion::sql::sqlparser::ast::ObjectName;
18use datafusion::sql::sqlparser::ast::ObjectNamePart;
19use datafusion::sql::sqlparser::ast::OrderByKind;
20use datafusion::sql::sqlparser::ast::Query;
21use datafusion::sql::sqlparser::ast::Select;
22use datafusion::sql::sqlparser::ast::SelectItem;
23use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
24use datafusion::sql::sqlparser::ast::SetExpr;
25use datafusion::sql::sqlparser::ast::Statement;
26use datafusion::sql::sqlparser::ast::TableFactor;
27use datafusion::sql::sqlparser::ast::TableWithJoins;
28use datafusion::sql::sqlparser::ast::UnaryOperator;
29use datafusion::sql::sqlparser::ast::Value;
30use datafusion::sql::sqlparser::ast::ValueWithSpan;
31use datafusion::sql::sqlparser::ast::VisitMut;
32use datafusion::sql::sqlparser::ast::VisitorMut;
33
34pub trait SqlStatementRewriteRule: Send + Sync + Debug {
35 fn rewrite(&self, s: Statement) -> Statement;
36}
37
38#[derive(Debug)]
47pub struct AliasDuplicatedProjectionRewrite;
48
49impl AliasDuplicatedProjectionRewrite {
50 fn rewrite_select_with_alias(select: &mut Box<Select>) {
52 let mut wildcard_tables = Vec::new();
54 let mut has_simple_wildcard = false;
55 for p in &select.projection {
56 match p {
57 SelectItem::QualifiedWildcard(name, _) => match name {
58 SelectItemQualifiedWildcardKind::ObjectName(objname) => {
59 let idents = objname
61 .0
62 .iter()
63 .map(|v| v.as_ident().unwrap().value.clone())
64 .collect::<Vec<_>>()
65 .join(".");
66
67 wildcard_tables.push(idents);
68 }
69 SelectItemQualifiedWildcardKind::Expr(_expr) => {
70 }
72 },
73 SelectItem::Wildcard(_) => {
74 has_simple_wildcard = true;
75 }
76 _ => {}
77 }
78 }
79
80 if wildcard_tables.is_empty() && !has_simple_wildcard {
82 return;
83 }
84
85 let mut new_projection = vec![];
87 for p in select.projection.drain(..) {
88 match p {
89 SelectItem::UnnamedExpr(expr) => {
90 let alias_partial = match &expr {
91 Expr::Identifier(ident) => Some(ident.clone()),
93 Expr::CompoundIdentifier(idents) => {
95 if idents.len() > 1 {
97 let table_name = &idents[..idents.len() - 1]
98 .iter()
99 .map(|i| i.value.clone())
100 .collect::<Vec<_>>()
101 .join(".");
102 if wildcard_tables.iter().any(|name| name == table_name) {
103 Some(idents[idents.len() - 1].clone())
104 } else {
105 None
106 }
107 } else {
108 None
109 }
110 }
111 _ => None,
112 };
113
114 if let Some(name) = alias_partial {
115 let alias = format!("__alias_{name}");
116 new_projection.push(SelectItem::ExprWithAlias {
117 expr,
118 alias: Ident::new(alias),
119 });
120 } else {
121 new_projection.push(SelectItem::UnnamedExpr(expr));
122 }
123 }
124 _ => new_projection.push(p),
126 }
127 }
128 select.projection = new_projection;
129 }
130}
131
132impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
133 fn rewrite(&self, mut statement: Statement) -> Statement {
134 if let Statement::Query(query) = &mut statement {
135 if let SetExpr::Select(select) = query.body.as_mut() {
136 Self::rewrite_select_with_alias(select);
137 }
138 }
139
140 statement
141 }
142}
143
144#[derive(Debug)]
149pub struct ResolveUnqualifiedIdentifer;
150
151impl ResolveUnqualifiedIdentifer {
152 fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
153 if let SetExpr::Select(select) = query.body.as_mut() {
154 let table_aliases = Self::get_table_aliases(&select.from);
156
157 let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
159 if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
160 return; }
162
163 let wildcard_alias = qualified_wildcard_alias.unwrap();
164
165 let projection_aliases = Self::get_projection_aliases(&select.projection);
167
168 if let Some(selection) = &mut select.selection {
170 Self::rewrite_expr(
171 selection,
172 &wildcard_alias,
173 &table_aliases,
174 &projection_aliases,
175 );
176 }
177
178 if let Some(OrderByKind::Expressions(order_by_exprs)) =
179 query.order_by.as_mut().map(|o| &mut o.kind)
180 {
181 for order_by_expr in order_by_exprs {
182 Self::rewrite_expr(
183 &mut order_by_expr.expr,
184 &wildcard_alias,
185 &table_aliases,
186 &projection_aliases,
187 );
188 }
189 }
190 }
191 }
192
193 fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
194 let mut aliases = HashSet::new();
195 for table_with_joins in tables {
196 if let TableFactor::Table {
197 alias: Some(alias), ..
198 } = &table_with_joins.relation
199 {
200 aliases.insert(alias.name.value.clone());
201 }
202 for join in &table_with_joins.joins {
203 if let TableFactor::Table {
204 alias: Some(alias), ..
205 } = &join.relation
206 {
207 aliases.insert(alias.name.value.clone());
208 }
209 }
210 }
211 aliases
212 }
213
214 fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
215 let mut qualified_wildcards = projection
216 .iter()
217 .filter_map(|item| {
218 if let SelectItem::QualifiedWildcard(
219 SelectItemQualifiedWildcardKind::ObjectName(objname),
220 _,
221 ) = item
222 {
223 Some(
224 objname
225 .0
226 .iter()
227 .map(|v| v.as_ident().unwrap().value.clone())
228 .collect::<Vec<_>>()
229 .join("."),
230 )
231 } else {
232 None
233 }
234 })
235 .collect::<Vec<_>>();
236
237 if qualified_wildcards.len() == 1 {
238 Some(qualified_wildcards.remove(0))
239 } else {
240 None
241 }
242 }
243
244 fn get_projection_aliases(projection: &[SelectItem]) -> HashSet<String> {
245 let mut aliases = HashSet::new();
246 for item in projection {
247 match item {
248 SelectItem::ExprWithAlias { alias, .. } => {
249 aliases.insert(alias.value.clone());
250 }
251 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
252 aliases.insert(ident.value.clone());
253 }
254 _ => {}
255 }
256 }
257 aliases
258 }
259
260 fn rewrite_expr(
261 expr: &mut Expr,
262 wildcard_alias: &str,
263 table_aliases: &HashSet<String>,
264 projection_aliases: &HashSet<String>,
265 ) {
266 match expr {
267 Expr::Identifier(ident) => {
268 if !table_aliases.contains(&ident.value)
270 && !projection_aliases.contains(&ident.value)
271 {
272 *expr = Expr::CompoundIdentifier(vec![
273 Ident::new(wildcard_alias.to_string()),
274 ident.clone(),
275 ]);
276 }
277 }
278 Expr::BinaryOp { left, right, .. } => {
279 Self::rewrite_expr(left, wildcard_alias, table_aliases, projection_aliases);
280 Self::rewrite_expr(right, wildcard_alias, table_aliases, projection_aliases);
281 }
282 _ => {}
284 }
285 }
286}
287
288impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
289 fn rewrite(&self, mut statement: Statement) -> Statement {
290 if let Statement::Query(query) = &mut statement {
291 Self::rewrite_unqualified_identifiers(query);
292 }
293
294 statement
295 }
296}
297
298#[derive(Debug)]
301pub struct RemoveUnsupportedTypes {
302 unsupported_types: HashSet<String>,
303}
304
305impl Default for RemoveUnsupportedTypes {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311impl RemoveUnsupportedTypes {
312 pub fn new() -> Self {
313 let mut unsupported_types = HashSet::new();
314
315 for item in [
316 "regclass",
317 "regproc",
318 "regtype",
319 "regtype[]",
320 "regnamespace",
321 "oid",
322 ] {
323 unsupported_types.insert(item.to_owned());
324 unsupported_types.insert(format!("pg_catalog.{item}"));
325 }
326
327 Self { unsupported_types }
328 }
329}
330
331struct RemoveUnsupportedTypesVisitor<'a> {
332 unsupported_types: &'a HashSet<String>,
333}
334
335impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> {
336 type Break = ();
337
338 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
339 match expr {
340 Expr::TypedString { value, data_type } => {
342 if self
343 .unsupported_types
344 .contains(data_type.to_string().to_lowercase().as_str())
345 {
346 *expr =
347 Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span());
348 }
349 }
350 Expr::Cast {
351 data_type,
352 expr: value,
353 ..
354 } => {
355 if self
356 .unsupported_types
357 .contains(data_type.to_string().to_lowercase().as_str())
358 {
359 *expr = *value.clone();
360 }
361 }
362 _ => {}
364 }
365
366 ControlFlow::Continue(())
367 }
368}
369
370impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
371 fn rewrite(&self, mut statement: Statement) -> Statement {
372 let mut visitor = RemoveUnsupportedTypesVisitor {
373 unsupported_types: &self.unsupported_types,
374 };
375 let _ = statement.visit(&mut visitor);
376 statement
377 }
378}
379
380#[derive(Debug)]
382pub struct RewriteArrayAnyAllOperation;
383
384struct RewriteArrayAnyAllOperationVisitor;
385
386impl RewriteArrayAnyAllOperationVisitor {
387 fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr {
388 let array = if let Expr::Value(ValueWithSpan {
389 value: Value::SingleQuotedString(array_literal),
390 ..
391 }) = right
392 {
393 let array_literal = array_literal.trim();
394 if array_literal.starts_with('{') && array_literal.ends_with('}') {
395 let items = array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
396 let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
397
398 let elems = items
400 .map(|s| {
401 Expr::Value(Value::SingleQuotedString(s.to_string()).with_empty_span())
402 })
403 .collect();
404 Expr::Array(Array {
405 elem: elems,
406 named: true,
407 })
408 } else {
409 right.clone()
410 }
411 } else {
412 right.clone()
413 };
414
415 Expr::Function(Function {
416 name: ObjectName::from(vec![Ident::new("array_contains")]),
417 args: FunctionArguments::List(FunctionArgumentList {
418 args: vec![
419 FunctionArg::Unnamed(FunctionArgExpr::Expr(array)),
420 FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())),
421 ],
422 duplicate_treatment: None,
423 clauses: vec![],
424 }),
425 uses_odbc_syntax: false,
426 parameters: FunctionArguments::None,
427 filter: None,
428 null_treatment: None,
429 over: None,
430 within_group: vec![],
431 })
432 }
433}
434
435impl VisitorMut for RewriteArrayAnyAllOperationVisitor {
436 type Break = ();
437
438 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
439 match expr {
440 Expr::AnyOp {
441 left,
442 compare_op,
443 right,
444 ..
445 } => match compare_op {
446 BinaryOperator::Eq => {
447 *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref());
448 }
449 BinaryOperator::NotEq => {
450 }
452 _ => {}
453 },
454 Expr::AllOp {
455 left,
456 compare_op,
457 right,
458 } => match compare_op {
459 BinaryOperator::Eq => {
460 }
462 BinaryOperator::NotEq => {
463 *expr = Expr::UnaryOp {
464 op: UnaryOperator::Not,
465 expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())),
466 }
467 }
468 _ => {}
469 },
470 _ => {}
471 }
472
473 ControlFlow::Continue(())
474 }
475}
476
477impl SqlStatementRewriteRule for RewriteArrayAnyAllOperation {
478 fn rewrite(&self, mut s: Statement) -> Statement {
479 let mut visitor = RewriteArrayAnyAllOperationVisitor;
480
481 let _ = s.visit(&mut visitor);
482
483 s
484 }
485}
486
487#[derive(Debug)]
492pub struct PrependUnqualifiedPgTableName;
493
494struct PrependUnqualifiedPgTableNameVisitor;
495
496impl VisitorMut for PrependUnqualifiedPgTableNameVisitor {
497 type Break = ();
498
499 fn pre_visit_table_factor(
500 &mut self,
501 table_factor: &mut TableFactor,
502 ) -> ControlFlow<Self::Break> {
503 if let TableFactor::Table { name, args, .. } = table_factor {
504 if args.is_none() && name.0.len() == 1 {
506 if let ObjectNamePart::Identifier(ident) = &name.0[0] {
507 if ident.value.starts_with("pg_") {
508 *name = ObjectName(vec![
509 ObjectNamePart::Identifier(Ident::new("pg_catalog")),
510 name.0[0].clone(),
511 ]);
512 }
513 }
514 }
515 }
516
517 ControlFlow::Continue(())
518 }
519}
520
521impl SqlStatementRewriteRule for PrependUnqualifiedPgTableName {
522 fn rewrite(&self, mut s: Statement) -> Statement {
523 let mut visitor = PrependUnqualifiedPgTableNameVisitor;
524
525 let _ = s.visit(&mut visitor);
526 s
527 }
528}
529
530#[derive(Debug)]
531pub struct FixArrayLiteral;
532
533struct FixArrayLiteralVisitor;
534
535impl FixArrayLiteralVisitor {
536 fn is_string_type(dt: &DataType) -> bool {
537 matches!(
538 dt,
539 DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::String(_)
540 )
541 }
542}
543
544impl VisitorMut for FixArrayLiteralVisitor {
545 type Break = ();
546
547 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
548 if let Expr::Cast {
549 kind,
550 expr,
551 data_type,
552 ..
553 } = expr
554 {
555 if kind == &CastKind::DoubleColon {
556 if let DataType::Array(arr) = data_type {
557 if let Expr::Value(ValueWithSpan {
559 value: Value::SingleQuotedString(array_literal),
560 ..
561 }) = expr.as_ref()
562 {
563 let items =
564 array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
565 let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
566
567 let is_text = match arr {
568 ArrayElemTypeDef::AngleBracket(dt) => Self::is_string_type(dt.as_ref()),
569 ArrayElemTypeDef::SquareBracket(dt, _) => {
570 Self::is_string_type(dt.as_ref())
571 }
572 ArrayElemTypeDef::Parenthesis(dt) => Self::is_string_type(dt.as_ref()),
573 _ => false,
574 };
575
576 let elems = items
577 .map(|s| {
578 if is_text {
579 Expr::Value(
580 Value::SingleQuotedString(s.to_string()).with_empty_span(),
581 )
582 } else {
583 Expr::Value(
584 Value::Number(s.to_string(), false).with_empty_span(),
585 )
586 }
587 })
588 .collect();
589 *expr = Box::new(Expr::Array(Array {
590 elem: elems,
591 named: true,
592 }));
593 }
594 }
595 }
596 }
597
598 ControlFlow::Continue(())
599 }
600}
601
602impl SqlStatementRewriteRule for FixArrayLiteral {
603 fn rewrite(&self, mut s: Statement) -> Statement {
604 let mut visitor = FixArrayLiteralVisitor;
605
606 let _ = s.visit(&mut visitor);
607 s
608 }
609}
610
611#[derive(Debug)]
618pub struct RemoveQualifier;
619
620struct RemoveQualifierVisitor;
621
622impl VisitorMut for RemoveQualifierVisitor {
623 type Break = ();
624
625 fn pre_visit_table_factor(
626 &mut self,
627 table_factor: &mut TableFactor,
628 ) -> ControlFlow<Self::Break> {
629 if let TableFactor::Table { name, args, .. } = table_factor {
631 if args.is_some() {
632 if name.0.len() > 1 {
634 if let Some(last_ident) = name.0.pop() {
635 *name = ObjectName(vec![last_ident]);
636 }
637 }
638 }
639 }
640 ControlFlow::Continue(())
641 }
642
643 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
644 match expr {
645 Expr::Cast { data_type, .. } => {
646 let data_type_str = data_type.to_string();
648 match data_type_str.as_str() {
649 "pg_catalog.text" => {
650 *data_type = DataType::Text;
651 }
652 "pg_catalog.int2[]" => {
653 *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket(
654 Box::new(DataType::Int16),
655 None,
656 ));
657 }
658 _ => {}
659 }
660 }
661 Expr::Function(function) => {
662 let name = &mut function.name;
664 if name.0.len() > 1 {
665 if let Some(last_ident) = name.0.pop() {
666 *name = ObjectName(vec![last_ident]);
667 }
668 }
669 }
670
671 _ => {}
672 }
673 ControlFlow::Continue(())
674 }
675}
676
677impl SqlStatementRewriteRule for RemoveQualifier {
678 fn rewrite(&self, mut s: Statement) -> Statement {
679 let mut visitor = RemoveQualifierVisitor;
680
681 let _ = s.visit(&mut visitor);
682 s
683 }
684}
685
686#[derive(Debug)]
688pub struct CurrentUserVariableToSessionUserFunctionCall;
689
690struct CurrentUserVariableToSessionUserFunctionCallVisitor;
691
692impl VisitorMut for CurrentUserVariableToSessionUserFunctionCallVisitor {
693 type Break = ();
694
695 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
696 if let Expr::Identifier(ident) = expr {
697 if ident.quote_style.is_none() && ident.value.to_lowercase() == "current_user" {
698 *expr = Expr::Function(Function {
699 name: ObjectName::from(vec![Ident::new("session_user")]),
700 args: FunctionArguments::None,
701 uses_odbc_syntax: false,
702 parameters: FunctionArguments::None,
703 filter: None,
704 null_treatment: None,
705 over: None,
706 within_group: vec![],
707 });
708 }
709 }
710
711 if let Expr::Function(func) = expr {
712 let fname = func
713 .name
714 .0
715 .iter()
716 .map(|ident| ident.to_string())
717 .collect::<Vec<String>>()
718 .join(".");
719 if fname.to_lowercase() == "current_user" {
720 func.name = ObjectName::from(vec![Ident::new("session_user")])
721 }
722 }
723
724 ControlFlow::Continue(())
725 }
726}
727
728impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall {
729 fn rewrite(&self, mut s: Statement) -> Statement {
730 let mut visitor = CurrentUserVariableToSessionUserFunctionCallVisitor;
731
732 let _ = s.visit(&mut visitor);
733 s
734 }
735}
736
737#[derive(Debug)]
739pub struct FixCollate;
740
741struct FixCollateVisitor;
742
743impl VisitorMut for FixCollateVisitor {
744 type Break = ();
745
746 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
747 match expr {
748 Expr::Collate { expr: inner, .. } => {
749 *expr = inner.as_ref().clone();
750 }
751 Expr::BinaryOp { op, .. } => {
752 if let BinaryOperator::PGCustomBinaryOperator(ops) = op {
753 if *ops == ["pg_catalog", "~"] {
754 *op = BinaryOperator::PGRegexMatch;
755 }
756 }
757 }
758 _ => {}
759 }
760
761 ControlFlow::Continue(())
762 }
763}
764
765impl SqlStatementRewriteRule for FixCollate {
766 fn rewrite(&self, mut s: Statement) -> Statement {
767 let mut visitor = FixCollateVisitor;
768
769 let _ = s.visit(&mut visitor);
770 s
771 }
772}
773
774#[derive(Debug)]
776pub struct RemoveSubqueryFromProjection;
777
778struct RemoveSubqueryFromProjectionVisitor;
779
780impl VisitorMut for RemoveSubqueryFromProjectionVisitor {
781 type Break = ();
782
783 fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
784 if let SetExpr::Select(select) = query.body.as_mut() {
785 for projection in &mut select.projection {
786 match projection {
787 SelectItem::UnnamedExpr(expr) => {
788 if let Expr::Subquery(_) = expr {
789 *expr = Expr::Value(Value::Null.with_empty_span());
790 }
791 }
792 SelectItem::ExprWithAlias { expr, .. } => {
793 if let Expr::Subquery(_) = expr {
794 *expr = Expr::Value(Value::Null.with_empty_span());
795 }
796 }
797 _ => {}
798 }
799 }
800 }
801
802 ControlFlow::Continue(())
803 }
804}
805
806impl SqlStatementRewriteRule for RemoveSubqueryFromProjection {
807 fn rewrite(&self, mut s: Statement) -> Statement {
808 let mut visitor = RemoveSubqueryFromProjectionVisitor;
809 let _ = s.visit(&mut visitor);
810
811 s
812 }
813}
814
815#[derive(Debug)]
817pub struct FixVersionColumnName;
818
819struct FixVersionColumnNameVisitor;
820
821impl VisitorMut for FixVersionColumnNameVisitor {
822 type Break = ();
823
824 fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
825 if let SetExpr::Select(select) = query.body.as_mut() {
826 for projection in &mut select.projection {
827 if let SelectItem::UnnamedExpr(Expr::Function(f)) = projection {
828 if f.name.0.len() == 1 {
829 if let ObjectNamePart::Identifier(part) = &f.name.0[0] {
830 if part.value == "version" {
831 if let FunctionArguments::List(args) = &f.args {
832 if args.args.is_empty() {
833 *projection = SelectItem::ExprWithAlias {
834 expr: Expr::Function(f.clone()),
835 alias: Ident::new("version"),
836 }
837 }
838 }
839 }
840 }
841 }
842 }
843 }
844 }
845
846 ControlFlow::Continue(())
847 }
848}
849
850impl SqlStatementRewriteRule for FixVersionColumnName {
851 fn rewrite(&self, mut s: Statement) -> Statement {
852 let mut visitor = FixVersionColumnNameVisitor;
853 let _ = s.visit(&mut visitor);
854
855 s
856 }
857}
858
859#[cfg(test)]
860mod tests {
861 use super::*;
862 use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
863 use datafusion::sql::sqlparser::parser::Parser;
864 use datafusion::sql::sqlparser::parser::ParserError;
865 use std::sync::Arc;
866
867 fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> {
868 let dialect = PostgreSqlDialect {};
869
870 Parser::parse_sql(&dialect, sql)
871 }
872
873 fn rewrite(mut s: Statement, rules: &[Arc<dyn SqlStatementRewriteRule>]) -> Statement {
874 for rule in rules {
875 s = rule.rewrite(s);
876 }
877
878 s
879 }
880
881 macro_rules! assert_rewrite {
882 ($rules:expr, $orig:expr, $rewt:expr) => {
883 let sql = $orig;
884 let statement = parse(sql).expect("Failed to parse").remove(0);
885
886 let statement = rewrite(statement, $rules);
887 assert_eq!(statement.to_string(), $rewt);
888 };
889 }
890
891 #[test]
892 fn test_alias_rewrite() {
893 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
894 vec![Arc::new(AliasDuplicatedProjectionRewrite)];
895
896 assert_rewrite!(
897 &rules,
898 "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n",
899 "SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
900 );
901
902 assert_rewrite!(
903 &rules,
904 "SELECT oid, * FROM pg_catalog.pg_namespace",
905 "SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
906 );
907
908 assert_rewrite!(
909 &rules,
910 "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id",
911 "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
912 );
913
914 let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
915 let statement = parse(sql).expect("Failed to parse").remove(0);
916
917 let statement = rewrite(statement, &rules);
918 assert_eq!(
919 statement.to_string(),
920 "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname"
921 );
922 }
923
924 #[test]
925 fn test_qualifier_prepend() {
926 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
927 vec![Arc::new(ResolveUnqualifiedIdentifer)];
928
929 assert_rewrite!(
930 &rules,
931 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname",
932 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
933 );
934
935 assert_rewrite!(
936 &rules,
937 "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname",
938 "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
939 );
940
941 assert_rewrite!(
942 &rules,
943 "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname",
944 "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
945 );
946
947 assert_rewrite!(&rules,
948 "SELECT i.*,i.indkey as keys,c.relname,c.relnamespace,c.relam,c.reltablespace,tc.relname as tabrelname,dsc.description FROM pg_catalog.pg_index i INNER JOIN pg_catalog.pg_class c ON c.oid=i.indexrelid INNER JOIN pg_catalog.pg_class tc ON tc.oid=i.indrelid LEFT OUTER JOIN pg_catalog.pg_description dsc ON i.indexrelid=dsc.objoid WHERE i.indrelid=1 ORDER BY tabrelname, c.relname",
949 "SELECT i.*, i.indkey AS keys, c.relname, c.relnamespace, c.relam, c.reltablespace, tc.relname AS tabrelname, dsc.description FROM pg_catalog.pg_index AS i INNER JOIN pg_catalog.pg_class AS c ON c.oid = i.indexrelid INNER JOIN pg_catalog.pg_class AS tc ON tc.oid = i.indrelid LEFT OUTER JOIN pg_catalog.pg_description AS dsc ON i.indexrelid = dsc.objoid WHERE i.indrelid = 1 ORDER BY tabrelname, c.relname"
950 );
951 }
952
953 #[test]
954 fn test_remove_unsupported_types() {
955 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
956 Arc::new(RemoveQualifier),
957 Arc::new(RemoveUnsupportedTypes::new()),
958 ];
959
960 assert_rewrite!(
961 &rules,
962 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
963 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
964 );
965
966 assert_rewrite!(
967 &rules,
968 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
969 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname"
970 );
971
972 assert_rewrite!(
973 &rules,
974 "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname",
975 "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname"
976 );
977
978 assert_rewrite!(
979 &rules,
980 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname",
981 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
982 );
983
984 assert_rewrite!(
985 &rules,
986 "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::pg_catalog.regtype::pg_catalog.text END, c.relpersistence, c.relreplident, am.amname
987 FROM pg_catalog.pg_class c
988 LEFT JOIN pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid)
989 LEFT JOIN pg_catalog.pg_am am ON (c.relam = am.oid)
990 WHERE c.oid = '16386'",
991 "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::TEXT END, c.relpersistence, c.relreplident, am.amname FROM pg_catalog.pg_class AS c LEFT JOIN pg_catalog.pg_class AS tc ON (c.reltoastrelid = tc.oid) LEFT JOIN pg_catalog.pg_am AS am ON (c.relam = am.oid) WHERE c.oid = '16386'"
992 );
993 }
994
995 #[test]
996 fn test_any_to_array_contains() {
997 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
998 vec![Arc::new(RewriteArrayAnyAllOperation)];
999
1000 assert_rewrite!(
1001 &rules,
1002 "SELECT a = ANY(current_schemas(true))",
1003 "SELECT array_contains(current_schemas(true), a)"
1004 );
1005
1006 assert_rewrite!(
1007 &rules,
1008 "SELECT a <> ALL(current_schemas(true))",
1009 "SELECT NOT array_contains(current_schemas(true), a)"
1010 );
1011
1012 assert_rewrite!(
1013 &rules,
1014 "SELECT a = ANY('{r, l, e}')",
1015 "SELECT array_contains(ARRAY['r', 'l', 'e'], a)"
1016 );
1017
1018 assert_rewrite!(
1019 &rules,
1020 "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))",
1021 "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)"
1022 );
1023 }
1024
1025 #[test]
1026 fn test_prepend_unqualified_table_name() {
1027 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1028 vec![Arc::new(PrependUnqualifiedPgTableName)];
1029
1030 assert_rewrite!(
1031 &rules,
1032 "SELECT * FROM pg_catalog.pg_namespace",
1033 "SELECT * FROM pg_catalog.pg_namespace"
1034 );
1035
1036 assert_rewrite!(
1037 &rules,
1038 "SELECT * FROM pg_namespace",
1039 "SELECT * FROM pg_catalog.pg_namespace"
1040 );
1041
1042 assert_rewrite!(
1043 &rules,
1044 "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid",
1045 "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid"
1046 );
1047 }
1048
1049 #[test]
1050 fn test_array_literal_fix() {
1051 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixArrayLiteral)];
1052
1053 assert_rewrite!(
1054 &rules,
1055 "SELECT '{a, abc}'::text[]",
1056 "SELECT ARRAY['a', 'abc']::TEXT[]"
1057 );
1058
1059 assert_rewrite!(
1060 &rules,
1061 "SELECT '{1, 2}'::int[]",
1062 "SELECT ARRAY[1, 2]::INT[]"
1063 );
1064
1065 assert_rewrite!(
1066 &rules,
1067 "SELECT '{t, f}'::bool[]",
1068 "SELECT ARRAY[t, f]::BOOL[]"
1069 );
1070 }
1071
1072 #[test]
1073 fn test_remove_qualifier_from_table_function() {
1074 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(RemoveQualifier)];
1075
1076 assert_rewrite!(
1077 &rules,
1078 "SELECT * FROM pg_catalog.pg_get_keywords()",
1079 "SELECT * FROM pg_get_keywords()"
1080 );
1081 }
1082
1083 #[test]
1084 fn test_current_user() {
1085 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1086 vec![Arc::new(CurrentUserVariableToSessionUserFunctionCall)];
1087
1088 assert_rewrite!(&rules, "SELECT current_user", "SELECT session_user");
1089
1090 assert_rewrite!(&rules, "SELECT CURRENT_USER", "SELECT session_user");
1091
1092 assert_rewrite!(
1093 &rules,
1094 "SELECT is_null(current_user)",
1095 "SELECT is_null(session_user)"
1096 );
1097 }
1098
1099 #[test]
1100 fn test_collate_fix() {
1101 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixCollate)];
1102
1103 assert_rewrite!(&rules, "SELECT c.oid, c.relname FROM pg_catalog.pg_class c WHERE c.relname OPERATOR(pg_catalog.~) '^(tablename)$' COLLATE pg_catalog.default AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3;", "SELECT c.oid, c.relname FROM pg_catalog.pg_class AS c WHERE c.relname ~ '^(tablename)$' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3");
1104 }
1105
1106 #[test]
1107 fn test_remove_subquery() {
1108 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1109 vec![Arc::new(RemoveSubqueryFromProjection)];
1110
1111 assert_rewrite!(&rules,
1112 "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;",
1113 "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum");
1114 }
1115
1116 #[test]
1117 fn test_version_rewrite() {
1118 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixVersionColumnName)];
1119
1120 assert_rewrite!(&rules, "SELECT version()", "SELECT version() AS version");
1121
1122 assert_rewrite!(&rules, "SELECT version() as foo", "SELECT version() AS foo");
1124 assert_rewrite!(&rules, "SELECT version(foo)", "SELECT version(foo)");
1125 assert_rewrite!(&rules, "SELECT foo.version()", "SELECT foo.version()");
1126 }
1127}