1use std::sync::Arc;
2
3use datafusion::sql::sqlparser::ast::Statement;
4use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
5use datafusion::sql::sqlparser::parser::Parser;
6use datafusion::sql::sqlparser::parser::ParserError;
7use datafusion::sql::sqlparser::tokenizer::Token;
8use datafusion::sql::sqlparser::tokenizer::TokenWithSpan;
9
10use super::rules::AliasDuplicatedProjectionRewrite;
11use super::rules::CurrentUserVariableToSessionUserFunctionCall;
12use super::rules::FixArrayLiteral;
13use super::rules::FixCollate;
14use super::rules::FixVersionColumnName;
15use super::rules::PrependUnqualifiedPgTableName;
16use super::rules::RemoveQualifier;
17use super::rules::RemoveSubqueryFromProjection;
18use super::rules::RemoveUnsupportedTypes;
19use super::rules::ResolveUnqualifiedIdentifer;
20use super::rules::RewriteArrayAnyAllOperation;
21use super::rules::RewriteRegclassCastToSubquery;
22use super::rules::SqlStatementRewriteRule;
23
24const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[
25 (
27"SELECT s_p.nspname AS parentschema,
28 t_p.relname AS parenttable,
29 unnest((
30 select
31 array_agg(attname ORDER BY i)
32 from
33 (select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
34 JOIN pg_catalog.pg_attribute c USING(attnum)
35 WHERE c.attrelid = fk.confrelid
36 )) AS parentcolumn,
37 s_c.nspname AS childschema,
38 t_c.relname AS childtable,
39 unnest((
40 select
41 array_agg(attname ORDER BY i)
42 from
43 (select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
44 JOIN pg_catalog.pg_attribute c USING(attnum)
45 WHERE c.attrelid = fk.conrelid
46 )) AS childcolumn
47 FROM pg_catalog.pg_constraint fk
48 JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid
49 JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace
50 JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid
51 JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace
52 WHERE fk.contype = 'f'",
53"SELECT
54 NULL::TEXT AS parentschema,
55 NULL::TEXT AS parenttable,
56 NULL::TEXT AS parentcolumn,
57 NULL::TEXT AS childschema,
58 NULL::TEXT AS childtable,
59 NULL::TEXT AS childcolumn
60 WHERE false"),
61
62 (
64"SELECT n.nspname schema_name,
65 t.typname type_name
66 FROM pg_catalog.pg_type t
67 INNER JOIN pg_catalog.pg_namespace n
68 ON n.oid = t.typnamespace
69 WHERE ( t.typrelid = 0 -- non-composite types
70 OR ( -- composite type, but not a table
71 SELECT c.relkind = 'c'
72 FROM pg_catalog.pg_class c
73 WHERE c.oid = t.typrelid
74 )
75 )
76 AND NOT EXISTS( -- ignore array types
77 SELECT 1
78 FROM pg_catalog.pg_type el
79 WHERE el.oid = t.typelem AND el.typarray = t.oid
80 )
81 AND n.nspname <> 'pg_catalog'
82 AND n.nspname <> 'information_schema'
83 ORDER BY 1, 2;",
84"SELECT NULL::TEXT AS schema_name, NULL::TEXT AS type_name WHERE false"
85 ),
86
87(
89"SELECT pol.polname, pol.polpermissive,
90 CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END,
91 pg_catalog.pg_get_expr(pol.polqual, pol.polrelid),
92 pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid),
93 CASE pol.polcmd
94 WHEN 'r' THEN 'SELECT'
95 WHEN 'a' THEN 'INSERT'
96 WHEN 'w' THEN 'UPDATE'
97 WHEN 'd' THEN 'DELETE'
98 END AS cmd
99 FROM pg_catalog.pg_policy pol
100 WHERE pol.polrelid = $1 ORDER BY 1;",
101"SELECT
102 NULL::TEXT AS polname,
103 NULL::TEXT AS polpermissive,
104 NULL::TEXT AS array_to_string,
105 NULL::TEXT AS pg_get_expr_1,
106 NULL::TEXT AS pg_get_expr_2,
107 NULL::TEXT AS cmd
108 WHERE false"
109 ),
110
111 (
112"SELECT oid, stxrelid::pg_catalog.regclass, stxnamespace::pg_catalog.regnamespace::pg_catalog.text AS nsp, stxname,
113 pg_catalog.pg_get_statisticsobjdef_columns(oid) AS columns,
114 'd' = any(stxkind) AS ndist_enabled,
115 'f' = any(stxkind) AS deps_enabled,
116 'm' = any(stxkind) AS mcv_enabled,
117 stxstattarget
118 FROM pg_catalog.pg_statistic_ext
119 WHERE stxrelid = $1
120 ORDER BY nsp, stxname;",
121"SELECT
122 NULL::INT AS oid,
123 NULL::TEXT AS stxrelid,
124 NULL::TEXT AS nsp,
125 NULL::TEXT AS stxname,
126 NULL::TEXT AS columns,
127 NULL::BOOLEAN AS ndist_enabled,
128 NULL::BOOLEAN AS deps_enabled,
129 NULL::BOOLEAN AS mcv_enabled,
130 NULL::TEXT AS stxstattarget
131 WHERE false"
132 ),
133
134 (
135"SELECT pubname
136 , NULL
137 , NULL
138 FROM pg_catalog.pg_publication p
139 JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid
140 JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid
141 WHERE pc.oid = $1 and pg_catalog.pg_relation_is_publishable($1)
142 UNION
143 SELECT pubname
144 , pg_get_expr(pr.prqual, c.oid)
145 , (CASE WHEN pr.prattrs IS NOT NULL THEN
146 (SELECT string_agg(attname, ', ')
147 FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s,
148 pg_catalog.pg_attribute
149 WHERE attrelid = pr.prrelid AND attnum = prattrs[s])
150 ELSE NULL END) FROM pg_catalog.pg_publication p
151 JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid
152 JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid
153 WHERE pr.prrelid = $1
154 UNION
155 SELECT pubname
156 , NULL
157 , NULL
158 FROM pg_catalog.pg_publication p
159 WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable($1)
160 ORDER BY 1;",
161"SELECT
162 NULL::TEXT AS pubname,
163 NULL::TEXT AS _1,
164 NULL::TEXT AS _2
165 WHERE false"
166 ),
167
168 (r#"SELECT
170 CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
171 FROM
172 generate_series(
173 array_lower(string_to_array(current_setting('search_path'),','),1),
174 array_upper(string_to_array(current_setting('search_path'),','),1)
175 ) as i,
176 string_to_array(current_setting('search_path'),',') s"#,
177"''")
178];
179
180#[derive(Debug)]
186pub struct PostgresCompatibilityParser {
187 blacklist: Vec<(Vec<Token>, Vec<Token>)>,
188 rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
189}
190
191impl Default for PostgresCompatibilityParser {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197impl PostgresCompatibilityParser {
198 pub fn new() -> Self {
199 let mut mapping = Vec::with_capacity(BLACKLIST_SQL_MAPPING.len());
200
201 for (sql_from, sql_to) in BLACKLIST_SQL_MAPPING {
202 mapping.push((
203 Parser::new(&PostgreSqlDialect {})
204 .try_with_sql(sql_from)
205 .unwrap()
206 .into_tokens()
207 .into_iter()
208 .map(|t| t.token)
209 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
210 .collect(),
211 Parser::new(&PostgreSqlDialect {})
212 .try_with_sql(sql_to)
213 .unwrap()
214 .into_tokens()
215 .into_iter()
216 .map(|t| t.token)
217 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
218 .collect(),
219 ));
220 }
221
222 Self {
223 blacklist: mapping,
224 rewrite_rules: vec![
225 Arc::new(AliasDuplicatedProjectionRewrite),
228 Arc::new(ResolveUnqualifiedIdentifer),
229 Arc::new(RewriteArrayAnyAllOperation),
230 Arc::new(PrependUnqualifiedPgTableName),
231 Arc::new(RemoveQualifier),
232 Arc::new(RewriteRegclassCastToSubquery::new()),
233 Arc::new(RemoveUnsupportedTypes::new()),
234 Arc::new(FixArrayLiteral),
235 Arc::new(CurrentUserVariableToSessionUserFunctionCall),
236 Arc::new(FixCollate),
237 Arc::new(RemoveSubqueryFromProjection),
238 Arc::new(FixVersionColumnName),
239 ],
240 }
241 }
242
243 fn maybe_replace_tokens(&self, input: &str) -> Result<Vec<Token>, ParserError> {
245 let parser = Parser::new(&PostgreSqlDialect {});
246 let tokens = parser.try_with_sql(input)?.into_tokens();
247
248 let filtered_tokens: Vec<Token> = tokens
251 .iter()
252 .map(|t| t.token.clone())
253 .filter(|t| !matches!(t, Token::Whitespace(_)))
254 .collect();
255
256 if filtered_tokens.is_empty() {
258 return Ok(Vec::new());
259 }
260
261 let mut result = Vec::new();
263 let mut i = 0;
264
265 while i < filtered_tokens.len() {
266 if matches!(&filtered_tokens[i], Token::SemiColon) {
268 result.push(filtered_tokens[i].clone());
269 i += 1;
270 continue;
271 }
272
273 let mut matched = false;
275 for (pattern, replacement) in &self.blacklist {
276 if pattern.is_empty() {
277 continue;
278 }
279
280 let mut j = 0;
282 let mut pattern_idx = 0;
283 while i + j < filtered_tokens.len() && pattern_idx < pattern.len() {
284 if matches!(&filtered_tokens[i + j], Token::SemiColon) {
286 j += 1;
287 continue;
288 }
289
290 match &pattern[pattern_idx] {
291 Token::Placeholder(_) => {
292 pattern_idx += 1;
294 j += 1;
295 }
296 _ => {
297 if filtered_tokens[i + j] != pattern[pattern_idx] {
298 break;
299 }
300 pattern_idx += 1;
301 j += 1;
302 }
303 }
304 }
305
306 if pattern_idx == pattern.len() {
308 result.extend(replacement.iter().cloned());
310 i += j;
312 matched = true;
313 break;
314 }
315 }
316
317 if !matched {
318 result.push(filtered_tokens[i].clone());
320 i += 1;
321 }
322 }
323
324 Ok(result)
325 }
326
327 fn parse_tokens(&self, tokens: Vec<Token>) -> Result<Vec<Statement>, ParserError> {
328 let parser = Parser::new(&PostgreSqlDialect {});
329 let tokens_with_spans: Vec<TokenWithSpan> = tokens
331 .into_iter()
332 .map(|token| TokenWithSpan {
333 token,
334 span: datafusion::sql::sqlparser::tokenizer::Span::empty(),
335 })
336 .collect();
337 parser
338 .with_tokens_with_locations(tokens_with_spans)
339 .parse_statements()
340 }
341
342 pub fn parse(&self, input: &str) -> Result<Vec<Statement>, ParserError> {
343 let tokens = self.maybe_replace_tokens(input)?;
344 let statements = self.parse_tokens(tokens)?;
345
346 let statements: Vec<_> = statements.into_iter().map(|s| self.rewrite(s)).collect();
347 Ok(statements)
348 }
349
350 pub fn rewrite(&self, mut s: Statement) -> Statement {
351 for rule in &self.rewrite_rules {
352 s = rule.rewrite(s);
353 }
354
355 s
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_full_match() {
365 let sql = "SELECT pol.polname, pol.polpermissive,
366 CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END,
367 pg_catalog.pg_get_expr(pol.polqual, pol.polrelid),
368 pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid),
369 CASE pol.polcmd
370 WHEN 'r' THEN 'SELECT'
371 WHEN 'a' THEN 'INSERT'
372 WHEN 'w' THEN 'UPDATE'
373 WHEN 'd' THEN 'DELETE'
374 END AS cmd
375 FROM pg_catalog.pg_policy pol
376 WHERE pol.polrelid = '16384' ORDER BY 1;";
377
378 let parser = PostgresCompatibilityParser::new();
379 let actual_tokens = parser
380 .maybe_replace_tokens(sql)
381 .expect("failed to parse sql")
382 .into_iter()
383 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
384 .collect::<Vec<_>>();
385
386 let expected_sql = r#"SELECT
387 NULL::TEXT AS polname,
388 NULL::TEXT AS polpermissive,
389 NULL::TEXT AS array_to_string,
390 NULL::TEXT AS pg_get_expr_1,
391 NULL::TEXT AS pg_get_expr_2,
392 NULL::TEXT AS cmd
393 WHERE false"#;
394
395 let expected_tokens = Parser::new(&PostgreSqlDialect {})
396 .try_with_sql(expected_sql)
397 .unwrap()
398 .into_tokens()
399 .into_iter()
400 .map(|t| t.token)
401 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
402 .collect::<Vec<_>>();
403
404 assert_eq!(actual_tokens, expected_tokens);
405
406 let sql = "SELECT n.nspname schema_name,
407 t.typname type_name
408 FROM pg_catalog.pg_type t
409 INNER JOIN pg_catalog.pg_namespace n
410 ON n.oid = t.typnamespace
411 WHERE ( t.typrelid = 0 -- non-composite types
412 OR ( -- composite type, but not a table
413 SELECT c.relkind = 'c'
414 FROM pg_catalog.pg_class c
415 WHERE c.oid = t.typrelid
416 )
417 )
418 AND NOT EXISTS( -- ignore array types
419 SELECT 1
420 FROM pg_catalog.pg_type el
421 WHERE el.oid = t.typelem AND el.typarray = t.oid
422 )
423 AND n.nspname <> 'pg_catalog'
424 AND n.nspname <> 'information_schema'
425 ORDER BY 1, 2";
426
427 let parser = PostgresCompatibilityParser::new();
428
429 let actual_tokens = parser
430 .maybe_replace_tokens(sql)
431 .expect("failed to parse sql")
432 .into_iter()
433 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
434 .collect::<Vec<_>>();
435
436 let expected_sql =
437 r#"SELECT NULL::TEXT AS schema_name, NULL::TEXT AS type_name WHERE false"#;
438
439 let expected_tokens = Parser::new(&PostgreSqlDialect {})
440 .try_with_sql(expected_sql)
441 .unwrap()
442 .into_tokens()
443 .into_iter()
444 .map(|t| t.token)
445 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
446 .collect::<Vec<_>>();
447
448 assert_eq!(actual_tokens, expected_tokens);
449
450 let sql = "SELECT pubname
451 , NULL
452 , NULL
453 FROM pg_catalog.pg_publication p
454 JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid
455 JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid
456 WHERE pc.oid ='16384' and pg_catalog.pg_relation_is_publishable('16384')
457 UNION
458 SELECT pubname
459 , pg_get_expr(pr.prqual, c.oid)
460 , (CASE WHEN pr.prattrs IS NOT NULL THEN
461 (SELECT string_agg(attname, ', ')
462 FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s,
463 pg_catalog.pg_attribute
464 WHERE attrelid = pr.prrelid AND attnum = prattrs[s])
465 ELSE NULL END) FROM pg_catalog.pg_publication p
466 JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid
467 JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid
468 WHERE pr.prrelid = '16384'
469 UNION
470 SELECT pubname
471 , NULL
472 , NULL
473 FROM pg_catalog.pg_publication p
474 WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable('16384')
475 ORDER BY 1;";
476
477 let parser = PostgresCompatibilityParser::new();
478
479 let actual_tokens = parser
480 .maybe_replace_tokens(sql)
481 .expect("failed to parse sql")
482 .into_iter()
483 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
484 .collect::<Vec<_>>();
485
486 let expected_sql = r#"SELECT
487 NULL::TEXT AS pubname,
488 NULL::TEXT AS _1,
489 NULL::TEXT AS _2
490 WHERE false"#;
491
492 let expected_tokens = Parser::new(&PostgreSqlDialect {})
493 .try_with_sql(expected_sql)
494 .unwrap()
495 .into_tokens()
496 .into_iter()
497 .map(|t| t.token)
498 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
499 .collect::<Vec<_>>();
500
501 assert_eq!(actual_tokens, expected_tokens);
502 }
503
504 #[test]
505 fn test_empty_query() {
506 let parser = PostgresCompatibilityParser::new();
507 let result = parser.parse(" ").expect("failed to parse sql");
508 assert!(result.is_empty());
509
510 let result = parser.parse("").expect("failed to parse sql");
511 assert!(result.is_empty());
512
513 let result = parser.parse(";").expect("failed to parse sql");
514 assert!(result.is_empty());
515 }
516
517 #[test]
518 fn test_partial_match() {
519 let parser = PostgresCompatibilityParser::new();
520
521 let sql = r#"SELECT
524 CASE WHEN
525 quote_ident(table_schema) IN (
526 SELECT
527 CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
528 FROM
529 generate_series(
530 array_lower(string_to_array(current_setting('search_path'),','),1),
531 array_upper(string_to_array(current_setting('search_path'),','),1)
532 ) as i,
533 string_to_array(current_setting('search_path'),',') s
534 )
535 THEN quote_ident(table_name)
536 ELSE quote_ident(table_schema) || '.' || quote_ident(table_name)
537 END AS "table"
538 FROM information_schema.tables
539 WHERE quote_ident(table_schema) NOT IN ('information_schema',
540 'pg_catalog',
541 '_timescaledb_cache',
542 '_timescaledb_catalog',
543 '_timescaledb_internal',
544 '_timescaledb_config',
545 'timescaledb_information',
546 'timescaledb_experimental')
547 ORDER BY CASE WHEN
548 quote_ident(table_schema) IN (
549 SELECT
550 CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END
551 FROM
552 generate_series(
553 array_lower(string_to_array(current_setting('search_path'),','),1),
554 array_upper(string_to_array(current_setting('search_path'),','),1)
555 ) as i,
556 string_to_array(current_setting('search_path'),',') s
557 ) THEN 0 ELSE 1 END, 1"#;
558
559 let tokens = parser
560 .maybe_replace_tokens(sql)
561 .expect("failed to parse sql");
562 assert!(tokens.len() > 0);
564
565 let expected_sql = r#"SELECT
566 CASE WHEN
567 quote_ident(table_schema) IN (
568 '')
569 THEN quote_ident(table_name)
570 ELSE quote_ident(table_schema) || '.' || quote_ident(table_name)
571 END AS "table"
572 FROM information_schema.tables
573 WHERE quote_ident(table_schema) NOT IN ('information_schema',
574 'pg_catalog',
575 '_timescaledb_cache',
576 '_timescaledb_catalog',
577 '_timescaledb_internal',
578 '_timescaledb_config',
579 'timescaledb_information',
580 'timescaledb_experimental')
581 ORDER BY CASE WHEN
582 quote_ident(table_schema) IN (
583 ''
584 ) THEN 0 ELSE 1 END, 1"#;
585
586 let expected_tokens = Parser::new(&PostgreSqlDialect {})
587 .try_with_sql(expected_sql)
588 .unwrap()
589 .into_tokens();
590
591 let actual_tokens: Vec<_> = tokens
593 .iter()
594 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
595 .collect();
596 let expected_token_values: Vec<_> = expected_tokens
597 .iter()
598 .map(|t| &t.token)
599 .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon))
600 .collect();
601
602 assert_eq!(actual_tokens, expected_token_values);
603 }
604}