Skip to main content

datafusion_pg_catalog/sql/
rules.rs

1use std::collections::HashSet;
2use std::fmt::Debug;
3use std::ops::ControlFlow;
4
5use datafusion::sql::sqlparser::ast::Array;
6use datafusion::sql::sqlparser::ast::ArrayElemTypeDef;
7use datafusion::sql::sqlparser::ast::BinaryOperator;
8use datafusion::sql::sqlparser::ast::CastKind;
9use datafusion::sql::sqlparser::ast::DataType;
10use datafusion::sql::sqlparser::ast::Expr;
11use datafusion::sql::sqlparser::ast::Function;
12use datafusion::sql::sqlparser::ast::FunctionArg;
13use datafusion::sql::sqlparser::ast::FunctionArgExpr;
14use datafusion::sql::sqlparser::ast::FunctionArgumentList;
15use datafusion::sql::sqlparser::ast::FunctionArguments;
16use datafusion::sql::sqlparser::ast::Ident;
17use datafusion::sql::sqlparser::ast::LimitClause;
18use datafusion::sql::sqlparser::ast::ObjectName;
19use datafusion::sql::sqlparser::ast::ObjectNamePart;
20use datafusion::sql::sqlparser::ast::OrderByKind;
21use datafusion::sql::sqlparser::ast::Query;
22use datafusion::sql::sqlparser::ast::Select;
23use datafusion::sql::sqlparser::ast::SelectItem;
24use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
25use datafusion::sql::sqlparser::ast::SetExpr;
26use datafusion::sql::sqlparser::ast::Statement;
27use datafusion::sql::sqlparser::ast::TableFactor;
28use datafusion::sql::sqlparser::ast::TableWithJoins;
29use datafusion::sql::sqlparser::ast::TypedString;
30use datafusion::sql::sqlparser::ast::UnaryOperator;
31use datafusion::sql::sqlparser::ast::Value;
32use datafusion::sql::sqlparser::ast::ValueWithSpan;
33use datafusion::sql::sqlparser::ast::VisitMut;
34use datafusion::sql::sqlparser::ast::Visitor;
35use datafusion::sql::sqlparser::ast::VisitorMut;
36use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
37use datafusion::sql::sqlparser::parser::Parser;
38
39pub trait SqlStatementRewriteRule: Send + Sync + Debug {
40    fn rewrite(&self, s: Statement) -> Statement;
41}
42
43/// Rewrite rule for adding alias to duplicated projection
44///
45/// This rule is to deal with sql like `SELECT n.oid, n.* FROM n`, which is a
46/// valid statement in postgres. But datafusion treat it as illegal because of
47/// duplicated column oid in projection.
48///
49/// This rule will add alias to column, when there is a wildcard found in
50/// projection.
51#[derive(Debug)]
52pub struct AliasDuplicatedProjectionRewrite;
53
54impl AliasDuplicatedProjectionRewrite {
55    // Rewrites a SELECT statement to alias explicit columns from the same table as a qualified wildcard.
56    fn rewrite_select_with_alias(select: &mut Box<Select>) {
57        // 1. Collect all table aliases from qualified wildcards.
58        let mut wildcard_tables = Vec::new();
59        let mut has_simple_wildcard = false;
60        for p in &select.projection {
61            match p {
62                SelectItem::QualifiedWildcard(name, _) => match name {
63                    SelectItemQualifiedWildcardKind::ObjectName(objname) => {
64                        // for n.oid,
65                        let idents = objname
66                            .0
67                            .iter()
68                            .map(|v| v.as_ident().unwrap().value.clone())
69                            .collect::<Vec<_>>()
70                            .join(".");
71
72                        wildcard_tables.push(idents);
73                    }
74                    SelectItemQualifiedWildcardKind::Expr(_expr) => {
75                        // FIXME:
76                    }
77                },
78                SelectItem::Wildcard(_) => {
79                    has_simple_wildcard = true;
80                }
81                _ => {}
82            }
83        }
84
85        // If there are no qualified wildcards, there's nothing to do.
86        if wildcard_tables.is_empty() && !has_simple_wildcard {
87            return;
88        }
89
90        // 2. Rewrite the projection, adding aliases to matching columns.
91        let mut new_projection = vec![];
92        for p in select.projection.drain(..) {
93            match p {
94                SelectItem::UnnamedExpr(expr) => {
95                    let alias_partial = match &expr {
96                        // Case for `oid` (unqualified identifier)
97                        Expr::Identifier(ident) => Some(ident.clone()),
98                        // Case for `n.oid` (compound identifier)
99                        Expr::CompoundIdentifier(idents) => {
100                            // compare every ident but the last
101                            if idents.len() > 1 {
102                                let table_name = &idents[..idents.len() - 1]
103                                    .iter()
104                                    .map(|i| i.value.clone())
105                                    .collect::<Vec<_>>()
106                                    .join(".");
107                                if wildcard_tables.iter().any(|name| name == table_name) {
108                                    Some(idents[idents.len() - 1].clone())
109                                } else {
110                                    None
111                                }
112                            } else {
113                                None
114                            }
115                        }
116                        _ => None,
117                    };
118
119                    if let Some(name) = alias_partial {
120                        let alias = format!("__alias_{name}");
121                        new_projection.push(SelectItem::ExprWithAlias {
122                            expr,
123                            alias: Ident::new(alias),
124                        });
125                    } else {
126                        new_projection.push(SelectItem::UnnamedExpr(expr));
127                    }
128                }
129                // Preserve existing aliases and wildcards.
130                _ => new_projection.push(p),
131            }
132        }
133        select.projection = new_projection;
134    }
135}
136
137impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
138    fn rewrite(&self, mut statement: Statement) -> Statement {
139        if let Statement::Query(query) = &mut statement {
140            if let SetExpr::Select(select) = query.body.as_mut() {
141                Self::rewrite_select_with_alias(select);
142            }
143        }
144
145        statement
146    }
147}
148
149/// Prepend qualifier for order by or filter when there is qualified wildcard
150///
151/// Postgres allows unqualified identifier in ORDER BY and FILTER but it's not
152/// accepted by datafusion.
153#[derive(Debug)]
154pub struct ResolveUnqualifiedIdentifer;
155
156impl ResolveUnqualifiedIdentifer {
157    fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
158        if let SetExpr::Select(select) = query.body.as_mut() {
159            // Step 1: Find all table aliases from FROM and JOIN clauses.
160            let table_aliases = Self::get_table_aliases(&select.from);
161
162            // Step 2: Check for a single qualified wildcard in the projection.
163            let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
164            if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
165                return; // Conditions not met.
166            }
167
168            let wildcard_alias = qualified_wildcard_alias.unwrap();
169
170            // Step 2.5: Collect all projection aliases to avoid rewriting them
171            let projection_aliases = Self::get_projection_aliases(&select.projection);
172
173            // Step 3: Rewrite expressions in the WHERE and ORDER BY clauses.
174            if let Some(selection) = &mut select.selection {
175                Self::rewrite_expr(
176                    selection,
177                    &wildcard_alias,
178                    &table_aliases,
179                    &projection_aliases,
180                );
181            }
182
183            if let Some(OrderByKind::Expressions(order_by_exprs)) =
184                query.order_by.as_mut().map(|o| &mut o.kind)
185            {
186                for order_by_expr in order_by_exprs {
187                    Self::rewrite_expr(
188                        &mut order_by_expr.expr,
189                        &wildcard_alias,
190                        &table_aliases,
191                        &projection_aliases,
192                    );
193                }
194            }
195        }
196    }
197
198    fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
199        let mut aliases = HashSet::new();
200        for table_with_joins in tables {
201            if let TableFactor::Table {
202                alias: Some(alias), ..
203            } = &table_with_joins.relation
204            {
205                aliases.insert(alias.name.value.clone());
206            }
207            for join in &table_with_joins.joins {
208                if let TableFactor::Table {
209                    alias: Some(alias), ..
210                } = &join.relation
211                {
212                    aliases.insert(alias.name.value.clone());
213                }
214            }
215        }
216        aliases
217    }
218
219    fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
220        let mut qualified_wildcards = projection
221            .iter()
222            .filter_map(|item| {
223                if let SelectItem::QualifiedWildcard(
224                    SelectItemQualifiedWildcardKind::ObjectName(objname),
225                    _,
226                ) = item
227                {
228                    Some(
229                        objname
230                            .0
231                            .iter()
232                            .map(|v| v.as_ident().unwrap().value.clone())
233                            .collect::<Vec<_>>()
234                            .join("."),
235                    )
236                } else {
237                    None
238                }
239            })
240            .collect::<Vec<_>>();
241
242        if qualified_wildcards.len() == 1 {
243            Some(qualified_wildcards.remove(0))
244        } else {
245            None
246        }
247    }
248
249    fn get_projection_aliases(projection: &[SelectItem]) -> HashSet<String> {
250        let mut aliases = HashSet::new();
251        for item in projection {
252            match item {
253                SelectItem::ExprWithAlias { alias, .. } => {
254                    aliases.insert(alias.value.clone());
255                }
256                SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
257                    aliases.insert(ident.value.clone());
258                }
259                _ => {}
260            }
261        }
262        aliases
263    }
264
265    fn rewrite_expr(
266        expr: &mut Expr,
267        wildcard_alias: &str,
268        table_aliases: &HashSet<String>,
269        projection_aliases: &HashSet<String>,
270    ) {
271        match expr {
272            Expr::Identifier(ident) => {
273                // If the identifier is not a table alias itself and not already aliased in projection, rewrite it.
274                if !table_aliases.contains(&ident.value)
275                    && !projection_aliases.contains(&ident.value)
276                {
277                    *expr = Expr::CompoundIdentifier(vec![
278                        Ident::new(wildcard_alias.to_string()),
279                        ident.clone(),
280                    ]);
281                }
282            }
283            Expr::BinaryOp { left, right, .. } => {
284                Self::rewrite_expr(left, wildcard_alias, table_aliases, projection_aliases);
285                Self::rewrite_expr(right, wildcard_alias, table_aliases, projection_aliases);
286            }
287            // Add more cases for other expression types as needed (e.g., `InList`, `Between`, etc.)
288            _ => {}
289        }
290    }
291}
292
293impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
294    fn rewrite(&self, mut statement: Statement) -> Statement {
295        if let Statement::Query(query) = &mut statement {
296            Self::rewrite_unqualified_identifiers(query);
297        }
298
299        statement
300    }
301}
302
303/// Remove datafusion unsupported type annotations
304/// it also removes pg_catalog as qualifier
305#[derive(Debug)]
306pub struct RemoveUnsupportedTypes {
307    unsupported_types: HashSet<String>,
308}
309
310impl Default for RemoveUnsupportedTypes {
311    fn default() -> Self {
312        Self::new()
313    }
314}
315
316impl RemoveUnsupportedTypes {
317    pub fn new() -> Self {
318        let mut unsupported_types = HashSet::new();
319
320        for item in [
321            "regclass",
322            "regproc",
323            "regtype",
324            "regtype[]",
325            "regnamespace",
326            "oid",
327        ] {
328            unsupported_types.insert(item.to_owned());
329            unsupported_types.insert(format!("pg_catalog.{item}"));
330        }
331
332        Self { unsupported_types }
333    }
334}
335
336struct RemoveUnsupportedTypesVisitor<'a> {
337    unsupported_types: &'a HashSet<String>,
338}
339
340impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> {
341    type Break = ();
342
343    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
344        match expr {
345            // This is the key part: identify constants with type annotations.
346            Expr::TypedString(TypedString {
347                data_type,
348                value,
349                uses_odbc_syntax: _,
350            }) => {
351                if self
352                    .unsupported_types
353                    .contains(data_type.to_string().to_lowercase().as_str())
354                {
355                    *expr =
356                        Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span());
357                }
358            }
359            Expr::Cast {
360                data_type,
361                expr: value,
362                ..
363            } => {
364                if self
365                    .unsupported_types
366                    .contains(data_type.to_string().to_lowercase().as_str())
367                {
368                    *expr = *value.clone();
369                }
370            }
371            // Add more match arms for other expression types (e.g., `Function`, `InList`) as needed.
372            _ => {}
373        }
374
375        ControlFlow::Continue(())
376    }
377}
378
379impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
380    fn rewrite(&self, mut statement: Statement) -> Statement {
381        let mut visitor = RemoveUnsupportedTypesVisitor {
382            unsupported_types: &self.unsupported_types,
383        };
384        let _ = statement.visit(&mut visitor);
385        statement
386    }
387}
388
389/// Rewrite regclass::oid cast to subquery
390///
391/// This rewrites patterns like `$1::regclass::oid` to
392/// `(SELECT oid FROM pg_catalog.pg_class WHERE relname = $1)`
393#[derive(Debug)]
394pub struct RewriteRegclassCastToSubquery(Box<Query>);
395
396impl Default for RewriteRegclassCastToSubquery {
397    fn default() -> Self {
398        Self::new()
399    }
400}
401
402impl RewriteRegclassCastToSubquery {
403    pub fn new() -> Self {
404        let sql = "SELECT c.oid
405FROM pg_catalog.pg_class c
406JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
407CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) p
408WHERE n.nspname = COALESCE(
409    CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END,
410    current_schema()
411)
412AND c.relname = p.parts[-1]";
413        let dialect = PostgreSqlDialect {};
414        let query = Parser::parse_sql(&dialect, sql)
415            .map(|mut stmts| {
416                let stmt = stmts.remove(0);
417                if let Statement::Query(query) = stmt {
418                    query
419                } else {
420                    unreachable!()
421                }
422            })
423            .expect("Failed to parse prepared query");
424        Self(query)
425    }
426}
427
428struct RewriteRegclassCastToSubqueryVisitor(Box<Query>);
429
430impl RewriteRegclassCastToSubqueryVisitor {
431    pub fn new(query: Box<Query>) -> Self {
432        Self(query)
433    }
434
435    fn create_subquery(&self, expr: &Expr) -> Expr {
436        struct PlaceholderReplacer(Expr);
437
438        impl VisitorMut for PlaceholderReplacer {
439            type Break = ();
440
441            fn pre_visit_expr(&mut self, e: &mut Expr) -> ControlFlow<Self::Break> {
442                if let Expr::Value(ValueWithSpan {
443                    value: Value::Placeholder(_placeholder),
444                    ..
445                }) = e
446                {
447                    *e = self.0.clone();
448                }
449                ControlFlow::Continue(())
450            }
451        }
452
453        let mut query = self.0.clone();
454        let mut replacer = PlaceholderReplacer(expr.clone());
455        let _ = query.visit(&mut replacer);
456        Expr::Subquery(query)
457    }
458
459    fn is_regclass_to_oid_cast(&self, expr: &Expr) -> bool {
460        if let Expr::Cast {
461            kind,
462            data_type,
463            expr: inner_expr,
464            format: _,
465            ..
466        } = expr
467        {
468            if *kind == CastKind::DoubleColon {
469                let dt_lower = data_type.to_string().to_lowercase();
470                if dt_lower == "oid" || dt_lower == "pg_catalog.oid" {
471                    return self.is_regclass_cast(inner_expr);
472                }
473            }
474        }
475        false
476    }
477
478    fn is_regclass_cast(&self, expr: &Expr) -> bool {
479        if let Expr::Cast {
480            kind,
481            data_type,
482            expr: _,
483            format: _,
484            ..
485        } = expr
486        {
487            if *kind == CastKind::DoubleColon {
488                let dt_lower = data_type.to_string().to_lowercase();
489                return dt_lower == "regclass" || dt_lower == "pg_catalog.regclass";
490            }
491        }
492        false
493    }
494
495    fn extract_inner_expr(&self, expr: &Expr) -> Option<Expr> {
496        if let Expr::Cast {
497            kind,
498            data_type,
499            expr: inner_expr,
500            format: _,
501            ..
502        } = expr
503        {
504            if *kind == CastKind::DoubleColon {
505                let dt_lower = data_type.to_string().to_lowercase();
506                if dt_lower == "oid" || dt_lower == "pg_catalog.oid" {
507                    if let Expr::Cast {
508                        kind: inner_kind,
509                        data_type: inner_data_type,
510                        expr: inner_inner_expr,
511                        format: _,
512                        ..
513                    } = inner_expr.as_ref()
514                    {
515                        if *inner_kind == CastKind::DoubleColon {
516                            let inner_dt_lower = inner_data_type.to_string().to_lowercase();
517                            if inner_dt_lower == "regclass"
518                                || inner_dt_lower == "pg_catalog.regclass"
519                            {
520                                return Some((**inner_inner_expr).clone());
521                            }
522                        }
523                    }
524                }
525            }
526        }
527        None
528    }
529}
530
531impl VisitorMut for RewriteRegclassCastToSubqueryVisitor {
532    type Break = ();
533
534    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
535        if self.is_regclass_to_oid_cast(expr) {
536            if let Some(inner_expr) = self.extract_inner_expr(expr) {
537                *expr = self.create_subquery(&inner_expr);
538            }
539        }
540        ControlFlow::Continue(())
541    }
542}
543
544impl SqlStatementRewriteRule for RewriteRegclassCastToSubquery {
545    fn rewrite(&self, mut s: Statement) -> Statement {
546        let mut visitor = RewriteRegclassCastToSubqueryVisitor::new(self.0.clone());
547        let _ = s.visit(&mut visitor);
548        s
549    }
550}
551
552/// Rewrite Postgres's ANY operator to array_contains
553#[derive(Debug)]
554pub struct RewriteArrayAnyAllOperation;
555
556struct RewriteArrayAnyAllOperationVisitor;
557
558impl RewriteArrayAnyAllOperationVisitor {
559    fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr {
560        let array = if let Expr::Value(ValueWithSpan {
561            value: Value::SingleQuotedString(array_literal),
562            ..
563        }) = right
564        {
565            let array_literal = array_literal.trim();
566            if array_literal.starts_with('{') && array_literal.ends_with('}') {
567                let items = array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
568                let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
569
570                // For now, we assume the data type is string
571                let elems = items
572                    .map(|s| {
573                        Expr::Value(Value::SingleQuotedString(s.to_string()).with_empty_span())
574                    })
575                    .collect();
576                Expr::Array(Array {
577                    elem: elems,
578                    named: true,
579                })
580            } else {
581                right.clone()
582            }
583        } else {
584            right.clone()
585        };
586
587        Expr::Function(Function {
588            name: ObjectName::from(vec![Ident::new("array_contains")]),
589            args: FunctionArguments::List(FunctionArgumentList {
590                args: vec![
591                    FunctionArg::Unnamed(FunctionArgExpr::Expr(array)),
592                    FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())),
593                ],
594                duplicate_treatment: None,
595                clauses: vec![],
596            }),
597            uses_odbc_syntax: false,
598            parameters: FunctionArguments::None,
599            filter: None,
600            null_treatment: None,
601            over: None,
602            within_group: vec![],
603        })
604    }
605}
606
607impl VisitorMut for RewriteArrayAnyAllOperationVisitor {
608    type Break = ();
609
610    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
611        match expr {
612            Expr::AnyOp {
613                left,
614                compare_op,
615                right,
616                ..
617            } => match compare_op {
618                BinaryOperator::Eq => {
619                    *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref());
620                }
621                BinaryOperator::NotEq => {
622                    // TODO:left not equals to any element in array
623                }
624                _ => {}
625            },
626            Expr::AllOp {
627                left,
628                compare_op,
629                right,
630            } => match compare_op {
631                BinaryOperator::Eq => {
632                    // TODO: left equals to every element in array
633                }
634                BinaryOperator::NotEq => {
635                    *expr = Expr::UnaryOp {
636                        op: UnaryOperator::Not,
637                        expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())),
638                    }
639                }
640                _ => {}
641            },
642            _ => {}
643        }
644
645        ControlFlow::Continue(())
646    }
647}
648
649impl SqlStatementRewriteRule for RewriteArrayAnyAllOperation {
650    fn rewrite(&self, mut s: Statement) -> Statement {
651        let mut visitor = RewriteArrayAnyAllOperationVisitor;
652
653        let _ = s.visit(&mut visitor);
654
655        s
656    }
657}
658
659/// Prepend qualifier to table_name
660///
661/// Postgres has pg_catalog in search_path by default so it allow access to
662/// `pg_namespace` without `pg_catalog.` qualifier
663#[derive(Debug)]
664pub struct PrependUnqualifiedPgTableName;
665
666struct PrependUnqualifiedPgTableNameVisitor;
667
668impl VisitorMut for PrependUnqualifiedPgTableNameVisitor {
669    type Break = ();
670
671    fn pre_visit_table_factor(
672        &mut self,
673        table_factor: &mut TableFactor,
674    ) -> ControlFlow<Self::Break> {
675        if let TableFactor::Table { name, args, .. } = table_factor {
676            // not a table function
677            if args.is_none() && name.0.len() == 1 {
678                if let ObjectNamePart::Identifier(ident) = &name.0[0] {
679                    if ident.value.starts_with("pg_") {
680                        *name = ObjectName(vec![
681                            ObjectNamePart::Identifier(Ident::new("pg_catalog")),
682                            name.0[0].clone(),
683                        ]);
684                    }
685                }
686            }
687        }
688
689        ControlFlow::Continue(())
690    }
691}
692
693impl SqlStatementRewriteRule for PrependUnqualifiedPgTableName {
694    fn rewrite(&self, mut s: Statement) -> Statement {
695        let mut visitor = PrependUnqualifiedPgTableNameVisitor;
696
697        let _ = s.visit(&mut visitor);
698        s
699    }
700}
701
702#[derive(Debug)]
703pub struct FixArrayLiteral;
704
705struct FixArrayLiteralVisitor;
706
707impl FixArrayLiteralVisitor {
708    fn is_string_type(dt: &DataType) -> bool {
709        matches!(
710            dt,
711            DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::String(_)
712        )
713    }
714}
715
716impl VisitorMut for FixArrayLiteralVisitor {
717    type Break = ();
718
719    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
720        if let Expr::Cast {
721            kind,
722            expr,
723            data_type,
724            ..
725        } = expr
726        {
727            if kind == &CastKind::DoubleColon {
728                if let DataType::Array(arr) = data_type {
729                    // cast some to
730                    if let Expr::Value(ValueWithSpan {
731                        value: Value::SingleQuotedString(array_literal),
732                        ..
733                    }) = expr.as_ref()
734                    {
735                        let items =
736                            array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
737                        let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
738
739                        let is_text = match arr {
740                            ArrayElemTypeDef::AngleBracket(dt) => Self::is_string_type(dt.as_ref()),
741                            ArrayElemTypeDef::SquareBracket(dt, _) => {
742                                Self::is_string_type(dt.as_ref())
743                            }
744                            ArrayElemTypeDef::Parenthesis(dt) => Self::is_string_type(dt.as_ref()),
745                            _ => false,
746                        };
747
748                        let elems = items
749                            .map(|s| {
750                                if is_text {
751                                    Expr::Value(
752                                        Value::SingleQuotedString(s.to_string()).with_empty_span(),
753                                    )
754                                } else {
755                                    Expr::Value(
756                                        Value::Number(s.to_string(), false).with_empty_span(),
757                                    )
758                                }
759                            })
760                            .collect();
761                        **expr = Expr::Array(Array {
762                            elem: elems,
763                            named: true,
764                        });
765                    }
766                }
767            }
768        }
769
770        ControlFlow::Continue(())
771    }
772}
773
774impl SqlStatementRewriteRule for FixArrayLiteral {
775    fn rewrite(&self, mut s: Statement) -> Statement {
776        let mut visitor = FixArrayLiteralVisitor;
777
778        let _ = s.visit(&mut visitor);
779        s
780    }
781}
782
783/// Remove qualifier from unsupported items
784///
785/// This rewriter removes qualifier from following items:
786/// 1. type cast: for example: `pg_catalog.text`
787/// 2. function name: for example: `pg_catalog.array_to_string`,
788/// 3. table function name
789#[derive(Debug)]
790pub struct RemoveQualifier;
791
792struct RemoveQualifierVisitor;
793
794impl VisitorMut for RemoveQualifierVisitor {
795    type Break = ();
796
797    fn pre_visit_table_factor(
798        &mut self,
799        table_factor: &mut TableFactor,
800    ) -> ControlFlow<Self::Break> {
801        // remove table function qualifier
802        if let TableFactor::Table { name, args, .. } = table_factor {
803            if args.is_some() {
804                //  multiple idents in name, which means it's a qualified table name
805                if name.0.len() > 1 {
806                    if let Some(last_ident) = name.0.pop() {
807                        *name = ObjectName(vec![last_ident]);
808                    }
809                }
810            }
811        }
812        ControlFlow::Continue(())
813    }
814
815    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
816        match expr {
817            Expr::Cast { data_type, .. } => {
818                // rewrite custom pg_catalog. qualified types
819                let data_type_str = data_type.to_string();
820                match data_type_str.as_str() {
821                    "pg_catalog.text" => {
822                        *data_type = DataType::Text;
823                    }
824                    "pg_catalog.int2[]" => {
825                        *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket(
826                            Box::new(DataType::Int16),
827                            None,
828                        ));
829                    }
830                    _ => {}
831                }
832            }
833            Expr::Function(function) => {
834                // remove qualifier from pg_catalog.function
835                let name = &mut function.name;
836                if name.0.len() > 1 {
837                    if let Some(last_ident) = name.0.pop() {
838                        *name = ObjectName(vec![last_ident]);
839                    }
840                }
841            }
842
843            _ => {}
844        }
845        ControlFlow::Continue(())
846    }
847}
848
849impl SqlStatementRewriteRule for RemoveQualifier {
850    fn rewrite(&self, mut s: Statement) -> Statement {
851        let mut visitor = RemoveQualifierVisitor;
852
853        let _ = s.visit(&mut visitor);
854        s
855    }
856}
857
858/// Replace `current_user` with `session_user()`
859#[derive(Debug)]
860pub struct CurrentUserVariableToSessionUserFunctionCall;
861
862struct CurrentUserVariableToSessionUserFunctionCallVisitor;
863
864impl VisitorMut for CurrentUserVariableToSessionUserFunctionCallVisitor {
865    type Break = ();
866
867    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
868        if let Expr::Identifier(ident) = expr {
869            if ident.quote_style.is_none() && ident.value.to_lowercase() == "current_user" {
870                *expr = Expr::Function(Function {
871                    name: ObjectName::from(vec![Ident::new("session_user")]),
872                    args: FunctionArguments::None,
873                    uses_odbc_syntax: false,
874                    parameters: FunctionArguments::None,
875                    filter: None,
876                    null_treatment: None,
877                    over: None,
878                    within_group: vec![],
879                });
880            }
881        }
882
883        if let Expr::Function(func) = expr {
884            let fname = func
885                .name
886                .0
887                .iter()
888                .map(|ident| ident.to_string())
889                .collect::<Vec<String>>()
890                .join(".");
891            if fname.to_lowercase() == "current_user" {
892                func.name = ObjectName::from(vec![Ident::new("session_user")])
893            }
894        }
895
896        ControlFlow::Continue(())
897    }
898}
899
900impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall {
901    fn rewrite(&self, mut s: Statement) -> Statement {
902        let mut visitor = CurrentUserVariableToSessionUserFunctionCallVisitor;
903
904        let _ = s.visit(&mut visitor);
905        s
906    }
907}
908
909/// Fix collate and regex calls
910#[derive(Debug)]
911pub struct FixCollate;
912
913struct FixCollateVisitor;
914
915impl VisitorMut for FixCollateVisitor {
916    type Break = ();
917
918    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
919        match expr {
920            Expr::Collate { expr: inner, .. } => {
921                *expr = inner.as_ref().clone();
922            }
923            Expr::BinaryOp { op, .. } => {
924                if let BinaryOperator::PGCustomBinaryOperator(ops) = op {
925                    if *ops == ["pg_catalog", "~"] {
926                        *op = BinaryOperator::PGRegexMatch;
927                    }
928                }
929            }
930            _ => {}
931        }
932
933        ControlFlow::Continue(())
934    }
935}
936
937impl SqlStatementRewriteRule for FixCollate {
938    fn rewrite(&self, mut s: Statement) -> Statement {
939        let mut visitor = FixCollateVisitor;
940
941        let _ = s.visit(&mut visitor);
942        s
943    }
944}
945
946/// A processor to replace unsupported subquery from projection with NULL.
947///
948/// It will also add `LIMIT 1` to supported subquery to ensure it returns scalar
949/// value.
950#[derive(Debug)]
951pub struct RemoveSubqueryFromProjection;
952
953struct RemoveSubqueryFromProjectionVisitor;
954
955impl RemoveSubqueryFromProjectionVisitor {
956    fn has_correlation(&self, query: &Query) -> bool {
957        if let SetExpr::Select(select) = &*query.body {
958            let table_aliases: HashSet<String> = select
959                .from
960                .iter()
961                .flat_map(|twj| {
962                    let mut aliases = HashSet::new();
963                    Self::collect_table_aliases_from_table_factor(&twj.relation, &mut aliases);
964                    for join in &twj.joins {
965                        Self::collect_table_aliases_from_table_factor(&join.relation, &mut aliases);
966                    }
967                    aliases
968                })
969                .collect();
970
971            let mut has_correlation = false;
972            let mut visitor = CorrelationCheckVisitor(&mut has_correlation, &table_aliases);
973            let _ = datafusion::logical_expr::sqlparser::ast::Visit::visit(query, &mut visitor);
974            has_correlation
975        } else {
976            false
977        }
978    }
979
980    fn has_limit(&self, query: &Query) -> bool {
981        query.limit_clause.is_some() || query.fetch.is_some()
982    }
983
984    fn collect_table_aliases_from_table_factor(
985        table_factor: &TableFactor,
986        aliases: &mut HashSet<String>,
987    ) {
988        if let TableFactor::Table {
989            alias: Some(alias), ..
990        } = table_factor
991        {
992            aliases.insert(alias.name.value.clone());
993        }
994    }
995}
996
997struct CorrelationCheckVisitor<'a>(&'a mut bool, &'a HashSet<String>);
998
999impl Visitor for CorrelationCheckVisitor<'_> {
1000    type Break = ();
1001
1002    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
1003        match expr {
1004            Expr::Value(ValueWithSpan {
1005                value: Value::Placeholder(_placeholder),
1006                ..
1007            }) => {
1008                *self.0 = true;
1009            }
1010            Expr::CompoundIdentifier(idents) => {
1011                if !idents.is_empty() {
1012                    let table_name = &idents[0].value;
1013                    if !self.1.contains(table_name) {
1014                        *self.0 = true;
1015                    }
1016                }
1017            }
1018            _ => {}
1019        }
1020        ControlFlow::Continue(())
1021    }
1022}
1023
1024impl VisitorMut for RemoveSubqueryFromProjectionVisitor {
1025    type Break = ();
1026
1027    fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
1028        if let SetExpr::Select(select) = query.body.as_mut() {
1029            for projection in &mut select.projection {
1030                match projection {
1031                    SelectItem::UnnamedExpr(expr) => {
1032                        if let Expr::Subquery(subquery) = expr {
1033                            if self.has_correlation(subquery) {
1034                                *expr = Expr::Value(Value::Null.with_empty_span());
1035                            } else if !self.has_limit(subquery) {
1036                                subquery.limit_clause = Some(LimitClause::LimitOffset {
1037                                    limit: Some(Expr::Value(
1038                                        Value::Number("1".to_string(), false).with_empty_span(),
1039                                    )),
1040                                    offset: None,
1041                                    limit_by: vec![],
1042                                });
1043                            }
1044                        }
1045                    }
1046                    SelectItem::ExprWithAlias { expr, .. } => {
1047                        if let Expr::Subquery(subquery) = expr {
1048                            if self.has_correlation(subquery) {
1049                                *expr = Expr::Value(Value::Null.with_empty_span());
1050                            } else if !self.has_limit(subquery) {
1051                                subquery.limit_clause = Some(LimitClause::LimitOffset {
1052                                    limit: Some(Expr::Value(
1053                                        Value::Number("1".to_string(), false).with_empty_span(),
1054                                    )),
1055                                    offset: None,
1056                                    limit_by: vec![],
1057                                });
1058                            }
1059                        }
1060                    }
1061                    _ => {}
1062                }
1063            }
1064        }
1065
1066        ControlFlow::Continue(())
1067    }
1068}
1069
1070impl SqlStatementRewriteRule for RemoveSubqueryFromProjection {
1071    fn rewrite(&self, mut s: Statement) -> Statement {
1072        let mut visitor = RemoveSubqueryFromProjectionVisitor;
1073        let _ = s.visit(&mut visitor);
1074
1075        s
1076    }
1077}
1078
1079/// `select version()` should return column named `version` not `version()`
1080#[derive(Debug)]
1081pub struct FixVersionColumnName;
1082
1083struct FixVersionColumnNameVisitor;
1084
1085impl VisitorMut for FixVersionColumnNameVisitor {
1086    type Break = ();
1087
1088    fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
1089        if let SetExpr::Select(select) = query.body.as_mut() {
1090            for projection in &mut select.projection {
1091                if let SelectItem::UnnamedExpr(Expr::Function(f)) = projection {
1092                    if f.name.0.len() == 1 {
1093                        if let ObjectNamePart::Identifier(part) = &f.name.0[0] {
1094                            if part.value == "version" {
1095                                if let FunctionArguments::List(args) = &f.args {
1096                                    if args.args.is_empty() {
1097                                        *projection = SelectItem::ExprWithAlias {
1098                                            expr: Expr::Function(f.clone()),
1099                                            alias: Ident::new("version"),
1100                                        }
1101                                    }
1102                                }
1103                            }
1104                        }
1105                    }
1106                }
1107            }
1108        }
1109
1110        ControlFlow::Continue(())
1111    }
1112}
1113
1114impl SqlStatementRewriteRule for FixVersionColumnName {
1115    fn rewrite(&self, mut s: Statement) -> Statement {
1116        let mut visitor = FixVersionColumnNameVisitor;
1117        let _ = s.visit(&mut visitor);
1118
1119        s
1120    }
1121}
1122
1123#[cfg(test)]
1124mod tests {
1125    use super::*;
1126    use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
1127    use datafusion::sql::sqlparser::parser::Parser;
1128    use datafusion::sql::sqlparser::parser::ParserError;
1129    use std::sync::Arc;
1130
1131    fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> {
1132        let dialect = PostgreSqlDialect {};
1133
1134        Parser::parse_sql(&dialect, sql)
1135    }
1136
1137    fn rewrite(mut s: Statement, rules: &[Arc<dyn SqlStatementRewriteRule>]) -> Statement {
1138        for rule in rules {
1139            s = rule.rewrite(s);
1140        }
1141
1142        s
1143    }
1144
1145    macro_rules! assert_rewrite {
1146        ($rules:expr, $orig:expr, $rewt:expr) => {
1147            let sql = $orig;
1148            let statement = parse(sql).expect("Failed to parse").remove(0);
1149
1150            let statement = rewrite(statement, $rules);
1151            assert_eq!(statement.to_string(), $rewt);
1152        };
1153    }
1154
1155    #[test]
1156    fn test_alias_rewrite() {
1157        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1158            vec![Arc::new(AliasDuplicatedProjectionRewrite)];
1159
1160        assert_rewrite!(
1161            &rules,
1162            "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n",
1163            "SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
1164        );
1165
1166        assert_rewrite!(
1167            &rules,
1168            "SELECT oid, * FROM pg_catalog.pg_namespace",
1169            "SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
1170        );
1171
1172        assert_rewrite!(
1173            &rules,
1174            "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id",
1175            "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
1176        );
1177
1178        let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
1179        let statement = parse(sql).expect("Failed to parse").remove(0);
1180
1181        let statement = rewrite(statement, &rules);
1182        assert_eq!(
1183            statement.to_string(),
1184            "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname"
1185        );
1186    }
1187
1188    #[test]
1189    fn test_qualifier_prepend() {
1190        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1191            vec![Arc::new(ResolveUnqualifiedIdentifer)];
1192
1193        assert_rewrite!(
1194            &rules,
1195            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname",
1196            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
1197        );
1198
1199        assert_rewrite!(
1200            &rules,
1201            "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname",
1202            "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
1203        );
1204
1205        assert_rewrite!(
1206            &rules,
1207            "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname",
1208            "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
1209        );
1210
1211        assert_rewrite!(&rules,
1212            "SELECT i.*,i.indkey as keys,c.relname,c.relnamespace,c.relam,c.reltablespace,tc.relname as tabrelname,dsc.description FROM pg_catalog.pg_index i INNER JOIN pg_catalog.pg_class c ON c.oid=i.indexrelid INNER JOIN pg_catalog.pg_class tc ON tc.oid=i.indrelid LEFT OUTER JOIN pg_catalog.pg_description dsc ON i.indexrelid=dsc.objoid WHERE i.indrelid=1 ORDER BY tabrelname, c.relname",
1213            "SELECT i.*, i.indkey AS keys, c.relname, c.relnamespace, c.relam, c.reltablespace, tc.relname AS tabrelname, dsc.description FROM pg_catalog.pg_index AS i INNER JOIN pg_catalog.pg_class AS c ON c.oid = i.indexrelid INNER JOIN pg_catalog.pg_class AS tc ON tc.oid = i.indrelid LEFT OUTER JOIN pg_catalog.pg_description AS dsc ON i.indexrelid = dsc.objoid WHERE i.indrelid = 1 ORDER BY tabrelname, c.relname"
1214        );
1215    }
1216
1217    #[test]
1218    fn test_remove_unsupported_types() {
1219        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
1220            Arc::new(RemoveQualifier),
1221            Arc::new(RemoveUnsupportedTypes::new()),
1222        ];
1223
1224        assert_rewrite!(
1225            &rules,
1226            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
1227            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
1228        );
1229
1230        assert_rewrite!(
1231            &rules,
1232            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
1233            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname"
1234        );
1235
1236        assert_rewrite!(
1237            &rules,
1238            "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname",
1239            "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname"
1240        );
1241
1242        assert_rewrite!(
1243            &rules,
1244            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname",
1245            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
1246        );
1247
1248        assert_rewrite!(
1249            &rules,
1250            "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::pg_catalog.regtype::pg_catalog.text END, c.relpersistence, c.relreplident, am.amname
1251    FROM pg_catalog.pg_class c
1252     LEFT JOIN pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid)
1253    LEFT JOIN pg_catalog.pg_am am ON (c.relam = am.oid)
1254    WHERE c.oid = '16386'",
1255            "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::TEXT END, c.relpersistence, c.relreplident, am.amname FROM pg_catalog.pg_class AS c LEFT JOIN pg_catalog.pg_class AS tc ON (c.reltoastrelid = tc.oid) LEFT JOIN pg_catalog.pg_am AS am ON (c.relam = am.oid) WHERE c.oid = '16386'"
1256        );
1257    }
1258
1259    #[test]
1260    fn test_rewrite_regclass_cast_to_subquery() {
1261        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1262            vec![Arc::new(RewriteRegclassCastToSubquery::new())];
1263
1264        assert_rewrite!(
1265            &rules,
1266            "SELECT $1::regclass::oid",
1267            "SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1268        );
1269
1270        assert_rewrite!(
1271            &rules,
1272            "SELECT $1::pg_catalog.regclass::oid",
1273            "SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1274        );
1275
1276        assert_rewrite!(
1277            &rules,
1278            "SELECT $1::pg_catalog.regclass::pg_catalog.oid",
1279            "SELECT (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident($1::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1280        );
1281
1282        assert_rewrite!(
1283            &rules,
1284            "SELECT * FROM pg_catalog.pg_class WHERE oid = 't'::pg_catalog.regclass::pg_catalog.oid",
1285            "SELECT * FROM pg_catalog.pg_class WHERE oid = (SELECT c.oid FROM pg_catalog.pg_class AS c JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace CROSS JOIN (SELECT parse_ident('t'::TEXT) AS parts) AS p WHERE n.nspname = COALESCE(CASE WHEN array_length(p.parts, 1) > 1 THEN p.parts[1] END, current_schema()) AND c.relname = p.parts[-1])"
1286        );
1287    }
1288
1289    #[test]
1290    fn test_any_to_array_contains() {
1291        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1292            vec![Arc::new(RewriteArrayAnyAllOperation)];
1293
1294        assert_rewrite!(
1295            &rules,
1296            "SELECT a = ANY(current_schemas(true))",
1297            "SELECT array_contains(current_schemas(true), a)"
1298        );
1299
1300        assert_rewrite!(
1301            &rules,
1302            "SELECT a <> ALL(current_schemas(true))",
1303            "SELECT NOT array_contains(current_schemas(true), a)"
1304        );
1305
1306        assert_rewrite!(
1307            &rules,
1308            "SELECT a = ANY('{r, l, e}')",
1309            "SELECT array_contains(ARRAY['r', 'l', 'e'], a)"
1310        );
1311
1312        assert_rewrite!(
1313            &rules,
1314            "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))",
1315            "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)"
1316        );
1317    }
1318
1319    #[test]
1320    fn test_prepend_unqualified_table_name() {
1321        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1322            vec![Arc::new(PrependUnqualifiedPgTableName)];
1323
1324        assert_rewrite!(
1325            &rules,
1326            "SELECT * FROM pg_catalog.pg_namespace",
1327            "SELECT * FROM pg_catalog.pg_namespace"
1328        );
1329
1330        assert_rewrite!(
1331            &rules,
1332            "SELECT * FROM pg_namespace",
1333            "SELECT * FROM pg_catalog.pg_namespace"
1334        );
1335
1336        assert_rewrite!(
1337            &rules,
1338            "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid",
1339            "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid"
1340        );
1341    }
1342
1343    #[test]
1344    fn test_array_literal_fix() {
1345        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixArrayLiteral)];
1346
1347        assert_rewrite!(
1348            &rules,
1349            "SELECT '{a, abc}'::text[]",
1350            "SELECT ARRAY['a', 'abc']::TEXT[]"
1351        );
1352
1353        assert_rewrite!(
1354            &rules,
1355            "SELECT '{1, 2}'::int[]",
1356            "SELECT ARRAY[1, 2]::INT[]"
1357        );
1358
1359        assert_rewrite!(
1360            &rules,
1361            "SELECT '{t, f}'::bool[]",
1362            "SELECT ARRAY[t, f]::BOOL[]"
1363        );
1364    }
1365
1366    #[test]
1367    fn test_remove_qualifier_from_table_function() {
1368        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(RemoveQualifier)];
1369
1370        assert_rewrite!(
1371            &rules,
1372            "SELECT * FROM pg_catalog.pg_get_keywords()",
1373            "SELECT * FROM pg_get_keywords()"
1374        );
1375    }
1376
1377    #[test]
1378    fn test_current_user() {
1379        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1380            vec![Arc::new(CurrentUserVariableToSessionUserFunctionCall)];
1381
1382        assert_rewrite!(&rules, "SELECT current_user", "SELECT session_user");
1383
1384        assert_rewrite!(&rules, "SELECT CURRENT_USER", "SELECT session_user");
1385
1386        assert_rewrite!(
1387            &rules,
1388            "SELECT is_null(current_user)",
1389            "SELECT is_null(session_user)"
1390        );
1391    }
1392
1393    #[test]
1394    fn test_collate_fix() {
1395        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixCollate)];
1396
1397        assert_rewrite!(&rules, "SELECT c.oid, c.relname FROM pg_catalog.pg_class c WHERE c.relname OPERATOR(pg_catalog.~) '^(tablename)$' COLLATE pg_catalog.default AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3;", "SELECT c.oid, c.relname FROM pg_catalog.pg_class AS c WHERE c.relname ~ '^(tablename)$' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3");
1398    }
1399
1400    #[test]
1401    fn test_remove_subquery() {
1402        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1403            vec![Arc::new(RemoveSubqueryFromProjection)];
1404
1405        assert_rewrite!(&rules,
1406            "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;",
1407            "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum");
1408    }
1409
1410    #[test]
1411    fn test_keep_simple_aggregated_subquery() {
1412        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1413            vec![Arc::new(RemoveSubqueryFromProjection)];
1414
1415        assert_rewrite!(&rules,
1416            "SELECT id, (SELECT COUNT(*) FROM pg_catalog.pg_attribute) AS attr_count FROM pg_catalog.pg_class",
1417            "SELECT id, (SELECT COUNT(*) FROM pg_catalog.pg_attribute LIMIT 1) AS attr_count FROM pg_catalog.pg_class"
1418        );
1419    }
1420
1421    #[test]
1422    fn test_remove_correlated_subquery() {
1423        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1424            vec![Arc::new(RemoveSubqueryFromProjection)];
1425
1426        assert_rewrite!(&rules,
1427            "SELECT a.attname, (SELECT COUNT(*) FROM pg_catalog.pg_attribute WHERE attrelid = a.oid) AS count FROM pg_catalog.pg_attribute a",
1428            "SELECT a.attname, NULL AS count FROM pg_catalog.pg_attribute AS a"
1429        );
1430    }
1431
1432    #[test]
1433    fn test_remove_non_aggregated_subquery() {
1434        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1435            vec![Arc::new(RemoveSubqueryFromProjection)];
1436
1437        assert_rewrite!(&rules,
1438            "SELECT id, (SELECT attname FROM pg_catalog.pg_attribute LIMIT 1) AS first_attr FROM pg_catalog.pg_class",
1439            "SELECT id, (SELECT attname FROM pg_catalog.pg_attribute LIMIT 1) AS first_attr FROM pg_catalog.pg_class"
1440        );
1441    }
1442
1443    #[test]
1444    fn test_keep_simple_scalar_subquery() {
1445        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1446            vec![Arc::new(RemoveSubqueryFromProjection)];
1447
1448        assert_rewrite!(
1449            &rules,
1450            "SELECT (SELECT 1) AS constant",
1451            "SELECT (SELECT 1 LIMIT 1) AS constant"
1452        );
1453
1454        assert_rewrite!(
1455            &rules,
1456            "SELECT (SELECT 'value') AS str_val",
1457            "SELECT (SELECT 'value' LIMIT 1) AS str_val"
1458        );
1459    }
1460
1461    #[test]
1462    fn test_version_rewrite() {
1463        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixVersionColumnName)];
1464
1465        assert_rewrite!(&rules, "SELECT version()", "SELECT version() AS version");
1466
1467        // Make sure we don't rewrite things we should leave alone
1468        assert_rewrite!(&rules, "SELECT version() as foo", "SELECT version() AS foo");
1469        assert_rewrite!(&rules, "SELECT version(foo)", "SELECT version(foo)");
1470        assert_rewrite!(&rules, "SELECT foo.version()", "SELECT foo.version()");
1471    }
1472}