datafusion_pg_catalog/sql/
parser.rs

1use std::sync::Arc;
2
3use datafusion::sql::sqlparser::ast::Statement;
4use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
5use datafusion::sql::sqlparser::parser::Parser;
6use datafusion::sql::sqlparser::parser::ParserError;
7use datafusion::sql::sqlparser::tokenizer::Token;
8use datafusion::sql::sqlparser::tokenizer::TokenWithSpan;
9
10use crate::sql::rules::FixVersionColumnName;
11
12use super::rules::AliasDuplicatedProjectionRewrite;
13use super::rules::CurrentUserVariableToSessionUserFunctionCall;
14use super::rules::FixArrayLiteral;
15use super::rules::FixCollate;
16use super::rules::PrependUnqualifiedPgTableName;
17use super::rules::RemoveQualifier;
18use super::rules::RemoveSubqueryFromProjection;
19use super::rules::RemoveUnsupportedTypes;
20use super::rules::ResolveUnqualifiedIdentifer;
21use super::rules::RewriteArrayAnyAllOperation;
22use super::rules::SqlStatementRewriteRule;
23
24const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[
25    // pgcli startup query
26    (
27"SELECT s_p.nspname AS parentschema,
28                               t_p.relname AS parenttable,
29                               unnest((
30                                select
31                                    array_agg(attname ORDER BY i)
32                                from
33                                    (select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
34                                    JOIN pg_catalog.pg_attribute c USING(attnum)
35                                    WHERE c.attrelid = fk.confrelid
36                                )) AS parentcolumn,
37                               s_c.nspname AS childschema,
38                               t_c.relname AS childtable,
39                               unnest((
40                                select
41                                    array_agg(attname ORDER BY i)
42                                from
43                                    (select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
44                                    JOIN pg_catalog.pg_attribute c USING(attnum)
45                                    WHERE c.attrelid = fk.conrelid
46                                )) AS childcolumn
47                        FROM pg_catalog.pg_constraint fk
48                        JOIN pg_catalog.pg_class      t_p ON t_p.oid = fk.confrelid
49                        JOIN pg_catalog.pg_namespace  s_p ON s_p.oid = t_p.relnamespace
50                        JOIN pg_catalog.pg_class      t_c ON t_c.oid = fk.conrelid
51                        JOIN pg_catalog.pg_namespace  s_c ON s_c.oid = t_c.relnamespace
52                        WHERE fk.contype = 'f'",
53"SELECT
54   NULL::TEXT AS parentschema,
55   NULL::TEXT AS parenttable,
56   NULL::TEXT AS parentcolumn,
57   NULL::TEXT AS childschema,
58   NULL::TEXT AS childtable,
59   NULL::TEXT AS childcolumn
60 WHERE false"),
61
62    // pgcli startup query
63    (
64"SELECT n.nspname schema_name,
65                                       t.typname type_name
66                                FROM   pg_catalog.pg_type t
67                                       INNER JOIN pg_catalog.pg_namespace n
68                                          ON n.oid = t.typnamespace
69                                WHERE ( t.typrelid = 0  -- non-composite types
70                                        OR (  -- composite type, but not a table
71                                              SELECT c.relkind = 'c'
72                                              FROM pg_catalog.pg_class c
73                                              WHERE c.oid = t.typrelid
74                                            )
75                                      )
76                                      AND NOT EXISTS( -- ignore array types
77                                            SELECT  1
78                                            FROM    pg_catalog.pg_type el
79                                            WHERE   el.oid = t.typelem AND el.typarray = t.oid
80                                          )
81                                      AND n.nspname <> 'pg_catalog'
82                                      AND n.nspname <> 'information_schema'
83                                ORDER BY 1, 2;",
84"SELECT NULL::TEXT AS schema_name, NULL::TEXT AS type_name WHERE false"
85    ),
86
87// psql \d <table> queries
88    (
89"SELECT pol.polname, pol.polpermissive,
90          CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END,
91          pg_catalog.pg_get_expr(pol.polqual, pol.polrelid),
92          pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid),
93          CASE pol.polcmd
94            WHEN 'r' THEN 'SELECT'
95            WHEN 'a' THEN 'INSERT'
96            WHEN 'w' THEN 'UPDATE'
97            WHEN 'd' THEN 'DELETE'
98            END AS cmd
99        FROM pg_catalog.pg_policy pol
100        WHERE pol.polrelid = $1 ORDER BY 1;",
101"SELECT
102   NULL::TEXT AS polname,
103   NULL::TEXT AS polpermissive,
104   NULL::TEXT AS array_to_string,
105   NULL::TEXT AS pg_get_expr_1,
106   NULL::TEXT AS pg_get_expr_2,
107   NULL::TEXT AS cmd
108 WHERE false"
109    ),
110
111    (
112"SELECT oid, stxrelid::pg_catalog.regclass, stxnamespace::pg_catalog.regnamespace::pg_catalog.text AS nsp, stxname,
113        pg_catalog.pg_get_statisticsobjdef_columns(oid) AS columns,
114          'd' = any(stxkind) AS ndist_enabled,
115          'f' = any(stxkind) AS deps_enabled,
116          'm' = any(stxkind) AS mcv_enabled,
117        stxstattarget
118        FROM pg_catalog.pg_statistic_ext
119        WHERE stxrelid = $1
120        ORDER BY nsp, stxname;",
121"SELECT
122   NULL::INT AS oid,
123   NULL::TEXT AS stxrelid,
124   NULL::TEXT AS nsp,
125   NULL::TEXT AS stxname,
126   NULL::TEXT AS columns,
127   NULL::BOOLEAN AS ndist_enabled,
128   NULL::BOOLEAN AS deps_enabled,
129   NULL::BOOLEAN AS mcv_enabled,
130   NULL::TEXT AS stxstattarget
131 WHERE false"
132    ),
133
134    (
135"SELECT pubname
136             , NULL
137             , NULL
138        FROM pg_catalog.pg_publication p
139             JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid
140             JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid
141        WHERE pc.oid = $1 and pg_catalog.pg_relation_is_publishable($1)
142        UNION
143        SELECT pubname
144             , pg_get_expr(pr.prqual, c.oid)
145             , (CASE WHEN pr.prattrs IS NOT NULL THEN
146                 (SELECT string_agg(attname, ', ')
147                   FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s,
148                        pg_catalog.pg_attribute
149                  WHERE attrelid = pr.prrelid AND attnum = prattrs[s])
150                ELSE NULL END) FROM pg_catalog.pg_publication p
151             JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid
152             JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid
153        WHERE pr.prrelid = $1
154        UNION
155        SELECT pubname
156             , NULL
157             , NULL
158        FROM pg_catalog.pg_publication p
159        WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable($1)
160        ORDER BY 1;",
161"SELECT
162   NULL::TEXT AS pubname,
163   NULL::TEXT AS _1,
164   NULL::TEXT AS _2
165 WHERE false"
166    ),
167];
168
169/// A parser with Postgres Compatibility for Datafusion
170///
171/// This parser will try its best to rewrite postgres SQL into a form that
172/// datafuiosn supports. It also maintains a blacklist that will transform the
173/// statement to a similar version if rewrite doesn't worth the effort for now.
174#[derive(Debug)]
175pub struct PostgresCompatibilityParser {
176    blacklist: Vec<(Vec<Token>, Statement)>,
177    rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
178}
179
180impl Default for PostgresCompatibilityParser {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186impl PostgresCompatibilityParser {
187    pub fn new() -> Self {
188        let mut mapping = Vec::with_capacity(BLACKLIST_SQL_MAPPING.len());
189
190        for (sql_from, sql_to) in BLACKLIST_SQL_MAPPING {
191            mapping.push((
192                Parser::new(&PostgreSqlDialect {})
193                    .try_with_sql(sql_from)
194                    .unwrap()
195                    .into_tokens()
196                    .into_iter()
197                    .map(|t| t.token)
198                    .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
199                    .collect(),
200                Parser::new(&PostgreSqlDialect {})
201                    .try_with_sql(sql_to)
202                    .unwrap()
203                    .parse_statement()
204                    .unwrap(),
205            ));
206        }
207
208        Self {
209            blacklist: mapping,
210            rewrite_rules: vec![
211                // make sure blacklist based rewriter it on the top to prevent sql
212                // being rewritten from other rewriters
213                Arc::new(AliasDuplicatedProjectionRewrite),
214                Arc::new(ResolveUnqualifiedIdentifer),
215                Arc::new(RewriteArrayAnyAllOperation),
216                Arc::new(PrependUnqualifiedPgTableName),
217                Arc::new(RemoveQualifier),
218                Arc::new(RemoveUnsupportedTypes::new()),
219                Arc::new(FixArrayLiteral),
220                Arc::new(CurrentUserVariableToSessionUserFunctionCall),
221                Arc::new(FixCollate),
222                Arc::new(RemoveSubqueryFromProjection),
223                Arc::new(FixVersionColumnName),
224            ],
225        }
226    }
227
228    /// return statement if matched
229    fn parse_and_replace(&self, input: &str) -> Result<MatchResult, ParserError> {
230        let parser = Parser::new(&PostgreSqlDialect {});
231        let tokens = parser.try_with_sql(input)?.into_tokens();
232
233        let tokens_without_whitespace = tokens
234            .iter()
235            .filter(|t| !matches!(t.token, Token::Whitespace(_) | Token::SemiColon))
236            .collect::<Vec<_>>();
237
238        for (blacklisted_sql_tokens, replacement) in &self.blacklist {
239            if blacklisted_sql_tokens.len() == tokens_without_whitespace.len() {
240                let matches = blacklisted_sql_tokens
241                    .iter()
242                    .zip(tokens_without_whitespace.iter())
243                    .all(|(a, b)| {
244                        if matches!(a, Token::Placeholder(_)) {
245                            true
246                        } else {
247                            *a == b.token
248                        }
249                    });
250                if matches {
251                    return Ok(MatchResult::Matches(Box::new(replacement.clone())));
252                }
253            } else {
254                continue;
255            }
256        }
257
258        Ok(MatchResult::Unmatches(tokens))
259    }
260
261    fn parse_tokens(&self, tokens: Vec<TokenWithSpan>) -> Result<Vec<Statement>, ParserError> {
262        let parser = Parser::new(&PostgreSqlDialect {});
263        parser.with_tokens_with_locations(tokens).parse_statements()
264    }
265
266    pub fn parse(&self, input: &str) -> Result<Vec<Statement>, ParserError> {
267        let statements = match self.parse_and_replace(input)? {
268            MatchResult::Matches(statement) => vec![*statement],
269            MatchResult::Unmatches(tokens) => self.parse_tokens(tokens)?,
270        };
271
272        let statements = statements.into_iter().map(|s| self.rewrite(s)).collect();
273
274        Ok(statements)
275    }
276
277    pub fn rewrite(&self, mut s: Statement) -> Statement {
278        for rule in &self.rewrite_rules {
279            s = rule.rewrite(s);
280        }
281
282        s
283    }
284}
285
286pub(crate) enum MatchResult {
287    Matches(Box<Statement>),
288    Unmatches(Vec<TokenWithSpan>),
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_sql_mapping() {
297        let sql = "SELECT pol.polname, pol.polpermissive,
298              CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END,
299              pg_catalog.pg_get_expr(pol.polqual, pol.polrelid),
300              pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid),
301              CASE pol.polcmd
302                WHEN 'r' THEN 'SELECT'
303                WHEN 'a' THEN 'INSERT'
304                WHEN 'w' THEN 'UPDATE'
305                WHEN 'd' THEN 'DELETE'
306                END AS cmd
307            FROM pg_catalog.pg_policy pol
308            WHERE pol.polrelid = '16384' ORDER BY 1;";
309
310        let parser = PostgresCompatibilityParser::new();
311        let match_result = parser.parse_and_replace(sql).expect("failed to parse sql");
312        assert!(matches!(match_result, MatchResult::Matches(_)));
313
314        let sql = "SELECT n.nspname schema_name,
315                                       t.typname type_name
316                                FROM   pg_catalog.pg_type t
317                                       INNER JOIN pg_catalog.pg_namespace n
318                                          ON n.oid = t.typnamespace
319                                WHERE ( t.typrelid = 0  -- non-composite types
320                                        OR (  -- composite type, but not a table
321                                              SELECT c.relkind = 'c'
322                                              FROM pg_catalog.pg_class c
323                                              WHERE c.oid = t.typrelid
324                                            )
325                                      )
326                                      AND NOT EXISTS( -- ignore array types
327                                            SELECT  1
328                                            FROM    pg_catalog.pg_type el
329                                            WHERE   el.oid = t.typelem AND el.typarray = t.oid
330                                          )
331                                      AND n.nspname <> 'pg_catalog'
332                                      AND n.nspname <> 'information_schema'
333                                ORDER BY 1, 2";
334
335        let parser = PostgresCompatibilityParser::new();
336        let match_result = parser.parse_and_replace(sql).expect("failed to parse sql");
337        assert!(matches!(match_result, MatchResult::Matches(_)));
338    }
339}