1use std::collections::HashSet;
2use std::fmt::Debug;
3use std::ops::ControlFlow;
4
5use datafusion::sql::sqlparser::ast::Array;
6use datafusion::sql::sqlparser::ast::ArrayElemTypeDef;
7use datafusion::sql::sqlparser::ast::BinaryOperator;
8use datafusion::sql::sqlparser::ast::CastKind;
9use datafusion::sql::sqlparser::ast::DataType;
10use datafusion::sql::sqlparser::ast::Expr;
11use datafusion::sql::sqlparser::ast::Function;
12use datafusion::sql::sqlparser::ast::FunctionArg;
13use datafusion::sql::sqlparser::ast::FunctionArgExpr;
14use datafusion::sql::sqlparser::ast::FunctionArgumentList;
15use datafusion::sql::sqlparser::ast::FunctionArguments;
16use datafusion::sql::sqlparser::ast::Ident;
17use datafusion::sql::sqlparser::ast::ObjectName;
18use datafusion::sql::sqlparser::ast::ObjectNamePart;
19use datafusion::sql::sqlparser::ast::OrderByKind;
20use datafusion::sql::sqlparser::ast::Query;
21use datafusion::sql::sqlparser::ast::Select;
22use datafusion::sql::sqlparser::ast::SelectItem;
23use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
24use datafusion::sql::sqlparser::ast::SetExpr;
25use datafusion::sql::sqlparser::ast::Statement;
26use datafusion::sql::sqlparser::ast::TableFactor;
27use datafusion::sql::sqlparser::ast::TableWithJoins;
28use datafusion::sql::sqlparser::ast::TypedString;
29use datafusion::sql::sqlparser::ast::UnaryOperator;
30use datafusion::sql::sqlparser::ast::Value;
31use datafusion::sql::sqlparser::ast::ValueWithSpan;
32use datafusion::sql::sqlparser::ast::VisitMut;
33use datafusion::sql::sqlparser::ast::VisitorMut;
34
35pub trait SqlStatementRewriteRule: Send + Sync + Debug {
36 fn rewrite(&self, s: Statement) -> Statement;
37}
38
39#[derive(Debug)]
48pub struct AliasDuplicatedProjectionRewrite;
49
50impl AliasDuplicatedProjectionRewrite {
51 fn rewrite_select_with_alias(select: &mut Box<Select>) {
53 let mut wildcard_tables = Vec::new();
55 let mut has_simple_wildcard = false;
56 for p in &select.projection {
57 match p {
58 SelectItem::QualifiedWildcard(name, _) => match name {
59 SelectItemQualifiedWildcardKind::ObjectName(objname) => {
60 let idents = objname
62 .0
63 .iter()
64 .map(|v| v.as_ident().unwrap().value.clone())
65 .collect::<Vec<_>>()
66 .join(".");
67
68 wildcard_tables.push(idents);
69 }
70 SelectItemQualifiedWildcardKind::Expr(_expr) => {
71 }
73 },
74 SelectItem::Wildcard(_) => {
75 has_simple_wildcard = true;
76 }
77 _ => {}
78 }
79 }
80
81 if wildcard_tables.is_empty() && !has_simple_wildcard {
83 return;
84 }
85
86 let mut new_projection = vec![];
88 for p in select.projection.drain(..) {
89 match p {
90 SelectItem::UnnamedExpr(expr) => {
91 let alias_partial = match &expr {
92 Expr::Identifier(ident) => Some(ident.clone()),
94 Expr::CompoundIdentifier(idents) => {
96 if idents.len() > 1 {
98 let table_name = &idents[..idents.len() - 1]
99 .iter()
100 .map(|i| i.value.clone())
101 .collect::<Vec<_>>()
102 .join(".");
103 if wildcard_tables.iter().any(|name| name == table_name) {
104 Some(idents[idents.len() - 1].clone())
105 } else {
106 None
107 }
108 } else {
109 None
110 }
111 }
112 _ => None,
113 };
114
115 if let Some(name) = alias_partial {
116 let alias = format!("__alias_{name}");
117 new_projection.push(SelectItem::ExprWithAlias {
118 expr,
119 alias: Ident::new(alias),
120 });
121 } else {
122 new_projection.push(SelectItem::UnnamedExpr(expr));
123 }
124 }
125 _ => new_projection.push(p),
127 }
128 }
129 select.projection = new_projection;
130 }
131}
132
133impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
134 fn rewrite(&self, mut statement: Statement) -> Statement {
135 if let Statement::Query(query) = &mut statement {
136 if let SetExpr::Select(select) = query.body.as_mut() {
137 Self::rewrite_select_with_alias(select);
138 }
139 }
140
141 statement
142 }
143}
144
145#[derive(Debug)]
150pub struct ResolveUnqualifiedIdentifer;
151
152impl ResolveUnqualifiedIdentifer {
153 fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
154 if let SetExpr::Select(select) = query.body.as_mut() {
155 let table_aliases = Self::get_table_aliases(&select.from);
157
158 let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
160 if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
161 return; }
163
164 let wildcard_alias = qualified_wildcard_alias.unwrap();
165
166 let projection_aliases = Self::get_projection_aliases(&select.projection);
168
169 if let Some(selection) = &mut select.selection {
171 Self::rewrite_expr(
172 selection,
173 &wildcard_alias,
174 &table_aliases,
175 &projection_aliases,
176 );
177 }
178
179 if let Some(OrderByKind::Expressions(order_by_exprs)) =
180 query.order_by.as_mut().map(|o| &mut o.kind)
181 {
182 for order_by_expr in order_by_exprs {
183 Self::rewrite_expr(
184 &mut order_by_expr.expr,
185 &wildcard_alias,
186 &table_aliases,
187 &projection_aliases,
188 );
189 }
190 }
191 }
192 }
193
194 fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
195 let mut aliases = HashSet::new();
196 for table_with_joins in tables {
197 if let TableFactor::Table {
198 alias: Some(alias), ..
199 } = &table_with_joins.relation
200 {
201 aliases.insert(alias.name.value.clone());
202 }
203 for join in &table_with_joins.joins {
204 if let TableFactor::Table {
205 alias: Some(alias), ..
206 } = &join.relation
207 {
208 aliases.insert(alias.name.value.clone());
209 }
210 }
211 }
212 aliases
213 }
214
215 fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
216 let mut qualified_wildcards = projection
217 .iter()
218 .filter_map(|item| {
219 if let SelectItem::QualifiedWildcard(
220 SelectItemQualifiedWildcardKind::ObjectName(objname),
221 _,
222 ) = item
223 {
224 Some(
225 objname
226 .0
227 .iter()
228 .map(|v| v.as_ident().unwrap().value.clone())
229 .collect::<Vec<_>>()
230 .join("."),
231 )
232 } else {
233 None
234 }
235 })
236 .collect::<Vec<_>>();
237
238 if qualified_wildcards.len() == 1 {
239 Some(qualified_wildcards.remove(0))
240 } else {
241 None
242 }
243 }
244
245 fn get_projection_aliases(projection: &[SelectItem]) -> HashSet<String> {
246 let mut aliases = HashSet::new();
247 for item in projection {
248 match item {
249 SelectItem::ExprWithAlias { alias, .. } => {
250 aliases.insert(alias.value.clone());
251 }
252 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
253 aliases.insert(ident.value.clone());
254 }
255 _ => {}
256 }
257 }
258 aliases
259 }
260
261 fn rewrite_expr(
262 expr: &mut Expr,
263 wildcard_alias: &str,
264 table_aliases: &HashSet<String>,
265 projection_aliases: &HashSet<String>,
266 ) {
267 match expr {
268 Expr::Identifier(ident) => {
269 if !table_aliases.contains(&ident.value)
271 && !projection_aliases.contains(&ident.value)
272 {
273 *expr = Expr::CompoundIdentifier(vec![
274 Ident::new(wildcard_alias.to_string()),
275 ident.clone(),
276 ]);
277 }
278 }
279 Expr::BinaryOp { left, right, .. } => {
280 Self::rewrite_expr(left, wildcard_alias, table_aliases, projection_aliases);
281 Self::rewrite_expr(right, wildcard_alias, table_aliases, projection_aliases);
282 }
283 _ => {}
285 }
286 }
287}
288
289impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
290 fn rewrite(&self, mut statement: Statement) -> Statement {
291 if let Statement::Query(query) = &mut statement {
292 Self::rewrite_unqualified_identifiers(query);
293 }
294
295 statement
296 }
297}
298
299#[derive(Debug)]
302pub struct RemoveUnsupportedTypes {
303 unsupported_types: HashSet<String>,
304}
305
306impl Default for RemoveUnsupportedTypes {
307 fn default() -> Self {
308 Self::new()
309 }
310}
311
312impl RemoveUnsupportedTypes {
313 pub fn new() -> Self {
314 let mut unsupported_types = HashSet::new();
315
316 for item in [
317 "regclass",
318 "regproc",
319 "regtype",
320 "regtype[]",
321 "regnamespace",
322 "oid",
323 ] {
324 unsupported_types.insert(item.to_owned());
325 unsupported_types.insert(format!("pg_catalog.{item}"));
326 }
327
328 Self { unsupported_types }
329 }
330}
331
332struct RemoveUnsupportedTypesVisitor<'a> {
333 unsupported_types: &'a HashSet<String>,
334}
335
336impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> {
337 type Break = ();
338
339 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
340 match expr {
341 Expr::TypedString(TypedString {
343 data_type,
344 value,
345 uses_odbc_syntax: _,
346 }) => {
347 if self
348 .unsupported_types
349 .contains(data_type.to_string().to_lowercase().as_str())
350 {
351 *expr =
352 Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span());
353 }
354 }
355 Expr::Cast {
356 data_type,
357 expr: value,
358 ..
359 } => {
360 if self
361 .unsupported_types
362 .contains(data_type.to_string().to_lowercase().as_str())
363 {
364 *expr = *value.clone();
365 }
366 }
367 _ => {}
369 }
370
371 ControlFlow::Continue(())
372 }
373}
374
375impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
376 fn rewrite(&self, mut statement: Statement) -> Statement {
377 let mut visitor = RemoveUnsupportedTypesVisitor {
378 unsupported_types: &self.unsupported_types,
379 };
380 let _ = statement.visit(&mut visitor);
381 statement
382 }
383}
384
385#[derive(Debug)]
387pub struct RewriteArrayAnyAllOperation;
388
389struct RewriteArrayAnyAllOperationVisitor;
390
391impl RewriteArrayAnyAllOperationVisitor {
392 fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr {
393 let array = if let Expr::Value(ValueWithSpan {
394 value: Value::SingleQuotedString(array_literal),
395 ..
396 }) = right
397 {
398 let array_literal = array_literal.trim();
399 if array_literal.starts_with('{') && array_literal.ends_with('}') {
400 let items = array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
401 let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
402
403 let elems = items
405 .map(|s| {
406 Expr::Value(Value::SingleQuotedString(s.to_string()).with_empty_span())
407 })
408 .collect();
409 Expr::Array(Array {
410 elem: elems,
411 named: true,
412 })
413 } else {
414 right.clone()
415 }
416 } else {
417 right.clone()
418 };
419
420 Expr::Function(Function {
421 name: ObjectName::from(vec![Ident::new("array_contains")]),
422 args: FunctionArguments::List(FunctionArgumentList {
423 args: vec![
424 FunctionArg::Unnamed(FunctionArgExpr::Expr(array)),
425 FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())),
426 ],
427 duplicate_treatment: None,
428 clauses: vec![],
429 }),
430 uses_odbc_syntax: false,
431 parameters: FunctionArguments::None,
432 filter: None,
433 null_treatment: None,
434 over: None,
435 within_group: vec![],
436 })
437 }
438}
439
440impl VisitorMut for RewriteArrayAnyAllOperationVisitor {
441 type Break = ();
442
443 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
444 match expr {
445 Expr::AnyOp {
446 left,
447 compare_op,
448 right,
449 ..
450 } => match compare_op {
451 BinaryOperator::Eq => {
452 *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref());
453 }
454 BinaryOperator::NotEq => {
455 }
457 _ => {}
458 },
459 Expr::AllOp {
460 left,
461 compare_op,
462 right,
463 } => match compare_op {
464 BinaryOperator::Eq => {
465 }
467 BinaryOperator::NotEq => {
468 *expr = Expr::UnaryOp {
469 op: UnaryOperator::Not,
470 expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())),
471 }
472 }
473 _ => {}
474 },
475 _ => {}
476 }
477
478 ControlFlow::Continue(())
479 }
480}
481
482impl SqlStatementRewriteRule for RewriteArrayAnyAllOperation {
483 fn rewrite(&self, mut s: Statement) -> Statement {
484 let mut visitor = RewriteArrayAnyAllOperationVisitor;
485
486 let _ = s.visit(&mut visitor);
487
488 s
489 }
490}
491
492#[derive(Debug)]
497pub struct PrependUnqualifiedPgTableName;
498
499struct PrependUnqualifiedPgTableNameVisitor;
500
501impl VisitorMut for PrependUnqualifiedPgTableNameVisitor {
502 type Break = ();
503
504 fn pre_visit_table_factor(
505 &mut self,
506 table_factor: &mut TableFactor,
507 ) -> ControlFlow<Self::Break> {
508 if let TableFactor::Table { name, args, .. } = table_factor {
509 if args.is_none() && name.0.len() == 1 {
511 if let ObjectNamePart::Identifier(ident) = &name.0[0] {
512 if ident.value.starts_with("pg_") {
513 *name = ObjectName(vec![
514 ObjectNamePart::Identifier(Ident::new("pg_catalog")),
515 name.0[0].clone(),
516 ]);
517 }
518 }
519 }
520 }
521
522 ControlFlow::Continue(())
523 }
524}
525
526impl SqlStatementRewriteRule for PrependUnqualifiedPgTableName {
527 fn rewrite(&self, mut s: Statement) -> Statement {
528 let mut visitor = PrependUnqualifiedPgTableNameVisitor;
529
530 let _ = s.visit(&mut visitor);
531 s
532 }
533}
534
535#[derive(Debug)]
536pub struct FixArrayLiteral;
537
538struct FixArrayLiteralVisitor;
539
540impl FixArrayLiteralVisitor {
541 fn is_string_type(dt: &DataType) -> bool {
542 matches!(
543 dt,
544 DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::String(_)
545 )
546 }
547}
548
549impl VisitorMut for FixArrayLiteralVisitor {
550 type Break = ();
551
552 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
553 if let Expr::Cast {
554 kind,
555 expr,
556 data_type,
557 ..
558 } = expr
559 {
560 if kind == &CastKind::DoubleColon {
561 if let DataType::Array(arr) = data_type {
562 if let Expr::Value(ValueWithSpan {
564 value: Value::SingleQuotedString(array_literal),
565 ..
566 }) = expr.as_ref()
567 {
568 let items =
569 array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
570 let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
571
572 let is_text = match arr {
573 ArrayElemTypeDef::AngleBracket(dt) => Self::is_string_type(dt.as_ref()),
574 ArrayElemTypeDef::SquareBracket(dt, _) => {
575 Self::is_string_type(dt.as_ref())
576 }
577 ArrayElemTypeDef::Parenthesis(dt) => Self::is_string_type(dt.as_ref()),
578 _ => false,
579 };
580
581 let elems = items
582 .map(|s| {
583 if is_text {
584 Expr::Value(
585 Value::SingleQuotedString(s.to_string()).with_empty_span(),
586 )
587 } else {
588 Expr::Value(
589 Value::Number(s.to_string(), false).with_empty_span(),
590 )
591 }
592 })
593 .collect();
594 *expr = Box::new(Expr::Array(Array {
595 elem: elems,
596 named: true,
597 }));
598 }
599 }
600 }
601 }
602
603 ControlFlow::Continue(())
604 }
605}
606
607impl SqlStatementRewriteRule for FixArrayLiteral {
608 fn rewrite(&self, mut s: Statement) -> Statement {
609 let mut visitor = FixArrayLiteralVisitor;
610
611 let _ = s.visit(&mut visitor);
612 s
613 }
614}
615
616#[derive(Debug)]
623pub struct RemoveQualifier;
624
625struct RemoveQualifierVisitor;
626
627impl VisitorMut for RemoveQualifierVisitor {
628 type Break = ();
629
630 fn pre_visit_table_factor(
631 &mut self,
632 table_factor: &mut TableFactor,
633 ) -> ControlFlow<Self::Break> {
634 if let TableFactor::Table { name, args, .. } = table_factor {
636 if args.is_some() {
637 if name.0.len() > 1 {
639 if let Some(last_ident) = name.0.pop() {
640 *name = ObjectName(vec![last_ident]);
641 }
642 }
643 }
644 }
645 ControlFlow::Continue(())
646 }
647
648 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
649 match expr {
650 Expr::Cast { data_type, .. } => {
651 let data_type_str = data_type.to_string();
653 match data_type_str.as_str() {
654 "pg_catalog.text" => {
655 *data_type = DataType::Text;
656 }
657 "pg_catalog.int2[]" => {
658 *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket(
659 Box::new(DataType::Int16),
660 None,
661 ));
662 }
663 _ => {}
664 }
665 }
666 Expr::Function(function) => {
667 let name = &mut function.name;
669 if name.0.len() > 1 {
670 if let Some(last_ident) = name.0.pop() {
671 *name = ObjectName(vec![last_ident]);
672 }
673 }
674 }
675
676 _ => {}
677 }
678 ControlFlow::Continue(())
679 }
680}
681
682impl SqlStatementRewriteRule for RemoveQualifier {
683 fn rewrite(&self, mut s: Statement) -> Statement {
684 let mut visitor = RemoveQualifierVisitor;
685
686 let _ = s.visit(&mut visitor);
687 s
688 }
689}
690
691#[derive(Debug)]
693pub struct CurrentUserVariableToSessionUserFunctionCall;
694
695struct CurrentUserVariableToSessionUserFunctionCallVisitor;
696
697impl VisitorMut for CurrentUserVariableToSessionUserFunctionCallVisitor {
698 type Break = ();
699
700 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
701 if let Expr::Identifier(ident) = expr {
702 if ident.quote_style.is_none() && ident.value.to_lowercase() == "current_user" {
703 *expr = Expr::Function(Function {
704 name: ObjectName::from(vec![Ident::new("session_user")]),
705 args: FunctionArguments::None,
706 uses_odbc_syntax: false,
707 parameters: FunctionArguments::None,
708 filter: None,
709 null_treatment: None,
710 over: None,
711 within_group: vec![],
712 });
713 }
714 }
715
716 if let Expr::Function(func) = expr {
717 let fname = func
718 .name
719 .0
720 .iter()
721 .map(|ident| ident.to_string())
722 .collect::<Vec<String>>()
723 .join(".");
724 if fname.to_lowercase() == "current_user" {
725 func.name = ObjectName::from(vec![Ident::new("session_user")])
726 }
727 }
728
729 ControlFlow::Continue(())
730 }
731}
732
733impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall {
734 fn rewrite(&self, mut s: Statement) -> Statement {
735 let mut visitor = CurrentUserVariableToSessionUserFunctionCallVisitor;
736
737 let _ = s.visit(&mut visitor);
738 s
739 }
740}
741
742#[derive(Debug)]
744pub struct FixCollate;
745
746struct FixCollateVisitor;
747
748impl VisitorMut for FixCollateVisitor {
749 type Break = ();
750
751 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
752 match expr {
753 Expr::Collate { expr: inner, .. } => {
754 *expr = inner.as_ref().clone();
755 }
756 Expr::BinaryOp { op, .. } => {
757 if let BinaryOperator::PGCustomBinaryOperator(ops) = op {
758 if *ops == ["pg_catalog", "~"] {
759 *op = BinaryOperator::PGRegexMatch;
760 }
761 }
762 }
763 _ => {}
764 }
765
766 ControlFlow::Continue(())
767 }
768}
769
770impl SqlStatementRewriteRule for FixCollate {
771 fn rewrite(&self, mut s: Statement) -> Statement {
772 let mut visitor = FixCollateVisitor;
773
774 let _ = s.visit(&mut visitor);
775 s
776 }
777}
778
779#[derive(Debug)]
781pub struct RemoveSubqueryFromProjection;
782
783struct RemoveSubqueryFromProjectionVisitor;
784
785impl VisitorMut for RemoveSubqueryFromProjectionVisitor {
786 type Break = ();
787
788 fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
789 if let SetExpr::Select(select) = query.body.as_mut() {
790 for projection in &mut select.projection {
791 match projection {
792 SelectItem::UnnamedExpr(expr) => {
793 if let Expr::Subquery(_) = expr {
794 *expr = Expr::Value(Value::Null.with_empty_span());
795 }
796 }
797 SelectItem::ExprWithAlias { expr, .. } => {
798 if let Expr::Subquery(_) = expr {
799 *expr = Expr::Value(Value::Null.with_empty_span());
800 }
801 }
802 _ => {}
803 }
804 }
805 }
806
807 ControlFlow::Continue(())
808 }
809}
810
811impl SqlStatementRewriteRule for RemoveSubqueryFromProjection {
812 fn rewrite(&self, mut s: Statement) -> Statement {
813 let mut visitor = RemoveSubqueryFromProjectionVisitor;
814 let _ = s.visit(&mut visitor);
815
816 s
817 }
818}
819
820#[derive(Debug)]
822pub struct FixVersionColumnName;
823
824struct FixVersionColumnNameVisitor;
825
826impl VisitorMut for FixVersionColumnNameVisitor {
827 type Break = ();
828
829 fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
830 if let SetExpr::Select(select) = query.body.as_mut() {
831 for projection in &mut select.projection {
832 if let SelectItem::UnnamedExpr(Expr::Function(f)) = projection {
833 if f.name.0.len() == 1 {
834 if let ObjectNamePart::Identifier(part) = &f.name.0[0] {
835 if part.value == "version" {
836 if let FunctionArguments::List(args) = &f.args {
837 if args.args.is_empty() {
838 *projection = SelectItem::ExprWithAlias {
839 expr: Expr::Function(f.clone()),
840 alias: Ident::new("version"),
841 }
842 }
843 }
844 }
845 }
846 }
847 }
848 }
849 }
850
851 ControlFlow::Continue(())
852 }
853}
854
855impl SqlStatementRewriteRule for FixVersionColumnName {
856 fn rewrite(&self, mut s: Statement) -> Statement {
857 let mut visitor = FixVersionColumnNameVisitor;
858 let _ = s.visit(&mut visitor);
859
860 s
861 }
862}
863
864#[cfg(test)]
865mod tests {
866 use super::*;
867 use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
868 use datafusion::sql::sqlparser::parser::Parser;
869 use datafusion::sql::sqlparser::parser::ParserError;
870 use std::sync::Arc;
871
872 fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> {
873 let dialect = PostgreSqlDialect {};
874
875 Parser::parse_sql(&dialect, sql)
876 }
877
878 fn rewrite(mut s: Statement, rules: &[Arc<dyn SqlStatementRewriteRule>]) -> Statement {
879 for rule in rules {
880 s = rule.rewrite(s);
881 }
882
883 s
884 }
885
886 macro_rules! assert_rewrite {
887 ($rules:expr, $orig:expr, $rewt:expr) => {
888 let sql = $orig;
889 let statement = parse(sql).expect("Failed to parse").remove(0);
890
891 let statement = rewrite(statement, $rules);
892 assert_eq!(statement.to_string(), $rewt);
893 };
894 }
895
896 #[test]
897 fn test_alias_rewrite() {
898 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
899 vec![Arc::new(AliasDuplicatedProjectionRewrite)];
900
901 assert_rewrite!(
902 &rules,
903 "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n",
904 "SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
905 );
906
907 assert_rewrite!(
908 &rules,
909 "SELECT oid, * FROM pg_catalog.pg_namespace",
910 "SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
911 );
912
913 assert_rewrite!(
914 &rules,
915 "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id",
916 "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
917 );
918
919 let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
920 let statement = parse(sql).expect("Failed to parse").remove(0);
921
922 let statement = rewrite(statement, &rules);
923 assert_eq!(
924 statement.to_string(),
925 "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname"
926 );
927 }
928
929 #[test]
930 fn test_qualifier_prepend() {
931 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
932 vec![Arc::new(ResolveUnqualifiedIdentifer)];
933
934 assert_rewrite!(
935 &rules,
936 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname",
937 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
938 );
939
940 assert_rewrite!(
941 &rules,
942 "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname",
943 "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
944 );
945
946 assert_rewrite!(
947 &rules,
948 "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname",
949 "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
950 );
951
952 assert_rewrite!(&rules,
953 "SELECT i.*,i.indkey as keys,c.relname,c.relnamespace,c.relam,c.reltablespace,tc.relname as tabrelname,dsc.description FROM pg_catalog.pg_index i INNER JOIN pg_catalog.pg_class c ON c.oid=i.indexrelid INNER JOIN pg_catalog.pg_class tc ON tc.oid=i.indrelid LEFT OUTER JOIN pg_catalog.pg_description dsc ON i.indexrelid=dsc.objoid WHERE i.indrelid=1 ORDER BY tabrelname, c.relname",
954 "SELECT i.*, i.indkey AS keys, c.relname, c.relnamespace, c.relam, c.reltablespace, tc.relname AS tabrelname, dsc.description FROM pg_catalog.pg_index AS i INNER JOIN pg_catalog.pg_class AS c ON c.oid = i.indexrelid INNER JOIN pg_catalog.pg_class AS tc ON tc.oid = i.indrelid LEFT OUTER JOIN pg_catalog.pg_description AS dsc ON i.indexrelid = dsc.objoid WHERE i.indrelid = 1 ORDER BY tabrelname, c.relname"
955 );
956 }
957
958 #[test]
959 fn test_remove_unsupported_types() {
960 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
961 Arc::new(RemoveQualifier),
962 Arc::new(RemoveUnsupportedTypes::new()),
963 ];
964
965 assert_rewrite!(
966 &rules,
967 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
968 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
969 );
970
971 assert_rewrite!(
972 &rules,
973 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
974 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname"
975 );
976
977 assert_rewrite!(
978 &rules,
979 "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname",
980 "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname"
981 );
982
983 assert_rewrite!(
984 &rules,
985 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname",
986 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
987 );
988
989 assert_rewrite!(
990 &rules,
991 "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::pg_catalog.regtype::pg_catalog.text END, c.relpersistence, c.relreplident, am.amname
992 FROM pg_catalog.pg_class c
993 LEFT JOIN pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid)
994 LEFT JOIN pg_catalog.pg_am am ON (c.relam = am.oid)
995 WHERE c.oid = '16386'",
996 "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::TEXT END, c.relpersistence, c.relreplident, am.amname FROM pg_catalog.pg_class AS c LEFT JOIN pg_catalog.pg_class AS tc ON (c.reltoastrelid = tc.oid) LEFT JOIN pg_catalog.pg_am AS am ON (c.relam = am.oid) WHERE c.oid = '16386'"
997 );
998 }
999
1000 #[test]
1001 fn test_any_to_array_contains() {
1002 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1003 vec![Arc::new(RewriteArrayAnyAllOperation)];
1004
1005 assert_rewrite!(
1006 &rules,
1007 "SELECT a = ANY(current_schemas(true))",
1008 "SELECT array_contains(current_schemas(true), a)"
1009 );
1010
1011 assert_rewrite!(
1012 &rules,
1013 "SELECT a <> ALL(current_schemas(true))",
1014 "SELECT NOT array_contains(current_schemas(true), a)"
1015 );
1016
1017 assert_rewrite!(
1018 &rules,
1019 "SELECT a = ANY('{r, l, e}')",
1020 "SELECT array_contains(ARRAY['r', 'l', 'e'], a)"
1021 );
1022
1023 assert_rewrite!(
1024 &rules,
1025 "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))",
1026 "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)"
1027 );
1028 }
1029
1030 #[test]
1031 fn test_prepend_unqualified_table_name() {
1032 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1033 vec![Arc::new(PrependUnqualifiedPgTableName)];
1034
1035 assert_rewrite!(
1036 &rules,
1037 "SELECT * FROM pg_catalog.pg_namespace",
1038 "SELECT * FROM pg_catalog.pg_namespace"
1039 );
1040
1041 assert_rewrite!(
1042 &rules,
1043 "SELECT * FROM pg_namespace",
1044 "SELECT * FROM pg_catalog.pg_namespace"
1045 );
1046
1047 assert_rewrite!(
1048 &rules,
1049 "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid",
1050 "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid"
1051 );
1052 }
1053
1054 #[test]
1055 fn test_array_literal_fix() {
1056 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixArrayLiteral)];
1057
1058 assert_rewrite!(
1059 &rules,
1060 "SELECT '{a, abc}'::text[]",
1061 "SELECT ARRAY['a', 'abc']::TEXT[]"
1062 );
1063
1064 assert_rewrite!(
1065 &rules,
1066 "SELECT '{1, 2}'::int[]",
1067 "SELECT ARRAY[1, 2]::INT[]"
1068 );
1069
1070 assert_rewrite!(
1071 &rules,
1072 "SELECT '{t, f}'::bool[]",
1073 "SELECT ARRAY[t, f]::BOOL[]"
1074 );
1075 }
1076
1077 #[test]
1078 fn test_remove_qualifier_from_table_function() {
1079 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(RemoveQualifier)];
1080
1081 assert_rewrite!(
1082 &rules,
1083 "SELECT * FROM pg_catalog.pg_get_keywords()",
1084 "SELECT * FROM pg_get_keywords()"
1085 );
1086 }
1087
1088 #[test]
1089 fn test_current_user() {
1090 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1091 vec![Arc::new(CurrentUserVariableToSessionUserFunctionCall)];
1092
1093 assert_rewrite!(&rules, "SELECT current_user", "SELECT session_user");
1094
1095 assert_rewrite!(&rules, "SELECT CURRENT_USER", "SELECT session_user");
1096
1097 assert_rewrite!(
1098 &rules,
1099 "SELECT is_null(current_user)",
1100 "SELECT is_null(session_user)"
1101 );
1102 }
1103
1104 #[test]
1105 fn test_collate_fix() {
1106 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixCollate)];
1107
1108 assert_rewrite!(&rules, "SELECT c.oid, c.relname FROM pg_catalog.pg_class c WHERE c.relname OPERATOR(pg_catalog.~) '^(tablename)$' COLLATE pg_catalog.default AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3;", "SELECT c.oid, c.relname FROM pg_catalog.pg_class AS c WHERE c.relname ~ '^(tablename)$' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3");
1109 }
1110
1111 #[test]
1112 fn test_remove_subquery() {
1113 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1114 vec![Arc::new(RemoveSubqueryFromProjection)];
1115
1116 assert_rewrite!(&rules,
1117 "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;",
1118 "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum");
1119 }
1120
1121 #[test]
1122 fn test_version_rewrite() {
1123 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixVersionColumnName)];
1124
1125 assert_rewrite!(&rules, "SELECT version()", "SELECT version() AS version");
1126
1127 assert_rewrite!(&rules, "SELECT version() as foo", "SELECT version() AS foo");
1129 assert_rewrite!(&rules, "SELECT version(foo)", "SELECT version(foo)");
1130 assert_rewrite!(&rules, "SELECT foo.version()", "SELECT foo.version()");
1131 }
1132}