datafusion_pg_catalog/sql/
rules.rs

1use std::collections::HashSet;
2use std::fmt::Debug;
3use std::ops::ControlFlow;
4
5use datafusion::sql::sqlparser::ast::Array;
6use datafusion::sql::sqlparser::ast::ArrayElemTypeDef;
7use datafusion::sql::sqlparser::ast::BinaryOperator;
8use datafusion::sql::sqlparser::ast::CastKind;
9use datafusion::sql::sqlparser::ast::DataType;
10use datafusion::sql::sqlparser::ast::Expr;
11use datafusion::sql::sqlparser::ast::Function;
12use datafusion::sql::sqlparser::ast::FunctionArg;
13use datafusion::sql::sqlparser::ast::FunctionArgExpr;
14use datafusion::sql::sqlparser::ast::FunctionArgumentList;
15use datafusion::sql::sqlparser::ast::FunctionArguments;
16use datafusion::sql::sqlparser::ast::Ident;
17use datafusion::sql::sqlparser::ast::ObjectName;
18use datafusion::sql::sqlparser::ast::ObjectNamePart;
19use datafusion::sql::sqlparser::ast::OrderByKind;
20use datafusion::sql::sqlparser::ast::Query;
21use datafusion::sql::sqlparser::ast::Select;
22use datafusion::sql::sqlparser::ast::SelectItem;
23use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
24use datafusion::sql::sqlparser::ast::SetExpr;
25use datafusion::sql::sqlparser::ast::Statement;
26use datafusion::sql::sqlparser::ast::TableFactor;
27use datafusion::sql::sqlparser::ast::TableWithJoins;
28use datafusion::sql::sqlparser::ast::UnaryOperator;
29use datafusion::sql::sqlparser::ast::Value;
30use datafusion::sql::sqlparser::ast::ValueWithSpan;
31use datafusion::sql::sqlparser::ast::VisitMut;
32use datafusion::sql::sqlparser::ast::VisitorMut;
33
34pub trait SqlStatementRewriteRule: Send + Sync + Debug {
35    fn rewrite(&self, s: Statement) -> Statement;
36}
37
38/// Rewrite rule for adding alias to duplicated projection
39///
40/// This rule is to deal with sql like `SELECT n.oid, n.* FROM n`, which is a
41/// valid statement in postgres. But datafusion treat it as illegal because of
42/// duplicated column oid in projection.
43///
44/// This rule will add alias to column, when there is a wildcard found in
45/// projection.
46#[derive(Debug)]
47pub struct AliasDuplicatedProjectionRewrite;
48
49impl AliasDuplicatedProjectionRewrite {
50    // Rewrites a SELECT statement to alias explicit columns from the same table as a qualified wildcard.
51    fn rewrite_select_with_alias(select: &mut Box<Select>) {
52        // 1. Collect all table aliases from qualified wildcards.
53        let mut wildcard_tables = Vec::new();
54        let mut has_simple_wildcard = false;
55        for p in &select.projection {
56            match p {
57                SelectItem::QualifiedWildcard(name, _) => match name {
58                    SelectItemQualifiedWildcardKind::ObjectName(objname) => {
59                        // for n.oid,
60                        let idents = objname
61                            .0
62                            .iter()
63                            .map(|v| v.as_ident().unwrap().value.clone())
64                            .collect::<Vec<_>>()
65                            .join(".");
66
67                        wildcard_tables.push(idents);
68                    }
69                    SelectItemQualifiedWildcardKind::Expr(_expr) => {
70                        // FIXME:
71                    }
72                },
73                SelectItem::Wildcard(_) => {
74                    has_simple_wildcard = true;
75                }
76                _ => {}
77            }
78        }
79
80        // If there are no qualified wildcards, there's nothing to do.
81        if wildcard_tables.is_empty() && !has_simple_wildcard {
82            return;
83        }
84
85        // 2. Rewrite the projection, adding aliases to matching columns.
86        let mut new_projection = vec![];
87        for p in select.projection.drain(..) {
88            match p {
89                SelectItem::UnnamedExpr(expr) => {
90                    let alias_partial = match &expr {
91                        // Case for `oid` (unqualified identifier)
92                        Expr::Identifier(ident) => Some(ident.clone()),
93                        // Case for `n.oid` (compound identifier)
94                        Expr::CompoundIdentifier(idents) => {
95                            // compare every ident but the last
96                            if idents.len() > 1 {
97                                let table_name = &idents[..idents.len() - 1]
98                                    .iter()
99                                    .map(|i| i.value.clone())
100                                    .collect::<Vec<_>>()
101                                    .join(".");
102                                if wildcard_tables.iter().any(|name| name == table_name) {
103                                    Some(idents[idents.len() - 1].clone())
104                                } else {
105                                    None
106                                }
107                            } else {
108                                None
109                            }
110                        }
111                        _ => None,
112                    };
113
114                    if let Some(name) = alias_partial {
115                        let alias = format!("__alias_{name}");
116                        new_projection.push(SelectItem::ExprWithAlias {
117                            expr,
118                            alias: Ident::new(alias),
119                        });
120                    } else {
121                        new_projection.push(SelectItem::UnnamedExpr(expr));
122                    }
123                }
124                // Preserve existing aliases and wildcards.
125                _ => new_projection.push(p),
126            }
127        }
128        select.projection = new_projection;
129    }
130}
131
132impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
133    fn rewrite(&self, mut statement: Statement) -> Statement {
134        if let Statement::Query(query) = &mut statement {
135            if let SetExpr::Select(select) = query.body.as_mut() {
136                Self::rewrite_select_with_alias(select);
137            }
138        }
139
140        statement
141    }
142}
143
144/// Prepend qualifier for order by or filter when there is qualified wildcard
145///
146/// Postgres allows unqualified identifier in ORDER BY and FILTER but it's not
147/// accepted by datafusion.
148#[derive(Debug)]
149pub struct ResolveUnqualifiedIdentifer;
150
151impl ResolveUnqualifiedIdentifer {
152    fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
153        if let SetExpr::Select(select) = query.body.as_mut() {
154            // Step 1: Find all table aliases from FROM and JOIN clauses.
155            let table_aliases = Self::get_table_aliases(&select.from);
156
157            // Step 2: Check for a single qualified wildcard in the projection.
158            let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
159            if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
160                return; // Conditions not met.
161            }
162
163            let wildcard_alias = qualified_wildcard_alias.unwrap();
164
165            // Step 2.5: Collect all projection aliases to avoid rewriting them
166            let projection_aliases = Self::get_projection_aliases(&select.projection);
167
168            // Step 3: Rewrite expressions in the WHERE and ORDER BY clauses.
169            if let Some(selection) = &mut select.selection {
170                Self::rewrite_expr(
171                    selection,
172                    &wildcard_alias,
173                    &table_aliases,
174                    &projection_aliases,
175                );
176            }
177
178            if let Some(OrderByKind::Expressions(order_by_exprs)) =
179                query.order_by.as_mut().map(|o| &mut o.kind)
180            {
181                for order_by_expr in order_by_exprs {
182                    Self::rewrite_expr(
183                        &mut order_by_expr.expr,
184                        &wildcard_alias,
185                        &table_aliases,
186                        &projection_aliases,
187                    );
188                }
189            }
190        }
191    }
192
193    fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
194        let mut aliases = HashSet::new();
195        for table_with_joins in tables {
196            if let TableFactor::Table {
197                alias: Some(alias), ..
198            } = &table_with_joins.relation
199            {
200                aliases.insert(alias.name.value.clone());
201            }
202            for join in &table_with_joins.joins {
203                if let TableFactor::Table {
204                    alias: Some(alias), ..
205                } = &join.relation
206                {
207                    aliases.insert(alias.name.value.clone());
208                }
209            }
210        }
211        aliases
212    }
213
214    fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
215        let mut qualified_wildcards = projection
216            .iter()
217            .filter_map(|item| {
218                if let SelectItem::QualifiedWildcard(
219                    SelectItemQualifiedWildcardKind::ObjectName(objname),
220                    _,
221                ) = item
222                {
223                    Some(
224                        objname
225                            .0
226                            .iter()
227                            .map(|v| v.as_ident().unwrap().value.clone())
228                            .collect::<Vec<_>>()
229                            .join("."),
230                    )
231                } else {
232                    None
233                }
234            })
235            .collect::<Vec<_>>();
236
237        if qualified_wildcards.len() == 1 {
238            Some(qualified_wildcards.remove(0))
239        } else {
240            None
241        }
242    }
243
244    fn get_projection_aliases(projection: &[SelectItem]) -> HashSet<String> {
245        let mut aliases = HashSet::new();
246        for item in projection {
247            match item {
248                SelectItem::ExprWithAlias { alias, .. } => {
249                    aliases.insert(alias.value.clone());
250                }
251                SelectItem::UnnamedExpr(Expr::Identifier(ident)) => {
252                    aliases.insert(ident.value.clone());
253                }
254                _ => {}
255            }
256        }
257        aliases
258    }
259
260    fn rewrite_expr(
261        expr: &mut Expr,
262        wildcard_alias: &str,
263        table_aliases: &HashSet<String>,
264        projection_aliases: &HashSet<String>,
265    ) {
266        match expr {
267            Expr::Identifier(ident) => {
268                // If the identifier is not a table alias itself and not already aliased in projection, rewrite it.
269                if !table_aliases.contains(&ident.value)
270                    && !projection_aliases.contains(&ident.value)
271                {
272                    *expr = Expr::CompoundIdentifier(vec![
273                        Ident::new(wildcard_alias.to_string()),
274                        ident.clone(),
275                    ]);
276                }
277            }
278            Expr::BinaryOp { left, right, .. } => {
279                Self::rewrite_expr(left, wildcard_alias, table_aliases, projection_aliases);
280                Self::rewrite_expr(right, wildcard_alias, table_aliases, projection_aliases);
281            }
282            // Add more cases for other expression types as needed (e.g., `InList`, `Between`, etc.)
283            _ => {}
284        }
285    }
286}
287
288impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
289    fn rewrite(&self, mut statement: Statement) -> Statement {
290        if let Statement::Query(query) = &mut statement {
291            Self::rewrite_unqualified_identifiers(query);
292        }
293
294        statement
295    }
296}
297
298/// Remove datafusion unsupported type annotations
299/// it also removes pg_catalog as qualifier
300#[derive(Debug)]
301pub struct RemoveUnsupportedTypes {
302    unsupported_types: HashSet<String>,
303}
304
305impl Default for RemoveUnsupportedTypes {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311impl RemoveUnsupportedTypes {
312    pub fn new() -> Self {
313        let mut unsupported_types = HashSet::new();
314
315        for item in [
316            "regclass",
317            "regproc",
318            "regtype",
319            "regtype[]",
320            "regnamespace",
321            "oid",
322        ] {
323            unsupported_types.insert(item.to_owned());
324            unsupported_types.insert(format!("pg_catalog.{item}"));
325        }
326
327        Self { unsupported_types }
328    }
329}
330
331struct RemoveUnsupportedTypesVisitor<'a> {
332    unsupported_types: &'a HashSet<String>,
333}
334
335impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> {
336    type Break = ();
337
338    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
339        match expr {
340            // This is the key part: identify constants with type annotations.
341            Expr::TypedString { value, data_type } => {
342                if self
343                    .unsupported_types
344                    .contains(data_type.to_string().to_lowercase().as_str())
345                {
346                    *expr =
347                        Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span());
348                }
349            }
350            Expr::Cast {
351                data_type,
352                expr: value,
353                ..
354            } => {
355                if self
356                    .unsupported_types
357                    .contains(data_type.to_string().to_lowercase().as_str())
358                {
359                    *expr = *value.clone();
360                }
361            }
362            // Add more match arms for other expression types (e.g., `Function`, `InList`) as needed.
363            _ => {}
364        }
365
366        ControlFlow::Continue(())
367    }
368}
369
370impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
371    fn rewrite(&self, mut statement: Statement) -> Statement {
372        let mut visitor = RemoveUnsupportedTypesVisitor {
373            unsupported_types: &self.unsupported_types,
374        };
375        let _ = statement.visit(&mut visitor);
376        statement
377    }
378}
379
380/// Rewrite Postgres's ANY operator to array_contains
381#[derive(Debug)]
382pub struct RewriteArrayAnyAllOperation;
383
384struct RewriteArrayAnyAllOperationVisitor;
385
386impl RewriteArrayAnyAllOperationVisitor {
387    fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr {
388        let array = if let Expr::Value(ValueWithSpan {
389            value: Value::SingleQuotedString(array_literal),
390            ..
391        }) = right
392        {
393            let array_literal = array_literal.trim();
394            if array_literal.starts_with('{') && array_literal.ends_with('}') {
395                let items = array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
396                let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
397
398                // For now, we assume the data type is string
399                let elems = items
400                    .map(|s| {
401                        Expr::Value(Value::SingleQuotedString(s.to_string()).with_empty_span())
402                    })
403                    .collect();
404                Expr::Array(Array {
405                    elem: elems,
406                    named: true,
407                })
408            } else {
409                right.clone()
410            }
411        } else {
412            right.clone()
413        };
414
415        Expr::Function(Function {
416            name: ObjectName::from(vec![Ident::new("array_contains")]),
417            args: FunctionArguments::List(FunctionArgumentList {
418                args: vec![
419                    FunctionArg::Unnamed(FunctionArgExpr::Expr(array)),
420                    FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())),
421                ],
422                duplicate_treatment: None,
423                clauses: vec![],
424            }),
425            uses_odbc_syntax: false,
426            parameters: FunctionArguments::None,
427            filter: None,
428            null_treatment: None,
429            over: None,
430            within_group: vec![],
431        })
432    }
433}
434
435impl VisitorMut for RewriteArrayAnyAllOperationVisitor {
436    type Break = ();
437
438    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
439        match expr {
440            Expr::AnyOp {
441                left,
442                compare_op,
443                right,
444                ..
445            } => match compare_op {
446                BinaryOperator::Eq => {
447                    *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref());
448                }
449                BinaryOperator::NotEq => {
450                    // TODO:left not equals to any element in array
451                }
452                _ => {}
453            },
454            Expr::AllOp {
455                left,
456                compare_op,
457                right,
458            } => match compare_op {
459                BinaryOperator::Eq => {
460                    // TODO: left equals to every element in array
461                }
462                BinaryOperator::NotEq => {
463                    *expr = Expr::UnaryOp {
464                        op: UnaryOperator::Not,
465                        expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())),
466                    }
467                }
468                _ => {}
469            },
470            _ => {}
471        }
472
473        ControlFlow::Continue(())
474    }
475}
476
477impl SqlStatementRewriteRule for RewriteArrayAnyAllOperation {
478    fn rewrite(&self, mut s: Statement) -> Statement {
479        let mut visitor = RewriteArrayAnyAllOperationVisitor;
480
481        let _ = s.visit(&mut visitor);
482
483        s
484    }
485}
486
487/// Prepend qualifier to table_name
488///
489/// Postgres has pg_catalog in search_path by default so it allow access to
490/// `pg_namespace` without `pg_catalog.` qualifier
491#[derive(Debug)]
492pub struct PrependUnqualifiedPgTableName;
493
494struct PrependUnqualifiedPgTableNameVisitor;
495
496impl VisitorMut for PrependUnqualifiedPgTableNameVisitor {
497    type Break = ();
498
499    fn pre_visit_table_factor(
500        &mut self,
501        table_factor: &mut TableFactor,
502    ) -> ControlFlow<Self::Break> {
503        if let TableFactor::Table { name, args, .. } = table_factor {
504            // not a table function
505            if args.is_none() && name.0.len() == 1 {
506                if let ObjectNamePart::Identifier(ident) = &name.0[0] {
507                    if ident.value.starts_with("pg_") {
508                        *name = ObjectName(vec![
509                            ObjectNamePart::Identifier(Ident::new("pg_catalog")),
510                            name.0[0].clone(),
511                        ]);
512                    }
513                }
514            }
515        }
516
517        ControlFlow::Continue(())
518    }
519}
520
521impl SqlStatementRewriteRule for PrependUnqualifiedPgTableName {
522    fn rewrite(&self, mut s: Statement) -> Statement {
523        let mut visitor = PrependUnqualifiedPgTableNameVisitor;
524
525        let _ = s.visit(&mut visitor);
526        s
527    }
528}
529
530#[derive(Debug)]
531pub struct FixArrayLiteral;
532
533struct FixArrayLiteralVisitor;
534
535impl FixArrayLiteralVisitor {
536    fn is_string_type(dt: &DataType) -> bool {
537        matches!(
538            dt,
539            DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::String(_)
540        )
541    }
542}
543
544impl VisitorMut for FixArrayLiteralVisitor {
545    type Break = ();
546
547    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
548        if let Expr::Cast {
549            kind,
550            expr,
551            data_type,
552            ..
553        } = expr
554        {
555            if kind == &CastKind::DoubleColon {
556                if let DataType::Array(arr) = data_type {
557                    // cast some to
558                    if let Expr::Value(ValueWithSpan {
559                        value: Value::SingleQuotedString(array_literal),
560                        ..
561                    }) = expr.as_ref()
562                    {
563                        let items =
564                            array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' ');
565                        let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty());
566
567                        let is_text = match arr {
568                            ArrayElemTypeDef::AngleBracket(dt) => Self::is_string_type(dt.as_ref()),
569                            ArrayElemTypeDef::SquareBracket(dt, _) => {
570                                Self::is_string_type(dt.as_ref())
571                            }
572                            ArrayElemTypeDef::Parenthesis(dt) => Self::is_string_type(dt.as_ref()),
573                            _ => false,
574                        };
575
576                        let elems = items
577                            .map(|s| {
578                                if is_text {
579                                    Expr::Value(
580                                        Value::SingleQuotedString(s.to_string()).with_empty_span(),
581                                    )
582                                } else {
583                                    Expr::Value(
584                                        Value::Number(s.to_string(), false).with_empty_span(),
585                                    )
586                                }
587                            })
588                            .collect();
589                        *expr = Box::new(Expr::Array(Array {
590                            elem: elems,
591                            named: true,
592                        }));
593                    }
594                }
595            }
596        }
597
598        ControlFlow::Continue(())
599    }
600}
601
602impl SqlStatementRewriteRule for FixArrayLiteral {
603    fn rewrite(&self, mut s: Statement) -> Statement {
604        let mut visitor = FixArrayLiteralVisitor;
605
606        let _ = s.visit(&mut visitor);
607        s
608    }
609}
610
611/// Remove qualifier from unsupported items
612///
613/// This rewriter removes qualifier from following items:
614/// 1. type cast: for example: `pg_catalog.text`
615/// 2. function name: for example: `pg_catalog.array_to_string`,
616/// 3. table function name
617#[derive(Debug)]
618pub struct RemoveQualifier;
619
620struct RemoveQualifierVisitor;
621
622impl VisitorMut for RemoveQualifierVisitor {
623    type Break = ();
624
625    fn pre_visit_table_factor(
626        &mut self,
627        table_factor: &mut TableFactor,
628    ) -> ControlFlow<Self::Break> {
629        // remove table function qualifier
630        if let TableFactor::Table { name, args, .. } = table_factor {
631            if args.is_some() {
632                //  multiple idents in name, which means it's a qualified table name
633                if name.0.len() > 1 {
634                    if let Some(last_ident) = name.0.pop() {
635                        *name = ObjectName(vec![last_ident]);
636                    }
637                }
638            }
639        }
640        ControlFlow::Continue(())
641    }
642
643    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
644        match expr {
645            Expr::Cast { data_type, .. } => {
646                // rewrite custom pg_catalog. qualified types
647                let data_type_str = data_type.to_string();
648                match data_type_str.as_str() {
649                    "pg_catalog.text" => {
650                        *data_type = DataType::Text;
651                    }
652                    "pg_catalog.int2[]" => {
653                        *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket(
654                            Box::new(DataType::Int16),
655                            None,
656                        ));
657                    }
658                    _ => {}
659                }
660            }
661            Expr::Function(function) => {
662                // remove qualifier from pg_catalog.function
663                let name = &mut function.name;
664                if name.0.len() > 1 {
665                    if let Some(last_ident) = name.0.pop() {
666                        *name = ObjectName(vec![last_ident]);
667                    }
668                }
669            }
670
671            _ => {}
672        }
673        ControlFlow::Continue(())
674    }
675}
676
677impl SqlStatementRewriteRule for RemoveQualifier {
678    fn rewrite(&self, mut s: Statement) -> Statement {
679        let mut visitor = RemoveQualifierVisitor;
680
681        let _ = s.visit(&mut visitor);
682        s
683    }
684}
685
686/// Replace `current_user` with `session_user()`
687#[derive(Debug)]
688pub struct CurrentUserVariableToSessionUserFunctionCall;
689
690struct CurrentUserVariableToSessionUserFunctionCallVisitor;
691
692impl VisitorMut for CurrentUserVariableToSessionUserFunctionCallVisitor {
693    type Break = ();
694
695    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
696        if let Expr::Identifier(ident) = expr {
697            if ident.quote_style.is_none() && ident.value.to_lowercase() == "current_user" {
698                *expr = Expr::Function(Function {
699                    name: ObjectName::from(vec![Ident::new("session_user")]),
700                    args: FunctionArguments::None,
701                    uses_odbc_syntax: false,
702                    parameters: FunctionArguments::None,
703                    filter: None,
704                    null_treatment: None,
705                    over: None,
706                    within_group: vec![],
707                });
708            }
709        }
710
711        if let Expr::Function(func) = expr {
712            let fname = func
713                .name
714                .0
715                .iter()
716                .map(|ident| ident.to_string())
717                .collect::<Vec<String>>()
718                .join(".");
719            if fname.to_lowercase() == "current_user" {
720                func.name = ObjectName::from(vec![Ident::new("session_user")])
721            }
722        }
723
724        ControlFlow::Continue(())
725    }
726}
727
728impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall {
729    fn rewrite(&self, mut s: Statement) -> Statement {
730        let mut visitor = CurrentUserVariableToSessionUserFunctionCallVisitor;
731
732        let _ = s.visit(&mut visitor);
733        s
734    }
735}
736
737/// Fix collate and regex calls
738#[derive(Debug)]
739pub struct FixCollate;
740
741struct FixCollateVisitor;
742
743impl VisitorMut for FixCollateVisitor {
744    type Break = ();
745
746    fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
747        match expr {
748            Expr::Collate { expr: inner, .. } => {
749                *expr = inner.as_ref().clone();
750            }
751            Expr::BinaryOp { op, .. } => {
752                if let BinaryOperator::PGCustomBinaryOperator(ops) = op {
753                    if *ops == ["pg_catalog", "~"] {
754                        *op = BinaryOperator::PGRegexMatch;
755                    }
756                }
757            }
758            _ => {}
759        }
760
761        ControlFlow::Continue(())
762    }
763}
764
765impl SqlStatementRewriteRule for FixCollate {
766    fn rewrite(&self, mut s: Statement) -> Statement {
767        let mut visitor = FixCollateVisitor;
768
769        let _ = s.visit(&mut visitor);
770        s
771    }
772}
773
774/// Datafusion doesn't support subquery on projection
775#[derive(Debug)]
776pub struct RemoveSubqueryFromProjection;
777
778struct RemoveSubqueryFromProjectionVisitor;
779
780impl VisitorMut for RemoveSubqueryFromProjectionVisitor {
781    type Break = ();
782
783    fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
784        if let SetExpr::Select(select) = query.body.as_mut() {
785            for projection in &mut select.projection {
786                match projection {
787                    SelectItem::UnnamedExpr(expr) => {
788                        if let Expr::Subquery(_) = expr {
789                            *expr = Expr::Value(Value::Null.with_empty_span());
790                        }
791                    }
792                    SelectItem::ExprWithAlias { expr, .. } => {
793                        if let Expr::Subquery(_) = expr {
794                            *expr = Expr::Value(Value::Null.with_empty_span());
795                        }
796                    }
797                    _ => {}
798                }
799            }
800        }
801
802        ControlFlow::Continue(())
803    }
804}
805
806impl SqlStatementRewriteRule for RemoveSubqueryFromProjection {
807    fn rewrite(&self, mut s: Statement) -> Statement {
808        let mut visitor = RemoveSubqueryFromProjectionVisitor;
809        let _ = s.visit(&mut visitor);
810
811        s
812    }
813}
814
815/// `select version()` should return column named `version` not `version()`
816#[derive(Debug)]
817pub struct FixVersionColumnName;
818
819struct FixVersionColumnNameVisitor;
820
821impl VisitorMut for FixVersionColumnNameVisitor {
822    type Break = ();
823
824    fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow<Self::Break> {
825        if let SetExpr::Select(select) = query.body.as_mut() {
826            for projection in &mut select.projection {
827                if let SelectItem::UnnamedExpr(Expr::Function(f)) = projection {
828                    if f.name.0.len() == 1 {
829                        if let ObjectNamePart::Identifier(part) = &f.name.0[0] {
830                            if part.value == "version" {
831                                if let FunctionArguments::List(args) = &f.args {
832                                    if args.args.is_empty() {
833                                        *projection = SelectItem::ExprWithAlias {
834                                            expr: Expr::Function(f.clone()),
835                                            alias: Ident::new("version"),
836                                        }
837                                    }
838                                }
839                            }
840                        }
841                    }
842                }
843            }
844        }
845
846        ControlFlow::Continue(())
847    }
848}
849
850impl SqlStatementRewriteRule for FixVersionColumnName {
851    fn rewrite(&self, mut s: Statement) -> Statement {
852        let mut visitor = FixVersionColumnNameVisitor;
853        let _ = s.visit(&mut visitor);
854
855        s
856    }
857}
858
859#[cfg(test)]
860mod tests {
861    use super::*;
862    use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
863    use datafusion::sql::sqlparser::parser::Parser;
864    use datafusion::sql::sqlparser::parser::ParserError;
865    use std::sync::Arc;
866
867    fn parse(sql: &str) -> Result<Vec<Statement>, ParserError> {
868        let dialect = PostgreSqlDialect {};
869
870        Parser::parse_sql(&dialect, sql)
871    }
872
873    fn rewrite(mut s: Statement, rules: &[Arc<dyn SqlStatementRewriteRule>]) -> Statement {
874        for rule in rules {
875            s = rule.rewrite(s);
876        }
877
878        s
879    }
880
881    macro_rules! assert_rewrite {
882        ($rules:expr, $orig:expr, $rewt:expr) => {
883            let sql = $orig;
884            let statement = parse(sql).expect("Failed to parse").remove(0);
885
886            let statement = rewrite(statement, $rules);
887            assert_eq!(statement.to_string(), $rewt);
888        };
889    }
890
891    #[test]
892    fn test_alias_rewrite() {
893        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
894            vec![Arc::new(AliasDuplicatedProjectionRewrite)];
895
896        assert_rewrite!(
897            &rules,
898            "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n",
899            "SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
900        );
901
902        assert_rewrite!(
903            &rules,
904            "SELECT oid, * FROM pg_catalog.pg_namespace",
905            "SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
906        );
907
908        assert_rewrite!(
909            &rules,
910            "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id",
911            "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
912        );
913
914        let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
915        let statement = parse(sql).expect("Failed to parse").remove(0);
916
917        let statement = rewrite(statement, &rules);
918        assert_eq!(
919            statement.to_string(),
920            "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname"
921        );
922    }
923
924    #[test]
925    fn test_qualifier_prepend() {
926        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
927            vec![Arc::new(ResolveUnqualifiedIdentifer)];
928
929        assert_rewrite!(
930            &rules,
931            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname",
932            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
933        );
934
935        assert_rewrite!(
936            &rules,
937            "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname",
938            "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
939        );
940
941        assert_rewrite!(
942            &rules,
943            "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname",
944            "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
945        );
946
947        assert_rewrite!(&rules,
948            "SELECT i.*,i.indkey as keys,c.relname,c.relnamespace,c.relam,c.reltablespace,tc.relname as tabrelname,dsc.description FROM pg_catalog.pg_index i INNER JOIN pg_catalog.pg_class c ON c.oid=i.indexrelid INNER JOIN pg_catalog.pg_class tc ON tc.oid=i.indrelid LEFT OUTER JOIN pg_catalog.pg_description dsc ON i.indexrelid=dsc.objoid WHERE i.indrelid=1 ORDER BY tabrelname, c.relname",
949            "SELECT i.*, i.indkey AS keys, c.relname, c.relnamespace, c.relam, c.reltablespace, tc.relname AS tabrelname, dsc.description FROM pg_catalog.pg_index AS i INNER JOIN pg_catalog.pg_class AS c ON c.oid = i.indexrelid INNER JOIN pg_catalog.pg_class AS tc ON tc.oid = i.indrelid LEFT OUTER JOIN pg_catalog.pg_description AS dsc ON i.indexrelid = dsc.objoid WHERE i.indrelid = 1 ORDER BY tabrelname, c.relname"
950        );
951    }
952
953    #[test]
954    fn test_remove_unsupported_types() {
955        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
956            Arc::new(RemoveQualifier),
957            Arc::new(RemoveUnsupportedTypes::new()),
958        ];
959
960        assert_rewrite!(
961            &rules,
962            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
963            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
964        );
965
966        assert_rewrite!(
967            &rules,
968            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
969            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname"
970        );
971
972        assert_rewrite!(
973            &rules,
974            "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname",
975            "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname"
976        );
977
978        assert_rewrite!(
979            &rules,
980            "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname",
981            "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
982        );
983
984        assert_rewrite!(
985            &rules,
986            "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::pg_catalog.regtype::pg_catalog.text END, c.relpersistence, c.relreplident, am.amname
987    FROM pg_catalog.pg_class c
988     LEFT JOIN pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid)
989    LEFT JOIN pg_catalog.pg_am am ON (c.relam = am.oid)
990    WHERE c.oid = '16386'",
991            "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::TEXT END, c.relpersistence, c.relreplident, am.amname FROM pg_catalog.pg_class AS c LEFT JOIN pg_catalog.pg_class AS tc ON (c.reltoastrelid = tc.oid) LEFT JOIN pg_catalog.pg_am AS am ON (c.relam = am.oid) WHERE c.oid = '16386'"
992        );
993    }
994
995    #[test]
996    fn test_any_to_array_contains() {
997        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
998            vec![Arc::new(RewriteArrayAnyAllOperation)];
999
1000        assert_rewrite!(
1001            &rules,
1002            "SELECT a = ANY(current_schemas(true))",
1003            "SELECT array_contains(current_schemas(true), a)"
1004        );
1005
1006        assert_rewrite!(
1007            &rules,
1008            "SELECT a <> ALL(current_schemas(true))",
1009            "SELECT NOT array_contains(current_schemas(true), a)"
1010        );
1011
1012        assert_rewrite!(
1013            &rules,
1014            "SELECT a = ANY('{r, l, e}')",
1015            "SELECT array_contains(ARRAY['r', 'l', 'e'], a)"
1016        );
1017
1018        assert_rewrite!(
1019            &rules,
1020            "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))",
1021            "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)"
1022        );
1023    }
1024
1025    #[test]
1026    fn test_prepend_unqualified_table_name() {
1027        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1028            vec![Arc::new(PrependUnqualifiedPgTableName)];
1029
1030        assert_rewrite!(
1031            &rules,
1032            "SELECT * FROM pg_catalog.pg_namespace",
1033            "SELECT * FROM pg_catalog.pg_namespace"
1034        );
1035
1036        assert_rewrite!(
1037            &rules,
1038            "SELECT * FROM pg_namespace",
1039            "SELECT * FROM pg_catalog.pg_namespace"
1040        );
1041
1042        assert_rewrite!(
1043            &rules,
1044            "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid",
1045            "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid"
1046        );
1047    }
1048
1049    #[test]
1050    fn test_array_literal_fix() {
1051        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixArrayLiteral)];
1052
1053        assert_rewrite!(
1054            &rules,
1055            "SELECT '{a, abc}'::text[]",
1056            "SELECT ARRAY['a', 'abc']::TEXT[]"
1057        );
1058
1059        assert_rewrite!(
1060            &rules,
1061            "SELECT '{1, 2}'::int[]",
1062            "SELECT ARRAY[1, 2]::INT[]"
1063        );
1064
1065        assert_rewrite!(
1066            &rules,
1067            "SELECT '{t, f}'::bool[]",
1068            "SELECT ARRAY[t, f]::BOOL[]"
1069        );
1070    }
1071
1072    #[test]
1073    fn test_remove_qualifier_from_table_function() {
1074        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(RemoveQualifier)];
1075
1076        assert_rewrite!(
1077            &rules,
1078            "SELECT * FROM pg_catalog.pg_get_keywords()",
1079            "SELECT * FROM pg_get_keywords()"
1080        );
1081    }
1082
1083    #[test]
1084    fn test_current_user() {
1085        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1086            vec![Arc::new(CurrentUserVariableToSessionUserFunctionCall)];
1087
1088        assert_rewrite!(&rules, "SELECT current_user", "SELECT session_user");
1089
1090        assert_rewrite!(&rules, "SELECT CURRENT_USER", "SELECT session_user");
1091
1092        assert_rewrite!(
1093            &rules,
1094            "SELECT is_null(current_user)",
1095            "SELECT is_null(session_user)"
1096        );
1097    }
1098
1099    #[test]
1100    fn test_collate_fix() {
1101        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixCollate)];
1102
1103        assert_rewrite!(&rules, "SELECT c.oid, c.relname FROM pg_catalog.pg_class c WHERE c.relname OPERATOR(pg_catalog.~) '^(tablename)$' COLLATE pg_catalog.default AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3;", "SELECT c.oid, c.relname FROM pg_catalog.pg_class AS c WHERE c.relname ~ '^(tablename)$' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3");
1104    }
1105
1106    #[test]
1107    fn test_remove_subquery() {
1108        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
1109            vec![Arc::new(RemoveSubqueryFromProjection)];
1110
1111        assert_rewrite!(&rules,
1112            "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;",
1113            "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum");
1114    }
1115
1116    #[test]
1117    fn test_version_rewrite() {
1118        let rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![Arc::new(FixVersionColumnName)];
1119
1120        assert_rewrite!(&rules, "SELECT version()", "SELECT version() AS version");
1121
1122        // Make sure we don't rewrite things we should leave alone
1123        assert_rewrite!(&rules, "SELECT version() as foo", "SELECT version() AS foo");
1124        assert_rewrite!(&rules, "SELECT version(foo)", "SELECT version(foo)");
1125        assert_rewrite!(&rules, "SELECT foo.version()", "SELECT foo.version()");
1126    }
1127}