datafusion_pg_catalog/sql/
parser.rs

1use std::sync::Arc;
2
3use datafusion::sql::sqlparser::ast::Statement;
4use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
5use datafusion::sql::sqlparser::parser::Parser;
6use datafusion::sql::sqlparser::parser::ParserError;
7use datafusion::sql::sqlparser::tokenizer::Token;
8use datafusion::sql::sqlparser::tokenizer::TokenWithSpan;
9
10use super::rules::AliasDuplicatedProjectionRewrite;
11use super::rules::CurrentUserVariableToSessionUserFunctionCall;
12use super::rules::FixArrayLiteral;
13use super::rules::FixCollate;
14use super::rules::FixVersionColumnName;
15use super::rules::PrependUnqualifiedPgTableName;
16use super::rules::RemoveQualifier;
17use super::rules::RemoveSubqueryFromProjection;
18use super::rules::RemoveUnsupportedTypes;
19use super::rules::ResolveUnqualifiedIdentifer;
20use super::rules::RewriteArrayAnyAllOperation;
21use super::rules::SqlStatementRewriteRule;
22
23const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[
24    // pgcli startup query
25    (
26"SELECT s_p.nspname AS parentschema,
27                               t_p.relname AS parenttable,
28                               unnest((
29                                select
30                                    array_agg(attname ORDER BY i)
31                                from
32                                    (select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
33                                    JOIN pg_catalog.pg_attribute c USING(attnum)
34                                    WHERE c.attrelid = fk.confrelid
35                                )) AS parentcolumn,
36                               s_c.nspname AS childschema,
37                               t_c.relname AS childtable,
38                               unnest((
39                                select
40                                    array_agg(attname ORDER BY i)
41                                from
42                                    (select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
43                                    JOIN pg_catalog.pg_attribute c USING(attnum)
44                                    WHERE c.attrelid = fk.conrelid
45                                )) AS childcolumn
46                        FROM pg_catalog.pg_constraint fk
47                        JOIN pg_catalog.pg_class      t_p ON t_p.oid = fk.confrelid
48                        JOIN pg_catalog.pg_namespace  s_p ON s_p.oid = t_p.relnamespace
49                        JOIN pg_catalog.pg_class      t_c ON t_c.oid = fk.conrelid
50                        JOIN pg_catalog.pg_namespace  s_c ON s_c.oid = t_c.relnamespace
51                        WHERE fk.contype = 'f'",
52"SELECT
53   NULL::TEXT AS parentschema,
54   NULL::TEXT AS parenttable,
55   NULL::TEXT AS parentcolumn,
56   NULL::TEXT AS childschema,
57   NULL::TEXT AS childtable,
58   NULL::TEXT AS childcolumn
59 WHERE false"),
60
61    // pgcli startup query
62    (
63"SELECT n.nspname schema_name,
64                                       t.typname type_name
65                                FROM   pg_catalog.pg_type t
66                                       INNER JOIN pg_catalog.pg_namespace n
67                                          ON n.oid = t.typnamespace
68                                WHERE ( t.typrelid = 0  -- non-composite types
69                                        OR (  -- composite type, but not a table
70                                              SELECT c.relkind = 'c'
71                                              FROM pg_catalog.pg_class c
72                                              WHERE c.oid = t.typrelid
73                                            )
74                                      )
75                                      AND NOT EXISTS( -- ignore array types
76                                            SELECT  1
77                                            FROM    pg_catalog.pg_type el
78                                            WHERE   el.oid = t.typelem AND el.typarray = t.oid
79                                          )
80                                      AND n.nspname <> 'pg_catalog'
81                                      AND n.nspname <> 'information_schema'
82                                ORDER BY 1, 2;",
83"SELECT NULL::TEXT AS schema_name, NULL::TEXT AS type_name WHERE false"
84    ),
85
86// psql \d <table> queries
87    (
88"SELECT pol.polname, pol.polpermissive,
89          CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END,
90          pg_catalog.pg_get_expr(pol.polqual, pol.polrelid),
91          pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid),
92          CASE pol.polcmd
93            WHEN 'r' THEN 'SELECT'
94            WHEN 'a' THEN 'INSERT'
95            WHEN 'w' THEN 'UPDATE'
96            WHEN 'd' THEN 'DELETE'
97            END AS cmd
98        FROM pg_catalog.pg_policy pol
99        WHERE pol.polrelid = $1 ORDER BY 1;",
100"SELECT
101   NULL::TEXT AS polname,
102   NULL::TEXT AS polpermissive,
103   NULL::TEXT AS array_to_string,
104   NULL::TEXT AS pg_get_expr_1,
105   NULL::TEXT AS pg_get_expr_2,
106   NULL::TEXT AS cmd
107 WHERE false"
108    ),
109
110    (
111"SELECT oid, stxrelid::pg_catalog.regclass, stxnamespace::pg_catalog.regnamespace::pg_catalog.text AS nsp, stxname,
112        pg_catalog.pg_get_statisticsobjdef_columns(oid) AS columns,
113          'd' = any(stxkind) AS ndist_enabled,
114          'f' = any(stxkind) AS deps_enabled,
115          'm' = any(stxkind) AS mcv_enabled,
116        stxstattarget
117        FROM pg_catalog.pg_statistic_ext
118        WHERE stxrelid = $1
119        ORDER BY nsp, stxname;",
120"SELECT
121   NULL::INT AS oid,
122   NULL::TEXT AS stxrelid,
123   NULL::TEXT AS nsp,
124   NULL::TEXT AS stxname,
125   NULL::TEXT AS columns,
126   NULL::BOOLEAN AS ndist_enabled,
127   NULL::BOOLEAN AS deps_enabled,
128   NULL::BOOLEAN AS mcv_enabled,
129   NULL::TEXT AS stxstattarget
130 WHERE false"
131    ),
132
133    (
134"SELECT pubname
135             , NULL
136             , NULL
137        FROM pg_catalog.pg_publication p
138             JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid
139             JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid
140        WHERE pc.oid = $1 and pg_catalog.pg_relation_is_publishable($1)
141        UNION
142        SELECT pubname
143             , pg_get_expr(pr.prqual, c.oid)
144             , (CASE WHEN pr.prattrs IS NOT NULL THEN
145                 (SELECT string_agg(attname, ', ')
146                   FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s,
147                        pg_catalog.pg_attribute
148                  WHERE attrelid = pr.prrelid AND attnum = prattrs[s])
149                ELSE NULL END) FROM pg_catalog.pg_publication p
150             JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid
151             JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid
152        WHERE pr.prrelid = $1
153        UNION
154        SELECT pubname
155             , NULL
156             , NULL
157        FROM pg_catalog.pg_publication p
158        WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable($1)
159        ORDER BY 1;",
160"SELECT
161   NULL::TEXT AS pubname,
162   NULL::TEXT AS _1,
163   NULL::TEXT AS _2
164 WHERE false"
165    ),
166
167    // grafana array index magic
168    (r#"SELECT
169            CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
170        FROM
171            generate_series(
172                array_lower(string_to_array(current_setting('search_path'),','),1),
173                array_upper(string_to_array(current_setting('search_path'),','),1)
174            ) as i,
175            string_to_array(current_setting('search_path'),',') s"#,
176"''")
177];
178
179/// A parser with Postgres Compatibility for Datafusion
180///
181/// This parser will try its best to rewrite postgres SQL into a form that
182/// datafuiosn supports. It also maintains a blacklist that will transform the
183/// statement to a similar version if rewrite doesn't worth the effort for now.
184#[derive(Debug)]
185pub struct PostgresCompatibilityParser {
186    blacklist: Vec<(Vec<Token>, Vec<Token>)>,
187    rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
188}
189
190impl Default for PostgresCompatibilityParser {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196impl PostgresCompatibilityParser {
197    pub fn new() -> Self {
198        let mut mapping = Vec::with_capacity(BLACKLIST_SQL_MAPPING.len());
199
200        for (sql_from, sql_to) in BLACKLIST_SQL_MAPPING {
201            mapping.push((
202                Parser::new(&PostgreSqlDialect {})
203                    .try_with_sql(sql_from)
204                    .unwrap()
205                    .into_tokens()
206                    .into_iter()
207                    .map(|t| t.token)
208                    .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
209                    .collect(),
210                Parser::new(&PostgreSqlDialect {})
211                    .try_with_sql(sql_to)
212                    .unwrap()
213                    .into_tokens()
214                    .into_iter()
215                    .map(|t| t.token)
216                    .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
217                    .collect(),
218            ));
219        }
220
221        Self {
222            blacklist: mapping,
223            rewrite_rules: vec![
224                // make sure blacklist based rewriter it on the top to prevent sql
225                // being rewritten from other rewriters
226                Arc::new(AliasDuplicatedProjectionRewrite),
227                Arc::new(ResolveUnqualifiedIdentifer),
228                Arc::new(RewriteArrayAnyAllOperation),
229                Arc::new(PrependUnqualifiedPgTableName),
230                Arc::new(RemoveQualifier),
231                Arc::new(RemoveUnsupportedTypes::new()),
232                Arc::new(FixArrayLiteral),
233                Arc::new(CurrentUserVariableToSessionUserFunctionCall),
234                Arc::new(FixCollate),
235                Arc::new(RemoveSubqueryFromProjection),
236                Arc::new(FixVersionColumnName),
237            ],
238        }
239    }
240
241    /// return tokens with replacements applied
242    fn maybe_replace_tokens(&self, input: &str) -> Result<Vec<Token>, ParserError> {
243        let parser = Parser::new(&PostgreSqlDialect {});
244        let tokens = parser.try_with_sql(input)?.into_tokens();
245
246        // Get token values (without spans) and filter out only whitespace
247        // Keep semicolons as they separate statements
248        let filtered_tokens: Vec<Token> = tokens
249            .iter()
250            .map(|t| t.token.clone())
251            .filter(|t| !matches!(t, Token::Whitespace(_)))
252            .collect();
253
254        // Handle empty input
255        if filtered_tokens.is_empty() {
256            return Ok(Vec::new());
257        }
258
259        // Build result by processing filtered tokens sequentially
260        let mut result = Vec::new();
261        let mut i = 0;
262
263        while i < filtered_tokens.len() {
264            // Keep semicolons as-is
265            if matches!(&filtered_tokens[i], Token::SemiColon) {
266                result.push(filtered_tokens[i].clone());
267                i += 1;
268                continue;
269            }
270
271            // Try to find a blacklist pattern match starting at this position
272            let mut matched = false;
273            for (pattern, replacement) in &self.blacklist {
274                if pattern.is_empty() {
275                    continue;
276                }
277
278                // Check if we have enough tokens remaining
279                let mut j = 0;
280                let mut pattern_idx = 0;
281                while i + j < filtered_tokens.len() && pattern_idx < pattern.len() {
282                    // Skip semicolons in the input when matching patterns
283                    if matches!(&filtered_tokens[i + j], Token::SemiColon) {
284                        j += 1;
285                        continue;
286                    }
287
288                    match &pattern[pattern_idx] {
289                        Token::Placeholder(_) => {
290                            // Placeholder matches any non-semicolon token
291                            pattern_idx += 1;
292                            j += 1;
293                        }
294                        _ => {
295                            if filtered_tokens[i + j] != pattern[pattern_idx] {
296                                break;
297                            }
298                            pattern_idx += 1;
299                            j += 1;
300                        }
301                    }
302                }
303
304                // Check if we matched the entire pattern
305                if pattern_idx == pattern.len() {
306                    // Add replacement tokens
307                    result.extend(replacement.iter().cloned());
308                    // Skip the matched pattern (including any semicolons we skipped)
309                    i += j;
310                    matched = true;
311                    break;
312                }
313            }
314
315            if !matched {
316                // No match, keep the original token
317                result.push(filtered_tokens[i].clone());
318                i += 1;
319            }
320        }
321
322        Ok(result)
323    }
324
325    fn parse_tokens(&self, tokens: Vec<Token>) -> Result<Vec<Statement>, ParserError> {
326        let parser = Parser::new(&PostgreSqlDialect {});
327        // Convert tokens to TokenWithSpan with dummy spans
328        let tokens_with_spans: Vec<TokenWithSpan> = tokens
329            .into_iter()
330            .map(|token| TokenWithSpan {
331                token,
332                span: datafusion::sql::sqlparser::tokenizer::Span::empty(),
333            })
334            .collect();
335        parser
336            .with_tokens_with_locations(tokens_with_spans)
337            .parse_statements()
338    }
339
340    pub fn parse(&self, input: &str) -> Result<Vec<Statement>, ParserError> {
341        let tokens = self.maybe_replace_tokens(input)?;
342        let statements = self.parse_tokens(tokens)?;
343
344        let statements: Vec<_> = statements.into_iter().map(|s| self.rewrite(s)).collect();
345        Ok(statements)
346    }
347
348    pub fn rewrite(&self, mut s: Statement) -> Statement {
349        for rule in &self.rewrite_rules {
350            s = rule.rewrite(s);
351        }
352
353        s
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_full_match() {
363        let sql = "SELECT pol.polname, pol.polpermissive,
364              CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END,
365              pg_catalog.pg_get_expr(pol.polqual, pol.polrelid),
366              pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid),
367              CASE pol.polcmd
368                WHEN 'r' THEN 'SELECT'
369                WHEN 'a' THEN 'INSERT'
370                WHEN 'w' THEN 'UPDATE'
371                WHEN 'd' THEN 'DELETE'
372                END AS cmd
373            FROM pg_catalog.pg_policy pol
374            WHERE pol.polrelid = '16384' ORDER BY 1;";
375
376        let parser = PostgresCompatibilityParser::new();
377        let actual_tokens = parser
378            .maybe_replace_tokens(sql)
379            .expect("failed to parse sql")
380            .into_iter()
381            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
382            .collect::<Vec<_>>();
383
384        let expected_sql = r#"SELECT
385   NULL::TEXT AS polname,
386   NULL::TEXT AS polpermissive,
387   NULL::TEXT AS array_to_string,
388   NULL::TEXT AS pg_get_expr_1,
389   NULL::TEXT AS pg_get_expr_2,
390   NULL::TEXT AS cmd
391 WHERE false"#;
392
393        let expected_tokens = Parser::new(&PostgreSqlDialect {})
394            .try_with_sql(expected_sql)
395            .unwrap()
396            .into_tokens()
397            .into_iter()
398            .map(|t| t.token)
399            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
400            .collect::<Vec<_>>();
401
402        assert_eq!(actual_tokens, expected_tokens);
403
404        let sql = "SELECT n.nspname schema_name,
405                                       t.typname type_name
406                                FROM   pg_catalog.pg_type t
407                                       INNER JOIN pg_catalog.pg_namespace n
408                                          ON n.oid = t.typnamespace
409                                WHERE ( t.typrelid = 0  -- non-composite types
410                                        OR (  -- composite type, but not a table
411                                              SELECT c.relkind = 'c'
412                                              FROM pg_catalog.pg_class c
413                                              WHERE c.oid = t.typrelid
414                                            )
415                                      )
416                                      AND NOT EXISTS( -- ignore array types
417                                            SELECT  1
418                                            FROM    pg_catalog.pg_type el
419                                            WHERE   el.oid = t.typelem AND el.typarray = t.oid
420                                          )
421                                      AND n.nspname <> 'pg_catalog'
422                                      AND n.nspname <> 'information_schema'
423                                ORDER BY 1, 2";
424
425        let parser = PostgresCompatibilityParser::new();
426
427        let actual_tokens = parser
428            .maybe_replace_tokens(sql)
429            .expect("failed to parse sql")
430            .into_iter()
431            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
432            .collect::<Vec<_>>();
433
434        let expected_sql =
435            r#"SELECT NULL::TEXT AS schema_name, NULL::TEXT AS type_name WHERE false"#;
436
437        let expected_tokens = Parser::new(&PostgreSqlDialect {})
438            .try_with_sql(expected_sql)
439            .unwrap()
440            .into_tokens()
441            .into_iter()
442            .map(|t| t.token)
443            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
444            .collect::<Vec<_>>();
445
446        assert_eq!(actual_tokens, expected_tokens);
447
448        let sql = "SELECT pubname
449             , NULL
450             , NULL
451        FROM pg_catalog.pg_publication p
452             JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid
453             JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid
454        WHERE pc.oid ='16384' and pg_catalog.pg_relation_is_publishable('16384')
455        UNION
456        SELECT pubname
457             , pg_get_expr(pr.prqual, c.oid)
458             , (CASE WHEN pr.prattrs IS NOT NULL THEN
459                 (SELECT string_agg(attname, ', ')
460                   FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s,
461                        pg_catalog.pg_attribute
462                  WHERE attrelid = pr.prrelid AND attnum = prattrs[s])
463                ELSE NULL END) FROM pg_catalog.pg_publication p
464             JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid
465             JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid
466        WHERE pr.prrelid = '16384'
467        UNION
468        SELECT pubname
469             , NULL
470             , NULL
471        FROM pg_catalog.pg_publication p
472        WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable('16384')
473        ORDER BY 1;";
474
475        let parser = PostgresCompatibilityParser::new();
476
477        let actual_tokens = parser
478            .maybe_replace_tokens(sql)
479            .expect("failed to parse sql")
480            .into_iter()
481            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
482            .collect::<Vec<_>>();
483
484        let expected_sql = r#"SELECT
485   NULL::TEXT AS pubname,
486   NULL::TEXT AS _1,
487   NULL::TEXT AS _2
488 WHERE false"#;
489
490        let expected_tokens = Parser::new(&PostgreSqlDialect {})
491            .try_with_sql(expected_sql)
492            .unwrap()
493            .into_tokens()
494            .into_iter()
495            .map(|t| t.token)
496            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
497            .collect::<Vec<_>>();
498
499        assert_eq!(actual_tokens, expected_tokens);
500    }
501
502    #[test]
503    fn test_empty_query() {
504        let parser = PostgresCompatibilityParser::new();
505        let result = parser.parse(" ").expect("failed to parse sql");
506        assert!(result.is_empty());
507
508        let result = parser.parse("").expect("failed to parse sql");
509        assert!(result.is_empty());
510
511        let result = parser.parse(";").expect("failed to parse sql");
512        assert!(result.is_empty());
513    }
514
515    #[test]
516    fn test_partial_match() {
517        let parser = PostgresCompatibilityParser::new();
518
519        // Test partial match where the beginning matches a blacklisted query
520        // Using a simpler query that doesn't have placeholders for easier testing
521        let sql = r#"SELECT
522        CASE WHEN
523              quote_ident(table_schema) IN (
524              SELECT
525                CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
526              FROM
527                generate_series(
528                  array_lower(string_to_array(current_setting('search_path'),','),1),
529                  array_upper(string_to_array(current_setting('search_path'),','),1)
530                ) as i,
531                string_to_array(current_setting('search_path'),',') s
532              )
533          THEN quote_ident(table_name)
534          ELSE quote_ident(table_schema) || '.' || quote_ident(table_name)
535        END AS "table"
536        FROM information_schema.tables
537        WHERE quote_ident(table_schema) NOT IN ('information_schema',
538                                 'pg_catalog',
539                                 '_timescaledb_cache',
540                                 '_timescaledb_catalog',
541                                 '_timescaledb_internal',
542                                 '_timescaledb_config',
543                                 'timescaledb_information',
544                                 'timescaledb_experimental')
545        ORDER BY CASE WHEN
546              quote_ident(table_schema) IN (
547              SELECT
548                CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
549              FROM
550                generate_series(
551                  array_lower(string_to_array(current_setting('search_path'),','),1),
552                  array_upper(string_to_array(current_setting('search_path'),','),1)
553                ) as i,
554                string_to_array(current_setting('search_path'),',') s
555              ) THEN 0 ELSE 1 END, 1"#;
556
557        let tokens = parser
558            .maybe_replace_tokens(sql)
559            .expect("failed to parse sql");
560        // Should have the beginning replaced with 'SELECT' and the rest preserved
561        assert!(tokens.len() > 0);
562
563        let expected_sql = r#"SELECT
564        CASE WHEN
565              quote_ident(table_schema) IN (
566              '')
567          THEN quote_ident(table_name)
568          ELSE quote_ident(table_schema) || '.' || quote_ident(table_name)
569        END AS "table"
570        FROM information_schema.tables
571        WHERE quote_ident(table_schema) NOT IN ('information_schema',
572                                 'pg_catalog',
573                                 '_timescaledb_cache',
574                                 '_timescaledb_catalog',
575                                 '_timescaledb_internal',
576                                 '_timescaledb_config',
577                                 'timescaledb_information',
578                                 'timescaledb_experimental')
579        ORDER BY CASE WHEN
580              quote_ident(table_schema) IN (
581              ''
582              ) THEN 0 ELSE 1 END, 1"#;
583
584        let expected_tokens = Parser::new(&PostgreSqlDialect {})
585            .try_with_sql(expected_sql)
586            .unwrap()
587            .into_tokens();
588
589        // Compare token values (ignoring spans and whitespace)
590        let actual_tokens: Vec<_> = tokens
591            .iter()
592            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
593            .collect();
594        let expected_token_values: Vec<_> = expected_tokens
595            .iter()
596            .map(|t| &t.token)
597            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
598            .collect();
599
600        assert_eq!(actual_tokens, expected_token_values);
601    }
602}