datafusion_pg_catalog/sql/
rules.rs

1use std::collections::HashSet;
2use std::fmt::Debug;
3use std::ops::ControlFlow;
4
5use datafusion::sql::sqlparser::ast::Array;
6use datafusion::sql::sqlparser::ast::ArrayElemTypeDef;
7use datafusion::sql::sqlparser::ast::BinaryOperator;
8use datafusion::sql::sqlparser::ast::CastKind;
9use datafusion::sql::sqlparser::ast::DataType;
10use datafusion::sql::sqlparser::ast::Expr;
11use datafusion::sql::sqlparser::ast::Function;
12use datafusion::sql::sqlparser::ast::FunctionArg;
13use datafusion::sql::sqlparser::ast::FunctionArgExpr;
14use datafusion::sql::sqlparser::ast::FunctionArgumentList;
15use datafusion::sql::sqlparser::ast::FunctionArguments;
16use datafusion::sql::sqlparser::ast::Ident;
17use datafusion::sql::sqlparser::ast::ObjectName;
18use datafusion::sql::sqlparser::ast::ObjectNamePart;
19use datafusion::sql::sqlparser::ast::OrderByKind;
20use datafusion::sql::sqlparser::ast::Query;
21use datafusion::sql::sqlparser::ast::Select;
22use datafusion::sql::sqlparser::ast::SelectItem;
23use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
24use datafusion::sql::sqlparser::ast::SetExpr;
25use datafusion::sql::sqlparser::ast::Statement;
26use datafusion::sql::sqlparser::ast::TableFactor;
27use datafusion::sql::sqlparser::ast::TableWithJoins;
28use datafusion::sql::sqlparser::ast::UnaryOperator;
29use datafusion::sql::sqlparser::ast::Value;
30use datafusion::sql::sqlparser::ast::ValueWithSpan;
31use datafusion::sql::sqlparser::ast::VisitMut;
32use datafusion::sql::sqlparser::ast::VisitorMut;
33
34pub trait SqlStatementRewriteRule: Send + Sync + Debug {
35    fn rewrite(&self, s: Statement) -> Statement;
36}
37
38/// Rewrite rule for adding alias to duplicated projection
39///
40/// This rule is to deal with sql like `SELECT n.oid, n.* FROM n`, which is a
41/// valid statement in postgres. But datafusion treat it as illegal because of
42/// duplicated column oid in projection.
43///
44/// This rule will add alias to column, when there is a wildcard found in
45/// projection.
46#[derive(Debug)]
47pub struct AliasDuplicatedProjectionRewrite;
48
49impl AliasDuplicatedProjectionRewrite {
50    // Rewrites a SELECT statement to alias explicit columns from the same table as a qualified wildcard.
51    fn rewrite_select_with_alias(select: &mut Box<Select>) {
52        // 1. Collect all table aliases from qualified wildcards.
53        let mut wildcard_tables = Vec::new();
54        let mut has_simple_wildcard = false;
55        for p in &select.projection {
56            match p {
57                SelectItem::QualifiedWildcard(name, _) => match name {
58                    SelectItemQualifiedWildcardKind::ObjectName(objname) => {
59                        // for n.oid,
60                        let idents = objname
61                            .0
62                            .iter()
63                            .map(|v| v.as_ident().unwrap().value.clone())
64                            .collect::<Vec<_>>()
65                            .join(".");
66
67                        wildcard_tables.push(idents);
68                    }
69                    SelectItemQualifiedWildcardKind::Expr(_expr) => {
70                        // FIXME:
71                    }
72                },
73                SelectItem::Wildcard(_) => {
74                    has_simple_wildcard = true;
75                }
76                _ => {}
77            }
78        }
79
80        // If there are no qualified wildcards, there's nothing to do.
81        if wildcard_tables.is_empty() && !has_simple_wildcard {
82            return;
83        }
84
85        // 2. Rewrite the projection, adding aliases to matching columns.
86        let mut new_projection = vec![];
87        for p in select.projection.drain(..) {
88            match p {
89                SelectItem::UnnamedExpr(expr) => {
90                    let alias_partial = match &expr {
91                        // Case for `oid` (unqualified identifier)
92                        Expr::Identifier(ident) => Some(ident.clone()),
93                        // Case for `n.oid` (compound identifier)
94                        Expr::CompoundIdentifier(idents) => {
95                            // compare every ident but the last
96                            if idents.len() > 1 {
97                                let table_name = &idents[..idents.len() - 1]
98                                    .iter()
99                                    .map(|i| i.value.clone())
100                                    .collect::<Vec<_>>()
101                                    .join(".");
102                                if wildcard_tables.iter().any(|name| name == table_name) {
103                                    Some(idents[idents.len() - 1].clone())
104                                } else {
105                                    None
106                                }
107                            } else {
108                                None
109                            }
110                        }
111                        _ => None,
112                    };
113
114                    if let Some(name) = alias_partial {
115                        let alias = format!("__alias_{name}");
116                        new_projection.push(SelectItem::ExprWithAlias {
117                            expr,
118                            alias: Ident::new(alias),
119                        });
120                    } else {
121                        new_projection.push(SelectItem::UnnamedExpr(expr));
122                    }
123                }
124                // Preserve existing aliases and wildcards.
125                _ => new_projection.push(p),
126            }
127        }
128        select.projection = new_projection;
129    }
130}
131
132impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
133    fn rewrite(&self, mut statement: Statement) -> Statement {
134        if let Statement::Query(query) = &mut statement {
135            if let SetExpr::Select(select) = query.body.as_mut() {
136                Self::rewrite_select_with_alias(select);
137            }
138        }
139
140        statement
141    }
142}
143
144/// Prepend qualifier for order by or filter when there is qualified wildcard
145///
146/// Postgres allows unqualified identifier in ORDER BY and FILTER but it's not
147/// accepted by datafusion.
148#[derive(Debug)]
149pub struct ResolveUnqualifiedIdentifer;
150
151impl ResolveUnqualifiedIdentifer {
152    fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
153        if let SetExpr::Select(select) = query.body.as_mut() {
154            // Step 1: Find all table aliases from FROM and JOIN clauses.
155            let table_aliases = Self::get_table_aliases(&select.from);
156
157            // Step 2: Check for a single qualified wildcard in the projection.
158            let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
159            if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
160                return; // Conditions not met.
161            }
162
163            let wildcard_alias = qualified_wildcard_alias.unwrap();
164
165            // Step 3: Rewrite expressions in the WHERE and ORDER BY clauses.
166            if let Some(selection) = &mut select.selection {
167                Self::rewrite_expr(selection, &wildcard_alias, &table_aliases);
168            }
169
170            if let Some(OrderByKind::Expressions(order_by_exprs)) =
171                query.order_by.as_mut().map(|o| &mut o.kind)
172            {
173                for order_by_expr in order_by_exprs {
174                    Self::rewrite_expr(&mut order_by_expr.expr, &wildcard_alias, &table_aliases);
175                }
176            }
177        }
178    }
179
180    fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
181        let mut aliases = HashSet::new();
182        for table_with_joins in tables {
183            if let TableFactor::Table {
184                alias: Some(alias), ..
185            } = &table_with_joins.relation
186            {
187                aliases.insert(alias.name.value.clone());
188            }
189            for join in &table_with_joins.joins {
190                if let TableFactor::Table {
191                    alias: Some(alias), ..
192                } = &join.relation
193                {
194                    aliases.insert(alias.name.value.clone());
195                }
196            }
197        }
198        aliases
199    }
200
201    fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
202        let mut qualified_wildcards = projection
203            .iter()
204            .filter_map(|item| {
205                if let SelectItem::QualifiedWildcard(
206                    SelectItemQualifiedWildcardKind::ObjectName(objname),
207                    _,
208                ) = item
209                {
210                    Some(
211                        objname
212                            .0
213                            .iter()
214                            .map(|v| v.as_ident().unwrap().value.clone())
215                            .collect::<Vec<_>>()
216                            .join("."),
217                    )
218                } else {
219                    None
220                }
221            })
222            .collect::<Vec<_>>();
223
224        if qualified_wildcards.len() == 1 {
225            Some(qualified_wildcards.remove(0))
226        } else {
227            None
228        }
229    }
230
231    fn rewrite_expr(expr: &mut Expr, wildcard_alias: &str, table_aliases: &HashSet<String>) {
232        match expr {
233            Expr::Identifier(ident) => {
234                // If the identifier is not a table alias itself, rewrite it.
235                if !table_aliases.contains(&ident.value) {
236                    *expr = Expr::CompoundIdentifier(vec![
237                        Ident::new(wildcard_alias.to_string()),
238                        ident.clone(),
239                    ]);
240                }
241            }
242            Expr::BinaryOp { left, right, .. } => {
243                Self::rewrite_expr(left, wildcard_alias, table_aliases);
244                Self::rewrite_expr(right, wildcard_alias, table_aliases);
245            }
246            // Add more cases for other expression types as needed (e.g., `InList`, `Between`, etc.)
247            _ => {}
248        }
249    }
250}
251
252impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
253    fn rewrite(&self, mut statement: Statement) -> Statement {
254        if let Statement::Query(query) = &mut statement {
255            Self::rewrite_unqualified_identifiers(query);
256        }
257
258        statement
259    }
260}
261
262/// Remove datafusion unsupported type annotations
263/// it also removes pg_catalog as qualifier
264#[derive(Debug)]
265pub struct RemoveUnsupportedTypes {
266    unsupported_types: HashSet<String>,
267}
268
269impl Default for RemoveUnsupportedTypes {
270    fn default() -> Self {
271        Self::new()
272    }
273}
274
275impl RemoveUnsupportedTypes {
276    pub fn new() -> Self {
277        let mut unsupported_types = HashSet::new();
278
279        for item in [
280            "regclass",
281            "regproc",
282            "regtype",
283            "regtype[]",
284            "regnamespace",
285            "oid",
286        ] {
287            unsupported_types.insert(item.to_owned());
288            unsupported_types.insert(format!("pg_catalog.{item}"));
289        }
290
291        Self { unsupported_types }
292    }
293}
294
295struct RemoveUnsupportedTypesVisitor<'a> {
296    unsupported_types: &'a HashSet<String>,
297}
298
299impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> {
300    type Break = ();
301
302    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
303        match expr {
304            // This is the key part: identify constants with type annotations.
305            Expr::TypedString { value, data_type } => {
306                if self
307                    .unsupported_types
308                    .contains(data_type.to_string().to_lowercase().as_str())
309                {
310                    *expr =
311                        Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span());
312                }
313            }
314            Expr::Cast {
315                data_type,
316                expr: value,
317                ..
318            } => {
319                if self
320                    .unsupported_types
321                    .contains(data_type.to_string().to_lowercase().as_str())
322                {
323                    *expr = *value.clone();
324                }
325            }
326            // Add more match arms for other expression types (e.g., `Function`, `InList`) as needed.
327            _ => {}
328        }
329
330        ControlFlow::Continue(())
331    }
332}
333
334impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
335    fn rewrite(&self, mut statement: Statement) -> Statement {
336        let mut visitor = RemoveUnsupportedTypesVisitor {
337            unsupported_types: &self.unsupported_types,
338        };
339        let _ = statement.visit(&mut visitor);
340        statement
341    }
342}
343
344/// Rewrite Postgres's ANY operator to array_contains
345#[derive(Debug)]
346pub struct RewriteArrayAnyAllOperation;
347
348struct RewriteArrayAnyAllOperationVisitor;
349
350impl RewriteArrayAnyAllOperationVisitor {
351    fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr {
352        let array = if let Expr::Value(ValueWithSpan {
353            value: Value::SingleQuotedString(array_literal),
354            ..
355        }) = right
356        {
357            let array_literal = array_literal.trim();
358            if array_literal.starts_with('{') && array_literal.ends_with('}') {
359                let items = array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
360                let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
361
362                // For now, we assume the data type is string
363                let elems = items
364                    .map(|s| {
365                        Expr::Value(Value::SingleQuotedString(s.to_string()).with_empty_span())
366                    })
367                    .collect();
368                Expr::Array(Array {
369                    elem: elems,
370                    named: true,
371                })
372            } else {
373                right.clone()
374            }
375        } else {
376            right.clone()
377        };
378
379        Expr::Function(Function {
380            name: ObjectName::from(vec![Ident::new("array_contains")]),
381            args: FunctionArguments::List(FunctionArgumentList {
382                args: vec![
383                    FunctionArg::Unnamed(FunctionArgExpr::Expr(array)),
384                    FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())),
385                ],
386                duplicate_treatment: None,
387                clauses: vec![],
388            }),
389            uses_odbc_syntax: false,
390            parameters: FunctionArguments::None,
391            filter: None,
392            null_treatment: None,
393            over: None,
394            within_group: vec![],
395        })
396    }
397}
398
399impl VisitorMut for RewriteArrayAnyAllOperationVisitor {
400    type Break = ();
401
402    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
403        match expr {
404            Expr::AnyOp {
405                left,
406                compare_op,
407                right,
408                ..
409            } => match compare_op {
410                BinaryOperator::Eq => {
411                    *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref());
412                }
413                BinaryOperator::NotEq => {
414                    // TODO:left not equals to any element in array
415                }
416                _ => {}
417            },
418            Expr::AllOp {
419                left,
420                compare_op,
421                right,
422            } => match compare_op {
423                BinaryOperator::Eq => {
424                    // TODO: left equals to every element in array
425                }
426                BinaryOperator::NotEq => {
427                    *expr = Expr::UnaryOp {
428                        op: UnaryOperator::Not,
429                        expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())),
430                    }
431                }
432                _ => {}
433            },
434            _ => {}
435        }
436
437        ControlFlow::Continue(())
438    }
439}
440
441impl SqlStatementRewriteRule for RewriteArrayAnyAllOperation {
442    fn rewrite(&self, mut s: Statement) -> Statement {
443        let mut visitor = RewriteArrayAnyAllOperationVisitor;
444
445        let _ = s.visit(&mut visitor);
446
447        s
448    }
449}
450
451/// Prepend qualifier to table_name
452///
453/// Postgres has pg_catalog in search_path by default so it allow access to
454/// `pg_namespace` without `pg_catalog.` qualifier
455#[derive(Debug)]
456pub struct PrependUnqualifiedPgTableName;
457
458struct PrependUnqualifiedPgTableNameVisitor;
459
460impl VisitorMut for PrependUnqualifiedPgTableNameVisitor {
461    type Break = ();
462
463    fn pre_visit_table_factor(
464        &mut self,
465        table_factor: &mut TableFactor,
466    ) -> ControlFlow<Self::Break> {
467        if let TableFactor::Table { name, args, .. } = table_factor {
468            // not a table function
469            if args.is_none() && name.0.len() == 1 {
470                if let ObjectNamePart::Identifier(ident) = &name.0[0] {
471                    if ident.value.starts_with("pg_") {
472                        *name = ObjectName(vec![
473                            ObjectNamePart::Identifier(Ident::new("pg_catalog")),
474                            name.0[0].clone(),
475                        ]);
476                    }
477                }
478            }
479        }
480
481        ControlFlow::Continue(())
482    }
483}
484
485impl SqlStatementRewriteRule for PrependUnqualifiedPgTableName {
486    fn rewrite(&self, mut s: Statement) -> Statement {
487        let mut visitor = PrependUnqualifiedPgTableNameVisitor;
488
489        let _ = s.visit(&mut visitor);
490        s
491    }
492}
493
494#[derive(Debug)]
495pub struct FixArrayLiteral;
496
497struct FixArrayLiteralVisitor;
498
499impl FixArrayLiteralVisitor {
500    fn is_string_type(dt: &DataType) -> bool {
501        matches!(
502            dt,
503            DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::String(_)
504        )
505    }
506}
507
508impl VisitorMut for FixArrayLiteralVisitor {
509    type Break = ();
510
511    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
512        if let Expr::Cast {
513            kind,
514            expr,
515            data_type,
516            ..
517        } = expr
518        {
519            if kind == &CastKind::DoubleColon {
520                if let DataType::Array(arr) = data_type {
521                    // cast some to
522                    if let Expr::Value(ValueWithSpan {
523                        value: Value::SingleQuotedString(array_literal),
524                        ..
525                    }) = expr.as_ref()
526                    {
527                        let items =
528                            array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
529                        let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
530
531                        let is_text = match arr {
532                            ArrayElemTypeDef::AngleBracket(dt) => Self::is_string_type(dt.as_ref()),
533                            ArrayElemTypeDef::SquareBracket(dt, _) => {
534                                Self::is_string_type(dt.as_ref())
535                            }
536                            ArrayElemTypeDef::Parenthesis(dt) => Self::is_string_type(dt.as_ref()),
537                            _ => false,
538                        };
539
540                        let elems = items
541                            .map(|s| {
542                                if is_text {
543                                    Expr::Value(
544                                        Value::SingleQuotedString(s.to_string()).with_empty_span(),
545                                    )
546                                } else {
547                                    Expr::Value(
548                                        Value::Number(s.to_string(), false).with_empty_span(),
549                                    )
550                                }
551                            })
552                            .collect();
553                        *expr = Box::new(Expr::Array(Array {
554                            elem: elems,
555                            named: true,
556                        }));
557                    }
558                }
559            }
560        }
561
562        ControlFlow::Continue(())
563    }
564}
565
566impl SqlStatementRewriteRule for FixArrayLiteral {
567    fn rewrite(&self, mut s: Statement) -> Statement {
568        let mut visitor = FixArrayLiteralVisitor;
569
570        let _ = s.visit(&mut visitor);
571        s
572    }
573}
574
575/// Remove qualifier from unsupported items
576///
577/// This rewriter removes qualifier from following items:
578/// 1. type cast: for example: `pg_catalog.text`
579/// 2. function name: for example: `pg_catalog.array_to_string`,
580/// 3. table function name
581#[derive(Debug)]
582pub struct RemoveQualifier;
583
584struct RemoveQualifierVisitor;
585
586impl VisitorMut for RemoveQualifierVisitor {
587    type Break = ();
588
589    fn pre_visit_table_factor(
590        &mut self,
591        table_factor: &mut TableFactor,
592    ) -> ControlFlow<Self::Break> {
593        // remove table function qualifier
594        if let TableFactor::Table { name, args, .. } = table_factor {
595            if args.is_some() {
596                //  multiple idents in name, which means it's a qualified table name
597                if name.0.len() > 1 {
598                    if let Some(last_ident) = name.0.pop() {
599                        *name = ObjectName(vec![last_ident]);
600                    }
601                }
602            }
603        }
604        ControlFlow::Continue(())
605    }
606
607    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
608        match expr {
609            Expr::Cast { data_type, .. } => {
610                // rewrite custom pg_catalog. qualified types
611                let data_type_str = data_type.to_string();
612                match data_type_str.as_str() {
613                    "pg_catalog.text" => {
614                        *data_type = DataType::Text;
615                    }
616                    "pg_catalog.int2[]" => {
617                        *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket(
618                            Box::new(DataType::Int16),
619                            None,
620                        ));
621                    }
622                    _ => {}
623                }
624            }
625            Expr::Function(function) => {
626                // remove qualifier from pg_catalog.function
627                let name = &mut function.name;
628                if name.0.len() > 1 {
629                    if let Some(last_ident) = name.0.pop() {
630                        *name = ObjectName(vec![last_ident]);
631                    }
632                }
633            }
634
635            _ => {}
636        }
637        ControlFlow::Continue(())
638    }
639}
640
641impl SqlStatementRewriteRule for RemoveQualifier {
642    fn rewrite(&self, mut s: Statement) -> Statement {
643        let mut visitor = RemoveQualifierVisitor;
644
645        let _ = s.visit(&mut visitor);
646        s
647    }
648}
649
650/// Replace `current_user` with `session_user()`
651#[derive(Debug)]
652pub struct CurrentUserVariableToSessionUserFunctionCall;
653
654struct CurrentUserVariableToSessionUserFunctionCallVisitor;
655
656impl VisitorMut for CurrentUserVariableToSessionUserFunctionCallVisitor {
657    type Break = ();
658
659    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
660        if let Expr::Identifier(ident) = expr {
661            if ident.quote_style.is_none() && ident.value.to_lowercase() == "current_user" {
662                *expr = Expr::Function(Function {
663                    name: ObjectName::from(vec![Ident::new("session_user")]),
664                    args: FunctionArguments::None,
665                    uses_odbc_syntax: false,
666                    parameters: FunctionArguments::None,
667                    filter: None,
668                    null_treatment: None,
669                    over: None,
670                    within_group: vec![],
671                });
672            }
673        }
674
675        if let Expr::Function(func) = expr {
676            let fname = func
677                .name
678                .0
679                .iter()
680                .map(|ident| ident.to_string())
681                .collect::<Vec<String>>()
682                .join(".");
683            if fname.to_lowercase() == "current_user" {
684                func.name = ObjectName::from(vec![Ident::new("session_user")])
685            }
686        }
687
688        ControlFlow::Continue(())
689    }
690}
691
692impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall {
693    fn rewrite(&self, mut s: Statement) -> Statement {
694        let mut visitor = CurrentUserVariableToSessionUserFunctionCallVisitor;
695
696        let _ = s.visit(&mut visitor);
697        s
698    }
699}
700
701/// Fix collate and regex calls
702#[derive(Debug)]
703pub struct FixCollate;
704
705struct FixCollateVisitor;
706
707impl VisitorMut for FixCollateVisitor {
708    type Break = ();
709
710    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
711        match expr {
712            Expr::Collate { expr: inner, .. } => {
713                *expr = inner.as_ref().clone();
714            }
715            Expr::BinaryOp { op, .. } => {
716                if let BinaryOperator::PGCustomBinaryOperator(ops) = op {
717                    if *ops == ["pg_catalog", "~"] {
718                        *op = BinaryOperator::PGRegexMatch;
719                    }
720                }
721            }
722            _ => {}
723        }
724
725        ControlFlow::Continue(())
726    }
727}
728
729impl SqlStatementRewriteRule for FixCollate {
730    fn rewrite(&self, mut s: Statement) -> Statement {
731        let mut visitor = FixCollateVisitor;
732
733        let _ = s.visit(&mut visitor);
734        s
735    }
736}
737
738/// Datafusion doesn't support subquery on projection
739#[derive(Debug)]
740pub struct RemoveSubqueryFromProjection;
741
742struct RemoveSubqueryFromProjectionVisitor;
743
744impl VisitorMut for RemoveSubqueryFromProjectionVisitor {
745    type Break = ();
746
747    fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
748        if let SetExpr::Select(select) = query.body.as_mut() {
749            for projection in &mut select.projection {
750                match projection {
751                    SelectItem::UnnamedExpr(expr) => {
752                        if let Expr::Subquery(_) = expr {
753                            *expr = Expr::Value(Value::Null.with_empty_span());
754                        }
755                    }
756                    SelectItem::ExprWithAlias { expr, .. } => {
757                        if let Expr::Subquery(_) = expr {
758                            *expr = Expr::Value(Value::Null.with_empty_span());
759                        }
760                    }
761                    _ => {}
762                }
763            }
764        }
765
766        ControlFlow::Continue(())
767    }
768}
769
770impl SqlStatementRewriteRule for RemoveSubqueryFromProjection {
771    fn rewrite(&self, mut s: Statement) -> Statement {
772        let mut visitor = RemoveSubqueryFromProjectionVisitor;
773        let _ = s.visit(&mut visitor);
774
775        s
776    }
777}
778
779/// `select version()` should return column named `version` not `version()`
780#[derive(Debug)]
781pub struct FixVersionColumnName;
782
783struct FixVersionColumnNameVisitor;
784
785impl VisitorMut for FixVersionColumnNameVisitor {
786    type Break = ();
787
788    fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
789        if let SetExpr::Select(select) = query.body.as_mut() {
790            for projection in &mut select.projection {
791                if let SelectItem::UnnamedExpr(Expr::Function(f)) = projection {
792                    if f.name.0.len() == 1 {
793                        if let ObjectNamePart::Identifier(part) = &f.name.0[0] {
794                            if part.value == "version" {
795                                if let FunctionArguments::List(args) = &f.args {
796                                    if args.args.is_empty() {
797                                        *projection = SelectItem::ExprWithAlias {
798                                            expr: Expr::Function(f.clone()),
799                                            alias: Ident::new("version"),
800                                        }
801                                    }
802                                }
803                            }
804                        }
805                    }
806                }
807            }
808        }
809
810        ControlFlow::Continue(())
811    }
812}
813
814impl SqlStatementRewriteRule for FixVersionColumnName {
815    fn rewrite(&self, mut s: Statement) -> Statement {
816        let mut visitor = FixVersionColumnNameVisitor;
817        let _ = s.visit(&mut visitor);
818
819        s
820    }
821}
822
823#[cfg(test)]
824mod tests {
825    use super::*;
826    use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
827    use datafusion::sql::sqlparser::parser::Parser;
828    use datafusion::sql::sqlparser::parser::ParserError;
829    use std::sync::Arc;
830
831    fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> {
832        let dialect = PostgreSqlDialect {};
833
834        Parser::parse_sql(&dialect, sql)
835    }
836
837    fn rewrite(mut s: Statement, rules: &[Arc<dyn SqlStatementRewriteRule>]) -> Statement {
838        for rule in rules {
839            s = rule.rewrite(s);
840        }
841
842        s
843    }
844
845    macro_rules! assert_rewrite {
846        ($rules:expr, $orig:expr, $rewt:expr) => {
847            let sql = $orig;
848            let statement = parse(sql).expect("Failed to parse").remove(0);
849
850            let statement = rewrite(statement, $rules);
851            assert_eq!(statement.to_string(), $rewt);
852        };
853    }
854
855    #[test]
856    fn test_alias_rewrite() {
857        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
858            vec![Arc::new(AliasDuplicatedProjectionRewrite)];
859
860        assert_rewrite!(
861            &rules,
862            "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n",
863            "SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
864        );
865
866        assert_rewrite!(
867            &rules,
868            "SELECT oid, * FROM pg_catalog.pg_namespace",
869            "SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
870        );
871
872        assert_rewrite!(
873            &rules,
874            "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id",
875            "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
876        );
877
878        let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
879        let statement = parse(sql).expect("Failed to parse").remove(0);
880
881        let statement = rewrite(statement, &rules);
882        assert_eq!(
883            statement.to_string(),
884            "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname"
885        );
886    }
887
888    #[test]
889    fn test_qualifier_prepend() {
890        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
891            vec![Arc::new(ResolveUnqualifiedIdentifer)];
892
893        assert_rewrite!(
894            &rules,
895            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname",
896            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
897        );
898
899        assert_rewrite!(
900            &rules,
901            "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname",
902            "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
903        );
904
905        assert_rewrite!(
906            &rules,
907            "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname",
908            "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
909        );
910    }
911
912    #[test]
913    fn test_remove_unsupported_types() {
914        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
915            Arc::new(RemoveQualifier),
916            Arc::new(RemoveUnsupportedTypes::new()),
917        ];
918
919        assert_rewrite!(
920            &rules,
921            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
922            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
923        );
924
925        assert_rewrite!(
926            &rules,
927            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
928            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname"
929        );
930
931        assert_rewrite!(
932            &rules,
933            "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname",
934            "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname"
935        );
936
937        assert_rewrite!(
938            &rules,
939            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname",
940            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
941        );
942
943        assert_rewrite!(
944            &rules,
945            "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::pg_catalog.regtype::pg_catalog.text END, c.relpersistence, c.relreplident, am.amname
946    FROM pg_catalog.pg_class c
947     LEFT JOIN pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid)
948    LEFT JOIN pg_catalog.pg_am am ON (c.relam = am.oid)
949    WHERE c.oid = '16386'",
950            "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::TEXT END, c.relpersistence, c.relreplident, am.amname FROM pg_catalog.pg_class AS c LEFT JOIN pg_catalog.pg_class AS tc ON (c.reltoastrelid = tc.oid) LEFT JOIN pg_catalog.pg_am AS am ON (c.relam = am.oid) WHERE c.oid = '16386'"
951        );
952    }
953
954    #[test]
955    fn test_any_to_array_contains() {
956        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
957            vec![Arc::new(RewriteArrayAnyAllOperation)];
958
959        assert_rewrite!(
960            &rules,
961            "SELECT a = ANY(current_schemas(true))",
962            "SELECT array_contains(current_schemas(true), a)"
963        );
964
965        assert_rewrite!(
966            &rules,
967            "SELECT a <> ALL(current_schemas(true))",
968            "SELECT NOT array_contains(current_schemas(true), a)"
969        );
970
971        assert_rewrite!(
972            &rules,
973            "SELECT a = ANY('{r, l, e}')",
974            "SELECT array_contains(ARRAY['r', 'l', 'e'], a)"
975        );
976
977        assert_rewrite!(
978            &rules,
979            "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))",
980            "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)"
981        );
982    }
983
984    #[test]
985    fn test_prepend_unqualified_table_name() {
986        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
987            vec![Arc::new(PrependUnqualifiedPgTableName)];
988
989        assert_rewrite!(
990            &rules,
991            "SELECT * FROM pg_catalog.pg_namespace",
992            "SELECT * FROM pg_catalog.pg_namespace"
993        );
994
995        assert_rewrite!(
996            &rules,
997            "SELECT * FROM pg_namespace",
998            "SELECT * FROM pg_catalog.pg_namespace"
999        );
1000
1001        assert_rewrite!(
1002            &rules,
1003            "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid",
1004            "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid"
1005        );
1006    }
1007
1008    #[test]
1009    fn test_array_literal_fix() {
1010        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixArrayLiteral)];
1011
1012        assert_rewrite!(
1013            &rules,
1014            "SELECT '{a, abc}'::text[]",
1015            "SELECT ARRAY['a', 'abc']::TEXT[]"
1016        );
1017
1018        assert_rewrite!(
1019            &rules,
1020            "SELECT '{1, 2}'::int[]",
1021            "SELECT ARRAY[1, 2]::INT[]"
1022        );
1023
1024        assert_rewrite!(
1025            &rules,
1026            "SELECT '{t, f}'::bool[]",
1027            "SELECT ARRAY[t, f]::BOOL[]"
1028        );
1029    }
1030
1031    #[test]
1032    fn test_remove_qualifier_from_table_function() {
1033        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(RemoveQualifier)];
1034
1035        assert_rewrite!(
1036            &rules,
1037            "SELECT * FROM pg_catalog.pg_get_keywords()",
1038            "SELECT * FROM pg_get_keywords()"
1039        );
1040    }
1041
1042    #[test]
1043    fn test_current_user() {
1044        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1045            vec![Arc::new(CurrentUserVariableToSessionUserFunctionCall)];
1046
1047        assert_rewrite!(&rules, "SELECT current_user", "SELECT session_user");
1048
1049        assert_rewrite!(&rules, "SELECT CURRENT_USER", "SELECT session_user");
1050
1051        assert_rewrite!(
1052            &rules,
1053            "SELECT is_null(current_user)",
1054            "SELECT is_null(session_user)"
1055        );
1056    }
1057
1058    #[test]
1059    fn test_collate_fix() {
1060        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixCollate)];
1061
1062        assert_rewrite!(&rules, "SELECT c.oid, c.relname FROM pg_catalog.pg_class c WHERE c.relname OPERATOR(pg_catalog.~) '^(tablename)$' COLLATE pg_catalog.default AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3;", "SELECT c.oid, c.relname FROM pg_catalog.pg_class AS c WHERE c.relname ~ '^(tablename)$' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3");
1063    }
1064
1065    #[test]
1066    fn test_remove_subquery() {
1067        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1068            vec![Arc::new(RemoveSubqueryFromProjection)];
1069
1070        assert_rewrite!(&rules,
1071            "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;",
1072            "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum");
1073    }
1074
1075    #[test]
1076    fn test_version_rewrite() {
1077        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixVersionColumnName)];
1078
1079        assert_rewrite!(&rules, "SELECT version()", "SELECT version() AS version");
1080
1081        // Make sure we don't rewrite things we should leave alone
1082        assert_rewrite!(&rules, "SELECT version() as foo", "SELECT version() AS foo");
1083        assert_rewrite!(&rules, "SELECT version(foo)", "SELECT version(foo)");
1084        assert_rewrite!(&rules, "SELECT foo.version()", "SELECT foo.version()");
1085    }
1086}