datafusion_pg_catalog/sql/
rules.rs

1use std::collections::HashSet;
2use std::fmt::Debug;
3use std::ops::ControlFlow;
4
5use datafusion::sql::sqlparser::ast::Array;
6use datafusion::sql::sqlparser::ast::ArrayElemTypeDef;
7use datafusion::sql::sqlparser::ast::BinaryOperator;
8use datafusion::sql::sqlparser::ast::CastKind;
9use datafusion::sql::sqlparser::ast::DataType;
10use datafusion::sql::sqlparser::ast::Expr;
11use datafusion::sql::sqlparser::ast::Function;
12use datafusion::sql::sqlparser::ast::FunctionArg;
13use datafusion::sql::sqlparser::ast::FunctionArgExpr;
14use datafusion::sql::sqlparser::ast::FunctionArgumentList;
15use datafusion::sql::sqlparser::ast::FunctionArguments;
16use datafusion::sql::sqlparser::ast::Ident;
17use datafusion::sql::sqlparser::ast::ObjectName;
18use datafusion::sql::sqlparser::ast::ObjectNamePart;
19use datafusion::sql::sqlparser::ast::OrderByKind;
20use datafusion::sql::sqlparser::ast::Query;
21use datafusion::sql::sqlparser::ast::Select;
22use datafusion::sql::sqlparser::ast::SelectItem;
23use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
24use datafusion::sql::sqlparser::ast::SetExpr;
25use datafusion::sql::sqlparser::ast::Statement;
26use datafusion::sql::sqlparser::ast::TableFactor;
27use datafusion::sql::sqlparser::ast::TableWithJoins;
28use datafusion::sql::sqlparser::ast::TypedString;
29use datafusion::sql::sqlparser::ast::UnaryOperator;
30use datafusion::sql::sqlparser::ast::Value;
31use datafusion::sql::sqlparser::ast::ValueWithSpan;
32use datafusion::sql::sqlparser::ast::VisitMut;
33use datafusion::sql::sqlparser::ast::VisitorMut;
34
35pub trait SqlStatementRewriteRule: Send + Sync + Debug {
36    fn rewrite(&self, s: Statement) -> Statement;
37}
38
39/// Rewrite rule for adding alias to duplicated projection
40///
41/// This rule is to deal with sql like `SELECT n.oid, n.* FROM n`, which is a
42/// valid statement in postgres. But datafusion treat it as illegal because of
43/// duplicated column oid in projection.
44///
45/// This rule will add alias to column, when there is a wildcard found in
46/// projection.
47#[derive(Debug)]
48pub struct AliasDuplicatedProjectionRewrite;
49
50impl AliasDuplicatedProjectionRewrite {
51    // Rewrites a SELECT statement to alias explicit columns from the same table as a qualified wildcard.
52    fn rewrite_select_with_alias(select: &mut Box<Select>) {
53        // 1. Collect all table aliases from qualified wildcards.
54        let mut wildcard_tables = Vec::new();
55        let mut has_simple_wildcard = false;
56        for p in &select.projection {
57            match p {
58                SelectItem::QualifiedWildcard(name, _) => match name {
59                    SelectItemQualifiedWildcardKind::ObjectName(objname) => {
60                        // for n.oid,
61                        let idents = objname
62                            .0
63                            .iter()
64                            .map(|v| v.as_ident().unwrap().value.clone())
65                            .collect::<Vec<_>>()
66                            .join(".");
67
68                        wildcard_tables.push(idents);
69                    }
70                    SelectItemQualifiedWildcardKind::Expr(_expr) => {
71                        // FIXME:
72                    }
73                },
74                SelectItem::Wildcard(_) => {
75                    has_simple_wildcard = true;
76                }
77                _ => {}
78            }
79        }
80
81        // If there are no qualified wildcards, there's nothing to do.
82        if wildcard_tables.is_empty() && !has_simple_wildcard {
83            return;
84        }
85
86        // 2. Rewrite the projection, adding aliases to matching columns.
87        let mut new_projection = vec![];
88        for p in select.projection.drain(..) {
89            match p {
90                SelectItem::UnnamedExpr(expr) => {
91                    let alias_partial = match &expr {
92                        // Case for `oid` (unqualified identifier)
93                        Expr::Identifier(ident) => Some(ident.clone()),
94                        // Case for `n.oid` (compound identifier)
95                        Expr::CompoundIdentifier(idents) => {
96                            // compare every ident but the last
97                            if idents.len() > 1 {
98                                let table_name = &idents[..idents.len() - 1]
99                                    .iter()
100                                    .map(|i| i.value.clone())
101                                    .collect::<Vec<_>>()
102                                    .join(".");
103                                if wildcard_tables.iter().any(|name| name == table_name) {
104                                    Some(idents[idents.len() - 1].clone())
105                                } else {
106                                    None
107                                }
108                            } else {
109                                None
110                            }
111                        }
112                        _ => None,
113                    };
114
115                    if let Some(name) = alias_partial {
116                        let alias = format!("__alias_{name}");
117                        new_projection.push(SelectItem::ExprWithAlias {
118                            expr,
119                            alias: Ident::new(alias),
120                        });
121                    } else {
122                        new_projection.push(SelectItem::UnnamedExpr(expr));
123                    }
124                }
125                // Preserve existing aliases and wildcards.
126                _ => new_projection.push(p),
127            }
128        }
129        select.projection = new_projection;
130    }
131}
132
133impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
134    fn rewrite(&self, mut statement: Statement) -> Statement {
135        if let Statement::Query(query) = &mut statement {
136            if let SetExpr::Select(select) = query.body.as_mut() {
137                Self::rewrite_select_with_alias(select);
138            }
139        }
140
141        statement
142    }
143}
144
145/// Prepend qualifier for order by or filter when there is qualified wildcard
146///
147/// Postgres allows unqualified identifier in ORDER BY and FILTER but it's not
148/// accepted by datafusion.
149#[derive(Debug)]
150pub struct ResolveUnqualifiedIdentifer;
151
152impl ResolveUnqualifiedIdentifer {
153    fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
154        if let SetExpr::Select(select) = query.body.as_mut() {
155            // Step 1: Find all table aliases from FROM and JOIN clauses.
156            let table_aliases = Self::get_table_aliases(&select.from);
157
158            // Step 2: Check for a single qualified wildcard in the projection.
159            let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
160            if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
161                return; // Conditions not met.
162            }
163
164            let wildcard_alias = qualified_wildcard_alias.unwrap();
165
166            // Step 2.5: Collect all projection aliases to avoid rewriting them
167            let projection_aliases = Self::get_projection_aliases(&select.projection);
168
169            // Step 3: Rewrite expressions in the WHERE and ORDER BY clauses.
170            if let Some(selection) = &mut select.selection {
171                Self::rewrite_expr(
172                    selection,
173                    &wildcard_alias,
174                    &table_aliases,
175                    &projection_aliases,
176                );
177            }
178
179            if let Some(OrderByKind::Expressions(order_by_exprs)) =
180                query.order_by.as_mut().map(|o| &mut o.kind)
181            {
182                for order_by_expr in order_by_exprs {
183                    Self::rewrite_expr(
184                        &mut order_by_expr.expr,
185                        &wildcard_alias,
186                        &table_aliases,
187                        &projection_aliases,
188                    );
189                }
190            }
191        }
192    }
193
194    fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
195        let mut aliases = HashSet::new();
196        for table_with_joins in tables {
197            if let TableFactor::Table {
198                alias: Some(alias), ..
199            } = &table_with_joins.relation
200            {
201                aliases.insert(alias.name.value.clone());
202            }
203            for join in &table_with_joins.joins {
204                if let TableFactor::Table {
205                    alias: Some(alias), ..
206                } = &join.relation
207                {
208                    aliases.insert(alias.name.value.clone());
209                }
210            }
211        }
212        aliases
213    }
214
215    fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
216        let mut qualified_wildcards = projection
217            .iter()
218            .filter_map(|item| {
219                if let SelectItem::QualifiedWildcard(
220                    SelectItemQualifiedWildcardKind::ObjectName(objname),
221                    _,
222                ) = item
223                {
224                    Some(
225                        objname
226                            .0
227                            .iter()
228                            .map(|v| v.as_ident().unwrap().value.clone())
229                            .collect::<Vec<_>>()
230                            .join("."),
231                    )
232                } else {
233                    None
234                }
235            })
236            .collect::<Vec<_>>();
237
238        if qualified_wildcards.len() == 1 {
239            Some(qualified_wildcards.remove(0))
240        } else {
241            None
242        }
243    }
244
245    fn get_projection_aliases(projection: &[SelectItem]) -> HashSet<String> {
246        let mut aliases = HashSet::new();
247        for item in projection {
248            match item {
249                SelectItem::ExprWithAlias { alias, .. } => {
250                    aliases.insert(alias.value.clone());
251                }
252                SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
253                    aliases.insert(ident.value.clone());
254                }
255                _ => {}
256            }
257        }
258        aliases
259    }
260
261    fn rewrite_expr(
262        expr: &mut Expr,
263        wildcard_alias: &str,
264        table_aliases: &HashSet<String>,
265        projection_aliases: &HashSet<String>,
266    ) {
267        match expr {
268            Expr::Identifier(ident) => {
269                // If the identifier is not a table alias itself and not already aliased in projection, rewrite it.
270                if !table_aliases.contains(&ident.value)
271                    && !projection_aliases.contains(&ident.value)
272                {
273                    *expr = Expr::CompoundIdentifier(vec![
274                        Ident::new(wildcard_alias.to_string()),
275                        ident.clone(),
276                    ]);
277                }
278            }
279            Expr::BinaryOp { left, right, .. } => {
280                Self::rewrite_expr(left, wildcard_alias, table_aliases, projection_aliases);
281                Self::rewrite_expr(right, wildcard_alias, table_aliases, projection_aliases);
282            }
283            // Add more cases for other expression types as needed (e.g., `InList`, `Between`, etc.)
284            _ => {}
285        }
286    }
287}
288
289impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
290    fn rewrite(&self, mut statement: Statement) -> Statement {
291        if let Statement::Query(query) = &mut statement {
292            Self::rewrite_unqualified_identifiers(query);
293        }
294
295        statement
296    }
297}
298
299/// Remove datafusion unsupported type annotations
300/// it also removes pg_catalog as qualifier
301#[derive(Debug)]
302pub struct RemoveUnsupportedTypes {
303    unsupported_types: HashSet<String>,
304}
305
306impl Default for RemoveUnsupportedTypes {
307    fn default() -> Self {
308        Self::new()
309    }
310}
311
312impl RemoveUnsupportedTypes {
313    pub fn new() -> Self {
314        let mut unsupported_types = HashSet::new();
315
316        for item in [
317            "regclass",
318            "regproc",
319            "regtype",
320            "regtype[]",
321            "regnamespace",
322            "oid",
323        ] {
324            unsupported_types.insert(item.to_owned());
325            unsupported_types.insert(format!("pg_catalog.{item}"));
326        }
327
328        Self { unsupported_types }
329    }
330}
331
332struct RemoveUnsupportedTypesVisitor<'a> {
333    unsupported_types: &'a HashSet<String>,
334}
335
336impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> {
337    type Break = ();
338
339    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
340        match expr {
341            // This is the key part: identify constants with type annotations.
342            Expr::TypedString(TypedString {
343                data_type,
344                value,
345                uses_odbc_syntax: _,
346            }) => {
347                if self
348                    .unsupported_types
349                    .contains(data_type.to_string().to_lowercase().as_str())
350                {
351                    *expr =
352                        Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span());
353                }
354            }
355            Expr::Cast {
356                data_type,
357                expr: value,
358                ..
359            } => {
360                if self
361                    .unsupported_types
362                    .contains(data_type.to_string().to_lowercase().as_str())
363                {
364                    *expr = *value.clone();
365                }
366            }
367            // Add more match arms for other expression types (e.g., `Function`, `InList`) as needed.
368            _ => {}
369        }
370
371        ControlFlow::Continue(())
372    }
373}
374
375impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
376    fn rewrite(&self, mut statement: Statement) -> Statement {
377        let mut visitor = RemoveUnsupportedTypesVisitor {
378            unsupported_types: &self.unsupported_types,
379        };
380        let _ = statement.visit(&mut visitor);
381        statement
382    }
383}
384
385/// Rewrite Postgres's ANY operator to array_contains
386#[derive(Debug)]
387pub struct RewriteArrayAnyAllOperation;
388
389struct RewriteArrayAnyAllOperationVisitor;
390
391impl RewriteArrayAnyAllOperationVisitor {
392    fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr {
393        let array = if let Expr::Value(ValueWithSpan {
394            value: Value::SingleQuotedString(array_literal),
395            ..
396        }) = right
397        {
398            let array_literal = array_literal.trim();
399            if array_literal.starts_with('{') && array_literal.ends_with('}') {
400                let items = array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
401                let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
402
403                // For now, we assume the data type is string
404                let elems = items
405                    .map(|s| {
406                        Expr::Value(Value::SingleQuotedString(s.to_string()).with_empty_span())
407                    })
408                    .collect();
409                Expr::Array(Array {
410                    elem: elems,
411                    named: true,
412                })
413            } else {
414                right.clone()
415            }
416        } else {
417            right.clone()
418        };
419
420        Expr::Function(Function {
421            name: ObjectName::from(vec![Ident::new("array_contains")]),
422            args: FunctionArguments::List(FunctionArgumentList {
423                args: vec![
424                    FunctionArg::Unnamed(FunctionArgExpr::Expr(array)),
425                    FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())),
426                ],
427                duplicate_treatment: None,
428                clauses: vec![],
429            }),
430            uses_odbc_syntax: false,
431            parameters: FunctionArguments::None,
432            filter: None,
433            null_treatment: None,
434            over: None,
435            within_group: vec![],
436        })
437    }
438}
439
440impl VisitorMut for RewriteArrayAnyAllOperationVisitor {
441    type Break = ();
442
443    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
444        match expr {
445            Expr::AnyOp {
446                left,
447                compare_op,
448                right,
449                ..
450            } => match compare_op {
451                BinaryOperator::Eq => {
452                    *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref());
453                }
454                BinaryOperator::NotEq => {
455                    // TODO:left not equals to any element in array
456                }
457                _ => {}
458            },
459            Expr::AllOp {
460                left,
461                compare_op,
462                right,
463            } => match compare_op {
464                BinaryOperator::Eq => {
465                    // TODO: left equals to every element in array
466                }
467                BinaryOperator::NotEq => {
468                    *expr = Expr::UnaryOp {
469                        op: UnaryOperator::Not,
470                        expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())),
471                    }
472                }
473                _ => {}
474            },
475            _ => {}
476        }
477
478        ControlFlow::Continue(())
479    }
480}
481
482impl SqlStatementRewriteRule for RewriteArrayAnyAllOperation {
483    fn rewrite(&self, mut s: Statement) -> Statement {
484        let mut visitor = RewriteArrayAnyAllOperationVisitor;
485
486        let _ = s.visit(&mut visitor);
487
488        s
489    }
490}
491
492/// Prepend qualifier to table_name
493///
494/// Postgres has pg_catalog in search_path by default so it allow access to
495/// `pg_namespace` without `pg_catalog.` qualifier
496#[derive(Debug)]
497pub struct PrependUnqualifiedPgTableName;
498
499struct PrependUnqualifiedPgTableNameVisitor;
500
501impl VisitorMut for PrependUnqualifiedPgTableNameVisitor {
502    type Break = ();
503
504    fn pre_visit_table_factor(
505        &mut self,
506        table_factor: &mut TableFactor,
507    ) -> ControlFlow<Self::Break> {
508        if let TableFactor::Table { name, args, .. } = table_factor {
509            // not a table function
510            if args.is_none() && name.0.len() == 1 {
511                if let ObjectNamePart::Identifier(ident) = &name.0[0] {
512                    if ident.value.starts_with("pg_") {
513                        *name = ObjectName(vec![
514                            ObjectNamePart::Identifier(Ident::new("pg_catalog")),
515                            name.0[0].clone(),
516                        ]);
517                    }
518                }
519            }
520        }
521
522        ControlFlow::Continue(())
523    }
524}
525
526impl SqlStatementRewriteRule for PrependUnqualifiedPgTableName {
527    fn rewrite(&self, mut s: Statement) -> Statement {
528        let mut visitor = PrependUnqualifiedPgTableNameVisitor;
529
530        let _ = s.visit(&mut visitor);
531        s
532    }
533}
534
535#[derive(Debug)]
536pub struct FixArrayLiteral;
537
538struct FixArrayLiteralVisitor;
539
540impl FixArrayLiteralVisitor {
541    fn is_string_type(dt: &DataType) -> bool {
542        matches!(
543            dt,
544            DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::String(_)
545        )
546    }
547}
548
549impl VisitorMut for FixArrayLiteralVisitor {
550    type Break = ();
551
552    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
553        if let Expr::Cast {
554            kind,
555            expr,
556            data_type,
557            ..
558        } = expr
559        {
560            if kind == &CastKind::DoubleColon {
561                if let DataType::Array(arr) = data_type {
562                    // cast some to
563                    if let Expr::Value(ValueWithSpan {
564                        value: Value::SingleQuotedString(array_literal),
565                        ..
566                    }) = expr.as_ref()
567                    {
568                        let items =
569                            array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
570                        let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
571
572                        let is_text = match arr {
573                            ArrayElemTypeDef::AngleBracket(dt) => Self::is_string_type(dt.as_ref()),
574                            ArrayElemTypeDef::SquareBracket(dt, _) => {
575                                Self::is_string_type(dt.as_ref())
576                            }
577                            ArrayElemTypeDef::Parenthesis(dt) => Self::is_string_type(dt.as_ref()),
578                            _ => false,
579                        };
580
581                        let elems = items
582                            .map(|s| {
583                                if is_text {
584                                    Expr::Value(
585                                        Value::SingleQuotedString(s.to_string()).with_empty_span(),
586                                    )
587                                } else {
588                                    Expr::Value(
589                                        Value::Number(s.to_string(), false).with_empty_span(),
590                                    )
591                                }
592                            })
593                            .collect();
594                        **expr = Expr::Array(Array {
595                            elem: elems,
596                            named: true,
597                        });
598                    }
599                }
600            }
601        }
602
603        ControlFlow::Continue(())
604    }
605}
606
607impl SqlStatementRewriteRule for FixArrayLiteral {
608    fn rewrite(&self, mut s: Statement) -> Statement {
609        let mut visitor = FixArrayLiteralVisitor;
610
611        let _ = s.visit(&mut visitor);
612        s
613    }
614}
615
616/// Remove qualifier from unsupported items
617///
618/// This rewriter removes qualifier from following items:
619/// 1. type cast: for example: `pg_catalog.text`
620/// 2. function name: for example: `pg_catalog.array_to_string`,
621/// 3. table function name
622#[derive(Debug)]
623pub struct RemoveQualifier;
624
625struct RemoveQualifierVisitor;
626
627impl VisitorMut for RemoveQualifierVisitor {
628    type Break = ();
629
630    fn pre_visit_table_factor(
631        &mut self,
632        table_factor: &mut TableFactor,
633    ) -> ControlFlow<Self::Break> {
634        // remove table function qualifier
635        if let TableFactor::Table { name, args, .. } = table_factor {
636            if args.is_some() {
637                //  multiple idents in name, which means it's a qualified table name
638                if name.0.len() > 1 {
639                    if let Some(last_ident) = name.0.pop() {
640                        *name = ObjectName(vec![last_ident]);
641                    }
642                }
643            }
644        }
645        ControlFlow::Continue(())
646    }
647
648    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
649        match expr {
650            Expr::Cast { data_type, .. } => {
651                // rewrite custom pg_catalog. qualified types
652                let data_type_str = data_type.to_string();
653                match data_type_str.as_str() {
654                    "pg_catalog.text" => {
655                        *data_type = DataType::Text;
656                    }
657                    "pg_catalog.int2[]" => {
658                        *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket(
659                            Box::new(DataType::Int16),
660                            None,
661                        ));
662                    }
663                    _ => {}
664                }
665            }
666            Expr::Function(function) => {
667                // remove qualifier from pg_catalog.function
668                let name = &mut function.name;
669                if name.0.len() > 1 {
670                    if let Some(last_ident) = name.0.pop() {
671                        *name = ObjectName(vec![last_ident]);
672                    }
673                }
674            }
675
676            _ => {}
677        }
678        ControlFlow::Continue(())
679    }
680}
681
682impl SqlStatementRewriteRule for RemoveQualifier {
683    fn rewrite(&self, mut s: Statement) -> Statement {
684        let mut visitor = RemoveQualifierVisitor;
685
686        let _ = s.visit(&mut visitor);
687        s
688    }
689}
690
691/// Replace `current_user` with `session_user()`
692#[derive(Debug)]
693pub struct CurrentUserVariableToSessionUserFunctionCall;
694
695struct CurrentUserVariableToSessionUserFunctionCallVisitor;
696
697impl VisitorMut for CurrentUserVariableToSessionUserFunctionCallVisitor {
698    type Break = ();
699
700    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
701        if let Expr::Identifier(ident) = expr {
702            if ident.quote_style.is_none() && ident.value.to_lowercase() == "current_user" {
703                *expr = Expr::Function(Function {
704                    name: ObjectName::from(vec![Ident::new("session_user")]),
705                    args: FunctionArguments::None,
706                    uses_odbc_syntax: false,
707                    parameters: FunctionArguments::None,
708                    filter: None,
709                    null_treatment: None,
710                    over: None,
711                    within_group: vec![],
712                });
713            }
714        }
715
716        if let Expr::Function(func) = expr {
717            let fname = func
718                .name
719                .0
720                .iter()
721                .map(|ident| ident.to_string())
722                .collect::<Vec<String>>()
723                .join(".");
724            if fname.to_lowercase() == "current_user" {
725                func.name = ObjectName::from(vec![Ident::new("session_user")])
726            }
727        }
728
729        ControlFlow::Continue(())
730    }
731}
732
733impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall {
734    fn rewrite(&self, mut s: Statement) -> Statement {
735        let mut visitor = CurrentUserVariableToSessionUserFunctionCallVisitor;
736
737        let _ = s.visit(&mut visitor);
738        s
739    }
740}
741
742/// Fix collate and regex calls
743#[derive(Debug)]
744pub struct FixCollate;
745
746struct FixCollateVisitor;
747
748impl VisitorMut for FixCollateVisitor {
749    type Break = ();
750
751    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
752        match expr {
753            Expr::Collate { expr: inner, .. } => {
754                *expr = inner.as_ref().clone();
755            }
756            Expr::BinaryOp { op, .. } => {
757                if let BinaryOperator::PGCustomBinaryOperator(ops) = op {
758                    if *ops == ["pg_catalog", "~"] {
759                        *op = BinaryOperator::PGRegexMatch;
760                    }
761                }
762            }
763            _ => {}
764        }
765
766        ControlFlow::Continue(())
767    }
768}
769
770impl SqlStatementRewriteRule for FixCollate {
771    fn rewrite(&self, mut s: Statement) -> Statement {
772        let mut visitor = FixCollateVisitor;
773
774        let _ = s.visit(&mut visitor);
775        s
776    }
777}
778
779/// Datafusion doesn't support subquery on projection
780#[derive(Debug)]
781pub struct RemoveSubqueryFromProjection;
782
783struct RemoveSubqueryFromProjectionVisitor;
784
785impl VisitorMut for RemoveSubqueryFromProjectionVisitor {
786    type Break = ();
787
788    fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
789        if let SetExpr::Select(select) = query.body.as_mut() {
790            for projection in &mut select.projection {
791                match projection {
792                    SelectItem::UnnamedExpr(expr) => {
793                        if let Expr::Subquery(_) = expr {
794                            *expr = Expr::Value(Value::Null.with_empty_span());
795                        }
796                    }
797                    SelectItem::ExprWithAlias { expr, .. } => {
798                        if let Expr::Subquery(_) = expr {
799                            *expr = Expr::Value(Value::Null.with_empty_span());
800                        }
801                    }
802                    _ => {}
803                }
804            }
805        }
806
807        ControlFlow::Continue(())
808    }
809}
810
811impl SqlStatementRewriteRule for RemoveSubqueryFromProjection {
812    fn rewrite(&self, mut s: Statement) -> Statement {
813        let mut visitor = RemoveSubqueryFromProjectionVisitor;
814        let _ = s.visit(&mut visitor);
815
816        s
817    }
818}
819
820/// `select version()` should return column named `version` not `version()`
821#[derive(Debug)]
822pub struct FixVersionColumnName;
823
824struct FixVersionColumnNameVisitor;
825
826impl VisitorMut for FixVersionColumnNameVisitor {
827    type Break = ();
828
829    fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
830        if let SetExpr::Select(select) = query.body.as_mut() {
831            for projection in &mut select.projection {
832                if let SelectItem::UnnamedExpr(Expr::Function(f)) = projection {
833                    if f.name.0.len() == 1 {
834                        if let ObjectNamePart::Identifier(part) = &f.name.0[0] {
835                            if part.value == "version" {
836                                if let FunctionArguments::List(args) = &f.args {
837                                    if args.args.is_empty() {
838                                        *projection = SelectItem::ExprWithAlias {
839                                            expr: Expr::Function(f.clone()),
840                                            alias: Ident::new("version"),
841                                        }
842                                    }
843                                }
844                            }
845                        }
846                    }
847                }
848            }
849        }
850
851        ControlFlow::Continue(())
852    }
853}
854
855impl SqlStatementRewriteRule for FixVersionColumnName {
856    fn rewrite(&self, mut s: Statement) -> Statement {
857        let mut visitor = FixVersionColumnNameVisitor;
858        let _ = s.visit(&mut visitor);
859
860        s
861    }
862}
863
864#[cfg(test)]
865mod tests {
866    use super::*;
867    use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
868    use datafusion::sql::sqlparser::parser::Parser;
869    use datafusion::sql::sqlparser::parser::ParserError;
870    use std::sync::Arc;
871
872    fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> {
873        let dialect = PostgreSqlDialect {};
874
875        Parser::parse_sql(&dialect, sql)
876    }
877
878    fn rewrite(mut s: Statement, rules: &[Arc<dyn SqlStatementRewriteRule>]) -> Statement {
879        for rule in rules {
880            s = rule.rewrite(s);
881        }
882
883        s
884    }
885
886    macro_rules! assert_rewrite {
887        ($rules:expr, $orig:expr, $rewt:expr) => {
888            let sql = $orig;
889            let statement = parse(sql).expect("Failed to parse").remove(0);
890
891            let statement = rewrite(statement, $rules);
892            assert_eq!(statement.to_string(), $rewt);
893        };
894    }
895
896    #[test]
897    fn test_alias_rewrite() {
898        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
899            vec![Arc::new(AliasDuplicatedProjectionRewrite)];
900
901        assert_rewrite!(
902            &rules,
903            "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n",
904            "SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
905        );
906
907        assert_rewrite!(
908            &rules,
909            "SELECT oid, * FROM pg_catalog.pg_namespace",
910            "SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
911        );
912
913        assert_rewrite!(
914            &rules,
915            "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id",
916            "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
917        );
918
919        let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
920        let statement = parse(sql).expect("Failed to parse").remove(0);
921
922        let statement = rewrite(statement, &rules);
923        assert_eq!(
924            statement.to_string(),
925            "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname"
926        );
927    }
928
929    #[test]
930    fn test_qualifier_prepend() {
931        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
932            vec![Arc::new(ResolveUnqualifiedIdentifer)];
933
934        assert_rewrite!(
935            &rules,
936            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname",
937            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
938        );
939
940        assert_rewrite!(
941            &rules,
942            "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname",
943            "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
944        );
945
946        assert_rewrite!(
947            &rules,
948            "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname",
949            "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
950        );
951
952        assert_rewrite!(&rules,
953            "SELECT i.*,i.indkey as keys,c.relname,c.relnamespace,c.relam,c.reltablespace,tc.relname as tabrelname,dsc.description FROM pg_catalog.pg_index i INNER JOIN pg_catalog.pg_class c ON c.oid=i.indexrelid INNER JOIN pg_catalog.pg_class tc ON tc.oid=i.indrelid LEFT OUTER JOIN pg_catalog.pg_description dsc ON i.indexrelid=dsc.objoid WHERE i.indrelid=1 ORDER BY tabrelname, c.relname",
954            "SELECT i.*, i.indkey AS keys, c.relname, c.relnamespace, c.relam, c.reltablespace, tc.relname AS tabrelname, dsc.description FROM pg_catalog.pg_index AS i INNER JOIN pg_catalog.pg_class AS c ON c.oid = i.indexrelid INNER JOIN pg_catalog.pg_class AS tc ON tc.oid = i.indrelid LEFT OUTER JOIN pg_catalog.pg_description AS dsc ON i.indexrelid = dsc.objoid WHERE i.indrelid = 1 ORDER BY tabrelname, c.relname"
955        );
956    }
957
958    #[test]
959    fn test_remove_unsupported_types() {
960        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
961            Arc::new(RemoveQualifier),
962            Arc::new(RemoveUnsupportedTypes::new()),
963        ];
964
965        assert_rewrite!(
966            &rules,
967            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
968            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
969        );
970
971        assert_rewrite!(
972            &rules,
973            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
974            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname"
975        );
976
977        assert_rewrite!(
978            &rules,
979            "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname",
980            "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname"
981        );
982
983        assert_rewrite!(
984            &rules,
985            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname",
986            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
987        );
988
989        assert_rewrite!(
990            &rules,
991            "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::pg_catalog.regtype::pg_catalog.text END, c.relpersistence, c.relreplident, am.amname
992    FROM pg_catalog.pg_class c
993     LEFT JOIN pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid)
994    LEFT JOIN pg_catalog.pg_am am ON (c.relam = am.oid)
995    WHERE c.oid = '16386'",
996            "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::TEXT END, c.relpersistence, c.relreplident, am.amname FROM pg_catalog.pg_class AS c LEFT JOIN pg_catalog.pg_class AS tc ON (c.reltoastrelid = tc.oid) LEFT JOIN pg_catalog.pg_am AS am ON (c.relam = am.oid) WHERE c.oid = '16386'"
997        );
998    }
999
1000    #[test]
1001    fn test_any_to_array_contains() {
1002        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1003            vec![Arc::new(RewriteArrayAnyAllOperation)];
1004
1005        assert_rewrite!(
1006            &rules,
1007            "SELECT a = ANY(current_schemas(true))",
1008            "SELECT array_contains(current_schemas(true), a)"
1009        );
1010
1011        assert_rewrite!(
1012            &rules,
1013            "SELECT a <> ALL(current_schemas(true))",
1014            "SELECT NOT array_contains(current_schemas(true), a)"
1015        );
1016
1017        assert_rewrite!(
1018            &rules,
1019            "SELECT a = ANY('{r, l, e}')",
1020            "SELECT array_contains(ARRAY['r', 'l', 'e'], a)"
1021        );
1022
1023        assert_rewrite!(
1024            &rules,
1025            "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))",
1026            "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)"
1027        );
1028    }
1029
1030    #[test]
1031    fn test_prepend_unqualified_table_name() {
1032        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1033            vec![Arc::new(PrependUnqualifiedPgTableName)];
1034
1035        assert_rewrite!(
1036            &rules,
1037            "SELECT * FROM pg_catalog.pg_namespace",
1038            "SELECT * FROM pg_catalog.pg_namespace"
1039        );
1040
1041        assert_rewrite!(
1042            &rules,
1043            "SELECT * FROM pg_namespace",
1044            "SELECT * FROM pg_catalog.pg_namespace"
1045        );
1046
1047        assert_rewrite!(
1048            &rules,
1049            "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid",
1050            "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid"
1051        );
1052    }
1053
1054    #[test]
1055    fn test_array_literal_fix() {
1056        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixArrayLiteral)];
1057
1058        assert_rewrite!(
1059            &rules,
1060            "SELECT '{a, abc}'::text[]",
1061            "SELECT ARRAY['a', 'abc']::TEXT[]"
1062        );
1063
1064        assert_rewrite!(
1065            &rules,
1066            "SELECT '{1, 2}'::int[]",
1067            "SELECT ARRAY[1, 2]::INT[]"
1068        );
1069
1070        assert_rewrite!(
1071            &rules,
1072            "SELECT '{t, f}'::bool[]",
1073            "SELECT ARRAY[t, f]::BOOL[]"
1074        );
1075    }
1076
1077    #[test]
1078    fn test_remove_qualifier_from_table_function() {
1079        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(RemoveQualifier)];
1080
1081        assert_rewrite!(
1082            &rules,
1083            "SELECT * FROM pg_catalog.pg_get_keywords()",
1084            "SELECT * FROM pg_get_keywords()"
1085        );
1086    }
1087
1088    #[test]
1089    fn test_current_user() {
1090        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1091            vec![Arc::new(CurrentUserVariableToSessionUserFunctionCall)];
1092
1093        assert_rewrite!(&rules, "SELECT current_user", "SELECT session_user");
1094
1095        assert_rewrite!(&rules, "SELECT CURRENT_USER", "SELECT session_user");
1096
1097        assert_rewrite!(
1098            &rules,
1099            "SELECT is_null(current_user)",
1100            "SELECT is_null(session_user)"
1101        );
1102    }
1103
1104    #[test]
1105    fn test_collate_fix() {
1106        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixCollate)];
1107
1108        assert_rewrite!(&rules, "SELECT c.oid, c.relname FROM pg_catalog.pg_class c WHERE c.relname OPERATOR(pg_catalog.~) '^(tablename)$' COLLATE pg_catalog.default AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3;", "SELECT c.oid, c.relname FROM pg_catalog.pg_class AS c WHERE c.relname ~ '^(tablename)$' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3");
1109    }
1110
1111    #[test]
1112    fn test_remove_subquery() {
1113        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1114            vec![Arc::new(RemoveSubqueryFromProjection)];
1115
1116        assert_rewrite!(&rules,
1117            "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;",
1118            "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum");
1119    }
1120
1121    #[test]
1122    fn test_version_rewrite() {
1123        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixVersionColumnName)];
1124
1125        assert_rewrite!(&rules, "SELECT version()", "SELECT version() AS version");
1126
1127        // Make sure we don't rewrite things we should leave alone
1128        assert_rewrite!(&rules, "SELECT version() as foo", "SELECT version() AS foo");
1129        assert_rewrite!(&rules, "SELECT version(foo)", "SELECT version(foo)");
1130        assert_rewrite!(&rules, "SELECT foo.version()", "SELECT foo.version()");
1131    }
1132}