Skip to main content

datafusion_pg_catalog/sql/
parser.rs

1use std::sync::Arc;
2
3use datafusion::sql::sqlparser::ast::Statement;
4use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
5use datafusion::sql::sqlparser::parser::Parser;
6use datafusion::sql::sqlparser::parser::ParserError;
7use datafusion::sql::sqlparser::tokenizer::Token;
8use datafusion::sql::sqlparser::tokenizer::TokenWithSpan;
9
10use super::rules::AliasDuplicatedProjectionRewrite;
11use super::rules::CurrentUserVariableToSessionUserFunctionCall;
12use super::rules::FixArrayLiteral;
13use super::rules::FixCollate;
14use super::rules::FixVersionColumnName;
15use super::rules::PrependUnqualifiedPgTableName;
16use super::rules::RemoveQualifier;
17use super::rules::RemoveSubqueryFromProjection;
18use super::rules::RemoveUnsupportedTypes;
19use super::rules::ResolveUnqualifiedIdentifer;
20use super::rules::RewriteArrayAnyAllOperation;
21use super::rules::RewriteRegclassCastToSubquery;
22use super::rules::SqlStatementRewriteRule;
23
24const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[
25    // pgcli startup query
26    (
27"SELECT s_p.nspname AS parentschema,
28                               t_p.relname AS parenttable,
29                               unnest((
30                                select
31                                    array_agg(attname ORDER BY i)
32                                from
33                                    (select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
34                                    JOIN pg_catalog.pg_attribute c USING(attnum)
35                                    WHERE c.attrelid = fk.confrelid
36                                )) AS parentcolumn,
37                               s_c.nspname AS childschema,
38                               t_c.relname AS childtable,
39                               unnest((
40                                select
41                                    array_agg(attname ORDER BY i)
42                                from
43                                    (select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
44                                    JOIN pg_catalog.pg_attribute c USING(attnum)
45                                    WHERE c.attrelid = fk.conrelid
46                                )) AS childcolumn
47                        FROM pg_catalog.pg_constraint fk
48                        JOIN pg_catalog.pg_class      t_p ON t_p.oid = fk.confrelid
49                        JOIN pg_catalog.pg_namespace  s_p ON s_p.oid = t_p.relnamespace
50                        JOIN pg_catalog.pg_class      t_c ON t_c.oid = fk.conrelid
51                        JOIN pg_catalog.pg_namespace  s_c ON s_c.oid = t_c.relnamespace
52                        WHERE fk.contype = 'f'",
53"SELECT
54   NULL::TEXT AS parentschema,
55   NULL::TEXT AS parenttable,
56   NULL::TEXT AS parentcolumn,
57   NULL::TEXT AS childschema,
58   NULL::TEXT AS childtable,
59   NULL::TEXT AS childcolumn
60 WHERE false"),
61
62    // pgcli startup query
63    (
64"SELECT n.nspname schema_name,
65                                       t.typname type_name
66                                FROM   pg_catalog.pg_type t
67                                       INNER JOIN pg_catalog.pg_namespace n
68                                          ON n.oid = t.typnamespace
69                                WHERE ( t.typrelid = 0  -- non-composite types
70                                        OR (  -- composite type, but not a table
71                                              SELECT c.relkind = 'c'
72                                              FROM pg_catalog.pg_class c
73                                              WHERE c.oid = t.typrelid
74                                            )
75                                      )
76                                      AND NOT EXISTS( -- ignore array types
77                                            SELECT  1
78                                            FROM    pg_catalog.pg_type el
79                                            WHERE   el.oid = t.typelem AND el.typarray = t.oid
80                                          )
81                                      AND n.nspname <> 'pg_catalog'
82                                      AND n.nspname <> 'information_schema'
83                                ORDER BY 1, 2;",
84"SELECT NULL::TEXT AS schema_name, NULL::TEXT AS type_name WHERE false"
85    ),
86
87// psql \d <table> queries
88    (
89"SELECT pol.polname, pol.polpermissive,
90          CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END,
91          pg_catalog.pg_get_expr(pol.polqual, pol.polrelid),
92          pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid),
93          CASE pol.polcmd
94            WHEN 'r' THEN 'SELECT'
95            WHEN 'a' THEN 'INSERT'
96            WHEN 'w' THEN 'UPDATE'
97            WHEN 'd' THEN 'DELETE'
98            END AS cmd
99        FROM pg_catalog.pg_policy pol
100        WHERE pol.polrelid = $1 ORDER BY 1;",
101"SELECT
102   NULL::TEXT AS polname,
103   NULL::TEXT AS polpermissive,
104   NULL::TEXT AS array_to_string,
105   NULL::TEXT AS pg_get_expr_1,
106   NULL::TEXT AS pg_get_expr_2,
107   NULL::TEXT AS cmd
108 WHERE false"
109    ),
110
111    (
112"SELECT oid, stxrelid::pg_catalog.regclass, stxnamespace::pg_catalog.regnamespace::pg_catalog.text AS nsp, stxname,
113        pg_catalog.pg_get_statisticsobjdef_columns(oid) AS columns,
114          'd' = any(stxkind) AS ndist_enabled,
115          'f' = any(stxkind) AS deps_enabled,
116          'm' = any(stxkind) AS mcv_enabled,
117        stxstattarget
118        FROM pg_catalog.pg_statistic_ext
119        WHERE stxrelid = $1
120        ORDER BY nsp, stxname;",
121"SELECT
122   NULL::INT AS oid,
123   NULL::TEXT AS stxrelid,
124   NULL::TEXT AS nsp,
125   NULL::TEXT AS stxname,
126   NULL::TEXT AS columns,
127   NULL::BOOLEAN AS ndist_enabled,
128   NULL::BOOLEAN AS deps_enabled,
129   NULL::BOOLEAN AS mcv_enabled,
130   NULL::TEXT AS stxstattarget
131 WHERE false"
132    ),
133
134    (
135"SELECT pubname
136             , NULL
137             , NULL
138        FROM pg_catalog.pg_publication p
139             JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid
140             JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid
141        WHERE pc.oid = $1 and pg_catalog.pg_relation_is_publishable($1)
142        UNION
143        SELECT pubname
144             , pg_get_expr(pr.prqual, c.oid)
145             , (CASE WHEN pr.prattrs IS NOT NULL THEN
146                 (SELECT string_agg(attname, ', ')
147                   FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s,
148                        pg_catalog.pg_attribute
149                  WHERE attrelid = pr.prrelid AND attnum = prattrs[s])
150                ELSE NULL END) FROM pg_catalog.pg_publication p
151             JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid
152             JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid
153        WHERE pr.prrelid = $1
154        UNION
155        SELECT pubname
156             , NULL
157             , NULL
158        FROM pg_catalog.pg_publication p
159        WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable($1)
160        ORDER BY 1;",
161"SELECT
162   NULL::TEXT AS pubname,
163   NULL::TEXT AS _1,
164   NULL::TEXT AS _2
165 WHERE false"
166    ),
167
168    // grafana array index magic
169    (r#"SELECT
170            CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
171        FROM
172            generate_series(
173                array_lower(string_to_array(current_setting('search_path'),','),1),
174                array_upper(string_to_array(current_setting('search_path'),','),1)
175            ) as i,
176            string_to_array(current_setting('search_path'),',') s"#,
177"''")
178];
179
180/// A parser with Postgres Compatibility for Datafusion
181///
182/// This parser will try its best to rewrite postgres SQL into a form that
183/// datafuiosn supports. It also maintains a blacklist that will transform the
184/// statement to a similar version if rewrite doesn't worth the effort for now.
185#[derive(Debug)]
186pub struct PostgresCompatibilityParser {
187    blacklist: Vec<(Vec<Token>, Vec<Token>)>,
188    rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
189}
190
191impl Default for PostgresCompatibilityParser {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197impl PostgresCompatibilityParser {
198    pub fn new() -> Self {
199        let mut mapping = Vec::with_capacity(BLACKLIST_SQL_MAPPING.len());
200
201        for (sql_from, sql_to) in BLACKLIST_SQL_MAPPING {
202            mapping.push((
203                Parser::new(&PostgreSqlDialect {})
204                    .try_with_sql(sql_from)
205                    .unwrap()
206                    .into_tokens()
207                    .into_iter()
208                    .map(|t| t.token)
209                    .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
210                    .collect(),
211                Parser::new(&PostgreSqlDialect {})
212                    .try_with_sql(sql_to)
213                    .unwrap()
214                    .into_tokens()
215                    .into_iter()
216                    .map(|t| t.token)
217                    .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
218                    .collect(),
219            ));
220        }
221
222        Self {
223            blacklist: mapping,
224            rewrite_rules: vec![
225                // make sure blacklist based rewriter it on the top to prevent sql
226                // being rewritten from other rewriters
227                Arc::new(AliasDuplicatedProjectionRewrite),
228                Arc::new(ResolveUnqualifiedIdentifer),
229                Arc::new(RewriteArrayAnyAllOperation),
230                Arc::new(PrependUnqualifiedPgTableName),
231                Arc::new(RemoveQualifier),
232                Arc::new(RewriteRegclassCastToSubquery::new()),
233                Arc::new(RemoveUnsupportedTypes::new()),
234                Arc::new(FixArrayLiteral),
235                Arc::new(CurrentUserVariableToSessionUserFunctionCall),
236                Arc::new(FixCollate),
237                Arc::new(RemoveSubqueryFromProjection),
238                Arc::new(FixVersionColumnName),
239            ],
240        }
241    }
242
243    /// return tokens with replacements applied
244    fn maybe_replace_tokens(&self, input: &str) -> Result<Vec<Token>, ParserError> {
245        let parser = Parser::new(&PostgreSqlDialect {});
246        let tokens = parser.try_with_sql(input)?.into_tokens();
247
248        // Get token values (without spans) and filter out only whitespace
249        // Keep semicolons as they separate statements
250        let filtered_tokens: Vec<Token> = tokens
251            .iter()
252            .map(|t| t.token.clone())
253            .filter(|t| !matches!(t, Token::Whitespace(_)))
254            .collect();
255
256        // Handle empty input
257        if filtered_tokens.is_empty() {
258            return Ok(Vec::new());
259        }
260
261        // Build result by processing filtered tokens sequentially
262        let mut result = Vec::new();
263        let mut i = 0;
264
265        while i < filtered_tokens.len() {
266            // Keep semicolons as-is
267            if matches!(&filtered_tokens[i], Token::SemiColon) {
268                result.push(filtered_tokens[i].clone());
269                i += 1;
270                continue;
271            }
272
273            // Try to find a blacklist pattern match starting at this position
274            let mut matched = false;
275            for (pattern, replacement) in &self.blacklist {
276                if pattern.is_empty() {
277                    continue;
278                }
279
280                // Check if we have enough tokens remaining
281                let mut j = 0;
282                let mut pattern_idx = 0;
283                while i + j < filtered_tokens.len() && pattern_idx < pattern.len() {
284                    // Skip semicolons in the input when matching patterns
285                    if matches!(&filtered_tokens[i + j], Token::SemiColon) {
286                        j += 1;
287                        continue;
288                    }
289
290                    match &pattern[pattern_idx] {
291                        Token::Placeholder(_) => {
292                            // Placeholder matches any non-semicolon token
293                            pattern_idx += 1;
294                            j += 1;
295                        }
296                        _ => {
297                            if filtered_tokens[i + j] != pattern[pattern_idx] {
298                                break;
299                            }
300                            pattern_idx += 1;
301                            j += 1;
302                        }
303                    }
304                }
305
306                // Check if we matched the entire pattern
307                if pattern_idx == pattern.len() {
308                    // Add replacement tokens
309                    result.extend(replacement.iter().cloned());
310                    // Skip the matched pattern (including any semicolons we skipped)
311                    i += j;
312                    matched = true;
313                    break;
314                }
315            }
316
317            if !matched {
318                // No match, keep the original token
319                result.push(filtered_tokens[i].clone());
320                i += 1;
321            }
322        }
323
324        Ok(result)
325    }
326
327    fn parse_tokens(&self, tokens: Vec<Token>) -> Result<Vec<Statement>, ParserError> {
328        let parser = Parser::new(&PostgreSqlDialect {});
329        // Convert tokens to TokenWithSpan with dummy spans
330        let tokens_with_spans: Vec<TokenWithSpan> = tokens
331            .into_iter()
332            .map(|token| TokenWithSpan {
333                token,
334                span: datafusion::sql::sqlparser::tokenizer::Span::empty(),
335            })
336            .collect();
337        parser
338            .with_tokens_with_locations(tokens_with_spans)
339            .parse_statements()
340    }
341
342    pub fn parse(&self, input: &str) -> Result<Vec<Statement>, ParserError> {
343        let tokens = self.maybe_replace_tokens(input)?;
344        let statements = self.parse_tokens(tokens)?;
345
346        let statements: Vec<_> = statements.into_iter().map(|s| self.rewrite(s)).collect();
347        Ok(statements)
348    }
349
350    pub fn rewrite(&self, mut s: Statement) -> Statement {
351        for rule in &self.rewrite_rules {
352            s = rule.rewrite(s);
353        }
354
355        s
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_full_match() {
365        let sql = "SELECT pol.polname, pol.polpermissive,
366              CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END,
367              pg_catalog.pg_get_expr(pol.polqual, pol.polrelid),
368              pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid),
369              CASE pol.polcmd
370                WHEN 'r' THEN 'SELECT'
371                WHEN 'a' THEN 'INSERT'
372                WHEN 'w' THEN 'UPDATE'
373                WHEN 'd' THEN 'DELETE'
374                END AS cmd
375            FROM pg_catalog.pg_policy pol
376            WHERE pol.polrelid = '16384' ORDER BY 1;";
377
378        let parser = PostgresCompatibilityParser::new();
379        let actual_tokens = parser
380            .maybe_replace_tokens(sql)
381            .expect("failed to parse sql")
382            .into_iter()
383            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
384            .collect::<Vec<_>>();
385
386        let expected_sql = r#"SELECT
387   NULL::TEXT AS polname,
388   NULL::TEXT AS polpermissive,
389   NULL::TEXT AS array_to_string,
390   NULL::TEXT AS pg_get_expr_1,
391   NULL::TEXT AS pg_get_expr_2,
392   NULL::TEXT AS cmd
393 WHERE false"#;
394
395        let expected_tokens = Parser::new(&PostgreSqlDialect {})
396            .try_with_sql(expected_sql)
397            .unwrap()
398            .into_tokens()
399            .into_iter()
400            .map(|t| t.token)
401            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
402            .collect::<Vec<_>>();
403
404        assert_eq!(actual_tokens, expected_tokens);
405
406        let sql = "SELECT n.nspname schema_name,
407                                       t.typname type_name
408                                FROM   pg_catalog.pg_type t
409                                       INNER JOIN pg_catalog.pg_namespace n
410                                          ON n.oid = t.typnamespace
411                                WHERE ( t.typrelid = 0  -- non-composite types
412                                        OR (  -- composite type, but not a table
413                                              SELECT c.relkind = 'c'
414                                              FROM pg_catalog.pg_class c
415                                              WHERE c.oid = t.typrelid
416                                            )
417                                      )
418                                      AND NOT EXISTS( -- ignore array types
419                                            SELECT  1
420                                            FROM    pg_catalog.pg_type el
421                                            WHERE   el.oid = t.typelem AND el.typarray = t.oid
422                                          )
423                                      AND n.nspname <> 'pg_catalog'
424                                      AND n.nspname <> 'information_schema'
425                                ORDER BY 1, 2";
426
427        let parser = PostgresCompatibilityParser::new();
428
429        let actual_tokens = parser
430            .maybe_replace_tokens(sql)
431            .expect("failed to parse sql")
432            .into_iter()
433            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
434            .collect::<Vec<_>>();
435
436        let expected_sql =
437            r#"SELECT NULL::TEXT AS schema_name, NULL::TEXT AS type_name WHERE false"#;
438
439        let expected_tokens = Parser::new(&PostgreSqlDialect {})
440            .try_with_sql(expected_sql)
441            .unwrap()
442            .into_tokens()
443            .into_iter()
444            .map(|t| t.token)
445            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
446            .collect::<Vec<_>>();
447
448        assert_eq!(actual_tokens, expected_tokens);
449
450        let sql = "SELECT pubname
451             , NULL
452             , NULL
453        FROM pg_catalog.pg_publication p
454             JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid
455             JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid
456        WHERE pc.oid ='16384' and pg_catalog.pg_relation_is_publishable('16384')
457        UNION
458        SELECT pubname
459             , pg_get_expr(pr.prqual, c.oid)
460             , (CASE WHEN pr.prattrs IS NOT NULL THEN
461                 (SELECT string_agg(attname, ', ')
462                   FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s,
463                        pg_catalog.pg_attribute
464                  WHERE attrelid = pr.prrelid AND attnum = prattrs[s])
465                ELSE NULL END) FROM pg_catalog.pg_publication p
466             JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid
467             JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid
468        WHERE pr.prrelid = '16384'
469        UNION
470        SELECT pubname
471             , NULL
472             , NULL
473        FROM pg_catalog.pg_publication p
474        WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable('16384')
475        ORDER BY 1;";
476
477        let parser = PostgresCompatibilityParser::new();
478
479        let actual_tokens = parser
480            .maybe_replace_tokens(sql)
481            .expect("failed to parse sql")
482            .into_iter()
483            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
484            .collect::<Vec<_>>();
485
486        let expected_sql = r#"SELECT
487   NULL::TEXT AS pubname,
488   NULL::TEXT AS _1,
489   NULL::TEXT AS _2
490 WHERE false"#;
491
492        let expected_tokens = Parser::new(&PostgreSqlDialect {})
493            .try_with_sql(expected_sql)
494            .unwrap()
495            .into_tokens()
496            .into_iter()
497            .map(|t| t.token)
498            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
499            .collect::<Vec<_>>();
500
501        assert_eq!(actual_tokens, expected_tokens);
502    }
503
504    #[test]
505    fn test_empty_query() {
506        let parser = PostgresCompatibilityParser::new();
507        let result = parser.parse(" ").expect("failed to parse sql");
508        assert!(result.is_empty());
509
510        let result = parser.parse("").expect("failed to parse sql");
511        assert!(result.is_empty());
512
513        let result = parser.parse(";").expect("failed to parse sql");
514        assert!(result.is_empty());
515    }
516
517    #[test]
518    fn test_partial_match() {
519        let parser = PostgresCompatibilityParser::new();
520
521        // Test partial match where the beginning matches a blacklisted query
522        // Using a simpler query that doesn't have placeholders for easier testing
523        let sql = r#"SELECT
524        CASE WHEN
525              quote_ident(table_schema) IN (
526              SELECT
527                CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
528              FROM
529                generate_series(
530                  array_lower(string_to_array(current_setting('search_path'),','),1),
531                  array_upper(string_to_array(current_setting('search_path'),','),1)
532                ) as i,
533                string_to_array(current_setting('search_path'),',') s
534              )
535          THEN quote_ident(table_name)
536          ELSE quote_ident(table_schema) || '.' || quote_ident(table_name)
537        END AS "table"
538        FROM information_schema.tables
539        WHERE quote_ident(table_schema) NOT IN ('information_schema',
540                                 'pg_catalog',
541                                 '_timescaledb_cache',
542                                 '_timescaledb_catalog',
543                                 '_timescaledb_internal',
544                                 '_timescaledb_config',
545                                 'timescaledb_information',
546                                 'timescaledb_experimental')
547        ORDER BY CASE WHEN
548              quote_ident(table_schema) IN (
549              SELECT
550                CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
551              FROM
552                generate_series(
553                  array_lower(string_to_array(current_setting('search_path'),','),1),
554                  array_upper(string_to_array(current_setting('search_path'),','),1)
555                ) as i,
556                string_to_array(current_setting('search_path'),',') s
557              ) THEN 0 ELSE 1 END, 1"#;
558
559        let tokens = parser
560            .maybe_replace_tokens(sql)
561            .expect("failed to parse sql");
562        // Should have the beginning replaced with 'SELECT' and the rest preserved
563        assert!(tokens.len() > 0);
564
565        let expected_sql = r#"SELECT
566        CASE WHEN
567              quote_ident(table_schema) IN (
568              '')
569          THEN quote_ident(table_name)
570          ELSE quote_ident(table_schema) || '.' || quote_ident(table_name)
571        END AS "table"
572        FROM information_schema.tables
573        WHERE quote_ident(table_schema) NOT IN ('information_schema',
574                                 'pg_catalog',
575                                 '_timescaledb_cache',
576                                 '_timescaledb_catalog',
577                                 '_timescaledb_internal',
578                                 '_timescaledb_config',
579                                 'timescaledb_information',
580                                 'timescaledb_experimental')
581        ORDER BY CASE WHEN
582              quote_ident(table_schema) IN (
583              ''
584              ) THEN 0 ELSE 1 END, 1"#;
585
586        let expected_tokens = Parser::new(&PostgreSqlDialect {})
587            .try_with_sql(expected_sql)
588            .unwrap()
589            .into_tokens();
590
591        // Compare token values (ignoring spans and whitespace)
592        let actual_tokens: Vec<_> = tokens
593            .iter()
594            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
595            .collect();
596        let expected_token_values: Vec<_> = expected_tokens
597            .iter()
598            .map(|t| &t.token)
599            .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
600            .collect();
601
602        assert_eq!(actual_tokens, expected_token_values);
603    }
604}