1use std::collections::HashSet;
2use std::fmt::Debug;
3use std::ops::ControlFlow;
4
5use datafusion::sql::sqlparser::ast::Array;
6use datafusion::sql::sqlparser::ast::ArrayElemTypeDef;
7use datafusion::sql::sqlparser::ast::BinaryOperator;
8use datafusion::sql::sqlparser::ast::CastKind;
9use datafusion::sql::sqlparser::ast::DataType;
10use datafusion::sql::sqlparser::ast::Expr;
11use datafusion::sql::sqlparser::ast::Function;
12use datafusion::sql::sqlparser::ast::FunctionArg;
13use datafusion::sql::sqlparser::ast::FunctionArgExpr;
14use datafusion::sql::sqlparser::ast::FunctionArgumentList;
15use datafusion::sql::sqlparser::ast::FunctionArguments;
16use datafusion::sql::sqlparser::ast::Ident;
17use datafusion::sql::sqlparser::ast::LimitClause;
18use datafusion::sql::sqlparser::ast::ObjectName;
19use datafusion::sql::sqlparser::ast::ObjectNamePart;
20use datafusion::sql::sqlparser::ast::OrderByKind;
21use datafusion::sql::sqlparser::ast::Query;
22use datafusion::sql::sqlparser::ast::Select;
23use datafusion::sql::sqlparser::ast::SelectItem;
24use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
25use datafusion::sql::sqlparser::ast::SetExpr;
26use datafusion::sql::sqlparser::ast::Statement;
27use datafusion::sql::sqlparser::ast::TableFactor;
28use datafusion::sql::sqlparser::ast::TableWithJoins;
29use datafusion::sql::sqlparser::ast::TypedString;
30use datafusion::sql::sqlparser::ast::UnaryOperator;
31use datafusion::sql::sqlparser::ast::Value;
32use datafusion::sql::sqlparser::ast::ValueWithSpan;
33use datafusion::sql::sqlparser::ast::VisitMut;
34use datafusion::sql::sqlparser::ast::Visitor;
35use datafusion::sql::sqlparser::ast::VisitorMut;
36use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
37use datafusion::sql::sqlparser::parser::Parser;
38
39pub trait SqlStatementRewriteRule: Send + Sync + Debug {
40 fn rewrite(&self, s: Statement) -> Statement;
41}
42
43#[derive(Debug)]
52pub struct AliasDuplicatedProjectionRewrite;
53
54impl AliasDuplicatedProjectionRewrite {
55 fn rewrite_select_with_alias(select: &mut Box<Select>) {
57 let mut wildcard_tables = Vec::new();
59 let mut has_simple_wildcard = false;
60 for p in &select.projection {
61 match p {
62 SelectItem::QualifiedWildcard(name, _) => match name {
63 SelectItemQualifiedWildcardKind::ObjectName(objname) => {
64 let idents = objname
66 .0
67 .iter()
68 .map(|v| v.as_ident().unwrap().value.clone())
69 .collect::<Vec<_>>()
70 .join(".");
71
72 wildcard_tables.push(idents);
73 }
74 SelectItemQualifiedWildcardKind::Expr(_expr) => {
75 }
77 },
78 SelectItem::Wildcard(_) => {
79 has_simple_wildcard = true;
80 }
81 _ => {}
82 }
83 }
84
85 if wildcard_tables.is_empty() && !has_simple_wildcard {
87 return;
88 }
89
90 let mut new_projection = vec![];
92 for p in select.projection.drain(..) {
93 match p {
94 SelectItem::UnnamedExpr(expr) => {
95 let alias_partial = match &expr {
96 Expr::Identifier(ident) => Some(ident.clone()),
98 Expr::CompoundIdentifier(idents) => {
100 if idents.len() > 1 {
102 let table_name = &idents[..idents.len() - 1]
103 .iter()
104 .map(|i| i.value.clone())
105 .collect::<Vec<_>>()
106 .join(".");
107 if wildcard_tables.iter().any(|name| name == table_name) {
108 Some(idents[idents.len() - 1].clone())
109 } else {
110 None
111 }
112 } else {
113 None
114 }
115 }
116 _ => None,
117 };
118
119 if let Some(name) = alias_partial {
120 let alias = format!("__alias_{name}");
121 new_projection.push(SelectItem::ExprWithAlias {
122 expr,
123 alias: Ident::new(alias),
124 });
125 } else {
126 new_projection.push(SelectItem::UnnamedExpr(expr));
127 }
128 }
129 _ => new_projection.push(p),
131 }
132 }
133 select.projection = new_projection;
134 }
135}
136
137impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
138 fn rewrite(&self, mut statement: Statement) -> Statement {
139 if let Statement::Query(query) = &mut statement {
140 if let SetExpr::Select(select) = query.body.as_mut() {
141 Self::rewrite_select_with_alias(select);
142 }
143 }
144
145 statement
146 }
147}
148
149#[derive(Debug)]
154pub struct ResolveUnqualifiedIdentifer;
155
156impl ResolveUnqualifiedIdentifer {
157 fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
158 if let SetExpr::Select(select) = query.body.as_mut() {
159 let table_aliases = Self::get_table_aliases(&select.from);
161
162 let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
164 if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
165 return; }
167
168 let wildcard_alias = qualified_wildcard_alias.unwrap();
169
170 let projection_aliases = Self::get_projection_aliases(&select.projection);
172
173 if let Some(selection) = &mut select.selection {
175 Self::rewrite_expr(
176 selection,
177 &wildcard_alias,
178 &table_aliases,
179 &projection_aliases,
180 );
181 }
182
183 if let Some(OrderByKind::Expressions(order_by_exprs)) =
184 query.order_by.as_mut().map(|o| &mut o.kind)
185 {
186 for order_by_expr in order_by_exprs {
187 Self::rewrite_expr(
188 &mut order_by_expr.expr,
189 &wildcard_alias,
190 &table_aliases,
191 &projection_aliases,
192 );
193 }
194 }
195 }
196 }
197
198 fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
199 let mut aliases = HashSet::new();
200 for table_with_joins in tables {
201 if let TableFactor::Table {
202 alias: Some(alias), ..
203 } = &table_with_joins.relation
204 {
205 aliases.insert(alias.name.value.clone());
206 }
207 for join in &table_with_joins.joins {
208 if let TableFactor::Table {
209 alias: Some(alias), ..
210 } = &join.relation
211 {
212 aliases.insert(alias.name.value.clone());
213 }
214 }
215 }
216 aliases
217 }
218
219 fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
220 let mut qualified_wildcards = projection
221 .iter()
222 .filter_map(|item| {
223 if let SelectItem::QualifiedWildcard(
224 SelectItemQualifiedWildcardKind::ObjectName(objname),
225 _,
226 ) = item
227 {
228 Some(
229 objname
230 .0
231 .iter()
232 .map(|v| v.as_ident().unwrap().value.clone())
233 .collect::<Vec<_>>()
234 .join("."),
235 )
236 } else {
237 None
238 }
239 })
240 .collect::<Vec<_>>();
241
242 if qualified_wildcards.len() == 1 {
243 Some(qualified_wildcards.remove(0))
244 } else {
245 None
246 }
247 }
248
249 fn get_projection_aliases(projection: &[SelectItem]) -> HashSet<String> {
250 let mut aliases = HashSet::new();
251 for item in projection {
252 match item {
253 SelectItem::ExprWithAlias { alias, .. } => {
254 aliases.insert(alias.value.clone());
255 }
256 SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
257 aliases.insert(ident.value.clone());
258 }
259 _ => {}
260 }
261 }
262 aliases
263 }
264
265 fn rewrite_expr(
266 expr: &mut Expr,
267 wildcard_alias: &str,
268 table_aliases: &HashSet<String>,
269 projection_aliases: &HashSet<String>,
270 ) {
271 match expr {
272 Expr::Identifier(ident) => {
273 if !table_aliases.contains(&ident.value)
275 && !projection_aliases.contains(&ident.value)
276 {
277 *expr = Expr::CompoundIdentifier(vec![
278 Ident::new(wildcard_alias.to_string()),
279 ident.clone(),
280 ]);
281 }
282 }
283 Expr::BinaryOp { left, right, .. } => {
284 Self::rewrite_expr(left, wildcard_alias, table_aliases, projection_aliases);
285 Self::rewrite_expr(right, wildcard_alias, table_aliases, projection_aliases);
286 }
287 _ => {}
289 }
290 }
291}
292
293impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
294 fn rewrite(&self, mut statement: Statement) -> Statement {
295 if let Statement::Query(query) = &mut statement {
296 Self::rewrite_unqualified_identifiers(query);
297 }
298
299 statement
300 }
301}
302
303#[derive(Debug)]
306pub struct RemoveUnsupportedTypes {
307 unsupported_types: HashSet<String>,
308}
309
310impl Default for RemoveUnsupportedTypes {
311 fn default() -> Self {
312 Self::new()
313 }
314}
315
316impl RemoveUnsupportedTypes {
317 pub fn new() -> Self {
318 let mut unsupported_types = HashSet::new();
319
320 for item in [
321 "regclass",
322 "regproc",
323 "regtype",
324 "regtype[]",
325 "regnamespace",
326 "oid",
327 ] {
328 unsupported_types.insert(item.to_owned());
329 unsupported_types.insert(format!("pg_catalog.{item}"));
330 }
331
332 Self { unsupported_types }
333 }
334}
335
336struct RemoveUnsupportedTypesVisitor<'a> {
337 unsupported_types: &'a HashSet<String>,
338}
339
340impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> {
341 type Break = ();
342
343 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
344 match expr {
345 Expr::TypedString(TypedString {
347 data_type,
348 value,
349 uses_odbc_syntax: _,
350 }) => {
351 if self
352 .unsupported_types
353 .contains(data_type.to_string().to_lowercase().as_str())
354 {
355 *expr =
356 Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span());
357 }
358 }
359 Expr::Cast {
360 data_type,
361 expr: value,
362 ..
363 } => {
364 if self
365 .unsupported_types
366 .contains(data_type.to_string().to_lowercase().as_str())
367 {
368 *expr = *value.clone();
369 }
370 }
371 _ => {}
373 }
374
375 ControlFlow::Continue(())
376 }
377}
378
379impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
380 fn rewrite(&self, mut statement: Statement) -> Statement {
381 let mut visitor = RemoveUnsupportedTypesVisitor {
382 unsupported_types: &self.unsupported_types,
383 };
384 let _ = statement.visit(&mut visitor);
385 statement
386 }
387}
388
389#[derive(Debug)]
394pub struct RewriteRegclassCastToSubquery(Box<Query>);
395
396impl Default for RewriteRegclassCastToSubquery {
397 fn default() -> Self {
398 Self::new()
399 }
400}
401
402impl RewriteRegclassCastToSubquery {
403 pub fn new() -> Self {
404 let sql = "SELECT c.oid
405FROM pg_catalog.pg_class c
406JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
407CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) p
408WHERE n.nspname = COALESCE(
409 CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END,
410 current_schema()
411)
412AND c.relname = p.parts[-1]";
413 let dialect = PostgreSqlDialect {};
414 let query = Parser::parse_sql(&dialect, sql)
415 .map(|mut stmts| {
416 let stmt = stmts.remove(0);
417 if let Statement::Query(query) = stmt {
418 query
419 } else {
420 unreachable!()
421 }
422 })
423 .expect("Failed to parse prepared query");
424 Self(query)
425 }
426}
427
428struct RewriteRegclassCastToSubqueryVisitor(Box<Query>);
429
430impl RewriteRegclassCastToSubqueryVisitor {
431 pub fn new(query: Box<Query>) -> Self {
432 Self(query)
433 }
434
435 fn create_subquery(&self, expr: &Expr) -> Expr {
436 struct PlaceholderReplacer(Expr);
437
438 impl VisitorMut for PlaceholderReplacer {
439 type Break = ();
440
441 fn pre_visit_expr(&mut self, e: &mut Expr) -> ControlFlow<Self::Break> {
442 if let Expr::Value(ValueWithSpan {
443 value: Value::Placeholder(_placeholder),
444 ..
445 }) = e
446 {
447 *e = self.0.clone();
448 }
449 ControlFlow::Continue(())
450 }
451 }
452
453 let mut query = self.0.clone();
454 let mut replacer = PlaceholderReplacer(expr.clone());
455 let _ = query.visit(&mut replacer);
456 Expr::Subquery(query)
457 }
458
459 fn is_regclass_to_oid_cast(&self, expr: &Expr) -> bool {
460 if let Expr::Cast {
461 kind,
462 data_type,
463 expr: inner_expr,
464 format: _,
465 ..
466 } = expr
467 {
468 if *kind == CastKind::DoubleColon {
469 let dt_lower = data_type.to_string().to_lowercase();
470 if dt_lower == "oid" || dt_lower == "pg_catalog.oid" {
471 return self.is_regclass_cast(inner_expr);
472 }
473 }
474 }
475 false
476 }
477
478 fn is_regclass_cast(&self, expr: &Expr) -> bool {
479 if let Expr::Cast {
480 kind,
481 data_type,
482 expr: _,
483 format: _,
484 ..
485 } = expr
486 {
487 if *kind == CastKind::DoubleColon {
488 let dt_lower = data_type.to_string().to_lowercase();
489 return dt_lower == "regclass" || dt_lower == "pg_catalog.regclass";
490 }
491 }
492 false
493 }
494
495 fn extract_inner_expr(&self, expr: &Expr) -> Option<Expr> {
496 if let Expr::Cast {
497 kind,
498 data_type,
499 expr: inner_expr,
500 format: _,
501 ..
502 } = expr
503 {
504 if *kind == CastKind::DoubleColon {
505 let dt_lower = data_type.to_string().to_lowercase();
506 if dt_lower == "oid" || dt_lower == "pg_catalog.oid" {
507 if let Expr::Cast {
508 kind: inner_kind,
509 data_type: inner_data_type,
510 expr: inner_inner_expr,
511 format: _,
512 ..
513 } = inner_expr.as_ref()
514 {
515 if *inner_kind == CastKind::DoubleColon {
516 let inner_dt_lower = inner_data_type.to_string().to_lowercase();
517 if inner_dt_lower == "regclass"
518 || inner_dt_lower == "pg_catalog.regclass"
519 {
520 return Some((**inner_inner_expr).clone());
521 }
522 }
523 }
524 }
525 }
526 }
527 None
528 }
529}
530
531impl VisitorMut for RewriteRegclassCastToSubqueryVisitor {
532 type Break = ();
533
534 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
535 if self.is_regclass_to_oid_cast(expr) {
536 if let Some(inner_expr) = self.extract_inner_expr(expr) {
537 *expr = self.create_subquery(&inner_expr);
538 }
539 }
540 ControlFlow::Continue(())
541 }
542}
543
544impl SqlStatementRewriteRule for RewriteRegclassCastToSubquery {
545 fn rewrite(&self, mut s: Statement) -> Statement {
546 let mut visitor = RewriteRegclassCastToSubqueryVisitor::new(self.0.clone());
547 let _ = s.visit(&mut visitor);
548 s
549 }
550}
551
552#[derive(Debug)]
554pub struct RewriteArrayAnyAllOperation;
555
556struct RewriteArrayAnyAllOperationVisitor;
557
558impl RewriteArrayAnyAllOperationVisitor {
559 fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr {
560 let array = if let Expr::Value(ValueWithSpan {
561 value: Value::SingleQuotedString(array_literal),
562 ..
563 }) = right
564 {
565 let array_literal = array_literal.trim();
566 if array_literal.starts_with('{') && array_literal.ends_with('}') {
567 let items = array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
568 let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
569
570 let elems = items
572 .map(|s| {
573 Expr::Value(Value::SingleQuotedString(s.to_string()).with_empty_span())
574 })
575 .collect();
576 Expr::Array(Array {
577 elem: elems,
578 named: true,
579 })
580 } else {
581 right.clone()
582 }
583 } else {
584 right.clone()
585 };
586
587 Expr::Function(Function {
588 name: ObjectName::from(vec![Ident::new("array_contains")]),
589 args: FunctionArguments::List(FunctionArgumentList {
590 args: vec![
591 FunctionArg::Unnamed(FunctionArgExpr::Expr(array)),
592 FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())),
593 ],
594 duplicate_treatment: None,
595 clauses: vec![],
596 }),
597 uses_odbc_syntax: false,
598 parameters: FunctionArguments::None,
599 filter: None,
600 null_treatment: None,
601 over: None,
602 within_group: vec![],
603 })
604 }
605}
606
607impl VisitorMut for RewriteArrayAnyAllOperationVisitor {
608 type Break = ();
609
610 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
611 match expr {
612 Expr::AnyOp {
613 left,
614 compare_op,
615 right,
616 ..
617 } => match compare_op {
618 BinaryOperator::Eq => {
619 *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref());
620 }
621 BinaryOperator::NotEq => {
622 }
624 _ => {}
625 },
626 Expr::AllOp {
627 left,
628 compare_op,
629 right,
630 } => match compare_op {
631 BinaryOperator::Eq => {
632 }
634 BinaryOperator::NotEq => {
635 *expr = Expr::UnaryOp {
636 op: UnaryOperator::Not,
637 expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())),
638 }
639 }
640 _ => {}
641 },
642 _ => {}
643 }
644
645 ControlFlow::Continue(())
646 }
647}
648
649impl SqlStatementRewriteRule for RewriteArrayAnyAllOperation {
650 fn rewrite(&self, mut s: Statement) -> Statement {
651 let mut visitor = RewriteArrayAnyAllOperationVisitor;
652
653 let _ = s.visit(&mut visitor);
654
655 s
656 }
657}
658
659#[derive(Debug)]
664pub struct PrependUnqualifiedPgTableName;
665
666struct PrependUnqualifiedPgTableNameVisitor;
667
668impl VisitorMut for PrependUnqualifiedPgTableNameVisitor {
669 type Break = ();
670
671 fn pre_visit_table_factor(
672 &mut self,
673 table_factor: &mut TableFactor,
674 ) -> ControlFlow<Self::Break> {
675 if let TableFactor::Table { name, args, .. } = table_factor {
676 if args.is_none() && name.0.len() == 1 {
678 if let ObjectNamePart::Identifier(ident) = &name.0[0] {
679 if ident.value.starts_with("pg_") {
680 *name = ObjectName(vec![
681 ObjectNamePart::Identifier(Ident::new("pg_catalog")),
682 name.0[0].clone(),
683 ]);
684 }
685 }
686 }
687 }
688
689 ControlFlow::Continue(())
690 }
691}
692
693impl SqlStatementRewriteRule for PrependUnqualifiedPgTableName {
694 fn rewrite(&self, mut s: Statement) -> Statement {
695 let mut visitor = PrependUnqualifiedPgTableNameVisitor;
696
697 let _ = s.visit(&mut visitor);
698 s
699 }
700}
701
702#[derive(Debug)]
703pub struct FixArrayLiteral;
704
705struct FixArrayLiteralVisitor;
706
707impl FixArrayLiteralVisitor {
708 fn is_string_type(dt: &DataType) -> bool {
709 matches!(
710 dt,
711 DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::String(_)
712 )
713 }
714}
715
716impl VisitorMut for FixArrayLiteralVisitor {
717 type Break = ();
718
719 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
720 if let Expr::Cast {
721 kind,
722 expr,
723 data_type,
724 ..
725 } = expr
726 {
727 if kind == &CastKind::DoubleColon {
728 if let DataType::Array(arr) = data_type {
729 if let Expr::Value(ValueWithSpan {
731 value: Value::SingleQuotedString(array_literal),
732 ..
733 }) = expr.as_ref()
734 {
735 let items =
736 array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
737 let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
738
739 let is_text = match arr {
740 ArrayElemTypeDef::AngleBracket(dt) => Self::is_string_type(dt.as_ref()),
741 ArrayElemTypeDef::SquareBracket(dt, _) => {
742 Self::is_string_type(dt.as_ref())
743 }
744 ArrayElemTypeDef::Parenthesis(dt) => Self::is_string_type(dt.as_ref()),
745 _ => false,
746 };
747
748 let elems = items
749 .map(|s| {
750 if is_text {
751 Expr::Value(
752 Value::SingleQuotedString(s.to_string()).with_empty_span(),
753 )
754 } else {
755 Expr::Value(
756 Value::Number(s.to_string(), false).with_empty_span(),
757 )
758 }
759 })
760 .collect();
761 **expr = Expr::Array(Array {
762 elem: elems,
763 named: true,
764 });
765 }
766 }
767 }
768 }
769
770 ControlFlow::Continue(())
771 }
772}
773
774impl SqlStatementRewriteRule for FixArrayLiteral {
775 fn rewrite(&self, mut s: Statement) -> Statement {
776 let mut visitor = FixArrayLiteralVisitor;
777
778 let _ = s.visit(&mut visitor);
779 s
780 }
781}
782
783#[derive(Debug)]
790pub struct RemoveQualifier;
791
792struct RemoveQualifierVisitor;
793
794impl VisitorMut for RemoveQualifierVisitor {
795 type Break = ();
796
797 fn pre_visit_table_factor(
798 &mut self,
799 table_factor: &mut TableFactor,
800 ) -> ControlFlow<Self::Break> {
801 if let TableFactor::Table { name, args, .. } = table_factor {
803 if args.is_some() {
804 if name.0.len() > 1 {
806 if let Some(last_ident) = name.0.pop() {
807 *name = ObjectName(vec![last_ident]);
808 }
809 }
810 }
811 }
812 ControlFlow::Continue(())
813 }
814
815 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
816 match expr {
817 Expr::Cast { data_type, .. } => {
818 let data_type_str = data_type.to_string();
820 match data_type_str.as_str() {
821 "pg_catalog.text" => {
822 *data_type = DataType::Text;
823 }
824 "pg_catalog.int2[]" => {
825 *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket(
826 Box::new(DataType::Int16),
827 None,
828 ));
829 }
830 _ => {}
831 }
832 }
833 Expr::Function(function) => {
834 let name = &mut function.name;
836 if name.0.len() > 1 {
837 if let Some(last_ident) = name.0.pop() {
838 *name = ObjectName(vec![last_ident]);
839 }
840 }
841 }
842
843 _ => {}
844 }
845 ControlFlow::Continue(())
846 }
847}
848
849impl SqlStatementRewriteRule for RemoveQualifier {
850 fn rewrite(&self, mut s: Statement) -> Statement {
851 let mut visitor = RemoveQualifierVisitor;
852
853 let _ = s.visit(&mut visitor);
854 s
855 }
856}
857
858#[derive(Debug)]
860pub struct CurrentUserVariableToSessionUserFunctionCall;
861
862struct CurrentUserVariableToSessionUserFunctionCallVisitor;
863
864impl VisitorMut for CurrentUserVariableToSessionUserFunctionCallVisitor {
865 type Break = ();
866
867 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
868 if let Expr::Identifier(ident) = expr {
869 if ident.quote_style.is_none() && ident.value.to_lowercase() == "current_user" {
870 *expr = Expr::Function(Function {
871 name: ObjectName::from(vec![Ident::new("session_user")]),
872 args: FunctionArguments::None,
873 uses_odbc_syntax: false,
874 parameters: FunctionArguments::None,
875 filter: None,
876 null_treatment: None,
877 over: None,
878 within_group: vec![],
879 });
880 }
881 }
882
883 if let Expr::Function(func) = expr {
884 let fname = func
885 .name
886 .0
887 .iter()
888 .map(|ident| ident.to_string())
889 .collect::<Vec<String>>()
890 .join(".");
891 if fname.to_lowercase() == "current_user" {
892 func.name = ObjectName::from(vec![Ident::new("session_user")])
893 }
894 }
895
896 ControlFlow::Continue(())
897 }
898}
899
900impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall {
901 fn rewrite(&self, mut s: Statement) -> Statement {
902 let mut visitor = CurrentUserVariableToSessionUserFunctionCallVisitor;
903
904 let _ = s.visit(&mut visitor);
905 s
906 }
907}
908
909#[derive(Debug)]
911pub struct FixCollate;
912
913struct FixCollateVisitor;
914
915impl VisitorMut for FixCollateVisitor {
916 type Break = ();
917
918 fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
919 match expr {
920 Expr::Collate { expr: inner, .. } => {
921 *expr = inner.as_ref().clone();
922 }
923 Expr::BinaryOp { op, .. } => {
924 if let BinaryOperator::PGCustomBinaryOperator(ops) = op {
925 if *ops == ["pg_catalog", "~"] {
926 *op = BinaryOperator::PGRegexMatch;
927 }
928 }
929 }
930 _ => {}
931 }
932
933 ControlFlow::Continue(())
934 }
935}
936
937impl SqlStatementRewriteRule for FixCollate {
938 fn rewrite(&self, mut s: Statement) -> Statement {
939 let mut visitor = FixCollateVisitor;
940
941 let _ = s.visit(&mut visitor);
942 s
943 }
944}
945
946#[derive(Debug)]
951pub struct RemoveSubqueryFromProjection;
952
953struct RemoveSubqueryFromProjectionVisitor;
954
955impl RemoveSubqueryFromProjectionVisitor {
956 fn has_correlation(&self, query: &Query) -> bool {
957 if let SetExpr::Select(select) = &*query.body {
958 let table_aliases: HashSet<String> = select
959 .from
960 .iter()
961 .flat_map(|twj| {
962 let mut aliases = HashSet::new();
963 Self::collect_table_aliases_from_table_factor(&twj.relation, &mut aliases);
964 for join in &twj.joins {
965 Self::collect_table_aliases_from_table_factor(&join.relation, &mut aliases);
966 }
967 aliases
968 })
969 .collect();
970
971 let mut has_correlation = false;
972 let mut visitor = CorrelationCheckVisitor(&mut has_correlation, &table_aliases);
973 let _ = datafusion::logical_expr::sqlparser::ast::Visit::visit(query, &mut visitor);
974 has_correlation
975 } else {
976 false
977 }
978 }
979
980 fn has_limit(&self, query: &Query) -> bool {
981 query.limit_clause.is_some() || query.fetch.is_some()
982 }
983
984 fn collect_table_aliases_from_table_factor(
985 table_factor: &TableFactor,
986 aliases: &mut HashSet<String>,
987 ) {
988 if let TableFactor::Table {
989 alias: Some(alias), ..
990 } = table_factor
991 {
992 aliases.insert(alias.name.value.clone());
993 }
994 }
995}
996
997struct CorrelationCheckVisitor<'a>(&'a mut bool, &'a HashSet<String>);
998
999impl Visitor for CorrelationCheckVisitor<'_> {
1000 type Break = ();
1001
1002 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
1003 match expr {
1004 Expr::Value(ValueWithSpan {
1005 value: Value::Placeholder(_placeholder),
1006 ..
1007 }) => {
1008 *self.0 = true;
1009 }
1010 Expr::CompoundIdentifier(idents) => {
1011 if !idents.is_empty() {
1012 let table_name = &idents[0].value;
1013 if !self.1.contains(table_name) {
1014 *self.0 = true;
1015 }
1016 }
1017 }
1018 _ => {}
1019 }
1020 ControlFlow::Continue(())
1021 }
1022}
1023
1024impl VisitorMut for RemoveSubqueryFromProjectionVisitor {
1025 type Break = ();
1026
1027 fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
1028 if let SetExpr::Select(select) = query.body.as_mut() {
1029 for projection in &mut select.projection {
1030 match projection {
1031 SelectItem::UnnamedExpr(expr) => {
1032 if let Expr::Subquery(subquery) = expr {
1033 if self.has_correlation(subquery) {
1034 *expr = Expr::Value(Value::Null.with_empty_span());
1035 } else if !self.has_limit(subquery) {
1036 subquery.limit_clause = Some(LimitClause::LimitOffset {
1037 limit: Some(Expr::Value(
1038 Value::Number("1".to_string(), false).with_empty_span(),
1039 )),
1040 offset: None,
1041 limit_by: vec![],
1042 });
1043 }
1044 }
1045 }
1046 SelectItem::ExprWithAlias { expr, .. } => {
1047 if let Expr::Subquery(subquery) = expr {
1048 if self.has_correlation(subquery) {
1049 *expr = Expr::Value(Value::Null.with_empty_span());
1050 } else if !self.has_limit(subquery) {
1051 subquery.limit_clause = Some(LimitClause::LimitOffset {
1052 limit: Some(Expr::Value(
1053 Value::Number("1".to_string(), false).with_empty_span(),
1054 )),
1055 offset: None,
1056 limit_by: vec![],
1057 });
1058 }
1059 }
1060 }
1061 _ => {}
1062 }
1063 }
1064 }
1065
1066 ControlFlow::Continue(())
1067 }
1068}
1069
1070impl SqlStatementRewriteRule for RemoveSubqueryFromProjection {
1071 fn rewrite(&self, mut s: Statement) -> Statement {
1072 let mut visitor = RemoveSubqueryFromProjectionVisitor;
1073 let _ = s.visit(&mut visitor);
1074
1075 s
1076 }
1077}
1078
1079#[derive(Debug)]
1081pub struct FixVersionColumnName;
1082
1083struct FixVersionColumnNameVisitor;
1084
1085impl VisitorMut for FixVersionColumnNameVisitor {
1086 type Break = ();
1087
1088 fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
1089 if let SetExpr::Select(select) = query.body.as_mut() {
1090 for projection in &mut select.projection {
1091 if let SelectItem::UnnamedExpr(Expr::Function(f)) = projection {
1092 if f.name.0.len() == 1 {
1093 if let ObjectNamePart::Identifier(part) = &f.name.0[0] {
1094 if part.value == "version" {
1095 if let FunctionArguments::List(args) = &f.args {
1096 if args.args.is_empty() {
1097 *projection = SelectItem::ExprWithAlias {
1098 expr: Expr::Function(f.clone()),
1099 alias: Ident::new("version"),
1100 }
1101 }
1102 }
1103 }
1104 }
1105 }
1106 }
1107 }
1108 }
1109
1110 ControlFlow::Continue(())
1111 }
1112}
1113
1114impl SqlStatementRewriteRule for FixVersionColumnName {
1115 fn rewrite(&self, mut s: Statement) -> Statement {
1116 let mut visitor = FixVersionColumnNameVisitor;
1117 let _ = s.visit(&mut visitor);
1118
1119 s
1120 }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125 use super::*;
1126 use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
1127 use datafusion::sql::sqlparser::parser::Parser;
1128 use datafusion::sql::sqlparser::parser::ParserError;
1129 use std::sync::Arc;
1130
1131 fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> {
1132 let dialect = PostgreSqlDialect {};
1133
1134 Parser::parse_sql(&dialect, sql)
1135 }
1136
1137 fn rewrite(mut s: Statement, rules: &[Arc<dyn SqlStatementRewriteRule>]) -> Statement {
1138 for rule in rules {
1139 s = rule.rewrite(s);
1140 }
1141
1142 s
1143 }
1144
1145 macro_rules! assert_rewrite {
1146 ($rules:expr, $orig:expr, $rewt:expr) => {
1147 let sql = $orig;
1148 let statement = parse(sql).expect("Failed to parse").remove(0);
1149
1150 let statement = rewrite(statement, $rules);
1151 assert_eq!(statement.to_string(), $rewt);
1152 };
1153 }
1154
1155 #[test]
1156 fn test_alias_rewrite() {
1157 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1158 vec![Arc::new(AliasDuplicatedProjectionRewrite)];
1159
1160 assert_rewrite!(
1161 &rules,
1162 "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n",
1163 "SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
1164 );
1165
1166 assert_rewrite!(
1167 &rules,
1168 "SELECT oid, * FROM pg_catalog.pg_namespace",
1169 "SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
1170 );
1171
1172 assert_rewrite!(
1173 &rules,
1174 "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id",
1175 "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
1176 );
1177
1178 let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
1179 let statement = parse(sql).expect("Failed to parse").remove(0);
1180
1181 let statement = rewrite(statement, &rules);
1182 assert_eq!(
1183 statement.to_string(),
1184 "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname"
1185 );
1186 }
1187
1188 #[test]
1189 fn test_qualifier_prepend() {
1190 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1191 vec![Arc::new(ResolveUnqualifiedIdentifer)];
1192
1193 assert_rewrite!(
1194 &rules,
1195 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname",
1196 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
1197 );
1198
1199 assert_rewrite!(
1200 &rules,
1201 "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname",
1202 "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
1203 );
1204
1205 assert_rewrite!(
1206 &rules,
1207 "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname",
1208 "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
1209 );
1210
1211 assert_rewrite!(&rules,
1212 "SELECT i.*,i.indkey as keys,c.relname,c.relnamespace,c.relam,c.reltablespace,tc.relname as tabrelname,dsc.description FROM pg_catalog.pg_index i INNER JOIN pg_catalog.pg_class c ON c.oid=i.indexrelid INNER JOIN pg_catalog.pg_class tc ON tc.oid=i.indrelid LEFT OUTER JOIN pg_catalog.pg_description dsc ON i.indexrelid=dsc.objoid WHERE i.indrelid=1 ORDER BY tabrelname, c.relname",
1213 "SELECT i.*, i.indkey AS keys, c.relname, c.relnamespace, c.relam, c.reltablespace, tc.relname AS tabrelname, dsc.description FROM pg_catalog.pg_index AS i INNER JOIN pg_catalog.pg_class AS c ON c.oid = i.indexrelid INNER JOIN pg_catalog.pg_class AS tc ON tc.oid = i.indrelid LEFT OUTER JOIN pg_catalog.pg_description AS dsc ON i.indexrelid = dsc.objoid WHERE i.indrelid = 1 ORDER BY tabrelname, c.relname"
1214 );
1215 }
1216
1217 #[test]
1218 fn test_remove_unsupported_types() {
1219 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
1220 Arc::new(RemoveQualifier),
1221 Arc::new(RemoveUnsupportedTypes::new()),
1222 ];
1223
1224 assert_rewrite!(
1225 &rules,
1226 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
1227 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
1228 );
1229
1230 assert_rewrite!(
1231 &rules,
1232 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
1233 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname"
1234 );
1235
1236 assert_rewrite!(
1237 &rules,
1238 "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname",
1239 "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname"
1240 );
1241
1242 assert_rewrite!(
1243 &rules,
1244 "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname",
1245 "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
1246 );
1247
1248 assert_rewrite!(
1249 &rules,
1250 "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::pg_catalog.regtype::pg_catalog.text END, c.relpersistence, c.relreplident, am.amname
1251 FROM pg_catalog.pg_class c
1252 LEFT JOIN pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid)
1253 LEFT JOIN pg_catalog.pg_am am ON (c.relam = am.oid)
1254 WHERE c.oid = '16386'",
1255 "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::TEXT END, c.relpersistence, c.relreplident, am.amname FROM pg_catalog.pg_class AS c LEFT JOIN pg_catalog.pg_class AS tc ON (c.reltoastrelid = tc.oid) LEFT JOIN pg_catalog.pg_am AS am ON (c.relam = am.oid) WHERE c.oid = '16386'"
1256 );
1257 }
1258
1259 #[test]
1260 fn test_rewrite_regclass_cast_to_subquery() {
1261 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1262 vec![Arc::new(RewriteRegclassCastToSubquery::new())];
1263
1264 assert_rewrite!(
1265 &rules,
1266 "SELECT $1::regclass::oid",
1267 "SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1268 );
1269
1270 assert_rewrite!(
1271 &rules,
1272 "SELECT $1::pg_catalog.regclass::oid",
1273 "SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1274 );
1275
1276 assert_rewrite!(
1277 &rules,
1278 "SELECT $1::pg_catalog.regclass::pg_catalog.oid",
1279 "SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1280 );
1281
1282 assert_rewrite!(
1283 &rules,
1284 "SELECT * FROM pg_catalog.pg_class WHERE oid = 't'::pg_catalog.regclass::pg_catalog.oid",
1285 "SELECT * FROM pg_catalog.pg_class WHERE oid = (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident('t'::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1286 );
1287 }
1288
1289 #[test]
1290 fn test_any_to_array_contains() {
1291 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1292 vec![Arc::new(RewriteArrayAnyAllOperation)];
1293
1294 assert_rewrite!(
1295 &rules,
1296 "SELECT a = ANY(current_schemas(true))",
1297 "SELECT array_contains(current_schemas(true), a)"
1298 );
1299
1300 assert_rewrite!(
1301 &rules,
1302 "SELECT a <> ALL(current_schemas(true))",
1303 "SELECT NOT array_contains(current_schemas(true), a)"
1304 );
1305
1306 assert_rewrite!(
1307 &rules,
1308 "SELECT a = ANY('{r, l, e}')",
1309 "SELECT array_contains(ARRAY['r', 'l', 'e'], a)"
1310 );
1311
1312 assert_rewrite!(
1313 &rules,
1314 "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))",
1315 "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)"
1316 );
1317 }
1318
1319 #[test]
1320 fn test_prepend_unqualified_table_name() {
1321 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1322 vec![Arc::new(PrependUnqualifiedPgTableName)];
1323
1324 assert_rewrite!(
1325 &rules,
1326 "SELECT * FROM pg_catalog.pg_namespace",
1327 "SELECT * FROM pg_catalog.pg_namespace"
1328 );
1329
1330 assert_rewrite!(
1331 &rules,
1332 "SELECT * FROM pg_namespace",
1333 "SELECT * FROM pg_catalog.pg_namespace"
1334 );
1335
1336 assert_rewrite!(
1337 &rules,
1338 "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid",
1339 "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid"
1340 );
1341 }
1342
1343 #[test]
1344 fn test_array_literal_fix() {
1345 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixArrayLiteral)];
1346
1347 assert_rewrite!(
1348 &rules,
1349 "SELECT '{a, abc}'::text[]",
1350 "SELECT ARRAY['a', 'abc']::TEXT[]"
1351 );
1352
1353 assert_rewrite!(
1354 &rules,
1355 "SELECT '{1, 2}'::int[]",
1356 "SELECT ARRAY[1, 2]::INT[]"
1357 );
1358
1359 assert_rewrite!(
1360 &rules,
1361 "SELECT '{t, f}'::bool[]",
1362 "SELECT ARRAY[t, f]::BOOL[]"
1363 );
1364 }
1365
1366 #[test]
1367 fn test_remove_qualifier_from_table_function() {
1368 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(RemoveQualifier)];
1369
1370 assert_rewrite!(
1371 &rules,
1372 "SELECT * FROM pg_catalog.pg_get_keywords()",
1373 "SELECT * FROM pg_get_keywords()"
1374 );
1375 }
1376
1377 #[test]
1378 fn test_current_user() {
1379 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1380 vec![Arc::new(CurrentUserVariableToSessionUserFunctionCall)];
1381
1382 assert_rewrite!(&rules, "SELECT current_user", "SELECT session_user");
1383
1384 assert_rewrite!(&rules, "SELECT CURRENT_USER", "SELECT session_user");
1385
1386 assert_rewrite!(
1387 &rules,
1388 "SELECT is_null(current_user)",
1389 "SELECT is_null(session_user)"
1390 );
1391 }
1392
1393 #[test]
1394 fn test_collate_fix() {
1395 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixCollate)];
1396
1397 assert_rewrite!(&rules, "SELECT c.oid, c.relname FROM pg_catalog.pg_class c WHERE c.relname OPERATOR(pg_catalog.~) '^(tablename)$' COLLATE pg_catalog.default AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3;", "SELECT c.oid, c.relname FROM pg_catalog.pg_class AS c WHERE c.relname ~ '^(tablename)$' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3");
1398 }
1399
1400 #[test]
1401 fn test_remove_subquery() {
1402 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1403 vec![Arc::new(RemoveSubqueryFromProjection)];
1404
1405 assert_rewrite!(&rules,
1406 "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;",
1407 "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum");
1408 }
1409
1410 #[test]
1411 fn test_keep_simple_aggregated_subquery() {
1412 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1413 vec![Arc::new(RemoveSubqueryFromProjection)];
1414
1415 assert_rewrite!(&rules,
1416 "SELECT id, (SELECT COUNT(*) FROM pg_catalog.pg_attribute) AS attr_count FROM pg_catalog.pg_class",
1417 "SELECT id, (SELECT COUNT(*) FROM pg_catalog.pg_attribute LIMIT 1) AS attr_count FROM pg_catalog.pg_class"
1418 );
1419 }
1420
1421 #[test]
1422 fn test_remove_correlated_subquery() {
1423 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1424 vec![Arc::new(RemoveSubqueryFromProjection)];
1425
1426 assert_rewrite!(&rules,
1427 "SELECT a.attname, (SELECT COUNT(*) FROM pg_catalog.pg_attribute WHERE attrelid = a.oid) AS count FROM pg_catalog.pg_attribute a",
1428 "SELECT a.attname, NULL AS count FROM pg_catalog.pg_attribute AS a"
1429 );
1430 }
1431
1432 #[test]
1433 fn test_remove_non_aggregated_subquery() {
1434 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1435 vec![Arc::new(RemoveSubqueryFromProjection)];
1436
1437 assert_rewrite!(&rules,
1438 "SELECT id, (SELECT attname FROM pg_catalog.pg_attribute LIMIT 1) AS first_attr FROM pg_catalog.pg_class",
1439 "SELECT id, (SELECT attname FROM pg_catalog.pg_attribute LIMIT 1) AS first_attr FROM pg_catalog.pg_class"
1440 );
1441 }
1442
1443 #[test]
1444 fn test_keep_simple_scalar_subquery() {
1445 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1446 vec![Arc::new(RemoveSubqueryFromProjection)];
1447
1448 assert_rewrite!(
1449 &rules,
1450 "SELECT (SELECT 1) AS constant",
1451 "SELECT (SELECT 1 LIMIT 1) AS constant"
1452 );
1453
1454 assert_rewrite!(
1455 &rules,
1456 "SELECT (SELECT 'value') AS str_val",
1457 "SELECT (SELECT 'value' LIMIT 1) AS str_val"
1458 );
1459 }
1460
1461 #[test]
1462 fn test_version_rewrite() {
1463 let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixVersionColumnName)];
1464
1465 assert_rewrite!(&rules, "SELECT version()", "SELECT version() AS version");
1466
1467 assert_rewrite!(&rules, "SELECT version() as foo", "SELECT version() AS foo");
1469 assert_rewrite!(&rules, "SELECT version(foo)", "SELECT version(foo)");
1470 assert_rewrite!(&rules, "SELECT foo.version()", "SELECT foo.version()");
1471 }
1472}