1use std::collections::HashSet;
2use std::fmt::Debug;
3use std::ops::ControlFlow;
4
5use datafusion::sql::sqlparser::ast::Array;
6use datafusion::sql::sqlparser::ast::ArrayElemTypeDef;
7use datafusion::sql::sqlparser::ast::BinaryOperator;
8use datafusion::sql::sqlparser::ast::CastKind;
9use datafusion::sql::sqlparser::ast::DataType;
10use datafusion::sql::sqlparser::ast::Expr;
11use datafusion::sql::sqlparser::ast::Function;
12use datafusion::sql::sqlparser::ast::FunctionArg;
13use datafusion::sql::sqlparser::ast::FunctionArgExpr;
14use datafusion::sql::sqlparser::ast::FunctionArgumentList;
15use datafusion::sql::sqlparser::ast::FunctionArguments;
16use datafusion::sql::sqlparser::ast::Ident;
17use datafusion::sql::sqlparser::ast::ObjectName;
18use datafusion::sql::sqlparser::ast::ObjectNamePart;
19use datafusion::sql::sqlparser::ast::OrderByKind;
20use datafusion::sql::sqlparser::ast::Query;
21use datafusion::sql::sqlparser::ast::Select;
22use datafusion::sql::sqlparser::ast::SelectItem;
23use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
24use datafusion::sql::sqlparser::ast::SetExpr;
25use datafusion::sql::sqlparser::ast::Statement;
26use datafusion::sql::sqlparser::ast::TableFactor;
27use datafusion::sql::sqlparser::ast::TableWithJoins;
28use datafusion::sql::sqlparser::ast::UnaryOperator;
29use datafusion::sql::sqlparser::ast::Value;
30use datafusion::sql::sqlparser::ast::ValueWithSpan;
31use datafusion::sql::sqlparser::ast::VisitMut;
32use datafusion::sql::sqlparser::ast::VisitorMut;
33
34pub trait SqlStatementRewriteRule: Send + Sync + Debug {
35 fn rewrite(&self, s: Statement) -> Statement;
36}
37
38#[derive(Debug)]
47pub struct AliasDuplicatedProjectionRewrite;
48
49impl AliasDuplicatedProjectionRewrite {
50 fn rewrite_select_with_alias(select: &mut Box<Select>) {
52 let mut wildcard_tables = Vec::new();
54 let mut has_simple_wildcard = false;
55 for p in &select.projection {
56 match p {
57 SelectItem::QualifiedWildcard(name, _) => match name {
58 SelectItemQualifiedWildcardKind::ObjectName(objname) => {
59 let idents = objname
61 .0
62 .iter()
63 .map(|v| v.as_ident().unwrap().value.clone())
64 .collect::<Vec<_>>()
65 .join(".");
66
67 wildcard_tables.push(idents);
68 }
69 SelectItemQualifiedWildcardKind::Expr(_expr) => {
70 }
72 },
73 SelectItem::Wildcard(_) => {
74 has_simple_wildcard = true;
75 }
76 _ => {}
77 }
78 }
79
80 if wildcard_tables.is_empty() && !has_simple_wildcard {
82 return;
83 }
84
85 let mut new_projection = vec![];
87 for p in select.projection.drain(..) {
88 match p {
89 SelectItem::UnnamedExpr(expr) => {
90 let alias_partial = match &expr {
91 Expr::Identifier(ident) => Some(ident.clone()),
93 Expr::CompoundIdentifier(idents) => {
95 if idents.len() > 1 {
97 let table_name = &idents[..idents.len() - 1]
98 .iter()
99 .map(|i| i.value.clone())
100 .collect::<Vec<_>>()
101 .join(".");
102 if wildcard_tables.iter().any(|name| name == table_name) {
103 Some(idents[idents.len() - 1].clone())
104 } else {
105 None
106 }
107 } else {
108 None
109 }
110 }
111 _ => None,
112 };
113
114 if let Some(name) = alias_partial {
115 let alias = format!("__alias_{name}");
116 new_projection.push(SelectItem::ExprWithAlias {
117 expr,
118 alias: Ident::new(alias),
119 });
120 } else {
121 new_projection.push(SelectItem::UnnamedExpr(expr));
122 }
123 }
124 _ => new_projection.push(p),
126 }
127 }
128 select.projection = new_projection;
129 }
130}
131
132impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
133 fn rewrite(&self, mut statement: Statement) -> Statement {
134 if let Statement::Query(query) = &mut statement {
135 if let SetExpr::Select(select) = query.body.as_mut() {
136 Self::rewrite_select_with_alias(select);
137 }
138 }
139
140 statement
141 }
142}
143
144#[derive(Debug)]
149pub struct ResolveUnqualifiedIdentifer;
150
151impl ResolveUnqualifiedIdentifer {
152 fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
153 if let SetExpr::Select(select) = query.body.as_mut() {
154 let table_aliases = Self::get_table_aliases(&select.from);
156
157 let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
159 if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
160 return; }
162
163 let wildcard_alias = qualified_wildcard_alias.unwrap();
164
165 if let Some(selection) = &mut select.selection {
167 Self::rewrite_expr(selection, &wildcard_alias, &table_aliases);
168 }
169
170 if let Some(OrderByKind::Expressions(order_by_exprs)) =
171 query.order_by.as_mut().map(|o| &mut o.kind)
172 {
173 for order_by_expr in order_by_exprs {
174 Self::rewrite_expr(&mut order_by_expr.expr, &wildcard_alias, &table_aliases);
175 }
176 }
177 }
178 }
179
180 fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
181 let mut aliases = HashSet::new();
182 for table_with_joins in tables {
183 if let TableFactor::Table {
184 alias: Some(alias), ..
185 } = &table_with_joins.relation
186 {
187 aliases.insert(alias.name.value.clone());
188 }
189 for join in &table_with_joins.joins {
190 if let TableFactor::Table {
191 alias: Some(alias), ..
192 } = &join.relation
193 {
194 aliases.insert(alias.name.value.clone());
195 }
196 }
197 }
198 aliases
199 }
200
201 fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
202 let mut qualified_wildcards = projection
203 .iter()
204 .filter_map(|item| {
205 if let SelectItem::QualifiedWildcard(
206 SelectItemQualifiedWildcardKind::ObjectName(objname),
207 _,
208 ) = item
209 {
210 Some(
211 objname
212 .0
213 .iter()
214 .map(|v| v.as_ident().unwrap().value.clone())
215 .collect::<Vec<_>>()
216 .join("."),
217 )
218 } else {
219 None
220 }
221 })
222 .collect::<Vec<_>>();
223
224 if qualified_wildcards.len() == 1 {
225 Some(qualified_wildcards.remove(0))
226 } else {
227 None
228 }
229 }
230
231 fn rewrite_expr(expr: &mut Expr, wildcard_alias: &str, table_aliases: &HashSet<String>) {
232 match expr {
233 Expr::Identifier(ident) => {
234 if !table_aliases.contains(&ident.value) {
236 *expr = Expr::CompoundIdentifier(vec![
237 Ident::new(wildcard_alias.to_string()),
238 ident.clone(),
239 ]);
240 }
241 }
242 Expr::BinaryOp { left, right, .. } => {
243 Self::rewrite_expr(left, wildcard_alias, table_aliases);
244 Self::rewrite_expr(right, wildcard_alias, table_aliases);
245 }
246 _ => {}
248 }
249 }
250}
251
252impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
253 fn rewrite(&self, mut statement: Statement) -> Statement {
254 if let Statement::Query(query) = &mut statement {
255 Self::rewrite_unqualified_identifiers(query);
256 }
257
258 statement
259 }
260}
261
262#[derive(Debug)]
265pub struct RemoveUnsupportedTypes {
266 unsupported_types: HashSet<String>,
267}
268
269impl Default for RemoveUnsupportedTypes {
270 fn default() -> Self {
271 Self::new()
272 }
273}
274
275impl RemoveUnsupportedTypes {
276 pub fn new() -> Self {
277 let mut unsupported_types = HashSet::new();
278
279 for item in [
280 "regclass",
281 "regproc",
282 "regtype",
283 "regtype[]",
284 "regnamespace",
285 "oid",
286 ] {
287 unsupported_types.insert(item.to_owned());
288 unsupported_types.insert(format!("pg_catalog.{item}"));
289 }
290
291 Self { unsupported_types }
292 }
293}
294
295struct RemoveUnsupportedTypesVisitor<'a> {
296 unsupported_types: &'a HashSet<String>,
297}
298
299impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> {
300 type Break = ();
301
302 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
303 match expr {
304 Expr::TypedString { value, data_type } => {
306 if self
307 .unsupported_types
308 .contains(data_type.to_string().to_lowercase().as_str())
309 {
310 *expr =
311 Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span());
312 }
313 }
314 Expr::Cast {
315 data_type,
316 expr: value,
317 ..
318 } => {
319 if self
320 .unsupported_types
321 .contains(data_type.to_string().to_lowercase().as_str())
322 {
323 *expr = *value.clone();
324 }
325 }
326 _ => {}
328 }
329
330 ControlFlow::Continue(())
331 }
332}
333
334impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
335 fn rewrite(&self, mut statement: Statement) -> Statement {
336 let mut visitor = RemoveUnsupportedTypesVisitor {
337 unsupported_types: &self.unsupported_types,
338 };
339 let _ = statement.visit(&mut visitor);
340 statement
341 }
342}
343
344#[derive(Debug)]
346pub struct RewriteArrayAnyAllOperation;
347
348struct RewriteArrayAnyAllOperationVisitor;
349
350impl RewriteArrayAnyAllOperationVisitor {
351 fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr {
352 let array = if let Expr::Value(ValueWithSpan {
353 value: Value::SingleQuotedString(array_literal),
354 ..
355 }) = right
356 {
357 let array_literal = array_literal.trim();
358 if array_literal.starts_with('{') && array_literal.ends_with('}') {
359 let items = array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
360 let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
361
362 let elems = items
364 .map(|s| {
365 Expr::Value(Value::SingleQuotedString(s.to_string()).with_empty_span())
366 })
367 .collect();
368 Expr::Array(Array {
369 elem: elems,
370 named: true,
371 })
372 } else {
373 right.clone()
374 }
375 } else {
376 right.clone()
377 };
378
379 Expr::Function(Function {
380 name: ObjectName::from(vec![Ident::new("array_contains")]),
381 args: FunctionArguments::List(FunctionArgumentList {
382 args: vec![
383 FunctionArg::Unnamed(FunctionArgExpr::Expr(array)),
384 FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())),
385 ],
386 duplicate_treatment: None,
387 clauses: vec![],
388 }),
389 uses_odbc_syntax: false,
390 parameters: FunctionArguments::None,
391 filter: None,
392 null_treatment: None,
393 over: None,
394 within_group: vec![],
395 })
396 }
397}
398
399impl VisitorMut for RewriteArrayAnyAllOperationVisitor {
400 type Break = ();
401
402 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
403 match expr {
404 Expr::AnyOp {
405 left,
406 compare_op,
407 right,
408 ..
409 } => match compare_op {
410 BinaryOperator::Eq => {
411 *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref());
412 }
413 BinaryOperator::NotEq => {
414 }
416 _ => {}
417 },
418 Expr::AllOp {
419 left,
420 compare_op,
421 right,
422 } => match compare_op {
423 BinaryOperator::Eq => {
424 }
426 BinaryOperator::NotEq => {
427 *expr = Expr::UnaryOp {
428 op: UnaryOperator::Not,
429 expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())),
430 }
431 }
432 _ => {}
433 },
434 _ => {}
435 }
436
437 ControlFlow::Continue(())
438 }
439}
440
441impl SqlStatementRewriteRule for RewriteArrayAnyAllOperation {
442 fn rewrite(&self, mut s: Statement) -> Statement {
443 let mut visitor = RewriteArrayAnyAllOperationVisitor;
444
445 let _ = s.visit(&mut visitor);
446
447 s
448 }
449}
450
451#[derive(Debug)]
456pub struct PrependUnqualifiedPgTableName;
457
458struct PrependUnqualifiedPgTableNameVisitor;
459
460impl VisitorMut for PrependUnqualifiedPgTableNameVisitor {
461 type Break = ();
462
463 fn pre_visit_table_factor(
464 &mut self,
465 table_factor: &mut TableFactor,
466 ) -> ControlFlow<Self::Break> {
467 if let TableFactor::Table { name, args, .. } = table_factor {
468 if args.is_none() && name.0.len() == 1 {
470 if let ObjectNamePart::Identifier(ident) = &name.0[0] {
471 if ident.value.starts_with("pg_") {
472 *name = ObjectName(vec![
473 ObjectNamePart::Identifier(Ident::new("pg_catalog")),
474 name.0[0].clone(),
475 ]);
476 }
477 }
478 }
479 }
480
481 ControlFlow::Continue(())
482 }
483}
484
485impl SqlStatementRewriteRule for PrependUnqualifiedPgTableName {
486 fn rewrite(&self, mut s: Statement) -> Statement {
487 let mut visitor = PrependUnqualifiedPgTableNameVisitor;
488
489 let _ = s.visit(&mut visitor);
490 s
491 }
492}
493
494#[derive(Debug)]
495pub struct FixArrayLiteral;
496
497struct FixArrayLiteralVisitor;
498
499impl FixArrayLiteralVisitor {
500 fn is_string_type(dt: &DataType) -> bool {
501 matches!(
502 dt,
503 DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::String(_)
504 )
505 }
506}
507
508impl VisitorMut for FixArrayLiteralVisitor {
509 type Break = ();
510
511 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
512 if let Expr::Cast {
513 kind,
514 expr,
515 data_type,
516 ..
517 } = expr
518 {
519 if kind == &CastKind::DoubleColon {
520 if let DataType::Array(arr) = data_type {
521 if let Expr::Value(ValueWithSpan {
523 value: Value::SingleQuotedString(array_literal),
524 ..
525 }) = expr.as_ref()
526 {
527 let items =
528 array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
529 let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
530
531 let is_text = match arr {
532 ArrayElemTypeDef::AngleBracket(dt) => Self::is_string_type(dt.as_ref()),
533 ArrayElemTypeDef::SquareBracket(dt, _) => {
534 Self::is_string_type(dt.as_ref())
535 }
536 ArrayElemTypeDef::Parenthesis(dt) => Self::is_string_type(dt.as_ref()),
537 _ => false,
538 };
539
540 let elems = items
541 .map(|s| {
542 if is_text {
543 Expr::Value(
544 Value::SingleQuotedString(s.to_string()).with_empty_span(),
545 )
546 } else {
547 Expr::Value(
548 Value::Number(s.to_string(), false).with_empty_span(),
549 )
550 }
551 })
552 .collect();
553 *expr = Box::new(Expr::Array(Array {
554 elem: elems,
555 named: true,
556 }));
557 }
558 }
559 }
560 }
561
562 ControlFlow::Continue(())
563 }
564}
565
566impl SqlStatementRewriteRule for FixArrayLiteral {
567 fn rewrite(&self, mut s: Statement) -> Statement {
568 let mut visitor = FixArrayLiteralVisitor;
569
570 let _ = s.visit(&mut visitor);
571 s
572 }
573}
574
575#[derive(Debug)]
582pub struct RemoveQualifier;
583
584struct RemoveQualifierVisitor;
585
586impl VisitorMut for RemoveQualifierVisitor {
587 type Break = ();
588
589 fn pre_visit_table_factor(
590 &mut self,
591 table_factor: &mut TableFactor,
592 ) -> ControlFlow<Self::Break> {
593 if let TableFactor::Table { name, args, .. } = table_factor {
595 if args.is_some() {
596 if name.0.len() > 1 {
598 if let Some(last_ident) = name.0.pop() {
599 *name = ObjectName(vec![last_ident]);
600 }
601 }
602 }
603 }
604 ControlFlow::Continue(())
605 }
606
607 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
608 match expr {
609 Expr::Cast { data_type, .. } => {
610 let data_type_str = data_type.to_string();
612 match data_type_str.as_str() {
613 "pg_catalog.text" => {
614 *data_type = DataType::Text;
615 }
616 "pg_catalog.int2[]" => {
617 *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket(
618 Box::new(DataType::Int16),
619 None,
620 ));
621 }
622 _ => {}
623 }
624 }
625 Expr::Function(function) => {
626 let name = &mut function.name;
628 if name.0.len() > 1 {
629 if let Some(last_ident) = name.0.pop() {
630 *name = ObjectName(vec![last_ident]);
631 }
632 }
633 }
634
635 _ => {}
636 }
637 ControlFlow::Continue(())
638 }
639}
640
641impl SqlStatementRewriteRule for RemoveQualifier {
642 fn rewrite(&self, mut s: Statement) -> Statement {
643 let mut visitor = RemoveQualifierVisitor;
644
645 let _ = s.visit(&mut visitor);
646 s
647 }
648}
649
650#[derive(Debug)]
652pub struct CurrentUserVariableToSessionUserFunctionCall;
653
654struct CurrentUserVariableToSessionUserFunctionCallVisitor;
655
656impl VisitorMut for CurrentUserVariableToSessionUserFunctionCallVisitor {
657 type Break = ();
658
659 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
660 if let Expr::Identifier(ident) = expr {
661 if ident.quote_style.is_none() && ident.value.to_lowercase() == "current_user" {
662 *expr = Expr::Function(Function {
663 name: ObjectName::from(vec![Ident::new("session_user")]),
664 args: FunctionArguments::None,
665 uses_odbc_syntax: false,
666 parameters: FunctionArguments::None,
667 filter: None,
668 null_treatment: None,
669 over: None,
670 within_group: vec![],
671 });
672 }
673 }
674
675 if let Expr::Function(func) = expr {
676 let fname = func
677 .name
678 .0
679 .iter()
680 .map(|ident| ident.to_string())
681 .collect::<Vec<String>>()
682 .join(".");
683 if fname.to_lowercase() == "current_user" {
684 func.name = ObjectName::from(vec![Ident::new("session_user")])
685 }
686 }
687
688 ControlFlow::Continue(())
689 }
690}
691
692impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall {
693 fn rewrite(&self, mut s: Statement) -> Statement {
694 let mut visitor = CurrentUserVariableToSessionUserFunctionCallVisitor;
695
696 let _ = s.visit(&mut visitor);
697 s
698 }
699}
700
701#[derive(Debug)]
703pub struct FixCollate;
704
705struct FixCollateVisitor;
706
707impl VisitorMut for FixCollateVisitor {
708 type Break = ();
709
710 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
711 match expr {
712 Expr::Collate { expr: inner, .. } => {
713 *expr = inner.as_ref().clone();
714 }
715 Expr::BinaryOp { op, .. } => {
716 if let BinaryOperator::PGCustomBinaryOperator(ops) = op {
717 if *ops == ["pg_catalog", "~"] {
718 *op = BinaryOperator::PGRegexMatch;
719 }
720 }
721 }
722 _ => {}
723 }
724
725 ControlFlow::Continue(())
726 }
727}
728
729impl SqlStatementRewriteRule for FixCollate {
730 fn rewrite(&self, mut s: Statement) -> Statement {
731 let mut visitor = FixCollateVisitor;
732
733 let _ = s.visit(&mut visitor);
734 s
735 }
736}
737
738#[derive(Debug)]
740pub struct RemoveSubqueryFromProjection;
741
742struct RemoveSubqueryFromProjectionVisitor;
743
744impl VisitorMut for RemoveSubqueryFromProjectionVisitor {
745 type Break = ();
746
747 fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
748 if let SetExpr::Select(select) = query.body.as_mut() {
749 for projection in &mut select.projection {
750 match projection {
751 SelectItem::UnnamedExpr(expr) => {
752 if let Expr::Subquery(_) = expr {
753 *expr = Expr::Value(Value::Null.with_empty_span());
754 }
755 }
756 SelectItem::ExprWithAlias { expr, .. } => {
757 if let Expr::Subquery(_) = expr {
758 *expr = Expr::Value(Value::Null.with_empty_span());
759 }
760 }
761 _ => {}
762 }
763 }
764 }
765
766 ControlFlow::Continue(())
767 }
768}
769
770impl SqlStatementRewriteRule for RemoveSubqueryFromProjection {
771 fn rewrite(&self, mut s: Statement) -> Statement {
772 let mut visitor = RemoveSubqueryFromProjectionVisitor;
773 let _ = s.visit(&mut visitor);
774
775 s
776 }
777}
778
779#[derive(Debug)]
781pub struct FixVersionColumnName;
782
783struct FixVersionColumnNameVisitor;
784
785impl VisitorMut for FixVersionColumnNameVisitor {
786 type Break = ();
787
788 fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
789 if let SetExpr::Select(select) = query.body.as_mut() {
790 for projection in &mut select.projection {
791 if let SelectItem::UnnamedExpr(Expr::Function(f)) = projection {
792 if f.name.0.len() == 1 {
793 if let ObjectNamePart::Identifier(part) = &f.name.0[0] {
794 if part.value == "version" {
795 if let FunctionArguments::List(args) = &f.args {
796 if args.args.is_empty() {
797 *projection = SelectItem::ExprWithAlias {
798 expr: Expr::Function(f.clone()),
799 alias: Ident::new("version"),
800 }
801 }
802 }
803 }
804 }
805 }
806 }
807 }
808 }
809
810 ControlFlow::Continue(())
811 }
812}
813
814impl SqlStatementRewriteRule for FixVersionColumnName {
815 fn rewrite(&self, mut s: Statement) -> Statement {
816 let mut visitor = FixVersionColumnNameVisitor;
817 let _ = s.visit(&mut visitor);
818
819 s
820 }
821}
822
823#[cfg(test)]
824mod tests {
825 use super::*;
826 use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
827 use datafusion::sql::sqlparser::parser::Parser;
828 use datafusion::sql::sqlparser::parser::ParserError;
829 use std::sync::Arc;
830
831 fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> {
832 let dialect = PostgreSqlDialect {};
833
834 Parser::parse_sql(&dialect, sql)
835 }
836
837 fn rewrite(mut s: Statement, rules: &[Arc<dyn SqlStatementRewriteRule>]) -> Statement {
838 for rule in rules {
839 s = rule.rewrite(s);
840 }
841
842 s
843 }
844
845 macro_rules! assert_rewrite {
846 ($rules:expr, $orig:expr, $rewt:expr) => {
847 let sql = $orig;
848 let statement = parse(sql).expect("Failed to parse").remove(0);
849
850 let statement = rewrite(statement, $rules);
851 assert_eq!(statement.to_string(), $rewt);
852 };
853 }
854
855 #[test]
856 fn test_alias_rewrite() {
857 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
858 vec![Arc::new(AliasDuplicatedProjectionRewrite)];
859
860 assert_rewrite!(
861 &rules,
862 "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n",
863 "SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
864 );
865
866 assert_rewrite!(
867 &rules,
868 "SELECT oid, * FROM pg_catalog.pg_namespace",
869 "SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
870 );
871
872 assert_rewrite!(
873 &rules,
874 "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id",
875 "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
876 );
877
878 let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
879 let statement = parse(sql).expect("Failed to parse").remove(0);
880
881 let statement = rewrite(statement, &rules);
882 assert_eq!(
883 statement.to_string(),
884 "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname"
885 );
886 }
887
888 #[test]
889 fn test_qualifier_prepend() {
890 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
891 vec![Arc::new(ResolveUnqualifiedIdentifer)];
892
893 assert_rewrite!(
894 &rules,
895 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname",
896 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
897 );
898
899 assert_rewrite!(
900 &rules,
901 "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname",
902 "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
903 );
904
905 assert_rewrite!(
906 &rules,
907 "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname",
908 "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
909 );
910 }
911
912 #[test]
913 fn test_remove_unsupported_types() {
914 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
915 Arc::new(RemoveQualifier),
916 Arc::new(RemoveUnsupportedTypes::new()),
917 ];
918
919 assert_rewrite!(
920 &rules,
921 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
922 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
923 );
924
925 assert_rewrite!(
926 &rules,
927 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
928 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname"
929 );
930
931 assert_rewrite!(
932 &rules,
933 "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname",
934 "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname"
935 );
936
937 assert_rewrite!(
938 &rules,
939 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname",
940 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
941 );
942
943 assert_rewrite!(
944 &rules,
945 "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::pg_catalog.regtype::pg_catalog.text END, c.relpersistence, c.relreplident, am.amname
946 FROM pg_catalog.pg_class c
947 LEFT JOIN pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid)
948 LEFT JOIN pg_catalog.pg_am am ON (c.relam = am.oid)
949 WHERE c.oid = '16386'",
950 "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::TEXT END, c.relpersistence, c.relreplident, am.amname FROM pg_catalog.pg_class AS c LEFT JOIN pg_catalog.pg_class AS tc ON (c.reltoastrelid = tc.oid) LEFT JOIN pg_catalog.pg_am AS am ON (c.relam = am.oid) WHERE c.oid = '16386'"
951 );
952 }
953
954 #[test]
955 fn test_any_to_array_contains() {
956 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
957 vec![Arc::new(RewriteArrayAnyAllOperation)];
958
959 assert_rewrite!(
960 &rules,
961 "SELECT a = ANY(current_schemas(true))",
962 "SELECT array_contains(current_schemas(true), a)"
963 );
964
965 assert_rewrite!(
966 &rules,
967 "SELECT a <> ALL(current_schemas(true))",
968 "SELECT NOT array_contains(current_schemas(true), a)"
969 );
970
971 assert_rewrite!(
972 &rules,
973 "SELECT a = ANY('{r, l, e}')",
974 "SELECT array_contains(ARRAY['r', 'l', 'e'], a)"
975 );
976
977 assert_rewrite!(
978 &rules,
979 "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))",
980 "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)"
981 );
982 }
983
984 #[test]
985 fn test_prepend_unqualified_table_name() {
986 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
987 vec![Arc::new(PrependUnqualifiedPgTableName)];
988
989 assert_rewrite!(
990 &rules,
991 "SELECT * FROM pg_catalog.pg_namespace",
992 "SELECT * FROM pg_catalog.pg_namespace"
993 );
994
995 assert_rewrite!(
996 &rules,
997 "SELECT * FROM pg_namespace",
998 "SELECT * FROM pg_catalog.pg_namespace"
999 );
1000
1001 assert_rewrite!(
1002 &rules,
1003 "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid",
1004 "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid"
1005 );
1006 }
1007
1008 #[test]
1009 fn test_array_literal_fix() {
1010 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixArrayLiteral)];
1011
1012 assert_rewrite!(
1013 &rules,
1014 "SELECT '{a, abc}'::text[]",
1015 "SELECT ARRAY['a', 'abc']::TEXT[]"
1016 );
1017
1018 assert_rewrite!(
1019 &rules,
1020 "SELECT '{1, 2}'::int[]",
1021 "SELECT ARRAY[1, 2]::INT[]"
1022 );
1023
1024 assert_rewrite!(
1025 &rules,
1026 "SELECT '{t, f}'::bool[]",
1027 "SELECT ARRAY[t, f]::BOOL[]"
1028 );
1029 }
1030
1031 #[test]
1032 fn test_remove_qualifier_from_table_function() {
1033 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(RemoveQualifier)];
1034
1035 assert_rewrite!(
1036 &rules,
1037 "SELECT * FROM pg_catalog.pg_get_keywords()",
1038 "SELECT * FROM pg_get_keywords()"
1039 );
1040 }
1041
1042 #[test]
1043 fn test_current_user() {
1044 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1045 vec![Arc::new(CurrentUserVariableToSessionUserFunctionCall)];
1046
1047 assert_rewrite!(&rules, "SELECT current_user", "SELECT session_user");
1048
1049 assert_rewrite!(&rules, "SELECT CURRENT_USER", "SELECT session_user");
1050
1051 assert_rewrite!(
1052 &rules,
1053 "SELECT is_null(current_user)",
1054 "SELECT is_null(session_user)"
1055 );
1056 }
1057
1058 #[test]
1059 fn test_collate_fix() {
1060 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixCollate)];
1061
1062 assert_rewrite!(&rules, "SELECT c.oid, c.relname FROM pg_catalog.pg_class c WHERE c.relname OPERATOR(pg_catalog.~) '^(tablename)$' COLLATE pg_catalog.default AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3;", "SELECT c.oid, c.relname FROM pg_catalog.pg_class AS c WHERE c.relname ~ '^(tablename)$' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3");
1063 }
1064
1065 #[test]
1066 fn test_remove_subquery() {
1067 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1068 vec![Arc::new(RemoveSubqueryFromProjection)];
1069
1070 assert_rewrite!(&rules,
1071 "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;",
1072 "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum");
1073 }
1074
1075 #[test]
1076 fn test_version_rewrite() {
1077 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixVersionColumnName)];
1078
1079 assert_rewrite!(&rules, "SELECT version()", "SELECT version() AS version");
1080
1081 assert_rewrite!(&rules, "SELECT version() as foo", "SELECT version() AS foo");
1083 assert_rewrite!(&rules, "SELECT version(foo)", "SELECT version(foo)");
1084 assert_rewrite!(&rules, "SELECT foo.version()", "SELECT foo.version()");
1085 }
1086}